WIP: Got the streaming matrix multiplication working, and started work on file hashing.

KCaverly and maxbrunsfeld created

Co-authored-by: maxbrunsfeld <max@zed.dev>

Change summary

Cargo.lock                                    |   5 
crates/vector_store/Cargo.toml                |   5 
crates/vector_store/src/db.rs                 |  84 +++++-
crates/vector_store/src/embedding.rs          |   2 
crates/vector_store/src/search.rs             |  18 -
crates/vector_store/src/vector_store.rs       | 247 +++++++++++++++-----
crates/vector_store/src/vector_store_tests.rs | 136 +++++++++++
7 files changed, 398 insertions(+), 99 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -7958,13 +7958,18 @@ dependencies = [
  "language",
  "lazy_static",
  "log",
+ "matrixmultiply",
  "ndarray",
  "project",
+ "rand 0.8.5",
  "rusqlite",
  "serde",
  "serde_json",
+ "sha-1 0.10.1",
  "smol",
  "tree-sitter",
+ "tree-sitter-rust",
+ "unindent",
  "util",
  "workspace",
 ]

crates/vector_store/Cargo.toml 🔗

@@ -27,9 +27,14 @@ serde_json.workspace = true
 async-trait.workspace = true
 bincode = "1.3.3"
 ndarray = "0.15.6"
+sha-1 = "0.10.1"
+matrixmultiply = "0.3.7"
 
 [dev-dependencies]
 gpui = { path = "../gpui", features = ["test-support"] }
 language = { path = "../language", features = ["test-support"] }
 project = { path = "../project", features = ["test-support"] }
 workspace = { path = "../workspace", features = ["test-support"] }
+tree-sitter-rust = "*"
+rand.workspace = true
+unindent.workspace = true

crates/vector_store/src/db.rs 🔗

@@ -1,4 +1,7 @@
-use std::{collections::HashMap, path::PathBuf};
+use std::{
+    collections::HashMap,
+    path::{Path, PathBuf},
+};
 
 use anyhow::{anyhow, Result};
 
@@ -13,7 +16,7 @@ use crate::IndexedFile;
 // This is saving to a local database store within the users dev zed path
 // Where do we want this to sit?
 // Assuming near where the workspace DB sits.
-const VECTOR_DB_URL: &str = "embeddings_db";
+pub const VECTOR_DB_URL: &str = "embeddings_db";
 
 // Note this is not an appropriate document
 #[derive(Debug)]
