regex_search.rs

  1use anyhow::{anyhow, Result};
  2use assistant_tool::Tool;
  3use futures::StreamExt;
  4use gpui::{App, Entity, Task};
  5use language::OffsetRangeExt;
  6use language_model::LanguageModelRequestMessage;
  7use project::{search::SearchQuery, Project};
  8use schemars::JsonSchema;
  9use serde::{Deserialize, Serialize};
 10use std::{cmp, fmt::Write, sync::Arc};
 11use util::paths::PathMatcher;
 12
 13#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 14pub struct RegexSearchToolInput {
 15    /// A regex pattern to search for in the entire project. Note that the regex
 16    /// will be parsed by the Rust `regex` crate.
 17    pub regex: String,
 18}
 19
 20pub struct RegexSearchTool;
 21
 22impl Tool for RegexSearchTool {
 23    fn name(&self) -> String {
 24        "regex-search".into()
 25    }
 26
 27    fn description(&self) -> String {
 28        include_str!("./regex_search_tool/description.md").into()
 29    }
 30
 31    fn input_schema(&self) -> serde_json::Value {
 32        let schema = schemars::schema_for!(RegexSearchToolInput);
 33        serde_json::to_value(&schema).unwrap()
 34    }
 35
 36    fn run(
 37        self: Arc<Self>,
 38        input: serde_json::Value,
 39        _messages: &[LanguageModelRequestMessage],
 40        project: Entity<Project>,
 41        cx: &mut App,
 42    ) -> Task<Result<String>> {
 43        const CONTEXT_LINES: u32 = 2;
 44
 45        let input = match serde_json::from_value::<RegexSearchToolInput>(input) {
 46            Ok(input) => input,
 47            Err(err) => return Task::ready(Err(anyhow!(err))),
 48        };
 49
 50        let query = match SearchQuery::regex(
 51            &input.regex,
 52            false,
 53            false,
 54            false,
 55            PathMatcher::default(),
 56            PathMatcher::default(),
 57            None,
 58        ) {
 59            Ok(query) => query,
 60            Err(error) => return Task::ready(Err(error)),
 61        };
 62
 63        let results = project.update(cx, |project, cx| project.search(query, cx));
 64        cx.spawn(|cx| async move {
 65            futures::pin_mut!(results);
 66
 67            let mut output = String::new();
 68            while let Some(project::search::SearchResult::Buffer { buffer, ranges }) =
 69                results.next().await
 70            {
 71                if ranges.is_empty() {
 72                    continue;
 73                }
 74
 75                buffer.read_with(&cx, |buffer, cx| {
 76                    if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
 77                        writeln!(output, "### Found matches in {}:\n", path.display()).unwrap();
 78                        let mut ranges = ranges
 79                            .into_iter()
 80                            .map(|range| {
 81                                let mut point_range = range.to_point(buffer);
 82                                point_range.start.row =
 83                                    point_range.start.row.saturating_sub(CONTEXT_LINES);
 84                                point_range.start.column = 0;
 85                                point_range.end.row = cmp::min(
 86                                    buffer.max_point().row,
 87                                    point_range.end.row + CONTEXT_LINES,
 88                                );
 89                                point_range.end.column = buffer.line_len(point_range.end.row);
 90                                point_range
 91                            })
 92                            .peekable();
 93
 94                        while let Some(mut range) = ranges.next() {
 95                            while let Some(next_range) = ranges.peek() {
 96                                if range.end.row >= next_range.start.row {
 97                                    range.end = next_range.end;
 98                                    ranges.next();
 99                                } else {
100                                    break;
101                                }
102                            }
103
104                            writeln!(output, "```").unwrap();
105                            output.extend(buffer.text_for_range(range));
106                            writeln!(output, "\n```\n").unwrap();
107                        }
108                    }
109                })?;
110            }
111
112            if output.is_empty() {
113                Ok("No matches found".into())
114            } else {
115                Ok(output)
116            }
117        })
118    }
119}