1use anyhow::{anyhow, Result};
2use assistant_tool::{ActionLog, Tool};
3use futures::StreamExt;
4use gpui::{App, Entity, Task};
5use language::OffsetRangeExt;
6use language_model::LanguageModelRequestMessage;
7use project::{
8 search::{SearchQuery, SearchResult},
9 Project,
10};
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use std::{cmp, fmt::Write, sync::Arc};
14use util::paths::PathMatcher;
15
16#[derive(Debug, Serialize, Deserialize, JsonSchema)]
17pub struct RegexSearchToolInput {
18 /// A regex pattern to search for in the entire project. Note that the regex
19 /// will be parsed by the Rust `regex` crate.
20 pub regex: String,
21
22 /// Optional starting position for paginated results (0-based).
23 /// When not provided, starts from the beginning.
24 #[serde(default)]
25 pub offset: Option<u32>,
26}
27
28impl RegexSearchToolInput {
29 /// Which page of search results this is.
30 pub fn page(&self) -> u32 {
31 1 + (self.offset.unwrap_or(0) / RESULTS_PER_PAGE)
32 }
33}
34
35const RESULTS_PER_PAGE: u32 = 20;
36
37pub struct RegexSearchTool;
38
39impl Tool for RegexSearchTool {
40 fn name(&self) -> String {
41 "regex-search".into()
42 }
43
44 fn description(&self) -> String {
45 include_str!("./regex_search_tool/description.md").into()
46 }
47
48 fn input_schema(&self) -> serde_json::Value {
49 let schema = schemars::schema_for!(RegexSearchToolInput);
50 serde_json::to_value(&schema).unwrap()
51 }
52
53 fn ui_text(&self, input: &serde_json::Value) -> String {
54 match serde_json::from_value::<RegexSearchToolInput>(input.clone()) {
55 Ok(input) => {
56 let page = input.page();
57
58 if page > 1 {
59 format!(
60 "Get page {page} of search results for regex “`{}`”",
61 input.regex
62 )
63 } else {
64 format!("Search files for regex “`{}`”", input.regex)
65 }
66 }
67 Err(_) => "Search with regex".to_string(),
68 }
69 }
70
71 fn run(
72 self: Arc<Self>,
73 input: serde_json::Value,
74 _messages: &[LanguageModelRequestMessage],
75 project: Entity<Project>,
76 _action_log: Entity<ActionLog>,
77 cx: &mut App,
78 ) -> Task<Result<String>> {
79 const CONTEXT_LINES: u32 = 2;
80
81 let (offset, regex) = match serde_json::from_value::<RegexSearchToolInput>(input) {
82 Ok(input) => (input.offset.unwrap_or(0), input.regex),
83 Err(err) => return Task::ready(Err(anyhow!(err))),
84 };
85
86 let query = match SearchQuery::regex(
87 ®ex,
88 false,
89 false,
90 false,
91 PathMatcher::default(),
92 PathMatcher::default(),
93 None,
94 ) {
95 Ok(query) => query,
96 Err(error) => return Task::ready(Err(error)),
97 };
98
99 let results = project.update(cx, |project, cx| project.search(query, cx));
100
101 cx.spawn(async move|cx| {
102 futures::pin_mut!(results);
103
104 let mut output = String::new();
105 let mut skips_remaining = offset;
106 let mut matches_found = 0;
107 let mut has_more_matches = false;
108
109 while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
110 if ranges.is_empty() {
111 continue;
112 }
113
114 buffer.read_with(cx, |buffer, cx| -> Result<(), anyhow::Error> {
115 if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
116 let mut file_header_written = false;
117 let mut ranges = ranges
118 .into_iter()
119 .map(|range| {
120 let mut point_range = range.to_point(buffer);
121 point_range.start.row =
122 point_range.start.row.saturating_sub(CONTEXT_LINES);
123 point_range.start.column = 0;
124 point_range.end.row = cmp::min(
125 buffer.max_point().row,
126 point_range.end.row + CONTEXT_LINES,
127 );
128 point_range.end.column = buffer.line_len(point_range.end.row);
129 point_range
130 })
131 .peekable();
132
133 while let Some(mut range) = ranges.next() {
134 if skips_remaining > 0 {
135 skips_remaining -= 1;
136 continue;
137 }
138
139 // We'd already found a full page of matches, and we just found one more.
140 if matches_found >= RESULTS_PER_PAGE {
141 has_more_matches = true;
142 return Ok(());
143 }
144
145 while let Some(next_range) = ranges.peek() {
146 if range.end.row >= next_range.start.row {
147 range.end = next_range.end;
148 ranges.next();
149 } else {
150 break;
151 }
152 }
153
154 if !file_header_written {
155 writeln!(output, "\n## Matches in {}", path.display())?;
156 file_header_written = true;
157 }
158
159 let start_line = range.start.row + 1;
160 let end_line = range.end.row + 1;
161 writeln!(output, "\n### Lines {start_line}-{end_line}\n```")?;
162 output.extend(buffer.text_for_range(range));
163 output.push_str("\n```\n");
164
165 matches_found += 1;
166 }
167 }
168
169 Ok(())
170 })??;
171 }
172
173 if matches_found == 0 {
174 Ok("No matches found".to_string())
175 } else if has_more_matches {
176 Ok(format!(
177 "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
178 offset + 1,
179 offset + matches_found,
180 offset + RESULTS_PER_PAGE,
181 ))
182 } else {
183 Ok(format!("Found {matches_found} matches:\n{output}"))
184 }
185 })
186 }
187}