grep_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use futures::StreamExt;
  5use gpui::{AnyWindowHandle, App, Entity, Task};
  6use language::{OffsetRangeExt, ParseStatus, Point};
  7use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
  8use project::{
  9    Project,
 10    search::{SearchQuery, SearchResult},
 11};
 12use schemars::JsonSchema;
 13use serde::{Deserialize, Serialize};
 14use std::{cmp, fmt::Write, sync::Arc};
 15use ui::IconName;
 16use util::RangeExt;
 17use util::markdown::MarkdownInlineCode;
 18use util::paths::PathMatcher;
 19
 20#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 21pub struct GrepToolInput {
 22    /// A regex pattern to search for in the entire project. Note that the regex
 23    /// will be parsed by the Rust `regex` crate.
 24    ///
 25    /// Do NOT specify a path here! This will only be matched against the code **content**.
 26    pub regex: String,
 27
 28    /// A glob pattern for the paths of files to include in the search.
 29    /// Supports standard glob patterns like "**/*.rs" or "src/**/*.ts".
 30    /// If omitted, all files in the project will be searched.
 31    pub include_pattern: Option<String>,
 32
 33    /// Optional starting position for paginated results (0-based).
 34    /// When not provided, starts from the beginning.
 35    #[serde(default)]
 36    pub offset: u32,
 37
 38    /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
 39    #[serde(default)]
 40    pub case_sensitive: bool,
 41}
 42
 43impl GrepToolInput {
 44    /// Which page of search results this is.
 45    pub fn page(&self) -> u32 {
 46        1 + (self.offset / RESULTS_PER_PAGE)
 47    }
 48}
 49
 50const RESULTS_PER_PAGE: u32 = 20;
 51
 52pub struct GrepTool;
 53
 54impl Tool for GrepTool {
 55    fn name(&self) -> String {
 56        "grep".into()
 57    }
 58
 59    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 60        false
 61    }
 62
 63    fn may_perform_edits(&self) -> bool {
 64        false
 65    }
 66
 67    fn description(&self) -> String {
 68        include_str!("./grep_tool/description.md").into()
 69    }
 70
 71    fn icon(&self) -> IconName {
 72        IconName::Regex
 73    }
 74
 75    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 76        json_schema_for::<GrepToolInput>(format)
 77    }
 78
 79    fn ui_text(&self, input: &serde_json::Value) -> String {
 80        match serde_json::from_value::<GrepToolInput>(input.clone()) {
 81            Ok(input) => {
 82                let page = input.page();
 83                let regex_str = MarkdownInlineCode(&input.regex);
 84                let case_info = if input.case_sensitive {
 85                    " (case-sensitive)"
 86                } else {
 87                    ""
 88                };
 89
 90                if page > 1 {
 91                    format!("Get page {page} of search results for regex {regex_str}{case_info}")
 92                } else {
 93                    format!("Search files for regex {regex_str}{case_info}")
 94                }
 95            }
 96            Err(_) => "Search with regex".to_string(),
 97        }
 98    }
 99
