grep_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use futures::StreamExt;
  5use gpui::{AnyWindowHandle, App, Entity, Task};
  6use language::{OffsetRangeExt, ParseStatus, Point};
  7use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
  8use project::{
  9    Project,
 10    search::{SearchQuery, SearchResult},
 11};
 12use schemars::JsonSchema;
 13use serde::{Deserialize, Serialize};
 14use std::{cmp, fmt::Write, sync::Arc};
 15use ui::IconName;
 16use util::RangeExt;
 17use util::markdown::MarkdownInlineCode;
 18use util::paths::PathMatcher;
 19
 20#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 21pub struct GrepToolInput {
 22    /// A regex pattern to search for in the entire project. Note that the regex
 23    /// will be parsed by the Rust `regex` crate.
 24    ///
 25    /// Do NOT specify a path here! This will only be matched against the code **content**.
 26    pub regex: String,
 27
 28    /// A glob pattern for the paths of files to include in the search.
 29    /// Supports standard glob patterns like "**/*.rs" or "src/**/*.ts".
 30    /// If omitted, all files in the project will be searched.
 31    pub include_pattern: Option<String>,
 32
 33    /// Optional starting position for paginated results (0-based).
 34    /// When not provided, starts from the beginning.
 35    #[serde(default)]
 36    pub offset: u32,
 37
 38    /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
 39    #[serde(default)]
 40    pub case_sensitive: bool,
 41}
 42
 43impl GrepToolInput {
 44    /// Which page of search results this is.
 45    pub fn page(&self) -> u32 {
 46        1 + (self.offset / RESULTS_PER_PAGE)
 47    }
 48}
 49
 50const RESULTS_PER_PAGE: u32 = 20;
 51
 52pub struct GrepTool;
 53
 54impl Tool for GrepTool {
 55    fn name(&self) -> String {
 56        "grep".into()
 57    }
 58
 59    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 60        false
 61    }
 62
 63    fn description(&self) -> String {
 64        include_str!("./grep_tool/description.md").into()
 65    }
 66
 67    fn icon(&self) -> IconName {
 68        IconName::Regex
 69    }
 70
 71    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 72        json_schema_for::<GrepToolInput>(format)
 73    }
 74
 75    fn ui_text(&self, input: &serde_json::Value) -> String {
 76        match serde_json::from_value::<GrepToolInput>(input.clone()) {
 77            Ok(input) => {
 78                let page = input.page();
 79                let regex_str = MarkdownInlineCode(&input.regex);
 80                let case_info = if input.case_sensitive {
 81                    " (case-sensitive)"
 82                } else {
 83                    ""
 84                };
 85
 86                if page > 1 {
 87                    format!("Get page {page} of search results for regex {regex_str}{case_info}")
 88                } else {
 89                    format!("Search files for regex {regex_str}{case_info}")
 90                }
 91            }
 92            Err(_) => "Search with regex".to_string(),
 93        }
 94    }
 95
 96    fn run(
 97        self: Arc<Self>,
 98        input: serde_json::Value,
 99        _messages: &[LanguageModelRequestMessage],
100        project: Entity<Project>,
101        _action_log: Entity<ActionLog>,
102        _window: Option<AnyWindowHandle>,
103        cx: &mut App,
104    ) -> ToolResult {
105        const CONTEXT_LINES: u32 = 2;
106        const MAX_ANCESTOR_LINES: u32 = 10;
107
108        let input = match serde_json::from_value::<GrepToolInput>(input) {
109            Ok(input) => input,
110            Err(error) => {
111                return Task::ready(Err(anyhow!("Failed to parse input: {}", error))).into();
112            }
113        };
114
115        let include_matcher = match PathMatcher::new(
116            input
117                .include_pattern
118                .as_ref()
119                .into_iter()
120                .collect::<Vec<_>>(),
121        ) {
122            Ok(matcher) => matcher,
123            Err(error) => {
124                return Task::ready(Err(anyhow!("invalid include glob pattern: {}", error))).into();
125            }
126        };
127
128        let query = match SearchQuery::regex(
129            &input.regex,
130            false,
131            input.case_sensitive,
132            false,
133            false,
134            include_matcher,
135            PathMatcher::default(), // For now, keep it simple and don't enable an exclude pattern.
136            true, // Always match file include pattern against *full project paths* that start with a project root.
137            None,
138        ) {
139            Ok(query) => query,
140            Err(error) => return Task::ready(Err(error)).into(),
141        };
142
143        let results = project.update(cx, |project, cx| project.search(query, cx));
144
145        cx.spawn(async move |cx|  {
146            futures::pin_mut!(results);
147
148            let mut output = String::new();
149            let mut skips_remaining = input.offset;
150            let mut matches_found = 0;
151            let mut has_more_matches = false;
152
153            'outer: while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
154                if ranges.is_empty() {
155                    continue;
156                }
157
158                let (Some(path), mut parse_status) = buffer.read_with(cx, |buffer, cx| {
159                    (buffer.file().map(|file| file.full_path(cx)), buffer.parse_status())
160                })? else {
161                    continue;
162                };
163
164
165                while *parse_status.borrow() != ParseStatus::Idle {
166                    parse_status.changed().await?;
167                }
168
169                let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
170
171                let mut ranges = ranges
172                    .into_iter()
173                    .map(|range| {
174                        let matched = range.to_point(&snapshot);
175                        let matched_end_line_len = snapshot.line_len(matched.end.row);
176                        let full_lines = Point::new(matched.start.row, 0)..Point::new(matched.end.row, matched_end_line_len);
177                        let symbols = snapshot.symbols_containing(matched.start, None);
178
179                        if let Some(ancestor_node) = snapshot.syntax_ancestor(full_lines.clone()) {
180                            let full_ancestor_range = ancestor_node.byte_range().to_point(&snapshot);
181                            let end_row = full_ancestor_range.end.row.min(full_ancestor_range.start.row + MAX_ANCESTOR_LINES);
182                            let end_col = snapshot.line_len(end_row);
183                            let capped_ancestor_range = Point::new(full_ancestor_range.start.row, 0)..Point::new(end_row, end_col);
184
185                            if capped_ancestor_range.contains_inclusive(&full_lines) {
186                                return (capped_ancestor_range, Some(full_ancestor_range), symbols)
187                            }
188                        }
189
190                        let mut matched = matched;
191                        matched.start.column = 0;
192                        matched.start.row =
193                            matched.start.row.saturating_sub(CONTEXT_LINES);
194                        matched.end.row = cmp::min(
195                            snapshot.max_point().row,
196                            matched.end.row + CONTEXT_LINES,
197                        );
198                        matched.end.column = snapshot.line_len(matched.end.row);
199
200                        (matched, None, symbols)
201                    })
202                    .peekable();
203
204                let mut file_header_written = false;
205
206                while let Some((mut range, ancestor_range, parent_symbols)) = ranges.next(){
207                    if skips_remaining > 0 {
208                        skips_remaining -= 1;
209                        continue;
210                    }
211
212                    // We'd already found a full page of matches, and we just found one more.
213                    if matches_found >= RESULTS_PER_PAGE {
214                        has_more_matches = true;
215                        break 'outer;
216                    }
217
218                    while let Some((next_range, _, _)) = ranges.peek() {
219                        if range.end.row >= next_range.start.row {
220                            range.end = next_range.end;
221                            ranges.next();
222                        } else {
223                            break;
224                        }
225                    }
226
227                    if !file_header_written {
228                        writeln!(output, "\n## Matches in {}", path.display())?;
229                        file_header_written = true;
230                    }
231
232                    let end_row = range.end.row;
233                    output.push_str("\n### ");
234
235                    if let Some(parent_symbols) = &parent_symbols {
236                        for symbol in parent_symbols {
237                            write!(output, "{} › ", symbol.text)?;
238                        }
239                    }
240
241                    if range.start.row == end_row {
242                        writeln!(output, "L{}", range.start.row + 1)?;
243                    } else {
244                        writeln!(output, "L{}-{}", range.start.row + 1, end_row + 1)?;
245                    }
246
247                    output.push_str("```\n");
248                    output.extend(snapshot.text_for_range(range));
249                    output.push_str("\n```\n");
250
251                    if let Some(ancestor_range) = ancestor_range {
252                        if end_row < ancestor_range.end.row {
253                            let remaining_lines = ancestor_range.end.row - end_row;
254                            writeln!(output, "\n{} lines remaining in ancestor node. Read the file to see all.", remaining_lines)?;
255                        }
256                    }
257
258                    matches_found += 1;
259                }
260            }
261
262            if matches_found == 0 {
263                Ok("No matches found".to_string().into())
264            } else if has_more_matches {
265                Ok(format!(
266                    "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
267                    input.offset + 1,
268                    input.offset + matches_found,
269                    input.offset + RESULTS_PER_PAGE,
270                ).into())
271            } else {
272                Ok(format!("Found {matches_found} matches:\n{output}").into())
273            }
274        }).into()
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use assistant_tool::Tool;
282    use gpui::{AppContext, TestAppContext};
283    use language::{Language, LanguageConfig, LanguageMatcher};
284    use project::{FakeFs, Project};
285    use settings::SettingsStore;
286    use unindent::Unindent;
287    use util::path;
288
289    #[gpui::test]
290    async fn test_grep_tool_with_include_pattern(cx: &mut TestAppContext) {
291        init_test(cx);
292        cx.executor().allow_parking();
293
294        let fs = FakeFs::new(cx.executor().clone());
295        fs.insert_tree(
296            "/root",
297            serde_json::json!({
298                "src": {
299                    "main.rs": "fn main() {\n    println!(\"Hello, world!\");\n}",
300                    "utils": {
301                        "helper.rs": "fn helper() {\n    println!(\"I'm a helper!\");\n}",
302                    },
303                },
304                "tests": {
305                    "test_main.rs": "fn test_main() {\n    assert!(true);\n}",
306                }
307            }),
308        )
309        .await;
310
311        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
312
313        // Test with include pattern for Rust files inside the root of the project
314        let input = serde_json::to_value(GrepToolInput {
315            regex: "println".to_string(),
316            include_pattern: Some("root/**/*.rs".to_string()),
317            offset: 0,
318            case_sensitive: false,
319        })
320        .unwrap();
321
322        let result = run_grep_tool(input, project.clone(), cx).await;
323        assert!(result.contains("main.rs"), "Should find matches in main.rs");
324        assert!(
325            result.contains("helper.rs"),
326            "Should find matches in helper.rs"
327        );
328        assert!(
329            !result.contains("test_main.rs"),
330            "Should not include test_main.rs even though it's a .rs file (because it doesn't have the pattern)"
331        );
332
333        // Test with include pattern for src directory only
334        let input = serde_json::to_value(GrepToolInput {
335            regex: "fn".to_string(),
336            include_pattern: Some("root/**/src/**".to_string()),
337            offset: 0,
338            case_sensitive: false,
339        })
340        .unwrap();
341
342        let result = run_grep_tool(input, project.clone(), cx).await;
343        assert!(
344            result.contains("main.rs"),
345            "Should find matches in src/main.rs"
346        );
347        assert!(
348            result.contains("helper.rs"),
349            "Should find matches in src/utils/helper.rs"
350        );
351        assert!(
352            !result.contains("test_main.rs"),
353            "Should not include test_main.rs as it's not in src directory"
354        );
355
356        // Test with empty include pattern (should default to all files)
357        let input = serde_json::to_value(GrepToolInput {
358            regex: "fn".to_string(),
359            include_pattern: None,
360            offset: 0,
361            case_sensitive: false,
362        })
363        .unwrap();
364
365        let result = run_grep_tool(input, project.clone(), cx).await;
366        assert!(result.contains("main.rs"), "Should find matches in main.rs");
367        assert!(
368            result.contains("helper.rs"),
369            "Should find matches in helper.rs"
370        );
371        assert!(
372            result.contains("test_main.rs"),
373            "Should include test_main.rs"
374        );
375    }
376
377    #[gpui::test]
378    async fn test_grep_tool_with_case_sensitivity(cx: &mut TestAppContext) {
379        init_test(cx);
380        cx.executor().allow_parking();
381
382        let fs = FakeFs::new(cx.executor().clone());
383        fs.insert_tree(
384            "/root",
385            serde_json::json!({
386                "case_test.txt": "This file has UPPERCASE and lowercase text.\nUPPERCASE patterns should match only with case_sensitive: true",
387            }),
388        )
389        .await;
390
391        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
392
393        // Test case-insensitive search (default)
394        let input = serde_json::to_value(GrepToolInput {
395            regex: "uppercase".to_string(),
396            include_pattern: Some("**/*.txt".to_string()),
397            offset: 0,
398            case_sensitive: false,
399        })
400        .unwrap();
401
402        let result = run_grep_tool(input, project.clone(), cx).await;
403        assert!(
404            result.contains("UPPERCASE"),
405            "Case-insensitive search should match uppercase"
406        );
407
408        // Test case-sensitive search
409        let input = serde_json::to_value(GrepToolInput {
410            regex: "uppercase".to_string(),
411            include_pattern: Some("**/*.txt".to_string()),
412            offset: 0,
413            case_sensitive: true,
414        })
415        .unwrap();
416
417        let result = run_grep_tool(input, project.clone(), cx).await;
418        assert!(
419            !result.contains("UPPERCASE"),
420            "Case-sensitive search should not match uppercase"
421        );
422
423        // Test case-sensitive search
424        let input = serde_json::to_value(GrepToolInput {
425            regex: "LOWERCASE".to_string(),
426            include_pattern: Some("**/*.txt".to_string()),
427            offset: 0,
428            case_sensitive: true,
429        })
430        .unwrap();
431
432        let result = run_grep_tool(input, project.clone(), cx).await;
433
434        assert!(
435            !result.contains("lowercase"),
436            "Case-sensitive search should match lowercase"
437        );
438
439        // Test case-sensitive search for lowercase pattern
440        let input = serde_json::to_value(GrepToolInput {
441            regex: "lowercase".to_string(),
442            include_pattern: Some("**/*.txt".to_string()),
443            offset: 0,
444            case_sensitive: true,
445        })
446        .unwrap();
447
448        let result = run_grep_tool(input, project.clone(), cx).await;
449        assert!(
450            result.contains("lowercase"),
451            "Case-sensitive search should match lowercase text"
452        );
453    }
454
455    /// Helper function to set up a syntax test environment
456    async fn setup_syntax_test(cx: &mut TestAppContext) -> Entity<Project> {
457        use unindent::Unindent;
458        init_test(cx);
459        cx.executor().allow_parking();
460
461        let fs = FakeFs::new(cx.executor().clone());
462
463        // Create test file with syntax structures
464        fs.insert_tree(
465            "/root",
466            serde_json::json!({
467                "test_syntax.rs": r#"
468                    fn top_level_function() {
469                        println!("This is at the top level");
470                    }
471
472                    mod feature_module {
473                        pub mod nested_module {
474                            pub fn nested_function(
475                                first_arg: String,
476                                second_arg: i32,
477                            ) {
478                                println!("Function in nested module");
479                                println!("{first_arg}");
480                                println!("{second_arg}");
481                            }
482                        }
483                    }
484
485                    struct MyStruct {
486                        field1: String,
487                        field2: i32,
488                    }
489
490                    impl MyStruct {
491                        fn method_with_block() {
492                            let condition = true;
493                            if condition {
494                                println!("Inside if block");
495                            }
496                        }
497
498                        fn long_function() {
499                            println!("Line 1");
500                            println!("Line 2");
501                            println!("Line 3");
502                            println!("Line 4");
503                            println!("Line 5");
504                            println!("Line 6");
505                            println!("Line 7");
506                            println!("Line 8");
507                            println!("Line 9");
508                            println!("Line 10");
509                            println!("Line 11");
510                            println!("Line 12");
511                        }
512                    }
513
514                    trait Processor {
515                        fn process(&self, input: &str) -> String;
516                    }
517
518                    impl Processor for MyStruct {
519                        fn process(&self, input: &str) -> String {
520                            format!("Processed: {}", input)
521                        }
522                    }
523                "#.unindent().trim(),
524            }),
525        )
526        .await;
527
528        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
529
530        project.update(cx, |project, _cx| {
531            project.languages().add(rust_lang().into())
532        });
533
534        project
535    }
536
537    #[gpui::test]
538    async fn test_grep_top_level_function(cx: &mut TestAppContext) {
539        let project = setup_syntax_test(cx).await;
540
541        // Test: Line at the top level of the file
542        let input = serde_json::to_value(GrepToolInput {
543            regex: "This is at the top level".to_string(),
544            include_pattern: Some("**/*.rs".to_string()),
545            offset: 0,
546            case_sensitive: false,
547        })
548        .unwrap();
549
550        let result = run_grep_tool(input, project.clone(), cx).await;
551        let expected = r#"
552            Found 1 matches:
553
554            ## Matches in root/test_syntax.rs
555
556            ### fn top_level_function › L1-3
557            ```
558            fn top_level_function() {
559                println!("This is at the top level");
560            }
561            ```
562            "#
563        .unindent();
564        assert_eq!(result, expected);
565    }
566
567    #[gpui::test]
568    async fn test_grep_function_body(cx: &mut TestAppContext) {
569        let project = setup_syntax_test(cx).await;
570
571        // Test: Line inside a function body
572        let input = serde_json::to_value(GrepToolInput {
573            regex: "Function in nested module".to_string(),
574            include_pattern: Some("**/*.rs".to_string()),
575            offset: 0,
576            case_sensitive: false,
577        })
578        .unwrap();
579
580        let result = run_grep_tool(input, project.clone(), cx).await;
581        let expected = r#"
582            Found 1 matches:
583
584            ## Matches in root/test_syntax.rs
585
586            ### mod feature_module › pub mod nested_module › pub fn nested_function › L10-14
587            ```
588                    ) {
589                        println!("Function in nested module");
590                        println!("{first_arg}");
591                        println!("{second_arg}");
592                    }
593            ```
594            "#
595        .unindent();
596        assert_eq!(result, expected);
597    }
598
599    #[gpui::test]
600    async fn test_grep_function_args_and_body(cx: &mut TestAppContext) {
601        let project = setup_syntax_test(cx).await;
602
603        // Test: Line with a function argument
604        let input = serde_json::to_value(GrepToolInput {
605            regex: "second_arg".to_string(),
606            include_pattern: Some("**/*.rs".to_string()),
607            offset: 0,
608            case_sensitive: false,
609        })
610        .unwrap();
611
612        let result = run_grep_tool(input, project.clone(), cx).await;
613        let expected = r#"
614            Found 1 matches:
615
616            ## Matches in root/test_syntax.rs
617
618            ### mod feature_module › pub mod nested_module › pub fn nested_function › L7-14
619            ```
620                    pub fn nested_function(
621                        first_arg: String,
622                        second_arg: i32,
623                    ) {
624                        println!("Function in nested module");
625                        println!("{first_arg}");
626                        println!("{second_arg}");
627                    }
628            ```
629            "#
630        .unindent();
631        assert_eq!(result, expected);
632    }
633
634    #[gpui::test]
635    async fn test_grep_if_block(cx: &mut TestAppContext) {
636        use unindent::Unindent;
637        let project = setup_syntax_test(cx).await;
638
639        // Test: Line inside an if block
640        let input = serde_json::to_value(GrepToolInput {
641            regex: "Inside if block".to_string(),
642            include_pattern: Some("**/*.rs".to_string()),
643            offset: 0,
644            case_sensitive: false,
645        })
646        .unwrap();
647
648        let result = run_grep_tool(input, project.clone(), cx).await;
649        let expected = r#"
650            Found 1 matches:
651
652            ## Matches in root/test_syntax.rs
653
654            ### impl MyStruct › fn method_with_block › L26-28
655            ```
656                    if condition {
657                        println!("Inside if block");
658                    }
659            ```
660            "#
661        .unindent();
662        assert_eq!(result, expected);
663    }
664
665    #[gpui::test]
666    async fn test_grep_long_function_top(cx: &mut TestAppContext) {
667        use unindent::Unindent;
668        let project = setup_syntax_test(cx).await;
669
670        // Test: Line in the middle of a long function - should show message about remaining lines
671        let input = serde_json::to_value(GrepToolInput {
672            regex: "Line 5".to_string(),
673            include_pattern: Some("**/*.rs".to_string()),
674            offset: 0,
675            case_sensitive: false,
676        })
677        .unwrap();
678
679        let result = run_grep_tool(input, project.clone(), cx).await;
680        let expected = r#"
681            Found 1 matches:
682
683            ## Matches in root/test_syntax.rs
684
685            ### impl MyStruct › fn long_function › L31-41
686            ```
687                fn long_function() {
688                    println!("Line 1");
689                    println!("Line 2");
690                    println!("Line 3");
691                    println!("Line 4");
692                    println!("Line 5");
693                    println!("Line 6");
694                    println!("Line 7");
695                    println!("Line 8");
696                    println!("Line 9");
697                    println!("Line 10");
698            ```
699
700            3 lines remaining in ancestor node. Read the file to see all.
701            "#
702        .unindent();
703        assert_eq!(result, expected);
704    }
705
706    #[gpui::test]
707    async fn test_grep_long_function_bottom(cx: &mut TestAppContext) {
708        use unindent::Unindent;
709        let project = setup_syntax_test(cx).await;
710
711        // Test: Line in the long function
712        let input = serde_json::to_value(GrepToolInput {
713            regex: "Line 12".to_string(),
714            include_pattern: Some("**/*.rs".to_string()),
715            offset: 0,
716            case_sensitive: false,
717        })
718        .unwrap();
719
720        let result = run_grep_tool(input, project.clone(), cx).await;
721        let expected = r#"
722            Found 1 matches:
723
724            ## Matches in root/test_syntax.rs
725
726            ### impl MyStruct › fn long_function › L41-45
727            ```
728                    println!("Line 10");
729                    println!("Line 11");
730                    println!("Line 12");
731                }
732            }
733            ```
734            "#
735        .unindent();
736        assert_eq!(result, expected);
737    }
738
739    async fn run_grep_tool(
740        input: serde_json::Value,
741        project: Entity<Project>,
742        cx: &mut TestAppContext,
743    ) -> String {
744        let tool = Arc::new(GrepTool);
745        let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
746        let task = cx.update(|cx| tool.run(input, &[], project, action_log, None, cx));
747
748        match task.output.await {
749            Ok(result) => {
750                if cfg!(windows) {
751                    result.content.replace("root\\", "root/")
752                } else {
753                    result.content
754                }
755            }
756            Err(e) => panic!("Failed to run grep tool: {}", e),
757        }
758    }
759
760    fn init_test(cx: &mut TestAppContext) {
761        cx.update(|cx| {
762            let settings_store = SettingsStore::test(cx);
763            cx.set_global(settings_store);
764            language::init(cx);
765            Project::init_settings(cx);
766        });
767    }
768
769    fn rust_lang() -> Language {
770        Language::new(
771            LanguageConfig {
772                name: "Rust".into(),
773                matcher: LanguageMatcher {
774                    path_suffixes: vec!["rs".to_string()],
775                    ..Default::default()
776                },
777                ..Default::default()
778            },
779            Some(tree_sitter_rust::LANGUAGE.into()),
780        )
781        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
782        .unwrap()
783    }
784}