retrieval_search.rs

  1use std::ops::Range;
  2
  3use anyhow::Result;
  4use cloud_zeta2_prompt::retrieval_prompt::SearchToolQuery;
  5use collections::HashMap;
  6use futures::{
  7    StreamExt,
  8    channel::mpsc::{self, UnboundedSender},
  9};
 10use gpui::{AppContext, AsyncApp, Entity};
 11use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt, Point, ToOffset, ToPoint};
 12use project::{
 13    Project, WorktreeSettings,
 14    search::{SearchQuery, SearchResult},
 15};
 16use smol::channel;
 17use util::{
 18    ResultExt as _,
 19    paths::{PathMatcher, PathStyle},
 20};
 21use workspace::item::Settings as _;
 22
 23pub async fn run_retrieval_searches(
 24    project: Entity<Project>,
 25    queries: Vec<SearchToolQuery>,
 26    cx: &mut AsyncApp,
 27) -> Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>> {
 28    let (exclude_matcher, path_style) = project.update(cx, |project, cx| {
 29        let global_settings = WorktreeSettings::get_global(cx);
 30        let exclude_patterns = global_settings
 31            .file_scan_exclusions
 32            .sources()
 33            .iter()
 34            .chain(global_settings.private_files.sources().iter());
 35        let path_style = project.path_style(cx);
 36        anyhow::Ok((PathMatcher::new(exclude_patterns, path_style)?, path_style))
 37    })??;
 38
 39    let (results_tx, mut results_rx) = mpsc::unbounded();
 40
 41    for query in queries {
 42        let exclude_matcher = exclude_matcher.clone();
 43        let results_tx = results_tx.clone();
 44        let project = project.clone();
 45        cx.spawn(async move |cx| {
 46            run_query(
 47                query,
 48                results_tx.clone(),
 49                path_style,
 50                exclude_matcher,
 51                &project,
 52                cx,
 53            )
 54            .await
 55            .log_err();
 56        })
 57        .detach()
 58    }
 59    drop(results_tx);
 60
 61    cx.background_spawn(async move {
 62        let mut results: HashMap<Entity<Buffer>, Vec<Range<Anchor>>> = HashMap::default();
 63        let mut snapshots = HashMap::default();
 64
 65        let mut total_bytes = 0;
 66        'outer: while let Some((buffer, snapshot, excerpts)) = results_rx.next().await {
 67            snapshots.insert(buffer.entity_id(), snapshot);
 68            let existing = results.entry(buffer).or_default();
 69            existing.reserve(excerpts.len());
 70
 71            for (range, size) in excerpts {
 72                // Blunt trimming of the results until we have a proper algorithmic filtering step
 73                if (total_bytes + size) > MAX_RESULTS_LEN {
 74                    log::trace!("Combined results reached limit of {MAX_RESULTS_LEN}B");
 75                    break 'outer;
 76                }
 77                total_bytes += size;
 78                existing.push(range);
 79            }
 80        }
 81
 82        for (buffer, ranges) in results.iter_mut() {
 83            if let Some(snapshot) = snapshots.get(&buffer.entity_id()) {
 84                ranges.sort_unstable_by(|a, b| {
 85                    a.start
 86                        .cmp(&b.start, snapshot)
 87                        .then(b.end.cmp(&b.end, snapshot))
 88                });
 89
 90                let mut index = 1;
 91                while index < ranges.len() {
 92                    if ranges[index - 1]
 93                        .end
 94                        .cmp(&ranges[index].start, snapshot)
 95                        .is_gt()
 96                    {
 97                        let removed = ranges.remove(index);
 98                        ranges[index - 1].end = removed.end;
 99                    } else {
100                        index += 1;
101                    }
102                }
103            }
104        }
105
106        Ok(results)
107    })
108    .await
109}
110
111const MAX_EXCERPT_LEN: usize = 768;
112const MAX_RESULTS_LEN: usize = MAX_EXCERPT_LEN * 5;
113
114struct SearchJob {
115    buffer: Entity<Buffer>,
116    snapshot: BufferSnapshot,
117    ranges: Vec<Range<usize>>,
118    query_ix: usize,
119    jobs_tx: channel::Sender<SearchJob>,
120}
121
122async fn run_query(
123    input_query: SearchToolQuery,
124    results_tx: UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
125    path_style: PathStyle,
126    exclude_matcher: PathMatcher,
127    project: &Entity<Project>,
128    cx: &mut AsyncApp,
129) -> Result<()> {
130    let include_matcher = PathMatcher::new(vec![input_query.glob], path_style)?;
131
132    let make_search = |regex: &str| -> Result<SearchQuery> {
133        SearchQuery::regex(
134            regex,
135            false,
136            true,
137            false,
138            true,
139            include_matcher.clone(),
140            exclude_matcher.clone(),
141            true,
142            None,
143        )
144    };
145
146    if let Some(outer_syntax_regex) = input_query.syntax_node.first() {
147        let outer_syntax_query = make_search(outer_syntax_regex)?;
148        let nested_syntax_queries = input_query
149            .syntax_node
150            .into_iter()
151            .skip(1)
152            .map(|query| make_search(&query))
153            .collect::<Result<Vec<_>>>()?;
154        let content_query = input_query
155            .content
156            .map(|regex| make_search(&regex))
157            .transpose()?;
158
159        let (jobs_tx, jobs_rx) = channel::unbounded();
160
161        let outer_search_results_rx =
162            project.update(cx, |project, cx| project.search(outer_syntax_query, cx))?;
163
164        let outer_search_task = cx.spawn(async move |cx| {
165            futures::pin_mut!(outer_search_results_rx);
166            while let Some(SearchResult::Buffer { buffer, ranges }) =
167                outer_search_results_rx.next().await
168            {
169                buffer
170                    .read_with(cx, |buffer, _| buffer.parsing_idle())?
171                    .await;
172                let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
173                let expanded_ranges: Vec<_> = ranges
174                    .into_iter()
175                    .filter_map(|range| expand_to_parent_range(&range, &snapshot))
176                    .collect();
177                jobs_tx
178                    .send(SearchJob {
179                        buffer,
180                        snapshot,
181                        ranges: expanded_ranges,
182                        query_ix: 0,
183                        jobs_tx: jobs_tx.clone(),
184                    })
185                    .await?;
186            }
187            anyhow::Ok(())
188        });
189
190        let n_workers = cx.background_executor().num_cpus();
191        let search_job_task = cx.background_executor().scoped(|scope| {
192            for _ in 0..n_workers {
193                scope.spawn(async {
194                    while let Ok(job) = jobs_rx.recv().await {
195                        process_nested_search_job(
196                            &results_tx,
197                            &nested_syntax_queries,
198                            &content_query,
199                            job,
200                        )
201                        .await;
202                    }
203                });
204            }
205        });
206
207        search_job_task.await;
208        outer_search_task.await?;
209    } else if let Some(content_regex) = &input_query.content {
210        let search_query = make_search(&content_regex)?;
211
212        let results_rx = project.update(cx, |project, cx| project.search(search_query, cx))?;
213        futures::pin_mut!(results_rx);
214
215        while let Some(SearchResult::Buffer { buffer, ranges }) = results_rx.next().await {
216            let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
217
218            let ranges = ranges
219                .into_iter()
220                .map(|range| {
221                    let range = range.to_offset(&snapshot);
222                    let range = expand_to_entire_lines(range, &snapshot);
223                    let size = range.len();
224                    let range =
225                        snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
226                    (range, size)
227                })
228                .collect();
229
230            let send_result = results_tx.unbounded_send((buffer.clone(), snapshot.clone(), ranges));
231
232            if let Err(err) = send_result
233                && !err.is_disconnected()
234            {
235                log::error!("{err}");
236            }
237        }
238    } else {
239        log::warn!("Context gathering model produced a glob-only search");
240    }
241
242    anyhow::Ok(())
243}
244
245async fn process_nested_search_job(
246    results_tx: &UnboundedSender<(Entity<Buffer>, BufferSnapshot, Vec<(Range<Anchor>, usize)>)>,
247    queries: &Vec<SearchQuery>,
248    content_query: &Option<SearchQuery>,
249    job: SearchJob,
250) {
251    if let Some(search_query) = queries.get(job.query_ix) {
252        let mut subranges = Vec::new();
253        for range in job.ranges {
254            let start = range.start;
255            let search_results = search_query.search(&job.snapshot, Some(range)).await;
256            for subrange in search_results {
257                let subrange = start + subrange.start..start + subrange.end;
258                subranges.extend(expand_to_parent_range(&subrange, &job.snapshot));
259            }
260        }
261        job.jobs_tx
262            .send(SearchJob {
263                buffer: job.buffer,
264                snapshot: job.snapshot,
265                ranges: subranges,
266                query_ix: job.query_ix + 1,
267                jobs_tx: job.jobs_tx.clone(),
268            })
269            .await
270            .ok();
271    } else {
272        let ranges = if let Some(content_query) = content_query {
273            let mut subranges = Vec::new();
274            for range in job.ranges {
275                let start = range.start;
276                let search_results = content_query.search(&job.snapshot, Some(range)).await;
277                for subrange in search_results {
278                    let subrange = start + subrange.start..start + subrange.end;
279                    subranges.push(subrange);
280                }
281            }
282            subranges
283        } else {
284            job.ranges
285        };
286
287        let matches = ranges
288            .into_iter()
289            .map(|range| {
290                let snapshot = &job.snapshot;
291                let range = expand_to_entire_lines(range, snapshot);
292                let size = range.len();
293                let range = snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end);
294                (range, size)
295            })
296            .collect();
297
298        let send_result = results_tx.unbounded_send((job.buffer, job.snapshot, matches));
299
300        if let Err(err) = send_result
301            && !err.is_disconnected()
302        {
303            log::error!("{err}");
304        }
305    }
306}
307
308fn expand_to_entire_lines(range: Range<usize>, snapshot: &BufferSnapshot) -> Range<usize> {
309    let mut point_range = range.to_point(snapshot);
310    point_range.start.column = 0;
311    if point_range.end.column > 0 {
312        point_range.end = snapshot.max_point().min(point_range.end + Point::new(1, 0));
313    }
314    point_range.to_offset(snapshot)
315}
316
317fn expand_to_parent_range<T: ToPoint + ToOffset>(
318    range: &Range<T>,
319    snapshot: &BufferSnapshot,
320) -> Option<Range<usize>> {
321    let mut line_range = range.to_point(&snapshot);
322    line_range.start.column = snapshot.indent_size_for_line(line_range.start.row).len;
323    line_range.end.column = snapshot.line_len(line_range.end.row);
324    // TODO skip result if matched line isn't the first node line?
325
326    let node = snapshot.syntax_ancestor(line_range)?;
327    Some(node.byte_range())
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    use crate::merge_excerpts::merge_excerpts;
334    use cloud_zeta2_prompt::write_codeblock;
335    use edit_prediction_context::Line;
336    use gpui::TestAppContext;
337    use indoc::indoc;
338    use language::{Language, LanguageConfig, LanguageMatcher, tree_sitter_rust};
339    use pretty_assertions::assert_eq;
340    use project::FakeFs;
341    use serde_json::json;
342    use settings::SettingsStore;
343    use std::path::Path;
344    use util::path;
345
346    #[gpui::test]
347    async fn test_retrieval(cx: &mut TestAppContext) {
348        init_test(cx);
349        let fs = FakeFs::new(cx.executor());
350        fs.insert_tree(
351            path!("/root"),
352            json!({
353                "user.rs": indoc!{"
354                    pub struct Organization {
355                        owner: Arc<User>,
356                    }
357
358                    pub struct User {
359                        first_name: String,
360                        last_name: String,
361                    }
362
363                    impl Organization {
364                        pub fn owner(&self) -> Arc<User> {
365                            self.owner.clone()
366                        }
367                    }
368
369                    impl User {
370                        pub fn new(first_name: String, last_name: String) -> Self {
371                            Self {
372                                first_name,
373                                last_name
374                            }
375                        }
376
377                        pub fn first_name(&self) -> String {
378                            self.first_name.clone()
379                        }
380
381                        pub fn last_name(&self) -> String {
382                            self.last_name.clone()
383                        }
384                    }
385                "},
386                "main.rs": indoc!{r#"
387                    fn main() {
388                        let user = User::new(FIRST_NAME.clone(), "doe".into());
389                        println!("user {:?}", user);
390                    }
391                "#},
392            }),
393        )
394        .await;
395
396        let project = Project::test(fs, vec![Path::new(path!("/root"))], cx).await;
397        project.update(cx, |project, _cx| {
398            project.languages().add(rust_lang().into())
399        });
400
401        assert_results(
402            &project,
403            SearchToolQuery {
404                glob: "user.rs".into(),
405                syntax_node: vec!["impl\\s+User".into(), "pub\\s+fn\\s+first_name".into()],
406                content: None,
407            },
408            indoc! {r#"
409                `````root/user.rs
410411                impl User {
412413                    pub fn first_name(&self) -> String {
414                        self.first_name.clone()
415                    }
416417                `````
418            "#},
419            cx,
420        )
421        .await;
422
423        assert_results(
424            &project,
425            SearchToolQuery {
426                glob: "user.rs".into(),
427                syntax_node: vec!["impl\\s+User".into()],
428                content: Some("\\.clone".into()),
429            },
430            indoc! {r#"
431                `````root/user.rs
432433                impl User {
434435                    pub fn first_name(&self) -> String {
436                        self.first_name.clone()
437438                    pub fn last_name(&self) -> String {
439                        self.last_name.clone()
440441                `````
442            "#},
443            cx,
444        )
445        .await;
446
447        assert_results(
448            &project,
449            SearchToolQuery {
450                glob: "*.rs".into(),
451                syntax_node: vec![],
452                content: Some("\\.clone".into()),
453            },
454            indoc! {r#"
455                `````root/main.rs
456                fn main() {
457                    let user = User::new(FIRST_NAME.clone(), "doe".into());
458459                `````
460
461                `````root/user.rs
462463                impl Organization {
464                    pub fn owner(&self) -> Arc<User> {
465                        self.owner.clone()
466467                impl User {
468469                    pub fn first_name(&self) -> String {
470                        self.first_name.clone()
471472                    pub fn last_name(&self) -> String {
473                        self.last_name.clone()
474475                `````
476            "#},
477            cx,
478        )
479        .await;
480    }
481
482    async fn assert_results(
483        project: &Entity<Project>,
484        query: SearchToolQuery,
485        expected_output: &str,
486        cx: &mut TestAppContext,
487    ) {
488        let results = run_retrieval_searches(project.clone(), vec![query], &mut cx.to_async())
489            .await
490            .unwrap();
491
492        let mut results = results.into_iter().collect::<Vec<_>>();
493        results.sort_by_key(|results| {
494            results
495                .0
496                .read_with(cx, |buffer, _| buffer.file().unwrap().path().clone())
497        });
498
499        let mut output = String::new();
500        for (buffer, ranges) in results {
501            buffer.read_with(cx, |buffer, cx| {
502                let excerpts = ranges.into_iter().map(|range| {
503                    let point_range = range.to_point(buffer);
504                    if point_range.end.column > 0 {
505                        Line(point_range.start.row)..Line(point_range.end.row + 1)
506                    } else {
507                        Line(point_range.start.row)..Line(point_range.end.row)
508                    }
509                });
510
511                write_codeblock(
512                    &buffer.file().unwrap().full_path(cx),
513                    merge_excerpts(&buffer.snapshot(), excerpts).iter(),
514                    &[],
515                    Line(buffer.max_point().row),
516                    false,
517                    &mut output,
518                );
519            });
520        }
521        output.pop();
522
523        assert_eq!(output, expected_output);
524    }
525
526    fn rust_lang() -> Language {
527        Language::new(
528            LanguageConfig {
529                name: "Rust".into(),
530                matcher: LanguageMatcher {
531                    path_suffixes: vec!["rs".to_string()],
532                    ..Default::default()
533                },
534                ..Default::default()
535            },
536            Some(tree_sitter_rust::LANGUAGE.into()),
537        )
538        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
539        .unwrap()
540    }
541
542    fn init_test(cx: &mut TestAppContext) {
543        cx.update(move |cx| {
544            let settings_store = SettingsStore::test(cx);
545            cx.set_global(settings_store);
546            zlog::init_test();
547        });
548    }
549}