grep_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use futures::StreamExt;
  5use gpui::{AnyWindowHandle, App, Entity, Task};
  6use language::{OffsetRangeExt, ParseStatus, Point};
  7use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
  8use project::{
  9    Project,
 10    search::{SearchQuery, SearchResult},
 11};
 12use schemars::JsonSchema;
 13use serde::{Deserialize, Serialize};
 14use std::{cmp, fmt::Write, sync::Arc};
 15use ui::IconName;
 16use util::RangeExt;
 17use util::markdown::MarkdownInlineCode;
 18use util::paths::PathMatcher;
 19
 20#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 21pub struct GrepToolInput {
 22    /// A regex pattern to search for in the entire project. Note that the regex
 23    /// will be parsed by the Rust `regex` crate.
 24    ///
 25    /// Do NOT specify a path here! This will only be matched against the code **content**.
 26    pub regex: String,
 27
 28    /// A glob pattern for the paths of files to include in the search.
 29    /// Supports standard glob patterns like "**/*.rs" or "src/**/*.ts".
 30    /// If omitted, all files in the project will be searched.
 31    pub include_pattern: Option<String>,
 32
 33    /// Optional starting position for paginated results (0-based).
 34    /// When not provided, starts from the beginning.
 35    #[serde(default)]
 36    pub offset: u32,
 37
 38    /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
 39    #[serde(default)]
 40    pub case_sensitive: bool,
 41}
 42
 43impl GrepToolInput {
 44    /// Which page of search results this is.
 45    pub fn page(&self) -> u32 {
 46        1 + (self.offset / RESULTS_PER_PAGE)
 47    }
 48}
 49
 50const RESULTS_PER_PAGE: u32 = 20;
 51
 52pub struct GrepTool;
 53
 54impl Tool for GrepTool {
 55    fn name(&self) -> String {
 56        "grep".into()
 57    }
 58
 59    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 60        false
 61    }
 62
 63    fn description(&self) -> String {
 64        include_str!("./grep_tool/description.md").into()
 65    }
 66
 67    fn icon(&self) -> IconName {
 68        IconName::Regex
 69    }
 70
 71    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 72        json_schema_for::<GrepToolInput>(format)
 73    }
 74
 75    fn ui_text(&self, input: &serde_json::Value) -> String {
 76        match serde_json::from_value::<GrepToolInput>(input.clone()) {
 77            Ok(input) => {
 78                let page = input.page();
 79                let regex_str = MarkdownInlineCode(&input.regex);
 80                let case_info = if input.case_sensitive {
 81                    " (case-sensitive)"
 82                } else {
 83                    ""
 84                };
 85
 86                if page > 1 {
 87                    format!("Get page {page} of search results for regex {regex_str}{case_info}")
 88                } else {
 89                    format!("Search files for regex {regex_str}{case_info}")
 90                }
 91            }
 92            Err(_) => "Search with regex".to_string(),
 93        }
 94    }
 95
 96    fn run(
 97        self: Arc<Self>,
 98        input: serde_json::Value,
 99        _request: Arc<LanguageModelRequest>,
100        project: Entity<Project>,
101        _action_log: Entity<ActionLog>,
102        _model: Arc<dyn LanguageModel>,
103        _window: Option<AnyWindowHandle>,
104        cx: &mut App,
105    ) -> ToolResult {
106        const CONTEXT_LINES: u32 = 2;
107        const MAX_ANCESTOR_LINES: u32 = 10;
108
109        let input = match serde_json::from_value::<GrepToolInput>(input) {
110            Ok(input) => input,
111            Err(error) => {
112                return Task::ready(Err(anyhow!("Failed to parse input: {error}"))).into();
113            }
114        };
115
116        let include_matcher = match PathMatcher::new(
117            input
118                .include_pattern
119                .as_ref()
120                .into_iter()
121                .collect::<Vec<_>>(),
122        ) {
123            Ok(matcher) => matcher,
124            Err(error) => {
125                return Task::ready(Err(anyhow!("invalid include glob pattern: {error}"))).into();
126            }
127        };
128
129        let query = match SearchQuery::regex(
130            &input.regex,
131            false,
132            input.case_sensitive,
133            false,
134            false,
135            include_matcher,
136            PathMatcher::default(), // For now, keep it simple and don't enable an exclude pattern.
137            true, // Always match file include pattern against *full project paths* that start with a project root.
138            None,
139        ) {
140            Ok(query) => query,
141            Err(error) => return Task::ready(Err(error)).into(),
142        };
143
144        let results = project.update(cx, |project, cx| project.search(query, cx));
145
146        cx.spawn(async move |cx|  {
147            futures::pin_mut!(results);
148
149            let mut output = String::new();
150            let mut skips_remaining = input.offset;
151            let mut matches_found = 0;
152            let mut has_more_matches = false;
153
154            'outer: while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
155                if ranges.is_empty() {
156                    continue;
157                }
158
159                let (Some(path), mut parse_status) = buffer.read_with(cx, |buffer, cx| {
160                    (buffer.file().map(|file| file.full_path(cx)), buffer.parse_status())
161                })? else {
162                    continue;
163                };
164
165
166                while *parse_status.borrow() != ParseStatus::Idle {
167                    parse_status.changed().await?;
168                }
169
170                let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
171
172                let mut ranges = ranges
173                    .into_iter()
174                    .map(|range| {
175                        let matched = range.to_point(&snapshot);
176                        let matched_end_line_len = snapshot.line_len(matched.end.row);
177                        let full_lines = Point::new(matched.start.row, 0)..Point::new(matched.end.row, matched_end_line_len);
178                        let symbols = snapshot.symbols_containing(matched.start, None);
179
180                        if let Some(ancestor_node) = snapshot.syntax_ancestor(full_lines.clone()) {
181                            let full_ancestor_range = ancestor_node.byte_range().to_point(&snapshot);
182                            let end_row = full_ancestor_range.end.row.min(full_ancestor_range.start.row + MAX_ANCESTOR_LINES);
183                            let end_col = snapshot.line_len(end_row);
184                            let capped_ancestor_range = Point::new(full_ancestor_range.start.row, 0)..Point::new(end_row, end_col);
185
186                            if capped_ancestor_range.contains_inclusive(&full_lines) {
187                                return (capped_ancestor_range, Some(full_ancestor_range), symbols)
188                            }
189                        }
190
191                        let mut matched = matched;
192                        matched.start.column = 0;
193                        matched.start.row =
194                            matched.start.row.saturating_sub(CONTEXT_LINES);
195                        matched.end.row = cmp::min(
196                            snapshot.max_point().row,
197                            matched.end.row + CONTEXT_LINES,
198                        );
199                        matched.end.column = snapshot.line_len(matched.end.row);
200
201                        (matched, None, symbols)
202                    })
203                    .peekable();
204
205                let mut file_header_written = false;
206
207                while let Some((mut range, ancestor_range, parent_symbols)) = ranges.next(){
208                    if skips_remaining > 0 {
209                        skips_remaining -= 1;
210                        continue;
211                    }
212
213                    // We'd already found a full page of matches, and we just found one more.
214                    if matches_found >= RESULTS_PER_PAGE {
215                        has_more_matches = true;
216                        break 'outer;
217                    }
218
219                    while let Some((next_range, _, _)) = ranges.peek() {
220                        if range.end.row >= next_range.start.row {
221                            range.end = next_range.end;
222                            ranges.next();
223                        } else {
224                            break;
225                        }
226                    }
227
228                    if !file_header_written {
229                        writeln!(output, "\n## Matches in {}", path.display())?;
230                        file_header_written = true;
231                    }
232
233                    let end_row = range.end.row;
234                    output.push_str("\n### ");
235
236                    if let Some(parent_symbols) = &parent_symbols {
237                        for symbol in parent_symbols {
238                            write!(output, "{} › ", symbol.text)?;
239                        }
240                    }
241
242                    if range.start.row == end_row {
243                        writeln!(output, "L{}", range.start.row + 1)?;
244                    } else {
245                        writeln!(output, "L{}-{}", range.start.row + 1, end_row + 1)?;
246                    }
247
248                    output.push_str("```\n");
249                    output.extend(snapshot.text_for_range(range));
250                    output.push_str("\n```\n");
251
252                    if let Some(ancestor_range) = ancestor_range {
253                        if end_row < ancestor_range.end.row {
254                            let remaining_lines = ancestor_range.end.row - end_row;
255                            writeln!(output, "\n{} lines remaining in ancestor node. Read the file to see all.", remaining_lines)?;
256                        }
257                    }
258
259                    matches_found += 1;
260                }
261            }
262
263            if matches_found == 0 {
264                Ok("No matches found".to_string().into())
265            } else if has_more_matches {
266                Ok(format!(
267                    "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
268                    input.offset + 1,
269                    input.offset + matches_found,
270                    input.offset + RESULTS_PER_PAGE,
271                ).into())
272            } else {
273                Ok(format!("Found {matches_found} matches:\n{output}").into())
274            }
275        }).into()
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use assistant_tool::Tool;
283    use gpui::{AppContext, TestAppContext};
284    use language::{Language, LanguageConfig, LanguageMatcher};
285    use language_model::fake_provider::FakeLanguageModel;
286    use project::{FakeFs, Project};
287    use settings::SettingsStore;
288    use unindent::Unindent;
289    use util::path;
290
291    #[gpui::test]
292    async fn test_grep_tool_with_include_pattern(cx: &mut TestAppContext) {
293        init_test(cx);
294        cx.executor().allow_parking();
295
296        let fs = FakeFs::new(cx.executor().clone());
297        fs.insert_tree(
298            "/root",
299            serde_json::json!({
300                "src": {
301                    "main.rs": "fn main() {\n    println!(\"Hello, world!\");\n}",
302                    "utils": {
303                        "helper.rs": "fn helper() {\n    println!(\"I'm a helper!\");\n}",
304                    },
305                },
306                "tests": {
307                    "test_main.rs": "fn test_main() {\n    assert!(true);\n}",
308                }
309            }),
310        )
311        .await;
312
313        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
314
315        // Test with include pattern for Rust files inside the root of the project
316        let input = serde_json::to_value(GrepToolInput {
317            regex: "println".to_string(),
318            include_pattern: Some("root/**/*.rs".to_string()),
319            offset: 0,
320            case_sensitive: false,
321        })
322        .unwrap();
323
324        let result = run_grep_tool(input, project.clone(), cx).await;
325        assert!(result.contains("main.rs"), "Should find matches in main.rs");
326        assert!(
327            result.contains("helper.rs"),
328            "Should find matches in helper.rs"
329        );
330        assert!(
331            !result.contains("test_main.rs"),
332            "Should not include test_main.rs even though it's a .rs file (because it doesn't have the pattern)"
333        );
334
335        // Test with include pattern for src directory only
336        let input = serde_json::to_value(GrepToolInput {
337            regex: "fn".to_string(),
338            include_pattern: Some("root/**/src/**".to_string()),
339            offset: 0,
340            case_sensitive: false,
341        })
342        .unwrap();
343
344        let result = run_grep_tool(input, project.clone(), cx).await;
345        assert!(
346            result.contains("main.rs"),
347            "Should find matches in src/main.rs"
348        );
349        assert!(
350            result.contains("helper.rs"),
351            "Should find matches in src/utils/helper.rs"
352        );
353        assert!(
354            !result.contains("test_main.rs"),
355            "Should not include test_main.rs as it's not in src directory"
356        );
357
358        // Test with empty include pattern (should default to all files)
359        let input = serde_json::to_value(GrepToolInput {
360            regex: "fn".to_string(),
361            include_pattern: None,
362            offset: 0,
363            case_sensitive: false,
364        })
365        .unwrap();
366
367        let result = run_grep_tool(input, project.clone(), cx).await;
368        assert!(result.contains("main.rs"), "Should find matches in main.rs");
369        assert!(
370            result.contains("helper.rs"),
371            "Should find matches in helper.rs"
372        );
373        assert!(
374            result.contains("test_main.rs"),
375            "Should include test_main.rs"
376        );
377    }
378
379    #[gpui::test]
380    async fn test_grep_tool_with_case_sensitivity(cx: &mut TestAppContext) {
381        init_test(cx);
382        cx.executor().allow_parking();
383
384        let fs = FakeFs::new(cx.executor().clone());
385        fs.insert_tree(
386            "/root",
387            serde_json::json!({
388                "case_test.txt": "This file has UPPERCASE and lowercase text.\nUPPERCASE patterns should match only with case_sensitive: true",
389            }),
390        )
391        .await;
392
393        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
394
395        // Test case-insensitive search (default)
396        let input = serde_json::to_value(GrepToolInput {
397            regex: "uppercase".to_string(),
398            include_pattern: Some("**/*.txt".to_string()),
399            offset: 0,
400            case_sensitive: false,
401        })
402        .unwrap();
403
404        let result = run_grep_tool(input, project.clone(), cx).await;
405        assert!(
406            result.contains("UPPERCASE"),
407            "Case-insensitive search should match uppercase"
408        );
409
410        // Test case-sensitive search
411        let input = serde_json::to_value(GrepToolInput {
412            regex: "uppercase".to_string(),
413            include_pattern: Some("**/*.txt".to_string()),
414            offset: 0,
415            case_sensitive: true,
416        })
417        .unwrap();
418
419        let result = run_grep_tool(input, project.clone(), cx).await;
420        assert!(
421            !result.contains("UPPERCASE"),
422            "Case-sensitive search should not match uppercase"
423        );
424
425        // Test case-sensitive search
426        let input = serde_json::to_value(GrepToolInput {
427            regex: "LOWERCASE".to_string(),
428            include_pattern: Some("**/*.txt".to_string()),
429            offset: 0,
430            case_sensitive: true,
431        })
432        .unwrap();
433
434        let result = run_grep_tool(input, project.clone(), cx).await;
435
436        assert!(
437            !result.contains("lowercase"),
438            "Case-sensitive search should match lowercase"
439        );
440
441        // Test case-sensitive search for lowercase pattern
442        let input = serde_json::to_value(GrepToolInput {
443            regex: "lowercase".to_string(),
444            include_pattern: Some("**/*.txt".to_string()),
445            offset: 0,
446            case_sensitive: true,
447        })
448        .unwrap();
449
450        let result = run_grep_tool(input, project.clone(), cx).await;
451        assert!(
452            result.contains("lowercase"),
453            "Case-sensitive search should match lowercase text"
454        );
455    }
456
457    /// Helper function to set up a syntax test environment
458    async fn setup_syntax_test(cx: &mut TestAppContext) -> Entity<Project> {
459        use unindent::Unindent;
460        init_test(cx);
461        cx.executor().allow_parking();
462
463        let fs = FakeFs::new(cx.executor().clone());
464
465        // Create test file with syntax structures
466        fs.insert_tree(
467            "/root",
468            serde_json::json!({
469                "test_syntax.rs": r#"
470                    fn top_level_function() {
471                        println!("This is at the top level");
472                    }
473
474                    mod feature_module {
475                        pub mod nested_module {
476                            pub fn nested_function(
477                                first_arg: String,
478                                second_arg: i32,
479                            ) {
480                                println!("Function in nested module");
481                                println!("{first_arg}");
482                                println!("{second_arg}");
483                            }
484                        }
485                    }
486
487                    struct MyStruct {
488                        field1: String,
489                        field2: i32,
490                    }
491
492                    impl MyStruct {
493                        fn method_with_block() {
494                            let condition = true;
495                            if condition {
496                                println!("Inside if block");
497                            }
498                        }
499
500                        fn long_function() {
501                            println!("Line 1");
502                            println!("Line 2");
503                            println!("Line 3");
504                            println!("Line 4");
505                            println!("Line 5");
506                            println!("Line 6");
507                            println!("Line 7");
508                            println!("Line 8");
509                            println!("Line 9");
510                            println!("Line 10");
511                            println!("Line 11");
512                            println!("Line 12");
513                        }
514                    }
515
516                    trait Processor {
517                        fn process(&self, input: &str) -> String;
518                    }
519
520                    impl Processor for MyStruct {
521                        fn process(&self, input: &str) -> String {
522                            format!("Processed: {}", input)
523                        }
524                    }
525                "#.unindent().trim(),
526            }),
527        )
528        .await;
529
530        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
531
532        project.update(cx, |project, _cx| {
533            project.languages().add(rust_lang().into())
534        });
535
536        project
537    }
538
539    #[gpui::test]
540    async fn test_grep_top_level_function(cx: &mut TestAppContext) {
541        let project = setup_syntax_test(cx).await;
542
543        // Test: Line at the top level of the file
544        let input = serde_json::to_value(GrepToolInput {
545            regex: "This is at the top level".to_string(),
546            include_pattern: Some("**/*.rs".to_string()),
547            offset: 0,
548            case_sensitive: false,
549        })
550        .unwrap();
551
552        let result = run_grep_tool(input, project.clone(), cx).await;
553        let expected = r#"
554            Found 1 matches:
555
556            ## Matches in root/test_syntax.rs
557
558            ### fn top_level_function › L1-3
559            ```
560            fn top_level_function() {
561                println!("This is at the top level");
562            }
563            ```
564            "#
565        .unindent();
566        assert_eq!(result, expected);
567    }
568
569    #[gpui::test]
570    async fn test_grep_function_body(cx: &mut TestAppContext) {
571        let project = setup_syntax_test(cx).await;
572
573        // Test: Line inside a function body
574        let input = serde_json::to_value(GrepToolInput {
575            regex: "Function in nested module".to_string(),
576            include_pattern: Some("**/*.rs".to_string()),
577            offset: 0,
578            case_sensitive: false,
579        })
580        .unwrap();
581
582        let result = run_grep_tool(input, project.clone(), cx).await;
583        let expected = r#"
584            Found 1 matches:
585
586            ## Matches in root/test_syntax.rs
587
588            ### mod feature_module › pub mod nested_module › pub fn nested_function › L10-14
589            ```
590                    ) {
591                        println!("Function in nested module");
592                        println!("{first_arg}");
593                        println!("{second_arg}");
594                    }
595            ```
596            "#
597        .unindent();
598        assert_eq!(result, expected);
599    }
600
601    #[gpui::test]
602    async fn test_grep_function_args_and_body(cx: &mut TestAppContext) {
603        let project = setup_syntax_test(cx).await;
604
605        // Test: Line with a function argument
606        let input = serde_json::to_value(GrepToolInput {
607            regex: "second_arg".to_string(),
608            include_pattern: Some("**/*.rs".to_string()),
609            offset: 0,
610            case_sensitive: false,
611        })
612        .unwrap();
613
614        let result = run_grep_tool(input, project.clone(), cx).await;
615        let expected = r#"
616            Found 1 matches:
617
618            ## Matches in root/test_syntax.rs
619
620            ### mod feature_module › pub mod nested_module › pub fn nested_function › L7-14
621            ```
622                    pub fn nested_function(
623                        first_arg: String,
624                        second_arg: i32,
625                    ) {
626                        println!("Function in nested module");
627                        println!("{first_arg}");
628                        println!("{second_arg}");
629                    }
630            ```
631            "#
632        .unindent();
633        assert_eq!(result, expected);
634    }
635
636    #[gpui::test]
637    async fn test_grep_if_block(cx: &mut TestAppContext) {
638        use unindent::Unindent;
639        let project = setup_syntax_test(cx).await;
640
641        // Test: Line inside an if block
642        let input = serde_json::to_value(GrepToolInput {
643            regex: "Inside if block".to_string(),
644            include_pattern: Some("**/*.rs".to_string()),
645            offset: 0,
646            case_sensitive: false,
647        })
648        .unwrap();
649
650        let result = run_grep_tool(input, project.clone(), cx).await;
651        let expected = r#"
652            Found 1 matches:
653
654            ## Matches in root/test_syntax.rs
655
656            ### impl MyStruct › fn method_with_block › L26-28
657            ```
658                    if condition {
659                        println!("Inside if block");
660                    }
661            ```
662            "#
663        .unindent();
664        assert_eq!(result, expected);
665    }
666
667    #[gpui::test]
668    async fn test_grep_long_function_top(cx: &mut TestAppContext) {
669        use unindent::Unindent;
670        let project = setup_syntax_test(cx).await;
671
672        // Test: Line in the middle of a long function - should show message about remaining lines
673        let input = serde_json::to_value(GrepToolInput {
674            regex: "Line 5".to_string(),
675            include_pattern: Some("**/*.rs".to_string()),
676            offset: 0,
677            case_sensitive: false,
678        })
679        .unwrap();
680
681        let result = run_grep_tool(input, project.clone(), cx).await;
682        let expected = r#"
683            Found 1 matches:
684
685            ## Matches in root/test_syntax.rs
686
687            ### impl MyStruct › fn long_function › L31-41
688            ```
689                fn long_function() {
690                    println!("Line 1");
691                    println!("Line 2");
692                    println!("Line 3");
693                    println!("Line 4");
694                    println!("Line 5");
695                    println!("Line 6");
696                    println!("Line 7");
697                    println!("Line 8");
698                    println!("Line 9");
699                    println!("Line 10");
700            ```
701
702            3 lines remaining in ancestor node. Read the file to see all.
703            "#
704        .unindent();
705        assert_eq!(result, expected);
706    }
707
708    #[gpui::test]
709    async fn test_grep_long_function_bottom(cx: &mut TestAppContext) {
710        use unindent::Unindent;
711        let project = setup_syntax_test(cx).await;
712
713        // Test: Line in the long function
714        let input = serde_json::to_value(GrepToolInput {
715            regex: "Line 12".to_string(),
716            include_pattern: Some("**/*.rs".to_string()),
717            offset: 0,
718            case_sensitive: false,
719        })
720        .unwrap();
721
722        let result = run_grep_tool(input, project.clone(), cx).await;
723        let expected = r#"
724            Found 1 matches:
725
726            ## Matches in root/test_syntax.rs
727
728            ### impl MyStruct › fn long_function › L41-45
729            ```
730                    println!("Line 10");
731                    println!("Line 11");
732                    println!("Line 12");
733                }
734            }
735            ```
736            "#
737        .unindent();
738        assert_eq!(result, expected);
739    }
740
741    async fn run_grep_tool(
742        input: serde_json::Value,
743        project: Entity<Project>,
744        cx: &mut TestAppContext,
745    ) -> String {
746        let tool = Arc::new(GrepTool);
747        let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
748        let model = Arc::new(FakeLanguageModel::default());
749        let task =
750            cx.update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx));
751
752        match task.output.await {
753            Ok(result) => {
754                if cfg!(windows) {
755                    result.content.as_str().unwrap().replace("root\\", "root/")
756                } else {
757                    result.content.as_str().unwrap().to_string()
758                }
759            }
760            Err(e) => panic!("Failed to run grep tool: {}", e),
761        }
762    }
763
764    fn init_test(cx: &mut TestAppContext) {
765        cx.update(|cx| {
766            let settings_store = SettingsStore::test(cx);
767            cx.set_global(settings_store);
768            language::init(cx);
769            Project::init_settings(cx);
770        });
771    }
772
773    fn rust_lang() -> Language {
774        Language::new(
775            LanguageConfig {
776                name: "Rust".into(),
777                matcher: LanguageMatcher {
778                    path_suffixes: vec!["rs".to_string()],
779                    ..Default::default()
780                },
781                ..Default::default()
782            },
783            Some(tree_sitter_rust::LANGUAGE.into()),
784        )
785        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
786        .unwrap()
787    }
788}