grep_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use futures::StreamExt;
  5use gpui::{AnyWindowHandle, App, Entity, Task};
  6use language::OffsetRangeExt;
  7use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
  8use project::{
  9    Project,
 10    search::{SearchQuery, SearchResult},
 11};
 12use schemars::JsonSchema;
 13use serde::{Deserialize, Serialize};
 14use std::{cmp, fmt::Write, sync::Arc};
 15use ui::IconName;
 16use util::markdown::MarkdownString;
 17use util::paths::PathMatcher;
 18
 19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 20pub struct GrepToolInput {
 21    /// A regex pattern to search for in the entire project. Note that the regex
 22    /// will be parsed by the Rust `regex` crate.
 23    pub regex: String,
 24
 25    /// A glob pattern for the paths of files to include in the search.
 26    /// Supports standard glob patterns like "**/*.rs" or "src/**/*.ts".
 27    /// If omitted, all files in the project will be searched.
 28    pub include_pattern: Option<String>,
 29
 30    /// Optional starting position for paginated results (0-based).
 31    /// When not provided, starts from the beginning.
 32    #[serde(default)]
 33    pub offset: u32,
 34
 35    /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
 36    #[serde(default)]
 37    pub case_sensitive: bool,
 38}
 39
 40impl GrepToolInput {
 41    /// Which page of search results this is.
 42    pub fn page(&self) -> u32 {
 43        1 + (self.offset / RESULTS_PER_PAGE)
 44    }
 45}
 46
 47const RESULTS_PER_PAGE: u32 = 20;
 48
 49pub struct GrepTool;
 50
 51impl Tool for GrepTool {
 52    fn name(&self) -> String {
 53        "grep".into()
 54    }
 55
 56    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 57        false
 58    }
 59
 60    fn description(&self) -> String {
 61        include_str!("./grep_tool/description.md").into()
 62    }
 63
 64    fn icon(&self) -> IconName {
 65        IconName::Regex
 66    }
 67
 68    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 69        json_schema_for::<GrepToolInput>(format)
 70    }
 71
 72    fn ui_text(&self, input: &serde_json::Value) -> String {
 73        match serde_json::from_value::<GrepToolInput>(input.clone()) {
 74            Ok(input) => {
 75                let page = input.page();
 76                let regex_str = MarkdownString::inline_code(&input.regex);
 77                let case_info = if input.case_sensitive {
 78                    " (case-sensitive)"
 79                } else {
 80                    ""
 81                };
 82
 83                if page > 1 {
 84                    format!("Get page {page} of search results for regex {regex_str}{case_info}")
 85                } else {
 86                    format!("Search files for regex {regex_str}{case_info}")
 87                }
 88            }
 89            Err(_) => "Search with regex".to_string(),
 90        }
 91    }
 92
 93    fn run(
 94        self: Arc<Self>,
 95        input: serde_json::Value,
 96        _messages: &[LanguageModelRequestMessage],
 97        project: Entity<Project>,
 98        _action_log: Entity<ActionLog>,
 99        _window: Option<AnyWindowHandle>,
100        cx: &mut App,
101    ) -> ToolResult {
102        const CONTEXT_LINES: u32 = 2;
103
104        let input = match serde_json::from_value::<GrepToolInput>(input) {
105            Ok(input) => input,
106            Err(error) => {
107                return Task::ready(Err(anyhow!("Failed to parse input: {}", error))).into();
108            }
109        };
110
111        let include_matcher = match PathMatcher::new(
112            input
113                .include_pattern
114                .as_ref()
115                .into_iter()
116                .collect::<Vec<_>>(),
117        ) {
118            Ok(matcher) => matcher,
119            Err(error) => {
120                return Task::ready(Err(anyhow!("invalid include glob pattern: {}", error))).into();
121            }
122        };
123
124        let query = match SearchQuery::regex(
125            &input.regex,
126            false,
127            input.case_sensitive,
128            false,
129            false,
130            include_matcher,
131            PathMatcher::default(), // For now, keep it simple and don't enable an exclude pattern.
132            true, // Always match file include pattern against *full project paths* that start with a project root.
133            None,
134        ) {
135            Ok(query) => query,
136            Err(error) => return Task::ready(Err(error)).into(),
137        };
138
139        let results = project.update(cx, |project, cx| project.search(query, cx));
140
141        cx.spawn(async move|cx|  {
142            futures::pin_mut!(results);
143
144            let mut output = String::new();
145            let mut skips_remaining = input.offset;
146            let mut matches_found = 0;
147            let mut has_more_matches = false;
148
149            while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
150                if ranges.is_empty() {
151                    continue;
152                }
153
154                buffer.read_with(cx, |buffer, cx| -> Result<(), anyhow::Error> {
155                    if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
156                        let mut file_header_written = false;
157                        let mut ranges = ranges
158                            .into_iter()
159                            .map(|range| {
160                                let mut point_range = range.to_point(buffer);
161                                point_range.start.row =
162                                    point_range.start.row.saturating_sub(CONTEXT_LINES);
163                                point_range.start.column = 0;
164                                point_range.end.row = cmp::min(
165                                    buffer.max_point().row,
166                                    point_range.end.row + CONTEXT_LINES,
167                                );
168                                point_range.end.column = buffer.line_len(point_range.end.row);
169                                point_range
170                            })
171                            .peekable();
172
173                        while let Some(mut range) = ranges.next() {
174                            if skips_remaining > 0 {
175                                skips_remaining -= 1;
176                                continue;
177                            }
178
179                            // We'd already found a full page of matches, and we just found one more.
180                            if matches_found >= RESULTS_PER_PAGE {
181                                has_more_matches = true;
182                                return Ok(());
183                            }
184
185                            while let Some(next_range) = ranges.peek() {
186                                if range.end.row >= next_range.start.row {
187                                    range.end = next_range.end;
188                                    ranges.next();
189                                } else {
190                                    break;
191                                }
192                            }
193
194                            if !file_header_written {
195                                writeln!(output, "\n## Matches in {}", path.display())?;
196                                file_header_written = true;
197                            }
198
199                            let start_line = range.start.row + 1;
200                            let end_line = range.end.row + 1;
201                            writeln!(output, "\n### Lines {start_line}-{end_line}\n```")?;
202                            output.extend(buffer.text_for_range(range));
203                            output.push_str("\n```\n");
204
205                            matches_found += 1;
206                        }
207                    }
208
209                    Ok(())
210                })??;
211            }
212
213            if matches_found == 0 {
214                Ok("No matches found".to_string())
215            } else if has_more_matches {
216                Ok(format!(
217                    "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
218                    input.offset + 1,
219                    input.offset + matches_found,
220                    input.offset + RESULTS_PER_PAGE,
221                ))
222            } else {
223                Ok(format!("Found {matches_found} matches:\n{output}"))
224            }
225        }).into()
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use assistant_tool::Tool;
233    use gpui::{AppContext, TestAppContext};
234    use project::{FakeFs, Project};
235    use settings::SettingsStore;
236    use util::path;
237
238    #[gpui::test]
239    async fn test_grep_tool_with_include_pattern(cx: &mut TestAppContext) {
240        init_test(cx);
241
242        let fs = FakeFs::new(cx.executor().clone());
243        fs.insert_tree(
244            "/root",
245            serde_json::json!({
246                "src": {
247                    "main.rs": "fn main() {\n    println!(\"Hello, world!\");\n}",
248                    "utils": {
249                        "helper.rs": "fn helper() {\n    println!(\"I'm a helper!\");\n}",
250                    },
251                },
252                "tests": {
253                    "test_main.rs": "fn test_main() {\n    assert!(true);\n}",
254                }
255            }),
256        )
257        .await;
258
259        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
260
261        // Test with include pattern for Rust files inside the root of the project
262        let input = serde_json::to_value(GrepToolInput {
263            regex: "println".to_string(),
264            include_pattern: Some("root/**/*.rs".to_string()),
265            offset: 0,
266            case_sensitive: false,
267        })
268        .unwrap();
269
270        let result = run_grep_tool(input, project.clone(), cx).await;
271        assert!(result.contains("main.rs"), "Should find matches in main.rs");
272        assert!(
273            result.contains("helper.rs"),
274            "Should find matches in helper.rs"
275        );
276        assert!(
277            !result.contains("test_main.rs"),
278            "Should not include test_main.rs even though it's a .rs file (because it doesn't have the pattern)"
279        );
280
281        // Test with include pattern for src directory only
282        let input = serde_json::to_value(GrepToolInput {
283            regex: "fn".to_string(),
284            include_pattern: Some("root/**/src/**".to_string()),
285            offset: 0,
286            case_sensitive: false,
287        })
288        .unwrap();
289
290        let result = run_grep_tool(input, project.clone(), cx).await;
291        assert!(
292            result.contains("main.rs"),
293            "Should find matches in src/main.rs"
294        );
295        assert!(
296            result.contains("helper.rs"),
297            "Should find matches in src/utils/helper.rs"
298        );
299        assert!(
300            !result.contains("test_main.rs"),
301            "Should not include test_main.rs as it's not in src directory"
302        );
303
304        // Test with empty include pattern (should default to all files)
305        let input = serde_json::to_value(GrepToolInput {
306            regex: "fn".to_string(),
307            include_pattern: None,
308            offset: 0,
309            case_sensitive: false,
310        })
311        .unwrap();
312
313        let result = run_grep_tool(input, project.clone(), cx).await;
314        assert!(result.contains("main.rs"), "Should find matches in main.rs");
315        assert!(
316            result.contains("helper.rs"),
317            "Should find matches in helper.rs"
318        );
319        assert!(
320            result.contains("test_main.rs"),
321            "Should include test_main.rs"
322        );
323    }
324
325    #[gpui::test]
326    async fn test_grep_tool_with_case_sensitivity(cx: &mut TestAppContext) {
327        init_test(cx);
328
329        let fs = FakeFs::new(cx.executor().clone());
330        fs.insert_tree(
331            "/root",
332            serde_json::json!({
333                "case_test.txt": "This file has UPPERCASE and lowercase text.\nUPPERCASE patterns should match only with case_sensitive: true",
334            }),
335        )
336        .await;
337
338        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
339
340        // Test case-insensitive search (default)
341        let input = serde_json::to_value(GrepToolInput {
342            regex: "uppercase".to_string(),
343            include_pattern: Some("**/*.txt".to_string()),
344            offset: 0,
345            case_sensitive: false,
346        })
347        .unwrap();
348
349        let result = run_grep_tool(input, project.clone(), cx).await;
350        assert!(
351            result.contains("UPPERCASE"),
352            "Case-insensitive search should match uppercase"
353        );
354
355        // Test case-sensitive search
356        let input = serde_json::to_value(GrepToolInput {
357            regex: "uppercase".to_string(),
358            include_pattern: Some("**/*.txt".to_string()),
359            offset: 0,
360            case_sensitive: true,
361        })
362        .unwrap();
363
364        let result = run_grep_tool(input, project.clone(), cx).await;
365        assert!(
366            !result.contains("UPPERCASE"),
367            "Case-sensitive search should not match uppercase"
368        );
369
370        // Test case-sensitive search
371        let input = serde_json::to_value(GrepToolInput {
372            regex: "LOWERCASE".to_string(),
373            include_pattern: Some("**/*.txt".to_string()),
374            offset: 0,
375            case_sensitive: true,
376        })
377        .unwrap();
378
379        let result = run_grep_tool(input, project.clone(), cx).await;
380
381        assert!(
382            !result.contains("lowercase"),
383            "Case-sensitive search should match lowercase"
384        );
385
386        // Test case-sensitive search for lowercase pattern
387        let input = serde_json::to_value(GrepToolInput {
388            regex: "lowercase".to_string(),
389            include_pattern: Some("**/*.txt".to_string()),
390            offset: 0,
391            case_sensitive: true,
392        })
393        .unwrap();
394
395        let result = run_grep_tool(input, project.clone(), cx).await;
396        assert!(
397            result.contains("lowercase"),
398            "Case-sensitive search should match lowercase text"
399        );
400    }
401
402    async fn run_grep_tool(
403        input: serde_json::Value,
404        project: Entity<Project>,
405        cx: &mut TestAppContext,
406    ) -> String {
407        let tool = Arc::new(GrepTool);
408        let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
409        let task = cx.update(|cx| tool.run(input, &[], project, action_log, None, cx));
410
411        match task.output.await {
412            Ok(result) => result,
413            Err(e) => panic!("Failed to run grep tool: {}", e),
414        }
415    }
416
417    fn init_test(cx: &mut TestAppContext) {
418        cx.update(|cx| {
419            let settings_store = SettingsStore::test(cx);
420            cx.set_global(settings_store);
421            language::init(cx);
422            Project::init_settings(cx);
423        });
424    }
425}