regex_search_tool.rs

  1use anyhow::{anyhow, Result};
  2use assistant_tool::{ActionLog, Tool};
  3use futures::StreamExt;
  4use gpui::{App, Entity, Task};
  5use language::OffsetRangeExt;
  6use language_model::LanguageModelRequestMessage;
  7use project::{
  8    search::{SearchQuery, SearchResult},
  9    Project,
 10};
 11use schemars::JsonSchema;
 12use serde::{Deserialize, Serialize};
 13use std::{cmp, fmt::Write, sync::Arc};
 14use ui::IconName;
 15use util::paths::PathMatcher;
 16
 17#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 18pub struct RegexSearchToolInput {
 19    /// A regex pattern to search for in the entire project. Note that the regex
 20    /// will be parsed by the Rust `regex` crate.
 21    pub regex: String,
 22
 23    /// Optional starting position for paginated results (0-based).
 24    /// When not provided, starts from the beginning.
 25    #[serde(default)]
 26    pub offset: Option<u32>,
 27}
 28
 29impl RegexSearchToolInput {
 30    /// Which page of search results this is.
 31    pub fn page(&self) -> u32 {
 32        1 + (self.offset.unwrap_or(0) / RESULTS_PER_PAGE)
 33    }
 34}
 35
 36const RESULTS_PER_PAGE: u32 = 20;
 37
 38pub struct RegexSearchTool;
 39
 40impl Tool for RegexSearchTool {
 41    fn name(&self) -> String {
 42        "regex-search".into()
 43    }
 44
 45    fn needs_confirmation(&self) -> bool {
 46        false
 47    }
 48
 49    fn description(&self) -> String {
 50        include_str!("./regex_search_tool/description.md").into()
 51    }
 52
 53    fn icon(&self) -> IconName {
 54        IconName::Regex
 55    }
 56
 57    fn input_schema(&self) -> serde_json::Value {
 58        let schema = schemars::schema_for!(RegexSearchToolInput);
 59        serde_json::to_value(&schema).unwrap()
 60    }
 61
 62    fn ui_text(&self, input: &serde_json::Value) -> String {
 63        match serde_json::from_value::<RegexSearchToolInput>(input.clone()) {
 64            Ok(input) => {
 65                let page = input.page();
 66
 67                if page > 1 {
 68                    format!(
 69                        "Get page {page} of search results for regex “`{}`”",
 70                        input.regex
 71                    )
 72                } else {
 73                    format!("Search files for regex “`{}`”", input.regex)
 74                }
 75            }
 76            Err(_) => "Search with regex".to_string(),
 77        }
 78    }
 79
 80    fn run(
 81        self: Arc<Self>,
 82        input: serde_json::Value,
 83        _messages: &[LanguageModelRequestMessage],
 84        project: Entity<Project>,
 85        _action_log: Entity<ActionLog>,
 86        cx: &mut App,
 87    ) -> Task<Result<String>> {
 88        const CONTEXT_LINES: u32 = 2;
 89
 90        let (offset, regex) = match serde_json::from_value::<RegexSearchToolInput>(input) {
 91            Ok(input) => (input.offset.unwrap_or(0), input.regex),
 92            Err(err) => return Task::ready(Err(anyhow!(err))),
 93        };
 94
 95        let query = match SearchQuery::regex(
 96            &regex,
 97            false,
 98            false,
 99            false,
100            PathMatcher::default(),
101            PathMatcher::default(),
102            None,
103        ) {
104            Ok(query) => query,
105            Err(error) => return Task::ready(Err(error)),
106        };
107
108        let results = project.update(cx, |project, cx| project.search(query, cx));
109
110        cx.spawn(async move|cx|  {
111            futures::pin_mut!(results);
112
113            let mut output = String::new();
114            let mut skips_remaining = offset;
115            let mut matches_found = 0;
116            let mut has_more_matches = false;
117
118            while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
119                if ranges.is_empty() {
120                    continue;
121                }
122
123                buffer.read_with(cx, |buffer, cx| -> Result<(), anyhow::Error> {
124                    if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
125                        let mut file_header_written = false;
126                        let mut ranges = ranges
127                            .into_iter()
128                            .map(|range| {
129                                let mut point_range = range.to_point(buffer);
130                                point_range.start.row =
131                                    point_range.start.row.saturating_sub(CONTEXT_LINES);
132                                point_range.start.column = 0;
133                                point_range.end.row = cmp::min(
134                                    buffer.max_point().row,
135                                    point_range.end.row + CONTEXT_LINES,
136                                );
137                                point_range.end.column = buffer.line_len(point_range.end.row);
138                                point_range
139                            })
140                            .peekable();
141
142                        while let Some(mut range) = ranges.next() {
143                            if skips_remaining > 0 {
144                                skips_remaining -= 1;
145                                continue;
146                            }
147
148                            // We'd already found a full page of matches, and we just found one more.
149                            if matches_found >= RESULTS_PER_PAGE {
150                                has_more_matches = true;
151                                return Ok(());
152                            }
153
154                            while let Some(next_range) = ranges.peek() {
155                                if range.end.row >= next_range.start.row {
156                                    range.end = next_range.end;
157                                    ranges.next();
158                                } else {
159                                    break;
160                                }
161                            }
162
163                            if !file_header_written {
164                                writeln!(output, "\n## Matches in {}", path.display())?;
165                                file_header_written = true;
166                            }
167
168                            let start_line = range.start.row + 1;
169                            let end_line = range.end.row + 1;
170                            writeln!(output, "\n### Lines {start_line}-{end_line}\n```")?;
171                            output.extend(buffer.text_for_range(range));
172                            output.push_str("\n```\n");
173
174                            matches_found += 1;
175                        }
176                    }
177
178                    Ok(())
179                })??;
180            }
181
182            if matches_found == 0 {
183                Ok("No matches found".to_string())
184            } else if has_more_matches {
185                Ok(format!(
186                    "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
187                    offset + 1,
188                    offset + matches_found,
189                    offset + RESULTS_PER_PAGE,
190                ))
191            } else {
192                Ok(format!("Found {matches_found} matches:\n{output}"))
193            }
194        })
195    }
196}