Reify `Embedding`/`Sha1` structs that can be (de)serialized from SQL

Antonio Scandurra and Kyle Caverly created

Co-Authored-By: Kyle Caverly <kyle@zed.dev>

Change summary

crates/semantic_index/src/db.rs                   |  76 +----------
crates/semantic_index/src/embedding.rs            | 114 ++++++++++++++++
crates/semantic_index/src/embedding_queue.rs      |   2 
crates/semantic_index/src/parsing.rs              |  69 +++++++---
crates/semantic_index/src/semantic_index_tests.rs |  57 ++------
5 files changed, 180 insertions(+), 138 deletions(-)

Detailed changes

crates/semantic_index/src/db.rs 🔗

@@ -1,13 +1,10 @@
-use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
+use crate::{embedding::Embedding, parsing::Document, SEMANTIC_INDEX_VERSION};
 use anyhow::{anyhow, Context, Result};
 use futures::channel::oneshot;
 use gpui::executor;
 use project::{search::PathMatcher, Fs};
 use rpc::proto::Timestamp;
-use rusqlite::{
-    params,
-    types::{FromSql, FromSqlResult, ValueRef},
-};
+use rusqlite::params;
 use std::{
     cmp::Ordering,
     collections::HashMap,
@@ -27,34 +24,6 @@ pub struct FileRecord {
     pub mtime: Timestamp,
 }
 
-#[derive(Debug)]
-struct Embedding(pub Vec<f32>);
-
-#[derive(Debug)]
-struct Sha1(pub Vec<u8>);
-
-impl FromSql for Embedding {
-    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
-        let bytes = value.as_blob()?;
-        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
-        if embedding.is_err() {
-            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
-        }
-        return Ok(Embedding(embedding.unwrap()));
-    }
-}
-
-impl FromSql for Sha1 {
-    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
-        let bytes = value.as_blob()?;
-        let sha1: Result<Vec<u8>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
-        if sha1.is_err() {
-            return Err(rusqlite::types::FromSqlError::Other(sha1.unwrap_err()));
-        }
-        return Ok(Sha1(sha1.unwrap()));
-    }
-}
-
 #[derive(Clone)]
 pub struct VectorDatabase {
     path: Arc<Path>,
@@ -255,9 +224,6 @@ impl VectorDatabase {
             // Currently inserting at approximately 3400 documents a second
             // I imagine we can speed this up with a bulk insert of some kind.
             for document in documents {
-                let embedding_blob = bincode::serialize(&document.embedding)?;
-                let sha_blob = bincode::serialize(&document.sha1)?;
-
                 db.execute(
                     "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
                     params![
@@ -265,8 +231,8 @@ impl VectorDatabase {
                         document.range.start.to_string(),
                         document.range.end.to_string(),
                         document.name,
-                        embedding_blob,
-                        sha_blob
+                        document.embedding,
+                        document.sha1
                     ],
                 )?;
            }
@@ -351,7 +317,7 @@ impl VectorDatabase {
 
     pub fn top_k_search(
         &self,
-        query_embedding: &Vec<f32>,
+        query_embedding: &Embedding,
         limit: usize,
         file_ids: &[i64],
     ) -> impl Future<Output = Result<Vec<(i64, f32)>>> {
@@ -360,7 +326,7 @@ impl VectorDatabase {
         self.transact(move |db| {
             let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
             Self::for_each_document(db, &file_ids, |id, embedding| {
-                let similarity = dot(&embedding, &query_embedding);
+                let similarity = embedding.similarity(&query_embedding);
                 let ix = match results.binary_search_by(|(_, s)| {
                     similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
                 }) {
@@ -417,7 +383,7 @@ impl VectorDatabase {
     fn for_each_document(
         db: &rusqlite::Connection,
         file_ids: &[i64],
-        mut f: impl FnMut(i64, Vec<f32>),
+        mut f: impl FnMut(i64, Embedding),
     ) -> Result<()> {
         let mut query_statement = db.prepare(
             "
@@ -435,7 +401,7 @@ impl VectorDatabase {
                 Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
             })?
             .filter_map(|row| row.ok())
-            .for_each(|(id, embedding)| f(id, embedding.0));
+            .for_each(|(id, embedding)| f(id, embedding));
         Ok(())
     }
 
@@ -497,29 +463,3 @@ fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
             .collect::<Vec<_>>(),
     )
 }
-
-pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
-    let len = vec_a.len();
-    assert_eq!(len, vec_b.len());
-
-    let mut result = 0.0;
-    unsafe {
-        matrixmultiply::sgemm(
-            1,
-            len,
-            1,
-            1.0,
-            vec_a.as_ptr(),
-            len as isize,
-            1,
-            vec_b.as_ptr(),
-            1,
-            len as isize,
-            0.0,
-            &mut result as *mut f32,
-            1,
-            1,
-        );
-    }
-    result
-}

crates/semantic_index/src/embedding.rs 🔗

@@ -8,6 +8,8 @@ use isahc::prelude::Configurable;
 use isahc::{AsyncBody, Response};
 use lazy_static::lazy_static;
 use parse_duration::parse;
+use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
+use rusqlite::ToSql;
 use serde::{Deserialize, Serialize};
 use std::env;
 use std::sync::Arc;
@@ -20,6 +22,62 @@ lazy_static! {
     static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
 }
 
+#[derive(Debug, PartialEq, Clone)]
+pub struct Embedding(Vec<f32>);
+
+impl From<Vec<f32>> for Embedding {
+    fn from(value: Vec<f32>) -> Self {
+        Embedding(value)
+    }
+}
+
+impl Embedding {
+    pub fn similarity(&self, other: &Self) -> f32 {
+        let len = self.0.len();
+        assert_eq!(len, other.0.len());
+
+        let mut result = 0.0;
+        unsafe {
+            matrixmultiply::sgemm(
+                1,
+                len,
+                1,
+                1.0,
+                self.0.as_ptr(),
+                len as isize,
+                1,
+                other.0.as_ptr(),
+                1,
+                len as isize,
+                0.0,
+                &mut result as *mut f32,
+                1,
+                1,
+            );
+        }
+        result
+    }
+}
+
+impl FromSql for Embedding {
+    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
+        let bytes = value.as_blob()?;
+        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
+        if embedding.is_err() {
+            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
+        }
+        Ok(Embedding(embedding.unwrap()))
+    }
+}
+
+impl ToSql for Embedding {
+    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
+        let bytes = bincode::serialize(&self.0)
+            .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
+        Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
+    }
+}
+
 #[derive(Clone)]
 pub struct OpenAIEmbeddings {
     pub client: Arc<dyn HttpClient>,
@@ -53,7 +111,7 @@ struct OpenAIEmbeddingUsage {
 
 #[async_trait]
 pub trait EmbeddingProvider: Sync + Send {
-    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>>;
+    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
     fn max_tokens_per_batch(&self) -> usize;
     fn truncate(&self, span: &str) -> (String, usize);
 }
@@ -62,10 +120,10 @@ pub struct DummyEmbeddings {}
 
 #[async_trait]
 impl EmbeddingProvider for DummyEmbeddings {
-    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
+    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
         // 1024 is the OpenAI Embeddings size for ada models.
         // the model we will likely be starting with.
-        let dummy_vec = vec![0.32 as f32; 1536];
+        let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
         return Ok(vec![dummy_vec; spans.len()]);
     }
 
@@ -137,7 +195,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
         (output, token_count)
     }
 
-    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
+    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
         const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
         const MAX_RETRIES: usize = 4;
 
@@ -175,7 +233,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
                     return Ok(response
                         .data
                         .into_iter()
-                        .map(|embedding| embedding.embedding)
+                        .map(|embedding| Embedding::from(embedding.embedding))
                         .collect());
                 }
                 StatusCode::TOO_MANY_REQUESTS => {
@@ -218,3 +276,49 @@ impl EmbeddingProvider for OpenAIEmbeddings {
         Err(anyhow!("openai max retries"))
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use rand::prelude::*;
+
+    #[gpui::test]
+    fn test_similarity(mut rng: StdRng) {
+        assert_eq!(
+            Embedding::from(vec![1., 0., 0., 0., 0.])
+                .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
+            0.
+        );
+        assert_eq!(
+            Embedding::from(vec![2., 0., 0., 0., 0.])
+                .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
+            6.
+        );
+
+        for _ in 0..100 {
+            let size = 1536;
+            let mut a = vec![0.; size];
+            let mut b = vec![0.; size];
+            for (a, b) in a.iter_mut().zip(b.iter_mut()) {
+                *a = rng.gen();
+                *b = rng.gen();
+            }
+            let a = Embedding::from(a);
+            let b = Embedding::from(b);
+
+            assert_eq!(
+                round_to_decimals(a.similarity(&b), 1),
+                round_to_decimals(reference_dot(&a.0, &b.0), 1)
+            );
+        }
+
+        fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
+            let factor = (10.0 as f32).powi(decimal_places);
+            (n * factor).round() / factor
+        }
+
+        fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
+            a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
+        }
+    }
+}

crates/semantic_index/src/embedding_queue.rs 🔗

@@ -121,7 +121,7 @@ impl EmbeddingQueue {
                             &mut fragment.file.lock().documents[fragment.document_range.clone()]
                         {
                             if let Some(embedding) = embeddings.next() {
-                                document.embedding = embedding;
+                                document.embedding = Some(embedding);
                             } else {
                                 //
                                 log::error!("number of embeddings returned different from number of documents");

crates/semantic_index/src/parsing.rs 🔗

@@ -1,7 +1,11 @@
-use crate::embedding::EmbeddingProvider;
-use anyhow::{anyhow, Ok, Result};
+use crate::embedding::{EmbeddingProvider, Embedding};
+use anyhow::{anyhow, Result};
 use language::{Grammar, Language};
-use sha1::{Digest, Sha1};
+use rusqlite::{
+    types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
+    ToSql,
+};
+use sha1::Digest;
 use std::{
     cmp::{self, Reverse},
     collections::HashSet,
@@ -11,13 +15,43 @@ use std::{
 };
 use tree_sitter::{Parser, QueryCursor};
 
+#[derive(Debug, PartialEq, Clone)]
+pub struct Sha1([u8; 20]);
+
+impl FromSql for Sha1 {
+    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
+        let blob = value.as_blob()?;
+        let bytes =
+            blob.try_into()
+                .map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
+                    expected_size: 20,
+                    blob_size: blob.len(),
+                })?;
+        return Ok(Sha1(bytes));
+    }
+}
+
+impl ToSql for Sha1 {
+    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
+        self.0.to_sql()
+    }
+}
+
+impl From<&'_ str> for Sha1 {
+    fn from(value: &'_ str) -> Self {
+        let mut sha1 = sha1::Sha1::new();
+        sha1.update(value);
+        Self(sha1.finalize().into())
+    }
+}
+
 #[derive(Debug, PartialEq, Clone)]
 pub struct Document {
     pub name: String,
     pub range: Range<usize>,
     pub content: String,
-    pub embedding: Vec<f32>,
-    pub sha1: [u8; 20],
+    pub embedding: Option<Embedding>,
+    pub sha1: Sha1,
     pub token_count: usize,
 }
 
@@ -69,17 +103,16 @@ impl CodeContextRetriever {
             .replace("<language>", language_name.as_ref())
             .replace("<item>", &content);
 
-        let mut sha1 = Sha1::new();
-        sha1.update(&document_span);
+        let sha1 = Sha1::from(document_span.as_str());
 
         let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
 
         Ok(vec![Document {
             range: 0..content.len(),
             content: document_span,
-            embedding: Vec::new(),
+            embedding: Default::default(),
             name: language_name.to_string(),
-            sha1: sha1.finalize().into(),
+            sha1,
             token_count,
         }])
     }
@@ -88,18 +121,14 @@ impl CodeContextRetriever {
         let document_span = MARKDOWN_CONTEXT_TEMPLATE
             .replace("<path>", relative_path.to_string_lossy().as_ref())
             .replace("<item>", &content);
-
-        let mut sha1 = Sha1::new();
-        sha1.update(&document_span);
-
+        let sha1 = Sha1::from(document_span.as_str());
         let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
-
         Ok(vec![Document {
             range: 0..content.len(),
             content: document_span,
-            embedding: Vec::new(),
+            embedding: None,
             name: "Markdown".to_string(),
-            sha1: sha1.finalize().into(),
+            sha1,
             token_count,
         }])
     }
@@ -279,15 +308,13 @@ impl CodeContextRetriever {
                 );
             }
 
-            let mut sha1 = Sha1::new();
-            sha1.update(&document_content);
-
+            let sha1 = Sha1::from(document_content.as_str());
             documents.push(Document {
                 name,
                 content: document_content,
                 range: item_range.clone(),
-                embedding: vec![],
-                sha1: sha1.finalize().into(),
+                embedding: None,
+                sha1,
                 token_count: 0,
             })
         }

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -1,8 +1,7 @@
 use crate::{
-    db::dot,
-    embedding::{DummyEmbeddings, EmbeddingProvider},
+    embedding::{DummyEmbeddings, Embedding, EmbeddingProvider},
     embedding_queue::EmbeddingQueue,
-    parsing::{subtract_ranges, CodeContextRetriever, Document},
+    parsing::{subtract_ranges, CodeContextRetriever, Document, Sha1},
     semantic_index_settings::SemanticIndexSettings,
     FileToEmbed, JobHandle, SearchResult, SemanticIndex,
 };
@@ -217,15 +216,17 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
             documents: (0..rng.gen_range(4..22))
                 .map(|document_ix| {
                     let content_len = rng.gen_range(10..100);
+                    let content = RandomCharIter::new(&mut rng)
+                        .with_simple_text()
+                        .take(content_len)
+                        .collect::<String>();
+                    let sha1 = Sha1::from(content.as_str());
                     Document {
                         range: 0..10,
-                        embedding: Vec::new(),
+                        embedding: None,
                         name: format!("document {document_ix}"),
-                        content: RandomCharIter::new(&mut rng)
-                            .with_simple_text()
-                            .take(content_len)
-                            .collect(),
-                        sha1: rng.gen(),
+                        content,
+                        sha1,
                         token_count: rng.gen_range(10..30),
                     }
                 })
@@ -254,7 +255,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
         .map(|file| {
             let mut file = file.clone();
             for doc in &mut file.documents {
-                doc.embedding = embedding_provider.embed_sync(doc.content.as_ref());
+                doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
             }
             file
         })
@@ -1242,36 +1243,6 @@ async fn test_code_context_retrieval_php() {
     );
 }
 
-#[gpui::test]
-fn test_dot_product(mut rng: StdRng) {
-    assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
-    assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
-
-    for _ in 0..100 {
-        let size = 1536;
-        let mut a = vec![0.; size];
-        let mut b = vec![0.; size];
-        for (a, b) in a.iter_mut().zip(b.iter_mut()) {
-            *a = rng.gen();
-            *b = rng.gen();
-        }
-
-        assert_eq!(
-            round_to_decimals(dot(&a, &b), 1),
-            round_to_decimals(reference_dot(&a, &b), 1)
-        );
-    }
-
-    fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
-        let factor = (10.0 as f32).powi(decimal_places);
-        (n * factor).round() / factor
-    }
-
-    fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
-        a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
-    }
-}
-
 #[derive(Default)]
 struct FakeEmbeddingProvider {
     embedding_count: AtomicUsize,
@@ -1282,7 +1253,7 @@ impl FakeEmbeddingProvider {
         self.embedding_count.load(atomic::Ordering::SeqCst)
     }
 
-    fn embed_sync(&self, span: &str) -> Vec<f32> {
+    fn embed_sync(&self, span: &str) -> Embedding {
         let mut result = vec![1.0; 26];
         for letter in span.chars() {
             let letter = letter.to_ascii_lowercase();
@@ -1299,7 +1270,7 @@ impl FakeEmbeddingProvider {
             *x /= norm;
         }
 
-        result
+        result.into()
     }
 }
 
@@ -1313,7 +1284,7 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
         200
     }
 
-    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Vec<f32>>> {
+    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
         self.embedding_count
             .fetch_add(spans.len(), atomic::Ordering::SeqCst);
         Ok(spans.iter().map(|span| self.embed_sync(span)).collect())