1use anyhow::{anyhow, Result};
2use assistant_tool::Tool;
3use futures::StreamExt;
4use gpui::{App, Entity, Task};
5use language::OffsetRangeExt;
6use language_model::LanguageModelRequestMessage;
7use project::{search::SearchQuery, Project};
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use std::{cmp, fmt::Write, sync::Arc};
11use util::paths::PathMatcher;
12
13#[derive(Debug, Serialize, Deserialize, JsonSchema)]
14pub struct RegexSearchToolInput {
15 /// A regex pattern to search for in the entire project. Note that the regex
16 /// will be parsed by the Rust `regex` crate.
17 pub regex: String,
18}
19
20pub struct RegexSearchTool;
21
22impl Tool for RegexSearchTool {
23 fn name(&self) -> String {
24 "regex-search".into()
25 }
26
27 fn description(&self) -> String {
28 include_str!("./regex_search_tool/description.md").into()
29 }
30
31 fn input_schema(&self) -> serde_json::Value {
32 let schema = schemars::schema_for!(RegexSearchToolInput);
33 serde_json::to_value(&schema).unwrap()
34 }
35
36 fn run(
37 self: Arc<Self>,
38 input: serde_json::Value,
39 _messages: &[LanguageModelRequestMessage],
40 project: Entity<Project>,
41 cx: &mut App,
42 ) -> Task<Result<String>> {
43 const CONTEXT_LINES: u32 = 2;
44
45 let input = match serde_json::from_value::<RegexSearchToolInput>(input) {
46 Ok(input) => input,
47 Err(err) => return Task::ready(Err(anyhow!(err))),
48 };
49
50 let query = match SearchQuery::regex(
51 &input.regex,
52 false,
53 false,
54 false,
55 PathMatcher::default(),
56 PathMatcher::default(),
57 None,
58 ) {
59 Ok(query) => query,
60 Err(error) => return Task::ready(Err(error)),
61 };
62
63 let results = project.update(cx, |project, cx| project.search(query, cx));
64 cx.spawn(|cx| async move {
65 futures::pin_mut!(results);
66
67 let mut output = String::new();
68 while let Some(project::search::SearchResult::Buffer { buffer, ranges }) =
69 results.next().await
70 {
71 if ranges.is_empty() {
72 continue;
73 }
74
75 buffer.read_with(&cx, |buffer, cx| {
76 if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
77 writeln!(output, "### Found matches in {}:\n", path.display()).unwrap();
78 let mut ranges = ranges
79 .into_iter()
80 .map(|range| {
81 let mut point_range = range.to_point(buffer);
82 point_range.start.row =
83 point_range.start.row.saturating_sub(CONTEXT_LINES);
84 point_range.start.column = 0;
85 point_range.end.row = cmp::min(
86 buffer.max_point().row,
87 point_range.end.row + CONTEXT_LINES,
88 );
89 point_range.end.column = buffer.line_len(point_range.end.row);
90 point_range
91 })
92 .peekable();
93
94 while let Some(mut range) = ranges.next() {
95 while let Some(next_range) = ranges.peek() {
96 if range.end.row >= next_range.start.row {
97 range.end = next_range.end;
98 ranges.next();
99 } else {
100 break;
101 }
102 }
103
104 writeln!(output, "```").unwrap();
105 output.extend(buffer.text_for_range(range));
106 writeln!(output, "\n```\n").unwrap();
107 }
108 }
109 })?;
110 }
111
112 if output.is_empty() {
113 Ok("No matches found".into())
114 } else {
115 Ok(output)
116 }
117 })
118 }
119}