1use std::ops::Range;
2
3use anyhow::Result;
4use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery;
5use collections::HashMap;
6use futures::{
7 StreamExt,
8 channel::mpsc::{self, UnboundedSender},
9};
10use gpui::{AppContext, AsyncApp, Entity};
11use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, Point, ToOffset, ToPoint};
12use project::{
13 Project, WorktreeSettings,
14 search::{SearchQuery, SearchResult},
15};
16use smol::channel;
17use util::{
18 ResultExt as _,
19 paths::{PathMatcher, PathStyle},
20};
21use workspace::item::Settings as _;
22
23pub async fn run_retrieval_searches(
24 project: Entity<Project>,
25 queries: Vec<SearchToolQuery>,
26 cx: &mut AsyncApp,
27) -> Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>> {
28 let (exclude_matcher, path_style) = project.update(cx, |project, cx| {
29 let global_settings = WorktreeSettings::get_global(cx);
30 let exclude_patterns = global_settings
31 .file_scan_exclusions
32 .sources()
33 .iter()
34 .chain(global_settings.private_files.sources().iter());
35 let path_style = project.path_style(cx);
36 anyhow::Ok((PathMatcher::new(exclude_patterns, path_style)?, path_style))
37 })??;
38
39 let (results_tx, mut results_rx) = mpsc::unbounded();
40
41 for query in queries {
42 let exclude_matcher = exclude_matcher.clone();
43 let results_tx = results_tx.clone();
44 let project = project.clone();
45 cx.spawn(async move |cx| {
46 run_query(
47 query,
48 results_tx.clone(),
49 path_style,
50 exclude_matcher,
51 &project,
52 cx,
53 )
54 .await
55 .log_err();
56 })
57 .detach()
58 }
59 drop(results_tx);
60
61 cx.background_spawn(async move {
62 let mut results: HashMap<Entity<Buffer>, Vec<Range<Anchor>>> = HashMap::default();
63 let mut snapshots = HashMap::default();
64
65 let mut total_bytes = 0;
66 'outer: while let Some((buffer, snapshot, excerpts)) = results_rx.next().await {
67 snapshots.insert(buffer.entity_id(), snapshot);
68 let existing = results.entry(buffer).or_default();
69 existing.reserve(excerpts.len());
70
71 for (range, size) in excerpts {
72 // Blunt trimming of the results until we have a proper algorithmic filtering step
73 if (total_bytes + size) > MAX_RESULTS_LEN {
74 log::trace!("Combined results reached limit of {MAX_RESULTS_LEN}B");
75 break 'outer;
76 }
77 total_bytes += size;
78 existing.push(range);
79 }
80 }
81
82 for (buffer, ranges) in results.iter_mut() {
83 if let Some(snapshot) = snapshots.get(&buffer.entity_id()) {
84 ranges.sort_unstable_by(|a, b| {
85 a.start
86 .cmp(&b.start, snapshot)
87 .then(b.end.cmp(&b.end, snapshot))
88 });
89
90 let mut index = 1;
91 while index < ranges.len() {
92 if ranges[index - 1]
93 .end
94 .cmp(&ranges[index].start, snapshot)
95 .is_gt()
96 {
97 let removed = ranges.remove(index);
98 ranges[index - 1].end = removed.end;
99 } else {
100 index += 1;
101 }
102 }
103 }
104 }
105
106 Ok(results)
107 })
108 .await
109}
110
111const MAX_EXCERPT_LEN: usize = 768;
112const MAX_RESULTS_LEN: usize = MAX_EXCERPT_LEN * 5;
113
114struct SearchJob {
115 buffer: Entity<Buffer>,
116 snapshot: BufferSnapshot,
117 ranges: Vec<Range<usize>>,
118 query_ix: usize,
119 jobs_tx: channel::Sender<SearchJob>,
120}
121
122async fn run_query(
123 input_query: SearchToolQuery,
124 results_tx: UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
125 path_style: PathStyle,
126 exclude_matcher: PathMatcher,
127 project: &Entity<Project>,
128 cx: &mut AsyncApp,
129) -> Result<()> {
130 let include_matcher = PathMatcher::new(vec![input_query.glob], path_style)?;
131
132 let make_search = |regex: &str| -> Result<SearchQuery> {
133 SearchQuery::regex(
134 regex,
135 false,
136 true,
137 false,
138 true,
139 include_matcher.clone(),
140 exclude_matcher.clone(),
141 true,
142 None,
143 )
144 };
145
146 if let Some(outer_syntax_regex) = input_query.syntax_node.first() {
147 let outer_syntax_query = make_search(outer_syntax_regex)?;
148 let nested_syntax_queries = input_query
149 .syntax_node
150 .into_iter()
151 .skip(1)
152 .map(|query| make_search(&query))
153 .collect::<Result<Vec<_>>>()?;
154 let content_query = input_query
155 .content
156 .map(|regex| make_search(®ex))
157 .transpose()?;
158
159 let (jobs_tx, jobs_rx) = channel::unbounded();
160
161 let outer_search_results_rx =
162 project.update(cx, |project, cx| project.search(outer_syntax_query, cx))?;
163
164 let outer_search_task = cx.spawn(async move |cx| {
165 futures::pin_mut!(outer_search_results_rx);
166 while let Some(SearchResult::Buffer { buffer, ranges }) =
167 outer_search_results_rx.next().await
168 {
169 buffer
170 .read_with(cx, |buffer, _| buffer.parsing_idle())?
171 .await;
172 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
173 let expanded_ranges: Vec<_> = ranges
174 .into_iter()
175 .filter_map(|range| expand_to_parent_range(&range, &snapshot))
176 .collect();
177 jobs_tx
178 .send(SearchJob {
179 buffer,
180 snapshot,
181 ranges: expanded_ranges,
182 query_ix: 0,
183 jobs_tx: jobs_tx.clone(),
184 })
185 .await?;
186 }
187 anyhow::Ok(())
188 });
189
190 let n_workers = cx.background_executor().num_cpus();
191 let search_job_task = cx.background_executor().scoped(|scope| {
192 for _ in 0..n_workers {
193 scope.spawn(async {
194 while let Ok(job) = jobs_rx.recv().await {
195 process_nested_search_job(
196 &results_tx,
197 &nested_syntax_queries,
198 &content_query,
199 job,
200 )
201 .await;
202 }
203 });
204 }
205 });
206
207 search_job_task.await;
208 outer_search_task.await?;
209 } else if let Some(content_regex) = &input_query.content {
210 let search_query = make_search(&content_regex)?;
211
212 let results_rx = project.update(cx, |project, cx| project.search(search_query, cx))?;
213 futures::pin_mut!(results_rx);
214
215 while let Some(SearchResult::Buffer { buffer, ranges }) = results_rx.next().await {
216 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
217
218 let ranges = ranges
219 .into_iter()
220 .map(|range| {
221 let range = range.to_offset(&snapshot);
222 let range = expand_to_entire_lines(range, &snapshot);
223 let size = range.len();
224 let range =
225 snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
226 (range, size)
227 })
228 .collect();
229
230 let send_result = results_tx.unbounded_send((buffer.clone(), snapshot.clone(), ranges));
231
232 if let Err(err) = send_result
233 && !err.is_disconnected()
234 {
235 log::error!("{err}");
236 }
237 }
238 } else {
239 log::warn!("Context gathering model produced a glob-only search");
240 }
241
242 anyhow::Ok(())
243}
244
245async fn process_nested_search_job(
246 results_tx: &UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
247 queries: &Vec<SearchQuery>,
248 content_query: &Option<SearchQuery>,
249 job: SearchJob,
250) {
251 if let Some(search_query) = queries.get(job.query_ix) {
252 let mut subranges = Vec::new();
253 for range in job.ranges {
254 let start = range.start;
255 let search_results = search_query.search(&job.snapshot, Some(range)).await;
256 for subrange in search_results {
257 let subrange = start + subrange.start..start + subrange.end;
258 subranges.extend(expand_to_parent_range(&subrange, &job.snapshot));
259 }
260 }
261 job.jobs_tx
262 .send(SearchJob {
263 buffer: job.buffer,
264 snapshot: job.snapshot,
265 ranges: subranges,
266 query_ix: job.query_ix + 1,
267 jobs_tx: job.jobs_tx.clone(),
268 })
269 .await
270 .ok();
271 } else {
272 let ranges = if let Some(content_query) = content_query {
273 let mut subranges = Vec::new();
274 for range in job.ranges {
275 let start = range.start;
276 let search_results = content_query.search(&job.snapshot, Some(range)).await;
277 for subrange in search_results {
278 let subrange = start + subrange.start..start + subrange.end;
279 subranges.push(subrange);
280 }
281 }
282 subranges
283 } else {
284 job.ranges
285 };
286
287 let matches = ranges
288 .into_iter()
289 .map(|range| {
290 let snapshot = &job.snapshot;
291 let range = expand_to_entire_lines(range, snapshot);
292 let size = range.len();
293 let range = snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
294 (range, size)
295 })
296 .collect();
297
298 let send_result = results_tx.unbounded_send((job.buffer, job.snapshot, matches));
299
300 if let Err(err) = send_result
301 && !err.is_disconnected()
302 {
303 log::error!("{err}");
304 }
305 }
306}
307
308fn expand_to_entire_lines(range: Range<usize>, snapshot: &BufferSnapshot) -> Range<usize> {
309 let mut point_range = range.to_point(snapshot);
310 point_range.start.column = 0;
311 if point_range.end.column > 0 {
312 point_range.end = snapshot.max_point().min(point_range.end + Point::new(1, 0));
313 }
314 point_range.to_offset(snapshot)
315}
316
317fn expand_to_parent_range<T: ToPoint + ToOffset>(
318 range: &Range<T>,
319 snapshot: &BufferSnapshot,
320) -> Option<Range<usize>> {
321 let mut line_range = range.to_point(&snapshot);
322 line_range.start.column = snapshot.indent_size_for_line(line_range.start.row).len;
323 line_range.end.column = snapshot.line_len(line_range.end.row);
324 // TODO skip result if matched line isn't the first node line?
325
326 let node = snapshot.syntax_ancestor(line_range)?;
327 Some(node.byte_range())
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333 use crate::merge_excerpts::merge_excerpts;
334 use cloud_zeta2_prompt::write_codeblock;
335 use edit_prediction_context::Line;
336 use gpui::TestAppContext;
337 use indoc::indoc;
338 use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
339 use pretty_assertions::assert_eq;
340 use project::FakeFs;
341 use serde_json::json;
342 use settings::SettingsStore;
343 use std::path::Path;
344 use util::path;
345
346 #[gpui::test]
347 async fn test_retrieval(cx: &mut TestAppContext) {
348 init_test(cx);
349 let fs = FakeFs::new(cx.executor());
350 fs.insert_tree(
351 path!("/root"),
352 json!({
353 "user.rs": indoc!{"
354 pub struct Organization {
355 owner: Arc<User>,
356 }
357
358 pub struct User {
359 first_name: String,
360 last_name: String,
361 }
362
363 impl Organization {
364 pub fn owner(&self) -> Arc<User> {
365 self.owner.clone()
366 }
367 }
368
369 impl User {
370 pub fn new(first_name: String, last_name: String) -> Self {
371 Self {
372 first_name,
373 last_name
374 }
375 }
376
377 pub fn first_name(&self) -> String {
378 self.first_name.clone()
379 }
380
381 pub fn last_name(&self) -> String {
382 self.last_name.clone()
383 }
384 }
385 "},
386 "main.rs": indoc!{r#"
387 fn main() {
388 let user = User::new(FIRST_NAME.clone(), "doe".into());
389 println!("user {:?}", user);
390 }
391 "#},
392 }),
393 )
394 .await;
395
396 let project = Project::test(fs, vec![Path::new(path!("/root"))], cx).await;
397 project.update(cx, |project, _cx| {
398 project.languages().add(rust_lang().into())
399 });
400
401 assert_results(
402 &project,
403 SearchToolQuery {
404 glob: "user.rs".into(),
405 syntax_node: vec!["impl\\s+User".into(), "pub\\s+fn\\s+first_name".into()],
406 content: None,
407 },
408 indoc! {r#"
409 `````root/user.rs
410 …
411 impl User {
412 …
413 pub fn first_name(&self) -> String {
414 self.first_name.clone()
415 }
416 …
417 `````
418 "#},
419 cx,
420 )
421 .await;
422
423 assert_results(
424 &project,
425 SearchToolQuery {
426 glob: "user.rs".into(),
427 syntax_node: vec!["impl\\s+User".into()],
428 content: Some("\\.clone".into()),
429 },
430 indoc! {r#"
431 `````root/user.rs
432 …
433 impl User {
434 …
435 pub fn first_name(&self) -> String {
436 self.first_name.clone()
437 …
438 pub fn last_name(&self) -> String {
439 self.last_name.clone()
440 …
441 `````
442 "#},
443 cx,
444 )
445 .await;
446
447 assert_results(
448 &project,
449 SearchToolQuery {
450 glob: "*.rs".into(),
451 syntax_node: vec![],
452 content: Some("\\.clone".into()),
453 },
454 indoc! {r#"
455 `````root/main.rs
456 fn main() {
457 let user = User::new(FIRST_NAME.clone(), "doe".into());
458 …
459 `````
460
461 `````root/user.rs
462 …
463 impl Organization {
464 pub fn owner(&self) -> Arc<User> {
465 self.owner.clone()
466 …
467 impl User {
468 …
469 pub fn first_name(&self) -> String {
470 self.first_name.clone()
471 …
472 pub fn last_name(&self) -> String {
473 self.last_name.clone()
474 …
475 `````
476 "#},
477 cx,
478 )
479 .await;
480 }
481
482 async fn assert_results(
483 project: &Entity<Project>,
484 query: SearchToolQuery,
485 expected_output: &str,
486 cx: &mut TestAppContext,
487 ) {
488 let results = run_retrieval_searches(project.clone(), vec![query], &mut cx.to_async())
489 .await
490 .unwrap();
491
492 let mut results = results.into_iter().collect::<Vec<_>>();
493 results.sort_by_key(|results| {
494 results
495 .0
496 .read_with(cx, |buffer, _| buffer.file().unwrap().path().clone())
497 });
498
499 let mut output = String::new();
500 for (buffer, ranges) in results {
501 buffer.read_with(cx, |buffer, cx| {
502 let excerpts = ranges.into_iter().map(|range| {
503 let point_range = range.to_point(buffer);
504 if point_range.end.column > 0 {
505 Line(point_range.start.row)..Line(point_range.end.row + 1)
506 } else {
507 Line(point_range.start.row)..Line(point_range.end.row)
508 }
509 });
510
511 write_codeblock(
512 &buffer.file().unwrap().full_path(cx),
513 merge_excerpts(&buffer.snapshot(), excerpts).iter(),
514 &[],
515 Line(buffer.max_point().row),
516 false,
517 &mut output,
518 );
519 });
520 }
521 output.pop();
522
523 assert_eq!(output, expected_output);
524 }
525
526 fn rust_lang() -> Language {
527 Language::new(
528 LanguageConfig {
529 name: "Rust".into(),
530 matcher: LanguageMatcher {
531 path_suffixes: vec!["rs".to_string()],
532 ..Default::default()
533 },
534 ..Default::default()
535 },
536 Some(tree_sitter_rust::LANGUAGE.into()),
537 )
538 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
539 .unwrap()
540 }
541
542 fn init_test(cx: &mut TestAppContext) {
543 cx.update(move |cx| {
544 let settings_store = SettingsStore::test(cx);
545 cx.set_global(settings_store);
546 zlog::init_test();
547 });
548 }
549}