1use crate::schema::json_schema_for;
2use anyhow::{Result, anyhow};
3use assistant_tool::{ActionLog, Tool, ToolResult};
4use futures::StreamExt;
5use gpui::{App, Entity, Task};
6use language::OffsetRangeExt;
7use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
8use project::{
9 Project,
10 search::{SearchQuery, SearchResult},
11};
12use schemars::JsonSchema;
13use serde::{Deserialize, Serialize};
14use std::{cmp, fmt::Write, sync::Arc};
15use ui::IconName;
16use util::markdown::MarkdownString;
17use util::paths::PathMatcher;
18
19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
20pub struct RegexSearchToolInput {
21 /// A regex pattern to search for in the entire project. Note that the regex
22 /// will be parsed by the Rust `regex` crate.
23 pub regex: String,
24
25 /// Optional starting position for paginated results (0-based).
26 /// When not provided, starts from the beginning.
27 #[serde(default)]
28 pub offset: u32,
29
30 /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
31 #[serde(default)]
32 pub case_sensitive: bool,
33}
34
35impl RegexSearchToolInput {
36 /// Which page of search results this is.
37 pub fn page(&self) -> u32 {
38 1 + (self.offset / RESULTS_PER_PAGE)
39 }
40}
41
42const RESULTS_PER_PAGE: u32 = 20;
43
44pub struct RegexSearchTool;
45
46impl Tool for RegexSearchTool {
47 fn name(&self) -> String {
48 "regex_search".into()
49 }
50
51 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
52 false
53 }
54
55 fn description(&self) -> String {
56 include_str!("./regex_search_tool/description.md").into()
57 }
58
59 fn icon(&self) -> IconName {
60 IconName::Regex
61 }
62
63 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
64 json_schema_for::<RegexSearchToolInput>(format)
65 }
66
67 fn ui_text(&self, input: &serde_json::Value) -> String {
68 match serde_json::from_value::<RegexSearchToolInput>(input.clone()) {
69 Ok(input) => {
70 let page = input.page();
71 let regex_str = MarkdownString::inline_code(&input.regex);
72 let case_info = if input.case_sensitive {
73 " (case-sensitive)"
74 } else {
75 ""
76 };
77
78 if page > 1 {
79 format!("Get page {page} of search results for regex {regex_str}{case_info}")
80 } else {
81 format!("Search files for regex {regex_str}{case_info}")
82 }
83 }
84 Err(_) => "Search with regex".to_string(),
85 }
86 }
87
88 fn run(
89 self: Arc<Self>,
90 input: serde_json::Value,
91 _messages: &[LanguageModelRequestMessage],
92 project: Entity<Project>,
93 _action_log: Entity<ActionLog>,
94 cx: &mut App,
95 ) -> ToolResult {
96 const CONTEXT_LINES: u32 = 2;
97
98 let (offset, regex, case_sensitive) =
99 match serde_json::from_value::<RegexSearchToolInput>(input) {
100 Ok(input) => (input.offset, input.regex, input.case_sensitive),
101 Err(err) => return Task::ready(Err(anyhow!(err))).into(),
102 };
103
104 let query = match SearchQuery::regex(
105 ®ex,
106 false,
107 case_sensitive,
108 false,
109 false,
110 PathMatcher::default(),
111 PathMatcher::default(),
112 None,
113 ) {
114 Ok(query) => query,
115 Err(error) => return Task::ready(Err(error)).into(),
116 };
117
118 let results = project.update(cx, |project, cx| project.search(query, cx));
119
120 cx.spawn(async move|cx| {
121 futures::pin_mut!(results);
122
123 let mut output = String::new();
124 let mut skips_remaining = offset;
125 let mut matches_found = 0;
126 let mut has_more_matches = false;
127
128 while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
129 if ranges.is_empty() {
130 continue;
131 }
132
133 buffer.read_with(cx, |buffer, cx| -> Result<(), anyhow::Error> {
134 if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
135 let mut file_header_written = false;
136 let mut ranges = ranges
137 .into_iter()
138 .map(|range| {
139 let mut point_range = range.to_point(buffer);
140 point_range.start.row =
141 point_range.start.row.saturating_sub(CONTEXT_LINES);
142 point_range.start.column = 0;
143 point_range.end.row = cmp::min(
144 buffer.max_point().row,
145 point_range.end.row + CONTEXT_LINES,
146 );
147 point_range.end.column = buffer.line_len(point_range.end.row);
148 point_range
149 })
150 .peekable();
151
152 while let Some(mut range) = ranges.next() {
153 if skips_remaining > 0 {
154 skips_remaining -= 1;
155 continue;
156 }
157
158 // We'd already found a full page of matches, and we just found one more.
159 if matches_found >= RESULTS_PER_PAGE {
160 has_more_matches = true;
161 return Ok(());
162 }
163
164 while let Some(next_range) = ranges.peek() {
165 if range.end.row >= next_range.start.row {
166 range.end = next_range.end;
167 ranges.next();
168 } else {
169 break;
170 }
171 }
172
173 if !file_header_written {
174 writeln!(output, "\n## Matches in {}", path.display())?;
175 file_header_written = true;
176 }
177
178 let start_line = range.start.row + 1;
179 let end_line = range.end.row + 1;
180 writeln!(output, "\n### Lines {start_line}-{end_line}\n```")?;
181 output.extend(buffer.text_for_range(range));
182 output.push_str("\n```\n");
183
184 matches_found += 1;
185 }
186 }
187
188 Ok(())
189 })??;
190 }
191
192 if matches_found == 0 {
193 Ok("No matches found".to_string())
194 } else if has_more_matches {
195 Ok(format!(
196 "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
197 offset + 1,
198 offset + matches_found,
199 offset + RESULTS_PER_PAGE,
200 ))
201 } else {
202 Ok(format!("Found {matches_found} matches:\n{output}"))
203 }
204 }).into()
205 }
206}