regex_search.rs

  1use anyhow::{anyhow, Result};
  2use assistant_tool::{ActionLog, Tool};
  3use futures::StreamExt;
  4use gpui::{App, Entity, Task};
  5use language::OffsetRangeExt;
  6use language_model::LanguageModelRequestMessage;
  7use project::{search::SearchQuery, Project};
  8use schemars::JsonSchema;
  9use serde::{Deserialize, Serialize};
 10use std::{cmp, fmt::Write, sync::Arc};
 11use util::paths::PathMatcher;
 12
 13#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 14pub struct RegexSearchToolInput {
 15    /// A regex pattern to search for in the entire project. Note that the regex
 16    /// will be parsed by the Rust `regex` crate.
 17    pub regex: String,
 18}
 19
 20pub struct RegexSearchTool;
 21
 22impl Tool for RegexSearchTool {
 23    fn name(&self) -> String {
 24        "regex-search".into()
 25    }
 26
 27    fn description(&self) -> String {
 28        include_str!("./regex_search_tool/description.md").into()
 29    }
 30
 31    fn input_schema(&self) -> serde_json::Value {
 32        let schema = schemars::schema_for!(RegexSearchToolInput);
 33        serde_json::to_value(&schema).unwrap()
 34    }
 35
 36    fn run(
 37        self: Arc<Self>,
 38        input: serde_json::Value,
 39        _messages: &[LanguageModelRequestMessage],
 40        project: Entity<Project>,
 41        _action_log: Entity<ActionLog>,
 42        cx: &mut App,
 43    ) -> Task<Result<String>> {
 44        const CONTEXT_LINES: u32 = 2;
 45
 46        let input = match serde_json::from_value::<RegexSearchToolInput>(input) {
 47            Ok(input) => input,
 48            Err(err) => return Task::ready(Err(anyhow!(err))),
 49        };
 50
 51        let query = match SearchQuery::regex(
 52            &input.regex,
 53            false,
 54            false,
 55            false,
 56            PathMatcher::default(),
 57            PathMatcher::default(),
 58            None,
 59        ) {
 60            Ok(query) => query,
 61            Err(error) => return Task::ready(Err(error)),
 62        };
 63
 64        let results = project.update(cx, |project, cx| project.search(query, cx));
 65        cx.spawn(|cx| async move {
 66            futures::pin_mut!(results);
 67
 68            let mut output = String::new();
 69            while let Some(project::search::SearchResult::Buffer { buffer, ranges }) =
 70                results.next().await
 71            {
 72                if ranges.is_empty() {
 73                    continue;
 74                }
 75
 76                buffer.read_with(&cx, |buffer, cx| {
 77                    if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
 78                        writeln!(output, "### Found matches in {}:\n", path.display()).unwrap();
 79                        let mut ranges = ranges
 80                            .into_iter()
 81                            .map(|range| {
 82                                let mut point_range = range.to_point(buffer);
 83                                point_range.start.row =
 84                                    point_range.start.row.saturating_sub(CONTEXT_LINES);
 85                                point_range.start.column = 0;
 86                                point_range.end.row = cmp::min(
 87                                    buffer.max_point().row,
 88                                    point_range.end.row + CONTEXT_LINES,
 89                                );
 90                                point_range.end.column = buffer.line_len(point_range.end.row);
 91                                point_range
 92                            })
 93                            .peekable();
 94
 95                        while let Some(mut range) = ranges.next() {
 96                            while let Some(next_range) = ranges.peek() {
 97                                if range.end.row >= next_range.start.row {
 98                                    range.end = next_range.end;
 99                                    ranges.next();
100                                } else {
101                                    break;
102                                }
103                            }
104
105                            writeln!(output, "```").unwrap();
106                            output.extend(buffer.text_for_range(range));
107                            writeln!(output, "\n```\n").unwrap();
108                        }
109                    }
110                })?;
111            }
112
113            if output.is_empty() {
114                Ok("No matches found".to_string())
115            } else {
116                Ok(output)
117            }
118        })
119    }
120}