1use std::ops::Range;
2
3use anyhow::Result;
4use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery;
5use collections::HashMap;
6use futures::{
7 StreamExt,
8 channel::mpsc::{self, UnboundedSender},
9};
10use gpui::{AppContext, AsyncApp, Entity};
11use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, Point, ToOffset, ToPoint};
12use project::{
13 Project, WorktreeSettings,
14 search::{SearchQuery, SearchResult},
15};
16use smol::channel;
17use util::{
18 ResultExt as _,
19 paths::{PathMatcher, PathStyle},
20};
21use workspace::item::Settings as _;
22
23pub async fn run_retrieval_searches(
24 project: Entity<Project>,
25 queries: Vec<SearchToolQuery>,
26 cx: &mut AsyncApp,
27) -> Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>> {
28 let (exclude_matcher, path_style) = project.update(cx, |project, cx| {
29 let global_settings = WorktreeSettings::get_global(cx);
30 let exclude_patterns = global_settings
31 .file_scan_exclusions
32 .sources()
33 .iter()
34 .chain(global_settings.private_files.sources().iter());
35 let path_style = project.path_style(cx);
36 anyhow::Ok((PathMatcher::new(exclude_patterns, path_style)?, path_style))
37 })??;
38
39 let (results_tx, mut results_rx) = mpsc::unbounded();
40
41 for query in queries {
42 let exclude_matcher = exclude_matcher.clone();
43 let results_tx = results_tx.clone();
44 let project = project.clone();
45 cx.spawn(async move |cx| {
46 run_query(
47 query,
48 results_tx.clone(),
49 path_style,
50 exclude_matcher,
51 &project,
52 cx,
53 )
54 .await
55 .log_err();
56 })
57 .detach()
58 }
59 drop(results_tx);
60
61 cx.background_spawn(async move {
62 let mut results: HashMap<Entity<Buffer>, Vec<Range<Anchor>>> = HashMap::default();
63 let mut snapshots = HashMap::default();
64
65 let mut total_bytes = 0;
66 'outer: while let Some((buffer, snapshot, excerpts)) = results_rx.next().await {
67 snapshots.insert(buffer.entity_id(), snapshot);
68 let existing = results.entry(buffer).or_default();
69 existing.reserve(excerpts.len());
70
71 for (range, size) in excerpts {
72 // Blunt trimming of the results until we have a proper algorithmic filtering step
73 if (total_bytes + size) > MAX_RESULTS_LEN {
74 log::trace!("Combined results reached limit of {MAX_RESULTS_LEN}B");
75 break 'outer;
76 }
77 total_bytes += size;
78 existing.push(range);
79 }
80 }
81
82 for (buffer, ranges) in results.iter_mut() {
83 if let Some(snapshot) = snapshots.get(&buffer.entity_id()) {
84 merge_anchor_ranges(ranges, snapshot);
85 }
86 }
87
88 Ok(results)
89 })
90 .await
91}
92
93fn merge_anchor_ranges(ranges: &mut Vec<Range<Anchor>>, snapshot: &BufferSnapshot) {
94 ranges.sort_unstable_by(|a, b| {
95 a.start
96 .cmp(&b.start, snapshot)
97 .then(b.end.cmp(&b.end, snapshot))
98 });
99
100 let mut index = 1;
101 while index < ranges.len() {
102 if ranges[index - 1]
103 .end
104 .cmp(&ranges[index].start, snapshot)
105 .is_ge()
106 {
107 let removed = ranges.remove(index);
108 ranges[index - 1].end = removed.end;
109 } else {
110 index += 1;
111 }
112 }
113}
114
115const MAX_EXCERPT_LEN: usize = 768;
116const MAX_RESULTS_LEN: usize = MAX_EXCERPT_LEN * 5;
117
118struct SearchJob {
119 buffer: Entity<Buffer>,
120 snapshot: BufferSnapshot,
121 ranges: Vec<Range<usize>>,
122 query_ix: usize,
123 jobs_tx: channel::Sender<SearchJob>,
124}
125
126async fn run_query(
127 input_query: SearchToolQuery,
128 results_tx: UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
129 path_style: PathStyle,
130 exclude_matcher: PathMatcher,
131 project: &Entity<Project>,
132 cx: &mut AsyncApp,
133) -> Result<()> {
134 let include_matcher = PathMatcher::new(vec![input_query.glob], path_style)?;
135
136 let make_search = |regex: &str| -> Result<SearchQuery> {
137 SearchQuery::regex(
138 regex,
139 false,
140 true,
141 false,
142 true,
143 include_matcher.clone(),
144 exclude_matcher.clone(),
145 true,
146 None,
147 )
148 };
149
150 if let Some(outer_syntax_regex) = input_query.syntax_node.first() {
151 let outer_syntax_query = make_search(outer_syntax_regex)?;
152 let nested_syntax_queries = input_query
153 .syntax_node
154 .into_iter()
155 .skip(1)
156 .map(|query| make_search(&query))
157 .collect::<Result<Vec<_>>>()?;
158 let content_query = input_query
159 .content
160 .map(|regex| make_search(®ex))
161 .transpose()?;
162
163 let (jobs_tx, jobs_rx) = channel::unbounded();
164
165 let outer_search_results_rx =
166 project.update(cx, |project, cx| project.search(outer_syntax_query, cx))?;
167
168 let outer_search_task = cx.spawn(async move |cx| {
169 futures::pin_mut!(outer_search_results_rx);
170 while let Some(SearchResult::Buffer { buffer, ranges }) =
171 outer_search_results_rx.next().await
172 {
173 buffer
174 .read_with(cx, |buffer, _| buffer.parsing_idle())?
175 .await;
176 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
177 let expanded_ranges: Vec<_> = ranges
178 .into_iter()
179 .filter_map(|range| expand_to_parent_range(&range, &snapshot))
180 .collect();
181 jobs_tx
182 .send(SearchJob {
183 buffer,
184 snapshot,
185 ranges: expanded_ranges,
186 query_ix: 0,
187 jobs_tx: jobs_tx.clone(),
188 })
189 .await?;
190 }
191 anyhow::Ok(())
192 });
193
194 let n_workers = cx.background_executor().num_cpus();
195 let search_job_task = cx.background_executor().scoped(|scope| {
196 for _ in 0..n_workers {
197 scope.spawn(async {
198 while let Ok(job) = jobs_rx.recv().await {
199 process_nested_search_job(
200 &results_tx,
201 &nested_syntax_queries,
202 &content_query,
203 job,
204 )
205 .await;
206 }
207 });
208 }
209 });
210
211 search_job_task.await;
212 outer_search_task.await?;
213 } else if let Some(content_regex) = &input_query.content {
214 let search_query = make_search(&content_regex)?;
215
216 let results_rx = project.update(cx, |project, cx| project.search(search_query, cx))?;
217 futures::pin_mut!(results_rx);
218
219 while let Some(SearchResult::Buffer { buffer, ranges }) = results_rx.next().await {
220 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
221
222 let ranges = ranges
223 .into_iter()
224 .map(|range| {
225 let range = range.to_offset(&snapshot);
226 let range = expand_to_entire_lines(range, &snapshot);
227 let size = range.len();
228 let range =
229 snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
230 (range, size)
231 })
232 .collect();
233
234 let send_result = results_tx.unbounded_send((buffer.clone(), snapshot.clone(), ranges));
235
236 if let Err(err) = send_result
237 && !err.is_disconnected()
238 {
239 log::error!("{err}");
240 }
241 }
242 } else {
243 log::warn!("Context gathering model produced a glob-only search");
244 }
245
246 anyhow::Ok(())
247}
248
249async fn process_nested_search_job(
250 results_tx: &UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
251 queries: &Vec<SearchQuery>,
252 content_query: &Option<SearchQuery>,
253 job: SearchJob,
254) {
255 if let Some(search_query) = queries.get(job.query_ix) {
256 let mut subranges = Vec::new();
257 for range in job.ranges {
258 let start = range.start;
259 let search_results = search_query.search(&job.snapshot, Some(range)).await;
260 for subrange in search_results {
261 let subrange = start + subrange.start..start + subrange.end;
262 subranges.extend(expand_to_parent_range(&subrange, &job.snapshot));
263 }
264 }
265 job.jobs_tx
266 .send(SearchJob {
267 buffer: job.buffer,
268 snapshot: job.snapshot,
269 ranges: subranges,
270 query_ix: job.query_ix + 1,
271 jobs_tx: job.jobs_tx.clone(),
272 })
273 .await
274 .ok();
275 } else {
276 let ranges = if let Some(content_query) = content_query {
277 let mut subranges = Vec::new();
278 for range in job.ranges {
279 let start = range.start;
280 let search_results = content_query.search(&job.snapshot, Some(range)).await;
281 for subrange in search_results {
282 let subrange = start + subrange.start..start + subrange.end;
283 subranges.push(subrange);
284 }
285 }
286 subranges
287 } else {
288 job.ranges
289 };
290
291 let matches = ranges
292 .into_iter()
293 .map(|range| {
294 let snapshot = &job.snapshot;
295 let range = expand_to_entire_lines(range, snapshot);
296 let size = range.len();
297 let range = snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
298 (range, size)
299 })
300 .collect();
301
302 let send_result = results_tx.unbounded_send((job.buffer, job.snapshot, matches));
303
304 if let Err(err) = send_result
305 && !err.is_disconnected()
306 {
307 log::error!("{err}");
308 }
309 }
310}
311
312fn expand_to_entire_lines(range: Range<usize>, snapshot: &BufferSnapshot) -> Range<usize> {
313 let mut point_range = range.to_point(snapshot);
314 point_range.start.column = 0;
315 if point_range.end.column > 0 {
316 point_range.end = snapshot.max_point().min(point_range.end + Point::new(1, 0));
317 }
318 point_range.to_offset(snapshot)
319}
320
321fn expand_to_parent_range<T: ToPoint + ToOffset>(
322 range: &Range<T>,
323 snapshot: &BufferSnapshot,
324) -> Option<Range<usize>> {
325 let mut line_range = range.to_point(&snapshot);
326 line_range.start.column = snapshot.indent_size_for_line(line_range.start.row).len;
327 line_range.end.column = snapshot.line_len(line_range.end.row);
328 // TODO skip result if matched line isn't the first node line?
329
330 let node = snapshot.syntax_ancestor(line_range)?;
331 Some(node.byte_range())
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use crate::merge_excerpts::merge_excerpts;
338 use cloud_zeta2_prompt::write_codeblock;
339 use edit_prediction_context::Line;
340 use gpui::TestAppContext;
341 use indoc::indoc;
342 use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
343 use pretty_assertions::assert_eq;
344 use project::FakeFs;
345 use serde_json::json;
346 use settings::SettingsStore;
347 use std::path::Path;
348 use util::path;
349
350 #[gpui::test]
351 async fn test_retrieval(cx: &mut TestAppContext) {
352 init_test(cx);
353 let fs = FakeFs::new(cx.executor());
354 fs.insert_tree(
355 path!("/root"),
356 json!({
357 "user.rs": indoc!{"
358 pub struct Organization {
359 owner: Arc<User>,
360 }
361
362 pub struct User {
363 first_name: String,
364 last_name: String,
365 }
366
367 impl Organization {
368 pub fn owner(&self) -> Arc<User> {
369 self.owner.clone()
370 }
371 }
372
373 impl User {
374 pub fn new(first_name: String, last_name: String) -> Self {
375 Self {
376 first_name,
377 last_name
378 }
379 }
380
381 pub fn first_name(&self) -> String {
382 self.first_name.clone()
383 }
384
385 pub fn last_name(&self) -> String {
386 self.last_name.clone()
387 }
388 }
389 "},
390 "main.rs": indoc!{r#"
391 fn main() {
392 let user = User::new(FIRST_NAME.clone(), "doe".into());
393 println!("user {:?}", user);
394 }
395 "#},
396 }),
397 )
398 .await;
399
400 let project = Project::test(fs, vec![Path::new(path!("/root"))], cx).await;
401 project.update(cx, |project, _cx| {
402 project.languages().add(rust_lang().into())
403 });
404
405 assert_results(
406 &project,
407 SearchToolQuery {
408 glob: "user.rs".into(),
409 syntax_node: vec!["impl\\s+User".into(), "pub\\s+fn\\s+first_name".into()],
410 content: None,
411 },
412 indoc! {r#"
413 `````root/user.rs
414 …
415 impl User {
416 …
417 pub fn first_name(&self) -> String {
418 self.first_name.clone()
419 }
420 …
421 `````
422 "#},
423 cx,
424 )
425 .await;
426
427 assert_results(
428 &project,
429 SearchToolQuery {
430 glob: "user.rs".into(),
431 syntax_node: vec!["impl\\s+User".into()],
432 content: Some("\\.clone".into()),
433 },
434 indoc! {r#"
435 `````root/user.rs
436 …
437 impl User {
438 …
439 pub fn first_name(&self) -> String {
440 self.first_name.clone()
441 …
442 pub fn last_name(&self) -> String {
443 self.last_name.clone()
444 …
445 `````
446 "#},
447 cx,
448 )
449 .await;
450
451 assert_results(
452 &project,
453 SearchToolQuery {
454 glob: "*.rs".into(),
455 syntax_node: vec![],
456 content: Some("\\.clone".into()),
457 },
458 indoc! {r#"
459 `````root/main.rs
460 fn main() {
461 let user = User::new(FIRST_NAME.clone(), "doe".into());
462 …
463 `````
464
465 `````root/user.rs
466 …
467 impl Organization {
468 pub fn owner(&self) -> Arc<User> {
469 self.owner.clone()
470 …
471 impl User {
472 …
473 pub fn first_name(&self) -> String {
474 self.first_name.clone()
475 …
476 pub fn last_name(&self) -> String {
477 self.last_name.clone()
478 …
479 `````
480 "#},
481 cx,
482 )
483 .await;
484 }
485
486 async fn assert_results(
487 project: &Entity<Project>,
488 query: SearchToolQuery,
489 expected_output: &str,
490 cx: &mut TestAppContext,
491 ) {
492 let results = run_retrieval_searches(project.clone(), vec![query], &mut cx.to_async())
493 .await
494 .unwrap();
495
496 let mut results = results.into_iter().collect::<Vec<_>>();
497 results.sort_by_key(|results| {
498 results
499 .0
500 .read_with(cx, |buffer, _| buffer.file().unwrap().path().clone())
501 });
502
503 let mut output = String::new();
504 for (buffer, ranges) in results {
505 buffer.read_with(cx, |buffer, cx| {
506 let excerpts = ranges.into_iter().map(|range| {
507 let point_range = range.to_point(buffer);
508 if point_range.end.column > 0 {
509 Line(point_range.start.row)..Line(point_range.end.row + 1)
510 } else {
511 Line(point_range.start.row)..Line(point_range.end.row)
512 }
513 });
514
515 write_codeblock(
516 &buffer.file().unwrap().full_path(cx),
517 merge_excerpts(&buffer.snapshot(), excerpts).iter(),
518 &[],
519 Line(buffer.max_point().row),
520 false,
521 &mut output,
522 );
523 });
524 }
525 output.pop();
526
527 assert_eq!(output, expected_output);
528 }
529
530 fn rust_lang() -> Language {
531 Language::new(
532 LanguageConfig {
533 name: "Rust".into(),
534 matcher: LanguageMatcher {
535 path_suffixes: vec!["rs".to_string()],
536 ..Default::default()
537 },
538 ..Default::default()
539 },
540 Some(tree_sitter_rust::LANGUAGE.into()),
541 )
542 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
543 .unwrap()
544 }
545
546 fn init_test(cx: &mut TestAppContext) {
547 cx.update(move |cx| {
548 let settings_store = SettingsStore::test(cx);
549 cx.set_global(settings_store);
550 zlog::init_test();
551 });
552 }
553}