Refactor semantic searching of modified buffers

Antonio Scandurra created

Change summary

Cargo.lock                                  |   1 
crates/semantic_index/Cargo.toml            |   1 
crates/semantic_index/src/db.rs             |  13 
crates/semantic_index/src/embedding.rs      |  11 
crates/semantic_index/src/semantic_index.rs | 415 +++++++++++-----------
5 files changed, 214 insertions(+), 227 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -6739,6 +6739,7 @@ dependencies = [
  "lazy_static",
  "log",
  "matrixmultiply",
+ "ordered-float",
  "parking_lot 0.11.2",
  "parse_duration",
  "picker",

crates/semantic_index/Cargo.toml 🔗

@@ -23,6 +23,7 @@ settings = { path = "../settings" }
 anyhow.workspace = true
 postage.workspace = true
 futures.workspace = true
+ordered-float.workspace = true
 smol.workspace = true
 rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
 isahc.workspace = true

crates/semantic_index/src/db.rs 🔗

@@ -7,12 +7,13 @@ use anyhow::{anyhow, Context, Result};
 use collections::HashMap;
 use futures::channel::oneshot;
 use gpui::executor;
+use ordered_float::OrderedFloat;
 use project::{search::PathMatcher, Fs};
 use rpc::proto::Timestamp;
 use rusqlite::params;
 use rusqlite::types::Value;
 use std::{
-    cmp::Ordering,
+    cmp::Reverse,
     future::Future,
     ops::Range,
     path::{Path, PathBuf},
@@ -407,16 +408,16 @@ impl VectorDatabase {
         query_embedding: &Embedding,
         limit: usize,
         file_ids: &[i64],
-    ) -> impl Future<Output = Result<Vec<(i64, f32)>>> {
+    ) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
         let query_embedding = query_embedding.clone();
         let file_ids = file_ids.to_vec();
         self.transact(move |db| {
-            let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
+            let mut results = Vec::<(i64, OrderedFloat<f32>)>::with_capacity(limit + 1);
             Self::for_each_span(db, &file_ids, |id, embedding| {
                 let similarity = embedding.similarity(&query_embedding);
-                let ix = match results.binary_search_by(|(_, s)| {
-                    similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
-                }) {
+                let ix = match results
+                    .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
+                {
                     Ok(ix) => ix,
                     Err(ix) => ix,
                 };

crates/semantic_index/src/embedding.rs 🔗

@@ -7,6 +7,7 @@ use isahc::http::StatusCode;
 use isahc::prelude::Configurable;
 use isahc::{AsyncBody, Response};
 use lazy_static::lazy_static;
+use ordered_float::OrderedFloat;
 use parking_lot::Mutex;
 use parse_duration::parse;
 use postage::watch;
@@ -35,7 +36,7 @@ impl From<Vec<f32>> for Embedding {
 }
 
 impl Embedding {
-    pub fn similarity(&self, other: &Self) -> f32 {
+    pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
         let len = self.0.len();
         assert_eq!(len, other.0.len());
 
@@ -58,7 +59,7 @@ impl Embedding {
                 1,
             );
         }
-        result
+        OrderedFloat(result)
     }
 }
 
@@ -379,13 +380,13 @@ mod tests {
             );
         }
 
-        fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
+        fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
             let factor = (10.0 as f32).powi(decimal_places);
             (n * factor).round() / factor
         }
 
-        fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
-            a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
+        fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
+            OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
         }
     }
 }

crates/semantic_index/src/semantic_index.rs 🔗

@@ -16,13 +16,14 @@ use embedding_queue::{EmbeddingQueue, FileToEmbed};
 use futures::{future, FutureExt, StreamExt};
 use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
 use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
+use ordered_float::OrderedFloat;
 use parking_lot::Mutex;
-use parsing::{CodeContextRetriever, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
+use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
 use postage::watch;
 use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
 use smol::channel;
 use std::{
-    cmp::Ordering,
+    cmp::Reverse,
     future::Future,
     mem,
     ops::Range,
@@ -267,7 +268,7 @@ pub struct PendingFile {
 pub struct SearchResult {
     pub buffer: ModelHandle<Buffer>,
     pub range: Range<Anchor>,
-    pub similarity: f32,
+    pub similarity: OrderedFloat<f32>,
 }
 
 impl SemanticIndex {
@@ -690,38 +691,70 @@ impl SemanticIndex {
     pub fn search_project(
         &mut self,
         project: ModelHandle<Project>,
-        phrase: String,
+        query: String,
         limit: usize,
         includes: Vec<PathMatcher>,
-        mut excludes: Vec<PathMatcher>,
+        excludes: Vec<PathMatcher>,
         cx: &mut ModelContext<Self>,
     ) -> Task<Result<Vec<SearchResult>>> {
+        if query.is_empty() {
+            return Task::ready(Ok(Vec::new()));
+        }
+
         let index = self.index_project(project.clone(), cx);
         let embedding_provider = self.embedding_provider.clone();
-        let db_path = self.db.path().clone();
-        let fs = self.fs.clone();
+
         cx.spawn(|this, mut cx| async move {
+            let query = embedding_provider
+                .embed_batch(vec![query])
+                .await?
+                .pop()
+                .ok_or_else(|| anyhow!("could not embed query"))?;
             index.await?;
 
-            let t0 = Instant::now();
-            let database =
-                VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
+            let search_start = Instant::now();
+            let modified_buffer_results = this.update(&mut cx, |this, cx| {
+                this.search_modified_buffers(&project, query.clone(), limit, &excludes, cx)
+            });
+            let file_results = this.update(&mut cx, |this, cx| {
+                this.search_files(project, query, limit, includes, excludes, cx)
+            });
+            let (modified_buffer_results, file_results) =
+                futures::join!(modified_buffer_results, file_results);
 
-            if phrase.len() == 0 {
-                return Ok(Vec::new());
+            // Weave together the results from modified buffers and files.
+            let mut results = Vec::new();
+            let mut modified_buffers = HashSet::default();
+            for result in modified_buffer_results.log_err().unwrap_or_default() {
+                modified_buffers.insert(result.buffer.clone());
+                results.push(result);
             }
+            for result in file_results.log_err().unwrap_or_default() {
+                if !modified_buffers.contains(&result.buffer) {
+                    results.push(result);
+                }
+            }
+            results.sort_by_key(|result| Reverse(result.similarity));
+            results.truncate(limit);
+            log::trace!("Semantic search took {:?}", search_start.elapsed());
+            Ok(results)
+        })
+    }
 
-            let phrase_embedding = embedding_provider
-                .embed_batch(vec![phrase])
-                .await?
-                .into_iter()
-                .next()
-                .unwrap();
-
-            log::trace!(
-                "Embedding search phrase took: {:?} milliseconds",
-                t0.elapsed().as_millis()
-            );
+    pub fn search_files(
+        &mut self,
+        project: ModelHandle<Project>,
+        query: Embedding,
+        limit: usize,
+        includes: Vec<PathMatcher>,
+        excludes: Vec<PathMatcher>,
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<Vec<SearchResult>>> {
+        let db_path = self.db.path().clone();
+        let fs = self.fs.clone();
+        cx.spawn(|this, mut cx| async move {
+            let database =
+                VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
 
             let worktree_db_ids = this.read_with(&cx, |this, _| {
                 let project_state = this
@@ -742,42 +775,6 @@ impl SemanticIndex {
                 anyhow::Ok(worktree_db_ids)
             })?;
 
-            let (dirty_buffers, dirty_paths) = project.read_with(&cx, |project, cx| {
-                let mut dirty_paths = Vec::new();
-                let dirty_buffers = project
-                    .opened_buffers(cx)
-                    .into_iter()
-                    .filter_map(|buffer_handle| {
-                        let buffer = buffer_handle.read(cx);
-                        if buffer.is_dirty() {
-                            let snapshot = buffer.snapshot();
-                            if let Some(file_pathbuf) = snapshot.resolve_file_path(cx, false) {
-                                let file_path = file_pathbuf.as_path();
-
-                                if excludes.iter().any(|glob| glob.is_match(file_path)) {
-                                    return None;
-                                }
-
-                                file_pathbuf
-                                    .to_str()
-                                    .and_then(|path| PathMatcher::new(path).log_err())
-                                    .and_then(|path_matcher| {
-                                        dirty_paths.push(path_matcher);
-                                        Some(())
-                                    });
-                            }
-                            // TOOD: @as-cii I removed the downgrade for now to fix the compiler - @kcaverly
-                            Some((buffer_handle, buffer.snapshot()))
-                        } else {
-                            None
-                        }
-                    })
-                    .collect::<HashMap<_, _>>();
-
-                (dirty_buffers, dirty_paths)
-            });
-
-            excludes.extend(dirty_paths);
             let file_ids = database
                 .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
                 .await?;
@@ -796,155 +793,26 @@ impl SemanticIndex {
                 let limit = limit.clone();
                 let fs = fs.clone();
                 let db_path = db_path.clone();
-                let phrase_embedding = phrase_embedding.clone();
+                let query = query.clone();
                 if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
                     .await
                     .log_err()
                 {
                     batch_results.push(async move {
-                        db.top_k_search(&phrase_embedding, limit, batch.as_slice())
-                            .await
+                        db.top_k_search(&query, limit, batch.as_slice()).await
                     });
                 }
             }
 
-            let buffer_results = if let Some(db) =
-                VectorDatabase::new(fs, db_path.clone(), cx.background())
-                    .await
-                    .log_err()
-            {
-                cx.background()
-                    .spawn({
-                        let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
-                        let embedding_provider = embedding_provider.clone();
-                        let phrase_embedding = phrase_embedding.clone();
-                        async move {
-                            let mut results = Vec::<SearchResult>::new();
-                            'buffers: for (buffer_handle, buffer_snapshot) in dirty_buffers {
-                                let language = buffer_snapshot
-                                    .language_at(0)
-                                    .cloned()
-                                    .unwrap_or_else(|| language::PLAIN_TEXT.clone());
-                                if let Some(spans) = retriever
-                                    .parse_file_with_template(
-                                        None,
-                                        &buffer_snapshot.text(),
-                                        language,
-                                    )
-                                    .log_err()
-                                {
-                                    let mut batch = Vec::new();
-                                    let mut batch_tokens = 0;
-                                    let mut embeddings = Vec::new();
-
-                                    let digests = spans
-                                        .iter()
-                                        .map(|span| span.digest.clone())
-                                        .collect::<Vec<_>>();
-                                    let embeddings_for_digests = db
-                                        .embeddings_for_digests(digests)
-                                        .await
-                                        .map_or(Default::default(), |m| m);
-
-                                    for span in &spans {
-                                        if embeddings_for_digests.contains_key(&span.digest) {
-                                            continue;
-                                        };
-
-                                        if batch_tokens + span.token_count
-                                            > embedding_provider.max_tokens_per_batch()
-                                        {
-                                            if let Some(batch_embeddings) = embedding_provider
-                                                .embed_batch(mem::take(&mut batch))
-                                                .await
-                                                .log_err()
-                                            {
-                                                embeddings.extend(batch_embeddings);
-                                                batch_tokens = 0;
-                                            } else {
-                                                continue 'buffers;
-                                            }
-                                        }
-
-                                        batch_tokens += span.token_count;
-                                        batch.push(span.content.clone());
-                                    }
-
-                                    if let Some(batch_embeddings) = embedding_provider
-                                        .embed_batch(mem::take(&mut batch))
-                                        .await
-                                        .log_err()
-                                    {
-                                        embeddings.extend(batch_embeddings);
-                                    } else {
-                                        continue 'buffers;
-                                    }
-
-                                    let mut embeddings = embeddings.into_iter();
-                                    for span in spans {
-                                        let embedding = if let Some(embedding) =
-                                            embeddings_for_digests.get(&span.digest)
-                                        {
-                                            Some(embedding.clone())
-                                        } else {
-                                            embeddings.next()
-                                        };
-
-                                        if let Some(embedding) = embedding {
-                                            let similarity =
-                                                embedding.similarity(&phrase_embedding);
-
-                                            let ix = match results.binary_search_by(|s| {
-                                                similarity
-                                                    .partial_cmp(&s.similarity)
-                                                    .unwrap_or(Ordering::Equal)
-                                            }) {
-                                                Ok(ix) => ix,
-                                                Err(ix) => ix,
-                                            };
-
-                                            let range = {
-                                                let start = buffer_snapshot
-                                                    .clip_offset(span.range.start, Bias::Left);
-                                                let end = buffer_snapshot
-                                                    .clip_offset(span.range.end, Bias::Right);
-                                                buffer_snapshot.anchor_before(start)
-                                                    ..buffer_snapshot.anchor_after(end)
-                                            };
-
-                                            results.insert(
-                                                ix,
-                                                SearchResult {
-                                                    buffer: buffer_handle.clone(),
-                                                    range,
-                                                    similarity,
-                                                },
-                                            );
-                                            results.truncate(limit);
-                                        } else {
-                                            log::error!("failed to embed span");
-                                            continue 'buffers;
-                                        }
-                                    }
-                                }
-                            }
-                            anyhow::Ok(results)
-                        }
-                    })
-                    .await
-            } else {
-                Ok(Vec::new())
-            };
-
             let batch_results = futures::future::join_all(batch_results).await;
 
             let mut results = Vec::new();
             for batch_result in batch_results {
                 if batch_result.is_ok() {
                     for (id, similarity) in batch_result.unwrap() {
-                        let ix = match results.binary_search_by(|(_, s)| {
-                            similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
-                        }) {
+                        let ix = match results
+                            .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
+                        {
                             Ok(ix) => ix,
                             Err(ix) => ix,
                         };
@@ -958,7 +826,7 @@ impl SemanticIndex {
             let scores = results
                 .into_iter()
                 .map(|(_, score)| score)
-                .collect::<Vec<f32>>();
+                .collect::<Vec<_>>();
             let spans = database.spans_for_ids(ids.as_slice()).await?;
 
             let mut tasks = Vec::new();
@@ -983,12 +851,7 @@ impl SemanticIndex {
 
             let buffers = futures::future::join_all(tasks).await;
 
-            log::trace!(
-                "Semantic Searching took: {:?} milliseconds in total",
-                t0.elapsed().as_millis()
-            );
-
-            let mut database_results = buffers
+            Ok(buffers
                 .into_iter()
                 .zip(ranges)
                 .zip(scores)
@@ -1005,26 +868,89 @@ impl SemanticIndex {
                         similarity,
                     })
                 })
-                .collect::<Vec<_>>();
+                .collect())
+        })
+    }
 
-            // Stitch Together Database Results & Buffer Results
-            if let Ok(buffer_results) = buffer_results {
-                for buffer_result in buffer_results {
-                    let ix = match database_results.binary_search_by(|search_result| {
-                        buffer_result
-                            .similarity
-                            .partial_cmp(&search_result.similarity)
-                            .unwrap_or(Ordering::Equal)
-                    }) {
-                        Ok(ix) => ix,
-                        Err(ix) => ix,
-                    };
-                    database_results.insert(ix, buffer_result);
-                    database_results.truncate(limit);
+    fn search_modified_buffers(
+        &self,
+        project: &ModelHandle<Project>,
+        query: Embedding,
+        limit: usize,
+        excludes: &[PathMatcher],
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<Vec<SearchResult>>> {
+        let modified_buffers = project
+            .read(cx)
+            .opened_buffers(cx)
+            .into_iter()
+            .filter_map(|buffer_handle| {
+                let buffer = buffer_handle.read(cx);
+                let snapshot = buffer.snapshot();
+                let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
+                    excludes.iter().any(|matcher| matcher.is_match(&path))
+                });
+                if buffer.is_dirty() && !excluded {
+                    Some((buffer_handle, snapshot))
+                } else {
+                    None
+                }
+            })
+            .collect::<HashMap<_, _>>();
+
+        let embedding_provider = self.embedding_provider.clone();
+        let fs = self.fs.clone();
+        let db_path = self.db.path().clone();
+        let background = cx.background().clone();
+        cx.background().spawn(async move {
+            let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
+            let mut results = Vec::<SearchResult>::new();
+
+            let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
+            for (buffer, snapshot) in modified_buffers {
+                let language = snapshot
+                    .language_at(0)
+                    .cloned()
+                    .unwrap_or_else(|| language::PLAIN_TEXT.clone());
+                let mut spans = retriever
+                    .parse_file_with_template(None, &snapshot.text(), language)
+                    .log_err()
+                    .unwrap_or_default();
+                if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
+                    .await
+                    .log_err()
+                    .is_some()
+                {
+                    for span in spans {
+                        let similarity = span.embedding.unwrap().similarity(&query);
+                        let ix = match results
+                            .binary_search_by_key(&Reverse(similarity), |result| {
+                                Reverse(result.similarity)
+                            }) {
+                            Ok(ix) => ix,
+                            Err(ix) => ix,
+                        };
+
+                        let range = {
+                            let start = snapshot.clip_offset(span.range.start, Bias::Left);
+                            let end = snapshot.clip_offset(span.range.end, Bias::Right);
+                            snapshot.anchor_before(start)..snapshot.anchor_after(end)
+                        };
+
+                        results.insert(
+                            ix,
+                            SearchResult {
+                                buffer: buffer.clone(),
+                                range,
+                                similarity,
+                            },
+                        );
+                        results.truncate(limit);
+                    }
                 }
             }
 
-            Ok(database_results)
+            Ok(results)
         })
     }
 
@@ -1208,6 +1134,63 @@ impl SemanticIndex {
             Ok(())
         })
     }
+
+    async fn embed_spans(
+        spans: &mut [Span],
+        embedding_provider: &dyn EmbeddingProvider,
+        db: &VectorDatabase,
+    ) -> Result<()> {
+        let mut batch = Vec::new();
+        let mut batch_tokens = 0;
+        let mut embeddings = Vec::new();
+
+        let digests = spans
+            .iter()
+            .map(|span| span.digest.clone())
+            .collect::<Vec<_>>();
+        let embeddings_for_digests = db
+            .embeddings_for_digests(digests)
+            .await
+            .log_err()
+            .unwrap_or_default();
+
+        for span in &*spans {
+            if embeddings_for_digests.contains_key(&span.digest) {
+                continue;
+            };
+
+            if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
+                let batch_embeddings = embedding_provider
+                    .embed_batch(mem::take(&mut batch))
+                    .await?;
+                embeddings.extend(batch_embeddings);
+                batch_tokens = 0;
+            }
+
+            batch_tokens += span.token_count;
+            batch.push(span.content.clone());
+        }
+
+        if !batch.is_empty() {
+            let batch_embeddings = embedding_provider
+                .embed_batch(mem::take(&mut batch))
+                .await?;
+
+            embeddings.extend(batch_embeddings);
+        }
+
+        let mut embeddings = embeddings.into_iter();
+        for span in spans {
+            let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
+                Some(embedding.clone())
+            } else {
+                embeddings.next()
+            };
+            let embedding = embedding.ok_or_else(|| anyhow!("failed to embed spans"))?;
+            span.embedding = Some(embedding);
+        }
+        Ok(())
+    }
 }
 
 impl Entity for SemanticIndex {