regex_search_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Result, anyhow};
  3use assistant_tool::{ActionLog, Tool, ToolResult};
  4use futures::StreamExt;
  5use gpui::{App, Entity, Task};
  6use language::OffsetRangeExt;
  7use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
  8use project::{
  9    Project,
 10    search::{SearchQuery, SearchResult},
 11};
 12use schemars::JsonSchema;
 13use serde::{Deserialize, Serialize};
 14use std::{cmp, fmt::Write, sync::Arc};
 15use ui::IconName;
 16use util::markdown::MarkdownString;
 17use util::paths::PathMatcher;
 18
 19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 20pub struct RegexSearchToolInput {
 21    /// A regex pattern to search for in the entire project. Note that the regex
 22    /// will be parsed by the Rust `regex` crate.
 23    pub regex: String,
 24
 25    /// Optional starting position for paginated results (0-based).
 26    /// When not provided, starts from the beginning.
 27    #[serde(default)]
 28    pub offset: u32,
 29
 30    /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
 31    #[serde(default)]
 32    pub case_sensitive: bool,
 33}
 34
 35impl RegexSearchToolInput {
 36    /// Which page of search results this is.
 37    pub fn page(&self) -> u32 {
 38        1 + (self.offset / RESULTS_PER_PAGE)
 39    }
 40}
 41
 42const RESULTS_PER_PAGE: u32 = 20;
 43
 44pub struct RegexSearchTool;
 45
 46impl Tool for RegexSearchTool {
 47    fn name(&self) -> String {
 48        "regex_search".into()
 49    }
 50
 51    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 52        false
 53    }
 54
 55    fn description(&self) -> String {
 56        include_str!("./regex_search_tool/description.md").into()
 57    }
 58
 59    fn icon(&self) -> IconName {
 60        IconName::Regex
 61    }
 62
 63    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 64        json_schema_for::<RegexSearchToolInput>(format)
 65    }
 66
 67    fn ui_text(&self, input: &serde_json::Value) -> String {
 68        match serde_json::from_value::<RegexSearchToolInput>(input.clone()) {
 69            Ok(input) => {
 70                let page = input.page();
 71                let regex_str = MarkdownString::inline_code(&input.regex);
 72                let case_info = if input.case_sensitive {
 73                    " (case-sensitive)"
 74                } else {
 75                    ""
 76                };
 77
 78                if page > 1 {
 79                    format!("Get page {page} of search results for regex {regex_str}{case_info}")
 80                } else {
 81                    format!("Search files for regex {regex_str}{case_info}")
 82                }
 83            }
 84            Err(_) => "Search with regex".to_string(),
 85        }
 86    }
 87
 88    fn run(
 89        self: Arc<Self>,
 90        input: serde_json::Value,
 91        _messages: &[LanguageModelRequestMessage],
 92        project: Entity<Project>,
 93        _action_log: Entity<ActionLog>,
 94        cx: &mut App,
 95    ) -> ToolResult {
 96        const CONTEXT_LINES: u32 = 2;
 97
 98        let (offset, regex, case_sensitive) =
 99            match serde_json::from_value::<RegexSearchToolInput>(input) {
100                Ok(input) => (input.offset, input.regex, input.case_sensitive),
101                Err(err) => return Task::ready(Err(anyhow!(err))).into(),
102            };
103
104        let query = match SearchQuery::regex(
105            &regex,
106            false,
107            case_sensitive,
108            false,
109            false,
110            PathMatcher::default(),
111            PathMatcher::default(),
112            None,
113        ) {
114            Ok(query) => query,
115            Err(error) => return Task::ready(Err(error)).into(),
116        };
117
118        let results = project.update(cx, |project, cx| project.search(query, cx));
119
120        cx.spawn(async move|cx|  {
121            futures::pin_mut!(results);
122
123            let mut output = String::new();
124            let mut skips_remaining = offset;
125            let mut matches_found = 0;
126            let mut has_more_matches = false;
127
128            while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
129                if ranges.is_empty() {
130                    continue;
131                }
132
133                buffer.read_with(cx, |buffer, cx| -> Result<(), anyhow::Error> {
134                    if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
135                        let mut file_header_written = false;
136                        let mut ranges = ranges
137                            .into_iter()
138                            .map(|range| {
139                                let mut point_range = range.to_point(buffer);
140                                point_range.start.row =
141                                    point_range.start.row.saturating_sub(CONTEXT_LINES);
142                                point_range.start.column = 0;
143                                point_range.end.row = cmp::min(
144                                    buffer.max_point().row,
145                                    point_range.end.row + CONTEXT_LINES,
146                                );
147                                point_range.end.column = buffer.line_len(point_range.end.row);
148                                point_range
149                            })
150                            .peekable();
151
152                        while let Some(mut range) = ranges.next() {
153                            if skips_remaining > 0 {
154                                skips_remaining -= 1;
155                                continue;
156                            }
157
158                            // We'd already found a full page of matches, and we just found one more.
159                            if matches_found >= RESULTS_PER_PAGE {
160                                has_more_matches = true;
161                                return Ok(());
162                            }
163
164                            while let Some(next_range) = ranges.peek() {
165                                if range.end.row >= next_range.start.row {
166                                    range.end = next_range.end;
167                                    ranges.next();
168                                } else {
169                                    break;
170                                }
171                            }
172
173                            if !file_header_written {
174                                writeln!(output, "\n## Matches in {}", path.display())?;
175                                file_header_written = true;
176                            }
177
178                            let start_line = range.start.row + 1;
179                            let end_line = range.end.row + 1;
180                            writeln!(output, "\n### Lines {start_line}-{end_line}\n```")?;
181                            output.extend(buffer.text_for_range(range));
182                            output.push_str("\n```\n");
183
184                            matches_found += 1;
185                        }
186                    }
187
188                    Ok(())
189                })??;
190            }
191
192            if matches_found == 0 {
193                Ok("No matches found".to_string())
194            } else if has_more_matches {
195                Ok(format!(
196                    "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
197                    offset + 1,
198                    offset + matches_found,
199                    offset + RESULTS_PER_PAGE,
200                ))
201            } else {
202                Ok(format!("Found {matches_found} matches:\n{output}"))
203            }
204        }).into()
205    }
206}