retrieval_search.rs

  1use std::ops::Range;
  2
  3use anyhow::Result;
  4use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery;
  5use collections::HashMap;
  6use futures::{
  7    StreamExt,
  8    channel::mpsc::{self, UnboundedSender},
  9};
 10use gpui::{AppContext, AsyncApp, Entity};
 11use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, Point, ToOffset, ToPoint};
 12use project::{
 13    Project, WorktreeSettings,
 14    search::{SearchQuery, SearchResult},
 15};
 16use smol::channel;
 17use util::{
 18    ResultExt as _,
 19    paths::{PathMatcher, PathStyle},
 20};
 21use workspace::item::Settings as _;
 22
 23pub async fn run_retrieval_searches(
 24    project: Entity<Project>,
 25    queries: Vec<SearchToolQuery>,
 26    cx: &mut AsyncApp,
 27) -> Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>> {
 28    let (exclude_matcher, path_style) = project.update(cx, |project, cx| {
 29        let global_settings = WorktreeSettings::get_global(cx);
 30        let exclude_patterns = global_settings
 31            .file_scan_exclusions
 32            .sources()
 33            .iter()
 34            .chain(global_settings.private_files.sources().iter());
 35        let path_style = project.path_style(cx);
 36        anyhow::Ok((PathMatcher::new(exclude_patterns, path_style)?, path_style))
 37    })??;
 38
 39    let (results_tx, mut results_rx) = mpsc::unbounded();
 40
 41    for query in queries {
 42        let exclude_matcher = exclude_matcher.clone();
 43        let results_tx = results_tx.clone();
 44        let project = project.clone();
 45        cx.spawn(async move |cx| {
 46            run_query(
 47                query,
 48                results_tx.clone(),
 49                path_style,
 50                exclude_matcher,
 51                &project,
 52                cx,
 53            )
 54            .await
 55            .log_err();
 56        })
 57        .detach()
 58    }
 59    drop(results_tx);
 60
 61    cx.background_spawn(async move {
 62        let mut results: HashMap<Entity<Buffer>, Vec<Range<Anchor>>> = HashMap::default();
 63        let mut snapshots = HashMap::default();
 64
 65        let mut total_bytes = 0;
 66        'outer: while let Some((buffer, snapshot, excerpts)) = results_rx.next().await {
 67            snapshots.insert(buffer.entity_id(), snapshot);
 68            let existing = results.entry(buffer).or_default();
 69            existing.reserve(excerpts.len());
 70
 71            for (range, size) in excerpts {
 72                // Blunt trimming of the results until we have a proper algorithmic filtering step
 73                if (total_bytes + size) > MAX_RESULTS_LEN {
 74                    log::trace!("Combined results reached limit of {MAX_RESULTS_LEN}B");
 75                    break 'outer;
 76                }
 77                total_bytes += size;
 78                existing.push(range);
 79            }
 80        }
 81
 82        for (buffer, ranges) in results.iter_mut() {
 83            if let Some(snapshot) = snapshots.get(&buffer.entity_id()) {
 84                merge_anchor_ranges(ranges, snapshot);
 85            }
 86        }
 87
 88        Ok(results)
 89    })
 90    .await
 91}
 92
 93fn merge_anchor_ranges(ranges: &mut Vec<Range<Anchor>>, snapshot: &BufferSnapshot) {
 94    ranges.sort_unstable_by(|a, b| {
 95        a.start
 96            .cmp(&b.start, snapshot)
 97            .then(b.end.cmp(&b.end, snapshot))
 98    });
 99
100    let mut index = 1;
101    while index < ranges.len() {
102        if ranges[index - 1]
103            .end
104            .cmp(&ranges[index].start, snapshot)
105            .is_ge()
106        {
107            let removed = ranges.remove(index);
108            ranges[index - 1].end = removed.end;
109        } else {
110            index += 1;
111        }
112    }
113}
114
115const MAX_EXCERPT_LEN: usize = 768;
116const MAX_RESULTS_LEN: usize = MAX_EXCERPT_LEN * 5;
117
118struct SearchJob {
119    buffer: Entity<Buffer>,
120    snapshot: BufferSnapshot,
121    ranges: Vec<Range<usize>>,
122    query_ix: usize,
123    jobs_tx: channel::Sender<SearchJob>,
124}
125
126async fn run_query(
127    input_query: SearchToolQuery,
128    results_tx: UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
129    path_style: PathStyle,
130    exclude_matcher: PathMatcher,
131    project: &Entity<Project>,
132    cx: &mut AsyncApp,
133) -> Result<()> {
134    let include_matcher = PathMatcher::new(vec![input_query.glob], path_style)?;
135
136    let make_search = |regex: &str| -> Result<SearchQuery> {
137        SearchQuery::regex(
138            regex,
139            false,
140            true,
141            false,
142            true,
143            include_matcher.clone(),
144            exclude_matcher.clone(),
145            true,
146            None,
147        )
148    };
149
150    if let Some(outer_syntax_regex) = input_query.syntax_node.first() {
151        let outer_syntax_query = make_search(outer_syntax_regex)?;
152        let nested_syntax_queries = input_query
153            .syntax_node
154            .into_iter()
155            .skip(1)
156            .map(|query| make_search(&query))
157            .collect::<Result<Vec<_>>>()?;
158        let content_query = input_query
159            .content
160            .map(|regex| make_search(&regex))
161            .transpose()?;
162
163        let (jobs_tx, jobs_rx) = channel::unbounded();
164
165        let outer_search_results_rx =
166            project.update(cx, |project, cx| project.search(outer_syntax_query, cx))?;
167
168        let outer_search_task = cx.spawn(async move |cx| {
169            futures::pin_mut!(outer_search_results_rx);
170            while let Some(SearchResult::Buffer { buffer, ranges }) =
171                outer_search_results_rx.next().await
172            {
173                buffer
174                    .read_with(cx, |buffer, _| buffer.parsing_idle())?
175                    .await;
176                let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
177                let expanded_ranges: Vec<_> = ranges
178                    .into_iter()
179                    .filter_map(|range| expand_to_parent_range(&range, &snapshot))
180                    .collect();
181                jobs_tx
182                    .send(SearchJob {
183                        buffer,
184                        snapshot,
185                        ranges: expanded_ranges,
186                        query_ix: 0,
187                        jobs_tx: jobs_tx.clone(),
188                    })
189                    .await?;
190            }
191            anyhow::Ok(())
192        });
193
194        let n_workers = cx.background_executor().num_cpus();
195        let search_job_task = cx.background_executor().scoped(|scope| {
196            for _ in 0..n_workers {
197                scope.spawn(async {
198                    while let Ok(job) = jobs_rx.recv().await {
199                        process_nested_search_job(
200                            &results_tx,
201                            &nested_syntax_queries,
202                            &content_query,
203                            job,
204                        )
205                        .await;
206                    }
207                });
208            }
209        });
210
211        search_job_task.await;
212        outer_search_task.await?;
213    } else if let Some(content_regex) = &input_query.content {
214        let search_query = make_search(&content_regex)?;
215
216        let results_rx = project.update(cx, |project, cx| project.search(search_query, cx))?;
217        futures::pin_mut!(results_rx);
218
219        while let Some(SearchResult::Buffer { buffer, ranges }) = results_rx.next().await {
220            let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
221
222            let ranges = ranges
223                .into_iter()
224                .map(|range| {
225                    let range = range.to_offset(&snapshot);
226                    let range = expand_to_entire_lines(range, &snapshot);
227                    let size = range.len();
228                    let range =
229                        snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
230                    (range, size)
231                })
232                .collect();
233
234            let send_result = results_tx.unbounded_send((buffer.clone(), snapshot.clone(), ranges));
235
236            if let Err(err) = send_result
237                && !err.is_disconnected()
238            {
239                log::error!("{err}");
240            }
241        }
242    } else {
243        log::warn!("Context gathering model produced a glob-only search");
244    }
245
246    anyhow::Ok(())
247}
248
249async fn process_nested_search_job(
250    results_tx: &UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
251    queries: &Vec<SearchQuery>,
252    content_query: &Option<SearchQuery>,
253    job: SearchJob,
254) {
255    if let Some(search_query) = queries.get(job.query_ix) {
256        let mut subranges = Vec::new();
257        for range in job.ranges {
258            let start = range.start;
259            let search_results = search_query.search(&job.snapshot, Some(range)).await;
260            for subrange in search_results {
261                let subrange = start + subrange.start..start + subrange.end;
262                subranges.extend(expand_to_parent_range(&subrange, &job.snapshot));
263            }
264        }
265        job.jobs_tx
266            .send(SearchJob {
267                buffer: job.buffer,
268                snapshot: job.snapshot,
269                ranges: subranges,
270                query_ix: job.query_ix + 1,
271                jobs_tx: job.jobs_tx.clone(),
272            })
273            .await
274            .ok();
275    } else {
276        let ranges = if let Some(content_query) = content_query {
277            let mut subranges = Vec::new();
278            for range in job.ranges {
279                let start = range.start;
280                let search_results = content_query.search(&job.snapshot, Some(range)).await;
281                for subrange in search_results {
282                    let subrange = start + subrange.start..start + subrange.end;
283                    subranges.push(subrange);
284                }
285            }
286            subranges
287        } else {
288            job.ranges
289        };
290
291        let matches = ranges
292            .into_iter()
293            .map(|range| {
294                let snapshot = &job.snapshot;
295                let range = expand_to_entire_lines(range, snapshot);
296                let size = range.len();
297                let range = snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
298                (range, size)
299            })
300            .collect();
301
302        let send_result = results_tx.unbounded_send((job.buffer, job.snapshot, matches));
303
304        if let Err(err) = send_result
305            && !err.is_disconnected()
306        {
307            log::error!("{err}");
308        }
309    }
310}
311
312fn expand_to_entire_lines(range: Range<usize>, snapshot: &BufferSnapshot) -> Range<usize> {
313    let mut point_range = range.to_point(snapshot);
314    point_range.start.column = 0;
315    if point_range.end.column > 0 {
316        point_range.end = snapshot.max_point().min(point_range.end + Point::new(1, 0));
317    }
318    point_range.to_offset(snapshot)
319}
320
321fn expand_to_parent_range<T: ToPoint + ToOffset>(
322    range: &Range<T>,
323    snapshot: &BufferSnapshot,
324) -> Option<Range<usize>> {
325    let mut line_range = range.to_point(&snapshot);
326    line_range.start.column = snapshot.indent_size_for_line(line_range.start.row).len;
327    line_range.end.column = snapshot.line_len(line_range.end.row);
328    // TODO skip result if matched line isn't the first node line?
329
330    let node = snapshot.syntax_ancestor(line_range)?;
331    Some(node.byte_range())
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337    use crate::merge_excerpts::merge_excerpts;
338    use cloud_zeta2_prompt::write_codeblock;
339    use edit_prediction_context::Line;
340    use gpui::TestAppContext;
341    use indoc::indoc;
342    use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
343    use pretty_assertions::assert_eq;
344    use project::FakeFs;
345    use serde_json::json;
346    use settings::SettingsStore;
347    use std::path::Path;
348    use util::path;
349
350    #[gpui::test]
351    async fn test_retrieval(cx: &mut TestAppContext) {
352        init_test(cx);
353        let fs = FakeFs::new(cx.executor());
354        fs.insert_tree(
355            path!("/root"),
356            json!({
357                "user.rs": indoc!{"
358                    pub struct Organization {
359                        owner: Arc<User>,
360                    }
361
362                    pub struct User {
363                        first_name: String,
364                        last_name: String,
365                    }
366
367                    impl Organization {
368                        pub fn owner(&self) -> Arc<User> {
369                            self.owner.clone()
370                        }
371                    }
372
373                    impl User {
374                        pub fn new(first_name: String, last_name: String) -> Self {
375                            Self {
376                                first_name,
377                                last_name
378                            }
379                        }
380
381                        pub fn first_name(&self) -> String {
382                            self.first_name.clone()
383                        }
384
385                        pub fn last_name(&self) -> String {
386                            self.last_name.clone()
387                        }
388                    }
389                "},
390                "main.rs": indoc!{r#"
391                    fn main() {
392                        let user = User::new(FIRST_NAME.clone(), "doe".into());
393                        println!("user {:?}", user);
394                    }
395                "#},
396            }),
397        )
398        .await;
399
400        let project = Project::test(fs, vec![Path::new(path!("/root"))], cx).await;
401        project.update(cx, |project, _cx| {
402            project.languages().add(rust_lang().into())
403        });
404
405        assert_results(
406            &project,
407            SearchToolQuery {
408                glob: "user.rs".into(),
409                syntax_node: vec!["impl\\s+User".into(), "pub\\s+fn\\s+first_name".into()],
410                content: None,
411            },
412            indoc! {r#"
413                `````root/user.rs
414415                impl User {
416417                    pub fn first_name(&self) -> String {
418                        self.first_name.clone()
419                    }
420421                `````
422            "#},
423            cx,
424        )
425        .await;
426
427        assert_results(
428            &project,
429            SearchToolQuery {
430                glob: "user.rs".into(),
431                syntax_node: vec!["impl\\s+User".into()],
432                content: Some("\\.clone".into()),
433            },
434            indoc! {r#"
435                `````root/user.rs
436437                impl User {
438439                    pub fn first_name(&self) -> String {
440                        self.first_name.clone()
441442                    pub fn last_name(&self) -> String {
443                        self.last_name.clone()
444445                `````
446            "#},
447            cx,
448        )
449        .await;
450
451        assert_results(
452            &project,
453            SearchToolQuery {
454                glob: "*.rs".into(),
455                syntax_node: vec![],
456                content: Some("\\.clone".into()),
457            },
458            indoc! {r#"
459                `````root/main.rs
460                fn main() {
461                    let user = User::new(FIRST_NAME.clone(), "doe".into());
462463                `````
464
465                `````root/user.rs
466467                impl Organization {
468                    pub fn owner(&self) -> Arc<User> {
469                        self.owner.clone()
470471                impl User {
472473                    pub fn first_name(&self) -> String {
474                        self.first_name.clone()
475476                    pub fn last_name(&self) -> String {
477                        self.last_name.clone()
478479                `````
480            "#},
481            cx,
482        )
483        .await;
484    }
485
486    async fn assert_results(
487        project: &Entity<Project>,
488        query: SearchToolQuery,
489        expected_output: &str,
490        cx: &mut TestAppContext,
491    ) {
492        let results = run_retrieval_searches(project.clone(), vec![query], &mut cx.to_async())
493            .await
494            .unwrap();
495
496        let mut results = results.into_iter().collect::<Vec<_>>();
497        results.sort_by_key(|results| {
498            results
499                .0
500                .read_with(cx, |buffer, _| buffer.file().unwrap().path().clone())
501        });
502
503        let mut output = String::new();
504        for (buffer, ranges) in results {
505            buffer.read_with(cx, |buffer, cx| {
506                let excerpts = ranges.into_iter().map(|range| {
507                    let point_range = range.to_point(buffer);
508                    if point_range.end.column > 0 {
509                        Line(point_range.start.row)..Line(point_range.end.row + 1)
510                    } else {
511                        Line(point_range.start.row)..Line(point_range.end.row)
512                    }
513                });
514
515                write_codeblock(
516                    &buffer.file().unwrap().full_path(cx),
517                    merge_excerpts(&buffer.snapshot(), excerpts).iter(),
518                    &[],
519                    Line(buffer.max_point().row),
520                    false,
521                    &mut output,
522                );
523            });
524        }
525        output.pop();
526
527        assert_eq!(output, expected_output);
528    }
529
530    fn rust_lang() -> Language {
531        Language::new(
532            LanguageConfig {
533                name: "Rust".into(),
534                matcher: LanguageMatcher {
535                    path_suffixes: vec!["rs".to_string()],
536                    ..Default::default()
537                },
538                ..Default::default()
539            },
540            Some(tree_sitter_rust::LANGUAGE.into()),
541        )
542        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
543        .unwrap()
544    }
545
546    fn init_test(cx: &mut TestAppContext) {
547        cx.update(move |cx| {
548            let settings_store = SettingsStore::test(cx);
549            cx.set_global(settings_store);
550            zlog::init_test();
551        });
552    }
553}