retrieval_search.rs

  1use anyhow::Result;
  2use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery;
  3use collections::HashMap;
  4use futures::{
  5    StreamExt,
  6    channel::mpsc::{self, UnboundedSender},
  7};
  8use gpui::{AppContext, AsyncApp, Entity};
  9use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, Point, ToOffset, ToPoint};
 10use project::{
 11    Project, WorktreeSettings,
 12    search::{SearchQuery, SearchResult},
 13};
 14use smol::channel;
 15use std::ops::Range;
 16use util::{
 17    ResultExt as _,
 18    paths::{PathMatcher, PathStyle},
 19};
 20use workspace::item::Settings as _;
 21
 22#[cfg(feature = "eval-support")]
 23type CachedSearchResults = std::collections::BTreeMap<std::path::PathBuf, Vec<Range<usize>>>;
 24
 25pub async fn run_retrieval_searches(
 26    queries: Vec<SearchToolQuery>,
 27    project: Entity<Project>,
 28    #[cfg(feature = "eval-support")] eval_cache: Option<std::sync::Arc<dyn crate::EvalCache>>,
 29    cx: &mut AsyncApp,
 30) -> Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>> {
 31    #[cfg(feature = "eval-support")]
 32    let cache = if let Some(eval_cache) = eval_cache {
 33        use crate::EvalCacheEntryKind;
 34        use anyhow::Context;
 35        use collections::FxHasher;
 36        use std::hash::{Hash, Hasher};
 37
 38        let mut hasher = FxHasher::default();
 39        project.read_with(cx, |project, cx| {
 40            let mut worktrees = project.worktrees(cx);
 41            let Some(worktree) = worktrees.next() else {
 42                panic!("Expected a single worktree in eval project. Found none.");
 43            };
 44            assert!(
 45                worktrees.next().is_none(),
 46                "Expected a single worktree in eval project. Found more than one."
 47            );
 48            worktree.read(cx).abs_path().hash(&mut hasher);
 49        })?;
 50
 51        queries.hash(&mut hasher);
 52        let key = (EvalCacheEntryKind::Search, hasher.finish());
 53
 54        if let Some(cached_results) = eval_cache.read(key) {
 55            let file_results = serde_json::from_str::<CachedSearchResults>(&cached_results)
 56                .context("Failed to deserialize cached search results")?;
 57            let mut results = HashMap::default();
 58
 59            for (path, ranges) in file_results {
 60                let buffer = project
 61                    .update(cx, |project, cx| {
 62                        let project_path = project.find_project_path(path, cx).unwrap();
 63                        project.open_buffer(project_path, cx)
 64                    })?
 65                    .await?;
 66                let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
 67                let mut ranges: Vec<_> = ranges
 68                    .into_iter()
 69                    .map(|range| {
 70                        snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end)
 71                    })
 72                    .collect();
 73                merge_anchor_ranges(&mut ranges, &snapshot);
 74                results.insert(buffer, ranges);
 75            }
 76
 77            return Ok(results);
 78        }
 79
 80        Some((eval_cache, serde_json::to_string_pretty(&queries)?, key))
 81    } else {
 82        None
 83    };
 84
 85    let (exclude_matcher, path_style) = project.update(cx, |project, cx| {
 86        let global_settings = WorktreeSettings::get_global(cx);
 87        let exclude_patterns = global_settings
 88            .file_scan_exclusions
 89            .sources()
 90            .chain(global_settings.private_files.sources());
 91        let path_style = project.path_style(cx);
 92        anyhow::Ok((PathMatcher::new(exclude_patterns, path_style)?, path_style))
 93    })??;
 94
 95    let (results_tx, mut results_rx) = mpsc::unbounded();
 96
 97    for query in queries {
 98        let exclude_matcher = exclude_matcher.clone();
 99        let results_tx = results_tx.clone();
100        let project = project.clone();
101        cx.spawn(async move |cx| {
102            run_query(
103                query,
104                results_tx.clone(),
105                path_style,
106                exclude_matcher,
107                &project,
108                cx,
109            )
110            .await
111            .log_err();
112        })
113        .detach()
114    }
115    drop(results_tx);
116
117    #[cfg(feature = "eval-support")]
118    let cache = cache.clone();
119    cx.background_spawn(async move {
120        let mut results: HashMap<Entity<Buffer>, Vec<Range<Anchor>>> = HashMap::default();
121        let mut snapshots = HashMap::default();
122
123        let mut total_bytes = 0;
124        'outer: while let Some((buffer, snapshot, excerpts)) = results_rx.next().await {
125            snapshots.insert(buffer.entity_id(), snapshot);
126            let existing = results.entry(buffer).or_default();
127            existing.reserve(excerpts.len());
128
129            for (range, size) in excerpts {
130                // Blunt trimming of the results until we have a proper algorithmic filtering step
131                if (total_bytes + size) > MAX_RESULTS_LEN {
132                    log::trace!("Combined results reached limit of {MAX_RESULTS_LEN}B");
133                    break 'outer;
134                }
135                total_bytes += size;
136                existing.push(range);
137            }
138        }
139
140        #[cfg(feature = "eval-support")]
141        if let Some((cache, queries, key)) = cache {
142            let cached_results: CachedSearchResults = results
143                .iter()
144                .filter_map(|(buffer, ranges)| {
145                    let snapshot = snapshots.get(&buffer.entity_id())?;
146                    let path = snapshot.file().map(|f| f.path());
147                    let mut ranges = ranges
148                        .iter()
149                        .map(|range| range.to_offset(&snapshot))
150                        .collect::<Vec<_>>();
151                    ranges.sort_unstable_by_key(|range| (range.start, range.end));
152
153                    Some((path?.as_std_path().to_path_buf(), ranges))
154                })
155                .collect();
156            cache.write(
157                key,
158                &queries,
159                &serde_json::to_string_pretty(&cached_results)?,
160            );
161        }
162
163        for (buffer, ranges) in results.iter_mut() {
164            if let Some(snapshot) = snapshots.get(&buffer.entity_id()) {
165                merge_anchor_ranges(ranges, snapshot);
166            }
167        }
168
169        Ok(results)
170    })
171    .await
172}
173
174pub(crate) fn merge_anchor_ranges(ranges: &mut Vec<Range<Anchor>>, snapshot: &BufferSnapshot) {
175    ranges.sort_unstable_by(|a, b| {
176        a.start
177            .cmp(&b.start, snapshot)
178            .then(b.end.cmp(&a.end, snapshot))
179    });
180
181    let mut index = 1;
182    while index < ranges.len() {
183        if ranges[index - 1]
184            .end
185            .cmp(&ranges[index].start, snapshot)
186            .is_ge()
187        {
188            let removed = ranges.remove(index);
189            if removed.end.cmp(&ranges[index - 1].end, snapshot).is_gt() {
190                ranges[index - 1].end = removed.end;
191            }
192        } else {
193            index += 1;
194        }
195    }
196}
197
198const MAX_EXCERPT_LEN: usize = 768;
199const MAX_RESULTS_LEN: usize = MAX_EXCERPT_LEN * 5;
200
201struct SearchJob {
202    buffer: Entity<Buffer>,
203    snapshot: BufferSnapshot,
204    ranges: Vec<Range<usize>>,
205    query_ix: usize,
206    jobs_tx: channel::Sender<SearchJob>,
207}
208
209async fn run_query(
210    input_query: SearchToolQuery,
211    results_tx: UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
212    path_style: PathStyle,
213    exclude_matcher: PathMatcher,
214    project: &Entity<Project>,
215    cx: &mut AsyncApp,
216) -> Result<()> {
217    let include_matcher = PathMatcher::new(vec![input_query.glob], path_style)?;
218
219    let make_search = |regex: &str| -> Result<SearchQuery> {
220        SearchQuery::regex(
221            regex,
222            false,
223            true,
224            false,
225            true,
226            include_matcher.clone(),
227            exclude_matcher.clone(),
228            true,
229            None,
230        )
231    };
232
233    if let Some(outer_syntax_regex) = input_query.syntax_node.first() {
234        let outer_syntax_query = make_search(outer_syntax_regex)?;
235        let nested_syntax_queries = input_query
236            .syntax_node
237            .into_iter()
238            .skip(1)
239            .map(|query| make_search(&query))
240            .collect::<Result<Vec<_>>>()?;
241        let content_query = input_query
242            .content
243            .map(|regex| make_search(&regex))
244            .transpose()?;
245
246        let (jobs_tx, jobs_rx) = channel::unbounded();
247
248        let outer_search_results_rx =
249            project.update(cx, |project, cx| project.search(outer_syntax_query, cx))?;
250
251        let outer_search_task = cx.spawn(async move |cx| {
252            futures::pin_mut!(outer_search_results_rx);
253            while let Some(SearchResult::Buffer { buffer, ranges }) =
254                outer_search_results_rx.next().await
255            {
256                buffer
257                    .read_with(cx, |buffer, _| buffer.parsing_idle())?
258                    .await;
259                let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
260                let expanded_ranges: Vec<_> = ranges
261                    .into_iter()
262                    .filter_map(|range| expand_to_parent_range(&range, &snapshot))
263                    .collect();
264                jobs_tx
265                    .send(SearchJob {
266                        buffer,
267                        snapshot,
268                        ranges: expanded_ranges,
269                        query_ix: 0,
270                        jobs_tx: jobs_tx.clone(),
271                    })
272                    .await?;
273            }
274            anyhow::Ok(())
275        });
276
277        let n_workers = cx.background_executor().num_cpus();
278        let search_job_task = cx.background_executor().scoped(|scope| {
279            for _ in 0..n_workers {
280                scope.spawn(async {
281                    while let Ok(job) = jobs_rx.recv().await {
282                        process_nested_search_job(
283                            &results_tx,
284                            &nested_syntax_queries,
285                            &content_query,
286                            job,
287                        )
288                        .await;
289                    }
290                });
291            }
292        });
293
294        search_job_task.await;
295        outer_search_task.await?;
296    } else if let Some(content_regex) = &input_query.content {
297        let search_query = make_search(&content_regex)?;
298
299        let results_rx = project.update(cx, |project, cx| project.search(search_query, cx))?;
300        futures::pin_mut!(results_rx);
301
302        while let Some(SearchResult::Buffer { buffer, ranges }) = results_rx.next().await {
303            let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
304
305            let ranges = ranges
306                .into_iter()
307                .map(|range| {
308                    let range = range.to_offset(&snapshot);
309                    let range = expand_to_entire_lines(range, &snapshot);
310                    let size = range.len();
311                    let range =
312                        snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
313                    (range, size)
314                })
315                .collect();
316
317            let send_result = results_tx.unbounded_send((buffer.clone(), snapshot.clone(), ranges));
318
319            if let Err(err) = send_result
320                && !err.is_disconnected()
321            {
322                log::error!("{err}");
323            }
324        }
325    } else {
326        log::warn!("Context gathering model produced a glob-only search");
327    }
328
329    anyhow::Ok(())
330}
331
332async fn process_nested_search_job(
333    results_tx: &UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
334    queries: &Vec<SearchQuery>,
335    content_query: &Option<SearchQuery>,
336    job: SearchJob,
337) {
338    if let Some(search_query) = queries.get(job.query_ix) {
339        let mut subranges = Vec::new();
340        for range in job.ranges {
341            let start = range.start;
342            let search_results = search_query.search(&job.snapshot, Some(range)).await;
343            for subrange in search_results {
344                let subrange = start + subrange.start..start + subrange.end;
345                subranges.extend(expand_to_parent_range(&subrange, &job.snapshot));
346            }
347        }
348        job.jobs_tx
349            .send(SearchJob {
350                buffer: job.buffer,
351                snapshot: job.snapshot,
352                ranges: subranges,
353                query_ix: job.query_ix + 1,
354                jobs_tx: job.jobs_tx.clone(),
355            })
356            .await
357            .ok();
358    } else {
359        let ranges = if let Some(content_query) = content_query {
360            let mut subranges = Vec::new();
361            for range in job.ranges {
362                let start = range.start;
363                let search_results = content_query.search(&job.snapshot, Some(range)).await;
364                for subrange in search_results {
365                    let subrange = start + subrange.start..start + subrange.end;
366                    subranges.push(subrange);
367                }
368            }
369            subranges
370        } else {
371            job.ranges
372        };
373
374        let matches = ranges
375            .into_iter()
376            .map(|range| {
377                let snapshot = &job.snapshot;
378                let range = expand_to_entire_lines(range, snapshot);
379                let size = range.len();
380                let range = snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
381                (range, size)
382            })
383            .collect();
384
385        let send_result = results_tx.unbounded_send((job.buffer, job.snapshot, matches));
386
387        if let Err(err) = send_result
388            && !err.is_disconnected()
389        {
390            log::error!("{err}");
391        }
392    }
393}
394
395fn expand_to_entire_lines(range: Range<usize>, snapshot: &BufferSnapshot) -> Range<usize> {
396    let mut point_range = range.to_point(snapshot);
397    point_range.start.column = 0;
398    if point_range.end.column > 0 {
399        point_range.end = snapshot.max_point().min(point_range.end + Point::new(1, 0));
400    }
401    point_range.to_offset(snapshot)
402}
403
404fn expand_to_parent_range<T: ToPoint + ToOffset>(
405    range: &Range<T>,
406    snapshot: &BufferSnapshot,
407) -> Option<Range<usize>> {
408    let mut line_range = range.to_point(&snapshot);
409    line_range.start.column = snapshot.indent_size_for_line(line_range.start.row).len;
410    line_range.end.column = snapshot.line_len(line_range.end.row);
411    // TODO skip result if matched line isn't the first node line?
412
413    let node = snapshot.syntax_ancestor(line_range)?;
414    Some(node.byte_range())
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420    use crate::assemble_excerpts::assemble_excerpts;
421    use cloud_zeta2_prompt::write_codeblock;
422    use edit_prediction_context::Line;
423    use gpui::TestAppContext;
424    use indoc::indoc;
425    use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
426    use pretty_assertions::assert_eq;
427    use project::FakeFs;
428    use serde_json::json;
429    use settings::SettingsStore;
430    use std::path::Path;
431    use util::path;
432
433    #[gpui::test]
434    async fn test_retrieval(cx: &mut TestAppContext) {
435        init_test(cx);
436        let fs = FakeFs::new(cx.executor());
437        fs.insert_tree(
438            path!("/root"),
439            json!({
440                "user.rs": indoc!{"
441                    pub struct Organization {
442                        owner: Arc<User>,
443                    }
444
445                    pub struct User {
446                        first_name: String,
447                        last_name: String,
448                    }
449
450                    impl Organization {
451                        pub fn owner(&self) -> Arc<User> {
452                            self.owner.clone()
453                        }
454                    }
455
456                    impl User {
457                        pub fn new(first_name: String, last_name: String) -> Self {
458                            Self {
459                                first_name,
460                                last_name
461                            }
462                        }
463
464                        pub fn first_name(&self) -> String {
465                            self.first_name.clone()
466                        }
467
468                        pub fn last_name(&self) -> String {
469                            self.last_name.clone()
470                        }
471                    }
472                "},
473                "main.rs": indoc!{r#"
474                    fn main() {
475                        let user = User::new(FIRST_NAME.clone(), "doe".into());
476                        println!("user {:?}", user);
477                    }
478                "#},
479            }),
480        )
481        .await;
482
483        let project = Project::test(fs, vec![Path::new(path!("/root"))], cx).await;
484        project.update(cx, |project, _cx| {
485            project.languages().add(rust_lang().into())
486        });
487
488        assert_results(
489            &project,
490            SearchToolQuery {
491                glob: "user.rs".into(),
492                syntax_node: vec!["impl\\s+User".into(), "pub\\s+fn\\s+first_name".into()],
493                content: None,
494            },
495            indoc! {r#"
496                `````root/user.rs
497498                impl User {
499500                    pub fn first_name(&self) -> String {
501                        self.first_name.clone()
502                    }
503504                `````
505            "#},
506            cx,
507        )
508        .await;
509
510        assert_results(
511            &project,
512            SearchToolQuery {
513                glob: "user.rs".into(),
514                syntax_node: vec!["impl\\s+User".into()],
515                content: Some("\\.clone".into()),
516            },
517            indoc! {r#"
518                `````root/user.rs
519520                impl User {
521522                    pub fn first_name(&self) -> String {
523                        self.first_name.clone()
524525                    pub fn last_name(&self) -> String {
526                        self.last_name.clone()
527528                `````
529            "#},
530            cx,
531        )
532        .await;
533
534        assert_results(
535            &project,
536            SearchToolQuery {
537                glob: "*.rs".into(),
538                syntax_node: vec![],
539                content: Some("\\.clone".into()),
540            },
541            indoc! {r#"
542                `````root/main.rs
543                fn main() {
544                    let user = User::new(FIRST_NAME.clone(), "doe".into());
545546                `````
547
548                `````root/user.rs
549550                impl Organization {
551                    pub fn owner(&self) -> Arc<User> {
552                        self.owner.clone()
553554                impl User {
555556                    pub fn first_name(&self) -> String {
557                        self.first_name.clone()
558559                    pub fn last_name(&self) -> String {
560                        self.last_name.clone()
561562                `````
563            "#},
564            cx,
565        )
566        .await;
567    }
568
569    async fn assert_results(
570        project: &Entity<Project>,
571        query: SearchToolQuery,
572        expected_output: &str,
573        cx: &mut TestAppContext,
574    ) {
575        let results = run_retrieval_searches(
576            vec![query],
577            project.clone(),
578            #[cfg(feature = "eval-support")]
579            None,
580            &mut cx.to_async(),
581        )
582        .await
583        .unwrap();
584
585        let mut results = results.into_iter().collect::<Vec<_>>();
586        results.sort_by_key(|results| {
587            results
588                .0
589                .read_with(cx, |buffer, _| buffer.file().unwrap().path().clone())
590        });
591
592        let mut output = String::new();
593        for (buffer, ranges) in results {
594            buffer.read_with(cx, |buffer, cx| {
595                let excerpts = ranges.into_iter().map(|range| {
596                    let point_range = range.to_point(buffer);
597                    if point_range.end.column > 0 {
598                        Line(point_range.start.row)..Line(point_range.end.row + 1)
599                    } else {
600                        Line(point_range.start.row)..Line(point_range.end.row)
601                    }
602                });
603
604                write_codeblock(
605                    &buffer.file().unwrap().full_path(cx),
606                    assemble_excerpts(&buffer.snapshot(), excerpts).iter(),
607                    &[],
608                    Line(buffer.max_point().row),
609                    false,
610                    &mut output,
611                );
612            });
613        }
614        output.pop();
615
616        assert_eq!(output, expected_output);
617    }
618
619    fn rust_lang() -> Language {
620        Language::new(
621            LanguageConfig {
622                name: "Rust".into(),
623                matcher: LanguageMatcher {
624                    path_suffixes: vec!["rs".to_string()],
625                    ..Default::default()
626                },
627                ..Default::default()
628            },
629            Some(tree_sitter_rust::LANGUAGE.into()),
630        )
631        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
632        .unwrap()
633    }
634
635    fn init_test(cx: &mut TestAppContext) {
636        cx.update(move |cx| {
637            let settings_store = SettingsStore::test(cx);
638            cx.set_global(settings_store);
639            zlog::init_test();
640        });
641    }
642}