Populate project search results multi-buffer from semantic search

Max Brunsfeld and Kyle created

Co-authored-by: Kyle <kyle@zed.dev>

Change summary

crates/search/src/project_search.rs               | 73 +++++++++++-----
crates/semantic_index/src/db.rs                   | 16 +--
crates/semantic_index/src/embedding.rs            |  5 -
crates/semantic_index/src/semantic_index.rs       | 68 +++++++++------
crates/semantic_index/src/semantic_index_tests.rs | 15 +-
5 files changed, 104 insertions(+), 73 deletions(-)

Detailed changes

crates/search/src/project_search.rs 🔗

@@ -2,7 +2,7 @@ use crate::{
     SearchOption, SelectNextMatch, SelectPrevMatch, ToggleCaseSensitive, ToggleRegex,
     ToggleWholeWord,
 };
-use anyhow::{Context, Result};
+use anyhow::Result;
 use collections::HashMap;
 use editor::{
     items::active_match_index, scroll::autoscroll::Autoscroll, Anchor, Editor, MultiBuffer,
@@ -187,6 +187,53 @@ impl ProjectSearch {
         }));
         cx.notify();
     }
+
+    fn semantic_search(&mut self, query: String, cx: &mut ModelContext<Self>) -> Option<()> {
+        let project = self.project.clone();
+        let semantic_index = SemanticIndex::global(cx)?;
+        let search_task = semantic_index.update(cx, |semantic_index, cx| {
+            semantic_index.search_project(project, query.clone(), 10, cx)
+        });
+
+        self.search_id += 1;
+        // self.active_query = Some(query);
+        self.match_ranges.clear();
+        self.pending_search = Some(cx.spawn(|this, mut cx| async move {
+            let results = search_task.await.log_err()?;
+
+            let (_task, mut match_ranges) = this.update(&mut cx, |this, cx| {
+                this.excerpts.update(cx, |excerpts, cx| {
+                    excerpts.clear(cx);
+
+                    let matches = results
+                        .into_iter()
+                        .map(|result| (result.buffer, vec![result.range]))
+                        .collect();
+
+                    excerpts.stream_excerpts_with_context_lines(matches, 3, cx)
+                })
+            });
+
+            while let Some(match_range) = match_ranges.next().await {
+                this.update(&mut cx, |this, cx| {
+                    this.match_ranges.push(match_range);
+                    while let Ok(Some(match_range)) = match_ranges.try_next() {
+                        this.match_ranges.push(match_range);
+                    }
+                    cx.notify();
+                });
+            }
+
+            this.update(&mut cx, |this, cx| {
+                this.pending_search.take();
+                cx.notify();
+            });
+
+            None
+        }));
+
+        Some(())
+    }
 }
 
 pub enum ViewEvent {
@@ -595,27 +642,9 @@ impl ProjectSearchView {
                 return;
             }
 
-            let search_phrase = self.query_editor.read(cx).text(cx);
-            let project = self.model.read(cx).project.clone();
-            if let Some(semantic_index) = SemanticIndex::global(cx) {
-                let search_task = semantic_index.update(cx, |semantic_index, cx| {
-                    semantic_index.search_project(project, search_phrase, 10, cx)
-                });
-                semantic.search_task = Some(cx.spawn(|this, mut cx| async move {
-                    let results = search_task.await.context("search task")?;
-
-                    this.update(&mut cx, |this, cx| {
-                        dbg!(&results);
-                        // TODO: Update results
-
-                        if let Some(semantic) = &mut this.semantic {
-                            semantic.search_task = None;
-                        }
-                    })?;
-
-                    anyhow::Ok(())
-                }));
-            }
+            let query = self.query_editor.read(cx).text(cx);
+            self.model
+                .update(cx, |model, cx| model.semantic_search(query, cx));
             return;
         }
 

crates/semantic_index/src/db.rs 🔗

