retrieval_search.rs

  1use anyhow::Result;
  2use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery;
  3use collections::HashMap;
  4use futures::{
  5    StreamExt,
  6    channel::mpsc::{self, UnboundedSender},
  7};
  8use gpui::{AppContext, AsyncApp, Entity};
  9use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, Point, ToOffset, ToPoint};
 10use project::{
 11    Project, WorktreeSettings,
 12    search::{SearchQuery, SearchResult},
 13};
 14use smol::channel;
 15use std::ops::Range;
 16use util::{
 17    ResultExt as _,
 18    paths::{PathMatcher, PathStyle},
 19};
 20use workspace::item::Settings as _;
 21
 22#[cfg(feature = "eval-support")]
 23type CachedSearchResults = std::collections::BTreeMap<std::path::PathBuf, Vec<Range<usize>>>;
 24
 25pub async fn run_retrieval_searches(
 26    queries: Vec<SearchToolQuery>,
 27    project: Entity<Project>,
 28    #[cfg(feature = "eval-support")] eval_cache: Option<std::sync::Arc<dyn crate::EvalCache>>,
 29    cx: &mut AsyncApp,
 30) -> Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>> {
 31    #[cfg(feature = "eval-support")]
 32    let cache = if let Some(eval_cache) = eval_cache {
 33        use crate::EvalCacheEntryKind;
 34        use anyhow::Context;
 35        use collections::FxHasher;
 36        use std::hash::{Hash, Hasher};
 37
 38        let mut hasher = FxHasher::default();
 39        project.read_with(cx, |project, cx| {
 40            let mut worktrees = project.worktrees(cx);
 41            let Some(worktree) = worktrees.next() else {
 42                panic!("Expected a single worktree in eval project. Found none.");
 43            };
 44            assert!(
 45                worktrees.next().is_none(),
 46                "Expected a single worktree in eval project. Found more than one."
 47            );
 48            worktree.read(cx).abs_path().hash(&mut hasher);
 49        })?;
 50
 51        queries.hash(&mut hasher);
 52        let key = (EvalCacheEntryKind::Search, hasher.finish());
 53
 54        if let Some(cached_results) = eval_cache.read(key) {
 55            let file_results = serde_json::from_str::<CachedSearchResults>(&cached_results)
 56                .context("Failed to deserialize cached search results")?;
 57            let mut results = HashMap::default();
 58
 59            for (path, ranges) in file_results {
 60                let buffer = project
 61                    .update(cx, |project, cx| {
 62                        let project_path = project.find_project_path(path, cx).unwrap();
 63                        project.open_buffer(project_path, cx)
 64                    })?
 65                    .await?;
 66                let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
 67                let mut ranges: Vec<_> = ranges
 68                    .into_iter()
 69                    .map(|range| {
 70                        snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end)
 71                    })
 72                    .collect();
 73                merge_anchor_ranges(&mut ranges, &snapshot);
 74                results.insert(buffer, ranges);
 75            }
 76
 77            return Ok(results);
 78        }
 79
 80        Some((eval_cache, serde_json::to_string_pretty(&queries)?, key))
 81    } else {
 82        None
 83    };
 84
 85    let (exclude_matcher, path_style) = project.update(cx, |project, cx| {
 86        let global_settings = WorktreeSettings::get_global(cx);
 87        let exclude_patterns = global_settings
 88            .file_scan_exclusions
 89            .sources()
 90            .iter()
 91            .chain(global_settings.private_files.sources().iter());
 92        let path_style = project.path_style(cx);
 93        anyhow::Ok((PathMatcher::new(exclude_patterns, path_style)?, path_style))
 94    })??;
 95
 96    let (results_tx, mut results_rx) = mpsc::unbounded();
 97
 98    for query in queries {
 99        let exclude_matcher = exclude_matcher.clone();
100        let results_tx = results_tx.clone();
101        let project = project.clone();
102        cx.spawn(async move |cx| {
103            run_query(
104                query,
105                results_tx.clone(),
106                path_style,
107                exclude_matcher,
108                &project,
109                cx,
110            )
111            .await
112            .log_err();
113        })
114        .detach()
115    }
116    drop(results_tx);
117
118    #[cfg(feature = "eval-support")]
119    let cache = cache.clone();
120    cx.background_spawn(async move {
121        let mut results: HashMap<Entity<Buffer>, Vec<Range<Anchor>>> = HashMap::default();
122        let mut snapshots = HashMap::default();
123
124        let mut total_bytes = 0;
125        'outer: while let Some((buffer, snapshot, excerpts)) = results_rx.next().await {
126            snapshots.insert(buffer.entity_id(), snapshot);
127            let existing = results.entry(buffer).or_default();
128            existing.reserve(excerpts.len());
129
130            for (range, size) in excerpts {
131                // Blunt trimming of the results until we have a proper algorithmic filtering step
132                if (total_bytes + size) > MAX_RESULTS_LEN {
133                    log::trace!("Combined results reached limit of {MAX_RESULTS_LEN}B");
134                    break 'outer;
135                }
136                total_bytes += size;
137                existing.push(range);
138            }
139        }
140
141        #[cfg(feature = "eval-support")]
142        if let Some((cache, queries, key)) = cache {
143            let cached_results: CachedSearchResults = results
144                .iter()
145                .filter_map(|(buffer, ranges)| {
146                    let snapshot = snapshots.get(&buffer.entity_id())?;
147                    let path = snapshot.file().map(|f| f.path());
148                    let mut ranges = ranges
149                        .iter()
150                        .map(|range| range.to_offset(&snapshot))
151                        .collect::<Vec<_>>();
152                    ranges.sort_unstable_by_key(|range| (range.start, range.end));
153
154                    Some((path?.as_std_path().to_path_buf(), ranges))
155                })
156                .collect();
157            cache.write(
158                key,
159                &queries,
160                &serde_json::to_string_pretty(&cached_results)?,
161            );
162        }
163
164        for (buffer, ranges) in results.iter_mut() {
165            if let Some(snapshot) = snapshots.get(&buffer.entity_id()) {
166                merge_anchor_ranges(ranges, snapshot);
167            }
168        }
169
170        Ok(results)
171    })
172    .await
173}
174
175pub(crate) fn merge_anchor_ranges(ranges: &mut Vec<Range<Anchor>>, snapshot: &BufferSnapshot) {
176    ranges.sort_unstable_by(|a, b| {
177        a.start
178            .cmp(&b.start, snapshot)
179            .then(b.end.cmp(&a.end, snapshot))
180    });
181
182    let mut index = 1;
183    while index < ranges.len() {
184        if ranges[index - 1]
185            .end
186            .cmp(&ranges[index].start, snapshot)
187            .is_ge()
188        {
189            let removed = ranges.remove(index);
190            if removed.end.cmp(&ranges[index - 1].end, snapshot).is_gt() {
191                ranges[index - 1].end = removed.end;
192            }
193        } else {
194            index += 1;
195        }
196    }
197}
198
199const MAX_EXCERPT_LEN: usize = 768;
200const MAX_RESULTS_LEN: usize = MAX_EXCERPT_LEN * 5;
201
202struct SearchJob {
203    buffer: Entity<Buffer>,
204    snapshot: BufferSnapshot,
205    ranges: Vec<Range<usize>>,
206    query_ix: usize,
207    jobs_tx: channel::Sender<SearchJob>,
208}
209
210async fn run_query(
211    input_query: SearchToolQuery,
212    results_tx: UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
213    path_style: PathStyle,
214    exclude_matcher: PathMatcher,
215    project: &Entity<Project>,
216    cx: &mut AsyncApp,
217) -> Result<()> {
218    let include_matcher = PathMatcher::new(vec![input_query.glob], path_style)?;
219
220    let make_search = |regex: &str| -> Result<SearchQuery> {
221        SearchQuery::regex(
222            regex,
223            false,
224            true,
225            false,
226            true,
227            include_matcher.clone(),
228            exclude_matcher.clone(),
229            true,
230            None,
231        )
232    };
233
234    if let Some(outer_syntax_regex) = input_query.syntax_node.first() {
235        let outer_syntax_query = make_search(outer_syntax_regex)?;
236        let nested_syntax_queries = input_query
237            .syntax_node
238            .into_iter()
239            .skip(1)
240            .map(|query| make_search(&query))
241            .collect::<Result<Vec<_>>>()?;
242        let content_query = input_query
243            .content
244            .map(|regex| make_search(&regex))
245            .transpose()?;
246
247        let (jobs_tx, jobs_rx) = channel::unbounded();
248
249        let outer_search_results_rx =
250            project.update(cx, |project, cx| project.search(outer_syntax_query, cx))?;
251
252        let outer_search_task = cx.spawn(async move |cx| {
253            futures::pin_mut!(outer_search_results_rx);
254            while let Some(SearchResult::Buffer { buffer, ranges }) =
255                outer_search_results_rx.next().await
256            {
257                buffer
258                    .read_with(cx, |buffer, _| buffer.parsing_idle())?
259                    .await;
260                let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
261                let expanded_ranges: Vec<_> = ranges
262                    .into_iter()
263                    .filter_map(|range| expand_to_parent_range(&range, &snapshot))
264                    .collect();
265                jobs_tx
266                    .send(SearchJob {
267                        buffer,
268                        snapshot,
269                        ranges: expanded_ranges,
270                        query_ix: 0,
271                        jobs_tx: jobs_tx.clone(),
272                    })
273                    .await?;
274            }
275            anyhow::Ok(())
276        });
277
278        let n_workers = cx.background_executor().num_cpus();
279        let search_job_task = cx.background_executor().scoped(|scope| {
280            for _ in 0..n_workers {
281                scope.spawn(async {
282                    while let Ok(job) = jobs_rx.recv().await {
283                        process_nested_search_job(
284                            &results_tx,
285                            &nested_syntax_queries,
286                            &content_query,
287                            job,
288                        )
289                        .await;
290                    }
291                });
292            }
293        });
294
295        search_job_task.await;
296        outer_search_task.await?;
297    } else if let Some(content_regex) = &input_query.content {
298        let search_query = make_search(&content_regex)?;
299
300        let results_rx = project.update(cx, |project, cx| project.search(search_query, cx))?;
301        futures::pin_mut!(results_rx);
302
303        while let Some(SearchResult::Buffer { buffer, ranges }) = results_rx.next().await {
304            let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
305
306            let ranges = ranges
307                .into_iter()
308                .map(|range| {
309                    let range = range.to_offset(&snapshot);
310                    let range = expand_to_entire_lines(range, &snapshot);
311                    let size = range.len();
312                    let range =
313                        snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
314                    (range, size)
315                })
316                .collect();
317
318            let send_result = results_tx.unbounded_send((buffer.clone(), snapshot.clone(), ranges));
319
320            if let Err(err) = send_result
321                && !err.is_disconnected()
322            {
323                log::error!("{err}");
324            }
325        }
326    } else {
327        log::warn!("Context gathering model produced a glob-only search");
328    }
329
330    anyhow::Ok(())
331}
332
333async fn process_nested_search_job(
334    results_tx: &UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
335    queries: &Vec<SearchQuery>,
336    content_query: &Option<SearchQuery>,
337    job: SearchJob,
338) {
339    if let Some(search_query) = queries.get(job.query_ix) {
340        let mut subranges = Vec::new();
341        for range in job.ranges {
342            let start = range.start;
343            let search_results = search_query.search(&job.snapshot, Some(range)).await;
344            for subrange in search_results {
345                let subrange = start + subrange.start..start + subrange.end;
346                subranges.extend(expand_to_parent_range(&subrange, &job.snapshot));
347            }
348        }
349        job.jobs_tx
350            .send(SearchJob {
351                buffer: job.buffer,
352                snapshot: job.snapshot,
353                ranges: subranges,
354                query_ix: job.query_ix + 1,
355                jobs_tx: job.jobs_tx.clone(),
356            })
357            .await
358            .ok();
359    } else {
360        let ranges = if let Some(content_query) = content_query {
361            let mut subranges = Vec::new();
362            for range in job.ranges {
363                let start = range.start;
364                let search_results = content_query.search(&job.snapshot, Some(range)).await;
365                for subrange in search_results {
366                    let subrange = start + subrange.start..start + subrange.end;
367                    subranges.push(subrange);
368                }
369            }
370            subranges
371        } else {
372            job.ranges
373        };
374
375        let matches = ranges
376            .into_iter()
377            .map(|range| {
378                let snapshot = &job.snapshot;
379                let range = expand_to_entire_lines(range, snapshot);
380                let size = range.len();
381                let range = snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
382                (range, size)
383            })
384            .collect();
385
386        let send_result = results_tx.unbounded_send((job.buffer, job.snapshot, matches));
387
388        if let Err(err) = send_result
389            && !err.is_disconnected()
390        {
391            log::error!("{err}");
392        }
393    }
394}
395
396fn expand_to_entire_lines(range: Range<usize>, snapshot: &BufferSnapshot) -> Range<usize> {
397    let mut point_range = range.to_point(snapshot);
398    point_range.start.column = 0;
399    if point_range.end.column > 0 {
400        point_range.end = snapshot.max_point().min(point_range.end + Point::new(1, 0));
401    }
402    point_range.to_offset(snapshot)
403}
404
405fn expand_to_parent_range<T: ToPoint + ToOffset>(
406    range: &Range<T>,
407    snapshot: &BufferSnapshot,
408) -> Option<Range<usize>> {
409    let mut line_range = range.to_point(&snapshot);
410    line_range.start.column = snapshot.indent_size_for_line(line_range.start.row).len;
411    line_range.end.column = snapshot.line_len(line_range.end.row);
412    // TODO skip result if matched line isn't the first node line?
413
414    let node = snapshot.syntax_ancestor(line_range)?;
415    Some(node.byte_range())
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421    use crate::assemble_excerpts::assemble_excerpts;
422    use cloud_zeta2_prompt::write_codeblock;
423    use edit_prediction_context::Line;
424    use gpui::TestAppContext;
425    use indoc::indoc;
426    use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
427    use pretty_assertions::assert_eq;
428    use project::FakeFs;
429    use serde_json::json;
430    use settings::SettingsStore;
431    use std::path::Path;
432    use util::path;
433
434    #[gpui::test]
435    async fn test_retrieval(cx: &mut TestAppContext) {
436        init_test(cx);
437        let fs = FakeFs::new(cx.executor());
438        fs.insert_tree(
439            path!("/root"),
440            json!({
441                "user.rs": indoc!{"
442                    pub struct Organization {
443                        owner: Arc<User>,
444                    }
445
446                    pub struct User {
447                        first_name: String,
448                        last_name: String,
449                    }
450
451                    impl Organization {
452                        pub fn owner(&self) -> Arc<User> {
453                            self.owner.clone()
454                        }
455                    }
456
457                    impl User {
458                        pub fn new(first_name: String, last_name: String) -> Self {
459                            Self {
460                                first_name,
461                                last_name
462                            }
463                        }
464
465                        pub fn first_name(&self) -> String {
466                            self.first_name.clone()
467                        }
468
469                        pub fn last_name(&self) -> String {
470                            self.last_name.clone()
471                        }
472                    }
473                "},
474                "main.rs": indoc!{r#"
475                    fn main() {
476                        let user = User::new(FIRST_NAME.clone(), "doe".into());
477                        println!("user {:?}", user);
478                    }
479                "#},
480            }),
481        )
482        .await;
483
484        let project = Project::test(fs, vec![Path::new(path!("/root"))], cx).await;
485        project.update(cx, |project, _cx| {
486            project.languages().add(rust_lang().into())
487        });
488
489        assert_results(
490            &project,
491            SearchToolQuery {
492                glob: "user.rs".into(),
493                syntax_node: vec!["impl\\s+User".into(), "pub\\s+fn\\s+first_name".into()],
494                content: None,
495            },
496            indoc! {r#"
497                `````root/user.rs
498499                impl User {
500501                    pub fn first_name(&self) -> String {
502                        self.first_name.clone()
503                    }
504505                `````
506            "#},
507            cx,
508        )
509        .await;
510
511        assert_results(
512            &project,
513            SearchToolQuery {
514                glob: "user.rs".into(),
515                syntax_node: vec!["impl\\s+User".into()],
516                content: Some("\\.clone".into()),
517            },
518            indoc! {r#"
519                `````root/user.rs
520521                impl User {
522523                    pub fn first_name(&self) -> String {
524                        self.first_name.clone()
525526                    pub fn last_name(&self) -> String {
527                        self.last_name.clone()
528529                `````
530            "#},
531            cx,
532        )
533        .await;
534
535        assert_results(
536            &project,
537            SearchToolQuery {
538                glob: "*.rs".into(),
539                syntax_node: vec![],
540                content: Some("\\.clone".into()),
541            },
542            indoc! {r#"
543                `````root/main.rs
544                fn main() {
545                    let user = User::new(FIRST_NAME.clone(), "doe".into());
546547                `````
548
549                `````root/user.rs
550551                impl Organization {
552                    pub fn owner(&self) -> Arc<User> {
553                        self.owner.clone()
554555                impl User {
556557                    pub fn first_name(&self) -> String {
558                        self.first_name.clone()
559560                    pub fn last_name(&self) -> String {
561                        self.last_name.clone()
562563                `````
564            "#},
565            cx,
566        )
567        .await;
568    }
569
570    async fn assert_results(
571        project: &Entity<Project>,
572        query: SearchToolQuery,
573        expected_output: &str,
574        cx: &mut TestAppContext,
575    ) {
576        let results = run_retrieval_searches(
577            vec![query],
578            project.clone(),
579            #[cfg(feature = "eval-support")]
580            None,
581            &mut cx.to_async(),
582        )
583        .await
584        .unwrap();
585
586        let mut results = results.into_iter().collect::<Vec<_>>();
587        results.sort_by_key(|results| {
588            results
589                .0
590                .read_with(cx, |buffer, _| buffer.file().unwrap().path().clone())
591        });
592
593        let mut output = String::new();
594        for (buffer, ranges) in results {
595            buffer.read_with(cx, |buffer, cx| {
596                let excerpts = ranges.into_iter().map(|range| {
597                    let point_range = range.to_point(buffer);
598                    if point_range.end.column > 0 {
599                        Line(point_range.start.row)..Line(point_range.end.row + 1)
600                    } else {
601                        Line(point_range.start.row)..Line(point_range.end.row)
602                    }
603                });
604
605                write_codeblock(
606                    &buffer.file().unwrap().full_path(cx),
607                    assemble_excerpts(&buffer.snapshot(), excerpts).iter(),
608                    &[],
609                    Line(buffer.max_point().row),
610                    false,
611                    &mut output,
612                );
613            });
614        }
615        output.pop();
616
617        assert_eq!(output, expected_output);
618    }
619
620    fn rust_lang() -> Language {
621        Language::new(
622            LanguageConfig {
623                name: "Rust".into(),
624                matcher: LanguageMatcher {
625                    path_suffixes: vec!["rs".to_string()],
626                    ..Default::default()
627                },
628                ..Default::default()
629            },
630            Some(tree_sitter_rust::LANGUAGE.into()),
631        )
632        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
633        .unwrap()
634    }
635
636    fn init_test(cx: &mut TestAppContext) {
637        cx.update(move |cx| {
638            let settings_store = SettingsStore::test(cx);
639            cx.set_global(settings_store);
640            zlog::init_test();
641        });
642    }
643}