1use anyhow::{anyhow, Result};
2use assistant_tool::{ActionLog, Tool};
3use futures::StreamExt;
4use gpui::{App, Entity, Task};
5use language::OffsetRangeExt;
6use language_model::LanguageModelRequestMessage;
7use project::{
8 search::{SearchQuery, SearchResult},
9 Project,
10};
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use std::{cmp, fmt::Write, sync::Arc};
14use ui::IconName;
15use util::markdown::MarkdownString;
16use util::paths::PathMatcher;
17
18#[derive(Debug, Serialize, Deserialize, JsonSchema)]
19pub struct RegexSearchToolInput {
20 /// A regex pattern to search for in the entire project. Note that the regex
21 /// will be parsed by the Rust `regex` crate.
22 pub regex: String,
23
24 /// Optional starting position for paginated results (0-based).
25 /// When not provided, starts from the beginning.
26 #[serde(default)]
27 pub offset: Option<u32>,
28}
29
30impl RegexSearchToolInput {
31 /// Which page of search results this is.
32 pub fn page(&self) -> u32 {
33 1 + (self.offset.unwrap_or(0) / RESULTS_PER_PAGE)
34 }
35}
36
37const RESULTS_PER_PAGE: u32 = 20;
38
39pub struct RegexSearchTool;
40
41impl Tool for RegexSearchTool {
42 fn name(&self) -> String {
43 "regex-search".into()
44 }
45
46 fn needs_confirmation(&self) -> bool {
47 false
48 }
49
50 fn description(&self) -> String {
51 include_str!("./regex_search_tool/description.md").into()
52 }
53
54 fn icon(&self) -> IconName {
55 IconName::Regex
56 }
57
58 fn input_schema(&self) -> serde_json::Value {
59 let schema = schemars::schema_for!(RegexSearchToolInput);
60 serde_json::to_value(&schema).unwrap()
61 }
62
63 fn ui_text(&self, input: &serde_json::Value) -> String {
64 match serde_json::from_value::<RegexSearchToolInput>(input.clone()) {
65 Ok(input) => {
66 let page = input.page();
67 let regex = MarkdownString::escape(&input.regex);
68
69 if page > 1 {
70 format!("Get page {page} of search results for regex “`{regex}`”")
71 } else {
72 format!("Search files for regex “`{regex}`”")
73 }
74 }
75 Err(_) => "Search with regex".to_string(),
76 }
77 }
78
79 fn run(
80 self: Arc<Self>,
81 input: serde_json::Value,
82 _messages: &[LanguageModelRequestMessage],
83 project: Entity<Project>,
84 _action_log: Entity<ActionLog>,
85 cx: &mut App,
86 ) -> Task<Result<String>> {
87 const CONTEXT_LINES: u32 = 2;
88
89 let (offset, regex) = match serde_json::from_value::<RegexSearchToolInput>(input) {
90 Ok(input) => (input.offset.unwrap_or(0), input.regex),
91 Err(err) => return Task::ready(Err(anyhow!(err))),
92 };
93
94 let query = match SearchQuery::regex(
95 ®ex,
96 false,
97 false,
98 false,
99 PathMatcher::default(),
100 PathMatcher::default(),
101 None,
102 ) {
103 Ok(query) => query,
104 Err(error) => return Task::ready(Err(error)),
105 };
106
107 let results = project.update(cx, |project, cx| project.search(query, cx));
108
109 cx.spawn(async move|cx| {
110 futures::pin_mut!(results);
111
112 let mut output = String::new();
113 let mut skips_remaining = offset;
114 let mut matches_found = 0;
115 let mut has_more_matches = false;
116
117 while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
118 if ranges.is_empty() {
119 continue;
120 }
121
122 buffer.read_with(cx, |buffer, cx| -> Result<(), anyhow::Error> {
123 if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
124 let mut file_header_written = false;
125 let mut ranges = ranges
126 .into_iter()
127 .map(|range| {
128 let mut point_range = range.to_point(buffer);
129 point_range.start.row =
130 point_range.start.row.saturating_sub(CONTEXT_LINES);
131 point_range.start.column = 0;
132 point_range.end.row = cmp::min(
133 buffer.max_point().row,
134 point_range.end.row + CONTEXT_LINES,
135 );
136 point_range.end.column = buffer.line_len(point_range.end.row);
137 point_range
138 })
139 .peekable();
140
141 while let Some(mut range) = ranges.next() {
142 if skips_remaining > 0 {
143 skips_remaining -= 1;
144 continue;
145 }
146
147 // We'd already found a full page of matches, and we just found one more.
148 if matches_found >= RESULTS_PER_PAGE {
149 has_more_matches = true;
150 return Ok(());
151 }
152
153 while let Some(next_range) = ranges.peek() {
154 if range.end.row >= next_range.start.row {
155 range.end = next_range.end;
156 ranges.next();
157 } else {
158 break;
159 }
160 }
161
162 if !file_header_written {
163 writeln!(output, "\n## Matches in {}", path.display())?;
164 file_header_written = true;
165 }
166
167 let start_line = range.start.row + 1;
168 let end_line = range.end.row + 1;
169 writeln!(output, "\n### Lines {start_line}-{end_line}\n```")?;
170 output.extend(buffer.text_for_range(range));
171 output.push_str("\n```\n");
172
173 matches_found += 1;
174 }
175 }
176
177 Ok(())
178 })??;
179 }
180
181 if matches_found == 0 {
182 Ok("No matches found".to_string())
183 } else if has_more_matches {
184 Ok(format!(
185 "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
186 offset + 1,
187 offset + matches_found,
188 offset + RESULTS_PER_PAGE,
189 ))
190 } else {
191 Ok(format!("Found {matches_found} matches:\n{output}"))
192 }
193 })
194 }
195}