regex_search_tool.rs

  1use crate::schema::json_schema_for;
  2use anyhow::{Result, anyhow};
  3use assistant_tool::{ActionLog, Tool};
  4use futures::StreamExt;
  5use gpui::{App, Entity, Task};
  6use language::OffsetRangeExt;
  7use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
  8use project::{
  9    Project,
 10    search::{SearchQuery, SearchResult},
 11};
 12use schemars::JsonSchema;
 13use serde::{Deserialize, Serialize};
 14use std::{cmp, fmt::Write, sync::Arc};
 15use ui::IconName;
 16use util::markdown::MarkdownString;
 17use util::paths::PathMatcher;
 18
 19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 20pub struct RegexSearchToolInput {
 21    /// A regex pattern to search for in the entire project. Note that the regex
 22    /// will be parsed by the Rust `regex` crate.
 23    pub regex: String,
 24
 25    /// Optional starting position for paginated results (0-based).
 26    /// When not provided, starts from the beginning.
 27    #[serde(default)]
 28    pub offset: u32,
 29}
 30
 31impl RegexSearchToolInput {
 32    /// Which page of search results this is.
 33    pub fn page(&self) -> u32 {
 34        1 + (self.offset / RESULTS_PER_PAGE)
 35    }
 36}
 37
 38const RESULTS_PER_PAGE: u32 = 20;
 39
 40pub struct RegexSearchTool;
 41
 42impl Tool for RegexSearchTool {
 43    fn name(&self) -> String {
 44        "regex_search".into()
 45    }
 46
 47    fn needs_confirmation(&self) -> bool {
 48        false
 49    }
 50
 51    fn description(&self) -> String {
 52        include_str!("./regex_search_tool/description.md").into()
 53    }
 54
 55    fn icon(&self) -> IconName {
 56        IconName::Regex
 57    }
 58
 59    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> serde_json::Value {
 60        json_schema_for::<RegexSearchToolInput>(format)
 61    }
 62
 63    fn ui_text(&self, input: &serde_json::Value) -> String {
 64        match serde_json::from_value::<RegexSearchToolInput>(input.clone()) {
 65            Ok(input) => {
 66                let page = input.page();
 67                let regex = MarkdownString::inline_code(&input.regex);
 68
 69                if page > 1 {
 70                    format!("Get page {page} of search results for regex “{regex}")
 71                } else {
 72                    format!("Search files for regex “{regex}")
 73                }
 74            }
 75            Err(_) => "Search with regex".to_string(),
 76        }
 77    }
 78
 79    fn run(
 80        self: Arc<Self>,
 81        input: serde_json::Value,
 82        _messages: &[LanguageModelRequestMessage],
 83        project: Entity<Project>,
 84        _action_log: Entity<ActionLog>,
 85        cx: &mut App,
 86    ) -> Task<Result<String>> {
 87        const CONTEXT_LINES: u32 = 2;
 88
 89        let (offset, regex) = match serde_json::from_value::<RegexSearchToolInput>(input) {
 90            Ok(input) => (input.offset, input.regex),
 91            Err(err) => return Task::ready(Err(anyhow!(err))),
 92        };
 93
 94        let query = match SearchQuery::regex(
 95            &regex,
 96            false,
 97            false,
 98            false,
 99            PathMatcher::default(),
100            PathMatcher::default(),
101            None,
102        ) {
103            Ok(query) => query,
104            Err(error) => return Task::ready(Err(error)),
105        };
106
107        let results = project.update(cx, |project, cx| project.search(query, cx));
108
109        cx.spawn(async move|cx|  {
110            futures::pin_mut!(results);
111
112            let mut output = String::new();
113            let mut skips_remaining = offset;
114            let mut matches_found = 0;
115            let mut has_more_matches = false;
116
117            while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
118                if ranges.is_empty() {
119                    continue;
120                }
121
122                buffer.read_with(cx, |buffer, cx| -> Result<(), anyhow::Error> {
123                    if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
124                        let mut file_header_written = false;
125                        let mut ranges = ranges
126                            .into_iter()
127                            .map(|range| {
128                                let mut point_range = range.to_point(buffer);
129                                point_range.start.row =
130                                    point_range.start.row.saturating_sub(CONTEXT_LINES);
131                                point_range.start.column = 0;
132                                point_range.end.row = cmp::min(
133                                    buffer.max_point().row,
134                                    point_range.end.row + CONTEXT_LINES,
135                                );
136                                point_range.end.column = buffer.line_len(point_range.end.row);
137                                point_range
138                            })
139                            .peekable();
140
141                        while let Some(mut range) = ranges.next() {
142                            if skips_remaining > 0 {
143                                skips_remaining -= 1;
144                                continue;
145                            }
146
147                            // We'd already found a full page of matches, and we just found one more.
148                            if matches_found >= RESULTS_PER_PAGE {
149                                has_more_matches = true;
150                                return Ok(());
151                            }
152
153                            while let Some(next_range) = ranges.peek() {
154                                if range.end.row >= next_range.start.row {
155                                    range.end = next_range.end;
156                                    ranges.next();
157                                } else {
158                                    break;
159                                }
160                            }
161
162                            if !file_header_written {
163                                writeln!(output, "\n## Matches in {}", path.display())?;
164                                file_header_written = true;
165                            }
166
167                            let start_line = range.start.row + 1;
168                            let end_line = range.end.row + 1;
169                            writeln!(output, "\n### Lines {start_line}-{end_line}\n```")?;
170                            output.extend(buffer.text_for_range(range));
171                            output.push_str("\n```\n");
172
173                            matches_found += 1;
174                        }
175                    }
176
177                    Ok(())
178                })??;
179            }
180
181            if matches_found == 0 {
182                Ok("No matches found".to_string())
183            } else if has_more_matches {
184                Ok(format!(
185                    "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
186                    offset + 1,
187                    offset + matches_found,
188                    offset + RESULTS_PER_PAGE,
189                ))
190            } else {
191                Ok(format!("Found {matches_found} matches:\n{output}"))
192            }
193        })
194    }
195}