add initial search inside modified buffers

KCaverly created

Change summary

crates/semantic_index/src/db.rs             |  33 +++
crates/semantic_index/src/parsing.rs        |   2 
crates/semantic_index/src/semantic_index.rs | 248 ++++++++++++++++------
3 files changed, 217 insertions(+), 66 deletions(-)

Detailed changes

crates/semantic_index/src/db.rs 🔗

@@ -278,6 +278,39 @@ impl VectorDatabase {
         })
     }
 
+    pub fn embeddings_for_digests(
+        &self,
+        digests: Vec<SpanDigest>,
+    ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
+        self.transact(move |db| {
+            let mut query = db.prepare(
+                "
+                SELECT digest, embedding
+                FROM spans
+                WHERE digest IN rarray(?)
+                ",
+            )?;
+            let mut embeddings_by_digest = HashMap::default();
+            let digests = Rc::new(
+                digests
+                    .into_iter()
+                    .map(|p| Value::Blob(p.0.to_vec()))
+                    .collect::<Vec<_>>(),
+            );
+            let rows = query.query_map(params![digests], |row| {
+                Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
+            })?;
+
+            for row in rows {
+                if let Ok(row) = row {
+                    embeddings_by_digest.insert(row.0, row.1);
+                }
+            }
+
+            Ok(embeddings_by_digest)
+        })
+    }
+
     pub fn embeddings_for_files(
         &self,
         worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,

crates/semantic_index/src/parsing.rs 🔗

@@ -17,7 +17,7 @@ use std::{
 use tree_sitter::{Parser, QueryCursor};
 
 #[derive(Debug, PartialEq, Eq, Clone, Hash)]
-pub struct SpanDigest([u8; 20]);
+pub struct SpanDigest(pub [u8; 20]);
 
 impl FromSql for SpanDigest {
     fn column_result(value: ValueRef) -> FromSqlResult<Self> {

crates/semantic_index/src/semantic_index.rs 🔗

@@ -263,9 +263,11 @@ pub struct PendingFile {
     job_handle: JobHandle,
 }
 
+#[derive(Clone)]
 pub struct SearchResult {
     pub buffer: ModelHandle<Buffer>,
     pub range: Range<Anchor>,
+    pub similarity: f32,
 }
 
 impl SemanticIndex {
@@ -775,7 +777,8 @@ impl SemanticIndex {
                     .filter_map(|buffer_handle| {
                         let buffer = buffer_handle.read(cx);
                         if buffer.is_dirty() {
-                            Some((buffer_handle.downgrade(), buffer.snapshot()))
+                            // TOOD: @as-cii I removed the downgrade for now to fix the compiler - @kcaverly
+                            Some((buffer_handle, buffer.snapshot()))
                         } else {
                             None
                         }
@@ -783,77 +786,133 @@ impl SemanticIndex {
                     .collect::<HashMap<_, _>>()
             });
 
-            cx.background()
-                .spawn({
-                    let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
-                    let embedding_provider = embedding_provider.clone();
-                    let phrase_embedding = phrase_embedding.clone();
-                    async move {
-                        let mut results = Vec::new();
-                        'buffers: for (buffer_handle, buffer_snapshot) in dirty_buffers {
-                            let language = buffer_snapshot
-                                .language_at(0)
-                                .cloned()
-                                .unwrap_or_else(|| language::PLAIN_TEXT.clone());
-                            if let Some(spans) = retriever
-                                .parse_file_with_template(None, &buffer_snapshot.text(), language)
-                                .log_err()
-                            {
-                                let mut batch = Vec::new();
-                                let mut batch_tokens = 0;
-                                let mut embeddings = Vec::new();
-
-                                // TODO: query span digests in the database to avoid embedding them again.
+            let buffer_results = if let Some(db) =
+                VectorDatabase::new(fs, db_path.clone(), cx.background())
+                    .await
+                    .log_err()
+            {
+                cx.background()
+                    .spawn({
+                        let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
+                        let embedding_provider = embedding_provider.clone();
+                        let phrase_embedding = phrase_embedding.clone();
+                        async move {
+                            let mut results = Vec::<SearchResult>::new();
+                            'buffers: for (buffer_handle, buffer_snapshot) in dirty_buffers {
+                                let language = buffer_snapshot
+                                    .language_at(0)
+                                    .cloned()
+                                    .unwrap_or_else(|| language::PLAIN_TEXT.clone());
+                                if let Some(spans) = retriever
+                                    .parse_file_with_template(
+                                        None,
+                                        &buffer_snapshot.text(),
+                                        language,
+                                    )
+                                    .log_err()
+                                {
+                                    let mut batch = Vec::new();
+                                    let mut batch_tokens = 0;
+                                    let mut embeddings = Vec::new();
+
+                                    let digests = spans
+                                        .iter()
+                                        .map(|span| span.digest.clone())
+                                        .collect::<Vec<_>>();
+                                    let embeddings_for_digests = db
+                                        .embeddings_for_digests(digests)
+                                        .await
+                                        .map_or(Default::default(), |m| m);
+
+                                    for span in &spans {
+                                        if embeddings_for_digests.contains_key(&span.digest) {
+                                            continue;
+                                        };
+
+                                        if batch_tokens + span.token_count
+                                            > embedding_provider.max_tokens_per_batch()
+                                        {
+                                            if let Some(batch_embeddings) = embedding_provider
+                                                .embed_batch(mem::take(&mut batch))
+                                                .await
+                                                .log_err()
+                                            {
+                                                embeddings.extend(batch_embeddings);
+                                                batch_tokens = 0;
+                                            } else {
+                                                continue 'buffers;
+                                            }
+                                        }
 
-                                for span in &spans {
-                                    if span.embedding.is_some() {
-                                        continue;
+                                        batch_tokens += span.token_count;
+                                        batch.push(span.content.clone());
                                     }
 
-                                    if batch_tokens + span.token_count
-                                        > embedding_provider.max_tokens_per_batch()
+                                    if let Some(batch_embeddings) = embedding_provider
+                                        .embed_batch(mem::take(&mut batch))
+                                        .await
+                                        .log_err()
                                     {
-                                        if let Some(batch_embeddings) = embedding_provider
-                                            .embed_batch(mem::take(&mut batch))
-                                            .await
-                                            .log_err()
+                                        embeddings.extend(batch_embeddings);
+                                    } else {
+                                        continue 'buffers;
+                                    }
+
+                                    let mut embeddings = embeddings.into_iter();
+                                    for span in spans {
+                                        let embedding = if let Some(embedding) =
+                                            embeddings_for_digests.get(&span.digest)
                                         {
-                                            embeddings.extend(batch_embeddings);
-                                            batch_tokens = 0;
+                                            Some(embedding.clone())
                                         } else {
+                                            embeddings.next()
+                                        };
+
+                                        if let Some(embedding) = embedding {
+                                            let similarity =
+                                                embedding.similarity(&phrase_embedding);
+
+                                            let ix = match results.binary_search_by(|s| {
+                                                similarity
+                                                    .partial_cmp(&s.similarity)
+                                                    .unwrap_or(Ordering::Equal)
+                                            }) {
+                                                Ok(ix) => ix,
+                                                Err(ix) => ix,
+                                            };
+
+                                            let range = {
+                                                let start = buffer_snapshot
+                                                    .clip_offset(span.range.start, Bias::Left);
+                                                let end = buffer_snapshot
+                                                    .clip_offset(span.range.end, Bias::Right);
+                                                buffer_snapshot.anchor_before(start)
+                                                    ..buffer_snapshot.anchor_after(end)
+                                            };
+
+                                            results.insert(
+                                                ix,
+                                                SearchResult {
+                                                    buffer: buffer_handle.clone(),
+                                                    range,
+                                                    similarity,
+                                                },
+                                            );
+                                            results.truncate(limit);
+                                        } else {
+                                            log::error!("failed to embed span");
                                             continue 'buffers;
                                         }
                                     }
-
-                                    batch_tokens += span.token_count;
-                                    batch.push(span.content.clone());
-                                }
-
-                                if let Some(batch_embeddings) = embedding_provider
-                                    .embed_batch(mem::take(&mut batch))
-                                    .await
-                                    .log_err()
-                                {
-                                    embeddings.extend(batch_embeddings);
-                                } else {
-                                    continue 'buffers;
-                                }
-
-                                let mut embeddings = embeddings.into_iter();
-                                for span in spans {
-                                    let embedding = span.embedding.or_else(|| embeddings.next());
-                                    if let Some(embedding) = embedding {
-                                        todo!()
-                                    } else {
-                                        log::error!("failed to embed span");
-                                        continue 'buffers;
-                                    }
                                 }
                             }
+                            anyhow::Ok(results)
                         }
-                    }
-                })
-                .await;
+                    })
+                    .await
+            } else {
+                Ok(Vec::new())
+            };
 
             let batch_results = futures::future::join_all(batch_results).await;
 
@@ -873,7 +932,11 @@ impl SemanticIndex {
                 }
             }
 
-            let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>();
+            let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
+            let scores = results
+                .into_iter()
+                .map(|(_, score)| score)
+                .collect::<Vec<f32>>();
             let spans = database.spans_for_ids(ids.as_slice()).await?;
 
             let mut tasks = Vec::new();
@@ -903,19 +966,74 @@ impl SemanticIndex {
                 t0.elapsed().as_millis()
             );
 
-            Ok(buffers
+            let database_results = buffers
                 .into_iter()
                 .zip(ranges)
-                .filter_map(|(buffer, range)| {
+                .zip(scores)
+                .filter_map(|((buffer, range), similarity)| {
                     let buffer = buffer.log_err()?;
                     let range = buffer.read_with(&cx, |buffer, _| {
                         let start = buffer.clip_offset(range.start, Bias::Left);
                         let end = buffer.clip_offset(range.end, Bias::Right);
                         buffer.anchor_before(start)..buffer.anchor_after(end)
                     });
-                    Some(SearchResult { buffer, range })
+                    Some(SearchResult {
+                        buffer,
+                        range,
+                        similarity,
+                    })
                 })
-                .collect::<Vec<_>>())
+                .collect::<Vec<_>>();
+
+            // Stitch Together Database Results & Buffer Results
+            if let Ok(buffer_results) = buffer_results {
+                let mut buffer_map = HashMap::default();
+                for buffer_result in buffer_results {
+                    buffer_map
+                        .entry(buffer_result.clone().buffer)
+                        .or_insert(Vec::new())
+                        .push(buffer_result);
+                }
+
+                for db_result in database_results {
+                    if !buffer_map.contains_key(&db_result.buffer) {
+                        buffer_map
+                            .entry(db_result.clone().buffer)
+                            .or_insert(Vec::new())
+                            .push(db_result);
+                    }
+                }
+
+                let mut full_results = Vec::<SearchResult>::new();
+
+                for (_, results) in buffer_map {
+                    for res in results.into_iter() {
+                        let ix = match full_results.binary_search_by(|search_result| {
+                            res.similarity
+                                .partial_cmp(&search_result.similarity)
+                                .unwrap_or(Ordering::Equal)
+                        }) {
+                            Ok(ix) => ix,
+                            Err(ix) => ix,
+                        };
+                        full_results.insert(ix, res);
+                        full_results.truncate(limit);
+                    }
+                }
+
+                return Ok(full_results);
+            } else {
+                return Ok(database_results);
+            }
+
+            // let ix = match results.binary_search_by(|(_, s)| {
+            //     similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
+            // }) {
+            //     Ok(ix) => ix,
+            //     Err(ix) => ix,
+            // };
+            // results.insert(ix, (id, similarity));
+            // results.truncate(limit);
         })
     }