100    fn run(
101        self: Arc<Self>,
102        input: serde_json::Value,
103        _request: Arc<LanguageModelRequest>,
104        project: Entity<Project>,
105        _action_log: Entity<ActionLog>,
106        _model: Arc<dyn LanguageModel>,
107        _window: Option<AnyWindowHandle>,
108        cx: &mut App,
109    ) -> ToolResult {
110        const CONTEXT_LINES: u32 = 2;
111        const MAX_ANCESTOR_LINES: u32 = 10;
112
113        let input = match serde_json::from_value::<GrepToolInput>(input) {
114            Ok(input) => input,
115            Err(error) => {
116                return Task::ready(Err(anyhow!("Failed to parse input: {error}"))).into();
117            }
118        };
119
120        let include_matcher = match PathMatcher::new(
121            input
122                .include_pattern
123                .as_ref()
124                .into_iter()
125                .collect::<Vec<_>>(),
126        ) {
127            Ok(matcher) => matcher,
128            Err(error) => {
129                return Task::ready(Err(anyhow!("invalid include glob pattern: {error}"))).into();
130            }
131        };
132
133        let query = match SearchQuery::regex(
134            &input.regex,
135            false,
136            input.case_sensitive,
137            false,
138            false,
139            include_matcher,
140            PathMatcher::default(), // For now, keep it simple and don't enable an exclude pattern.
141            true, // Always match file include pattern against *full project paths* that start with a project root.
142            None,
143        ) {
144            Ok(query) => query,
145            Err(error) => return Task::ready(Err(error)).into(),
146        };
147
148        let results = project.update(cx, |project, cx| project.search(query, cx));
149
150        cx.spawn(async move |cx|  {
151            futures::pin_mut!(results);
152
153            let mut output = String::new();
154            let mut skips_remaining = input.offset;
155            let mut matches_found = 0;
156            let mut has_more_matches = false;
157
158            'outer: while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
159                if ranges.is_empty() {
160                    continue;
161                }
162
163                let (Some(path), mut parse_status) = buffer.read_with(cx, |buffer, cx| {
164                    (buffer.file().map(|file| file.full_path(cx)), buffer.parse_status())
165                })? else {
166                    continue;
167                };
168
169
170                while *parse_status.borrow() != ParseStatus::Idle {
171                    parse_status.changed().await?;
172                }
173
174                let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
175
176                let mut ranges = ranges
177                    .into_iter()
178                    .map(|range| {
179                        let matched = range.to_point(&snapshot);
180                        let matched_end_line_len = snapshot.line_len(matched.end.row);
181                        let full_lines = Point::new(matched.start.row, 0)..Point::new(matched.end.row, matched_end_line_len);
182                        let symbols = snapshot.symbols_containing(matched.start, None);
183
184                        if let Some(ancestor_node) = snapshot.syntax_ancestor(full_lines.clone()) {
185                            let full_ancestor_range = ancestor_node.byte_range().to_point(&snapshot);
186                            let end_row = full_ancestor_range.end.row.min(full_ancestor_range.start.row + MAX_ANCESTOR_LINES);
187                            let end_col = snapshot.line_len(end_row);
188                            let capped_ancestor_range = Point::new(full_ancestor_range.start.row, 0)..Point::new(end_row, end_col);
189
190                            if capped_ancestor_range.contains_inclusive(&full_lines) {
191                                return (capped_ancestor_range, Some(full_ancestor_range), symbols)
192                            }
193                        }
194
195                        let mut matched = matched;
196                        matched.start.column = 0;
197                        matched.start.row =
198                            matched.start.row.saturating_sub(CONTEXT_LINES);
199                        matched.end.row = cmp::min(
200                            snapshot.max_point().row,
201                            matched.end.row + CONTEXT_LINES,
202                        );
203                        matched.end.column = snapshot.line_len(matched.end.row);
204
205                        (matched, None, symbols)
206                    })
207                    .peekable();
208
209                let mut file_header_written = false;
210
211                while let Some((mut range, ancestor_range, parent_symbols)) = ranges.next(){
212                    if skips_remaining > 0 {
213                        skips_remaining -= 1;
214                        continue;
215                    }
216
217                    // We'd already found a full page of matches, and we just found one more.
218                    if matches_found >= RESULTS_PER_PAGE {
219                        has_more_matches = true;
220                        break 'outer;
221                    }
222
223                    while let Some((next_range, _, _)) = ranges.peek() {
224                        if range.end.row >= next_range.start.row {
225                            range.end = next_range.end;
226                            ranges.next();
227                        } else {
228                            break;
229                        }
230                    }
231
232                    if !file_header_written {
233                        writeln!(output, "\n## Matches in {}", path.display())?;
234                        file_header_written = true;
235                    }
236
237                    let end_row = range.end.row;
238                    output.push_str("\n### ");
239
240                    if let Some(parent_symbols) = &parent_symbols {
241                        for symbol in parent_symbols {
242                            write!(output, "{} › ", symbol.text)?;
243                        }
244                    }
245
246                    if range.start.row == end_row {
247                        writeln!(output, "L{}", range.start.row + 1)?;
248                    } else {
249                        writeln!(output, "L{}-{}", range.start.row + 1, end_row + 1)?;
250                    }
251
252                    output.push_str("```\n");
253                    output.extend(snapshot.text_for_range(range));
254                    output.push_str("\n```\n");
255
256                    if let Some(ancestor_range) = ancestor_range {
257                        if end_row < ancestor_range.end.row {
258                            let remaining_lines = ancestor_range.end.row - end_row;
259                            writeln!(output, "\n{} lines remaining in ancestor node. Read the file to see all.", remaining_lines)?;
260                        }
261                    }
262
263                    matches_found += 1;
264                }
265            }
266
267            if matches_found == 0 {
268                Ok("No matches found".to_string().into())
269            } else if has_more_matches {
270                Ok(format!(
271                    "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
272                    input.offset + 1,
273                    input.offset + matches_found,
274                    input.offset + RESULTS_PER_PAGE,
275                ).into())
276            } else {
277                Ok(format!("Found {matches_found} matches:\n{output}").into())
278            }
279        }).into()
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use assistant_tool::Tool;
287    use gpui::{AppContext, TestAppContext};
288    use language::{Language, LanguageConfig, LanguageMatcher};
289    use language_model::fake_provider::FakeLanguageModel;
290    use project::{FakeFs, Project};
291    use settings::SettingsStore;
292    use unindent::Unindent;
293    use util::path;
294
295    #[gpui::test]
296    async fn test_grep_tool_with_include_pattern(cx: &mut TestAppContext) {
297        init_test(cx);
298        cx.executor().allow_parking();
299
300        let fs = FakeFs::new(cx.executor().clone());
301        fs.insert_tree(
302            "/root",
303            serde_json::json!({
304                "src": {
305                    "main.rs": "fn main() {\n    println!(\"Hello, world!\");\n}",
306                    "utils": {
307                        "helper.rs": "fn helper() {\n    println!(\"I'm a helper!\");\n}",
308                    },
309                },
310                "tests": {
311                    "test_main.rs": "fn test_main() {\n    assert!(true);\n}",
312                }
313            }),
314        )
315        .await;
316
317        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
318
319        // Test with include pattern for Rust files inside the root of the project
320        let input = serde_json::to_value(GrepToolInput {
321            regex: "println".to_string(),
322            include_pattern: Some("root/**/*.rs".to_string()),
323            offset: 0,
324            case_sensitive: false,
325        })
326        .unwrap();
327
328        let result = run_grep_tool(input, project.clone(), cx).await;
329        assert!(result.contains("main.rs"), "Should find matches in main.rs");
330        assert!(
331            result.contains("helper.rs"),
332            "Should find matches in helper.rs"
333        );
334        assert!(
335            !result.contains("test_main.rs"),
336            "Should not include test_main.rs even though it's a .rs file (because it doesn't have the pattern)"
337        );
338
339        // Test with include pattern for src directory only
340        let input = serde_json::to_value(GrepToolInput {
341            regex: "fn".to_string(),
342            include_pattern: Some("root/**/src/**".to_string()),
343            offset: 0,
344            case_sensitive: false,
345        })
346        .unwrap();
347
348        let result = run_grep_tool(input, project.clone(), cx).await;
349        assert!(
350            result.contains("main.rs"),
351            "Should find matches in src/main.rs"
352        );
353        assert!(
354            result.contains("helper.rs"),
355            "Should find matches in src/utils/helper.rs"
356        );
357        assert!(
358            !result.contains("test_main.rs"),
359            "Should not include test_main.rs as it's not in src directory"
360        );
361
362        // Test with empty include pattern (should default to all files)
363        let input = serde_json::to_value(GrepToolInput {
364            regex: "fn".to_string(),
365            include_pattern: None,
366            offset: 0,
367            case_sensitive: false,
368        })
369        .unwrap();
370
371        let result = run_grep_tool(input, project.clone(), cx).await;
372        assert!(result.contains("main.rs"), "Should find matches in main.rs");
373        assert!(
374            result.contains("helper.rs"),
375            "Should find matches in helper.rs"
376        );
377        assert!(
378            result.contains("test_main.rs"),
379            "Should include test_main.rs"
380        );
381    }
382
383    #[gpui::test]
384    async fn test_grep_tool_with_case_sensitivity(cx: &mut TestAppContext) {
385        init_test(cx);
386        cx.executor().allow_parking();
387
388        let fs = FakeFs::new(cx.executor().clone());
389        fs.insert_tree(
390            "/root",
391            serde_json::json!({
392                "case_test.txt": "This file has UPPERCASE and lowercase text.\nUPPERCASE patterns should match only with case_sensitive: true",
393            }),
394        )
395        .await;
396
397        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
398
399        // Test case-insensitive search (default)
400        let input = serde_json::to_value(GrepToolInput {
401            regex: "uppercase".to_string(),
402            include_pattern: Some("**/*.txt".to_string()),
403            offset: 0,
404            case_sensitive: false,
405        })
406        .unwrap();
407
408        let result = run_grep_tool(input, project.clone(), cx).await;
409        assert!(
410            result.contains("UPPERCASE"),
411            "Case-insensitive search should match uppercase"
412        );
413
414        // Test case-sensitive search
415        let input = serde_json::to_value(GrepToolInput {
416            regex: "uppercase".to_string(),
417            include_pattern: Some("**/*.txt".to_string()),
418            offset: 0,
419            case_sensitive: true,
420        })
421        .unwrap();
422
423        let result = run_grep_tool(input, project.clone(), cx).await;
424        assert!(
425            !result.contains("UPPERCASE"),
426            "Case-sensitive search should not match uppercase"
427        );
428
429        // Test case-sensitive search
430        let input = serde_json::to_value(GrepToolInput {
431            regex: "LOWERCASE".to_string(),
432            include_pattern: Some("**/*.txt".to_string()),
433            offset: 0,
434            case_sensitive: true,
435        })
436        .unwrap();
437
438        let result = run_grep_tool(input, project.clone(), cx).await;
439
440        assert!(
441            !result.contains("lowercase"),
442            "Case-sensitive search should match lowercase"
443        );
444
445        // Test case-sensitive search for lowercase pattern
446        let input = serde_json::to_value(GrepToolInput {
447            regex: "lowercase".to_string(),
448            include_pattern: Some("**/*.txt".to_string()),
449            offset: 0,
450            case_sensitive: true,
451        })
452        .unwrap();
453
454        let result = run_grep_tool(input, project.clone(), cx).await;
455        assert!(
456            result.contains("lowercase"),
457            "Case-sensitive search should match lowercase text"
458        );
459    }
460
461    /// Helper function to set up a syntax test environment
462    async fn setup_syntax_test(cx: &mut TestAppContext) -> Entity<Project> {
463        use unindent::Unindent;
464        init_test(cx);
465        cx.executor().allow_parking();
466
467        let fs = FakeFs::new(cx.executor().clone());
468
469        // Create test file with syntax structures
470        fs.insert_tree(
471            "/root",
472            serde_json::json!({
473                "test_syntax.rs": r#"
474                    fn top_level_function() {
475                        println!("This is at the top level");
476                    }
477
478                    mod feature_module {
479                        pub mod nested_module {
480                            pub fn nested_function(
481                                first_arg: String,
482                                second_arg: i32,
483                            ) {
484                                println!("Function in nested module");
485                                println!("{first_arg}");
486                                println!("{second_arg}");
487                            }
488                        }
489                    }
490
491                    struct MyStruct {
492                        field1: String,
493                        field2: i32,
494                    }
495
496                    impl MyStruct {
497                        fn method_with_block() {
498                            let condition = true;
499                            if condition {
500                                println!("Inside if block");
501                            }
502                        }
503
504                        fn long_function() {
505                            println!("Line 1");
506                            println!("Line 2");
507                            println!("Line 3");
508                            println!("Line 4");
509                            println!("Line 5");
510                            println!("Line 6");
511                            println!("Line 7");
512                            println!("Line 8");
513                            println!("Line 9");
514                            println!("Line 10");
515                            println!("Line 11");
516                            println!("Line 12");
517                        }
518                    }
519
520                    trait Processor {
521                        fn process(&self, input: &str) -> String;
522                    }
523
524                    impl Processor for MyStruct {
525                        fn process(&self, input: &str) -> String {
526                            format!("Processed: {}", input)
527                        }
528                    }
529                "#.unindent().trim(),
530            }),
531        )
532        .await;
533
534        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
535
536        project.update(cx, |project, _cx| {
537            project.languages().add(rust_lang().into())
538        });
539
540        project
541    }
542
543    #[gpui::test]
544    async fn test_grep_top_level_function(cx: &mut TestAppContext) {
545        let project = setup_syntax_test(cx).await;
546
547        // Test: Line at the top level of the file
548        let input = serde_json::to_value(GrepToolInput {
549            regex: "This is at the top level".to_string(),
550            include_pattern: Some("**/*.rs".to_string()),
551            offset: 0,
552            case_sensitive: false,
553        })
554        .unwrap();
555
556        let result = run_grep_tool(input, project.clone(), cx).await;
557        let expected = r#"
558            Found 1 matches:
559
560            ## Matches in root/test_syntax.rs
561
562            ### fn top_level_function › L1-3
563            ```
564            fn top_level_function() {
565                println!("This is at the top level");
566            }
567            ```
568            "#
569        .unindent();
570        assert_eq!(result, expected);
571    }
572
573    #[gpui::test]
574    async fn test_grep_function_body(cx: &mut TestAppContext) {
575        let project = setup_syntax_test(cx).await;
576
577        // Test: Line inside a function body
578        let input = serde_json::to_value(GrepToolInput {
579            regex: "Function in nested module".to_string(),
580            include_pattern: Some("**/*.rs".to_string()),
581            offset: 0,
582            case_sensitive: false,
583        })
584        .unwrap();
585
586        let result = run_grep_tool(input, project.clone(), cx).await;
587        let expected = r#"
588            Found 1 matches:
589
590            ## Matches in root/test_syntax.rs
591
592            ### mod feature_module › pub mod nested_module › pub fn nested_function › L10-14
593            ```
594                    ) {
595                        println!("Function in nested module");
596                        println!("{first_arg}");
597                        println!("{second_arg}");
598                    }
599            ```
600            "#
601        .unindent();
602        assert_eq!(result, expected);
603    }
604
605    #[gpui::test]
606    async fn test_grep_function_args_and_body(cx: &mut TestAppContext) {
607        let project = setup_syntax_test(cx).await;
608
609        // Test: Line with a function argument
610        let input = serde_json::to_value(GrepToolInput {
611            regex: "second_arg".to_string(),
612            include_pattern: Some("**/*.rs".to_string()),
613            offset: 0,
614            case_sensitive: false,
615        })
616        .unwrap();
617
618        let result = run_grep_tool(input, project.clone(), cx).await;
619        let expected = r#"
620            Found 1 matches:
621
622            ## Matches in root/test_syntax.rs
623
624            ### mod feature_module › pub mod nested_module › pub fn nested_function › L7-14
625            ```
626                    pub fn nested_function(
627                        first_arg: String,
628                        second_arg: i32,
629                    ) {
630                        println!("Function in nested module");
631                        println!("{first_arg}");
632                        println!("{second_arg}");
633                    }
634            ```
635            "#
636        .unindent();
637        assert_eq!(result, expected);
638    }
639
640    #[gpui::test]
641    async fn test_grep_if_block(cx: &mut TestAppContext) {
642        use unindent::Unindent;
643        let project = setup_syntax_test(cx).await;
644
645        // Test: Line inside an if block
646        let input = serde_json::to_value(GrepToolInput {
647            regex: "Inside if block".to_string(),
648            include_pattern: Some("**/*.rs".to_string()),
649            offset: 0,
650            case_sensitive: false,
651        })
652        .unwrap();
653
654        let result = run_grep_tool(input, project.clone(), cx).await;
655        let expected = r#"
656            Found 1 matches:
657
658            ## Matches in root/test_syntax.rs
659
660            ### impl MyStruct › fn method_with_block › L26-28
661            ```
662                    if condition {
663                        println!("Inside if block");
664                    }
665            ```
666            "#
667        .unindent();
668        assert_eq!(result, expected);
669    }
670
671    #[gpui::test]
672    async fn test_grep_long_function_top(cx: &mut TestAppContext) {
673        use unindent::Unindent;
674        let project = setup_syntax_test(cx).await;
675
676        // Test: Line in the middle of a long function - should show message about remaining lines
677        let input = serde_json::to_value(GrepToolInput {
678            regex: "Line 5".to_string(),
679            include_pattern: Some("**/*.rs".to_string()),
680            offset: 0,
681            case_sensitive: false,
682        })
683        .unwrap();
684
685        let result = run_grep_tool(input, project.clone(), cx).await;
686        let expected = r#"
687            Found 1 matches:
688
689            ## Matches in root/test_syntax.rs
690
691            ### impl MyStruct › fn long_function › L31-41
692            ```
693                fn long_function() {
694                    println!("Line 1");
695                    println!("Line 2");
696                    println!("Line 3");
697                    println!("Line 4");
698                    println!("Line 5");
699                    println!("Line 6");
700                    println!("Line 7");
701                    println!("Line 8");
702                    println!("Line 9");
703                    println!("Line 10");
704            ```
705
706            3 lines remaining in ancestor node. Read the file to see all.
707            "#
708        .unindent();
709        assert_eq!(result, expected);
710    }
711
712    #[gpui::test]
713    async fn test_grep_long_function_bottom(cx: &mut TestAppContext) {
714        use unindent::Unindent;
715        let project = setup_syntax_test(cx).await;
716
717        // Test: Line in the long function
718        let input = serde_json::to_value(GrepToolInput {
719            regex: "Line 12".to_string(),
720            include_pattern: Some("**/*.rs".to_string()),
721            offset: 0,
722            case_sensitive: false,
723        })
724        .unwrap();
725
726        let result = run_grep_tool(input, project.clone(), cx).await;
727        let expected = r#"
728            Found 1 matches:
729
730            ## Matches in root/test_syntax.rs
731
732            ### impl MyStruct › fn long_function › L41-45
733            ```
734                    println!("Line 10");
735                    println!("Line 11");
736                    println!("Line 12");
737                }
738            }
739            ```
740            "#
741        .unindent();
742        assert_eq!(result, expected);
743    }
744
745    async fn run_grep_tool(
746        input: serde_json::Value,
747        project: Entity<Project>,
748        cx: &mut TestAppContext,
749    ) -> String {
750        let tool = Arc::new(GrepTool);
751        let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
752        let model = Arc::new(FakeLanguageModel::default());
753        let task =
754            cx.update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx));
755
756        match task.output.await {
757            Ok(result) => {
758                if cfg!(windows) {
759                    result.content.as_str().unwrap().replace("root\\", "root/")
760                } else {
761                    result.content.as_str().unwrap().to_string()
762                }
763            }
764            Err(e) => panic!("Failed to run grep tool: {}", e),
765        }
766    }
767
768    fn init_test(cx: &mut TestAppContext) {
769        cx.update(|cx| {
770            let settings_store = SettingsStore::test(cx);
771            cx.set_global(settings_store);
772            language::init(cx);
773            Project::init_settings(cx);
774        });
775    }
776
777    fn rust_lang() -> Language {
778        Language::new(
779            LanguageConfig {
780                name: "Rust".into(),
781                matcher: LanguageMatcher {
782                    path_suffixes: vec!["rs".to_string()],
783                    ..Default::default()
784                },
785                ..Default::default()
786            },
787            Some(tree_sitter_rust::LANGUAGE.into()),
788        )
789        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
790        .unwrap()
791    }
792}