1use crate::schema::json_schema_for;
2use anyhow::{Result, anyhow};
3use assistant_tool::{ActionLog, Tool, ToolResult};
4use futures::StreamExt;
5use gpui::{AnyWindowHandle, App, Entity, Task};
6use language::OffsetRangeExt;
7use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
8use project::{
9 Project,
10 search::{SearchQuery, SearchResult},
11};
12use schemars::JsonSchema;
13use serde::{Deserialize, Serialize};
14use std::{cmp, fmt::Write, sync::Arc};
15use ui::IconName;
16use util::markdown::MarkdownString;
17use util::paths::PathMatcher;
18
19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
20pub struct GrepToolInput {
21 /// A regex pattern to search for in the entire project. Note that the regex
22 /// will be parsed by the Rust `regex` crate.
23 ///
24 /// Do NOT specify a path here! This will only be matched against the code **content**.
25 pub regex: String,
26
27 /// A glob pattern for the paths of files to include in the search.
28 /// Supports standard glob patterns like "**/*.rs" or "src/**/*.ts".
29 /// If omitted, all files in the project will be searched.
30 pub include_pattern: Option<String>,
31
32 /// Optional starting position for paginated results (0-based).
33 /// When not provided, starts from the beginning.
34 #[serde(default)]
35 pub offset: u32,
36
37 /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
38 #[serde(default)]
39 pub case_sensitive: bool,
40}
41
42impl GrepToolInput {
43 /// Which page of search results this is.
44 pub fn page(&self) -> u32 {
45 1 + (self.offset / RESULTS_PER_PAGE)
46 }
47}
48
49const RESULTS_PER_PAGE: u32 = 20;
50
51pub struct GrepTool;
52
53impl Tool for GrepTool {
54 fn name(&self) -> String {
55 "grep".into()
56 }
57
58 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
59 false
60 }
61
62 fn description(&self) -> String {
63 include_str!("./grep_tool/description.md").into()
64 }
65
66 fn icon(&self) -> IconName {
67 IconName::Regex
68 }
69
70 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
71 json_schema_for::<GrepToolInput>(format)
72 }
73
74 fn ui_text(&self, input: &serde_json::Value) -> String {
75 match serde_json::from_value::<GrepToolInput>(input.clone()) {
76 Ok(input) => {
77 let page = input.page();
78 let regex_str = MarkdownString::inline_code(&input.regex);
79 let case_info = if input.case_sensitive {
80 " (case-sensitive)"
81 } else {
82 ""
83 };
84
85 if page > 1 {
86 format!("Get page {page} of search results for regex {regex_str}{case_info}")
87 } else {
88 format!("Search files for regex {regex_str}{case_info}")
89 }
90 }
91 Err(_) => "Search with regex".to_string(),
92 }
93 }
94
95 fn run(
96 self: Arc<Self>,
97 input: serde_json::Value,
98 _messages: &[LanguageModelRequestMessage],
99 project: Entity<Project>,
100 _action_log: Entity<ActionLog>,
101 _window: Option<AnyWindowHandle>,
102 cx: &mut App,
103 ) -> ToolResult {
104 const CONTEXT_LINES: u32 = 2;
105
106 let input = match serde_json::from_value::<GrepToolInput>(input) {
107 Ok(input) => input,
108 Err(error) => {
109 return Task::ready(Err(anyhow!("Failed to parse input: {}", error))).into();
110 }
111 };
112
113 let include_matcher = match PathMatcher::new(
114 input
115 .include_pattern
116 .as_ref()
117 .into_iter()
118 .collect::<Vec<_>>(),
119 ) {
120 Ok(matcher) => matcher,
121 Err(error) => {
122 return Task::ready(Err(anyhow!("invalid include glob pattern: {}", error))).into();
123 }
124 };
125
126 let query = match SearchQuery::regex(
127 &input.regex,
128 false,
129 input.case_sensitive,
130 false,
131 false,
132 include_matcher,
133 PathMatcher::default(), // For now, keep it simple and don't enable an exclude pattern.
134 true, // Always match file include pattern against *full project paths* that start with a project root.
135 None,
136 ) {
137 Ok(query) => query,
138 Err(error) => return Task::ready(Err(error)).into(),
139 };
140
141 let results = project.update(cx, |project, cx| project.search(query, cx));
142
143 cx.spawn(async move|cx| {
144 futures::pin_mut!(results);
145
146 let mut output = String::new();
147 let mut skips_remaining = input.offset;
148 let mut matches_found = 0;
149 let mut has_more_matches = false;
150
151 while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
152 if ranges.is_empty() {
153 continue;
154 }
155
156 buffer.read_with(cx, |buffer, cx| -> Result<(), anyhow::Error> {
157 if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
158 let mut file_header_written = false;
159 let mut ranges = ranges
160 .into_iter()
161 .map(|range| {
162 let mut point_range = range.to_point(buffer);
163 point_range.start.row =
164 point_range.start.row.saturating_sub(CONTEXT_LINES);
165 point_range.start.column = 0;
166 point_range.end.row = cmp::min(
167 buffer.max_point().row,
168 point_range.end.row + CONTEXT_LINES,
169 );
170 point_range.end.column = buffer.line_len(point_range.end.row);
171 point_range
172 })
173 .peekable();
174
175 while let Some(mut range) = ranges.next() {
176 if skips_remaining > 0 {
177 skips_remaining -= 1;
178 continue;
179 }
180
181 // We'd already found a full page of matches, and we just found one more.
182 if matches_found >= RESULTS_PER_PAGE {
183 has_more_matches = true;
184 return Ok(());
185 }
186
187 while let Some(next_range) = ranges.peek() {
188 if range.end.row >= next_range.start.row {
189 range.end = next_range.end;
190 ranges.next();
191 } else {
192 break;
193 }
194 }
195
196 if !file_header_written {
197 writeln!(output, "\n## Matches in {}", path.display())?;
198 file_header_written = true;
199 }
200
201 let start_line = range.start.row + 1;
202 let end_line = range.end.row + 1;
203 writeln!(output, "\n### Lines {start_line}-{end_line}\n```")?;
204 output.extend(buffer.text_for_range(range));
205 output.push_str("\n```\n");
206
207 matches_found += 1;
208 }
209 }
210
211 Ok(())
212 })??;
213 }
214
215 if matches_found == 0 {
216 Ok("No matches found".to_string())
217 } else if has_more_matches {
218 Ok(format!(
219 "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
220 input.offset + 1,
221 input.offset + matches_found,
222 input.offset + RESULTS_PER_PAGE,
223 ))
224 } else {
225 Ok(format!("Found {matches_found} matches:\n{output}"))
226 }
227 }).into()
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use assistant_tool::Tool;
235 use gpui::{AppContext, TestAppContext};
236 use project::{FakeFs, Project};
237 use settings::SettingsStore;
238 use util::path;
239
240 #[gpui::test]
241 async fn test_grep_tool_with_include_pattern(cx: &mut TestAppContext) {
242 init_test(cx);
243
244 let fs = FakeFs::new(cx.executor().clone());
245 fs.insert_tree(
246 "/root",
247 serde_json::json!({
248 "src": {
249 "main.rs": "fn main() {\n println!(\"Hello, world!\");\n}",
250 "utils": {
251 "helper.rs": "fn helper() {\n println!(\"I'm a helper!\");\n}",
252 },
253 },
254 "tests": {
255 "test_main.rs": "fn test_main() {\n assert!(true);\n}",
256 }
257 }),
258 )
259 .await;
260
261 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
262
263 // Test with include pattern for Rust files inside the root of the project
264 let input = serde_json::to_value(GrepToolInput {
265 regex: "println".to_string(),
266 include_pattern: Some("root/**/*.rs".to_string()),
267 offset: 0,
268 case_sensitive: false,
269 })
270 .unwrap();
271
272 let result = run_grep_tool(input, project.clone(), cx).await;
273 assert!(result.contains("main.rs"), "Should find matches in main.rs");
274 assert!(
275 result.contains("helper.rs"),
276 "Should find matches in helper.rs"
277 );
278 assert!(
279 !result.contains("test_main.rs"),
280 "Should not include test_main.rs even though it's a .rs file (because it doesn't have the pattern)"
281 );
282
283 // Test with include pattern for src directory only
284 let input = serde_json::to_value(GrepToolInput {
285 regex: "fn".to_string(),
286 include_pattern: Some("root/**/src/**".to_string()),
287 offset: 0,
288 case_sensitive: false,
289 })
290 .unwrap();
291
292 let result = run_grep_tool(input, project.clone(), cx).await;
293 assert!(
294 result.contains("main.rs"),
295 "Should find matches in src/main.rs"
296 );
297 assert!(
298 result.contains("helper.rs"),
299 "Should find matches in src/utils/helper.rs"
300 );
301 assert!(
302 !result.contains("test_main.rs"),
303 "Should not include test_main.rs as it's not in src directory"
304 );
305
306 // Test with empty include pattern (should default to all files)
307 let input = serde_json::to_value(GrepToolInput {
308 regex: "fn".to_string(),
309 include_pattern: None,
310 offset: 0,
311 case_sensitive: false,
312 })
313 .unwrap();
314
315 let result = run_grep_tool(input, project.clone(), cx).await;
316 assert!(result.contains("main.rs"), "Should find matches in main.rs");
317 assert!(
318 result.contains("helper.rs"),
319 "Should find matches in helper.rs"
320 );
321 assert!(
322 result.contains("test_main.rs"),
323 "Should include test_main.rs"
324 );
325 }
326
327 #[gpui::test]
328 async fn test_grep_tool_with_case_sensitivity(cx: &mut TestAppContext) {
329 init_test(cx);
330
331 let fs = FakeFs::new(cx.executor().clone());
332 fs.insert_tree(
333 "/root",
334 serde_json::json!({
335 "case_test.txt": "This file has UPPERCASE and lowercase text.\nUPPERCASE patterns should match only with case_sensitive: true",
336 }),
337 )
338 .await;
339
340 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
341
342 // Test case-insensitive search (default)
343 let input = serde_json::to_value(GrepToolInput {
344 regex: "uppercase".to_string(),
345 include_pattern: Some("**/*.txt".to_string()),
346 offset: 0,
347 case_sensitive: false,
348 })
349 .unwrap();
350
351 let result = run_grep_tool(input, project.clone(), cx).await;
352 assert!(
353 result.contains("UPPERCASE"),
354 "Case-insensitive search should match uppercase"
355 );
356
357 // Test case-sensitive search
358 let input = serde_json::to_value(GrepToolInput {
359 regex: "uppercase".to_string(),
360 include_pattern: Some("**/*.txt".to_string()),
361 offset: 0,
362 case_sensitive: true,
363 })
364 .unwrap();
365
366 let result = run_grep_tool(input, project.clone(), cx).await;
367 assert!(
368 !result.contains("UPPERCASE"),
369 "Case-sensitive search should not match uppercase"
370 );
371
372 // Test case-sensitive search
373 let input = serde_json::to_value(GrepToolInput {
374 regex: "LOWERCASE".to_string(),
375 include_pattern: Some("**/*.txt".to_string()),
376 offset: 0,
377 case_sensitive: true,
378 })
379 .unwrap();
380
381 let result = run_grep_tool(input, project.clone(), cx).await;
382
383 assert!(
384 !result.contains("lowercase"),
385 "Case-sensitive search should match lowercase"
386 );
387
388 // Test case-sensitive search for lowercase pattern
389 let input = serde_json::to_value(GrepToolInput {
390 regex: "lowercase".to_string(),
391 include_pattern: Some("**/*.txt".to_string()),
392 offset: 0,
393 case_sensitive: true,
394 })
395 .unwrap();
396
397 let result = run_grep_tool(input, project.clone(), cx).await;
398 assert!(
399 result.contains("lowercase"),
400 "Case-sensitive search should match lowercase text"
401 );
402 }
403
404 async fn run_grep_tool(
405 input: serde_json::Value,
406 project: Entity<Project>,
407 cx: &mut TestAppContext,
408 ) -> String {
409 let tool = Arc::new(GrepTool);
410 let action_log = cx.new(|_cx| ActionLog::new(project.clone()));
411 let task = cx.update(|cx| tool.run(input, &[], project, action_log, None, cx));
412
413 match task.output.await {
414 Ok(result) => result,
415 Err(e) => panic!("Failed to run grep tool: {}", e),
416 }
417 }
418
419 fn init_test(cx: &mut TestAppContext) {
420 cx.update(|cx| {
421 let settings_store = SettingsStore::test(cx);
422 cx.set_global(settings_store);
423 language::init(cx);
424 Project::init_settings(cx);
425 });
426 }
427}