1use anyhow::Result;
2use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery;
3use collections::HashMap;
4use futures::{
5 StreamExt,
6 channel::mpsc::{self, UnboundedSender},
7};
8use gpui::{AppContext, AsyncApp, Entity};
9use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, Point, ToOffset, ToPoint};
10use project::{
11 Project, WorktreeSettings,
12 search::{SearchQuery, SearchResult},
13};
14use smol::channel;
15use std::ops::Range;
16use util::{
17 ResultExt as _,
18 paths::{PathMatcher, PathStyle},
19};
20use workspace::item::Settings as _;
21
22#[cfg(feature = "eval-support")]
23type CachedSearchResults = std::collections::BTreeMap<std::path::PathBuf, Vec<Range<usize>>>;
24
25pub async fn run_retrieval_searches(
26 queries: Vec<SearchToolQuery>,
27 project: Entity<Project>,
28 #[cfg(feature = "eval-support")] eval_cache: Option<std::sync::Arc<dyn crate::EvalCache>>,
29 cx: &mut AsyncApp,
30) -> Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>> {
31 #[cfg(feature = "eval-support")]
32 let cache = if let Some(eval_cache) = eval_cache {
33 use crate::EvalCacheEntryKind;
34 use anyhow::Context;
35 use collections::FxHasher;
36 use std::hash::{Hash, Hasher};
37
38 let mut hasher = FxHasher::default();
39 project.read_with(cx, |project, cx| {
40 let mut worktrees = project.worktrees(cx);
41 let Some(worktree) = worktrees.next() else {
42 panic!("Expected a single worktree in eval project. Found none.");
43 };
44 assert!(
45 worktrees.next().is_none(),
46 "Expected a single worktree in eval project. Found more than one."
47 );
48 worktree.read(cx).abs_path().hash(&mut hasher);
49 })?;
50
51 queries.hash(&mut hasher);
52 let key = (EvalCacheEntryKind::Search, hasher.finish());
53
54 if let Some(cached_results) = eval_cache.read(key) {
55 let file_results = serde_json::from_str::<CachedSearchResults>(&cached_results)
56 .context("Failed to deserialize cached search results")?;
57 let mut results = HashMap::default();
58
59 for (path, ranges) in file_results {
60 let buffer = project
61 .update(cx, |project, cx| {
62 let project_path = project.find_project_path(path, cx).unwrap();
63 project.open_buffer(project_path, cx)
64 })?
65 .await?;
66 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
67 let mut ranges: Vec<_> = ranges
68 .into_iter()
69 .map(|range| {
70 snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end)
71 })
72 .collect();
73 merge_anchor_ranges(&mut ranges, &snapshot);
74 results.insert(buffer, ranges);
75 }
76
77 return Ok(results);
78 }
79
80 Some((eval_cache, serde_json::to_string_pretty(&queries)?, key))
81 } else {
82 None
83 };
84
85 let (exclude_matcher, path_style) = project.update(cx, |project, cx| {
86 let global_settings = WorktreeSettings::get_global(cx);
87 let exclude_patterns = global_settings
88 .file_scan_exclusions
89 .sources()
90 .iter()
91 .chain(global_settings.private_files.sources().iter());
92 let path_style = project.path_style(cx);
93 anyhow::Ok((PathMatcher::new(exclude_patterns, path_style)?, path_style))
94 })??;
95
96 let (results_tx, mut results_rx) = mpsc::unbounded();
97
98 for query in queries {
99 let exclude_matcher = exclude_matcher.clone();
100 let results_tx = results_tx.clone();
101 let project = project.clone();
102 cx.spawn(async move |cx| {
103 run_query(
104 query,
105 results_tx.clone(),
106 path_style,
107 exclude_matcher,
108 &project,
109 cx,
110 )
111 .await
112 .log_err();
113 })
114 .detach()
115 }
116 drop(results_tx);
117
118 #[cfg(feature = "eval-support")]
119 let cache = cache.clone();
120 cx.background_spawn(async move {
121 let mut results: HashMap<Entity<Buffer>, Vec<Range<Anchor>>> = HashMap::default();
122 let mut snapshots = HashMap::default();
123
124 let mut total_bytes = 0;
125 'outer: while let Some((buffer, snapshot, excerpts)) = results_rx.next().await {
126 snapshots.insert(buffer.entity_id(), snapshot);
127 let existing = results.entry(buffer).or_default();
128 existing.reserve(excerpts.len());
129
130 for (range, size) in excerpts {
131 // Blunt trimming of the results until we have a proper algorithmic filtering step
132 if (total_bytes + size) > MAX_RESULTS_LEN {
133 log::trace!("Combined results reached limit of {MAX_RESULTS_LEN}B");
134 break 'outer;
135 }
136 total_bytes += size;
137 existing.push(range);
138 }
139 }
140
141 #[cfg(feature = "eval-support")]
142 if let Some((cache, queries, key)) = cache {
143 let cached_results: CachedSearchResults = results
144 .iter()
145 .filter_map(|(buffer, ranges)| {
146 let snapshot = snapshots.get(&buffer.entity_id())?;
147 let path = snapshot.file().map(|f| f.path());
148 let mut ranges = ranges
149 .iter()
150 .map(|range| range.to_offset(&snapshot))
151 .collect::<Vec<_>>();
152 ranges.sort_unstable_by_key(|range| (range.start, range.end));
153
154 Some((path?.as_std_path().to_path_buf(), ranges))
155 })
156 .collect();
157 cache.write(
158 key,
159 &queries,
160 &serde_json::to_string_pretty(&cached_results)?,
161 );
162 }
163
164 for (buffer, ranges) in results.iter_mut() {
165 if let Some(snapshot) = snapshots.get(&buffer.entity_id()) {
166 merge_anchor_ranges(ranges, snapshot);
167 }
168 }
169
170 Ok(results)
171 })
172 .await
173}
174
175pub(crate) fn merge_anchor_ranges(ranges: &mut Vec<Range<Anchor>>, snapshot: &BufferSnapshot) {
176 ranges.sort_unstable_by(|a, b| {
177 a.start
178 .cmp(&b.start, snapshot)
179 .then(b.end.cmp(&a.end, snapshot))
180 });
181
182 let mut index = 1;
183 while index < ranges.len() {
184 if ranges[index - 1]
185 .end
186 .cmp(&ranges[index].start, snapshot)
187 .is_ge()
188 {
189 let removed = ranges.remove(index);
190 if removed.end.cmp(&ranges[index - 1].end, snapshot).is_gt() {
191 ranges[index - 1].end = removed.end;
192 }
193 } else {
194 index += 1;
195 }
196 }
197}
198
199const MAX_EXCERPT_LEN: usize = 768;
200const MAX_RESULTS_LEN: usize = MAX_EXCERPT_LEN * 5;
201
202struct SearchJob {
203 buffer: Entity<Buffer>,
204 snapshot: BufferSnapshot,
205 ranges: Vec<Range<usize>>,
206 query_ix: usize,
207 jobs_tx: channel::Sender<SearchJob>,
208}
209
210async fn run_query(
211 input_query: SearchToolQuery,
212 results_tx: UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
213 path_style: PathStyle,
214 exclude_matcher: PathMatcher,
215 project: &Entity<Project>,
216 cx: &mut AsyncApp,
217) -> Result<()> {
218 let include_matcher = PathMatcher::new(vec![input_query.glob], path_style)?;
219
220 let make_search = |regex: &str| -> Result<SearchQuery> {
221 SearchQuery::regex(
222 regex,
223 false,
224 true,
225 false,
226 true,
227 include_matcher.clone(),
228 exclude_matcher.clone(),
229 true,
230 None,
231 )
232 };
233
234 if let Some(outer_syntax_regex) = input_query.syntax_node.first() {
235 let outer_syntax_query = make_search(outer_syntax_regex)?;
236 let nested_syntax_queries = input_query
237 .syntax_node
238 .into_iter()
239 .skip(1)
240 .map(|query| make_search(&query))
241 .collect::<Result<Vec<_>>>()?;
242 let content_query = input_query
243 .content
244 .map(|regex| make_search(®ex))
245 .transpose()?;
246
247 let (jobs_tx, jobs_rx) = channel::unbounded();
248
249 let outer_search_results_rx =
250 project.update(cx, |project, cx| project.search(outer_syntax_query, cx))?;
251
252 let outer_search_task = cx.spawn(async move |cx| {
253 futures::pin_mut!(outer_search_results_rx);
254 while let Some(SearchResult::Buffer { buffer, ranges }) =
255 outer_search_results_rx.next().await
256 {
257 buffer
258 .read_with(cx, |buffer, _| buffer.parsing_idle())?
259 .await;
260 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
261 let expanded_ranges: Vec<_> = ranges
262 .into_iter()
263 .filter_map(|range| expand_to_parent_range(&range, &snapshot))
264 .collect();
265 jobs_tx
266 .send(SearchJob {
267 buffer,
268 snapshot,
269 ranges: expanded_ranges,
270 query_ix: 0,
271 jobs_tx: jobs_tx.clone(),
272 })
273 .await?;
274 }
275 anyhow::Ok(())
276 });
277
278 let n_workers = cx.background_executor().num_cpus();
279 let search_job_task = cx.background_executor().scoped(|scope| {
280 for _ in 0..n_workers {
281 scope.spawn(async {
282 while let Ok(job) = jobs_rx.recv().await {
283 process_nested_search_job(
284 &results_tx,
285 &nested_syntax_queries,
286 &content_query,
287 job,
288 )
289 .await;
290 }
291 });
292 }
293 });
294
295 search_job_task.await;
296 outer_search_task.await?;
297 } else if let Some(content_regex) = &input_query.content {
298 let search_query = make_search(&content_regex)?;
299
300 let results_rx = project.update(cx, |project, cx| project.search(search_query, cx))?;
301 futures::pin_mut!(results_rx);
302
303 while let Some(SearchResult::Buffer { buffer, ranges }) = results_rx.next().await {
304 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
305
306 let ranges = ranges
307 .into_iter()
308 .map(|range| {
309 let range = range.to_offset(&snapshot);
310 let range = expand_to_entire_lines(range, &snapshot);
311 let size = range.len();
312 let range =
313 snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
314 (range, size)
315 })
316 .collect();
317
318 let send_result = results_tx.unbounded_send((buffer.clone(), snapshot.clone(), ranges));
319
320 if let Err(err) = send_result
321 && !err.is_disconnected()
322 {
323 log::error!("{err}");
324 }
325 }
326 } else {
327 log::warn!("Context gathering model produced a glob-only search");
328 }
329
330 anyhow::Ok(())
331}
332
333async fn process_nested_search_job(
334 results_tx: &UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
335 queries: &Vec<SearchQuery>,
336 content_query: &Option<SearchQuery>,
337 job: SearchJob,
338) {
339 if let Some(search_query) = queries.get(job.query_ix) {
340 let mut subranges = Vec::new();
341 for range in job.ranges {
342 let start = range.start;
343 let search_results = search_query.search(&job.snapshot, Some(range)).await;
344 for subrange in search_results {
345 let subrange = start + subrange.start..start + subrange.end;
346 subranges.extend(expand_to_parent_range(&subrange, &job.snapshot));
347 }
348 }
349 job.jobs_tx
350 .send(SearchJob {
351 buffer: job.buffer,
352 snapshot: job.snapshot,
353 ranges: subranges,
354 query_ix: job.query_ix + 1,
355 jobs_tx: job.jobs_tx.clone(),
356 })
357 .await
358 .ok();
359 } else {
360 let ranges = if let Some(content_query) = content_query {
361 let mut subranges = Vec::new();
362 for range in job.ranges {
363 let start = range.start;
364 let search_results = content_query.search(&job.snapshot, Some(range)).await;
365 for subrange in search_results {
366 let subrange = start + subrange.start..start + subrange.end;
367 subranges.push(subrange);
368 }
369 }
370 subranges
371 } else {
372 job.ranges
373 };
374
375 let matches = ranges
376 .into_iter()
377 .map(|range| {
378 let snapshot = &job.snapshot;
379 let range = expand_to_entire_lines(range, snapshot);
380 let size = range.len();
381 let range = snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
382 (range, size)
383 })
384 .collect();
385
386 let send_result = results_tx.unbounded_send((job.buffer, job.snapshot, matches));
387
388 if let Err(err) = send_result
389 && !err.is_disconnected()
390 {
391 log::error!("{err}");
392 }
393 }
394}
395
396fn expand_to_entire_lines(range: Range<usize>, snapshot: &BufferSnapshot) -> Range<usize> {
397 let mut point_range = range.to_point(snapshot);
398 point_range.start.column = 0;
399 if point_range.end.column > 0 {
400 point_range.end = snapshot.max_point().min(point_range.end + Point::new(1, 0));
401 }
402 point_range.to_offset(snapshot)
403}
404
405fn expand_to_parent_range<T: ToPoint + ToOffset>(
406 range: &Range<T>,
407 snapshot: &BufferSnapshot,
408) -> Option<Range<usize>> {
409 let mut line_range = range.to_point(&snapshot);
410 line_range.start.column = snapshot.indent_size_for_line(line_range.start.row).len;
411 line_range.end.column = snapshot.line_len(line_range.end.row);
412 // TODO skip result if matched line isn't the first node line?
413
414 let node = snapshot.syntax_ancestor(line_range)?;
415 Some(node.byte_range())
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421 use crate::assemble_excerpts::assemble_excerpts;
422 use cloud_zeta2_prompt::write_codeblock;
423 use edit_prediction_context::Line;
424 use gpui::TestAppContext;
425 use indoc::indoc;
426 use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
427 use pretty_assertions::assert_eq;
428 use project::FakeFs;
429 use serde_json::json;
430 use settings::SettingsStore;
431 use std::path::Path;
432 use util::path;
433
434 #[gpui::test]
435 async fn test_retrieval(cx: &mut TestAppContext) {
436 init_test(cx);
437 let fs = FakeFs::new(cx.executor());
438 fs.insert_tree(
439 path!("/root"),
440 json!({
441 "user.rs": indoc!{"
442 pub struct Organization {
443 owner: Arc<User>,
444 }
445
446 pub struct User {
447 first_name: String,
448 last_name: String,
449 }
450
451 impl Organization {
452 pub fn owner(&self) -> Arc<User> {
453 self.owner.clone()
454 }
455 }
456
457 impl User {
458 pub fn new(first_name: String, last_name: String) -> Self {
459 Self {
460 first_name,
461 last_name
462 }
463 }
464
465 pub fn first_name(&self) -> String {
466 self.first_name.clone()
467 }
468
469 pub fn last_name(&self) -> String {
470 self.last_name.clone()
471 }
472 }
473 "},
474 "main.rs": indoc!{r#"
475 fn main() {
476 let user = User::new(FIRST_NAME.clone(), "doe".into());
477 println!("user {:?}", user);
478 }
479 "#},
480 }),
481 )
482 .await;
483
484 let project = Project::test(fs, vec![Path::new(path!("/root"))], cx).await;
485 project.update(cx, |project, _cx| {
486 project.languages().add(rust_lang().into())
487 });
488
489 assert_results(
490 &project,
491 SearchToolQuery {
492 glob: "user.rs".into(),
493 syntax_node: vec!["impl\\s+User".into(), "pub\\s+fn\\s+first_name".into()],
494 content: None,
495 },
496 indoc! {r#"
497 `````root/user.rs
498 …
499 impl User {
500 …
501 pub fn first_name(&self) -> String {
502 self.first_name.clone()
503 }
504 …
505 `````
506 "#},
507 cx,
508 )
509 .await;
510
511 assert_results(
512 &project,
513 SearchToolQuery {
514 glob: "user.rs".into(),
515 syntax_node: vec!["impl\\s+User".into()],
516 content: Some("\\.clone".into()),
517 },
518 indoc! {r#"
519 `````root/user.rs
520 …
521 impl User {
522 …
523 pub fn first_name(&self) -> String {
524 self.first_name.clone()
525 …
526 pub fn last_name(&self) -> String {
527 self.last_name.clone()
528 …
529 `````
530 "#},
531 cx,
532 )
533 .await;
534
535 assert_results(
536 &project,
537 SearchToolQuery {
538 glob: "*.rs".into(),
539 syntax_node: vec![],
540 content: Some("\\.clone".into()),
541 },
542 indoc! {r#"
543 `````root/main.rs
544 fn main() {
545 let user = User::new(FIRST_NAME.clone(), "doe".into());
546 …
547 `````
548
549 `````root/user.rs
550 …
551 impl Organization {
552 pub fn owner(&self) -> Arc<User> {
553 self.owner.clone()
554 …
555 impl User {
556 …
557 pub fn first_name(&self) -> String {
558 self.first_name.clone()
559 …
560 pub fn last_name(&self) -> String {
561 self.last_name.clone()
562 …
563 `````
564 "#},
565 cx,
566 )
567 .await;
568 }
569
570 async fn assert_results(
571 project: &Entity<Project>,
572 query: SearchToolQuery,
573 expected_output: &str,
574 cx: &mut TestAppContext,
575 ) {
576 let results = run_retrieval_searches(
577 vec![query],
578 project.clone(),
579 #[cfg(feature = "eval-support")]
580 None,
581 &mut cx.to_async(),
582 )
583 .await
584 .unwrap();
585
586 let mut results = results.into_iter().collect::<Vec<_>>();
587 results.sort_by_key(|results| {
588 results
589 .0
590 .read_with(cx, |buffer, _| buffer.file().unwrap().path().clone())
591 });
592
593 let mut output = String::new();
594 for (buffer, ranges) in results {
595 buffer.read_with(cx, |buffer, cx| {
596 let excerpts = ranges.into_iter().map(|range| {
597 let point_range = range.to_point(buffer);
598 if point_range.end.column > 0 {
599 Line(point_range.start.row)..Line(point_range.end.row + 1)
600 } else {
601 Line(point_range.start.row)..Line(point_range.end.row)
602 }
603 });
604
605 write_codeblock(
606 &buffer.file().unwrap().full_path(cx),
607 assemble_excerpts(&buffer.snapshot(), excerpts).iter(),
608 &[],
609 Line(buffer.max_point().row),
610 false,
611 &mut output,
612 );
613 });
614 }
615 output.pop();
616
617 assert_eq!(output, expected_output);
618 }
619
620 fn rust_lang() -> Language {
621 Language::new(
622 LanguageConfig {
623 name: "Rust".into(),
624 matcher: LanguageMatcher {
625 path_suffixes: vec!["rs".to_string()],
626 ..Default::default()
627 },
628 ..Default::default()
629 },
630 Some(tree_sitter_rust::LANGUAGE.into()),
631 )
632 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
633 .unwrap()
634 }
635
636 fn init_test(cx: &mut TestAppContext) {
637 cx.update(move |cx| {
638 let settings_store = SettingsStore::test(cx);
639 cx.set_global(settings_store);
640 zlog::init_test();
641 });
642 }
643}