grep_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use futures::StreamExt;
  5use gpui::{AnyWindowHandle, App, Entity, Task};
  6use language::OffsetRangeExt;
  7use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
  8use project::{
  9    Project,
 10    search::{SearchQuery, SearchResult},
 11};
 12use schemars::JsonSchema;
 13use serde::{Deserialize, Serialize};
 14use std::{cmp, fmt::Write, sync::Arc};
 15use ui::IconName;
 16use util::markdown::MarkdownInlineCode;
 17use util::paths::PathMatcher;
 18
 19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 20pub struct GrepToolInput {
 21    /// A regex pattern to search for in the entire project. Note that the regex
 22    /// will be parsed by the Rust `regex` crate.
 23    ///
 24    /// Do NOT specify a path here! This will only be matched against the code **content**.
 25    pub regex: String,
 26
 27    /// A glob pattern for the paths of files to include in the search.
 28    /// Supports standard glob patterns like "**/*.rs" or "src/**/*.ts".
 29    /// If omitted, all files in the project will be searched.
 30    pub include_pattern: Option<String>,
 31
 32    /// Optional starting position for paginated results (0-based).
 33    /// When not provided, starts from the beginning.
 34    #[serde(default)]
 35    pub offset: u32,
 36
 37    /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
 38    #[serde(default)]
 39    pub case_sensitive: bool,
 40}
 41
 42impl GrepToolInput {
 43    /// Which page of search results this is.
 44    pub fn page(&self) -> u32 {
 45        1 + (self.offset / RESULTS_PER_PAGE)
 46    }
 47}
 48
 49const RESULTS_PER_PAGE: u32 = 20;
 50
 51pub struct GrepTool;
 52
 53impl Tool for GrepTool {
 54    fn name(&self) -> String {
 55        "grep".into()
 56    }
 57
 58    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 59        false
 60    }
 61
 62    fn description(&self) -> String {
 63        include_str!("./grep_tool/description.md").into()
 64    }
 65
 66    fn icon(&self) -> IconName {
 67        IconName::Regex
 68    }
 69
 70    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 71        json_schema_for::<GrepToolInput>(format)
 72    }
 73
 74    fn ui_text(&self, input: &serde_json::Value) -> String {
 75        match serde_json::from_value::<GrepToolInput>(input.clone()) {
 76            Ok(input) => {
 77                let page = input.page();
 78                let regex_str = MarkdownInlineCode(&input.regex);
 79                let case_info = if input.case_sensitive {
 80                    " (case-sensitive)"
 81                } else {
 82                    ""
 83                };
 84
 85                if page > 1 {
 86                    format!("Get page {page} of search results for regex {regex_str}{case_info}")
 87                } else {
 88                    format!("Search files for regex {regex_str}{case_info}")
 89                }
 90            }
 91            Err(_) => "Search with regex".to_string(),
 92        }
 93    }
 94
 95    fn run(
 96        self: Arc<Self>,
 97        input: serde_json::Value,
 98        _messages: &[LanguageModelRequestMessage],
 99        project: Entity<Project>,
100        _action_log: Entity<ActionLog>,
101        _window: Option<AnyWindowHandle>,
102        cx: &mut App,
103    ) -> ToolResult {
104        const CONTEXT_LINES: u32 = 2;
105
106        let input = match serde_json::from_value::<GrepToolInput>(input) {
107            Ok(input) => input,
108            Err(error) => {
109                return Task::ready(Err(anyhow!("Failed to parse input: {}", error))).into();
110            }
111        };
112
113        let include_matcher = match PathMatcher::new(
114            input
115                .include_pattern
116                .as_ref()
117                .into_iter()
118                .collect::<Vec<_>>(),
119        ) {
120            Ok(matcher) => matcher,
121            Err(error) => {
122                return Task::ready(Err(anyhow!("invalid include glob pattern: {}", error))).into();
123            }
124        };
125
126        let query = match SearchQuery::regex(
127            &input.regex,
128            false,
129            input.case_sensitive,
130            false,
131            false,
132            include_matcher,
133            PathMatcher::default(), // For now, keep it simple and don't enable an exclude pattern.
134            true, // Always match file include pattern against *full project paths* that start with a project root.
135            None,
136        ) {
137            Ok(query) => query,
138            Err(error) => return Task::ready(Err(error)).into(),
139        };
140
141        let results = project.update(cx, |project, cx| project.search(query, cx));
142
143        cx.spawn(async move|cx|  {
144            futures::pin_mut!(results);
145
146            let mut output = String::new();
147            let mut skips_remaining = input.offset;
148            let mut matches_found = 0;
149            let mut has_more_matches = false;
150
151            while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
152                if ranges.is_empty() {
153                    continue;
154                }
155
156                buffer.read_with(cx, |buffer, cx| -> Result<(), anyhow::Error> {
157                    if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
158                        let mut file_header_written = false;
159                        let mut ranges = ranges
160                            .into_iter()
161                            .map(|range| {
162                                let mut point_range = range.to_point(buffer);
163                                point_range.start.row =
164                                    point_range.start.row.saturating_sub(CONTEXT_LINES);
165                                point_range.start.column = 0;
166                                point_range.end.row = cmp::min(
167                                    buffer.max_point().row,
168                                    point_range.end.row + CONTEXT_LINES,
169                                );
170                                point_range.end.column = buffer.line_len(point_range.end.row);
171                                point_range
172                            })
173                            .peekable();
174
175                        while let Some(mut range) = ranges.next() {
176                            if skips_remaining > 0 {
177                                skips_remaining -= 1;
178                                continue;
179                            }
180
181                            // We'd already found a full page of matches, and we just found one more.
182                            if matches_found >= RESULTS_PER_PAGE {
183                                has_more_matches = true;
184                                return Ok(());
185                            }
186
187                            while let Some(next_range) = ranges.peek() {
188                                if range.end.row >= next_range.start.row {
189                                    range.end = next_range.end;
190                                    ranges.next();
191                                } else {
192                                    break;
193                                }
194                            }
195
196                            if !file_header_written {
197                                writeln!(output, "\n## Matches in {}", path.display())?;
198                                file_header_written = true;
199                            }
200
201                            let start_line = range.start.row + 1;
202                            let end_line = range.end.row + 1;
203                            writeln!(output, "\n### Lines {start_line}-{end_line}\n```")?;
204                            output.extend(buffer.text_for_range(range));
205                            output.push_str("\n```\n");
206
207                            matches_found += 1;
208                        }
209                    }
210
211                    Ok(())
212                })??;
213            }
214
215            if matches_found == 0 {
216                Ok("No matches found".to_string())
217            } else if has_more_matches {
218                Ok(format!(
219                    "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
220                    input.offset + 1,
221                    input.offset + matches_found,
222                    input.offset + RESULTS_PER_PAGE,
223                ))
224            } else {
225                Ok(format!("Found {matches_found} matches:\n{output}"))
226            }
227        }).into()
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use assistant_tool::Tool;
235    use gpui::{AppContext, TestAppContext};
236    use project::{FakeFs, Project};
237    use settings::SettingsStore;
238    use util::path;
239
240    #[gpui::test]
241    async fn test_grep_tool_with_include_pattern(cx: &mut TestAppContext) {
242        init_test(cx);
243
244        let fs = FakeFs::new(cx.executor().clone());
245        fs.insert_tree(
246            "/root",
247            serde_json::json!({
248                "src": {
249                    "main.rs": "fn main() {\n    println!(\"Hello, world!\");\n}",
250                    "utils": {
251                        "helper.rs": "fn helper() {\n    println!(\"I'm a helper!\");\n}",
252                    },
253                },
254                "tests": {
255                    "test_main.rs": "fn test_main() {\n    assert!(true);\n}",
256                }
257            }),
258        )
259        .await;
260
261        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
262
263        // Test with include pattern for Rust files inside the root of the project
264        let input = serde_json::to_value(GrepToolInput {
265            regex: "println".to_string(),
266            include_pattern: Some("root/**/*.rs".to_string()),
267            offset: 0,
268            case_sensitive: false,
269        })
270        .unwrap();
271
272        let result = run_grep_tool(input, project.clone(), cx).await;
273        assert!(result.contains("main.rs"), "Should find matches in main.rs");
274        assert!(
275            result.contains("helper.rs"),
276            "Should find matches in helper.rs"
277        );
278        assert!(
279            !result.contains("test_main.rs"),
280            "Should not include test_main.rs even though it's a .rs file (because it doesn't have the pattern)"
281        );
282
283        // Test with include pattern for src directory only
284        let input = serde_json::to_value(GrepToolInput {
285            regex: "fn".to_string(),
286            include_pattern: Some("root/**/src/**".to_string()),
287            offset: 0,
288            case_sensitive: false,
289        })
290        .unwrap();
291
292        let result = run_grep_tool(input, project.clone(), cx).await;
293        assert!(
294            result.contains("main.rs"),
295            "Should find matches in src/main.rs"
296        );
297        assert!(
298            result.contains("helper.rs"),
299            "Should find matches in src/utils/helper.rs"
300        );
301        assert!(
302            !result.contains("test_main.rs"),
303            "Should not include test_main.rs as it's not in src directory"
304        );
305
306        // Test with empty include pattern (should default to all files)
307        let input = serde_json::to_value(GrepToolInput {
308            regex: "fn".to_string(),
309            include_pattern: None,
310            offset: 0,
311            case_sensitive: false,
312        })
313        .unwrap();
314
315        let result = run_grep_tool(input, project.clone(), cx).await;
316        assert!(result.contains("main.rs"), "Should find matches in main.rs");
317        assert!(
318            result.contains("helper.rs"),
319            "Should find matches in helper.rs"
320        );
321        assert!(
322            result.contains("test_main.rs"),
323            "Should include test_main.rs"
324        );
325    }
326
327    #[gpui::test]
328    async fn test_grep_tool_with_case_sensitivity(cx: &mut TestAppContext) {
329        init_test(cx);
330
331        let fs = FakeFs::new(cx.executor().clone());
332        fs.insert_tree(
333            "/root",
334            serde_json::json!({
335                "case_test.txt": "This file has UPPERCASE and lowercase text.\nUPPERCASE patterns should match only with case_sensitive: true",
336            }),
337        )
338        .await;
339
340        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
341
342        // Test case-insensitive search (default)
343        let input = serde_json::to_value(GrepToolInput {
344            regex: "uppercase".to_string(),
345            include_pattern: Some("**/*.txt".to_string()),
346            offset: 0,
347            case_sensitive: false,
348        })
349        .unwrap();
350
351        let result = run_grep_tool(input, project.clone(), cx).await;
352        assert!(
353            result.contains("UPPERCASE"),
354            "Case-insensitive search should match uppercase"
355        );
356
357        // Test case-sensitive search
358        let input = serde_json::to_value(GrepToolInput {
359            regex: "uppercase".to_string(),
360            include_pattern: Some("**/*.txt".to_string()),
361            offset: 0,
362            case_sensitive: true,
363        })
364        .unwrap();
365
366        let result = run_grep_tool(input, project.clone(), cx).await;
367        assert!(
368            !result.contains("UPPERCASE"),
369            "Case-sensitive search should not match uppercase"
370        );
371
372        // Test case-sensitive search
373        let input = serde_json::to_value(GrepToolInput {
374            regex: "LOWERCASE".to_string(),
375            include_pattern: Some("**/*.txt".to_string()),
376            offset: 0,
377            case_sensitive: true,
378        })
379        .unwrap();
380
381        let result = run_grep_tool(input, project.clone(), cx).await;
382
383        assert!(
384            !result.contains("lowercase"),
385            "Case-sensitive search should match lowercase"
386        );
387
388        // Test case-sensitive search for lowercase pattern
389        let input = serde_json::to_value(GrepToolInput {
390            regex: "lowercase".to_string(),
391            include_pattern: Some("**/*.txt".to_string()),
392            offset: 0,
393            case_sensitive: true,
394        })
395        .unwrap();
396
397        let result = run_grep_tool(input, project.clone(), cx).await;
398        assert!(
399            result.contains("lowercase"),
400            "Case-sensitive search should match lowercase text"
401        );
402    }
403
404    async fn run_grep_tool(
405        input: serde_json::Value,
406        project: Entity<Project>,
407        cx: &mut TestAppContext,
408    ) -> String {
409        let tool = Arc::new(GrepTool);
410        let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
411        let task = cx.update(|cx| tool.run(input, &[], project, action_log, None, cx));
412
413        match task.output.await {
414            Ok(result) => result,
415            Err(e) => panic!("Failed to run grep tool: {}", e),
416        }
417    }
418
419    fn init_test(cx: &mut TestAppContext) {
420        cx.update(|cx| {
421            let settings_store = SettingsStore::test(cx);
422            cx.set_global(settings_store);
423            language::init(cx);
424            Project::init_settings(cx);
425        });
426    }
427}