regex_search_tool.rs

  1use anyhow::{anyhow, Result};
  2use assistant_tool::{ActionLog, Tool};
  3use futures::StreamExt;
  4use gpui::{App, Entity, Task};
  5use language::OffsetRangeExt;
  6use language_model::LanguageModelRequestMessage;
  7use project::{
  8    search::{SearchQuery, SearchResult},
  9    Project,
 10};
 11use schemars::JsonSchema;
 12use serde::{Deserialize, Serialize};
 13use std::{cmp, fmt::Write, sync::Arc};
 14use util::paths::PathMatcher;
 15
 16#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 17pub struct RegexSearchToolInput {
 18    /// A regex pattern to search for in the entire project. Note that the regex
 19    /// will be parsed by the Rust `regex` crate.
 20    pub regex: String,
 21
 22    /// Optional starting position for paginated results (0-based).
 23    /// When not provided, starts from the beginning.
 24    #[serde(default)]
 25    pub offset: Option<u32>,
 26}
 27
 28impl RegexSearchToolInput {
 29    /// Which page of search results this is.
 30    pub fn page(&self) -> u32 {
 31        1 + (self.offset.unwrap_or(0) / RESULTS_PER_PAGE)
 32    }
 33}
 34
 35const RESULTS_PER_PAGE: u32 = 20;
 36
 37pub struct RegexSearchTool;
 38
 39impl Tool for RegexSearchTool {
 40    fn name(&self) -> String {
 41        "regex-search".into()
 42    }
 43
 44    fn needs_confirmation(&self) -> bool {
 45        false
 46    }
 47
 48    fn description(&self) -> String {
 49        include_str!("./regex_search_tool/description.md").into()
 50    }
 51
 52    fn input_schema(&self) -> serde_json::Value {
 53        let schema = schemars::schema_for!(RegexSearchToolInput);
 54        serde_json::to_value(&schema).unwrap()
 55    }
 56
 57    fn ui_text(&self, input: &serde_json::Value) -> String {
 58        match serde_json::from_value::<RegexSearchToolInput>(input.clone()) {
 59            Ok(input) => {
 60                let page = input.page();
 61
 62                if page > 1 {
 63                    format!(
 64                        "Get page {page} of search results for regex “`{}`”",
 65                        input.regex
 66                    )
 67                } else {
 68                    format!("Search files for regex “`{}`”", input.regex)
 69                }
 70            }
 71            Err(_) => "Search with regex".to_string(),
 72        }
 73    }
 74
 75    fn run(
 76        self: Arc<Self>,
 77        input: serde_json::Value,
 78        _messages: &[LanguageModelRequestMessage],
 79        project: Entity<Project>,
 80        _action_log: Entity<ActionLog>,
 81        cx: &mut App,
 82    ) -> Task<Result<String>> {
 83        const CONTEXT_LINES: u32 = 2;
 84
 85        let (offset, regex) = match serde_json::from_value::<RegexSearchToolInput>(input) {
 86            Ok(input) => (input.offset.unwrap_or(0), input.regex),
 87            Err(err) => return Task::ready(Err(anyhow!(err))),
 88        };
 89
 90        let query = match SearchQuery::regex(
 91            &regex,
 92            false,
 93            false,
 94            false,
 95            PathMatcher::default(),
 96            PathMatcher::default(),
 97            None,
 98        ) {
 99            Ok(query) => query,
100            Err(error) => return Task::ready(Err(error)),
101        };
102
103        let results = project.update(cx, |project, cx| project.search(query, cx));
104
105        cx.spawn(async move|cx|  {
106            futures::pin_mut!(results);
107
108            let mut output = String::new();
109            let mut skips_remaining = offset;
110            let mut matches_found = 0;
111            let mut has_more_matches = false;
112
113            while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
114                if ranges.is_empty() {
115                    continue;
116                }
117
118                buffer.read_with(cx, |buffer, cx| -> Result<(), anyhow::Error> {
119                    if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
120                        let mut file_header_written = false;
121                        let mut ranges = ranges
122                            .into_iter()
123                            .map(|range| {
124                                let mut point_range = range.to_point(buffer);
125                                point_range.start.row =
126                                    point_range.start.row.saturating_sub(CONTEXT_LINES);
127                                point_range.start.column = 0;
128                                point_range.end.row = cmp::min(
129                                    buffer.max_point().row,
130                                    point_range.end.row + CONTEXT_LINES,
131                                );
132                                point_range.end.column = buffer.line_len(point_range.end.row);
133                                point_range
134                            })
135                            .peekable();
136
137                        while let Some(mut range) = ranges.next() {
138                            if skips_remaining > 0 {
139                                skips_remaining -= 1;
140                                continue;
141                            }
142
143                            // We'd already found a full page of matches, and we just found one more.
144                            if matches_found >= RESULTS_PER_PAGE {
145                                has_more_matches = true;
146                                return Ok(());
147                            }
148
149                            while let Some(next_range) = ranges.peek() {
150                                if range.end.row >= next_range.start.row {
151                                    range.end = next_range.end;
152                                    ranges.next();
153                                } else {
154                                    break;
155                                }
156                            }
157
158                            if !file_header_written {
159                                writeln!(output, "\n## Matches in {}", path.display())?;
160                                file_header_written = true;
161                            }
162
163                            let start_line = range.start.row + 1;
164                            let end_line = range.end.row + 1;
165                            writeln!(output, "\n### Lines {start_line}-{end_line}\n```")?;
166                            output.extend(buffer.text_for_range(range));
167                            output.push_str("\n```\n");
168
169                            matches_found += 1;
170                        }
171                    }
172
173                    Ok(())
174                })??;
175            }
176
177            if matches_found == 0 {
178                Ok("No matches found".to_string())
179            } else if has_more_matches {
180                Ok(format!(
181                    "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
182                    offset + 1,
183                    offset + matches_found,
184                    offset + RESULTS_PER_PAGE,
185                ))
186            } else {
187                Ok(format!("Found {matches_found} matches:\n{output}"))
188            }
189        })
190    }
191}