1use crate::schema::json_schema_for;
2use anyhow::{Result, anyhow};
3use assistant_tool::{ActionLog, Tool, ToolResult};
4use futures::StreamExt;
5use gpui::{AnyWindowHandle, App, Entity, Task};
6use language::{OffsetRangeExt, ParseStatus, Point};
7use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
8use project::{
9 Project,
10 search::{SearchQuery, SearchResult},
11};
12use schemars::JsonSchema;
13use serde::{Deserialize, Serialize};
14use std::{cmp, fmt::Write, sync::Arc};
15use ui::IconName;
16use util::RangeExt;
17use util::markdown::MarkdownInlineCode;
18use util::paths::PathMatcher;
19
20#[derive(Debug, Serialize, Deserialize, JsonSchema)]
21pub struct GrepToolInput {
22 /// A regex pattern to search for in the entire project. Note that the regex
23 /// will be parsed by the Rust `regex` crate.
24 ///
25 /// Do NOT specify a path here! This will only be matched against the code **content**.
26 pub regex: String,
27
28 /// A glob pattern for the paths of files to include in the search.
29 /// Supports standard glob patterns like "**/*.rs" or "src/**/*.ts".
30 /// If omitted, all files in the project will be searched.
31 pub include_pattern: Option<String>,
32
33 /// Optional starting position for paginated results (0-based).
34 /// When not provided, starts from the beginning.
35 #[serde(default)]
36 pub offset: u32,
37
38 /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
39 #[serde(default)]
40 pub case_sensitive: bool,
41}
42
43impl GrepToolInput {
44 /// Which page of search results this is.
45 pub fn page(&self) -> u32 {
46 1 + (self.offset / RESULTS_PER_PAGE)
47 }
48}
49
50const RESULTS_PER_PAGE: u32 = 20;
51
52pub struct GrepTool;
53
54impl Tool for GrepTool {
55 fn name(&self) -> String {
56 "grep".into()
57 }
58
59 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
60 false
61 }
62
63 fn description(&self) -> String {
64 include_str!("./grep_tool/description.md").into()
65 }
66
67 fn icon(&self) -> IconName {
68 IconName::Regex
69 }
70
71 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
72 json_schema_for::<GrepToolInput>(format)
73 }
74
75 fn ui_text(&self, input: &serde_json::Value) -> String {
76 match serde_json::from_value::<GrepToolInput>(input.clone()) {
77 Ok(input) => {
78 let page = input.page();
79 let regex_str = MarkdownInlineCode(&input.regex);
80 let case_info = if input.case_sensitive {
81 " (case-sensitive)"
82 } else {
83 ""
84 };
85
86 if page > 1 {
87 format!("Get page {page} of search results for regex {regex_str}{case_info}")
88 } else {
89 format!("Search files for regex {regex_str}{case_info}")
90 }
91 }
92 Err(_) => "Search with regex".to_string(),
93 }
94 }
95
96 fn run(
97 self: Arc<Self>,
98 input: serde_json::Value,
99 _messages: &[LanguageModelRequestMessage],
100 project: Entity<Project>,
101 _action_log: Entity<ActionLog>,
102 _window: Option<AnyWindowHandle>,
103 cx: &mut App,
104 ) -> ToolResult {
105 const CONTEXT_LINES: u32 = 2;
106 const MAX_ANCESTOR_LINES: u32 = 10;
107
108 let input = match serde_json::from_value::<GrepToolInput>(input) {
109 Ok(input) => input,
110 Err(error) => {
111 return Task::ready(Err(anyhow!("Failed to parse input: {}", error))).into();
112 }
113 };
114
115 let include_matcher = match PathMatcher::new(
116 input
117 .include_pattern
118 .as_ref()
119 .into_iter()
120 .collect::<Vec<_>>(),
121 ) {
122 Ok(matcher) => matcher,
123 Err(error) => {
124 return Task::ready(Err(anyhow!("invalid include glob pattern: {}", error))).into();
125 }
126 };
127
128 let query = match SearchQuery::regex(
129 &input.regex,
130 false,
131 input.case_sensitive,
132 false,
133 false,
134 include_matcher,
135 PathMatcher::default(), // For now, keep it simple and don't enable an exclude pattern.
136 true, // Always match file include pattern against *full project paths* that start with a project root.
137 None,
138 ) {
139 Ok(query) => query,
140 Err(error) => return Task::ready(Err(error)).into(),
141 };
142
143 let results = project.update(cx, |project, cx| project.search(query, cx));
144
145 cx.spawn(async move |cx| {
146 futures::pin_mut!(results);
147
148 let mut output = String::new();
149 let mut skips_remaining = input.offset;
150 let mut matches_found = 0;
151 let mut has_more_matches = false;
152
153 'outer: while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
154 if ranges.is_empty() {
155 continue;
156 }
157
158 let (Some(path), mut parse_status) = buffer.read_with(cx, |buffer, cx| {
159 (buffer.file().map(|file| file.full_path(cx)), buffer.parse_status())
160 })? else {
161 continue;
162 };
163
164
165 while *parse_status.borrow() != ParseStatus::Idle {
166 parse_status.changed().await?;
167 }
168
169 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
170
171 let mut ranges = ranges
172 .into_iter()
173 .map(|range| {
174 let matched = range.to_point(&snapshot);
175 let matched_end_line_len = snapshot.line_len(matched.end.row);
176 let full_lines = Point::new(matched.start.row, 0)..Point::new(matched.end.row, matched_end_line_len);
177 let symbols = snapshot.symbols_containing(matched.start, None);
178
179 if let Some(ancestor_node) = snapshot.syntax_ancestor(full_lines.clone()) {
180 let full_ancestor_range = ancestor_node.byte_range().to_point(&snapshot);
181 let end_row = full_ancestor_range.end.row.min(full_ancestor_range.start.row + MAX_ANCESTOR_LINES);
182 let end_col = snapshot.line_len(end_row);
183 let capped_ancestor_range = Point::new(full_ancestor_range.start.row, 0)..Point::new(end_row, end_col);
184
185 if capped_ancestor_range.contains_inclusive(&full_lines) {
186 return (capped_ancestor_range, Some(full_ancestor_range), symbols)
187 }
188 }
189
190 let mut matched = matched;
191 matched.start.column = 0;
192 matched.start.row =
193 matched.start.row.saturating_sub(CONTEXT_LINES);
194 matched.end.row = cmp::min(
195 snapshot.max_point().row,
196 matched.end.row + CONTEXT_LINES,
197 );
198 matched.end.column = snapshot.line_len(matched.end.row);
199
200 (matched, None, symbols)
201 })
202 .peekable();
203
204 let mut file_header_written = false;
205
206 while let Some((mut range, ancestor_range, parent_symbols)) = ranges.next(){
207 if skips_remaining > 0 {
208 skips_remaining -= 1;
209 continue;
210 }
211
212 // We'd already found a full page of matches, and we just found one more.
213 if matches_found >= RESULTS_PER_PAGE {
214 has_more_matches = true;
215 break 'outer;
216 }
217
218 while let Some((next_range, _, _)) = ranges.peek() {
219 if range.end.row >= next_range.start.row {
220 range.end = next_range.end;
221 ranges.next();
222 } else {
223 break;
224 }
225 }
226
227 if !file_header_written {
228 writeln!(output, "\n## Matches in {}", path.display())?;
229 file_header_written = true;
230 }
231
232 let end_row = range.end.row;
233 output.push_str("\n### ");
234
235 if let Some(parent_symbols) = &parent_symbols {
236 for symbol in parent_symbols {
237 write!(output, "{} › ", symbol.text)?;
238 }
239 }
240
241 if range.start.row == end_row {
242 writeln!(output, "L{}", range.start.row + 1)?;
243 } else {
244 writeln!(output, "L{}-{}", range.start.row + 1, end_row + 1)?;
245 }
246
247 output.push_str("```\n");
248 output.extend(snapshot.text_for_range(range));
249 output.push_str("\n```\n");
250
251 if let Some(ancestor_range) = ancestor_range {
252 if end_row < ancestor_range.end.row {
253 let remaining_lines = ancestor_range.end.row - end_row;
254 writeln!(output, "\n{} lines remaining in ancestor node. Read the file to see all.", remaining_lines)?;
255 }
256 }
257
258 matches_found += 1;
259 }
260 }
261
262 if matches_found == 0 {
263 Ok("No matches found".to_string().into())
264 } else if has_more_matches {
265 Ok(format!(
266 "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
267 input.offset + 1,
268 input.offset + matches_found,
269 input.offset + RESULTS_PER_PAGE,
270 ).into())
271 } else {
272 Ok(format!("Found {matches_found} matches:\n{output}").into())
273 }
274 }).into()
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281 use assistant_tool::Tool;
282 use gpui::{AppContext, TestAppContext};
283 use language::{Language, LanguageConfig, LanguageMatcher};
284 use project::{FakeFs, Project};
285 use settings::SettingsStore;
286 use unindent::Unindent;
287 use util::path;
288
289 #[gpui::test]
290 async fn test_grep_tool_with_include_pattern(cx: &mut TestAppContext) {
291 init_test(cx);
292 cx.executor().allow_parking();
293
294 let fs = FakeFs::new(cx.executor().clone());
295 fs.insert_tree(
296 "/root",
297 serde_json::json!({
298 "src": {
299 "main.rs": "fn main() {\n println!(\"Hello, world!\");\n}",
300 "utils": {
301 "helper.rs": "fn helper() {\n println!(\"I'm a helper!\");\n}",
302 },
303 },
304 "tests": {
305 "test_main.rs": "fn test_main() {\n assert!(true);\n}",
306 }
307 }),
308 )
309 .await;
310
311 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
312
313 // Test with include pattern for Rust files inside the root of the project
314 let input = serde_json::to_value(GrepToolInput {
315 regex: "println".to_string(),
316 include_pattern: Some("root/**/*.rs".to_string()),
317 offset: 0,
318 case_sensitive: false,
319 })
320 .unwrap();
321
322 let result = run_grep_tool(input, project.clone(), cx).await;
323 assert!(result.contains("main.rs"), "Should find matches in main.rs");
324 assert!(
325 result.contains("helper.rs"),
326 "Should find matches in helper.rs"
327 );
328 assert!(
329 !result.contains("test_main.rs"),
330 "Should not include test_main.rs even though it's a .rs file (because it doesn't have the pattern)"
331 );
332
333 // Test with include pattern for src directory only
334 let input = serde_json::to_value(GrepToolInput {
335 regex: "fn".to_string(),
336 include_pattern: Some("root/**/src/**".to_string()),
337 offset: 0,
338 case_sensitive: false,
339 })
340 .unwrap();
341
342 let result = run_grep_tool(input, project.clone(), cx).await;
343 assert!(
344 result.contains("main.rs"),
345 "Should find matches in src/main.rs"
346 );
347 assert!(
348 result.contains("helper.rs"),
349 "Should find matches in src/utils/helper.rs"
350 );
351 assert!(
352 !result.contains("test_main.rs"),
353 "Should not include test_main.rs as it's not in src directory"
354 );
355
356 // Test with empty include pattern (should default to all files)
357 let input = serde_json::to_value(GrepToolInput {
358 regex: "fn".to_string(),
359 include_pattern: None,
360 offset: 0,
361 case_sensitive: false,
362 })
363 .unwrap();
364
365 let result = run_grep_tool(input, project.clone(), cx).await;
366 assert!(result.contains("main.rs"), "Should find matches in main.rs");
367 assert!(
368 result.contains("helper.rs"),
369 "Should find matches in helper.rs"
370 );
371 assert!(
372 result.contains("test_main.rs"),
373 "Should include test_main.rs"
374 );
375 }
376
377 #[gpui::test]
378 async fn test_grep_tool_with_case_sensitivity(cx: &mut TestAppContext) {
379 init_test(cx);
380 cx.executor().allow_parking();
381
382 let fs = FakeFs::new(cx.executor().clone());
383 fs.insert_tree(
384 "/root",
385 serde_json::json!({
386 "case_test.txt": "This file has UPPERCASE and lowercase text.\nUPPERCASE patterns should match only with case_sensitive: true",
387 }),
388 )
389 .await;
390
391 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
392
393 // Test case-insensitive search (default)
394 let input = serde_json::to_value(GrepToolInput {
395 regex: "uppercase".to_string(),
396 include_pattern: Some("**/*.txt".to_string()),
397 offset: 0,
398 case_sensitive: false,
399 })
400 .unwrap();
401
402 let result = run_grep_tool(input, project.clone(), cx).await;
403 assert!(
404 result.contains("UPPERCASE"),
405 "Case-insensitive search should match uppercase"
406 );
407
408 // Test case-sensitive search
409 let input = serde_json::to_value(GrepToolInput {
410 regex: "uppercase".to_string(),
411 include_pattern: Some("**/*.txt".to_string()),
412 offset: 0,
413 case_sensitive: true,
414 })
415 .unwrap();
416
417 let result = run_grep_tool(input, project.clone(), cx).await;
418 assert!(
419 !result.contains("UPPERCASE"),
420 "Case-sensitive search should not match uppercase"
421 );
422
423 // Test case-sensitive search
424 let input = serde_json::to_value(GrepToolInput {
425 regex: "LOWERCASE".to_string(),
426 include_pattern: Some("**/*.txt".to_string()),
427 offset: 0,
428 case_sensitive: true,
429 })
430 .unwrap();
431
432 let result = run_grep_tool(input, project.clone(), cx).await;
433
434 assert!(
435 !result.contains("lowercase"),
436 "Case-sensitive search should match lowercase"
437 );
438
439 // Test case-sensitive search for lowercase pattern
440 let input = serde_json::to_value(GrepToolInput {
441 regex: "lowercase".to_string(),
442 include_pattern: Some("**/*.txt".to_string()),
443 offset: 0,
444 case_sensitive: true,
445 })
446 .unwrap();
447
448 let result = run_grep_tool(input, project.clone(), cx).await;
449 assert!(
450 result.contains("lowercase"),
451 "Case-sensitive search should match lowercase text"
452 );
453 }
454
455 /// Helper function to set up a syntax test environment
456 async fn setup_syntax_test(cx: &mut TestAppContext) -> Entity<Project> {
457 use unindent::Unindent;
458 init_test(cx);
459 cx.executor().allow_parking();
460
461 let fs = FakeFs::new(cx.executor().clone());
462
463 // Create test file with syntax structures
464 fs.insert_tree(
465 "/root",
466 serde_json::json!({
467 "test_syntax.rs": r#"
468 fn top_level_function() {
469 println!("This is at the top level");
470 }
471
472 mod feature_module {
473 pub mod nested_module {
474 pub fn nested_function(
475 first_arg: String,
476 second_arg: i32,
477 ) {
478 println!("Function in nested module");
479 println!("{first_arg}");
480 println!("{second_arg}");
481 }
482 }
483 }
484
485 struct MyStruct {
486 field1: String,
487 field2: i32,
488 }
489
490 impl MyStruct {
491 fn method_with_block() {
492 let condition = true;
493 if condition {
494 println!("Inside if block");
495 }
496 }
497
498 fn long_function() {
499 println!("Line 1");
500 println!("Line 2");
501 println!("Line 3");
502 println!("Line 4");
503 println!("Line 5");
504 println!("Line 6");
505 println!("Line 7");
506 println!("Line 8");
507 println!("Line 9");
508 println!("Line 10");
509 println!("Line 11");
510 println!("Line 12");
511 }
512 }
513
514 trait Processor {
515 fn process(&self, input: &str) -> String;
516 }
517
518 impl Processor for MyStruct {
519 fn process(&self, input: &str) -> String {
520 format!("Processed: {}", input)
521 }
522 }
523 "#.unindent().trim(),
524 }),
525 )
526 .await;
527
528 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
529
530 project.update(cx, |project, _cx| {
531 project.languages().add(rust_lang().into())
532 });
533
534 project
535 }
536
537 #[gpui::test]
538 async fn test_grep_top_level_function(cx: &mut TestAppContext) {
539 let project = setup_syntax_test(cx).await;
540
541 // Test: Line at the top level of the file
542 let input = serde_json::to_value(GrepToolInput {
543 regex: "This is at the top level".to_string(),
544 include_pattern: Some("**/*.rs".to_string()),
545 offset: 0,
546 case_sensitive: false,
547 })
548 .unwrap();
549
550 let result = run_grep_tool(input, project.clone(), cx).await;
551 let expected = r#"
552 Found 1 matches:
553
554 ## Matches in root/test_syntax.rs
555
556 ### fn top_level_function › L1-3
557 ```
558 fn top_level_function() {
559 println!("This is at the top level");
560 }
561 ```
562 "#
563 .unindent();
564 assert_eq!(result, expected);
565 }
566
567 #[gpui::test]
568 async fn test_grep_function_body(cx: &mut TestAppContext) {
569 let project = setup_syntax_test(cx).await;
570
571 // Test: Line inside a function body
572 let input = serde_json::to_value(GrepToolInput {
573 regex: "Function in nested module".to_string(),
574 include_pattern: Some("**/*.rs".to_string()),
575 offset: 0,
576 case_sensitive: false,
577 })
578 .unwrap();
579
580 let result = run_grep_tool(input, project.clone(), cx).await;
581 let expected = r#"
582 Found 1 matches:
583
584 ## Matches in root/test_syntax.rs
585
586 ### mod feature_module › pub mod nested_module › pub fn nested_function › L10-14
587 ```
588 ) {
589 println!("Function in nested module");
590 println!("{first_arg}");
591 println!("{second_arg}");
592 }
593 ```
594 "#
595 .unindent();
596 assert_eq!(result, expected);
597 }
598
599 #[gpui::test]
600 async fn test_grep_function_args_and_body(cx: &mut TestAppContext) {
601 let project = setup_syntax_test(cx).await;
602
603 // Test: Line with a function argument
604 let input = serde_json::to_value(GrepToolInput {
605 regex: "second_arg".to_string(),
606 include_pattern: Some("**/*.rs".to_string()),
607 offset: 0,
608 case_sensitive: false,
609 })
610 .unwrap();
611
612 let result = run_grep_tool(input, project.clone(), cx).await;
613 let expected = r#"
614 Found 1 matches:
615
616 ## Matches in root/test_syntax.rs
617
618 ### mod feature_module › pub mod nested_module › pub fn nested_function › L7-14
619 ```
620 pub fn nested_function(
621 first_arg: String,
622 second_arg: i32,
623 ) {
624 println!("Function in nested module");
625 println!("{first_arg}");
626 println!("{second_arg}");
627 }
628 ```
629 "#
630 .unindent();
631 assert_eq!(result, expected);
632 }
633
634 #[gpui::test]
635 async fn test_grep_if_block(cx: &mut TestAppContext) {
636 use unindent::Unindent;
637 let project = setup_syntax_test(cx).await;
638
639 // Test: Line inside an if block
640 let input = serde_json::to_value(GrepToolInput {
641 regex: "Inside if block".to_string(),
642 include_pattern: Some("**/*.rs".to_string()),
643 offset: 0,
644 case_sensitive: false,
645 })
646 .unwrap();
647
648 let result = run_grep_tool(input, project.clone(), cx).await;
649 let expected = r#"
650 Found 1 matches:
651
652 ## Matches in root/test_syntax.rs
653
654 ### impl MyStruct › fn method_with_block › L26-28
655 ```
656 if condition {
657 println!("Inside if block");
658 }
659 ```
660 "#
661 .unindent();
662 assert_eq!(result, expected);
663 }
664
665 #[gpui::test]
666 async fn test_grep_long_function_top(cx: &mut TestAppContext) {
667 use unindent::Unindent;
668 let project = setup_syntax_test(cx).await;
669
670 // Test: Line in the middle of a long function - should show message about remaining lines
671 let input = serde_json::to_value(GrepToolInput {
672 regex: "Line 5".to_string(),
673 include_pattern: Some("**/*.rs".to_string()),
674 offset: 0,
675 case_sensitive: false,
676 })
677 .unwrap();
678
679 let result = run_grep_tool(input, project.clone(), cx).await;
680 let expected = r#"
681 Found 1 matches:
682
683 ## Matches in root/test_syntax.rs
684
685 ### impl MyStruct › fn long_function › L31-41
686 ```
687 fn long_function() {
688 println!("Line 1");
689 println!("Line 2");
690 println!("Line 3");
691 println!("Line 4");
692 println!("Line 5");
693 println!("Line 6");
694 println!("Line 7");
695 println!("Line 8");
696 println!("Line 9");
697 println!("Line 10");
698 ```
699
700 3 lines remaining in ancestor node. Read the file to see all.
701 "#
702 .unindent();
703 assert_eq!(result, expected);
704 }
705
706 #[gpui::test]
707 async fn test_grep_long_function_bottom(cx: &mut TestAppContext) {
708 use unindent::Unindent;
709 let project = setup_syntax_test(cx).await;
710
711 // Test: Line in the long function
712 let input = serde_json::to_value(GrepToolInput {
713 regex: "Line 12".to_string(),
714 include_pattern: Some("**/*.rs".to_string()),
715 offset: 0,
716 case_sensitive: false,
717 })
718 .unwrap();
719
720 let result = run_grep_tool(input, project.clone(), cx).await;
721 let expected = r#"
722 Found 1 matches:
723
724 ## Matches in root/test_syntax.rs
725
726 ### impl MyStruct › fn long_function › L41-45
727 ```
728 println!("Line 10");
729 println!("Line 11");
730 println!("Line 12");
731 }
732 }
733 ```
734 "#
735 .unindent();
736 assert_eq!(result, expected);
737 }
738
739 async fn run_grep_tool(
740 input: serde_json::Value,
741 project: Entity<Project>,
742 cx: &mut TestAppContext,
743 ) -> String {
744 let tool = Arc::new(GrepTool);
745 let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
746 let task = cx.update(|cx| tool.run(input, &[], project, action_log, None, cx));
747
748 match task.output.await {
749 Ok(result) => {
750 if cfg!(windows) {
751 result.content.replace("root\\", "root/")
752 } else {
753 result.content
754 }
755 }
756 Err(e) => panic!("Failed to run grep tool: {}", e),
757 }
758 }
759
760 fn init_test(cx: &mut TestAppContext) {
761 cx.update(|cx| {
762 let settings_store = SettingsStore::test(cx);
763 cx.set_global(settings_store);
764 language::init(cx);
765 Project::init_settings(cx);
766 });
767 }
768
769 fn rust_lang() -> Language {
770 Language::new(
771 LanguageConfig {
772 name: "Rust".into(),
773 matcher: LanguageMatcher {
774 path_suffixes: vec!["rs".to_string()],
775 ..Default::default()
776 },
777 ..Default::default()
778 },
779 Some(tree_sitter_rust::LANGUAGE.into()),
780 )
781 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
782 .unwrap()
783 }
784}