1use crate::schema::json_schema_for;
2use anyhow::{Result, anyhow};
3use assistant_tool::{ActionLog, Tool, ToolResult};
4use futures::StreamExt;
5use gpui::{AnyWindowHandle, App, Entity, Task};
6use language::{OffsetRangeExt, ParseStatus, Point};
7use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
8use project::{
9 Project,
10 search::{SearchQuery, SearchResult},
11};
12use schemars::JsonSchema;
13use serde::{Deserialize, Serialize};
14use std::{cmp, fmt::Write, sync::Arc};
15use ui::IconName;
16use util::RangeExt;
17use util::markdown::MarkdownInlineCode;
18use util::paths::PathMatcher;
19
20#[derive(Debug, Serialize, Deserialize, JsonSchema)]
21pub struct GrepToolInput {
22 /// A regex pattern to search for in the entire project. Note that the regex
23 /// will be parsed by the Rust `regex` crate.
24 ///
25 /// Do NOT specify a path here! This will only be matched against the code **content**.
26 pub regex: String,
27
28 /// A glob pattern for the paths of files to include in the search.
29 /// Supports standard glob patterns like "**/*.rs" or "src/**/*.ts".
30 /// If omitted, all files in the project will be searched.
31 pub include_pattern: Option<String>,
32
33 /// Optional starting position for paginated results (0-based).
34 /// When not provided, starts from the beginning.
35 #[serde(default)]
36 pub offset: u32,
37
38 /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
39 #[serde(default)]
40 pub case_sensitive: bool,
41}
42
43impl GrepToolInput {
44 /// Which page of search results this is.
45 pub fn page(&self) -> u32 {
46 1 + (self.offset / RESULTS_PER_PAGE)
47 }
48}
49
50const RESULTS_PER_PAGE: u32 = 20;
51
52pub struct GrepTool;
53
54impl Tool for GrepTool {
55 fn name(&self) -> String {
56 "grep".into()
57 }
58
59 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
60 false
61 }
62
63 fn description(&self) -> String {
64 include_str!("./grep_tool/description.md").into()
65 }
66
67 fn icon(&self) -> IconName {
68 IconName::Regex
69 }
70
71 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
72 json_schema_for::<GrepToolInput>(format)
73 }
74
75 fn ui_text(&self, input: &serde_json::Value) -> String {
76 match serde_json::from_value::<GrepToolInput>(input.clone()) {
77 Ok(input) => {
78 let page = input.page();
79 let regex_str = MarkdownInlineCode(&input.regex);
80 let case_info = if input.case_sensitive {
81 " (case-sensitive)"
82 } else {
83 ""
84 };
85
86 if page > 1 {
87 format!("Get page {page} of search results for regex {regex_str}{case_info}")
88 } else {
89 format!("Search files for regex {regex_str}{case_info}")
90 }
91 }
92 Err(_) => "Search with regex".to_string(),
93 }
94 }
95
96 fn run(
97 self: Arc<Self>,
98 input: serde_json::Value,
99 _request: Arc<LanguageModelRequest>,
100 project: Entity<Project>,
101 _action_log: Entity<ActionLog>,
102 _model: Arc<dyn LanguageModel>,
103 _window: Option<AnyWindowHandle>,
104 cx: &mut App,
105 ) -> ToolResult {
106 const CONTEXT_LINES: u32 = 2;
107 const MAX_ANCESTOR_LINES: u32 = 10;
108
109 let input = match serde_json::from_value::<GrepToolInput>(input) {
110 Ok(input) => input,
111 Err(error) => {
112 return Task::ready(Err(anyhow!("Failed to parse input: {error}"))).into();
113 }
114 };
115
116 let include_matcher = match PathMatcher::new(
117 input
118 .include_pattern
119 .as_ref()
120 .into_iter()
121 .collect::<Vec<_>>(),
122 ) {
123 Ok(matcher) => matcher,
124 Err(error) => {
125 return Task::ready(Err(anyhow!("invalid include glob pattern: {error}"))).into();
126 }
127 };
128
129 let query = match SearchQuery::regex(
130 &input.regex,
131 false,
132 input.case_sensitive,
133 false,
134 false,
135 include_matcher,
136 PathMatcher::default(), // For now, keep it simple and don't enable an exclude pattern.
137 true, // Always match file include pattern against *full project paths* that start with a project root.
138 None,
139 ) {
140 Ok(query) => query,
141 Err(error) => return Task::ready(Err(error)).into(),
142 };
143
144 let results = project.update(cx, |project, cx| project.search(query, cx));
145
146 cx.spawn(async move |cx| {
147 futures::pin_mut!(results);
148
149 let mut output = String::new();
150 let mut skips_remaining = input.offset;
151 let mut matches_found = 0;
152 let mut has_more_matches = false;
153
154 'outer: while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
155 if ranges.is_empty() {
156 continue;
157 }
158
159 let (Some(path), mut parse_status) = buffer.read_with(cx, |buffer, cx| {
160 (buffer.file().map(|file| file.full_path(cx)), buffer.parse_status())
161 })? else {
162 continue;
163 };
164
165
166 while *parse_status.borrow() != ParseStatus::Idle {
167 parse_status.changed().await?;
168 }
169
170 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
171
172 let mut ranges = ranges
173 .into_iter()
174 .map(|range| {
175 let matched = range.to_point(&snapshot);
176 let matched_end_line_len = snapshot.line_len(matched.end.row);
177 let full_lines = Point::new(matched.start.row, 0)..Point::new(matched.end.row, matched_end_line_len);
178 let symbols = snapshot.symbols_containing(matched.start, None);
179
180 if let Some(ancestor_node) = snapshot.syntax_ancestor(full_lines.clone()) {
181 let full_ancestor_range = ancestor_node.byte_range().to_point(&snapshot);
182 let end_row = full_ancestor_range.end.row.min(full_ancestor_range.start.row + MAX_ANCESTOR_LINES);
183 let end_col = snapshot.line_len(end_row);
184 let capped_ancestor_range = Point::new(full_ancestor_range.start.row, 0)..Point::new(end_row, end_col);
185
186 if capped_ancestor_range.contains_inclusive(&full_lines) {
187 return (capped_ancestor_range, Some(full_ancestor_range), symbols)
188 }
189 }
190
191 let mut matched = matched;
192 matched.start.column = 0;
193 matched.start.row =
194 matched.start.row.saturating_sub(CONTEXT_LINES);
195 matched.end.row = cmp::min(
196 snapshot.max_point().row,
197 matched.end.row + CONTEXT_LINES,
198 );
199 matched.end.column = snapshot.line_len(matched.end.row);
200
201 (matched, None, symbols)
202 })
203 .peekable();
204
205 let mut file_header_written = false;
206
207 while let Some((mut range, ancestor_range, parent_symbols)) = ranges.next(){
208 if skips_remaining > 0 {
209 skips_remaining -= 1;
210 continue;
211 }
212
213 // We'd already found a full page of matches, and we just found one more.
214 if matches_found >= RESULTS_PER_PAGE {
215 has_more_matches = true;
216 break 'outer;
217 }
218
219 while let Some((next_range, _, _)) = ranges.peek() {
220 if range.end.row >= next_range.start.row {
221 range.end = next_range.end;
222 ranges.next();
223 } else {
224 break;
225 }
226 }
227
228 if !file_header_written {
229 writeln!(output, "\n## Matches in {}", path.display())?;
230 file_header_written = true;
231 }
232
233 let end_row = range.end.row;
234 output.push_str("\n### ");
235
236 if let Some(parent_symbols) = &parent_symbols {
237 for symbol in parent_symbols {
238 write!(output, "{} › ", symbol.text)?;
239 }
240 }
241
242 if range.start.row == end_row {
243 writeln!(output, "L{}", range.start.row + 1)?;
244 } else {
245 writeln!(output, "L{}-{}", range.start.row + 1, end_row + 1)?;
246 }
247
248 output.push_str("```\n");
249 output.extend(snapshot.text_for_range(range));
250 output.push_str("\n```\n");
251
252 if let Some(ancestor_range) = ancestor_range {
253 if end_row < ancestor_range.end.row {
254 let remaining_lines = ancestor_range.end.row - end_row;
255 writeln!(output, "\n{} lines remaining in ancestor node. Read the file to see all.", remaining_lines)?;
256 }
257 }
258
259 matches_found += 1;
260 }
261 }
262
263 if matches_found == 0 {
264 Ok("No matches found".to_string().into())
265 } else if has_more_matches {
266 Ok(format!(
267 "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
268 input.offset + 1,
269 input.offset + matches_found,
270 input.offset + RESULTS_PER_PAGE,
271 ).into())
272 } else {
273 Ok(format!("Found {matches_found} matches:\n{output}").into())
274 }
275 }).into()
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use assistant_tool::Tool;
283 use gpui::{AppContext, TestAppContext};
284 use language::{Language, LanguageConfig, LanguageMatcher};
285 use language_model::fake_provider::FakeLanguageModel;
286 use project::{FakeFs, Project};
287 use settings::SettingsStore;
288 use unindent::Unindent;
289 use util::path;
290
291 #[gpui::test]
292 async fn test_grep_tool_with_include_pattern(cx: &mut TestAppContext) {
293 init_test(cx);
294 cx.executor().allow_parking();
295
296 let fs = FakeFs::new(cx.executor().clone());
297 fs.insert_tree(
298 "/root",
299 serde_json::json!({
300 "src": {
301 "main.rs": "fn main() {\n println!(\"Hello, world!\");\n}",
302 "utils": {
303 "helper.rs": "fn helper() {\n println!(\"I'm a helper!\");\n}",
304 },
305 },
306 "tests": {
307 "test_main.rs": "fn test_main() {\n assert!(true);\n}",
308 }
309 }),
310 )
311 .await;
312
313 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
314
315 // Test with include pattern for Rust files inside the root of the project
316 let input = serde_json::to_value(GrepToolInput {
317 regex: "println".to_string(),
318 include_pattern: Some("root/**/*.rs".to_string()),
319 offset: 0,
320 case_sensitive: false,
321 })
322 .unwrap();
323
324 let result = run_grep_tool(input, project.clone(), cx).await;
325 assert!(result.contains("main.rs"), "Should find matches in main.rs");
326 assert!(
327 result.contains("helper.rs"),
328 "Should find matches in helper.rs"
329 );
330 assert!(
331 !result.contains("test_main.rs"),
332 "Should not include test_main.rs even though it's a .rs file (because it doesn't have the pattern)"
333 );
334
335 // Test with include pattern for src directory only
336 let input = serde_json::to_value(GrepToolInput {
337 regex: "fn".to_string(),
338 include_pattern: Some("root/**/src/**".to_string()),
339 offset: 0,
340 case_sensitive: false,
341 })
342 .unwrap();
343
344 let result = run_grep_tool(input, project.clone(), cx).await;
345 assert!(
346 result.contains("main.rs"),
347 "Should find matches in src/main.rs"
348 );
349 assert!(
350 result.contains("helper.rs"),
351 "Should find matches in src/utils/helper.rs"
352 );
353 assert!(
354 !result.contains("test_main.rs"),
355 "Should not include test_main.rs as it's not in src directory"
356 );
357
358 // Test with empty include pattern (should default to all files)
359 let input = serde_json::to_value(GrepToolInput {
360 regex: "fn".to_string(),
361 include_pattern: None,
362 offset: 0,
363 case_sensitive: false,
364 })
365 .unwrap();
366
367 let result = run_grep_tool(input, project.clone(), cx).await;
368 assert!(result.contains("main.rs"), "Should find matches in main.rs");
369 assert!(
370 result.contains("helper.rs"),
371 "Should find matches in helper.rs"
372 );
373 assert!(
374 result.contains("test_main.rs"),
375 "Should include test_main.rs"
376 );
377 }
378
379 #[gpui::test]
380 async fn test_grep_tool_with_case_sensitivity(cx: &mut TestAppContext) {
381 init_test(cx);
382 cx.executor().allow_parking();
383
384 let fs = FakeFs::new(cx.executor().clone());
385 fs.insert_tree(
386 "/root",
387 serde_json::json!({
388 "case_test.txt": "This file has UPPERCASE and lowercase text.\nUPPERCASE patterns should match only with case_sensitive: true",
389 }),
390 )
391 .await;
392
393 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
394
395 // Test case-insensitive search (default)
396 let input = serde_json::to_value(GrepToolInput {
397 regex: "uppercase".to_string(),
398 include_pattern: Some("**/*.txt".to_string()),
399 offset: 0,
400 case_sensitive: false,
401 })
402 .unwrap();
403
404 let result = run_grep_tool(input, project.clone(), cx).await;
405 assert!(
406 result.contains("UPPERCASE"),
407 "Case-insensitive search should match uppercase"
408 );
409
410 // Test case-sensitive search
411 let input = serde_json::to_value(GrepToolInput {
412 regex: "uppercase".to_string(),
413 include_pattern: Some("**/*.txt".to_string()),
414 offset: 0,
415 case_sensitive: true,
416 })
417 .unwrap();
418
419 let result = run_grep_tool(input, project.clone(), cx).await;
420 assert!(
421 !result.contains("UPPERCASE"),
422 "Case-sensitive search should not match uppercase"
423 );
424
425 // Test case-sensitive search
426 let input = serde_json::to_value(GrepToolInput {
427 regex: "LOWERCASE".to_string(),
428 include_pattern: Some("**/*.txt".to_string()),
429 offset: 0,
430 case_sensitive: true,
431 })
432 .unwrap();
433
434 let result = run_grep_tool(input, project.clone(), cx).await;
435
436 assert!(
437 !result.contains("lowercase"),
438 "Case-sensitive search should match lowercase"
439 );
440
441 // Test case-sensitive search for lowercase pattern
442 let input = serde_json::to_value(GrepToolInput {
443 regex: "lowercase".to_string(),
444 include_pattern: Some("**/*.txt".to_string()),
445 offset: 0,
446 case_sensitive: true,
447 })
448 .unwrap();
449
450 let result = run_grep_tool(input, project.clone(), cx).await;
451 assert!(
452 result.contains("lowercase"),
453 "Case-sensitive search should match lowercase text"
454 );
455 }
456
457 /// Helper function to set up a syntax test environment
458 async fn setup_syntax_test(cx: &mut TestAppContext) -> Entity<Project> {
459 use unindent::Unindent;
460 init_test(cx);
461 cx.executor().allow_parking();
462
463 let fs = FakeFs::new(cx.executor().clone());
464
465 // Create test file with syntax structures
466 fs.insert_tree(
467 "/root",
468 serde_json::json!({
469 "test_syntax.rs": r#"
470 fn top_level_function() {
471 println!("This is at the top level");
472 }
473
474 mod feature_module {
475 pub mod nested_module {
476 pub fn nested_function(
477 first_arg: String,
478 second_arg: i32,
479 ) {
480 println!("Function in nested module");
481 println!("{first_arg}");
482 println!("{second_arg}");
483 }
484 }
485 }
486
487 struct MyStruct {
488 field1: String,
489 field2: i32,
490 }
491
492 impl MyStruct {
493 fn method_with_block() {
494 let condition = true;
495 if condition {
496 println!("Inside if block");
497 }
498 }
499
500 fn long_function() {
501 println!("Line 1");
502 println!("Line 2");
503 println!("Line 3");
504 println!("Line 4");
505 println!("Line 5");
506 println!("Line 6");
507 println!("Line 7");
508 println!("Line 8");
509 println!("Line 9");
510 println!("Line 10");
511 println!("Line 11");
512 println!("Line 12");
513 }
514 }
515
516 trait Processor {
517 fn process(&self, input: &str) -> String;
518 }
519
520 impl Processor for MyStruct {
521 fn process(&self, input: &str) -> String {
522 format!("Processed: {}", input)
523 }
524 }
525 "#.unindent().trim(),
526 }),
527 )
528 .await;
529
530 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
531
532 project.update(cx, |project, _cx| {
533 project.languages().add(rust_lang().into())
534 });
535
536 project
537 }
538
539 #[gpui::test]
540 async fn test_grep_top_level_function(cx: &mut TestAppContext) {
541 let project = setup_syntax_test(cx).await;
542
543 // Test: Line at the top level of the file
544 let input = serde_json::to_value(GrepToolInput {
545 regex: "This is at the top level".to_string(),
546 include_pattern: Some("**/*.rs".to_string()),
547 offset: 0,
548 case_sensitive: false,
549 })
550 .unwrap();
551
552 let result = run_grep_tool(input, project.clone(), cx).await;
553 let expected = r#"
554 Found 1 matches:
555
556 ## Matches in root/test_syntax.rs
557
558 ### fn top_level_function › L1-3
559 ```
560 fn top_level_function() {
561 println!("This is at the top level");
562 }
563 ```
564 "#
565 .unindent();
566 assert_eq!(result, expected);
567 }
568
569 #[gpui::test]
570 async fn test_grep_function_body(cx: &mut TestAppContext) {
571 let project = setup_syntax_test(cx).await;
572
573 // Test: Line inside a function body
574 let input = serde_json::to_value(GrepToolInput {
575 regex: "Function in nested module".to_string(),
576 include_pattern: Some("**/*.rs".to_string()),
577 offset: 0,
578 case_sensitive: false,
579 })
580 .unwrap();
581
582 let result = run_grep_tool(input, project.clone(), cx).await;
583 let expected = r#"
584 Found 1 matches:
585
586 ## Matches in root/test_syntax.rs
587
588 ### mod feature_module › pub mod nested_module › pub fn nested_function › L10-14
589 ```
590 ) {
591 println!("Function in nested module");
592 println!("{first_arg}");
593 println!("{second_arg}");
594 }
595 ```
596 "#
597 .unindent();
598 assert_eq!(result, expected);
599 }
600
601 #[gpui::test]
602 async fn test_grep_function_args_and_body(cx: &mut TestAppContext) {
603 let project = setup_syntax_test(cx).await;
604
605 // Test: Line with a function argument
606 let input = serde_json::to_value(GrepToolInput {
607 regex: "second_arg".to_string(),
608 include_pattern: Some("**/*.rs".to_string()),
609 offset: 0,
610 case_sensitive: false,
611 })
612 .unwrap();
613
614 let result = run_grep_tool(input, project.clone(), cx).await;
615 let expected = r#"
616 Found 1 matches:
617
618 ## Matches in root/test_syntax.rs
619
620 ### mod feature_module › pub mod nested_module › pub fn nested_function › L7-14
621 ```
622 pub fn nested_function(
623 first_arg: String,
624 second_arg: i32,
625 ) {
626 println!("Function in nested module");
627 println!("{first_arg}");
628 println!("{second_arg}");
629 }
630 ```
631 "#
632 .unindent();
633 assert_eq!(result, expected);
634 }
635
636 #[gpui::test]
637 async fn test_grep_if_block(cx: &mut TestAppContext) {
638 use unindent::Unindent;
639 let project = setup_syntax_test(cx).await;
640
641 // Test: Line inside an if block
642 let input = serde_json::to_value(GrepToolInput {
643 regex: "Inside if block".to_string(),
644 include_pattern: Some("**/*.rs".to_string()),
645 offset: 0,
646 case_sensitive: false,
647 })
648 .unwrap();
649
650 let result = run_grep_tool(input, project.clone(), cx).await;
651 let expected = r#"
652 Found 1 matches:
653
654 ## Matches in root/test_syntax.rs
655
656 ### impl MyStruct › fn method_with_block › L26-28
657 ```
658 if condition {
659 println!("Inside if block");
660 }
661 ```
662 "#
663 .unindent();
664 assert_eq!(result, expected);
665 }
666
667 #[gpui::test]
668 async fn test_grep_long_function_top(cx: &mut TestAppContext) {
669 use unindent::Unindent;
670 let project = setup_syntax_test(cx).await;
671
672 // Test: Line in the middle of a long function - should show message about remaining lines
673 let input = serde_json::to_value(GrepToolInput {
674 regex: "Line 5".to_string(),
675 include_pattern: Some("**/*.rs".to_string()),
676 offset: 0,
677 case_sensitive: false,
678 })
679 .unwrap();
680
681 let result = run_grep_tool(input, project.clone(), cx).await;
682 let expected = r#"
683 Found 1 matches:
684
685 ## Matches in root/test_syntax.rs
686
687 ### impl MyStruct › fn long_function › L31-41
688 ```
689 fn long_function() {
690 println!("Line 1");
691 println!("Line 2");
692 println!("Line 3");
693 println!("Line 4");
694 println!("Line 5");
695 println!("Line 6");
696 println!("Line 7");
697 println!("Line 8");
698 println!("Line 9");
699 println!("Line 10");
700 ```
701
702 3 lines remaining in ancestor node. Read the file to see all.
703 "#
704 .unindent();
705 assert_eq!(result, expected);
706 }
707
708 #[gpui::test]
709 async fn test_grep_long_function_bottom(cx: &mut TestAppContext) {
710 use unindent::Unindent;
711 let project = setup_syntax_test(cx).await;
712
713 // Test: Line in the long function
714 let input = serde_json::to_value(GrepToolInput {
715 regex: "Line 12".to_string(),
716 include_pattern: Some("**/*.rs".to_string()),
717 offset: 0,
718 case_sensitive: false,
719 })
720 .unwrap();
721
722 let result = run_grep_tool(input, project.clone(), cx).await;
723 let expected = r#"
724 Found 1 matches:
725
726 ## Matches in root/test_syntax.rs
727
728 ### impl MyStruct › fn long_function › L41-45
729 ```
730 println!("Line 10");
731 println!("Line 11");
732 println!("Line 12");
733 }
734 }
735 ```
736 "#
737 .unindent();
738 assert_eq!(result, expected);
739 }
740
741 async fn run_grep_tool(
742 input: serde_json::Value,
743 project: Entity<Project>,
744 cx: &mut TestAppContext,
745 ) -> String {
746 let tool = Arc::new(GrepTool);
747 let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
748 let model = Arc::new(FakeLanguageModel::default());
749 let task =
750 cx.update(|cx| tool.run(input, Arc::default(), project, action_log, model, None, cx));
751
752 match task.output.await {
753 Ok(result) => {
754 if cfg!(windows) {
755 result.content.as_str().unwrap().replace("root\\", "root/")
756 } else {
757 result.content.as_str().unwrap().to_string()
758 }
759 }
760 Err(e) => panic!("Failed to run grep tool: {}", e),
761 }
762 }
763
764 fn init_test(cx: &mut TestAppContext) {
765 cx.update(|cx| {
766 let settings_store = SettingsStore::test(cx);
767 cx.set_global(settings_store);
768 language::init(cx);
769 Project::init_settings(cx);
770 });
771 }
772
773 fn rust_lang() -> Language {
774 Language::new(
775 LanguageConfig {
776 name: "Rust".into(),
777 matcher: LanguageMatcher {
778 path_suffixes: vec!["rs".to_string()],
779 ..Default::default()
780 },
781 ..Default::default()
782 },
783 Some(tree_sitter_rust::LANGUAGE.into()),
784 )
785 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
786 .unwrap()
787 }
788}