1use crate::schema::json_schema_for;
2use anyhow::{Result, anyhow};
3use assistant_tool::{ActionLog, Tool, ToolResult};
4use futures::StreamExt;
5use gpui::{App, Entity, Task};
6use language::OffsetRangeExt;
7use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
8use project::{
9 Project,
10 search::{SearchQuery, SearchResult},
11};
12use schemars::JsonSchema;
13use serde::{Deserialize, Serialize};
14use std::{cmp, fmt::Write, sync::Arc};
15use ui::IconName;
16use util::markdown::MarkdownString;
17use util::paths::PathMatcher;
18
19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
20pub struct GrepToolInput {
21 /// A regex pattern to search for in the entire project. Note that the regex
22 /// will be parsed by the Rust `regex` crate.
23 pub regex: String,
24
25 /// A glob pattern for the paths of files to include in the search.
26 /// Supports standard glob patterns like "**/*.rs" or "src/**/*.ts".
27 /// If omitted, all files in the project will be searched.
28 pub include_pattern: Option<String>,
29
30 /// Optional starting position for paginated results (0-based).
31 /// When not provided, starts from the beginning.
32 #[serde(default)]
33 pub offset: u32,
34
35 /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
36 #[serde(default)]
37 pub case_sensitive: bool,
38}
39
40impl GrepToolInput {
41 /// Which page of search results this is.
42 pub fn page(&self) -> u32 {
43 1 + (self.offset / RESULTS_PER_PAGE)
44 }
45}
46
47const RESULTS_PER_PAGE: u32 = 20;
48
49pub struct GrepTool;
50
51impl Tool for GrepTool {
52 fn name(&self) -> String {
53 "grep".into()
54 }
55
56 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
57 false
58 }
59
60 fn description(&self) -> String {
61 include_str!("./grep_tool/description.md").into()
62 }
63
64 fn icon(&self) -> IconName {
65 IconName::Regex
66 }
67
68 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
69 json_schema_for::<GrepToolInput>(format)
70 }
71
72 fn ui_text(&self, input: &serde_json::Value) -> String {
73 match serde_json::from_value::<GrepToolInput>(input.clone()) {
74 Ok(input) => {
75 let page = input.page();
76 let regex_str = MarkdownString::inline_code(&input.regex);
77 let case_info = if input.case_sensitive {
78 " (case-sensitive)"
79 } else {
80 ""
81 };
82
83 if page > 1 {
84 format!("Get page {page} of search results for regex {regex_str}{case_info}")
85 } else {
86 format!("Search files for regex {regex_str}{case_info}")
87 }
88 }
89 Err(_) => "Search with regex".to_string(),
90 }
91 }
92
93 fn run(
94 self: Arc<Self>,
95 input: serde_json::Value,
96 _messages: &[LanguageModelRequestMessage],
97 project: Entity<Project>,
98 _action_log: Entity<ActionLog>,
99 cx: &mut App,
100 ) -> ToolResult {
101 const CONTEXT_LINES: u32 = 2;
102
103 let input = match serde_json::from_value::<GrepToolInput>(input) {
104 Ok(input) => input,
105 Err(error) => {
106 return Task::ready(Err(anyhow!("Failed to parse input: {}", error))).into();
107 }
108 };
109
110 let include_matcher = match PathMatcher::new(
111 input
112 .include_pattern
113 .as_ref()
114 .into_iter()
115 .collect::<Vec<_>>(),
116 ) {
117 Ok(matcher) => matcher,
118 Err(error) => {
119 return Task::ready(Err(anyhow!("invalid include glob pattern: {}", error))).into();
120 }
121 };
122
123 let query = match SearchQuery::regex(
124 &input.regex,
125 false,
126 input.case_sensitive,
127 false,
128 false,
129 include_matcher,
130 PathMatcher::default(), // For now, keep it simple and don't enable an exclude pattern.
131 true, // Always match file include pattern against *full project paths* that start with a project root.
132 None,
133 ) {
134 Ok(query) => query,
135 Err(error) => return Task::ready(Err(error)).into(),
136 };
137
138 let results = project.update(cx, |project, cx| project.search(query, cx));
139
140 cx.spawn(async move|cx| {
141 futures::pin_mut!(results);
142
143 let mut output = String::new();
144 let mut skips_remaining = input.offset;
145 let mut matches_found = 0;
146 let mut has_more_matches = false;
147
148 while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
149 if ranges.is_empty() {
150 continue;
151 }
152
153 buffer.read_with(cx, |buffer, cx| -> Result<(), anyhow::Error> {
154 if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
155 let mut file_header_written = false;
156 let mut ranges = ranges
157 .into_iter()
158 .map(|range| {
159 let mut point_range = range.to_point(buffer);
160 point_range.start.row =
161 point_range.start.row.saturating_sub(CONTEXT_LINES);
162 point_range.start.column = 0;
163 point_range.end.row = cmp::min(
164 buffer.max_point().row,
165 point_range.end.row + CONTEXT_LINES,
166 );
167 point_range.end.column = buffer.line_len(point_range.end.row);
168 point_range
169 })
170 .peekable();
171
172 while let Some(mut range) = ranges.next() {
173 if skips_remaining > 0 {
174 skips_remaining -= 1;
175 continue;
176 }
177
178 // We'd already found a full page of matches, and we just found one more.
179 if matches_found >= RESULTS_PER_PAGE {
180 has_more_matches = true;
181 return Ok(());
182 }
183
184 while let Some(next_range) = ranges.peek() {
185 if range.end.row >= next_range.start.row {
186 range.end = next_range.end;
187 ranges.next();
188 } else {
189 break;
190 }
191 }
192
193 if !file_header_written {
194 writeln!(output, "\n## Matches in {}", path.display())?;
195 file_header_written = true;
196 }
197
198 let start_line = range.start.row + 1;
199 let end_line = range.end.row + 1;
200 writeln!(output, "\n### Lines {start_line}-{end_line}\n```")?;
201 output.extend(buffer.text_for_range(range));
202 output.push_str("\n```\n");
203
204 matches_found += 1;
205 }
206 }
207
208 Ok(())
209 })??;
210 }
211
212 if matches_found == 0 {
213 Ok("No matches found".to_string())
214 } else if has_more_matches {
215 Ok(format!(
216 "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
217 input.offset + 1,
218 input.offset + matches_found,
219 input.offset + RESULTS_PER_PAGE,
220 ))
221 } else {
222 Ok(format!("Found {matches_found} matches:\n{output}"))
223 }
224 }).into()
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use assistant_tool::Tool;
232 use gpui::{AppContext, TestAppContext};
233 use project::{FakeFs, Project};
234 use settings::SettingsStore;
235 use util::path;
236
237 #[gpui::test]
238 async fn test_grep_tool_with_include_pattern(cx: &mut TestAppContext) {
239 init_test(cx);
240
241 let fs = FakeFs::new(cx.executor().clone());
242 fs.insert_tree(
243 "/root",
244 serde_json::json!({
245 "src": {
246 "main.rs": "fn main() {\n println!(\"Hello, world!\");\n}",
247 "utils": {
248 "helper.rs": "fn helper() {\n println!(\"I'm a helper!\");\n}",
249 },
250 },
251 "tests": {
252 "test_main.rs": "fn test_main() {\n assert!(true);\n}",
253 }
254 }),
255 )
256 .await;
257
258 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
259
260 // Test with include pattern for Rust files inside the root of the project
261 let input = serde_json::to_value(GrepToolInput {
262 regex: "println".to_string(),
263 include_pattern: Some("root/**/*.rs".to_string()),
264 offset: 0,
265 case_sensitive: false,
266 })
267 .unwrap();
268
269 let result = run_grep_tool(input, project.clone(), cx).await;
270 assert!(result.contains("main.rs"), "Should find matches in main.rs");
271 assert!(
272 result.contains("helper.rs"),
273 "Should find matches in helper.rs"
274 );
275 assert!(
276 !result.contains("test_main.rs"),
277 "Should not include test_main.rs even though it's a .rs file (because it doesn't have the pattern)"
278 );
279
280 // Test with include pattern for src directory only
281 let input = serde_json::to_value(GrepToolInput {
282 regex: "fn".to_string(),
283 include_pattern: Some("root/**/src/**".to_string()),
284 offset: 0,
285 case_sensitive: false,
286 })
287 .unwrap();
288
289 let result = run_grep_tool(input, project.clone(), cx).await;
290 assert!(
291 result.contains("main.rs"),
292 "Should find matches in src/main.rs"
293 );
294 assert!(
295 result.contains("helper.rs"),
296 "Should find matches in src/utils/helper.rs"
297 );
298 assert!(
299 !result.contains("test_main.rs"),
300 "Should not include test_main.rs as it's not in src directory"
301 );
302
303 // Test with empty include pattern (should default to all files)
304 let input = serde_json::to_value(GrepToolInput {
305 regex: "fn".to_string(),
306 include_pattern: None,
307 offset: 0,
308 case_sensitive: false,
309 })
310 .unwrap();
311
312 let result = run_grep_tool(input, project.clone(), cx).await;
313 assert!(result.contains("main.rs"), "Should find matches in main.rs");
314 assert!(
315 result.contains("helper.rs"),
316 "Should find matches in helper.rs"
317 );
318 assert!(
319 result.contains("test_main.rs"),
320 "Should include test_main.rs"
321 );
322 }
323
324 #[gpui::test]
325 async fn test_grep_tool_with_case_sensitivity(cx: &mut TestAppContext) {
326 init_test(cx);
327
328 let fs = FakeFs::new(cx.executor().clone());
329 fs.insert_tree(
330 "/root",
331 serde_json::json!({
332 "case_test.txt": "This file has UPPERCASE and lowercase text.\nUPPERCASE patterns should match only with case_sensitive: true",
333 }),
334 )
335 .await;
336
337 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
338
339 // Test case-insensitive search (default)
340 let input = serde_json::to_value(GrepToolInput {
341 regex: "uppercase".to_string(),
342 include_pattern: Some("**/*.txt".to_string()),
343 offset: 0,
344 case_sensitive: false,
345 })
346 .unwrap();
347
348 let result = run_grep_tool(input, project.clone(), cx).await;
349 assert!(
350 result.contains("UPPERCASE"),
351 "Case-insensitive search should match uppercase"
352 );
353
354 // Test case-sensitive search
355 let input = serde_json::to_value(GrepToolInput {
356 regex: "uppercase".to_string(),
357 include_pattern: Some("**/*.txt".to_string()),
358 offset: 0,
359 case_sensitive: true,
360 })
361 .unwrap();
362
363 let result = run_grep_tool(input, project.clone(), cx).await;
364 assert!(
365 !result.contains("UPPERCASE"),
366 "Case-sensitive search should not match uppercase"
367 );
368
369 // Test case-sensitive search
370 let input = serde_json::to_value(GrepToolInput {
371 regex: "LOWERCASE".to_string(),
372 include_pattern: Some("**/*.txt".to_string()),
373 offset: 0,
374 case_sensitive: true,
375 })
376 .unwrap();
377
378 let result = run_grep_tool(input, project.clone(), cx).await;
379
380 assert!(
381 !result.contains("lowercase"),
382 "Case-sensitive search should match lowercase"
383 );
384
385 // Test case-sensitive search for lowercase pattern
386 let input = serde_json::to_value(GrepToolInput {
387 regex: "lowercase".to_string(),
388 include_pattern: Some("**/*.txt".to_string()),
389 offset: 0,
390 case_sensitive: true,
391 })
392 .unwrap();
393
394 let result = run_grep_tool(input, project.clone(), cx).await;
395 assert!(
396 result.contains("lowercase"),
397 "Case-sensitive search should match lowercase text"
398 );
399 }
400
401 async fn run_grep_tool(
402 input: serde_json::Value,
403 project: Entity<Project>,
404 cx: &mut TestAppContext,
405 ) -> String {
406 let tool = Arc::new(GrepTool);
407 let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
408 let task = cx.update(|cx| tool.run(input, &[], project, action_log, cx));
409
410 match task.output.await {
411 Ok(result) => result,
412 Err(e) => panic!("Failed to run grep tool: {}", e),
413 }
414 }
415
416 fn init_test(cx: &mut TestAppContext) {
417 cx.update(|cx| {
418 let settings_store = SettingsStore::test(cx);
419 cx.set_global(settings_store);
420 language::init(cx);
421 Project::init_settings(cx);
422 });
423 }
424}