@@ -252,7 +252,7 @@ impl VectorDatabase {
         worktree_ids: &[i64],
         query_embedding: &Vec<f32>,
         limit: usize,
-    ) -> Result<Vec<(i64, PathBuf, Range<usize>, String)>> {
+    ) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
         let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
         self.for_each_document(&worktree_ids, |id, embedding| {
             let similarity = dot(&embedding, &query_embedding);
@@ -296,10 +296,7 @@ impl VectorDatabase {
         Ok(())
     }
 
-    fn get_documents_by_ids(
-        &self,
-        ids: &[i64],
-    ) -> Result<Vec<(i64, PathBuf, Range<usize>, String)>> {
+    fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
         let mut statement = self.db.prepare(
             "
                 SELECT
@@ -307,7 +304,7 @@ impl VectorDatabase {
                     files.worktree_id,
                     files.relative_path,
                     documents.start_byte,
-                    documents.end_byte, documents.name
+                    documents.end_byte
                 FROM
                     documents, files
                 WHERE
@@ -322,14 +319,13 @@ impl VectorDatabase {
                 row.get::<_, i64>(1)?,
                 row.get::<_, String>(2)?.into(),
                 row.get(3)?..row.get(4)?,
-                row.get(5)?,
             ))
         })?;
 
-        let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>, String)>::default();
+        let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
         for row in result_iter {
-            let (id, worktree_id, path, range, name) = row?;
-            values_by_id.insert(id, (worktree_id, path, range, name));
+            let (id, worktree_id, path, range) = row?;
+            values_by_id.insert(id, (worktree_id, path, range));
         }
 
         let mut results = Vec::with_capacity(ids.len());

crates/semantic_index/src/embedding.rs 🔗

@@ -70,10 +70,6 @@ impl EmbeddingProvider for DummyEmbeddings {
 const OPENAI_INPUT_LIMIT: usize = 8190;
 
 impl OpenAIEmbeddings {
-    pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
-        Self { client, executor }
-    }
-
     fn truncate(span: String) -> String {
         let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
         if tokens.len() > OPENAI_INPUT_LIMIT {
@@ -81,7 +77,6 @@ impl OpenAIEmbeddings {
             let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
             if result.is_ok() {
                 let transformed = result.unwrap();
-                // assert_ne!(transformed, span);
                 return transformed;
             }
         }

crates/semantic_index/src/semantic_index.rs 🔗

@@ -12,7 +12,7 @@ use db::VectorDatabase;
 use embedding::{EmbeddingProvider, OpenAIEmbeddings};
 use futures::{channel::oneshot, Future};
 use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
-use language::{Language, LanguageRegistry};
+use language::{Anchor, Buffer, Language, LanguageRegistry};
 use parking_lot::Mutex;
 use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES};
 use postage::watch;
@@ -93,7 +93,7 @@ pub struct SemanticIndex {
 struct ProjectState {
     worktree_db_ids: Vec<(WorktreeId, i64)>,
     outstanding_job_count_rx: watch::Receiver<usize>,
-    outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
+    _outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
 }
 
 struct JobHandle {
@@ -135,12 +135,9 @@ pub struct PendingFile {
     job_handle: JobHandle,
 }
 
-#[derive(Debug, Clone)]
 pub struct SearchResult {
-    pub worktree_id: WorktreeId,
-    pub name: String,
-    pub byte_range: Range<usize>,
-    pub file_path: PathBuf,
+    pub buffer: ModelHandle<Buffer>,
+    pub range: Range<Anchor>,
 }
 
 enum DbOperation {
@@ -520,7 +517,7 @@ impl SemanticIndex {
                             .map(|(a, b)| (*a, *b))
                             .collect(),
                         outstanding_job_count_rx: job_count_rx.clone(),
-                        outstanding_job_count_tx: job_count_tx.clone(),
+                        _outstanding_job_count_tx: job_count_tx.clone(),
                     },
                 );
             });
@@ -623,7 +620,7 @@ impl SemanticIndex {
         let embedding_provider = self.embedding_provider.clone();
         let database_url = self.database_url.clone();
         let fs = self.fs.clone();
-        cx.spawn(|this, cx| async move {
+        cx.spawn(|this, mut cx| async move {
             let documents = cx
                 .background()
                 .spawn(async move {
@@ -640,26 +637,39 @@ impl SemanticIndex {
                 })
                 .await?;
 
-            this.read_with(&cx, |this, _| {
-                let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
-                    state
-                } else {
-                    return Err(anyhow!("project not added"));
-                };
-
-                Ok(documents
-                    .into_iter()
-                    .filter_map(|(worktree_db_id, file_path, byte_range, name)| {
-                        let worktree_id = project_state.worktree_id_for_db_id(worktree_db_id)?;
-                        Some(SearchResult {
-                            worktree_id,
-                            name,
-                            byte_range,
-                            file_path,
-                        })
-                    })
-                    .collect())
-            })
+            let mut tasks = Vec::new();
+            let mut ranges = Vec::new();
+            let weak_project = project.downgrade();
+            project.update(&mut cx, |project, cx| {
+                for (worktree_db_id, file_path, byte_range) in documents {
+                    let project_state =
+                        if let Some(state) = this.read(cx).projects.get(&weak_project) {
+                            state
+                        } else {
+                            return Err(anyhow!("project not added"));
+                        };
+                    if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
+                        tasks.push(project.open_buffer((worktree_id, file_path), cx));
+                        ranges.push(byte_range);
+                    }
+                }
+
+                Ok(())
+            })?;
+
+            let buffers = futures::future::join_all(tasks).await;
+
+            Ok(buffers
+                .into_iter()
+                .zip(ranges)
+                .filter_map(|(buffer, range)| {
+                    let buffer = buffer.log_err()?;
+                    let range = buffer.read_with(&cx, |buffer, _| {
+                        buffer.anchor_before(range.start)..buffer.anchor_after(range.end)
+                    });
+                    Some(SearchResult { buffer, range })
+                })
+                .collect::<Vec<_>>())
         })
     }
 }

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -8,7 +8,7 @@ use crate::{
 use anyhow::Result;
 use async_trait::async_trait;
 use gpui::{Task, TestAppContext};
-use language::{Language, LanguageConfig, LanguageRegistry};
+use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
 use project::{project_settings::ProjectSettings, FakeFs, Fs, Project};
 use rand::{rngs::StdRng, Rng};
 use serde_json::json;
@@ -85,9 +85,6 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
     .unwrap();
 
     let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
-    let worktree_id = project.read_with(cx, |project, cx| {
-        project.worktrees(cx).next().unwrap().read(cx).id()
-    });
     let (file_count, outstanding_file_count) = store
         .update(cx, |store, cx| store.index_project(project.clone(), cx))
         .await
@@ -103,9 +100,13 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
         .await
         .unwrap();
 
-    assert_eq!(search_results[0].byte_range.start, 0);
-    assert_eq!(search_results[0].name, "aaa");
-    assert_eq!(search_results[0].worktree_id, worktree_id);
+    search_results[0].buffer.read_with(cx, |buffer, _cx| {
+        assert_eq!(search_results[0].range.start.to_offset(buffer), 0);
+        assert_eq!(
+            buffer.file().unwrap().path().as_ref(),
+            Path::new("file1.rs")
+        );
+    });
 
     fs.save(
         "/the-root/src/file2.rs".as_ref(),