1use crate::schema::json_schema_for;
2use anyhow::{Result, anyhow};
3use assistant_tool::{ActionLog, Tool, ToolResult};
4use futures::StreamExt;
5use gpui::{AnyWindowHandle, App, Entity, Task};
6use language::{OffsetRangeExt, ParseStatus, Point};
7use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
8use project::{
9 Project,
10 search::{SearchQuery, SearchResult},
11};
12use schemars::JsonSchema;
13use serde::{Deserialize, Serialize};
14use std::{cmp, fmt::Write, sync::Arc};
15use ui::IconName;
16use util::RangeExt;
17use util::markdown::MarkdownInlineCode;
18use util::paths::PathMatcher;
19
20#[derive(Debug, Serialize, Deserialize, JsonSchema)]
21pub struct GrepToolInput {
22 /// A regex pattern to search for in the entire project. Note that the regex
23 /// will be parsed by the Rust `regex` crate.
24 ///
25 /// Do NOT specify a path here! This will only be matched against the code **content**.
26 pub regex: String,
27
28 /// A glob pattern for the paths of files to include in the search.
29 /// Supports standard glob patterns like "**/*.rs" or "src/**/*.ts".
30 /// If omitted, all files in the project will be searched.
31 pub include_pattern: Option<String>,
32
33 /// Optional starting position for paginated results (0-based).
34 /// When not provided, starts from the beginning.
35 #[serde(default)]
36 pub offset: u32,
37
38 /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
39 #[serde(default)]
40 pub case_sensitive: bool,
41}
42
43impl GrepToolInput {
44 /// Which page of search results this is.
45 pub fn page(&self) -> u32 {
46 1 + (self.offset / RESULTS_PER_PAGE)
47 }
48}
49
50const RESULTS_PER_PAGE: u32 = 20;
51
52pub struct GrepTool;
53
54impl Tool for GrepTool {
55 fn name(&self) -> String {
56 "grep".into()
57 }
58
59 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
60 false
61 }
62
63 fn may_perform_edits(&self) -> bool {
64 false
65 }
66
67 fn description(&self) -> String {
68 include_str!("./grep_tool/description.md").into()
69 }
70
71 fn icon(&self) -> IconName {
72 IconName::Regex
73 }
74
75 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
76 json_schema_for::<GrepToolInput>(format)
77 }
78
79 fn ui_text(&self, input: &serde_json::Value) -> String {
80 match serde_json::from_value::<GrepToolInput>(input.clone()) {
81 Ok(input) => {
82 let page = input.page();
83 let regex_str = MarkdownInlineCode(&input.regex);
84 let case_info = if input.case_sensitive {
85 " (case-sensitive)"
86 } else {
87 ""
88 };
89
90 if page > 1 {
91 format!("Get page {page} of search results for regex {regex_str}{case_info}")
92 } else {
93 format!("Search files for regex {regex_str}{case_info}")
94 }
95 }
96 Err(_) => "Search with regex".to_string(),
97 }
98 }
99
100 fn run(
101 self: Arc<Self>,
102 input: serde_json::Value,
103 _request: Arc<LanguageModelRequest>,
104 project: Entity<Project>,
105 _action_log: Entity<ActionLog>,
106 _model: Arc<dyn LanguageModel>,
107 _window: Option<AnyWindowHandle>,
108 cx: &mut App,
109 ) -> ToolResult {
110 const CONTEXT_LINES: u32 = 2;
111 const MAX_ANCESTOR_LINES: u32 = 10;
112
113 let input = match serde_json::from_value::<GrepToolInput>(input) {
114 Ok(input) => input,
115 Err(error) => {
116 return Task::ready(Err(anyhow!("Failed to parse input: {error}"))).into();
117 }
118 };
119
120 let include_matcher = match PathMatcher::new(
121 input
122 .include_pattern
123 .as_ref()
124 .into_iter()
125 .collect::<Vec<_>>(),
126 ) {
127 Ok(matcher) => matcher,
128 Err(error) => {
129 return Task::ready(Err(anyhow!("invalid include glob pattern: {error}"))).into();
130 }
131 };
132
133 let query = match SearchQuery::regex(
134 &input.regex,
135 false,
136 input.case_sensitive,
137 false,
138 false,
139 include_matcher,
140 PathMatcher::default(), // For now, keep it simple and don't enable an exclude pattern.
141 true, // Always match file include pattern against *full project paths* that start with a project root.
142 None,
143 ) {
144 Ok(query) => query,
145 Err(error) => return Task::ready(Err(error)).into(),
146 };
147
148 let results = project.update(cx, |project, cx| project.search(query, cx));
149
150 cx.spawn(async move |cx| {
151 futures::pin_mut!(results);
152
153 let mut output = String::new();
154 let mut skips_remaining = input.offset;
155 let mut matches_found = 0;
156 let mut has_more_matches = false;
157
158 'outer: while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
159 if ranges.is_empty() {
160 continue;
161 }
162
163 let (Some(path), mut parse_status) = buffer.read_with(cx, |buffer, cx| {
164 (buffer.file().map(|file| file.full_path(cx)), buffer.parse_status())
165 })? else {
166 continue;
167 };
168
169
170 while *parse_status.borrow() != ParseStatus::Idle {
171 parse_status.changed().await?;
172 }
173
174 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
175
176 let mut ranges = ranges
177 .into_iter()
178 .map(|range| {
179 let matched = range.to_point(&snapshot);
180 let matched_end_line_len = snapshot.line_len(matched.end.row);
181 let full_lines = Point::new(matched.start.row, 0)..Point::new(matched.end.row, matched_end_line_len);
182 let symbols = snapshot.symbols_containing(matched.start, None);
183
184 if let Some(ancestor_node) = snapshot.syntax_ancestor(full_lines.clone()) {
185 let full_ancestor_range = ancestor_node.byte_range().to_point(&snapshot);
186 let end_row = full_ancestor_range.end.row.min(full_ancestor_range.start.row + MAX_ANCESTOR_LINES);
187 let end_col = snapshot.line_len(end_row);
188 let capped_ancestor_range = Point::new(full_ancestor_range.start.row, 0)..Point::new(end_row, end_col);
189
190 if capped_ancestor_range.contains_inclusive(&full_lines) {
191 return (capped_ancestor_range, Some(full_ancestor_range), symbols)
192 }
193 }
194
195 let mut matched = matched;
196 matched.start.column = 0;
197 matched.start.row =
198 matched.start.row.saturating_sub(CONTEXT_LINES);
199 matched.end.row = cmp::min(
200 snapshot.max_point().row,
201 matched.end.row + CONTEXT_LINES,
202 );
203 matched.end.column = snapshot.line_len(matched.end.row);
204
205 (matched, None, symbols)
206 })
207 .peekable();
208
209 let mut file_header_written = false;
210
211 while let Some((mut range, ancestor_range, parent_symbols)) = ranges.next(){
212 if skips_remaining > 0 {
213 skips_remaining -= 1;
214 continue;
215 }
216
217 // We'd already found a full page of matches, and we just found one more.
218 if matches_found >= RESULTS_PER_PAGE {
219 has_more_matches = true;
220 break 'outer;
221 }
222
223 while let Some((next_range, _, _)) = ranges.peek() {
224 if range.end.row >= next_range.start.row {
225 range.end = next_range.end;
226 ranges.next();
227 } else {
228 break;
229 }
230 }
231
232 if !file_header_written {
233 writeln!(output, "\n## Matches in {}", path.display())?;
234 file_header_written = true;
235 }
236
237 let end_row = range.end.row;
238 output.push_str("\n### ");
239
240 if let Some(parent_symbols) = &parent_symbols {
241 for symbol in parent_symbols {
242 write!(output, "{} › ", symbol.text)?;
243 }
244 }
245
246 if range.start.row == end_row {
247 writeln!(output, "L{}", range.start.row + 1)?;
248 } else {
249 writeln!(output, "L{}-{}", range.start.row + 1, end_row + 1)?;
250 }
251
252 output.push_str("```\n");
253 output.extend(snapshot.text_for_range(range));
254 output.push_str("\n```\n");
255
256 if let Some(ancestor_range) = ancestor_range {
257 if end_row < ancestor_range.end.row {
258 let remaining_lines = ancestor_range.end.row - end_row;
259 writeln!(output, "\n{} lines remaining in ancestor node. Read the file to see all.", remaining_lines)?;
260 }
261 }
262
263 matches_found += 1;
264 }
265 }
266
267 if matches_found == 0 {
268 Ok("No matches found".to_string().into())
269 } else if has_more_matches {
270 Ok(format!(
271 "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
272 input.offset + 1,
273 input.offset + matches_found,
274 input.offset + RESULTS_PER_PAGE,
275 ).into())
276 } else {
277 Ok(format!("Found {matches_found} matches:\n{output}").into())
278 }
279 }).into()
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286 use assistant_tool::Tool;
287 use gpui::{AppContext, TestAppContext};
288 use language::{Language, LanguageConfig, LanguageMatcher};
289 use language_model::fake_provider::FakeLanguageModel;
290 use project::{FakeFs, Project};
291 use settings::SettingsStore;
292 use unindent::Unindent;
293 use util::path;
294
295 #[gpui::test]
296 async fn test_grep_tool_with_include_pattern(cx: &mut TestAppContext) {
297 init_test(cx);
298 cx.executor().allow_parking();
299
300 let fs = FakeFs::new(cx.executor().clone());
301 fs.insert_tree(
302 "/root",
303 serde_json::json!({
304 "src": {
305 "main.rs": "fn main() {\n println!(\"Hello, world!\");\n}",
306 "utils": {
307 "helper.rs": "fn helper() {\n println!(\"I'm a helper!\");\n}",
308 },
309 },
310 "tests": {
311 "test_main.rs": "fn test_main() {\n assert!(true);\n}",
312 }
313 }),
314 )
315 .await;
316
317 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
318
319 // Test with include pattern for Rust files inside the root of the project
320 let input = serde_json::to_value(GrepToolInput {
321 regex: "println".to_string(),
322 include_pattern: Some("root/**/*.rs".to_string()),
323 offset: 0,
324 case_sensitive: false,
325 })
326 .unwrap();
327
328 let result = run_grep_tool(input, project.clone(), cx).await;
329 assert!(result.contains("main.rs"), "Should find matches in main.rs");
330 assert!(
331 result.contains("helper.rs"),
332 "Should find matches in helper.rs"
333 );
334 assert!(
335 !result.contains("test_main.rs"),
336 "Should not include test_main.rs even though it's a .rs file (because it doesn't have the pattern)"
337 );
338
339 // Test with include pattern for src directory only
340 let input = serde_json::to_value(GrepToolInput {
341 regex: "fn".to_string(),
342 include_pattern: Some("root/**/src/**".to_string()),
343 offset: 0,
344 case_sensitive: false,
345 })
346 .unwrap();
347
348 let result = run_grep_tool(input, project.clone(), cx).await;
349 assert!(
350 result.contains("main.rs"),
351 "Should find matches in src/main.rs"
352 );
353 assert!(
354 result.contains("helper.rs"),
355 "Should find matches in src/utils/helper.rs"
356 );
357 assert!(
358 !result.contains("test_main.rs"),
359 "Should not include test_main.rs as it's not in src directory"
360 );
361
362 // Test with empty include pattern (should default to all files)
363 let input = serde_json::to_value(GrepToolInput {
364 regex: "fn".to_string(),
365 include_pattern: None,
366 offset: 0,
367 case_sensitive: false,
368 })
369 .unwrap();
370
371 let result = run_grep_tool(input, project.clone(), cx).await;
372 assert!(result.contains("main.rs"), "Should find matches in main.rs");
373 assert!(
374 result.contains("helper.rs"),
375 "Should find matches in helper.rs"
376 );
377 assert!(
378 result.contains("test_main.rs"),
379 "Should include test_main.rs"
380 );
381 }
382
383 #[gpui::test]
384 async fn test_grep_tool_with_case_sensitivity(cx: &mut TestAppContext) {
385 init_test(cx);
386 cx.executor().allow_parking();
387
388 let fs = FakeFs::new(cx.executor().clone());
389 fs.insert_tree(
390 "/root",
391 serde_json::json!({
392 "case_test.txt": "This file has UPPERCASE and lowercase text.\nUPPERCASE patterns should match only with case_sensitive: true",
393 }),
394 )
395 .await;
396
397 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
398
399 // Test case-insensitive search (default)
400 let input = serde_json::to_value(GrepToolInput {
401 regex: "uppercase".to_string(),
402 include_pattern: Some("**/*.txt".to_string()),
403 offset: 0,
404 case_sensitive: false,
405 })
406 .unwrap();
407
408 let result = run_grep_tool(input, project.clone(), cx).await;
409 assert!(
410 result.contains("UPPERCASE"),
411 "Case-insensitive search should match uppercase"
412 );
413
414 // Test case-sensitive search
415 let input = serde_json::to_value(GrepToolInput {
416 regex: "uppercase".to_string(),
417 include_pattern: Some("**/*.txt".to_string()),
418 offset: 0,
419 case_sensitive: true,
420 })
421 .unwrap();
422
423 let result = run_grep_tool(input, project.clone(), cx).await;
424 assert!(
425 !result.contains("UPPERCASE"),
426 "Case-sensitive search should not match uppercase"
427 );
428
429 // Test case-sensitive search
430 let input = serde_json::to_value(GrepToolInput {
431 regex: "LOWERCASE".to_string(),
432 include_pattern: Some("**/*.txt".to_string()),
433 offset: 0,
434 case_sensitive: true,
435 })
436 .unwrap();
437
438 let result = run_grep_tool(input, project.clone(), cx).await;
439
440 assert!(
441 !result.contains("lowercase"),
442 "Case-sensitive search should match lowercase"
443 );
444
445 // Test case-sensitive search for lowercase pattern
446 let input = serde_json::to_value(GrepToolInput {
447 regex: "lowercase".to_string(),
448 include_pattern: Some("**/*.txt".to_string()),
449 offset: 0,
450 case_sensitive: true,
451 })
452 .unwrap();
453
454 let result = run_grep_tool(input, project.clone(), cx).await;
455 assert!(
456 result.contains("lowercase"),
457 "Case-sensitive search should match lowercase text"
458 );
459 }
460
461 /// Helper function to set up a syntax test environment
462 async fn setup_syntax_test(cx: &mut TestAppContext) -> Entity<Project> {
463 use unindent::Unindent;
464 init_test(cx);
465 cx.executor().allow_parking();
466
467 let fs = FakeFs::new(cx.executor().clone());
468
469 // Create test file with syntax structures
470 fs.insert_tree(
471 "/root",
472 serde_json::json!({
473 "test_syntax.rs": r#"
474 fn top_level_function() {
475 println!("This is at the top level");
476 }
477
478 mod feature_module {
479 pub mod nested_module {
480 pub fn nested_function(
481 first_arg: String,
482 second_arg: i32,
483 ) {
484 println!("Function in nested module");
485 println!("{first_arg}");
486 println!("{second_arg}");
487 }
488 }
489 }
490
491 struct MyStruct {
492 field1: String,
493 field2: i32,
494 }
495
496 impl MyStruct {
497 fn method_with_block() {
498 let condition = true;
499 if condition {
500 println!("Inside if block");
501 }
502 }
503
504 fn long_function() {
505 println!("Line 1");
506 println!("Line 2");
507 println!("Line 3");
508 println!("Line 4");
509 println!("Line 5");
510 println!("Line 6");
511 println!("Line 7");
512 println!("Line 8");
513 println!("Line 9");
514 println!("Line 10");
515 println!("Line 11");
516 println!("Line 12");
517 }
518 }
519
520 trait Processor {
521 fn process(&self, input: &str) -> String;
522 }
523
524 impl Processor for MyStruct {
525 fn process(&self, input: &str) -> String {
526 format!("Processed: {}", input)
527 }
528 }
529 "#.unindent().trim(),
530 }),
531 )
532 .await;
533
534 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
535
536 project.update(cx, |project, _cx| {
537 project.languages().add(rust_lang().into())
538 });
539
540 project
541 }
542
543 #[gpui::test]
544 async fn test_grep_top_level_function(cx: &mut TestAppContext) {
545 let project = setup_syntax_test(cx).await;
546
547 // Test: Line at the top level of the file
548 let input = serde_json::to_value(GrepToolInput {
549 regex: "This is at the top level".to_string(),
550 include_pattern: Some("**/*.rs".to_string()),
551 offset: 0,
552 case_sensitive: false,
553 })
554 .unwrap();
555
556 let result = run_grep_tool(input, project.clone(), cx).await;
557 let expected = r#"
558 Found 1 matches:
559
560 ## Matches in root/test_syntax.rs
561
562 ### fn top_level_function › L1-3
563 ```
564 fn top_level_function() {
565 println!("This is at the top level");
566 }
567 ```
568 "#
569 .unindent();
570 assert_eq!(result, expected);
571 }
572
573 #[gpui::test]
574 async fn test_grep_function_body(cx: &mut TestAppContext) {
575 let project = setup_syntax_test(cx).await;
576
577 // Test: Line inside a function body
578 let input = serde_json::to_value(GrepToolInput {
579 regex: "Function in nested module".to_string(),
580 include_pattern: Some("**/*.rs".to_string()),
581 offset: 0,
582 case_sensitive: false,
583 })
584 .unwrap();
585
586 let result = run_grep_tool(input, project.clone(), cx).await;
587 let expected = r#"
588 Found 1 matches:
589
590 ## Matches in root/test_syntax.rs
591
592 ### mod feature_module › pub mod nested_module › pub fn nested_function › L10-14
593 ```
594 ) {
595 println!("Function in nested module");
596 println!("{first_arg}");
597 println!("{second_arg}");
598 }
599 ```
600 "#
601 .unindent();
602 assert_eq!(result, expected);
603 }
604
605 #[gpui::test]
606 async fn test_grep_function_args_and_body(cx: &mut TestAppContext) {
607 let project = setup_syntax_test(cx).await;
608
609 // Test: Line with a function argument
610 let input = serde_json::to_value(GrepToolInput {
611 regex: "second_arg".to_string(),
612 include_pattern: Some("**/*.rs".to_string()),
613 offset: 0,
614 case_sensitive: false,
615 })
616 .unwrap();
617
618 let result = run_grep_tool(input, project.clone(), cx).await;
619 let expected = r#"
620 Found 1 matches:
621
622 ## Matches in root/test_syntax.rs
623
624 ### mod feature_module › pub mod nested_module › pub fn nested_function › L7-14
625 ```
626 pub fn nested_function(
627 first_arg: String,
628 second_arg: i32,
629 ) {
630 println!("Function in nested module");
631 println!("{first_arg}");
632 println!("{second_arg}");
633 }
634 ```
635 "#
636 .unindent();
637 assert_eq!(result, expected);
638 }
639
640 #[gpui::test]
641 async fn test_grep_if_block(cx: &mut TestAppContext) {
642 use unindent::Unindent;
643 let project = setup_syntax_test(cx).await;
644
645 // Test: Line inside an if block
646 let input = serde_json::to_value(GrepToolInput {
647 regex: "Inside if block".to_string(),
648 include_pattern: Some("**/*.rs".to_string()),
649 offset: 0,
650 case_sensitive: false,
651 })
652 .unwrap();
653
654 let result = run_grep_tool(input, project.clone(), cx).await;
655 let expected = r#"
656 Found 1 matches:
657
658 ## Matches in root/test_syntax.rs
659
660 ### impl MyStruct › fn method_with_block › L26-28
661 ```
662 if condition {
663 println!("Inside if block");
664 }
665 ```
666 "#
667 .unindent();
668 assert_eq!(result, expected);
669 }
670
671 #[gpui::test]
672 async fn test_grep_long_function_top(cx: &mut TestAppContext) {
673 use unindent::Unindent;
674 let project = setup_syntax_test(cx).await;
675
676 // Test: Line in the middle of a long function - should show message about remaining lines
677 let input = serde_json::to_value(GrepToolInput {
678 regex: "Line 5".to_string(),
679 include_pattern: Some("**/*.rs".to_string()),
680 offset: 0,
681 case_sensitive: false,
682 })
683 .unwrap();
684
685 let result = run_grep_tool(input, project.clone(), cx).await;
686 let expected = r#"
687 Found 1 matches:
688
689 ## Matches in root/test_syntax.rs
690
691 ### impl MyStruct › fn long_function › L31-41
692 ```
693 fn long_function() {
694 println!("Line 1");
695 println!("Line 2");
696 println!("Line 3");
697 println!("Line 4");
698 println!("Line 5");
699 println!("Line 6");
700 println!("Line 7");
701 println!("Line 8");
702 println!("Line 9");
703 println!("Line 10");
704 ```
705
706 3 lines remaining in ancestor node. Read the file to see all.
707 "#
708 .unindent();
709 assert_eq!(result, expected);
710 }
711
712 #[gpui::test]
713 async fn test_grep_long_function_bottom(cx: &mut TestAppContext) {
714 use unindent::Unindent;
715 let project = setup_syntax_test(cx).await;
716
717 // Test: Line in the long function
718 let input = serde_json::to_value(GrepToolInput {
719 regex: "Line 12".to_string(),
720 include_pattern: Some("**/*.rs".to_string()),
721 offset: 0,
722 case_sensitive: false,
723 })
724 .unwrap();
725
726 let result = run_grep_tool(input, project.clone(), cx).await;
727 let expected = r#"
728 Found 1 matches:
729
730 ## Matches in root/test_syntax.rs
731
732 ### impl MyStruct › fn long_function › L41-45
733 ```
734 println!("Line 10");
735 println!("Line 11");
736 println!("Line 12");
737 }
738 }
739 ```
740 "#
741 .unindent();
742 assert_eq!(result, expected);
743 }
744
745 async fn run_grep_tool(
746 input: serde_json::Value,
747 project: Entity<Project>,
748 cx: &mut TestAppContext,
749 ) -> String {
750 let tool = Arc::new(GrepTool);
751 let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
752 let model = Arc::new(FakeLanguageModel::default());
753 let task =
754 cx.update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx));
755
756 match task.output.await {
757 Ok(result) => {
758 if cfg!(windows) {
759 result.content.as_str().unwrap().replace("root\\", "root/")
760 } else {
761 result.content.as_str().unwrap().to_string()
762 }
763 }
764 Err(e) => panic!("Failed to run grep tool: {}", e),
765 }
766 }
767
768 fn init_test(cx: &mut TestAppContext) {
769 cx.update(|cx| {
770 let settings_store = SettingsStore::test(cx);
771 cx.set_global(settings_store);
772 language::init(cx);
773 Project::init_settings(cx);
774 });
775 }
776
777 fn rust_lang() -> Language {
778 Language::new(
779 LanguageConfig {
780 name: "Rust".into(),
781 matcher: LanguageMatcher {
782 path_suffixes: vec!["rs".to_string()],
783 ..Default::default()
784 },
785 ..Default::default()
786 },
787 Some(tree_sitter_rust::LANGUAGE.into()),
788 )
789 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
790 .unwrap()
791 }
792}