1use anyhow::{anyhow, Result};
2use assistant_tool::{ActionLog, Tool};
3use futures::StreamExt;
4use gpui::{App, Entity, Task};
5use language::OffsetRangeExt;
6use language_model::LanguageModelRequestMessage;
7use project::{
8 search::{SearchQuery, SearchResult},
9 Project,
10};
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use std::{cmp, fmt::Write, sync::Arc};
14use util::paths::PathMatcher;
15
16#[derive(Debug, Serialize, Deserialize, JsonSchema)]
17pub struct RegexSearchToolInput {
18 /// A regex pattern to search for in the entire project. Note that the regex
19 /// will be parsed by the Rust `regex` crate.
20 pub regex: String,
21
22 /// Optional starting position for paginated results (0-based).
23 /// When not provided, starts from the beginning.
24 #[serde(default)]
25 pub offset: Option<u32>,
26}
27
28impl RegexSearchToolInput {
29 /// Which page of search results this is.
30 pub fn page(&self) -> u32 {
31 1 + (self.offset.unwrap_or(0) / RESULTS_PER_PAGE)
32 }
33}
34
35const RESULTS_PER_PAGE: u32 = 20;
36
37pub struct RegexSearchTool;
38
39impl Tool for RegexSearchTool {
40 fn name(&self) -> String {
41 "regex-search".into()
42 }
43
44 fn needs_confirmation(&self) -> bool {
45 false
46 }
47
48 fn description(&self) -> String {
49 include_str!("./regex_search_tool/description.md").into()
50 }
51
52 fn input_schema(&self) -> serde_json::Value {
53 let schema = schemars::schema_for!(RegexSearchToolInput);
54 serde_json::to_value(&schema).unwrap()
55 }
56
57 fn ui_text(&self, input: &serde_json::Value) -> String {
58 match serde_json::from_value::<RegexSearchToolInput>(input.clone()) {
59 Ok(input) => {
60 let page = input.page();
61
62 if page > 1 {
63 format!(
64 "Get page {page} of search results for regex “`{}`”",
65 input.regex
66 )
67 } else {
68 format!("Search files for regex “`{}`”", input.regex)
69 }
70 }
71 Err(_) => "Search with regex".to_string(),
72 }
73 }
74
75 fn run(
76 self: Arc<Self>,
77 input: serde_json::Value,
78 _messages: &[LanguageModelRequestMessage],
79 project: Entity<Project>,
80 _action_log: Entity<ActionLog>,
81 cx: &mut App,
82 ) -> Task<Result<String>> {
83 const CONTEXT_LINES: u32 = 2;
84
85 let (offset, regex) = match serde_json::from_value::<RegexSearchToolInput>(input) {
86 Ok(input) => (input.offset.unwrap_or(0), input.regex),
87 Err(err) => return Task::ready(Err(anyhow!(err))),
88 };
89
90 let query = match SearchQuery::regex(
91 ®ex,
92 false,
93 false,
94 false,
95 PathMatcher::default(),
96 PathMatcher::default(),
97 None,
98 ) {
99 Ok(query) => query,
100 Err(error) => return Task::ready(Err(error)),
101 };
102
103 let results = project.update(cx, |project, cx| project.search(query, cx));
104
105 cx.spawn(async move|cx| {
106 futures::pin_mut!(results);
107
108 let mut output = String::new();
109 let mut skips_remaining = offset;
110 let mut matches_found = 0;
111 let mut has_more_matches = false;
112
113 while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
114 if ranges.is_empty() {
115 continue;
116 }
117
118 buffer.read_with(cx, |buffer, cx| -> Result<(), anyhow::Error> {
119 if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
120 let mut file_header_written = false;
121 let mut ranges = ranges
122 .into_iter()
123 .map(|range| {
124 let mut point_range = range.to_point(buffer);
125 point_range.start.row =
126 point_range.start.row.saturating_sub(CONTEXT_LINES);
127 point_range.start.column = 0;
128 point_range.end.row = cmp::min(
129 buffer.max_point().row,
130 point_range.end.row + CONTEXT_LINES,
131 );
132 point_range.end.column = buffer.line_len(point_range.end.row);
133 point_range
134 })
135 .peekable();
136
137 while let Some(mut range) = ranges.next() {
138 if skips_remaining > 0 {
139 skips_remaining -= 1;
140 continue;
141 }
142
143 // We'd already found a full page of matches, and we just found one more.
144 if matches_found >= RESULTS_PER_PAGE {
145 has_more_matches = true;
146 return Ok(());
147 }
148
149 while let Some(next_range) = ranges.peek() {
150 if range.end.row >= next_range.start.row {
151 range.end = next_range.end;
152 ranges.next();
153 } else {
154 break;
155 }
156 }
157
158 if !file_header_written {
159 writeln!(output, "\n## Matches in {}", path.display())?;
160 file_header_written = true;
161 }
162
163 let start_line = range.start.row + 1;
164 let end_line = range.end.row + 1;
165 writeln!(output, "\n### Lines {start_line}-{end_line}\n```")?;
166 output.extend(buffer.text_for_range(range));
167 output.push_str("\n```\n");
168
169 matches_found += 1;
170 }
171 }
172
173 Ok(())
174 })??;
175 }
176
177 if matches_found == 0 {
178 Ok("No matches found".to_string())
179 } else if has_more_matches {
180 Ok(format!(
181 "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
182 offset + 1,
183 offset + matches_found,
184 offset + RESULTS_PER_PAGE,
185 ))
186 } else {
187 Ok(format!("Found {matches_found} matches:\n{output}"))
188 }
189 })
190 }
191}