regex_search_tool.rs

  1use anyhow::{anyhow, Result};
  2use assistant_tool::{ActionLog, Tool};
  3use futures::StreamExt;
  4use gpui::{App, Entity, Task};
  5use language::OffsetRangeExt;
  6use language_model::LanguageModelRequestMessage;
  7use project::{
  8    search::{SearchQuery, SearchResult},
  9    Project,
 10};
 11use schemars::JsonSchema;
 12use serde::{Deserialize, Serialize};
 13use std::{cmp, fmt::Write, sync::Arc};
 14use util::paths::PathMatcher;
 15
 16#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 17pub struct RegexSearchToolInput {
 18    /// A regex pattern to search for in the entire project. Note that the regex
 19    /// will be parsed by the Rust `regex` crate.
 20    pub regex: String,
 21
 22    /// Optional starting position for paginated results (0-based).
 23    /// When not provided, starts from the beginning.
 24    #[serde(default)]
 25    pub offset: Option<u32>,
 26}
 27
 28impl RegexSearchToolInput {
 29    /// Which page of search results this is.
 30    pub fn page(&self) -> u32 {
 31        1 + (self.offset.unwrap_or(0) / RESULTS_PER_PAGE)
 32    }
 33}
 34
 35const RESULTS_PER_PAGE: u32 = 20;
 36
 37pub struct RegexSearchTool;
 38
 39impl Tool for RegexSearchTool {
 40    fn name(&self) -> String {
 41        "regex-search".into()
 42    }
 43
 44    fn description(&self) -> String {
 45        include_str!("./regex_search_tool/description.md").into()
 46    }
 47
 48    fn input_schema(&self) -> serde_json::Value {
 49        let schema = schemars::schema_for!(RegexSearchToolInput);
 50        serde_json::to_value(&schema).unwrap()
 51    }
 52
 53    fn ui_text(&self, input: &serde_json::Value) -> String {
 54        match serde_json::from_value::<RegexSearchToolInput>(input.clone()) {
 55            Ok(input) => {
 56                let page = input.page();
 57
 58                if page > 1 {
 59                    format!(
 60                        "Get page {page} of search results for regex “`{}`”",
 61                        input.regex
 62                    )
 63                } else {
 64                    format!("Search files for regex “`{}`”", input.regex)
 65                }
 66            }
 67            Err(_) => "Search with regex".to_string(),
 68        }
 69    }
 70
 71    fn run(
 72        self: Arc<Self>,
 73        input: serde_json::Value,
 74        _messages: &[LanguageModelRequestMessage],
 75        project: Entity<Project>,
 76        _action_log: Entity<ActionLog>,
 77        cx: &mut App,
 78    ) -> Task<Result<String>> {
 79        const CONTEXT_LINES: u32 = 2;
 80
 81        let (offset, regex) = match serde_json::from_value::<RegexSearchToolInput>(input) {
 82            Ok(input) => (input.offset.unwrap_or(0), input.regex),
 83            Err(err) => return Task::ready(Err(anyhow!(err))),
 84        };
 85
 86        let query = match SearchQuery::regex(
 87            &regex,
 88            false,
 89            false,
 90            false,
 91            PathMatcher::default(),
 92            PathMatcher::default(),
 93            None,
 94        ) {
 95            Ok(query) => query,
 96            Err(error) => return Task::ready(Err(error)),
 97        };
 98
 99        let results = project.update(cx, |project, cx| project.search(query, cx));
100
101        cx.spawn(async move|cx|  {
102            futures::pin_mut!(results);
103
104            let mut output = String::new();
105            let mut skips_remaining = offset;
106            let mut matches_found = 0;
107            let mut has_more_matches = false;
108
109            while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
110                if ranges.is_empty() {
111                    continue;
112                }
113
114                buffer.read_with(cx, |buffer, cx| -> Result<(), anyhow::Error> {
115                    if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
116                        let mut file_header_written = false;
117                        let mut ranges = ranges
118                            .into_iter()
119                            .map(|range| {
120                                let mut point_range = range.to_point(buffer);
121                                point_range.start.row =
122                                    point_range.start.row.saturating_sub(CONTEXT_LINES);
123                                point_range.start.column = 0;
124                                point_range.end.row = cmp::min(
125                                    buffer.max_point().row,
126                                    point_range.end.row + CONTEXT_LINES,
127                                );
128                                point_range.end.column = buffer.line_len(point_range.end.row);
129                                point_range
130                            })
131                            .peekable();
132
133                        while let Some(mut range) = ranges.next() {
134                            if skips_remaining > 0 {
135                                skips_remaining -= 1;
136                                continue;
137                            }
138
139                            // We'd already found a full page of matches, and we just found one more.
140                            if matches_found >= RESULTS_PER_PAGE {
141                                has_more_matches = true;
142                                return Ok(());
143                            }
144
145                            while let Some(next_range) = ranges.peek() {
146                                if range.end.row >= next_range.start.row {
147                                    range.end = next_range.end;
148                                    ranges.next();
149                                } else {
150                                    break;
151                                }
152                            }
153
154                            if !file_header_written {
155                                writeln!(output, "\n## Matches in {}", path.display())?;
156                                file_header_written = true;
157                            }
158
159                            let start_line = range.start.row + 1;
160                            let end_line = range.end.row + 1;
161                            writeln!(output, "\n### Lines {start_line}-{end_line}\n```")?;
162                            output.extend(buffer.text_for_range(range));
163                            output.push_str("\n```\n");
164
165                            matches_found += 1;
166                        }
167                    }
168
169                    Ok(())
170                })??;
171            }
172
173            if matches_found == 0 {
174                Ok("No matches found".to_string())
175            } else if has_more_matches {
176                Ok(format!(
177                    "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
178                    offset + 1,
179                    offset + matches_found,
180                    offset + RESULTS_PER_PAGE,
181                ))
182            } else {
183                Ok(format!("Found {matches_found} matches:\n{output}"))
184            }
185        })
186    }
187}