1use crate::schema::json_schema_for;
2use anyhow::{Result, anyhow};
3use assistant_tool::{ActionLog, Tool, ToolResult};
4use futures::StreamExt;
5use gpui::{AnyWindowHandle, App, Entity, Task};
6use language::OffsetRangeExt;
7use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
8use project::{
9 Project,
10 search::{SearchQuery, SearchResult},
11};
12use schemars::JsonSchema;
13use serde::{Deserialize, Serialize};
14use std::{cmp, fmt::Write, sync::Arc};
15use ui::IconName;
16use util::markdown::MarkdownString;
17use util::paths::PathMatcher;
18
19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
20pub struct GrepToolInput {
21 /// A regex pattern to search for in the entire project. Note that the regex
22 /// will be parsed by the Rust `regex` crate.
23 pub regex: String,
24
25 /// A glob pattern for the paths of files to include in the search.
26 /// Supports standard glob patterns like "**/*.rs" or "src/**/*.ts".
27 /// If omitted, all files in the project will be searched.
28 pub include_pattern: Option<String>,
29
30 /// Optional starting position for paginated results (0-based).
31 /// When not provided, starts from the beginning.
32 #[serde(default)]
33 pub offset: u32,
34
35 /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
36 #[serde(default)]
37 pub case_sensitive: bool,
38}
39
40impl GrepToolInput {
41 /// Which page of search results this is.
42 pub fn page(&self) -> u32 {
43 1 + (self.offset / RESULTS_PER_PAGE)
44 }
45}
46
47const RESULTS_PER_PAGE: u32 = 20;
48
49pub struct GrepTool;
50
51impl Tool for GrepTool {
52 fn name(&self) -> String {
53 "grep".into()
54 }
55
56 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
57 false
58 }
59
60 fn description(&self) -> String {
61 include_str!("./grep_tool/description.md").into()
62 }
63
64 fn icon(&self) -> IconName {
65 IconName::Regex
66 }
67
68 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
69 json_schema_for::<GrepToolInput>(format)
70 }
71
72 fn ui_text(&self, input: &serde_json::Value) -> String {
73 match serde_json::from_value::<GrepToolInput>(input.clone()) {
74 Ok(input) => {
75 let page = input.page();
76 let regex_str = MarkdownString::inline_code(&input.regex);
77 let case_info = if input.case_sensitive {
78 " (case-sensitive)"
79 } else {
80 ""
81 };
82
83 if page > 1 {
84 format!("Get page {page} of search results for regex {regex_str}{case_info}")
85 } else {
86 format!("Search files for regex {regex_str}{case_info}")
87 }
88 }
89 Err(_) => "Search with regex".to_string(),
90 }
91 }
92
93 fn run(
94 self: Arc<Self>,
95 input: serde_json::Value,
96 _messages: &[LanguageModelRequestMessage],
97 project: Entity<Project>,
98 _action_log: Entity<ActionLog>,
99 _window: Option<AnyWindowHandle>,
100 cx: &mut App,
101 ) -> ToolResult {
102 const CONTEXT_LINES: u32 = 2;
103
104 let input = match serde_json::from_value::<GrepToolInput>(input) {
105 Ok(input) => input,
106 Err(error) => {
107 return Task::ready(Err(anyhow!("Failed to parse input: {}", error))).into();
108 }
109 };
110
111 let include_matcher = match PathMatcher::new(
112 input
113 .include_pattern
114 .as_ref()
115 .into_iter()
116 .collect::<Vec<_>>(),
117 ) {
118 Ok(matcher) => matcher,
119 Err(error) => {
120 return Task::ready(Err(anyhow!("invalid include glob pattern: {}", error))).into();
121 }
122 };
123
124 let query = match SearchQuery::regex(
125 &input.regex,
126 false,
127 input.case_sensitive,
128 false,
129 false,
130 include_matcher,
131 PathMatcher::default(), // For now, keep it simple and don't enable an exclude pattern.
132 true, // Always match file include pattern against *full project paths* that start with a project root.
133 None,
134 ) {
135 Ok(query) => query,
136 Err(error) => return Task::ready(Err(error)).into(),
137 };
138
139 let results = project.update(cx, |project, cx| project.search(query, cx));
140
141 cx.spawn(async move|cx| {
142 futures::pin_mut!(results);
143
144 let mut output = String::new();
145 let mut skips_remaining = input.offset;
146 let mut matches_found = 0;
147 let mut has_more_matches = false;
148
149 while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
150 if ranges.is_empty() {
151 continue;
152 }
153
154 buffer.read_with(cx, |buffer, cx| -> Result<(), anyhow::Error> {
155 if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
156 let mut file_header_written = false;
157 let mut ranges = ranges
158 .into_iter()
159 .map(|range| {
160 let mut point_range = range.to_point(buffer);
161 point_range.start.row =
162 point_range.start.row.saturating_sub(CONTEXT_LINES);
163 point_range.start.column = 0;
164 point_range.end.row = cmp::min(
165 buffer.max_point().row,
166 point_range.end.row + CONTEXT_LINES,
167 );
168 point_range.end.column = buffer.line_len(point_range.end.row);
169 point_range
170 })
171 .peekable();
172
173 while let Some(mut range) = ranges.next() {
174 if skips_remaining > 0 {
175 skips_remaining -= 1;
176 continue;
177 }
178
179 // We'd already found a full page of matches, and we just found one more.
180 if matches_found >= RESULTS_PER_PAGE {
181 has_more_matches = true;
182 return Ok(());
183 }
184
185 while let Some(next_range) = ranges.peek() {
186 if range.end.row >= next_range.start.row {
187 range.end = next_range.end;
188 ranges.next();
189 } else {
190 break;
191 }
192 }
193
194 if !file_header_written {
195 writeln!(output, "\n## Matches in {}", path.display())?;
196 file_header_written = true;
197 }
198
199 let start_line = range.start.row + 1;
200 let end_line = range.end.row + 1;
201 writeln!(output, "\n### Lines {start_line}-{end_line}\n```")?;
202 output.extend(buffer.text_for_range(range));
203 output.push_str("\n```\n");
204
205 matches_found += 1;
206 }
207 }
208
209 Ok(())
210 })??;
211 }
212
213 if matches_found == 0 {
214 Ok("No matches found".to_string())
215 } else if has_more_matches {
216 Ok(format!(
217 "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
218 input.offset + 1,
219 input.offset + matches_found,
220 input.offset + RESULTS_PER_PAGE,
221 ))
222 } else {
223 Ok(format!("Found {matches_found} matches:\n{output}"))
224 }
225 }).into()
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use assistant_tool::Tool;
233 use gpui::{AppContext, TestAppContext};
234 use project::{FakeFs, Project};
235 use settings::SettingsStore;
236 use util::path;
237
238 #[gpui::test]
239 async fn test_grep_tool_with_include_pattern(cx: &mut TestAppContext) {
240 init_test(cx);
241
242 let fs = FakeFs::new(cx.executor().clone());
243 fs.insert_tree(
244 "/root",
245 serde_json::json!({
246 "src": {
247 "main.rs": "fn main() {\n println!(\"Hello, world!\");\n}",
248 "utils": {
249 "helper.rs": "fn helper() {\n println!(\"I'm a helper!\");\n}",
250 },
251 },
252 "tests": {
253 "test_main.rs": "fn test_main() {\n assert!(true);\n}",
254 }
255 }),
256 )
257 .await;
258
259 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
260
261 // Test with include pattern for Rust files inside the root of the project
262 let input = serde_json::to_value(GrepToolInput {
263 regex: "println".to_string(),
264 include_pattern: Some("root/**/*.rs".to_string()),
265 offset: 0,
266 case_sensitive: false,
267 })
268 .unwrap();
269
270 let result = run_grep_tool(input, project.clone(), cx).await;
271 assert!(result.contains("main.rs"), "Should find matches in main.rs");
272 assert!(
273 result.contains("helper.rs"),
274 "Should find matches in helper.rs"
275 );
276 assert!(
277 !result.contains("test_main.rs"),
278 "Should not include test_main.rs even though it's a .rs file (because it doesn't have the pattern)"
279 );
280
281 // Test with include pattern for src directory only
282 let input = serde_json::to_value(GrepToolInput {
283 regex: "fn".to_string(),
284 include_pattern: Some("root/**/src/**".to_string()),
285 offset: 0,
286 case_sensitive: false,
287 })
288 .unwrap();
289
290 let result = run_grep_tool(input, project.clone(), cx).await;
291 assert!(
292 result.contains("main.rs"),
293 "Should find matches in src/main.rs"
294 );
295 assert!(
296 result.contains("helper.rs"),
297 "Should find matches in src/utils/helper.rs"
298 );
299 assert!(
300 !result.contains("test_main.rs"),
301 "Should not include test_main.rs as it's not in src directory"
302 );
303
304 // Test with empty include pattern (should default to all files)
305 let input = serde_json::to_value(GrepToolInput {
306 regex: "fn".to_string(),
307 include_pattern: None,
308 offset: 0,
309 case_sensitive: false,
310 })
311 .unwrap();
312
313 let result = run_grep_tool(input, project.clone(), cx).await;
314 assert!(result.contains("main.rs"), "Should find matches in main.rs");
315 assert!(
316 result.contains("helper.rs"),
317 "Should find matches in helper.rs"
318 );
319 assert!(
320 result.contains("test_main.rs"),
321 "Should include test_main.rs"
322 );
323 }
324
325 #[gpui::test]
326 async fn test_grep_tool_with_case_sensitivity(cx: &mut TestAppContext) {
327 init_test(cx);
328
329 let fs = FakeFs::new(cx.executor().clone());
330 fs.insert_tree(
331 "/root",
332 serde_json::json!({
333 "case_test.txt": "This file has UPPERCASE and lowercase text.\nUPPERCASE patterns should match only with case_sensitive: true",
334 }),
335 )
336 .await;
337
338 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
339
340 // Test case-insensitive search (default)
341 let input = serde_json::to_value(GrepToolInput {
342 regex: "uppercase".to_string(),
343 include_pattern: Some("**/*.txt".to_string()),
344 offset: 0,
345 case_sensitive: false,
346 })
347 .unwrap();
348
349 let result = run_grep_tool(input, project.clone(), cx).await;
350 assert!(
351 result.contains("UPPERCASE"),
352 "Case-insensitive search should match uppercase"
353 );
354
355 // Test case-sensitive search
356 let input = serde_json::to_value(GrepToolInput {
357 regex: "uppercase".to_string(),
358 include_pattern: Some("**/*.txt".to_string()),
359 offset: 0,
360 case_sensitive: true,
361 })
362 .unwrap();
363
364 let result = run_grep_tool(input, project.clone(), cx).await;
365 assert!(
366 !result.contains("UPPERCASE"),
367 "Case-sensitive search should not match uppercase"
368 );
369
370 // Test case-sensitive search
371 let input = serde_json::to_value(GrepToolInput {
372 regex: "LOWERCASE".to_string(),
373 include_pattern: Some("**/*.txt".to_string()),
374 offset: 0,
375 case_sensitive: true,
376 })
377 .unwrap();
378
379 let result = run_grep_tool(input, project.clone(), cx).await;
380
381 assert!(
382 !result.contains("lowercase"),
383 "Case-sensitive search should match lowercase"
384 );
385
386 // Test case-sensitive search for lowercase pattern
387 let input = serde_json::to_value(GrepToolInput {
388 regex: "lowercase".to_string(),
389 include_pattern: Some("**/*.txt".to_string()),
390 offset: 0,
391 case_sensitive: true,
392 })
393 .unwrap();
394
395 let result = run_grep_tool(input, project.clone(), cx).await;
396 assert!(
397 result.contains("lowercase"),
398 "Case-sensitive search should match lowercase text"
399 );
400 }
401
402 async fn run_grep_tool(
403 input: serde_json::Value,
404 project: Entity<Project>,
405 cx: &mut TestAppContext,
406 ) -> String {
407 let tool = Arc::new(GrepTool);
408 let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
409 let task = cx.update(|cx| tool.run(input, &[], project, action_log, None, cx));
410
411 match task.output.await {
412 Ok(result) => result,
413 Err(e) => panic!("Failed to run grep tool: {}", e),
414 }
415 }
416
417 fn init_test(cx: &mut TestAppContext) {
418 cx.update(|cx| {
419 let settings_store = SettingsStore::test(cx);
420 cx.set_global(settings_store);
421 language::init(cx);
422 Project::init_settings(cx);
423 });
424 }
425}