1use anyhow::{anyhow, Result};
2use assistant_tool::{ActionLog, Tool};
3use futures::StreamExt;
4use gpui::{App, Entity, Task};
5use language::OffsetRangeExt;
6use language_model::LanguageModelRequestMessage;
7use project::{
8 search::{SearchQuery, SearchResult},
9 Project,
10};
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use std::{cmp, fmt::Write, sync::Arc};
14use ui::IconName;
15use util::paths::PathMatcher;
16
17#[derive(Debug, Serialize, Deserialize, JsonSchema)]
18pub struct RegexSearchToolInput {
19 /// A regex pattern to search for in the entire project. Note that the regex
20 /// will be parsed by the Rust `regex` crate.
21 pub regex: String,
22
23 /// Optional starting position for paginated results (0-based).
24 /// When not provided, starts from the beginning.
25 #[serde(default)]
26 pub offset: Option<u32>,
27}
28
29impl RegexSearchToolInput {
30 /// Which page of search results this is.
31 pub fn page(&self) -> u32 {
32 1 + (self.offset.unwrap_or(0) / RESULTS_PER_PAGE)
33 }
34}
35
36const RESULTS_PER_PAGE: u32 = 20;
37
38pub struct RegexSearchTool;
39
40impl Tool for RegexSearchTool {
41 fn name(&self) -> String {
42 "regex-search".into()
43 }
44
45 fn needs_confirmation(&self) -> bool {
46 false
47 }
48
49 fn description(&self) -> String {
50 include_str!("./regex_search_tool/description.md").into()
51 }
52
53 fn icon(&self) -> IconName {
54 IconName::Regex
55 }
56
57 fn input_schema(&self) -> serde_json::Value {
58 let schema = schemars::schema_for!(RegexSearchToolInput);
59 serde_json::to_value(&schema).unwrap()
60 }
61
62 fn ui_text(&self, input: &serde_json::Value) -> String {
63 match serde_json::from_value::<RegexSearchToolInput>(input.clone()) {
64 Ok(input) => {
65 let page = input.page();
66
67 if page > 1 {
68 format!(
69 "Get page {page} of search results for regex “`{}`”",
70 input.regex
71 )
72 } else {
73 format!("Search files for regex “`{}`”", input.regex)
74 }
75 }
76 Err(_) => "Search with regex".to_string(),
77 }
78 }
79
80 fn run(
81 self: Arc<Self>,
82 input: serde_json::Value,
83 _messages: &[LanguageModelRequestMessage],
84 project: Entity<Project>,
85 _action_log: Entity<ActionLog>,
86 cx: &mut App,
87 ) -> Task<Result<String>> {
88 const CONTEXT_LINES: u32 = 2;
89
90 let (offset, regex) = match serde_json::from_value::<RegexSearchToolInput>(input) {
91 Ok(input) => (input.offset.unwrap_or(0), input.regex),
92 Err(err) => return Task::ready(Err(anyhow!(err))),
93 };
94
95 let query = match SearchQuery::regex(
96 ®ex,
97 false,
98 false,
99 false,
100 PathMatcher::default(),
101 PathMatcher::default(),
102 None,
103 ) {
104 Ok(query) => query,
105 Err(error) => return Task::ready(Err(error)),
106 };
107
108 let results = project.update(cx, |project, cx| project.search(query, cx));
109
110 cx.spawn(async move|cx| {
111 futures::pin_mut!(results);
112
113 let mut output = String::new();
114 let mut skips_remaining = offset;
115 let mut matches_found = 0;
116 let mut has_more_matches = false;
117
118 while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
119 if ranges.is_empty() {
120 continue;
121 }
122
123 buffer.read_with(cx, |buffer, cx| -> Result<(), anyhow::Error> {
124 if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
125 let mut file_header_written = false;
126 let mut ranges = ranges
127 .into_iter()
128 .map(|range| {
129 let mut point_range = range.to_point(buffer);
130 point_range.start.row =
131 point_range.start.row.saturating_sub(CONTEXT_LINES);
132 point_range.start.column = 0;
133 point_range.end.row = cmp::min(
134 buffer.max_point().row,
135 point_range.end.row + CONTEXT_LINES,
136 );
137 point_range.end.column = buffer.line_len(point_range.end.row);
138 point_range
139 })
140 .peekable();
141
142 while let Some(mut range) = ranges.next() {
143 if skips_remaining > 0 {
144 skips_remaining -= 1;
145 continue;
146 }
147
148 // We'd already found a full page of matches, and we just found one more.
149 if matches_found >= RESULTS_PER_PAGE {
150 has_more_matches = true;
151 return Ok(());
152 }
153
154 while let Some(next_range) = ranges.peek() {
155 if range.end.row >= next_range.start.row {
156 range.end = next_range.end;
157 ranges.next();
158 } else {
159 break;
160 }
161 }
162
163 if !file_header_written {
164 writeln!(output, "\n## Matches in {}", path.display())?;
165 file_header_written = true;
166 }
167
168 let start_line = range.start.row + 1;
169 let end_line = range.end.row + 1;
170 writeln!(output, "\n### Lines {start_line}-{end_line}\n```")?;
171 output.extend(buffer.text_for_range(range));
172 output.push_str("\n```\n");
173
174 matches_found += 1;
175 }
176 }
177
178 Ok(())
179 })??;
180 }
181
182 if matches_found == 0 {
183 Ok("No matches found".to_string())
184 } else if has_more_matches {
185 Ok(format!(
186 "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
187 offset + 1,
188 offset + matches_found,
189 offset + RESULTS_PER_PAGE,
190 ))
191 } else {
192 Ok(format!("Found {matches_found} matches:\n{output}"))
193 }
194 })
195 }
196}