@@ -28,7 +31,7 @@ pub struct DocumentRecord {
 #[derive(Debug)]
 pub struct FileRecord {
     pub id: usize,
-    pub path: String,
+    pub relative_path: String,
     pub sha1: String,
 }
 
@@ -51,9 +54,9 @@ pub struct VectorDatabase {
 }
 
 impl VectorDatabase {
-    pub fn new() -> Result<Self> {
+    pub fn new(path: &str) -> Result<Self> {
         let this = Self {
-            db: rusqlite::Connection::open(VECTOR_DB_URL)?,
+            db: rusqlite::Connection::open(path)?,
         };
         this.initialize_database()?;
         Ok(this)
@@ -63,21 +66,23 @@ impl VectorDatabase {
         // This will create the database if it doesnt exist
 
         // Initialize Vector Databasing Tables
-        // self.db.execute(
-        //     "
-        //     CREATE TABLE IF NOT EXISTS projects (
-        //         id INTEGER PRIMARY KEY AUTOINCREMENT,
-        //         path NVARCHAR(100) NOT NULL
-        //     )
-        //     ",
-        //     [],
-        // )?;
+        self.db.execute(
+            "CREATE TABLE IF NOT EXISTS worktrees (
+                id INTEGER PRIMARY KEY AUTOINCREMENT,
+                absolute_path VARCHAR NOT NULL
+            );
+            CREATE UNIQUE INDEX IF NOT EXISTS worktrees_absolute_path ON worktrees (absolute_path);
+            ",
+            [],
+        )?;
 
         self.db.execute(
             "CREATE TABLE IF NOT EXISTS files (
                 id INTEGER PRIMARY KEY AUTOINCREMENT,
-                path NVARCHAR(100) NOT NULL,
-                sha1 NVARCHAR(40) NOT NULL
+                worktree_id INTEGER NOT NULL,
+                relative_path VARCHAR NOT NULL,
+                sha1 NVARCHAR(40) NOT NULL,
+                FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
             )",
             [],
         )?;
@@ -87,7 +92,7 @@ impl VectorDatabase {
                 id INTEGER PRIMARY KEY AUTOINCREMENT,
                 file_id INTEGER NOT NULL,
                 offset INTEGER NOT NULL,
-                name NVARCHAR(100) NOT NULL,
+                name VARCHAR NOT NULL,
                 embedding BLOB NOT NULL,
                 FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
             )",
@@ -116,7 +121,7 @@ impl VectorDatabase {
     pub fn insert_file(&self, indexed_file: IndexedFile) -> Result<()> {
         // Write to files table, and return generated id.
         let files_insert = self.db.execute(
-            "INSERT INTO files (path, sha1) VALUES (?1, ?2)",
+            "INSERT INTO files (relative_path, sha1) VALUES (?1, ?2)",
             params![indexed_file.path.to_str(), indexed_file.sha1],
         )?;
 
@@ -141,12 +146,38 @@ impl VectorDatabase {
         Ok(())
     }
 
+    pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
+        self.db.execute(
+            "
+            INSERT into worktrees (absolute_path) VALUES (?1)
+            ON CONFLICT DO NOTHING
+            ",
+            params![worktree_root_path.to_string_lossy()],
+        )?;
+        Ok(self.db.last_insert_rowid())
+    }
+
+    pub fn get_file_hashes(&self, worktree_id: i64) -> Result<Vec<(PathBuf, String)>> {
+        let mut statement = self
+            .db
+            .prepare("SELECT relative_path, sha1 FROM files ORDER BY relative_path")?;
+        let mut result = Vec::new();
+        for row in
+            statement.query_map([], |row| Ok((row.get::<_, String>(0)?.into(), row.get(1)?)))?
+        {
+            result.push(row?);
+        }
+        Ok(result)
+    }
+
     pub fn get_files(&self) -> Result<HashMap<usize, FileRecord>> {
-        let mut query_statement = self.db.prepare("SELECT id, path, sha1 FROM files")?;
+        let mut query_statement = self
+            .db
+            .prepare("SELECT id, relative_path, sha1 FROM files")?;
         let result_iter = query_statement.query_map([], |row| {
             Ok(FileRecord {
                 id: row.get(0)?,
-                path: row.get(1)?,
+                relative_path: row.get(1)?,
                 sha1: row.get(2)?,
             })
         })?;
@@ -160,6 +191,19 @@ impl VectorDatabase {
         Ok(pages)
     }
 
+    pub fn for_each_document(
+        &self,
+        worktree_id: i64,
+        mut f: impl FnMut(i64, Embedding),
+    ) -> Result<()> {
+        let mut query_statement = self.db.prepare("SELECT id, embedding FROM documents")?;
+        query_statement
+            .query_map(params![], |row| Ok((row.get(0)?, row.get(1)?)))?
+            .filter_map(|row| row.ok())
+            .for_each(|row| f(row.0, row.1));
+        Ok(())
+    }
+
     pub fn get_documents(&self) -> Result<HashMap<usize, DocumentRecord>> {
         let mut query_statement = self
             .db

crates/vector_store/src/embedding.rs 🔗

@@ -44,7 +44,7 @@ struct OpenAIEmbeddingUsage {
 }
 
 #[async_trait]
-pub trait EmbeddingProvider: Sync {
+pub trait EmbeddingProvider: Sync + Send {
     async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
 }
 

crates/vector_store/src/search.rs 🔗

@@ -1,4 +1,4 @@
-use std::cmp::Ordering;
+use std::{cmp::Ordering, path::PathBuf};
 
 use async_trait::async_trait;
 use ndarray::{Array1, Array2};
@@ -20,7 +20,6 @@ pub struct BruteForceSearch {
 
 impl BruteForceSearch {
     pub fn load(db: &VectorDatabase) -> Result<Self> {
-        // let db = VectorDatabase {};
         let documents = db.get_documents()?;
         let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect();
         let mut document_ids = vec![];
@@ -63,20 +62,5 @@ impl VectorSearch for BruteForceSearch {
         with_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
         with_indices.truncate(limit);
         with_indices
-
-        // // extract the sorted indices from the sorted tuple vector
-        // let stored_indices = with_indices
-        //     .into_iter()
-        //     .map(|(index, value)| index)
-        //     .collect::<Vec<>>();
-
-        // let sorted_indices: Vec<usize> = stored_indices.into_iter().rev().collect();
-
-        // let mut results = vec![];
-        // for idx in sorted_indices[0..limit].to_vec() {
-        //     results.push((self.document_ids[idx], 1.0 - similarities[idx]));
-        // }
-
-        // return results;
     }
 }

crates/vector_store/src/vector_store.rs 🔗

@@ -3,16 +3,19 @@ mod embedding;
 mod parsing;
 mod search;
 
+#[cfg(test)]
+mod vector_store_tests;
+
 use anyhow::{anyhow, Result};
-use db::VectorDatabase;
+use db::{VectorDatabase, VECTOR_DB_URL};
 use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
-use gpui::{AppContext, Entity, ModelContext, ModelHandle};
+use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task};
 use language::LanguageRegistry;
 use parsing::Document;
 use project::{Fs, Project};
 use search::{BruteForceSearch, VectorSearch};
 use smol::channel;
-use std::{path::PathBuf, sync::Arc, time::Instant};
+use std::{cmp::Ordering, path::PathBuf, sync::Arc, time::Instant};
 use tree_sitter::{Parser, QueryCursor};
 use util::{http::HttpClient, ResultExt, TryFutureExt};
 use workspace::WorkspaceCreated;
@@ -23,7 +26,16 @@ pub fn init(
     language_registry: Arc<LanguageRegistry>,
     cx: &mut AppContext,
 ) {
-    let vector_store = cx.add_model(|cx| VectorStore::new(fs, http_client, language_registry));
+    let vector_store = cx.add_model(|cx| {
+        VectorStore::new(
+            fs,
+            VECTOR_DB_URL.to_string(),
+            Arc::new(OpenAIEmbeddings {
+                client: http_client,
+            }),
+            language_registry,
+        )
+    });
 
     cx.subscribe_global::<WorkspaceCreated, _>({
         let vector_store = vector_store.clone();
@@ -49,28 +61,36 @@ pub struct IndexedFile {
     documents: Vec<Document>,
 }
 
-struct SearchResult {
-    path: PathBuf,
-    offset: usize,
-    name: String,
-    distance: f32,
-}
-
+// struct SearchResult {
+//     path: PathBuf,
+//     offset: usize,
+//     name: String,
+//     distance: f32,
+// }
 struct VectorStore {
     fs: Arc<dyn Fs>,
-    http_client: Arc<dyn HttpClient>,
+    database_url: Arc<str>,
+    embedding_provider: Arc<dyn EmbeddingProvider>,
     language_registry: Arc<LanguageRegistry>,
 }
 
+pub struct SearchResult {
+    pub name: String,
+    pub offset: usize,
+    pub file_path: PathBuf,
+}
+
 impl VectorStore {
     fn new(
         fs: Arc<dyn Fs>,
-        http_client: Arc<dyn HttpClient>,
+        database_url: String,
+        embedding_provider: Arc<dyn EmbeddingProvider>,
         language_registry: Arc<LanguageRegistry>,
     ) -> Self {
         Self {
             fs,
-            http_client,
+            database_url: database_url.into(),
+            embedding_provider,
             language_registry,
         }
     }
@@ -79,10 +99,12 @@ impl VectorStore {
         cursor: &mut QueryCursor,
         parser: &mut Parser,
         embedding_provider: &dyn EmbeddingProvider,
-        fs: &Arc<dyn Fs>,
         language_registry: &Arc<LanguageRegistry>,
         file_path: PathBuf,
+        content: String,
     ) -> Result<IndexedFile> {
+        dbg!(&file_path, &content);
+
         let language = language_registry
             .language_for_file(&file_path, None)
             .await?;
@@ -97,7 +119,6 @@ impl VectorStore {
             .as_ref()
             .ok_or_else(|| anyhow!("no outline query"))?;
 
-        let content = fs.load(&file_path).await?;
         parser.set_language(grammar.ts_language).unwrap();
         let tree = parser
             .parse(&content, None)
@@ -142,7 +163,11 @@ impl VectorStore {
         });
     }
 
-    fn add_project(&mut self, project: ModelHandle<Project>, cx: &mut ModelContext<Self>) {
+    fn add_project(
+        &mut self,
+        project: ModelHandle<Project>,
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<()>> {
         let worktree_scans_complete = project
             .read(cx)
             .worktrees(cx)
@@ -151,7 +176,8 @@ impl VectorStore {
 
         let fs = self.fs.clone();
         let language_registry = self.language_registry.clone();
-        let client = self.http_client.clone();
+        let embedding_provider = self.embedding_provider.clone();
+        let database_url = self.database_url.clone();
 
         cx.spawn(|_, cx| async move {
             futures::future::join_all(worktree_scans_complete).await;
@@ -163,24 +189,47 @@ impl VectorStore {
                     .collect::<Vec<_>>()
             });
 
-            let (paths_tx, paths_rx) = channel::unbounded::<PathBuf>();
-            let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<IndexedFile>();
-            cx.background()
+            let db = VectorDatabase::new(&database_url)?;
+            let worktree_root_paths = worktrees
+                .iter()
+                .map(|worktree| worktree.abs_path().clone())
+                .collect::<Vec<_>>();
+            let (db, file_hashes) = cx
+                .background()
                 .spawn(async move {
-                    for worktree in worktrees {
-                        for file in worktree.files(false, 0) {
-                            paths_tx.try_send(worktree.absolutize(&file.path)).unwrap();
-                        }
+                    let mut hashes = Vec::new();
+                    for worktree_root_path in worktree_root_paths {
+                        let worktree_id =
+                            db.find_or_create_worktree(worktree_root_path.as_ref())?;
+                        hashes.push((worktree_id, db.get_file_hashes(worktree_id)?));
                     }
+                    anyhow::Ok((db, hashes))
                 })
-                .detach();
+                .await?;
 
+            let (paths_tx, paths_rx) = channel::unbounded::<(i64, PathBuf, String)>();
+            let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<IndexedFile>();
             cx.background()
                 .spawn({
-                    let client = client.clone();
+                    let fs = fs.clone();
                     async move {
+                        for worktree in worktrees.into_iter() {
+                            for file in worktree.files(false, 0) {
+                                let absolute_path = worktree.absolutize(&file.path);
+                                dbg!(&absolute_path);
+                                if let Some(content) = fs.load(&absolute_path).await.log_err() {
+                                    dbg!(&content);
+                                    paths_tx.try_send((0, absolute_path, content)).unwrap();
+                                }
+                            }
+                        }
+                    }
+                })
+                .detach();
+
+            let db_write_task = cx.background().spawn(
+                async move {
                     // Initialize Database, creates database and tables if not exists
-                    let db = VectorDatabase::new()?;
                     while let Ok(indexed_file) = indexed_files_rx.recv().await {
                         db.insert_file(indexed_file).log_err();
                     }
@@ -188,39 +237,39 @@ impl VectorStore {
                     // ALL OF THE BELOW IS FOR TESTING,
                     // This should be removed as we find and appropriate place for evaluate our search.
 
-                    let embedding_provider = OpenAIEmbeddings{ client };
-                    let queries = vec![
-                        "compute embeddings for all of the symbols in the codebase, and write them to a database",
-                            "compute an outline view of all of the symbols in a buffer",
-                            "scan a directory on the file system and load all of its children into an in-memory snapshot",
-                    ];
-                    let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
-
-                    let t2 = Instant::now();
-                    let documents = db.get_documents().unwrap();
-                    let files = db.get_files().unwrap();
-                    println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
-
-                    let t1 = Instant::now();
-                    let mut bfs = BruteForceSearch::load(&db).unwrap();
-                    println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
-                    for (idx, embed) in embeddings.into_iter().enumerate() {
-                        let t0 = Instant::now();
-                        println!("\nQuery: {:?}", queries[idx]);
-                        let results = bfs.top_k_search(&embed, 5).await;
-                        println!("Search Elapsed: {}", t0.elapsed().as_millis());
-                        for (id, distance) in results {
-                            println!("");
-                            println!("   distance: {:?}", distance);
-                            println!("   document: {:?}", documents[&id].name);
-                            println!("   path:     {:?}", files[&documents[&id].file_id].path);
-                        }
+                    // let queries = vec![
+                    //     "compute embeddings for all of the symbols in the codebase, and write them to a database",
+                    //         "compute an outline view of all of the symbols in a buffer",
+                    //         "scan a directory on the file system and load all of its children into an in-memory snapshot",
+                    // ];
+                    // let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
 
-                    }
+                    // let t2 = Instant::now();
+                    // let documents = db.get_documents().unwrap();
+                    // let files = db.get_files().unwrap();
+                    // println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
+
+                    // let t1 = Instant::now();
+                    // let mut bfs = BruteForceSearch::load(&db).unwrap();
+                    // println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
+                    // for (idx, embed) in embeddings.into_iter().enumerate() {
+                    //     let t0 = Instant::now();
+                    //     println!("\nQuery: {:?}", queries[idx]);
+                    //     let results = bfs.top_k_search(&embed, 5).await;
+                    //     println!("Search Elapsed: {}", t0.elapsed().as_millis());
+                    //     for (id, distance) in results {
+                    //         println!("");
+                    //         println!("   distance: {:?}", distance);
+                    //         println!("   document: {:?}", documents[&id].name);
+                    //         println!("   path:     {:?}", files[&documents[&id].file_id].relative_path);
+                    //     }
+
+                    // }
 
                     anyhow::Ok(())
-                }}.log_err())
-                .detach();
+                }
+                .log_err(),
+            );
 
             let provider = DummyEmbeddings {};
             // let provider = OpenAIEmbeddings { client };
@@ -231,14 +280,15 @@ impl VectorStore {
                         scope.spawn(async {
                             let mut parser = Parser::new();
                             let mut cursor = QueryCursor::new();
-                            while let Ok(file_path) = paths_rx.recv().await {
+                            while let Ok((worktree_id, file_path, content)) = paths_rx.recv().await
+                            {
                                 if let Some(indexed_file) = Self::index_file(
                                     &mut cursor,
                                     &mut parser,
                                     &provider,
-                                    &fs,
                                     &language_registry,
                                     file_path,
+                                    content,
                                 )
                                 .await
                                 .log_err()
@@ -250,11 +300,86 @@ impl VectorStore {
                     }
                 })
                 .await;
+            drop(indexed_files_tx);
+
+            db_write_task.await;
+            anyhow::Ok(())
+        })
+    }
+
+    pub fn search(
+        &mut self,
+        phrase: String,
+        limit: usize,
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<Vec<SearchResult>>> {
+        let embedding_provider = self.embedding_provider.clone();
+        let database_url = self.database_url.clone();
+        cx.spawn(|this, cx| async move {
+            let database = VectorDatabase::new(database_url.as_ref())?;
+
+            // let embedding = embedding_provider.embed_batch(vec![&phrase]).await?;
+            //
+            let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
+
+            database.for_each_document(0, |id, embedding| {
+                dbg!(id, &embedding);
+
+                let similarity = dot(&embedding.0, &embedding.0);
+                let ix = match results.binary_search_by(|(_, s)| {
+                    s.partial_cmp(&similarity).unwrap_or(Ordering::Equal)
+                }) {
+                    Ok(ix) => ix,
+                    Err(ix) => ix,
+                };
+
+                results.insert(ix, (id, similarity));
+                results.truncate(limit);
+            })?;
+
+            dbg!(&results);
+
+            let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
+            // let documents = database.get_documents_by_ids(ids)?;
+
+            // let search_provider = cx
+            //     .background()
+            //     .spawn(async move { BruteForceSearch::load(&database) })
+            //     .await?;
+
+            // let results = search_provider.top_k_search(&embedding, limit))
+
+            anyhow::Ok(vec![])
         })
-        .detach();
     }
 }
 
 impl Entity for VectorStore {
     type Event = ();
 }
+
+fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
+    let len = vec_a.len();
+    assert_eq!(len, vec_b.len());
+
+    let mut result = 0.0;
+    unsafe {
+        matrixmultiply::sgemm(
+            1,
+            len,
+            1,
+            1.0,
+            vec_a.as_ptr(),
+            len as isize,
+            1,
+            vec_b.as_ptr(),
+            1,
+            len as isize,
+            0.0,
+            &mut result as *mut f32,
+            1,
+            1,
+        );
+    }
+    result
+}

crates/vector_store/src/vector_store_tests.rs 🔗

@@ -0,0 +1,136 @@
+use std::sync::Arc;
+
+use crate::{dot, embedding::EmbeddingProvider, VectorStore};
+use anyhow::Result;
+use async_trait::async_trait;
+use gpui::{Task, TestAppContext};
+use language::{Language, LanguageConfig, LanguageRegistry};
+use project::{FakeFs, Project};
+use rand::Rng;
+use serde_json::json;
+use unindent::Unindent;
+
+#[gpui::test]
+async fn test_vector_store(cx: &mut TestAppContext) {
+    let fs = FakeFs::new(cx.background());
+    fs.insert_tree(
+        "/the-root",
+        json!({
+            "src": {
+                "file1.rs": "
+                    fn aaa() {
+                        println!(\"aaaa!\");
+                    }
+
+                    fn zzzzzzzzz() {
+                        println!(\"SLEEPING\");
+                    }
+                ".unindent(),
+                "file2.rs": "
+                    fn bbb() {
+                        println!(\"bbbb!\");
+                    }
+                ".unindent(),
+            }
+        }),
+    )
+    .await;
+
+    let languages = Arc::new(LanguageRegistry::new(Task::ready(())));
+    let rust_language = Arc::new(
+        Language::new(
+            LanguageConfig {
+                name: "Rust".into(),
+                path_suffixes: vec!["rs".into()],
+                ..Default::default()
+            },
+            Some(tree_sitter_rust::language()),
+        )
+        .with_outline_query(
+            r#"
+            (function_item
+                name: (identifier) @name
+                body: (block)) @item
+            "#,
+        )
+        .unwrap(),
+    );
+    languages.add(rust_language);
+
+    let store = cx.add_model(|_| {
+        VectorStore::new(
+            fs.clone(),
+            "foo".to_string(),
+            Arc::new(FakeEmbeddingProvider),
+            languages,
+        )
+    });
+
+    let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
+    store
+        .update(cx, |store, cx| store.add_project(project, cx))
+        .await
+        .unwrap();
+
+    let search_results = store
+        .update(cx, |store, cx| store.search("aaaa".to_string(), 5, cx))
+        .await
+        .unwrap();
+
+    assert_eq!(search_results[0].offset, 0);
+    assert_eq!(search_results[1].name, "aaa");
+}
+
+#[test]
+fn test_dot_product() {
+    assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
+    assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
+
+    for _ in 0..100 {
+        let mut rng = rand::thread_rng();
+        let a: [f32; 32] = rng.gen();
+        let b: [f32; 32] = rng.gen();
+        assert_eq!(
+            round_to_decimals(dot(&a, &b), 3),
+            round_to_decimals(reference_dot(&a, &b), 3)
+        );
+    }
+
+    fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
+        let factor = (10.0 as f32).powi(decimal_places);
+        (n * factor).round() / factor
+    }
+
+    fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
+        a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
+    }
+}
+
+struct FakeEmbeddingProvider;
+
+#[async_trait]
+impl EmbeddingProvider for FakeEmbeddingProvider {
+    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
+        Ok(spans
+            .iter()
+            .map(|span| {
+                let mut result = vec![0.0; 26];
+                for letter in span.chars() {
+                    if letter as u32 > 'a' as u32 {
+                        let ix = (letter as u32) - ('a' as u32);
+                        if ix < 26 {
+                            result[ix as usize] += 1.0;
+                        }
+                    }
+                }
+
+                let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
+                for x in &mut result {
+                    *x /= norm;
+                }
+
+                result
+            })
+            .collect())
+    }
+}