regex_search_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Result, anyhow};
  3use assistant_tool::{ActionLog, Tool};
  4use futures::StreamExt;
  5use gpui::{App, Entity, Task};
  6use language::OffsetRangeExt;
  7use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
  8use project::{
  9    Project,
 10    search::{SearchQuery, SearchResult},
 11};
 12use schemars::JsonSchema;
 13use serde::{Deserialize, Serialize};
 14use std::{cmp, fmt::Write, sync::Arc};
 15use ui::IconName;
 16use util::markdown::MarkdownString;
 17use util::paths::PathMatcher;
 18
 19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 20pub struct RegexSearchToolInput {
 21    /// A regex pattern to search for in the entire project. Note that the regex
 22    /// will be parsed by the Rust `regex` crate.
 23    pub regex: String,
 24
 25    /// Optional starting position for paginated results (0-based).
 26    /// When not provided, starts from the beginning.
 27    #[serde(default)]
 28    pub offset: u32,
 29
 30    /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
 31    #[serde(default)]
 32    pub case_sensitive: bool,
 33}
 34
 35impl RegexSearchToolInput {
 36    /// Which page of search results this is.
 37    pub fn page(&self) -> u32 {
 38        1 + (self.offset / RESULTS_PER_PAGE)
 39    }
 40}
 41
 42const RESULTS_PER_PAGE: u32 = 20;
 43
 44pub struct RegexSearchTool;
 45
 46impl Tool for RegexSearchTool {
 47    fn name(&self) -> String {
 48        "regex_search".into()
 49    }
 50
 51    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 52        false
 53    }
 54
 55    fn description(&self) -> String {
 56        include_str!("./regex_search_tool/description.md").into()
 57    }
 58
 59    fn icon(&self) -> IconName {
 60        IconName::Regex
 61    }
 62
 63    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
 64        json_schema_for::<RegexSearchToolInput>(format)
 65    }
 66
 67    fn ui_text(&self, input: &serde_json::Value) -> String {
 68        match serde_json::from_value::<RegexSearchToolInput>(input.clone()) {
 69            Ok(input) => {
 70                let page = input.page();
 71                let regex_str = MarkdownString::inline_code(&input.regex);
 72                let case_info = if input.case_sensitive {
 73                    " (case-sensitive)"
 74                } else {
 75                    ""
 76                };
 77
 78                if page > 1 {
 79                    format!("Get page {page} of search results for regex {regex_str}{case_info}")
 80                } else {
 81                    format!("Search files for regex {regex_str}{case_info}")
 82                }
 83            }
 84            Err(_) => "Search with regex".to_string(),
 85        }
 86    }
 87
 88    fn run(
 89        self: Arc<Self>,
 90        input: serde_json::Value,
 91        _messages: &[LanguageModelRequestMessage],
 92        project: Entity<Project>,
 93        _action_log: Entity<ActionLog>,
 94        cx: &mut App,
 95    ) -> Task<Result<String>> {
 96        const CONTEXT_LINES: u32 = 2;
 97
 98        let (offset, regex, case_sensitive) =
 99            match serde_json::from_value::<RegexSearchToolInput>(input) {
100                Ok(input) => (input.offset, input.regex, input.case_sensitive),
101                Err(err) => return Task::ready(Err(anyhow!(err))),
102            };
103
104        let query = match SearchQuery::regex(
105            &regex,
106            false,
107            case_sensitive,
108            false,
109            PathMatcher::default(),
110            PathMatcher::default(),
111            None,
112        ) {
113            Ok(query) => query,
114            Err(error) => return Task::ready(Err(error)),
115        };
116
117        let results = project.update(cx, |project, cx| project.search(query, cx));
118
119        cx.spawn(async move|cx|  {
120            futures::pin_mut!(results);
121
122            let mut output = String::new();
123            let mut skips_remaining = offset;
124            let mut matches_found = 0;
125            let mut has_more_matches = false;
126
127            while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
128                if ranges.is_empty() {
129                    continue;
130                }
131
132                buffer.read_with(cx, |buffer, cx| -> Result<(), anyhow::Error> {
133                    if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
134                        let mut file_header_written = false;
135                        let mut ranges = ranges
136                            .into_iter()
137                            .map(|range| {
138                                let mut point_range = range.to_point(buffer);
139                                point_range.start.row =
140                                    point_range.start.row.saturating_sub(CONTEXT_LINES);
141                                point_range.start.column = 0;
142                                point_range.end.row = cmp::min(
143                                    buffer.max_point().row,
144                                    point_range.end.row + CONTEXT_LINES,
145                                );
146                                point_range.end.column = buffer.line_len(point_range.end.row);
147                                point_range
148                            })
149                            .peekable();
150
151                        while let Some(mut range) = ranges.next() {
152                            if skips_remaining > 0 {
153                                skips_remaining -= 1;
154                                continue;
155                            }
156
157                            // We'd already found a full page of matches, and we just found one more.
158                            if matches_found >= RESULTS_PER_PAGE {
159                                has_more_matches = true;
160                                return Ok(());
161                            }
162
163                            while let Some(next_range) = ranges.peek() {
164                                if range.end.row >= next_range.start.row {
165                                    range.end = next_range.end;
166                                    ranges.next();
167                                } else {
168                                    break;
169                                }
170                            }
171
172                            if !file_header_written {
173                                writeln!(output, "\n## Matches in {}", path.display())?;
174                                file_header_written = true;
175                            }
176
177                            let start_line = range.start.row + 1;
178                            let end_line = range.end.row + 1;
179                            writeln!(output, "\n### Lines {start_line}-{end_line}\n```")?;
180                            output.extend(buffer.text_for_range(range));
181                            output.push_str("\n```\n");
182
183                            matches_found += 1;
184                        }
185                    }
186
187                    Ok(())
188                })??;
189            }
190
191            if matches_found == 0 {
192                Ok("No matches found".to_string())
193            } else if has_more_matches {
194                Ok(format!(
195                    "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
196                    offset + 1,
197                    offset + matches_found,
198                    offset + RESULTS_PER_PAGE,
199                ))
200            } else {
201                Ok(format!("Found {matches_found} matches:\n{output}"))
202            }
203        })
204    }
205}