1use anyhow::Result;
2use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery;
3use collections::HashMap;
4use futures::{
5 StreamExt,
6 channel::mpsc::{self, UnboundedSender},
7};
8use gpui::{AppContext, AsyncApp, Entity};
9use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, Point, ToOffset, ToPoint};
10use project::{
11 Project, WorktreeSettings,
12 search::{SearchQuery, SearchResult},
13};
14use smol::channel;
15use std::ops::Range;
16use util::{
17 ResultExt as _,
18 paths::{PathMatcher, PathStyle},
19};
20use workspace::item::Settings as _;
21
22#[cfg(feature = "eval-support")]
23type CachedSearchResults = std::collections::BTreeMap<std::path::PathBuf, Vec<Range<usize>>>;
24
25pub async fn run_retrieval_searches(
26 queries: Vec<SearchToolQuery>,
27 project: Entity<Project>,
28 #[cfg(feature = "eval-support")] eval_cache: Option<std::sync::Arc<dyn crate::EvalCache>>,
29 cx: &mut AsyncApp,
30) -> Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>> {
31 #[cfg(feature = "eval-support")]
32 let cache = if let Some(eval_cache) = eval_cache {
33 use crate::EvalCacheEntryKind;
34 use anyhow::Context;
35 use collections::FxHasher;
36 use std::hash::{Hash, Hasher};
37
38 let mut hasher = FxHasher::default();
39 project.read_with(cx, |project, cx| {
40 let mut worktrees = project.worktrees(cx);
41 let Some(worktree) = worktrees.next() else {
42 panic!("Expected a single worktree in eval project. Found none.");
43 };
44 assert!(
45 worktrees.next().is_none(),
46 "Expected a single worktree in eval project. Found more than one."
47 );
48 worktree.read(cx).abs_path().hash(&mut hasher);
49 })?;
50
51 queries.hash(&mut hasher);
52 let key = (EvalCacheEntryKind::Search, hasher.finish());
53
54 if let Some(cached_results) = eval_cache.read(key) {
55 let file_results = serde_json::from_str::<CachedSearchResults>(&cached_results)
56 .context("Failed to deserialize cached search results")?;
57 let mut results = HashMap::default();
58
59 for (path, ranges) in file_results {
60 let buffer = project
61 .update(cx, |project, cx| {
62 let project_path = project.find_project_path(path, cx).unwrap();
63 project.open_buffer(project_path, cx)
64 })?
65 .await?;
66 let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
67 let mut ranges: Vec<_> = ranges
68 .into_iter()
69 .map(|range| {
70 snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end)
71 })
72 .collect();
73 merge_anchor_ranges(&mut ranges, &snapshot);
74 results.insert(buffer, ranges);
75 }
76
77 return Ok(results);
78 }
79
80 Some((eval_cache, serde_json::to_string_pretty(&queries)?, key))
81 } else {
82 None
83 };
84
85 let (exclude_matcher, path_style) = project.update(cx, |project, cx| {
86 let global_settings = WorktreeSettings::get_global(cx);
87 let exclude_patterns = global_settings
88 .file_scan_exclusions
89 .sources()
90 .chain(global_settings.private_files.sources());
91 let path_style = project.path_style(cx);
92 anyhow::Ok((PathMatcher::new(exclude_patterns, path_style)?, path_style))
93 })??;
94
95 let (results_tx, mut results_rx) = mpsc::unbounded();
96
97 for query in queries {
98 let exclude_matcher = exclude_matcher.clone();
99 let results_tx = results_tx.clone();
100 let project = project.clone();
101 cx.spawn(async move |cx| {
102 run_query(
103 query,
104 results_tx.clone(),
105 path_style,
106 exclude_matcher,
107 &project,
108 cx,
109 )
110 .await
111 .log_err();
112 })
113 .detach()
114 }
115 drop(results_tx);
116
117 #[cfg(feature = "eval-support")]
118 let cache = cache.clone();
119 cx.background_spawn(async move {
120 let mut results: HashMap<Entity<Buffer>, Vec<Range<Anchor>>> = HashMap::default();
121 let mut snapshots = HashMap::default();
122
123 let mut total_bytes = 0;
124 'outer: while let Some((buffer, snapshot, excerpts)) = results_rx.next().await {
125 snapshots.insert(buffer.entity_id(), snapshot);
126 let existing = results.entry(buffer).or_default();
127 existing.reserve(excerpts.len());
128
129 for (range, size) in excerpts {
130 // Blunt trimming of the results until we have a proper algorithmic filtering step
131 if (total_bytes + size) > MAX_RESULTS_LEN {
132 log::trace!("Combined results reached limit of {MAX_RESULTS_LEN}B");
133 break 'outer;
134 }
135 total_bytes += size;
136 existing.push(range);
137 }
138 }
139
140 #[cfg(feature = "eval-support")]
141 if let Some((cache, queries, key)) = cache {
142 let cached_results: CachedSearchResults = results
143 .iter()
144 .filter_map(|(buffer, ranges)| {
145 let snapshot = snapshots.get(&buffer.entity_id())?;
146 let path = snapshot.file().map(|f| f.path());
147 let mut ranges = ranges
148 .iter()
149 .map(|range| range.to_offset(&snapshot))
150 .collect::<Vec<_>>();
151 ranges.sort_unstable_by_key(|range| (range.start, range.end));
152
153 Some((path?.as_std_path().to_path_buf(), ranges))
154 })
155 .collect();
156 cache.write(
157 key,
158 &queries,
159 &serde_json::to_string_pretty(&cached_results)?,
160 );
161 }
162
163 for (buffer, ranges) in results.iter_mut() {
164 if let Some(snapshot) = snapshots.get(&buffer.entity_id()) {
165 merge_anchor_ranges(ranges, snapshot);
166 }
167 }
168
169 Ok(results)
170 })
171 .await
172}
173
174pub(crate) fn merge_anchor_ranges(ranges: &mut Vec<Range<Anchor>>, snapshot: &BufferSnapshot) {
175 ranges.sort_unstable_by(|a, b| {
176 a.start
177 .cmp(&b.start, snapshot)
178 .then(b.end.cmp(&a.end, snapshot))
179 });
180
181 let mut index = 1;
182 while index < ranges.len() {
183 if ranges[index - 1]
184 .end
185 .cmp(&ranges[index].start, snapshot)
186 .is_ge()
187 {
188 let removed = ranges.remove(index);
189 if removed.end.cmp(&ranges[index - 1].end, snapshot).is_gt() {
190 ranges[index - 1].end = removed.end;
191 }
192 } else {
193 index += 1;
194 }
195 }
196}
197
198const MAX_EXCERPT_LEN: usize = 768;
199const MAX_RESULTS_LEN: usize = MAX_EXCERPT_LEN * 5;
200
201struct SearchJob {
202 buffer: Entity<Buffer>,
203 snapshot: BufferSnapshot,
204 ranges: Vec<Range<usize>>,
205 query_ix: usize,
206 jobs_tx: channel::Sender<SearchJob>,
207}
208
209async fn run_query(
210 input_query: SearchToolQuery,
211 results_tx: UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
212 path_style: PathStyle,
213 exclude_matcher: PathMatcher,
214 project: &Entity<Project>,
215 cx: &mut AsyncApp,
216) -> Result<()> {
217 let include_matcher = PathMatcher::new(vec![input_query.glob], path_style)?;
218
219 let make_search = |regex: &str| -> Result<SearchQuery> {
220 SearchQuery::regex(
221 regex,
222 false,
223 true,
224 false,
225 true,
226 include_matcher.clone(),
227 exclude_matcher.clone(),
228 true,
229 None,
230 )
231 };
232
233 if let Some(outer_syntax_regex) = input_query.syntax_node.first() {
234 let outer_syntax_query = make_search(outer_syntax_regex)?;
235 let nested_syntax_queries = input_query
236 .syntax_node
237 .into_iter()
238 .skip(1)
239 .map(|query| make_search(&query))
240 .collect::<Result<Vec<_>>>()?;
241 let content_query = input_query
242 .content
243 .map(|regex| make_search(®ex))
244 .transpose()?;
245
246 let (jobs_tx, jobs_rx) = channel::unbounded();
247
248 let outer_search_results_rx =
249 project.update(cx, |project, cx| project.search(outer_syntax_query, cx))?;
250
251 let outer_search_task = cx.spawn(async move |cx| {
252 futures::pin_mut!(outer_search_results_rx);
253 while let Some(SearchResult::Buffer { buffer, ranges }) =
254 outer_search_results_rx.next().await
255 {
256 buffer
257 .read_with(cx, |buffer, _| buffer.parsing_idle())?
258 .await;
259 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
260 let expanded_ranges: Vec<_> = ranges
261 .into_iter()
262 .filter_map(|range| expand_to_parent_range(&range, &snapshot))
263 .collect();
264 jobs_tx
265 .send(SearchJob {
266 buffer,
267 snapshot,
268 ranges: expanded_ranges,
269 query_ix: 0,
270 jobs_tx: jobs_tx.clone(),
271 })
272 .await?;
273 }
274 anyhow::Ok(())
275 });
276
277 let n_workers = cx.background_executor().num_cpus();
278 let search_job_task = cx.background_executor().scoped(|scope| {
279 for _ in 0..n_workers {
280 scope.spawn(async {
281 while let Ok(job) = jobs_rx.recv().await {
282 process_nested_search_job(
283 &results_tx,
284 &nested_syntax_queries,
285 &content_query,
286 job,
287 )
288 .await;
289 }
290 });
291 }
292 });
293
294 search_job_task.await;
295 outer_search_task.await?;
296 } else if let Some(content_regex) = &input_query.content {
297 let search_query = make_search(&content_regex)?;
298
299 let results_rx = project.update(cx, |project, cx| project.search(search_query, cx))?;
300 futures::pin_mut!(results_rx);
301
302 while let Some(SearchResult::Buffer { buffer, ranges }) = results_rx.next().await {
303 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
304
305 let ranges = ranges
306 .into_iter()
307 .map(|range| {
308 let range = range.to_offset(&snapshot);
309 let range = expand_to_entire_lines(range, &snapshot);
310 let size = range.len();
311 let range =
312 snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
313 (range, size)
314 })
315 .collect();
316
317 let send_result = results_tx.unbounded_send((buffer.clone(), snapshot.clone(), ranges));
318
319 if let Err(err) = send_result
320 && !err.is_disconnected()
321 {
322 log::error!("{err}");
323 }
324 }
325 } else {
326 log::warn!("Context gathering model produced a glob-only search");
327 }
328
329 anyhow::Ok(())
330}
331
332async fn process_nested_search_job(
333 results_tx: &UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
334 queries: &Vec<SearchQuery>,
335 content_query: &Option<SearchQuery>,
336 job: SearchJob,
337) {
338 if let Some(search_query) = queries.get(job.query_ix) {
339 let mut subranges = Vec::new();
340 for range in job.ranges {
341 let start = range.start;
342 let search_results = search_query.search(&job.snapshot, Some(range)).await;
343 for subrange in search_results {
344 let subrange = start + subrange.start..start + subrange.end;
345 subranges.extend(expand_to_parent_range(&subrange, &job.snapshot));
346 }
347 }
348 job.jobs_tx
349 .send(SearchJob {
350 buffer: job.buffer,
351 snapshot: job.snapshot,
352 ranges: subranges,
353 query_ix: job.query_ix + 1,
354 jobs_tx: job.jobs_tx.clone(),
355 })
356 .await
357 .ok();
358 } else {
359 let ranges = if let Some(content_query) = content_query {
360 let mut subranges = Vec::new();
361 for range in job.ranges {
362 let start = range.start;
363 let search_results = content_query.search(&job.snapshot, Some(range)).await;
364 for subrange in search_results {
365 let subrange = start + subrange.start..start + subrange.end;
366 subranges.push(subrange);
367 }
368 }
369 subranges
370 } else {
371 job.ranges
372 };
373
374 let matches = ranges
375 .into_iter()
376 .map(|range| {
377 let snapshot = &job.snapshot;
378 let range = expand_to_entire_lines(range, snapshot);
379 let size = range.len();
380 let range = snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
381 (range, size)
382 })
383 .collect();
384
385 let send_result = results_tx.unbounded_send((job.buffer, job.snapshot, matches));
386
387 if let Err(err) = send_result
388 && !err.is_disconnected()
389 {
390 log::error!("{err}");
391 }
392 }
393}
394
395fn expand_to_entire_lines(range: Range<usize>, snapshot: &BufferSnapshot) -> Range<usize> {
396 let mut point_range = range.to_point(snapshot);
397 point_range.start.column = 0;
398 if point_range.end.column > 0 {
399 point_range.end = snapshot.max_point().min(point_range.end + Point::new(1, 0));
400 }
401 point_range.to_offset(snapshot)
402}
403
404fn expand_to_parent_range<T: ToPoint + ToOffset>(
405 range: &Range<T>,
406 snapshot: &BufferSnapshot,
407) -> Option<Range<usize>> {
408 let mut line_range = range.to_point(&snapshot);
409 line_range.start.column = snapshot.indent_size_for_line(line_range.start.row).len;
410 line_range.end.column = snapshot.line_len(line_range.end.row);
411 // TODO skip result if matched line isn't the first node line?
412
413 let node = snapshot.syntax_ancestor(line_range)?;
414 Some(node.byte_range())
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420 use crate::assemble_excerpts::assemble_excerpts;
421 use cloud_zeta2_prompt::write_codeblock;
422 use edit_prediction_context::Line;
423 use gpui::TestAppContext;
424 use indoc::indoc;
425 use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
426 use pretty_assertions::assert_eq;
427 use project::FakeFs;
428 use serde_json::json;
429 use settings::SettingsStore;
430 use std::path::Path;
431 use util::path;
432
433 #[gpui::test]
434 async fn test_retrieval(cx: &mut TestAppContext) {
435 init_test(cx);
436 let fs = FakeFs::new(cx.executor());
437 fs.insert_tree(
438 path!("/root"),
439 json!({
440 "user.rs": indoc!{"
441 pub struct Organization {
442 owner: Arc<User>,
443 }
444
445 pub struct User {
446 first_name: String,
447 last_name: String,
448 }
449
450 impl Organization {
451 pub fn owner(&self) -> Arc<User> {
452 self.owner.clone()
453 }
454 }
455
456 impl User {
457 pub fn new(first_name: String, last_name: String) -> Self {
458 Self {
459 first_name,
460 last_name
461 }
462 }
463
464 pub fn first_name(&self) -> String {
465 self.first_name.clone()
466 }
467
468 pub fn last_name(&self) -> String {
469 self.last_name.clone()
470 }
471 }
472 "},
473 "main.rs": indoc!{r#"
474 fn main() {
475 let user = User::new(FIRST_NAME.clone(), "doe".into());
476 println!("user {:?}", user);
477 }
478 "#},
479 }),
480 )
481 .await;
482
483 let project = Project::test(fs, vec![Path::new(path!("/root"))], cx).await;
484 project.update(cx, |project, _cx| {
485 project.languages().add(rust_lang().into())
486 });
487
488 assert_results(
489 &project,
490 SearchToolQuery {
491 glob: "user.rs".into(),
492 syntax_node: vec!["impl\\s+User".into(), "pub\\s+fn\\s+first_name".into()],
493 content: None,
494 },
495 indoc! {r#"
496 `````root/user.rs
497 …
498 impl User {
499 …
500 pub fn first_name(&self) -> String {
501 self.first_name.clone()
502 }
503 …
504 `````
505 "#},
506 cx,
507 )
508 .await;
509
510 assert_results(
511 &project,
512 SearchToolQuery {
513 glob: "user.rs".into(),
514 syntax_node: vec!["impl\\s+User".into()],
515 content: Some("\\.clone".into()),
516 },
517 indoc! {r#"
518 `````root/user.rs
519 …
520 impl User {
521 …
522 pub fn first_name(&self) -> String {
523 self.first_name.clone()
524 …
525 pub fn last_name(&self) -> String {
526 self.last_name.clone()
527 …
528 `````
529 "#},
530 cx,
531 )
532 .await;
533
534 assert_results(
535 &project,
536 SearchToolQuery {
537 glob: "*.rs".into(),
538 syntax_node: vec![],
539 content: Some("\\.clone".into()),
540 },
541 indoc! {r#"
542 `````root/main.rs
543 fn main() {
544 let user = User::new(FIRST_NAME.clone(), "doe".into());
545 …
546 `````
547
548 `````root/user.rs
549 …
550 impl Organization {
551 pub fn owner(&self) -> Arc<User> {
552 self.owner.clone()
553 …
554 impl User {
555 …
556 pub fn first_name(&self) -> String {
557 self.first_name.clone()
558 …
559 pub fn last_name(&self) -> String {
560 self.last_name.clone()
561 …
562 `````
563 "#},
564 cx,
565 )
566 .await;
567 }
568
569 async fn assert_results(
570 project: &Entity<Project>,
571 query: SearchToolQuery,
572 expected_output: &str,
573 cx: &mut TestAppContext,
574 ) {
575 let results = run_retrieval_searches(
576 vec![query],
577 project.clone(),
578 #[cfg(feature = "eval-support")]
579 None,
580 &mut cx.to_async(),
581 )
582 .await
583 .unwrap();
584
585 let mut results = results.into_iter().collect::<Vec<_>>();
586 results.sort_by_key(|results| {
587 results
588 .0
589 .read_with(cx, |buffer, _| buffer.file().unwrap().path().clone())
590 });
591
592 let mut output = String::new();
593 for (buffer, ranges) in results {
594 buffer.read_with(cx, |buffer, cx| {
595 let excerpts = ranges.into_iter().map(|range| {
596 let point_range = range.to_point(buffer);
597 if point_range.end.column > 0 {
598 Line(point_range.start.row)..Line(point_range.end.row + 1)
599 } else {
600 Line(point_range.start.row)..Line(point_range.end.row)
601 }
602 });
603
604 write_codeblock(
605 &buffer.file().unwrap().full_path(cx),
606 assemble_excerpts(&buffer.snapshot(), excerpts).iter(),
607 &[],
608 Line(buffer.max_point().row),
609 false,
610 &mut output,
611 );
612 });
613 }
614 output.pop();
615
616 assert_eq!(output, expected_output);
617 }
618
619 fn rust_lang() -> Language {
620 Language::new(
621 LanguageConfig {
622 name: "Rust".into(),
623 matcher: LanguageMatcher {
624 path_suffixes: vec!["rs".to_string()],
625 ..Default::default()
626 },
627 ..Default::default()
628 },
629 Some(tree_sitter_rust::LANGUAGE.into()),
630 )
631 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
632 .unwrap()
633 }
634
635 fn init_test(cx: &mut TestAppContext) {
636 cx.update(move |cx| {
637 let settings_store = SettingsStore::test(cx);
638 cx.set_global(settings_store);
639 zlog::init_test();
640 });
641 }
642}