regex_search_tool.rs

  1use anyhow::{anyhow, Result};
  2use assistant_tool::{ActionLog, Tool};
  3use futures::StreamExt;
  4use gpui::{App, Entity, Task};
  5use language::OffsetRangeExt;
  6use language_model::LanguageModelRequestMessage;
  7use project::{
  8    search::{SearchQuery, SearchResult},
  9    Project,
 10};
 11use schemars::JsonSchema;
 12use serde::{Deserialize, Serialize};
 13use std::{cmp, fmt::Write, sync::Arc};
 14use util::paths::PathMatcher;
 15
 16#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 17pub struct RegexSearchToolInput {
 18    /// A regex pattern to search for in the entire project. Note that the regex
 19    /// will be parsed by the Rust `regex` crate.
 20    pub regex: String,
 21
 22    /// Optional starting position for paginated results (0-based).
 23    /// When not provided, starts from the beginning.
 24    #[serde(default)]
 25    pub offset: Option<usize>,
 26}
 27
 28const RESULTS_PER_PAGE: usize = 20;
 29
 30pub struct RegexSearchTool;
 31
 32impl Tool for RegexSearchTool {
 33    fn name(&self) -> String {
 34        "regex-search".into()
 35    }
 36
 37    fn description(&self) -> String {
 38        include_str!("./regex_search_tool/description.md").into()
 39    }
 40
 41    fn input_schema(&self) -> serde_json::Value {
 42        let schema = schemars::schema_for!(RegexSearchToolInput);
 43        serde_json::to_value(&schema).unwrap()
 44    }
 45
 46    fn run(
 47        self: Arc<Self>,
 48        input: serde_json::Value,
 49        _messages: &[LanguageModelRequestMessage],
 50        project: Entity<Project>,
 51        _action_log: Entity<ActionLog>,
 52        cx: &mut App,
 53    ) -> Task<Result<String>> {
 54        const CONTEXT_LINES: u32 = 2;
 55
 56        let (offset, regex) = match serde_json::from_value::<RegexSearchToolInput>(input) {
 57            Ok(input) => (input.offset.unwrap_or(0), input.regex),
 58            Err(err) => return Task::ready(Err(anyhow!(err))),
 59        };
 60
 61        let query = match SearchQuery::regex(
 62            &regex,
 63            false,
 64            false,
 65            false,
 66            PathMatcher::default(),
 67            PathMatcher::default(),
 68            None,
 69        ) {
 70            Ok(query) => query,
 71            Err(error) => return Task::ready(Err(error)),
 72        };
 73
 74        let results = project.update(cx, |project, cx| project.search(query, cx));
 75
 76        cx.spawn(|cx| async move {
 77            futures::pin_mut!(results);
 78
 79            let mut output = String::new();
 80            let mut skips_remaining = offset;
 81            let mut matches_found = 0;
 82            let mut has_more_matches = false;
 83
 84            while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
 85                if ranges.is_empty() {
 86                    continue;
 87                }
 88
 89                buffer.read_with(&cx, |buffer, cx| -> Result<(), anyhow::Error> {
 90                    if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
 91                        let mut file_header_written = false;
 92                        let mut ranges = ranges
 93                            .into_iter()
 94                            .map(|range| {
 95                                let mut point_range = range.to_point(buffer);
 96                                point_range.start.row =
 97                                    point_range.start.row.saturating_sub(CONTEXT_LINES);
 98                                point_range.start.column = 0;
 99                                point_range.end.row = cmp::min(
100                                    buffer.max_point().row,
101                                    point_range.end.row + CONTEXT_LINES,
102                                );
103                                point_range.end.column = buffer.line_len(point_range.end.row);
104                                point_range
105                            })
106                            .peekable();
107
108                        while let Some(mut range) = ranges.next() {
109                            if skips_remaining > 0 {
110                                skips_remaining -= 1;
111                                continue;
112                            }
113
114                            // We'd already found a full page of matches, and we just found one more.
115                            if matches_found >= RESULTS_PER_PAGE {
116                                has_more_matches = true;
117                                return Ok(());
118                            }
119
120                            while let Some(next_range) = ranges.peek() {
121                                if range.end.row >= next_range.start.row {
122                                    range.end = next_range.end;
123                                    ranges.next();
124                                } else {
125                                    break;
126                                }
127                            }
128
129                            if !file_header_written {
130                                writeln!(output, "\n## Matches in {}", path.display())?;
131                                file_header_written = true;
132                            }
133
134                            let start_line = range.start.row + 1;
135                            let end_line = range.end.row + 1;
136                            writeln!(output, "\n### Lines {start_line}-{end_line}\n```")?;
137                            output.extend(buffer.text_for_range(range));
138                            output.push_str("\n```\n");
139
140                            matches_found += 1;
141                        }
142                    }
143
144                    Ok(())
145                })??;
146            }
147
148            if matches_found == 0 {
149                Ok("No matches found".to_string())
150            } else if has_more_matches {
151                Ok(format!(
152                    "Showing matches {}-{} (there were more matches found; use offset: {} to see next page):\n{output}",
153                    offset + 1,
154                    offset + matches_found,
155                    offset + RESULTS_PER_PAGE,
156                ))
157          } else {
158                Ok(format!("Found {matches_found} matches:\n{output}"))
159            }
160        })
161    }
162}