1use anyhow::{anyhow, Result};
2use assistant_tool::{ActionLog, Tool};
3use futures::StreamExt;
4use gpui::{App, Entity, Task};
5use language::OffsetRangeExt;
6use language_model::LanguageModelRequestMessage;
7use project::{search::SearchQuery, Project};
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use std::{cmp, fmt::Write, sync::Arc};
11use util::paths::PathMatcher;
12
13#[derive(Debug, Serialize, Deserialize, JsonSchema)]
14pub struct RegexSearchToolInput {
15 /// A regex pattern to search for in the entire project. Note that the regex
16 /// will be parsed by the Rust `regex` crate.
17 pub regex: String,
18}
19
20pub struct RegexSearchTool;
21
22impl Tool for RegexSearchTool {
23 fn name(&self) -> String {
24 "regex-search".into()
25 }
26
27 fn description(&self) -> String {
28 include_str!("./regex_search_tool/description.md").into()
29 }
30
31 fn input_schema(&self) -> serde_json::Value {
32 let schema = schemars::schema_for!(RegexSearchToolInput);
33 serde_json::to_value(&schema).unwrap()
34 }
35
36 fn run(
37 self: Arc<Self>,
38 input: serde_json::Value,
39 _messages: &[LanguageModelRequestMessage],
40 project: Entity<Project>,
41 _action_log: Entity<ActionLog>,
42 cx: &mut App,
43 ) -> Task<Result<String>> {
44 const CONTEXT_LINES: u32 = 2;
45
46 let input = match serde_json::from_value::<RegexSearchToolInput>(input) {
47 Ok(input) => input,
48 Err(err) => return Task::ready(Err(anyhow!(err))),
49 };
50
51 let query = match SearchQuery::regex(
52 &input.regex,
53 false,
54 false,
55 false,
56 PathMatcher::default(),
57 PathMatcher::default(),
58 None,
59 ) {
60 Ok(query) => query,
61 Err(error) => return Task::ready(Err(error)),
62 };
63
64 let results = project.update(cx, |project, cx| project.search(query, cx));
65 cx.spawn(|cx| async move {
66 futures::pin_mut!(results);
67
68 let mut output = String::new();
69 while let Some(project::search::SearchResult::Buffer { buffer, ranges }) =
70 results.next().await
71 {
72 if ranges.is_empty() {
73 continue;
74 }
75
76 buffer.read_with(&cx, |buffer, cx| {
77 if let Some(path) = buffer.file().map(|file| file.full_path(cx)) {
78 writeln!(output, "### Found matches in {}:\n", path.display()).unwrap();
79 let mut ranges = ranges
80 .into_iter()
81 .map(|range| {
82 let mut point_range = range.to_point(buffer);
83 point_range.start.row =
84 point_range.start.row.saturating_sub(CONTEXT_LINES);
85 point_range.start.column = 0;
86 point_range.end.row = cmp::min(
87 buffer.max_point().row,
88 point_range.end.row + CONTEXT_LINES,
89 );
90 point_range.end.column = buffer.line_len(point_range.end.row);
91 point_range
92 })
93 .peekable();
94
95 while let Some(mut range) = ranges.next() {
96 while let Some(next_range) = ranges.peek() {
97 if range.end.row >= next_range.start.row {
98 range.end = next_range.end;
99 ranges.next();
100 } else {
101 break;
102 }
103 }
104
105 writeln!(output, "```").unwrap();
106 output.extend(buffer.text_for_range(range));
107 writeln!(output, "\n```\n").unwrap();
108 }
109 }
110 })?;
111 }
112
113 if output.is_empty() {
114 Ok("No matches found".to_string())
115 } else {
116 Ok(output)
117 }
118 })
119 }
120}