Eager background indexing (#2928)

Kyle Caverly created

This PR ships a series of optimizations for the semantic search engine.
Mostly focused on removing invalid states, optimizing requests to
OpenAI, and reducing token usage.

Release Notes (Preview-Only):

- Added eager incremental indexing in the background on a debounce.
- Added a local embeddings cache for reducing redundant calls to OpenAI.
- Moved to an Embeddings Queue model which ensures optimal batch sizes
at the token level, and atomic file & document writes.
- Adjusted OpenAI Embedding API requests to use provided backoff delays
during Rate Limiting.
- Removed flush races between parsing files step and embedding queue
steps.
- Moved truncation to parsing step reducing the probability that OpenAI
encounters bad data.

Change summary

Cargo.lock                                        |  65 +
crates/semantic_index/Cargo.toml                  |   1 
crates/semantic_index/src/db.rs                   | 721 ++++++++------
crates/semantic_index/src/embedding.rs            | 252 ++++-
crates/semantic_index/src/embedding_queue.rs      | 173 +++
crates/semantic_index/src/parsing.rs              |  85 +
crates/semantic_index/src/semantic_index.rs       | 788 ++++------------
crates/semantic_index/src/semantic_index_tests.rs | 224 +++-
crates/util/src/util.rs                           |  35 
9 files changed, 1,287 insertions(+), 1,057 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -3539,7 +3539,7 @@ dependencies = [
  "gif",
  "jpeg-decoder",
  "num-iter",
- "num-rational",
+ "num-rational 0.3.2",
  "num-traits",
  "png",
  "scoped_threadpool",
@@ -4631,6 +4631,31 @@ dependencies = [
  "winapi 0.3.9",
 ]
 
+[[package]]
+name = "num"
+version = "0.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36"
+dependencies = [
+ "num-bigint 0.2.6",
+ "num-complex",
+ "num-integer",
+ "num-iter",
+ "num-rational 0.2.4",
+ "num-traits",
+]
+
+[[package]]
+name = "num-bigint"
+version = "0.2.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "090c7f9998ee0ff65aa5b723e4009f7b217707f1fb5ea551329cc4d6231fb304"
+dependencies = [
+ "autocfg",
+ "num-integer",
+ "num-traits",
+]
+
 [[package]]
 name = "num-bigint"
 version = "0.4.4"
@@ -4659,6 +4684,16 @@ dependencies = [
  "zeroize",
 ]
 
+[[package]]
+name = "num-complex"
+version = "0.2.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95"
+dependencies = [
+ "autocfg",
+ "num-traits",
+]
+
 [[package]]
 name = "num-derive"
 version = "0.3.3"
@@ -4691,6 +4726,18 @@ dependencies = [
  "num-traits",
 ]
 
+[[package]]
+name = "num-rational"
+version = "0.2.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5c000134b5dbf44adc5cb772486d335293351644b801551abe8f75c84cfa4aef"
+dependencies = [
+ "autocfg",
+ "num-bigint 0.2.6",
+ "num-integer",
+ "num-traits",
+]
+
 [[package]]
 name = "num-rational"
 version = "0.3.2"
@@ -5007,6 +5054,17 @@ dependencies = [
  "windows-targets 0.48.5",
 ]
 
+[[package]]
+name = "parse_duration"
+version = "2.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7037e5e93e0172a5a96874380bf73bc6ecef022e26fa25f2be26864d6b3ba95d"
+dependencies = [
+ "lazy_static",
+ "num",
+ "regex",
+]
+
 [[package]]
 name = "password-hash"
 version = "0.2.3"
@@ -6674,6 +6732,7 @@ dependencies = [
  "log",
  "matrixmultiply",
  "parking_lot 0.11.2",
+ "parse_duration",
  "picker",
  "postage",
  "pretty_assertions",
@@ -7005,7 +7064,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "8eb4ea60fb301dc81dfc113df680571045d375ab7345d171c5dc7d7e13107a80"
 dependencies = [
  "chrono",
- "num-bigint",
+ "num-bigint 0.4.4",
  "num-traits",
  "thiserror",
 ]
@@ -7237,7 +7296,7 @@ dependencies = [
  "log",
  "md-5",
  "memchr",
- "num-bigint",
+ "num-bigint 0.4.4",
  "once_cell",
  "paste",
  "percent-encoding",

crates/semantic_index/Cargo.toml 🔗

@@ -39,6 +39,7 @@ rand.workspace = true
 schemars.workspace = true
 globset.workspace = true
 sha1 = "0.10.5"
+parse_duration = "2.1.1"
 
 [dev-dependencies]
 gpui = { path = "../gpui", features = ["test-support"] }

crates/semantic_index/src/db.rs 🔗

@@ -1,20 +1,26 @@
-use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
+use crate::{
+    embedding::Embedding,
+    parsing::{Document, DocumentDigest},
+    SEMANTIC_INDEX_VERSION,
+};
 use anyhow::{anyhow, Context, Result};
+use futures::channel::oneshot;
+use gpui::executor;
 use project::{search::PathMatcher, Fs};
 use rpc::proto::Timestamp;
-use rusqlite::{
-    params,
-    types::{FromSql, FromSqlResult, ValueRef},
-};
+use rusqlite::params;
+use rusqlite::types::Value;
 use std::{
     cmp::Ordering,
     collections::HashMap,
+    future::Future,
     ops::Range,
     path::{Path, PathBuf},
     rc::Rc,
     sync::Arc,
-    time::SystemTime,
+    time::{Instant, SystemTime},
 };
+use util::TryFutureExt;
 
 #[derive(Debug)]
 pub struct FileRecord {
@@ -23,145 +29,181 @@ pub struct FileRecord {
     pub mtime: Timestamp,
 }
 
-#[derive(Debug)]
-struct Embedding(pub Vec<f32>);
-
-#[derive(Debug)]
-struct Sha1(pub Vec<u8>);
-
-impl FromSql for Embedding {
-    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
-        let bytes = value.as_blob()?;
-        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
-        if embedding.is_err() {
-            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
-        }
-        return Ok(Embedding(embedding.unwrap()));
-    }
-}
-
-impl FromSql for Sha1 {
-    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
-        let bytes = value.as_blob()?;
-        let sha1: Result<Vec<u8>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
-        if sha1.is_err() {
-            return Err(rusqlite::types::FromSqlError::Other(sha1.unwrap_err()));
-        }
-        return Ok(Sha1(sha1.unwrap()));
-    }
-}
-
+#[derive(Clone)]
 pub struct VectorDatabase {
-    db: rusqlite::Connection,
+    path: Arc<Path>,
+    transactions:
+        smol::channel::Sender<Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>>,
 }
 
 impl VectorDatabase {
-    pub async fn new(fs: Arc<dyn Fs>, path: Arc<PathBuf>) -> Result<Self> {
+    pub async fn new(
+        fs: Arc<dyn Fs>,
+        path: Arc<Path>,
+        executor: Arc<executor::Background>,
+    ) -> Result<Self> {
         if let Some(db_directory) = path.parent() {
             fs.create_dir(db_directory).await?;
         }
 
+        let (transactions_tx, transactions_rx) = smol::channel::unbounded::<
+            Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>,
+        >();
+        executor
+            .spawn({
+                let path = path.clone();
+                async move {
+                    let mut connection = rusqlite::Connection::open(&path)?;
+
+                    connection.pragma_update(None, "journal_mode", "wal")?;
+                    connection.pragma_update(None, "synchronous", "normal")?;
+                    connection.pragma_update(None, "cache_size", 1000000)?;
+                    connection.pragma_update(None, "temp_store", "MEMORY")?;
+
+                    while let Ok(transaction) = transactions_rx.recv().await {
+                        transaction(&mut connection);
+                    }
+
+                    anyhow::Ok(())
+                }
+                .log_err()
+            })
+            .detach();
         let this = Self {
-            db: rusqlite::Connection::open(path.as_path())?,
+            transactions: transactions_tx,
+            path,
         };
-        this.initialize_database()?;
+        this.initialize_database().await?;
         Ok(this)
     }
 
-    fn get_existing_version(&self) -> Result<i64> {
-        let mut version_query = self
-            .db
-            .prepare("SELECT version from semantic_index_config")?;
-        version_query
-            .query_row([], |row| Ok(row.get::<_, i64>(0)?))
-            .map_err(|err| anyhow!("version query failed: {err}"))
+    pub fn path(&self) -> &Arc<Path> {
+        &self.path
     }
 
-    fn initialize_database(&self) -> Result<()> {
-        rusqlite::vtab::array::load_module(&self.db)?;
-
-        // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
-        if self
-            .get_existing_version()
-            .map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64)
-        {
-            log::trace!("vector database schema up to date");
-            return Ok(());
+    fn transact<F, T>(&self, f: F) -> impl Future<Output = Result<T>>
+    where
+        F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result<T>,
+        T: 'static + Send,
+    {
+        let (tx, rx) = oneshot::channel();
+        let transactions = self.transactions.clone();
+        async move {
+            if transactions
+                .send(Box::new(|connection| {
+                    let result = connection
+                        .transaction()
+                        .map_err(|err| anyhow!(err))
+                        .and_then(|transaction| {
+                            let result = f(&transaction)?;
+                            transaction.commit()?;
+                            Ok(result)
+                        });
+                    let _ = tx.send(result);
+                }))
+                .await
+                .is_err()
+            {
+                return Err(anyhow!("connection was dropped"))?;
+            }
+            rx.await?
         }
+    }
 
-        log::trace!("vector database schema out of date. updating...");
-        self.db
-            .execute("DROP TABLE IF EXISTS documents", [])
-            .context("failed to drop 'documents' table")?;
-        self.db
-            .execute("DROP TABLE IF EXISTS files", [])
-            .context("failed to drop 'files' table")?;
-        self.db
-            .execute("DROP TABLE IF EXISTS worktrees", [])
-            .context("failed to drop 'worktrees' table")?;
-        self.db
-            .execute("DROP TABLE IF EXISTS semantic_index_config", [])
-            .context("failed to drop 'semantic_index_config' table")?;
-
-        // Initialize Vector Databasing Tables
-        self.db.execute(
-            "CREATE TABLE semantic_index_config (
-                version INTEGER NOT NULL
-            )",
-            [],
-        )?;
+    fn initialize_database(&self) -> impl Future<Output = Result<()>> {
+        self.transact(|db| {
+            rusqlite::vtab::array::load_module(&db)?;
+
+            // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
+            let version_query = db.prepare("SELECT version from semantic_index_config");
+            let version = version_query
+                .and_then(|mut query| query.query_row([], |row| Ok(row.get::<_, i64>(0)?)));
+            if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) {
+                log::trace!("vector database schema up to date");
+                return Ok(());
+            }
 
-        self.db.execute(
-            "INSERT INTO semantic_index_config (version) VALUES (?1)",
-            params![SEMANTIC_INDEX_VERSION],
-        )?;
+            log::trace!("vector database schema out of date. updating...");
+            db.execute("DROP TABLE IF EXISTS documents", [])
+                .context("failed to drop 'documents' table")?;
+            db.execute("DROP TABLE IF EXISTS files", [])
+                .context("failed to drop 'files' table")?;
+            db.execute("DROP TABLE IF EXISTS worktrees", [])
+                .context("failed to drop 'worktrees' table")?;
+            db.execute("DROP TABLE IF EXISTS semantic_index_config", [])
+                .context("failed to drop 'semantic_index_config' table")?;
+
+            // Initialize Vector Databasing Tables
+            db.execute(
+                "CREATE TABLE semantic_index_config (
+                    version INTEGER NOT NULL
+                )",
+                [],
+            )?;
 
-        self.db.execute(
-            "CREATE TABLE worktrees (
-                id INTEGER PRIMARY KEY AUTOINCREMENT,
-                absolute_path VARCHAR NOT NULL
-            );
-            CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
-            ",
-            [],
-        )?;
+            db.execute(
+                "INSERT INTO semantic_index_config (version) VALUES (?1)",
+                params![SEMANTIC_INDEX_VERSION],
+            )?;
 
-        self.db.execute(
-            "CREATE TABLE files (
-                id INTEGER PRIMARY KEY AUTOINCREMENT,
-                worktree_id INTEGER NOT NULL,
-                relative_path VARCHAR NOT NULL,
-                mtime_seconds INTEGER NOT NULL,
-                mtime_nanos INTEGER NOT NULL,
-                FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
-            )",
-            [],
-        )?;
+            db.execute(
+                "CREATE TABLE worktrees (
+                    id INTEGER PRIMARY KEY AUTOINCREMENT,
+                    absolute_path VARCHAR NOT NULL
+                );
+                CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
+                ",
+                [],
+            )?;
 
-        self.db.execute(
-            "CREATE TABLE documents (
-                id INTEGER PRIMARY KEY AUTOINCREMENT,
-                file_id INTEGER NOT NULL,
-                start_byte INTEGER NOT NULL,
-                end_byte INTEGER NOT NULL,
-                name VARCHAR NOT NULL,
-                embedding BLOB NOT NULL,
-                sha1 BLOB NOT NULL,
-                FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
-            )",
-            [],
-        )?;
+            db.execute(
+                "CREATE TABLE files (
+                    id INTEGER PRIMARY KEY AUTOINCREMENT,
+                    worktree_id INTEGER NOT NULL,
+                    relative_path VARCHAR NOT NULL,
+                    mtime_seconds INTEGER NOT NULL,
+                    mtime_nanos INTEGER NOT NULL,
+                    FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
+                )",
+                [],
+            )?;
 
-        log::trace!("vector database initialized with updated schema.");
-        Ok(())
+            db.execute(
+                "CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)",
+                [],
+            )?;
+
+            db.execute(
+                "CREATE TABLE documents (
+                    id INTEGER PRIMARY KEY AUTOINCREMENT,
+                    file_id INTEGER NOT NULL,
+                    start_byte INTEGER NOT NULL,
+                    end_byte INTEGER NOT NULL,
+                    name VARCHAR NOT NULL,
+                    embedding BLOB NOT NULL,
+                    digest BLOB NOT NULL,
+                    FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
+                )",
+                [],
+            )?;
+
+            log::trace!("vector database initialized with updated schema.");
+            Ok(())
+        })
     }
 
-    pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> {
-        self.db.execute(
-            "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
-            params![worktree_id, delete_path.to_str()],
-        )?;
-        Ok(())
+    pub fn delete_file(
+        &self,
+        worktree_id: i64,
+        delete_path: PathBuf,
+    ) -> impl Future<Output = Result<()>> {
+        self.transact(move |db| {
+            db.execute(
+                "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
+                params![worktree_id, delete_path.to_str()],
+            )?;
+            Ok(())
+        })
     }
 
     pub fn insert_file(
@@ -170,139 +212,187 @@ impl VectorDatabase {
         path: PathBuf,
         mtime: SystemTime,
         documents: Vec<Document>,
-    ) -> Result<()> {
-        // Return the existing ID, if both the file and mtime match
-        let mtime = Timestamp::from(mtime);
-        let mut existing_id_query = self.db.prepare("SELECT id FROM files WHERE worktree_id = ?1 AND relative_path = ?2 AND mtime_seconds = ?3 AND mtime_nanos = ?4")?;
-        let existing_id = existing_id_query
-            .query_row(
+    ) -> impl Future<Output = Result<()>> {
+        self.transact(move |db| {
+            // Return the existing ID, if both the file and mtime match
+            let mtime = Timestamp::from(mtime);
+
+            db.execute(
+                "
+                REPLACE INTO files
+                (worktree_id, relative_path, mtime_seconds, mtime_nanos)
+                VALUES (?1, ?2, ?3, ?4)
+                ",
                 params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
-                |row| Ok(row.get::<_, i64>(0)?),
-            )
-            .map_err(|err| anyhow!(err));
-        let file_id = if existing_id.is_ok() {
-            // If already exists, just return the existing id
-            existing_id.unwrap()
-        } else {
-            // Delete Existing Row
-            self.db.execute(
-                "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;",
-                params![worktree_id, path.to_str()],
             )?;
-            self.db.execute("INSERT INTO files (worktree_id, relative_path, mtime_seconds, mtime_nanos) VALUES (?1, ?2, ?3, ?4);", params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos])?;
-            self.db.last_insert_rowid()
-        };
 
-        // Currently inserting at approximately 3400 documents a second
-        // I imagine we can speed this up with a bulk insert of some kind.
-        for document in documents {
-            let embedding_blob = bincode::serialize(&document.embedding)?;
-            let sha_blob = bincode::serialize(&document.sha1)?;
+            let file_id = db.last_insert_rowid();
 
-            self.db.execute(
-                "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
-                params![
+            let t0 = Instant::now();
+            let mut query = db.prepare(
+                "
+                INSERT INTO documents
+                (file_id, start_byte, end_byte, name, embedding, digest)
+                VALUES (?1, ?2, ?3, ?4, ?5, ?6)
+                ",
+            )?;
+            log::trace!(
+                "Preparing Query Took: {:?} milliseconds",
+                t0.elapsed().as_millis()
+            );
+
+            for document in documents {
+                query.execute(params![
                     file_id,
                     document.range.start.to_string(),
                     document.range.end.to_string(),
                     document.name,
-                    embedding_blob,
-                    sha_blob
-                ],
-            )?;
-        }
+                    document.embedding,
+                    document.digest
+                ])?;
+            }
 
-        Ok(())
+            Ok(())
+        })
     }
 
-    pub fn worktree_previously_indexed(&self, worktree_root_path: &Path) -> Result<bool> {
-        let mut worktree_query = self
-            .db
-            .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
-        let worktree_id = worktree_query
-            .query_row(params![worktree_root_path.to_string_lossy()], |row| {
-                Ok(row.get::<_, i64>(0)?)
-            })
-            .map_err(|err| anyhow!(err));
-
-        if worktree_id.is_ok() {
-            return Ok(true);
-        } else {
-            return Ok(false);
-        }
+    pub fn worktree_previously_indexed(
+        &self,
+        worktree_root_path: &Path,
+    ) -> impl Future<Output = Result<bool>> {
+        let worktree_root_path = worktree_root_path.to_string_lossy().into_owned();
+        self.transact(move |db| {
+            let mut worktree_query =
+                db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
+            let worktree_id = worktree_query
+                .query_row(params![worktree_root_path], |row| Ok(row.get::<_, i64>(0)?));
+
+            if worktree_id.is_ok() {
+                return Ok(true);
+            } else {
+                return Ok(false);
+            }
+        })
     }
 
-    pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
-        // Check that the absolute path doesnt exist
-        let mut worktree_query = self
-            .db
-            .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
+    pub fn embeddings_for_files(
+        &self,
+        worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
+    ) -> impl Future<Output = Result<HashMap<DocumentDigest, Embedding>>> {
+        self.transact(move |db| {
+            let mut query = db.prepare(
+                "
+                SELECT digest, embedding
+                FROM documents
+                LEFT JOIN files ON files.id = documents.file_id
+                WHERE files.worktree_id = ? AND files.relative_path IN rarray(?)
+            ",
+            )?;
+            let mut embeddings_by_digest = HashMap::new();
+            for (worktree_id, file_paths) in worktree_id_file_paths {
+                let file_paths = Rc::new(
+                    file_paths
+                        .into_iter()
+                        .map(|p| Value::Text(p.to_string_lossy().into_owned()))
+                        .collect::<Vec<_>>(),
+                );
+                let rows = query.query_map(params![worktree_id, file_paths], |row| {
+                    Ok((
+                        row.get::<_, DocumentDigest>(0)?,
+                        row.get::<_, Embedding>(1)?,
+                    ))
+                })?;
+
+                for row in rows {
+                    if let Ok(row) = row {
+                        embeddings_by_digest.insert(row.0, row.1);
+                    }
+                }
+            }
 
-        let worktree_id = worktree_query
-            .query_row(params![worktree_root_path.to_string_lossy()], |row| {
-                Ok(row.get::<_, i64>(0)?)
-            })
-            .map_err(|err| anyhow!(err));
+            Ok(embeddings_by_digest)
+        })
+    }
 
-        if worktree_id.is_ok() {
-            return worktree_id;
-        }
+    pub fn find_or_create_worktree(
+        &self,
+        worktree_root_path: PathBuf,
+    ) -> impl Future<Output = Result<i64>> {
+        self.transact(move |db| {
+            let mut worktree_query =
+                db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
+            let worktree_id = worktree_query
+                .query_row(params![worktree_root_path.to_string_lossy()], |row| {
+                    Ok(row.get::<_, i64>(0)?)
+                });
+
+            if worktree_id.is_ok() {
+                return Ok(worktree_id?);
+            }
 
-        // If worktree_id is Err, insert new worktree
-        self.db.execute(
-            "
-            INSERT into worktrees (absolute_path) VALUES (?1)
-            ",
-            params![worktree_root_path.to_string_lossy()],
-        )?;
-        Ok(self.db.last_insert_rowid())
+            // If worktree_id is Err, insert new worktree
+            db.execute(
+                "INSERT into worktrees (absolute_path) VALUES (?1)",
+                params![worktree_root_path.to_string_lossy()],
+            )?;
+            Ok(db.last_insert_rowid())
+        })
     }
 
-    pub fn get_file_mtimes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, SystemTime>> {
-        let mut statement = self.db.prepare(
-            "
-            SELECT relative_path, mtime_seconds, mtime_nanos
-            FROM files
-            WHERE worktree_id = ?1
-            ORDER BY relative_path",
-        )?;
-        let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
-        for row in statement.query_map(params![worktree_id], |row| {
-            Ok((
-                row.get::<_, String>(0)?.into(),
-                Timestamp {
-                    seconds: row.get(1)?,
-                    nanos: row.get(2)?,
-                }
-                .into(),
-            ))
-        })? {
-            let row = row?;
-            result.insert(row.0, row.1);
-        }
-        Ok(result)
+    pub fn get_file_mtimes(
+        &self,
+        worktree_id: i64,
+    ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
+        self.transact(move |db| {
+            let mut statement = db.prepare(
+                "
+                SELECT relative_path, mtime_seconds, mtime_nanos
+                FROM files
+                WHERE worktree_id = ?1
+                ORDER BY relative_path",
+            )?;
+            let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
+            for row in statement.query_map(params![worktree_id], |row| {
+                Ok((
+                    row.get::<_, String>(0)?.into(),
+                    Timestamp {
+                        seconds: row.get(1)?,
+                        nanos: row.get(2)?,
+                    }
+                    .into(),
+                ))
+            })? {
+                let row = row?;
+                result.insert(row.0, row.1);
+            }
+            Ok(result)
+        })
     }
 
     pub fn top_k_search(
         &self,
-        query_embedding: &Vec<f32>,
+        query_embedding: &Embedding,
         limit: usize,
         file_ids: &[i64],
-    ) -> Result<Vec<(i64, f32)>> {
-        let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
-        self.for_each_document(file_ids, |id, embedding| {
-            let similarity = dot(&embedding, &query_embedding);
-            let ix = match results
-                .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
-            {
-                Ok(ix) => ix,
-                Err(ix) => ix,
-            };
-            results.insert(ix, (id, similarity));
-            results.truncate(limit);
-        })?;
-
-        Ok(results)
+    ) -> impl Future<Output = Result<Vec<(i64, f32)>>> {
+        let query_embedding = query_embedding.clone();
+        let file_ids = file_ids.to_vec();
+        self.transact(move |db| {
+            let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
+            Self::for_each_document(db, &file_ids, |id, embedding| {
+                let similarity = embedding.similarity(&query_embedding);
+                let ix = match results.binary_search_by(|(_, s)| {
+                    similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
+                }) {
+                    Ok(ix) => ix,
+                    Err(ix) => ix,
+                };
+                results.insert(ix, (id, similarity));
+                results.truncate(limit);
+            })?;
+
+            anyhow::Ok(results)
+        })
     }
 
     pub fn retrieve_included_file_ids(
@@ -310,37 +400,46 @@ impl VectorDatabase {
         worktree_ids: &[i64],
         includes: &[PathMatcher],
         excludes: &[PathMatcher],
-    ) -> Result<Vec<i64>> {
-        let mut file_query = self.db.prepare(
-            "
-            SELECT
-                id, relative_path
-            FROM
-                files
-            WHERE
-                worktree_id IN rarray(?)
-            ",
-        )?;
+    ) -> impl Future<Output = Result<Vec<i64>>> {
+        let worktree_ids = worktree_ids.to_vec();
+        let includes = includes.to_vec();
+        let excludes = excludes.to_vec();
+        self.transact(move |db| {
+            let mut file_query = db.prepare(
+                "
+                SELECT
+                    id, relative_path
+                FROM
+                    files
+                WHERE
+                    worktree_id IN rarray(?)
+                ",
+            )?;
 
-        let mut file_ids = Vec::<i64>::new();
-        let mut rows = file_query.query([ids_to_sql(worktree_ids)])?;
-
-        while let Some(row) = rows.next()? {
-            let file_id = row.get(0)?;
-            let relative_path = row.get_ref(1)?.as_str()?;
-            let included =
-                includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
-            let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
-            if included && !excluded {
-                file_ids.push(file_id);
+            let mut file_ids = Vec::<i64>::new();
+            let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?;
+
+            while let Some(row) = rows.next()? {
+                let file_id = row.get(0)?;
+                let relative_path = row.get_ref(1)?.as_str()?;
+                let included =
+                    includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
+                let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
+                if included && !excluded {
+                    file_ids.push(file_id);
+                }
             }
-        }
 
-        Ok(file_ids)
+            anyhow::Ok(file_ids)
+        })
     }
 
-    fn for_each_document(&self, file_ids: &[i64], mut f: impl FnMut(i64, Vec<f32>)) -> Result<()> {
-        let mut query_statement = self.db.prepare(
+    fn for_each_document(
+        db: &rusqlite::Connection,
+        file_ids: &[i64],
+        mut f: impl FnMut(i64, Embedding),
+    ) -> Result<()> {
+        let mut query_statement = db.prepare(
             "
             SELECT
                 id, embedding
@@ -356,51 +455,57 @@ impl VectorDatabase {
                 Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
             })?
             .filter_map(|row| row.ok())
-            .for_each(|(id, embedding)| f(id, embedding.0));
+            .for_each(|(id, embedding)| f(id, embedding));
         Ok(())
     }
 
-    pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
-        let mut statement = self.db.prepare(
-            "
-                SELECT
-                    documents.id,
-                    files.worktree_id,
-                    files.relative_path,
-                    documents.start_byte,
-                    documents.end_byte
-                FROM
-                    documents, files
-                WHERE
-                    documents.file_id = files.id AND
-                    documents.id in rarray(?)
-            ",
-        )?;
+    pub fn get_documents_by_ids(
+        &self,
+        ids: &[i64],
+    ) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
+        let ids = ids.to_vec();
+        self.transact(move |db| {
+            let mut statement = db.prepare(
+                "
+                    SELECT
+                        documents.id,
+                        files.worktree_id,
+                        files.relative_path,
+                        documents.start_byte,
+                        documents.end_byte
+                    FROM
+                        documents, files
+                    WHERE
+                        documents.file_id = files.id AND
+                        documents.id in rarray(?)
+                ",
+            )?;
 
-        let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
-            Ok((
-                row.get::<_, i64>(0)?,
-                row.get::<_, i64>(1)?,
-                row.get::<_, String>(2)?.into(),
-                row.get(3)?..row.get(4)?,
-            ))
-        })?;
-
-        let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
-        for row in result_iter {
-            let (id, worktree_id, path, range) = row?;
-            values_by_id.insert(id, (worktree_id, path, range));
-        }
+            let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| {
+                Ok((
+                    row.get::<_, i64>(0)?,
+                    row.get::<_, i64>(1)?,
+                    row.get::<_, String>(2)?.into(),
+                    row.get(3)?..row.get(4)?,
+                ))
+            })?;
+
+            let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
+            for row in result_iter {
+                let (id, worktree_id, path, range) = row?;
+                values_by_id.insert(id, (worktree_id, path, range));
+            }
 
-        let mut results = Vec::with_capacity(ids.len());
-        for id in ids {
-            let value = values_by_id
-                .remove(id)
-                .ok_or(anyhow!("missing document id {}", id))?;
-            results.push(value);
-        }
+            let mut results = Vec::with_capacity(ids.len());
+            for id in &ids {
+                let value = values_by_id
+                    .remove(id)
+                    .ok_or(anyhow!("missing document id {}", id))?;
+                results.push(value);
+            }
 
-        Ok(results)
+            Ok(results)
+        })
     }
 }
 
@@ -412,29 +517,3 @@ fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
             .collect::<Vec<_>>(),
     )
 }
-
-pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
-    let len = vec_a.len();
-    assert_eq!(len, vec_b.len());
-
-    let mut result = 0.0;
-    unsafe {
-        matrixmultiply::sgemm(
-            1,
-            len,
-            1,
-            1.0,
-            vec_a.as_ptr(),
-            len as isize,
-            1,
-            vec_b.as_ptr(),
-            1,
-            len as isize,
-            0.0,
-            &mut result as *mut f32,
-            1,
-            1,
-        );
-    }
-    result
-}

crates/semantic_index/src/embedding.rs 🔗

@@ -7,6 +7,9 @@ use isahc::http::StatusCode;
 use isahc::prelude::Configurable;
 use isahc::{AsyncBody, Response};
 use lazy_static::lazy_static;
+use parse_duration::parse;
+use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
+use rusqlite::ToSql;
 use serde::{Deserialize, Serialize};
 use std::env;
 use std::sync::Arc;
@@ -19,6 +22,62 @@ lazy_static! {
     static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
 }
 
+#[derive(Debug, PartialEq, Clone)]
+pub struct Embedding(Vec<f32>);
+
+impl From<Vec<f32>> for Embedding {
+    fn from(value: Vec<f32>) -> Self {
+        Embedding(value)
+    }
+}
+
+impl Embedding {
+    pub fn similarity(&self, other: &Self) -> f32 {
+        let len = self.0.len();
+        assert_eq!(len, other.0.len());
+
+        let mut result = 0.0;
+        unsafe {
+            matrixmultiply::sgemm(
+                1,
+                len,
+                1,
+                1.0,
+                self.0.as_ptr(),
+                len as isize,
+                1,
+                other.0.as_ptr(),
+                1,
+                len as isize,
+                0.0,
+                &mut result as *mut f32,
+                1,
+                1,
+            );
+        }
+        result
+    }
+}
+
+impl FromSql for Embedding {
+    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
+        let bytes = value.as_blob()?;
+        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
+        if embedding.is_err() {
+            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
+        }
+        Ok(Embedding(embedding.unwrap()))
+    }
+}
+
+impl ToSql for Embedding {
+    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
+        let bytes = bincode::serialize(&self.0)
+            .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
+        Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
+    }
+}
+
 #[derive(Clone)]
 pub struct OpenAIEmbeddings {
     pub client: Arc<dyn HttpClient>,
@@ -52,42 +111,53 @@ struct OpenAIEmbeddingUsage {
 
 #[async_trait]
 pub trait EmbeddingProvider: Sync + Send {
-    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
+    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
+    fn max_tokens_per_batch(&self) -> usize;
+    fn truncate(&self, span: &str) -> (String, usize);
 }
 
 pub struct DummyEmbeddings {}
 
 #[async_trait]
 impl EmbeddingProvider for DummyEmbeddings {
-    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
+    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
         // 1024 is the OpenAI Embeddings size for ada models.
         // the model we will likely be starting with.
-        let dummy_vec = vec![0.32 as f32; 1536];
+        let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
         return Ok(vec![dummy_vec; spans.len()]);
     }
-}
 
-const OPENAI_INPUT_LIMIT: usize = 8190;
+    fn max_tokens_per_batch(&self) -> usize {
+        OPENAI_INPUT_LIMIT
+    }
 
-impl OpenAIEmbeddings {
-    fn truncate(span: String) -> String {
-        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span.as_ref());
-        if tokens.len() > OPENAI_INPUT_LIMIT {
+    fn truncate(&self, span: &str) -> (String, usize) {
+        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
+        let token_count = tokens.len();
+        let output = if token_count > OPENAI_INPUT_LIMIT {
             tokens.truncate(OPENAI_INPUT_LIMIT);
-            let result = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
-            if result.is_ok() {
-                let transformed = result.unwrap();
-                return transformed;
-            }
-        }
+            let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
+            new_input.ok().unwrap_or_else(|| span.to_string())
+        } else {
+            span.to_string()
+        };
 
-        span
+        (output, tokens.len())
     }
+}
+
+const OPENAI_INPUT_LIMIT: usize = 8190;
 
-    async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result<Response<AsyncBody>> {
+impl OpenAIEmbeddings {
+    async fn send_request(
+        &self,
+        api_key: &str,
+        spans: Vec<&str>,
+        request_timeout: u64,
+    ) -> Result<Response<AsyncBody>> {
         let request = Request::post("https://api.openai.com/v1/embeddings")
             .redirect_policy(isahc::config::RedirectPolicy::Follow)
-            .timeout(Duration::from_secs(4))
+            .timeout(Duration::from_secs(request_timeout))
             .header("Content-Type", "application/json")
             .header("Authorization", format!("Bearer {}", api_key))
             .body(
@@ -105,7 +175,27 @@ impl OpenAIEmbeddings {
 
 #[async_trait]
 impl EmbeddingProvider for OpenAIEmbeddings {
-    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
+    fn max_tokens_per_batch(&self) -> usize {
+        50000
+    }
+
+    fn truncate(&self, span: &str) -> (String, usize) {
+        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
+        let token_count = tokens.len();
+        let output = if token_count > OPENAI_INPUT_LIMIT {
+            tokens.truncate(OPENAI_INPUT_LIMIT);
+            OPENAI_BPE_TOKENIZER
+                .decode(tokens)
+                .ok()
+                .unwrap_or_else(|| span.to_string())
+        } else {
+            span.to_string()
+        };
+
+        (output, token_count)
+    }
+
+    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
         const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
         const MAX_RETRIES: usize = 4;
 
@@ -114,45 +204,21 @@ impl EmbeddingProvider for OpenAIEmbeddings {
             .ok_or_else(|| anyhow!("no api key"))?;
 
         let mut request_number = 0;
-        let mut truncated = false;
+        let mut request_timeout: u64 = 10;
         let mut response: Response<AsyncBody>;
-        let mut spans: Vec<String> = spans.iter().map(|x| x.to_string()).collect();
         while request_number < MAX_RETRIES {
             response = self
-                .send_request(api_key, spans.iter().map(|x| &**x).collect())
+                .send_request(
+                    api_key,
+                    spans.iter().map(|x| &**x).collect(),
+                    request_timeout,
+                )
                 .await?;
             request_number += 1;
 
-            if request_number + 1 == MAX_RETRIES && response.status() != StatusCode::OK {
-                return Err(anyhow!(
-                    "openai max retries, error: {:?}",
-                    &response.status()
-                ));
-            }
-
             match response.status() {
-                StatusCode::TOO_MANY_REQUESTS => {
-                    let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
-                    log::trace!(
-                        "open ai rate limiting, delaying request by {:?} seconds",
-                        delay.as_secs()
-                    );
-                    self.executor.timer(delay).await;
-                }
-                StatusCode::BAD_REQUEST => {
-                    // Only truncate if it hasnt been truncated before
-                    if !truncated {
-                        for span in spans.iter_mut() {
-                            *span = Self::truncate(span.clone());
-                        }
-                        truncated = true;
-                    } else {
-                        // If failing once already truncated, log the error and break the loop
-                        let mut body = String::new();
-                        response.body_mut().read_to_string(&mut body).await?;
-                        log::trace!("open ai bad request: {:?} {:?}", &response.status(), body);
-                        break;
-                    }
+                StatusCode::REQUEST_TIMEOUT => {
+                    request_timeout += 5;
                 }
                 StatusCode::OK => {
                     let mut body = String::new();
@@ -163,18 +229,96 @@ impl EmbeddingProvider for OpenAIEmbeddings {
                         "openai embedding completed. tokens: {:?}",
                         response.usage.total_tokens
                     );
+
                     return Ok(response
                         .data
                         .into_iter()
-                        .map(|embedding| embedding.embedding)
+                        .map(|embedding| Embedding::from(embedding.embedding))
                         .collect());
                 }
+                StatusCode::TOO_MANY_REQUESTS => {
+                    let mut body = String::new();
+                    response.body_mut().read_to_string(&mut body).await?;
+
+                    let delay_duration = {
+                        let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
+                        if let Some(time_to_reset) =
+                            response.headers().get("x-ratelimit-reset-tokens")
+                        {
+                            if let Ok(time_str) = time_to_reset.to_str() {
+                                parse(time_str).unwrap_or(delay)
+                            } else {
+                                delay
+                            }
+                        } else {
+                            delay
+                        }
+                    };
+
+                    log::trace!(
+                        "openai rate limiting: waiting {:?} until lifted",
+                        &delay_duration
+                    );
+
+                    self.executor.timer(delay_duration).await;
+                }
                 _ => {
-                    return Err(anyhow!("openai embedding failed {}", response.status()));
+                    let mut body = String::new();
+                    response.body_mut().read_to_string(&mut body).await?;
+                    return Err(anyhow!(
+                        "open ai bad request: {:?} {:?}",
+                        &response.status(),
+                        body
+                    ));
                 }
             }
         }
+        Err(anyhow!("openai max retries"))
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use rand::prelude::*;
+
+    #[gpui::test]
+    fn test_similarity(mut rng: StdRng) {
+        assert_eq!(
+            Embedding::from(vec![1., 0., 0., 0., 0.])
+                .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
+            0.
+        );
+        assert_eq!(
+            Embedding::from(vec![2., 0., 0., 0., 0.])
+                .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
+            6.
+        );
 
-        Err(anyhow!("openai embedding failed"))
+        for _ in 0..100 {
+            let size = 1536;
+            let mut a = vec![0.; size];
+            let mut b = vec![0.; size];
+            for (a, b) in a.iter_mut().zip(b.iter_mut()) {
+                *a = rng.gen();
+                *b = rng.gen();
+            }
+            let a = Embedding::from(a);
+            let b = Embedding::from(b);
+
+            assert_eq!(
+                round_to_decimals(a.similarity(&b), 1),
+                round_to_decimals(reference_dot(&a.0, &b.0), 1)
+            );
+        }
+
+        fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
+            let factor = (10.0 as f32).powi(decimal_places);
+            (n * factor).round() / factor
+        }
+
+        fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
+            a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
+        }
     }
 }

crates/semantic_index/src/embedding_queue.rs 🔗

@@ -0,0 +1,173 @@
+use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle};
+use gpui::executor::Background;
+use parking_lot::Mutex;
+use smol::channel;
+use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime};
+
+#[derive(Clone)]
+pub struct FileToEmbed {
+    pub worktree_id: i64,
+    pub path: PathBuf,
+    pub mtime: SystemTime,
+    pub documents: Vec<Document>,
+    pub job_handle: JobHandle,
+}
+
+impl std::fmt::Debug for FileToEmbed {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("FileToEmbed")
+            .field("worktree_id", &self.worktree_id)
+            .field("path", &self.path)
+            .field("mtime", &self.mtime)
+            .field("document", &self.documents)
+            .finish_non_exhaustive()
+    }
+}
+
+impl PartialEq for FileToEmbed {
+    fn eq(&self, other: &Self) -> bool {
+        self.worktree_id == other.worktree_id
+            && self.path == other.path
+            && self.mtime == other.mtime
+            && self.documents == other.documents
+    }
+}
+
+pub struct EmbeddingQueue {
+    embedding_provider: Arc<dyn EmbeddingProvider>,
+    pending_batch: Vec<FileToEmbedFragment>,
+    executor: Arc<Background>,
+    pending_batch_token_count: usize,
+    finished_files_tx: channel::Sender<FileToEmbed>,
+    finished_files_rx: channel::Receiver<FileToEmbed>,
+}
+
+#[derive(Clone)]
+pub struct FileToEmbedFragment {
+    file: Arc<Mutex<FileToEmbed>>,
+    document_range: Range<usize>,
+}
+
+impl EmbeddingQueue {
+    pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
+        let (finished_files_tx, finished_files_rx) = channel::unbounded();
+        Self {
+            embedding_provider,
+            executor,
+            pending_batch: Vec::new(),
+            pending_batch_token_count: 0,
+            finished_files_tx,
+            finished_files_rx,
+        }
+    }
+
+    pub fn push(&mut self, file: FileToEmbed) {
+        if file.documents.is_empty() {
+            self.finished_files_tx.try_send(file).unwrap();
+            return;
+        }
+
+        let file = Arc::new(Mutex::new(file));
+
+        self.pending_batch.push(FileToEmbedFragment {
+            file: file.clone(),
+            document_range: 0..0,
+        });
+
+        let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
+        let mut saved_tokens = 0;
+        for (ix, document) in file.lock().documents.iter().enumerate() {
+            let document_token_count = if document.embedding.is_none() {
+                document.token_count
+            } else {
+                saved_tokens += document.token_count;
+                0
+            };
+
+            let next_token_count = self.pending_batch_token_count + document_token_count;
+            if next_token_count > self.embedding_provider.max_tokens_per_batch() {
+                let range_end = fragment_range.end;
+                self.flush();
+                self.pending_batch.push(FileToEmbedFragment {
+                    file: file.clone(),
+                    document_range: range_end..range_end,
+                });
+                fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
+            }
+
+            fragment_range.end = ix + 1;
+            self.pending_batch_token_count += document_token_count;
+        }
+        log::trace!("Saved Tokens: {:?}", saved_tokens);
+    }
+
+    pub fn flush(&mut self) {
+        let batch = mem::take(&mut self.pending_batch);
+        self.pending_batch_token_count = 0;
+        if batch.is_empty() {
+            return;
+        }
+
+        let finished_files_tx = self.finished_files_tx.clone();
+        let embedding_provider = self.embedding_provider.clone();
+
+        self.executor.spawn(async move {
+            let mut spans = Vec::new();
+            let mut document_count = 0;
+            for fragment in &batch {
+                let file = fragment.file.lock();
+                document_count += file.documents[fragment.document_range.clone()].len();
+                spans.extend(
+                    {
+                        file.documents[fragment.document_range.clone()]
+                            .iter().filter(|d| d.embedding.is_none())
+                            .map(|d| d.content.clone())
+                        }
+                );
+            }
+
+            log::trace!("Documents Length: {:?}", document_count);
+            log::trace!("Span Length: {:?}", spans.clone().len());
+
+            // If spans is 0, just send the fragment to the finished files if its the last one.
+            if spans.len() == 0 {
+                for fragment in batch.clone() {
+                    if let Some(file) = Arc::into_inner(fragment.file) {
+                        finished_files_tx.try_send(file.into_inner()).unwrap();
+                    }
+                }
+                return;
+            };
+
+            match embedding_provider.embed_batch(spans).await {
+                Ok(embeddings) => {
+                    let mut embeddings = embeddings.into_iter();
+                    for fragment in batch {
+                        for document in
+                            &mut fragment.file.lock().documents[fragment.document_range.clone()].iter_mut().filter(|d| d.embedding.is_none())
+                        {
+                            if let Some(embedding) = embeddings.next() {
+                                document.embedding = Some(embedding);
+                            } else {
+                                //
+                                log::error!("number of embeddings returned different from number of documents");
+                            }
+                        }
+
+                        if let Some(file) = Arc::into_inner(fragment.file) {
+                            finished_files_tx.try_send(file.into_inner()).unwrap();
+                        }
+                    }
+                }
+                Err(error) => {
+                    log::error!("{:?}", error);
+                }
+            }
+        })
+        .detach();
+    }
+
+    pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
+        self.finished_files_rx.clone()
+    }
+}

crates/semantic_index/src/parsing.rs 🔗

@@ -1,5 +1,10 @@
-use anyhow::{anyhow, Ok, Result};
+use crate::embedding::{Embedding, EmbeddingProvider};
+use anyhow::{anyhow, Result};
 use language::{Grammar, Language};
+use rusqlite::{
+    types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
+    ToSql,
+};
 use sha1::{Digest, Sha1};
 use std::{
     cmp::{self, Reverse},
@@ -10,13 +15,44 @@ use std::{
 };
 use tree_sitter::{Parser, QueryCursor};
 
+#[derive(Debug, PartialEq, Eq, Clone, Hash)]
+pub struct DocumentDigest([u8; 20]);
+
+impl FromSql for DocumentDigest {
+    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
+        let blob = value.as_blob()?;
+        let bytes =
+            blob.try_into()
+                .map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
+                    expected_size: 20,
+                    blob_size: blob.len(),
+                })?;
+        return Ok(DocumentDigest(bytes));
+    }
+}
+
+impl ToSql for DocumentDigest {
+    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
+        self.0.to_sql()
+    }
+}
+
+impl From<&'_ str> for DocumentDigest {
+    fn from(value: &'_ str) -> Self {
+        let mut sha1 = Sha1::new();
+        sha1.update(value);
+        Self(sha1.finalize().into())
+    }
+}
+
 #[derive(Debug, PartialEq, Clone)]
 pub struct Document {
     pub name: String,
     pub range: Range<usize>,
     pub content: String,
-    pub embedding: Vec<f32>,
-    pub sha1: [u8; 20],
+    pub embedding: Option<Embedding>,
+    pub digest: DocumentDigest,
+    pub token_count: usize,
 }
 
 const CODE_CONTEXT_TEMPLATE: &str =
@@ -30,6 +66,7 @@ pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] =
 pub struct CodeContextRetriever {
     pub parser: Parser,
     pub cursor: QueryCursor,
+    pub embedding_provider: Arc<dyn EmbeddingProvider>,
 }
 
 // Every match has an item, this represents the fundamental treesitter symbol and anchors the search
@@ -47,10 +84,11 @@ pub struct CodeContextMatch {
 }
 
 impl CodeContextRetriever {
-    pub fn new() -> Self {
+    pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
         Self {
             parser: Parser::new(),
             cursor: QueryCursor::new(),
+            embedding_provider,
         }
     }
 
@@ -64,16 +102,15 @@ impl CodeContextRetriever {
             .replace("<path>", relative_path.to_string_lossy().as_ref())
             .replace("<language>", language_name.as_ref())
             .replace("<item>", &content);
-
-        let mut sha1 = Sha1::new();
-        sha1.update(&document_span);
-
+        let digest = DocumentDigest::from(document_span.as_str());
+        let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
         Ok(vec![Document {
             range: 0..content.len(),
             content: document_span,
-            embedding: Vec::new(),
+            embedding: Default::default(),
             name: language_name.to_string(),
-            sha1: sha1.finalize().into(),
+            digest,
+            token_count,
         }])
     }
 
@@ -81,16 +118,15 @@ impl CodeContextRetriever {
         let document_span = MARKDOWN_CONTEXT_TEMPLATE
             .replace("<path>", relative_path.to_string_lossy().as_ref())
             .replace("<item>", &content);
-
-        let mut sha1 = Sha1::new();
-        sha1.update(&document_span);
-
+        let digest = DocumentDigest::from(document_span.as_str());
+        let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
         Ok(vec![Document {
             range: 0..content.len(),
             content: document_span,
-            embedding: Vec::new(),
+            embedding: None,
             name: "Markdown".to_string(),
-            sha1: sha1.finalize().into(),
+            digest,
+            token_count,
         }])
     }
 
@@ -166,10 +202,16 @@ impl CodeContextRetriever {
 
         let mut documents = self.parse_file(content, language)?;
         for document in &mut documents {
-            document.content = CODE_CONTEXT_TEMPLATE
+            let document_content = CODE_CONTEXT_TEMPLATE
                 .replace("<path>", relative_path.to_string_lossy().as_ref())
                 .replace("<language>", language_name.as_ref())
                 .replace("item", &document.content);
+
+            let (document_content, token_count) =
+                self.embedding_provider.truncate(&document_content);
+
+            document.content = document_content;
+            document.token_count = token_count;
         }
         Ok(documents)
     }
@@ -263,15 +305,14 @@ impl CodeContextRetriever {
                 );
             }
 
-            let mut sha1 = Sha1::new();
-            sha1.update(&document_content);
-
+            let sha1 = DocumentDigest::from(document_content.as_str());
             documents.push(Document {
                 name,
                 content: document_content,
                 range: item_range.clone(),
-                embedding: vec![],
-                sha1: sha1.finalize().into(),
+                embedding: None,
+                digest: sha1,
+                token_count: 0,
             })
         }
 

crates/semantic_index/src/semantic_index.rs 🔗

@@ -1,5 +1,6 @@
 mod db;
 mod embedding;
+mod embedding_queue;
 mod parsing;
 pub mod semantic_index_settings;
 
@@ -9,23 +10,25 @@ mod semantic_index_tests;
 use crate::semantic_index_settings::SemanticIndexSettings;
 use anyhow::{anyhow, Result};
 use db::VectorDatabase;
-use embedding::{EmbeddingProvider, OpenAIEmbeddings};
-use futures::{channel::oneshot, Future};
+use embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
+use embedding_queue::{EmbeddingQueue, FileToEmbed};
+use futures::{FutureExt, StreamExt};
 use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
 use language::{Anchor, Buffer, Language, LanguageRegistry};
 use parking_lot::Mutex;
-use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES};
+use parsing::{CodeContextRetriever, DocumentDigest, PARSEABLE_ENTIRE_FILE_TYPES};
 use postage::watch;
-use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, WorktreeId};
+use project::{
+    search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, ProjectPath, Worktree, WorktreeId,
+};
 use smol::channel;
 use std::{
     cmp::Ordering,
-    collections::HashMap,
-    mem,
+    collections::{BTreeMap, HashMap},
     ops::Range,
     path::{Path, PathBuf},
     sync::{Arc, Weak},
-    time::{Instant, SystemTime},
+    time::{Duration, Instant, SystemTime},
 };
 use util::{
     channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
@@ -35,8 +38,9 @@ use util::{
 };
 use workspace::WorkspaceCreated;
 
-const SEMANTIC_INDEX_VERSION: usize = 7;
-const EMBEDDINGS_BATCH_SIZE: usize = 80;
+const SEMANTIC_INDEX_VERSION: usize = 9;
+const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
+const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
 
 pub fn init(
     fs: Arc<dyn Fs>,
@@ -97,14 +101,11 @@ pub fn init(
 
 pub struct SemanticIndex {
     fs: Arc<dyn Fs>,
-    database_url: Arc<PathBuf>,
+    db: VectorDatabase,
     embedding_provider: Arc<dyn EmbeddingProvider>,
     language_registry: Arc<LanguageRegistry>,
-    db_update_tx: channel::Sender<DbOperation>,
-    parsing_files_tx: channel::Sender<PendingFile>,
-    _db_update_task: Task<()>,
-    _embed_batch_tasks: Vec<Task<()>>,
-    _batch_files_task: Task<()>,
+    parsing_files_tx: channel::Sender<(Arc<HashMap<DocumentDigest, Embedding>>, PendingFile)>,
+    _embedding_task: Task<()>,
     _parsing_files_tasks: Vec<Task<()>>,
     projects: HashMap<WeakModelHandle<Project>, ProjectState>,
 }
@@ -113,13 +114,18 @@ struct ProjectState {
     worktree_db_ids: Vec<(WorktreeId, i64)>,
     _subscription: gpui::Subscription,
     outstanding_job_count_rx: watch::Receiver<usize>,
-    _outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
-    job_queue_tx: channel::Sender<IndexOperation>,
-    _queue_update_task: Task<()>,
+    outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
+    changed_paths: BTreeMap<ProjectPath, ChangedPathInfo>,
+}
+
+struct ChangedPathInfo {
+    changed_at: Instant,
+    mtime: SystemTime,
+    is_deleted: bool,
 }
 
 #[derive(Clone)]
-struct JobHandle {
+pub struct JobHandle {
     /// The outer Arc is here to count the clones of a JobHandle instance;
     /// when the last handle to a given job is dropped, we decrement a counter (just once).
     tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
@@ -133,31 +139,21 @@ impl JobHandle {
         }
     }
 }
+
 impl ProjectState {
     fn new(
-        cx: &mut AppContext,
         subscription: gpui::Subscription,
         worktree_db_ids: Vec<(WorktreeId, i64)>,
-        outstanding_job_count_rx: watch::Receiver<usize>,
-        _outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
+        changed_paths: BTreeMap<ProjectPath, ChangedPathInfo>,
     ) -> Self {
-        let (job_queue_tx, job_queue_rx) = channel::unbounded();
-        let _queue_update_task = cx.background().spawn({
-            let mut worktree_queue = HashMap::new();
-            async move {
-                while let Ok(operation) = job_queue_rx.recv().await {
-                    Self::update_queue(&mut worktree_queue, operation);
-                }
-            }
-        });
-
+        let (outstanding_job_count_tx, outstanding_job_count_rx) = watch::channel_with(0);
+        let outstanding_job_count_tx = Arc::new(Mutex::new(outstanding_job_count_tx));
         Self {
             worktree_db_ids,
             outstanding_job_count_rx,
-            _outstanding_job_count_tx,
+            outstanding_job_count_tx,
+            changed_paths,
             _subscription: subscription,
-            _queue_update_task,
-            job_queue_tx,
         }
     }
 
@@ -165,41 +161,6 @@ impl ProjectState {
         self.outstanding_job_count_rx.borrow().clone()
     }
 
-    fn update_queue(queue: &mut HashMap<PathBuf, IndexOperation>, operation: IndexOperation) {
-        match operation {
-            IndexOperation::FlushQueue => {
-                let queue = std::mem::take(queue);
-                for (_, op) in queue {
-                    match op {
-                        IndexOperation::IndexFile {
-                            absolute_path: _,
-                            payload,
-                            tx,
-                        } => {
-                            let _ = tx.try_send(payload);
-                        }
-                        IndexOperation::DeleteFile {
-                            absolute_path: _,
-                            payload,
-                            tx,
-                        } => {
-                            let _ = tx.try_send(payload);
-                        }
-                        _ => {}
-                    }
-                }
-            }
-            IndexOperation::IndexFile {
-                ref absolute_path, ..
-            }
-            | IndexOperation::DeleteFile {
-                ref absolute_path, ..
-            } => {
-                queue.insert(absolute_path.clone(), operation);
-            }
-        }
-    }
-
     fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
         self.worktree_db_ids
             .iter()
@@ -230,66 +191,16 @@ pub struct PendingFile {
     worktree_db_id: i64,
     relative_path: PathBuf,
     absolute_path: PathBuf,
-    language: Arc<Language>,
+    language: Option<Arc<Language>>,
     modified_time: SystemTime,
     job_handle: JobHandle,
 }
-enum IndexOperation {
-    IndexFile {
-        absolute_path: PathBuf,
-        payload: PendingFile,
-        tx: channel::Sender<PendingFile>,
-    },
-    DeleteFile {
-        absolute_path: PathBuf,
-        payload: DbOperation,
-        tx: channel::Sender<DbOperation>,
-    },
-    FlushQueue,
-}
 
 pub struct SearchResult {
     pub buffer: ModelHandle<Buffer>,
     pub range: Range<Anchor>,
 }
 
-enum DbOperation {
-    InsertFile {
-        worktree_id: i64,
-        documents: Vec<Document>,
-        path: PathBuf,
-        mtime: SystemTime,
-        job_handle: JobHandle,
-    },
-    Delete {
-        worktree_id: i64,
-        path: PathBuf,
-    },
-    FindOrCreateWorktree {
-        path: PathBuf,
-        sender: oneshot::Sender<Result<i64>>,
-    },
-    FileMTimes {
-        worktree_id: i64,
-        sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
-    },
-    WorktreePreviouslyIndexed {
-        path: Arc<Path>,
-        sender: oneshot::Sender<Result<bool>>,
-    },
-}
-
-enum EmbeddingJob {
-    Enqueue {
-        worktree_id: i64,
-        path: PathBuf,
-        mtime: SystemTime,
-        documents: Vec<Document>,
-        job_handle: JobHandle,
-    },
-    Flush,
-}
-
 impl SemanticIndex {
     pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
         if cx.has_global::<ModelHandle<Self>>() {
@@ -306,18 +217,14 @@ impl SemanticIndex {
 
     async fn new(
         fs: Arc<dyn Fs>,
-        database_url: PathBuf,
+        database_path: PathBuf,
         embedding_provider: Arc<dyn EmbeddingProvider>,
         language_registry: Arc<LanguageRegistry>,
         mut cx: AsyncAppContext,
     ) -> Result<ModelHandle<Self>> {
         let t0 = Instant::now();
-        let database_url = Arc::new(database_url);
-
-        let db = cx
-            .background()
-            .spawn(VectorDatabase::new(fs.clone(), database_url.clone()))
-            .await?;
+        let database_path = Arc::from(database_path);
+        let db = VectorDatabase::new(fs.clone(), database_path, cx.background()).await?;
 
         log::trace!(
             "db initialization took {:?} milliseconds",
@@ -326,73 +233,55 @@ impl SemanticIndex {
 
         Ok(cx.add_model(|cx| {
             let t0 = Instant::now();
-            // Perform database operations
-            let (db_update_tx, db_update_rx) = channel::unbounded();
-            let _db_update_task = cx.background().spawn({
+            let embedding_queue =
+                EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
+            let _embedding_task = cx.background().spawn({
+                let embedded_files = embedding_queue.finished_files();
+                let db = db.clone();
                 async move {
-                    while let Ok(job) = db_update_rx.recv().await {
-                        Self::run_db_operation(&db, job)
-                    }
-                }
-            });
-
-            // Group documents into batches and send them to the embedding provider.
-            let (embed_batch_tx, embed_batch_rx) =
-                channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
-            let mut _embed_batch_tasks = Vec::new();
-            for _ in 0..cx.background().num_cpus() {
-                let embed_batch_rx = embed_batch_rx.clone();
-                _embed_batch_tasks.push(cx.background().spawn({
-                    let db_update_tx = db_update_tx.clone();
-                    let embedding_provider = embedding_provider.clone();
-                    async move {
-                        while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
-                            Self::compute_embeddings_for_batch(
-                                embeddings_queue,
-                                &embedding_provider,
-                                &db_update_tx,
-                            )
-                            .await;
-                        }
+                    while let Ok(file) = embedded_files.recv().await {
+                        db.insert_file(file.worktree_id, file.path, file.mtime, file.documents)
+                            .await
+                            .log_err();
                     }
-                }));
-            }
-
-            // Group documents into batches and send them to the embedding provider.
-            let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
-            let _batch_files_task = cx.background().spawn(async move {
-                let mut queue_len = 0;
-                let mut embeddings_queue = vec![];
-                while let Ok(job) = batch_files_rx.recv().await {
-                    Self::enqueue_documents_to_embed(
-                        job,
-                        &mut queue_len,
-                        &mut embeddings_queue,
-                        &embed_batch_tx,
-                    );
                 }
             });
 
             // Parse files into embeddable documents.
-            let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
+            let (parsing_files_tx, parsing_files_rx) =
+                channel::unbounded::<(Arc<HashMap<DocumentDigest, Embedding>>, PendingFile)>();
+            let embedding_queue = Arc::new(Mutex::new(embedding_queue));
             let mut _parsing_files_tasks = Vec::new();
             for _ in 0..cx.background().num_cpus() {
                 let fs = fs.clone();
-                let parsing_files_rx = parsing_files_rx.clone();
-                let batch_files_tx = batch_files_tx.clone();
-                let db_update_tx = db_update_tx.clone();
+                let mut parsing_files_rx = parsing_files_rx.clone();
+                let embedding_provider = embedding_provider.clone();
+                let embedding_queue = embedding_queue.clone();
+                let background = cx.background().clone();
                 _parsing_files_tasks.push(cx.background().spawn(async move {
-                    let mut retriever = CodeContextRetriever::new();
-                    while let Ok(pending_file) = parsing_files_rx.recv().await {
-                        Self::parse_file(
-                            &fs,
-                            pending_file,
-                            &mut retriever,
-                            &batch_files_tx,
-                            &parsing_files_rx,
-                            &db_update_tx,
-                        )
-                        .await;
+                    let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
+                    loop {
+                        let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
+                        let mut next_file_to_parse = parsing_files_rx.next().fuse();
+                        futures::select_biased! {
+                            next_file_to_parse = next_file_to_parse => {
+                                if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
+                                    Self::parse_file(
+                                        &fs,
+                                        pending_file,
+                                        &mut retriever,
+                                        &embedding_queue,
+                                        &embeddings_for_digest,
+                                    )
+                                    .await
+                                } else {
+                                    break;
+                                }
+                            },
+                            _ = timer => {
+                                embedding_queue.lock().flush();
+                            }
+                        }
                     }
                 }));
             }
@@ -403,192 +292,31 @@ impl SemanticIndex {
             );
             Self {
                 fs,
-                database_url,
+                db,
                 embedding_provider,
                 language_registry,
-                db_update_tx,
                 parsing_files_tx,
-                _db_update_task,
-                _embed_batch_tasks,
-                _batch_files_task,
+                _embedding_task,
                 _parsing_files_tasks,
                 projects: HashMap::new(),
             }
         }))
     }
 
-    fn run_db_operation(db: &VectorDatabase, job: DbOperation) {
-        match job {
-            DbOperation::InsertFile {
-                worktree_id,
-                documents,
-                path,
-                mtime,
-                job_handle,
-            } => {
-                db.insert_file(worktree_id, path, mtime, documents)
-                    .log_err();
-                drop(job_handle)
-            }
-            DbOperation::Delete { worktree_id, path } => {
-                db.delete_file(worktree_id, path).log_err();
-            }
-            DbOperation::FindOrCreateWorktree { path, sender } => {
-                let id = db.find_or_create_worktree(&path);
-                sender.send(id).ok();
-            }
-            DbOperation::FileMTimes {
-                worktree_id: worktree_db_id,
-                sender,
-            } => {
-                let file_mtimes = db.get_file_mtimes(worktree_db_id);
-                sender.send(file_mtimes).ok();
-            }
-            DbOperation::WorktreePreviouslyIndexed { path, sender } => {
-                let worktree_indexed = db.worktree_previously_indexed(path.as_ref());
-                sender.send(worktree_indexed).ok();
-            }
-        }
-    }
-
-    async fn compute_embeddings_for_batch(
-        mut embeddings_queue: Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
-        embedding_provider: &Arc<dyn EmbeddingProvider>,
-        db_update_tx: &channel::Sender<DbOperation>,
-    ) {
-        let mut batch_documents = vec![];
-        for (_, documents, _, _, _) in embeddings_queue.iter() {
-            batch_documents.extend(documents.iter().map(|document| document.content.as_str()));
-        }
-
-        if let Ok(embeddings) = embedding_provider.embed_batch(batch_documents).await {
-            log::trace!(
-                "created {} embeddings for {} files",
-                embeddings.len(),
-                embeddings_queue.len(),
-            );
-
-            let mut i = 0;
-            let mut j = 0;
-
-            for embedding in embeddings.iter() {
-                while embeddings_queue[i].1.len() == j {
-                    i += 1;
-                    j = 0;
-                }
-
-                embeddings_queue[i].1[j].embedding = embedding.to_owned();
-                j += 1;
-            }
-
-            for (worktree_id, documents, path, mtime, job_handle) in embeddings_queue.into_iter() {
-                db_update_tx
-                    .send(DbOperation::InsertFile {
-                        worktree_id,
-                        documents,
-                        path,
-                        mtime,
-                        job_handle,
-                    })
-                    .await
-                    .unwrap();
-            }
-        } else {
-            // Insert the file in spite of failure so that future attempts to index it do not take place (unless the file is changed).
-            for (worktree_id, _, path, mtime, job_handle) in embeddings_queue.into_iter() {
-                db_update_tx
-                    .send(DbOperation::InsertFile {
-                        worktree_id,
-                        documents: vec![],
-                        path,
-                        mtime,
-                        job_handle,
-                    })
-                    .await
-                    .unwrap();
-            }
-        }
-    }
-
-    fn enqueue_documents_to_embed(
-        job: EmbeddingJob,
-        queue_len: &mut usize,
-        embeddings_queue: &mut Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
-        embed_batch_tx: &channel::Sender<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>,
-    ) {
-        // Handle edge case where individual file has more documents than max batch size
-        let should_flush = match job {
-            EmbeddingJob::Enqueue {
-                documents,
-                worktree_id,
-                path,
-                mtime,
-                job_handle,
-            } => {
-                // If documents is greater than embeddings batch size, recursively batch existing rows.
-                if &documents.len() > &EMBEDDINGS_BATCH_SIZE {
-                    let first_job = EmbeddingJob::Enqueue {
-                        documents: documents[..EMBEDDINGS_BATCH_SIZE].to_vec(),
-                        worktree_id,
-                        path: path.clone(),
-                        mtime,
-                        job_handle: job_handle.clone(),
-                    };
-
-                    Self::enqueue_documents_to_embed(
-                        first_job,
-                        queue_len,
-                        embeddings_queue,
-                        embed_batch_tx,
-                    );
-
-                    let second_job = EmbeddingJob::Enqueue {
-                        documents: documents[EMBEDDINGS_BATCH_SIZE..].to_vec(),
-                        worktree_id,
-                        path: path.clone(),
-                        mtime,
-                        job_handle: job_handle.clone(),
-                    };
-
-                    Self::enqueue_documents_to_embed(
-                        second_job,
-                        queue_len,
-                        embeddings_queue,
-                        embed_batch_tx,
-                    );
-                    return;
-                } else {
-                    *queue_len += &documents.len();
-                    embeddings_queue.push((worktree_id, documents, path, mtime, job_handle));
-                    *queue_len >= EMBEDDINGS_BATCH_SIZE
-                }
-            }
-            EmbeddingJob::Flush => true,
-        };
-
-        if should_flush {
-            embed_batch_tx
-                .try_send(mem::take(embeddings_queue))
-                .unwrap();
-            *queue_len = 0;
-        }
-    }
-
     async fn parse_file(
         fs: &Arc<dyn Fs>,
         pending_file: PendingFile,
         retriever: &mut CodeContextRetriever,
-        batch_files_tx: &channel::Sender<EmbeddingJob>,
-        parsing_files_rx: &channel::Receiver<PendingFile>,
-        db_update_tx: &channel::Sender<DbOperation>,
+        embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
+        embeddings_for_digest: &HashMap<DocumentDigest, Embedding>,
     ) {
+        let Some(language) = pending_file.language else {
+            return;
+        };
+
         if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
-            if let Some(documents) = retriever
-                .parse_file_with_template(
-                    &pending_file.relative_path,
-                    &content,
-                    pending_file.language,
-                )
+            if let Some(mut documents) = retriever
+                .parse_file_with_template(&pending_file.relative_path, &content, language)
                 .log_err()
             {
                 log::trace!(
@@ -597,66 +325,23 @@ impl SemanticIndex {
                     documents.len()
                 );
 
-                if documents.len() == 0 {
-                    db_update_tx
-                        .send(DbOperation::InsertFile {
-                            worktree_id: pending_file.worktree_db_id,
-                            documents,
-                            path: pending_file.relative_path,
-                            mtime: pending_file.modified_time,
-                            job_handle: pending_file.job_handle,
-                        })
-                        .await
-                        .unwrap();
-                } else {
-                    batch_files_tx
-                        .try_send(EmbeddingJob::Enqueue {
-                            worktree_id: pending_file.worktree_db_id,
-                            path: pending_file.relative_path,
-                            mtime: pending_file.modified_time,
-                            job_handle: pending_file.job_handle,
-                            documents,
-                        })
-                        .unwrap();
+                for document in documents.iter_mut() {
+                    if let Some(embedding) = embeddings_for_digest.get(&document.digest) {
+                        document.embedding = Some(embedding.to_owned());
+                    }
                 }
-            }
-        }
 
-        if parsing_files_rx.len() == 0 {
-            batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
+                embedding_queue.lock().push(FileToEmbed {
+                    worktree_id: pending_file.worktree_db_id,
+                    path: pending_file.relative_path,
+                    mtime: pending_file.modified_time,
+                    job_handle: pending_file.job_handle,
+                    documents,
+                });
+            }
         }
     }
 
-    fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
-        let (tx, rx) = oneshot::channel();
-        self.db_update_tx
-            .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
-            .unwrap();
-        async move { rx.await? }
-    }
-
-    fn get_file_mtimes(
-        &self,
-        worktree_id: i64,
-    ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
-        let (tx, rx) = oneshot::channel();
-        self.db_update_tx
-            .try_send(DbOperation::FileMTimes {
-                worktree_id,
-                sender: tx,
-            })
-            .unwrap();
-        async move { rx.await? }
-    }
-
-    fn worktree_previously_indexed(&self, path: Arc<Path>) -> impl Future<Output = Result<bool>> {
-        let (tx, rx) = oneshot::channel();
-        self.db_update_tx
-            .try_send(DbOperation::WorktreePreviouslyIndexed { path, sender: tx })
-            .unwrap();
-        async move { rx.await? }
-    }
-
     pub fn project_previously_indexed(
         &mut self,
         project: ModelHandle<Project>,
@@ -665,7 +350,10 @@ impl SemanticIndex {
         let worktrees_indexed_previously = project
             .read(cx)
             .worktrees(cx)
-            .map(|worktree| self.worktree_previously_indexed(worktree.read(cx).abs_path()))
+            .map(|worktree| {
+                self.db
+                    .worktree_previously_indexed(&worktree.read(cx).abs_path())
+            })
             .collect::<Vec<_>>();
         cx.spawn(|_, _cx| async move {
             let worktree_indexed_previously =
@@ -679,103 +367,73 @@ impl SemanticIndex {
     }
 
     fn project_entries_changed(
-        &self,
+        &mut self,
         project: ModelHandle<Project>,
         changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
         cx: &mut ModelContext<'_, SemanticIndex>,
         worktree_id: &WorktreeId,
-    ) -> Result<()> {
-        let parsing_files_tx = self.parsing_files_tx.clone();
-        let db_update_tx = self.db_update_tx.clone();
-        let (job_queue_tx, outstanding_job_tx, worktree_db_id) = {
-            let state = self
-                .projects
-                .get(&project.downgrade())
-                .ok_or(anyhow!("Project not yet initialized"))?;
-            let worktree_db_id = state
-                .db_id_for_worktree_id(*worktree_id)
-                .ok_or(anyhow!("Worktree ID in Database Not Available"))?;
-            (
-                state.job_queue_tx.clone(),
-                state._outstanding_job_count_tx.clone(),
-                worktree_db_id,
-            )
+    ) {
+        let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
+            return;
+        };
+        let project = project.downgrade();
+        let Some(project_state) = self.projects.get_mut(&project) else {
+            return;
         };
 
-        let language_registry = self.language_registry.clone();
-        let parsing_files_tx = parsing_files_tx.clone();
-        let db_update_tx = db_update_tx.clone();
-
-        let worktree = project
-            .read(cx)
-            .worktree_for_id(worktree_id.clone(), cx)
-            .ok_or(anyhow!("Worktree not available"))?
-            .read(cx)
-            .snapshot();
-        cx.spawn(|_, _| async move {
-            let worktree = worktree.clone();
-            for (path, entry_id, path_change) in changes.iter() {
-                let relative_path = path.to_path_buf();
-                let absolute_path = worktree.absolutize(path);
-
-                let Some(entry) = worktree.entry_for_id(*entry_id) else {
-                    continue;
-                };
-                if entry.is_ignored || entry.is_symlink || entry.is_external {
-                    continue;
+        let embeddings_for_digest = {
+            let mut worktree_id_file_paths = HashMap::new();
+            for (path, _) in &project_state.changed_paths {
+                if let Some(worktree_db_id) = project_state.db_id_for_worktree_id(path.worktree_id)
+                {
+                    worktree_id_file_paths
+                        .entry(worktree_db_id)
+                        .or_insert(Vec::new())
+                        .push(path.path.clone());
                 }
+            }
+            self.db.embeddings_for_files(worktree_id_file_paths)
+        };
 
-                log::trace!("File Event: {:?}, Path: {:?}", &path_change, &path);
-                match path_change {
-                    PathChange::AddedOrUpdated | PathChange::Updated | PathChange::Added => {
-                        if let Ok(language) = language_registry
-                            .language_for_file(&relative_path, None)
-                            .await
-                        {
-                            if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
-                                && &language.name().as_ref() != &"Markdown"
-                                && language
-                                    .grammar()
-                                    .and_then(|grammar| grammar.embedding_config.as_ref())
-                                    .is_none()
-                            {
-                                continue;
-                            }
+        let worktree = worktree.read(cx);
+        let change_time = Instant::now();
+        for (path, entry_id, change) in changes.iter() {
+            let Some(entry) = worktree.entry_for_id(*entry_id) else {
+                continue;
+            };
+            if entry.is_ignored || entry.is_symlink || entry.is_external {
+                continue;
+            }
+            let project_path = ProjectPath {
+                worktree_id: *worktree_id,
+                path: path.clone(),
+            };
+            project_state.changed_paths.insert(
+                project_path,
+                ChangedPathInfo {
+                    changed_at: change_time,
+                    mtime: entry.mtime,
+                    is_deleted: *change == PathChange::Removed,
+                },
+            );
+        }
 
-                            let job_handle = JobHandle::new(&outstanding_job_tx);
-                            let new_operation = IndexOperation::IndexFile {
-                                absolute_path: absolute_path.clone(),
-                                payload: PendingFile {
-                                    worktree_db_id,
-                                    relative_path,
-                                    absolute_path,
-                                    language,
-                                    modified_time: entry.mtime,
-                                    job_handle,
-                                },
-                                tx: parsing_files_tx.clone(),
-                            };
-                            let _ = job_queue_tx.try_send(new_operation);
-                        }
-                    }
-                    PathChange::Removed => {
-                        let new_operation = IndexOperation::DeleteFile {
-                            absolute_path,
-                            payload: DbOperation::Delete {
-                                worktree_id: worktree_db_id,
-                                path: relative_path,
-                            },
-                            tx: db_update_tx.clone(),
-                        };
-                        let _ = job_queue_tx.try_send(new_operation);
-                    }
-                    _ => {}
-                }
+        cx.spawn_weak(|this, mut cx| async move {
+            let embeddings_for_digest = embeddings_for_digest.await.log_err().unwrap_or_default();
+
+            cx.background().timer(BACKGROUND_INDEXING_DELAY).await;
+            if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) {
+                Self::reindex_changed_paths(
+                    this,
+                    project,
+                    Some(change_time),
+                    &mut cx,
+                    Arc::new(embeddings_for_digest),
+                )
+                .await;
             }
         })
         .detach();
-
-        Ok(())
     }
 
     pub fn initialize_project(
@@ -799,20 +457,18 @@ impl SemanticIndex {
             .read(cx)
             .worktrees(cx)
             .map(|worktree| {
-                self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
+                self.db
+                    .find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
             })
             .collect::<Vec<_>>();
 
         let _subscription = cx.subscribe(&project, |this, project, event, cx| {
             if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event {
-                let _ =
-                    this.project_entries_changed(project.clone(), changes.clone(), cx, worktree_id);
+                this.project_entries_changed(project.clone(), changes.clone(), cx, worktree_id);
             };
         });
 
         let language_registry = self.language_registry.clone();
-        let parsing_files_tx = self.parsing_files_tx.clone();
-        let db_update_tx = self.db_update_tx.clone();
 
         cx.spawn(|this, mut cx| async move {
             futures::future::join_all(worktree_scans_complete).await;
@@ -833,7 +489,7 @@ impl SemanticIndex {
                 db_ids_by_worktree_id.insert(worktree.id(), db_id);
                 worktree_file_mtimes.insert(
                     worktree.id(),
-                    this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
+                    this.read_with(&cx, |this, _| this.db.get_file_mtimes(db_id))
                         .await?,
                 );
             }
@@ -843,17 +499,13 @@ impl SemanticIndex {
                 .map(|(a, b)| (*a, *b))
                 .collect();
 
-            let (job_count_tx, job_count_rx) = watch::channel_with(0);
-            let job_count_tx = Arc::new(Mutex::new(job_count_tx));
-            let job_count_tx_longlived = job_count_tx.clone();
-
-            let worktree_files = cx
+            let changed_paths = cx
                 .background()
                 .spawn(async move {
-                    let mut worktree_files = Vec::new();
+                    let mut changed_paths = BTreeMap::new();
+                    let now = Instant::now();
                     for worktree in worktrees.into_iter() {
                         let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap();
-                        let worktree_db_id = db_ids_by_worktree_id[&worktree.id()];
                         for file in worktree.files(false, 0) {
                             let absolute_path = worktree.absolutize(&file.path);
 
@@ -876,59 +528,51 @@ impl SemanticIndex {
                                     continue;
                                 }
 
-                                let path_buf = file.path.to_path_buf();
                                 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
                                 let already_stored = stored_mtime
                                     .map_or(false, |existing_mtime| existing_mtime == file.mtime);
 
                                 if !already_stored {
-                                    let job_handle = JobHandle::new(&job_count_tx);
-                                    worktree_files.push(IndexOperation::IndexFile {
-                                        absolute_path: absolute_path.clone(),
-                                        payload: PendingFile {
-                                            worktree_db_id,
-                                            relative_path: path_buf,
-                                            absolute_path,
-                                            language,
-                                            job_handle,
-                                            modified_time: file.mtime,
+                                    changed_paths.insert(
+                                        ProjectPath {
+                                            worktree_id: worktree.id(),
+                                            path: file.path.clone(),
+                                        },
+                                        ChangedPathInfo {
+                                            changed_at: now,
+                                            mtime: file.mtime,
+                                            is_deleted: false,
                                         },
-                                        tx: parsing_files_tx.clone(),
-                                    });
+                                    );
                                 }
                             }
                         }
+
                         // Clean up entries from database that are no longer in the worktree.
-                        for (path, _) in file_mtimes {
-                            worktree_files.push(IndexOperation::DeleteFile {
-                                absolute_path: worktree.absolutize(path.as_path()),
-                                payload: DbOperation::Delete {
-                                    worktree_id: worktree_db_id,
-                                    path,
+                        for (path, mtime) in file_mtimes {
+                            changed_paths.insert(
+                                ProjectPath {
+                                    worktree_id: worktree.id(),
+                                    path: path.into(),
                                 },
-                                tx: db_update_tx.clone(),
-                            });
+                                ChangedPathInfo {
+                                    changed_at: now,
+                                    mtime,
+                                    is_deleted: true,
+                                },
+                            );
                         }
                     }
 
-                    anyhow::Ok(worktree_files)
+                    anyhow::Ok(changed_paths)
                 })
                 .await?;
 
-            this.update(&mut cx, |this, cx| {
-                let project_state = ProjectState::new(
-                    cx,
-                    _subscription,
-                    worktree_db_ids,
-                    job_count_rx,
-                    job_count_tx_longlived,
+            this.update(&mut cx, |this, _| {
+                this.projects.insert(
+                    project.downgrade(),
+                    ProjectState::new(_subscription, worktree_db_ids, changed_paths),
                 );
-
-                for op in worktree_files {
-                    let _ = project_state.job_queue_tx.try_send(op);
-                }
-
-                this.projects.insert(project.downgrade(), project_state);
             });
             Result::<(), _>::Ok(())
         })
@@ -939,27 +583,45 @@ impl SemanticIndex {
         project: ModelHandle<Project>,
         cx: &mut ModelContext<Self>,
     ) -> Task<Result<(usize, watch::Receiver<usize>)>> {
-        let state = self.projects.get_mut(&project.downgrade());
-        let state = if state.is_none() {
-            return Task::Ready(Some(Err(anyhow!("Project not yet initialized"))));
-        } else {
-            state.unwrap()
-        };
+        cx.spawn(|this, mut cx| async move {
+            let embeddings_for_digest = this.read_with(&cx, |this, _| {
+                if let Some(state) = this.projects.get(&project.downgrade()) {
+                    let mut worktree_id_file_paths = HashMap::default();
+                    for (path, _) in &state.changed_paths {
+                        if let Some(worktree_db_id) = state.db_id_for_worktree_id(path.worktree_id)
+                        {
+                            worktree_id_file_paths
+                                .entry(worktree_db_id)
+                                .or_insert(Vec::new())
+                                .push(path.path.clone());
+                        }
+                    }
+
+                    Ok(this.db.embeddings_for_files(worktree_id_file_paths))
+                } else {
+                    Err(anyhow!("Project not yet initialized"))
+                }
+            })?;
 
-        // let parsing_files_tx = self.parsing_files_tx.clone();
-        // let db_update_tx = self.db_update_tx.clone();
-        let job_count_rx = state.outstanding_job_count_rx.clone();
-        let count = state.get_outstanding_count();
+            let embeddings_for_digest = Arc::new(embeddings_for_digest.await?);
 
-        cx.spawn(|this, mut cx| async move {
-            this.update(&mut cx, |this, _| {
-                let Some(state) = this.projects.get_mut(&project.downgrade()) else {
-                    return;
-                };
-                let _ = state.job_queue_tx.try_send(IndexOperation::FlushQueue);
-            });
+            Self::reindex_changed_paths(
+                this.clone(),
+                project.clone(),
+                None,
+                &mut cx,
+                embeddings_for_digest,
+            )
+            .await;
 
-            Ok((count, job_count_rx))
+            this.update(&mut cx, |this, _cx| {
+                let Some(state) = this.projects.get(&project.downgrade()) else {
+                    return Err(anyhow!("Project not yet initialized"));
+                };
+                let job_count_rx = state.outstanding_job_count_rx.clone();
+                let count = state.get_outstanding_count();
+                Ok((count, job_count_rx))
+            })
         })
     }
 

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -1,14 +1,15 @@
 use crate::{
-    db::dot,
-    embedding::EmbeddingProvider,
-    parsing::{subtract_ranges, CodeContextRetriever, Document},
+    embedding::{DummyEmbeddings, Embedding, EmbeddingProvider},
+    embedding_queue::EmbeddingQueue,
+    parsing::{subtract_ranges, CodeContextRetriever, Document, DocumentDigest},
     semantic_index_settings::SemanticIndexSettings,
-    SearchResult, SemanticIndex,
+    FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
 };
 use anyhow::Result;
 use async_trait::async_trait;
-use gpui::{Task, TestAppContext};
+use gpui::{executor::Deterministic, Task, TestAppContext};
 use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
+use parking_lot::Mutex;
 use pretty_assertions::assert_eq;
 use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs, Project};
 use rand::{rngs::StdRng, Rng};
@@ -20,8 +21,10 @@ use std::{
         atomic::{self, AtomicUsize},
         Arc,
     },
+    time::SystemTime,
 };
 use unindent::Unindent;
+use util::RandomCharIter;
 
 #[ctor::ctor]
 fn init_logger() {
@@ -31,12 +34,8 @@ fn init_logger() {
 }
 
 #[gpui::test]
-async fn test_semantic_index(cx: &mut TestAppContext) {
-    cx.update(|cx| {
-        cx.set_global(SettingsStore::test(cx));
-        settings::register::<SemanticIndexSettings>(cx);
-        settings::register::<ProjectSettings>(cx);
-    });
+async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestAppContext) {
+    init_test(cx);
 
     let fs = FakeFs::new(cx.background());
     fs.insert_tree(
@@ -56,6 +55,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
                     fn bbb() {
                         println!(\"bbbbbbbbbbbbb!\");
                     }
+                    struct pqpqpqp {}
                 ".unindent(),
                 "file3.toml": "
                     ZZZZZZZZZZZZZZZZZZ = 5
@@ -75,7 +75,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
     let db_path = db_dir.path().join("db.sqlite");
 
     let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
-    let store = SemanticIndex::new(
+    let semantic_index = SemanticIndex::new(
         fs.clone(),
         db_path,
         embedding_provider.clone(),
@@ -87,21 +87,21 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
 
     let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
 
-    let _ = store
+    let _ = semantic_index
         .update(cx, |store, cx| {
             store.initialize_project(project.clone(), cx)
         })
         .await;
 
-    let (file_count, outstanding_file_count) = store
+    let (file_count, outstanding_file_count) = semantic_index
         .update(cx, |store, cx| store.index_project(project.clone(), cx))
         .await
         .unwrap();
     assert_eq!(file_count, 3);
-    cx.foreground().run_until_parked();
+    deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
     assert_eq!(*outstanding_file_count.borrow(), 0);
 
-    let search_results = store
+    let search_results = semantic_index
         .update(cx, |store, cx| {
             store.search_project(
                 project.clone(),
@@ -122,6 +122,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
             (Path::new("src/file2.rs").into(), 0),
             (Path::new("src/file3.toml").into(), 0),
             (Path::new("src/file1.rs").into(), 45),
+            (Path::new("src/file2.rs").into(), 45),
         ],
         cx,
     );
@@ -129,7 +130,7 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
     // Test Include Files Functonality
     let include_files = vec![PathMatcher::new("*.rs").unwrap()];
     let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
-    let rust_only_search_results = store
+    let rust_only_search_results = semantic_index
         .update(cx, |store, cx| {
             store.search_project(
                 project.clone(),
@@ -149,11 +150,12 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
             (Path::new("src/file1.rs").into(), 0),
             (Path::new("src/file2.rs").into(), 0),
             (Path::new("src/file1.rs").into(), 45),
+            (Path::new("src/file2.rs").into(), 45),
         ],
         cx,
     );
 
-    let no_rust_search_results = store
+    let no_rust_search_results = semantic_index
         .update(cx, |store, cx| {
             store.search_project(
                 project.clone(),
@@ -186,24 +188,87 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
     .await
     .unwrap();
 
-    cx.foreground().run_until_parked();
+    deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
 
     let prev_embedding_count = embedding_provider.embedding_count();
-    let (file_count, outstanding_file_count) = store
+    let (file_count, outstanding_file_count) = semantic_index
         .update(cx, |store, cx| store.index_project(project.clone(), cx))
         .await
         .unwrap();
     assert_eq!(file_count, 1);
 
-    cx.foreground().run_until_parked();
+    deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
     assert_eq!(*outstanding_file_count.borrow(), 0);
 
     assert_eq!(
         embedding_provider.embedding_count() - prev_embedding_count,
-        2
+        1
     );
 }
 
+#[gpui::test(iterations = 10)]
+async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
+    let (outstanding_job_count, _) = postage::watch::channel_with(0);
+    let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
+
+    let files = (1..=3)
+        .map(|file_ix| FileToEmbed {
+            worktree_id: 5,
+            path: format!("path-{file_ix}").into(),
+            mtime: SystemTime::now(),
+            documents: (0..rng.gen_range(4..22))
+                .map(|document_ix| {
+                    let content_len = rng.gen_range(10..100);
+                    let content = RandomCharIter::new(&mut rng)
+                        .with_simple_text()
+                        .take(content_len)
+                        .collect::<String>();
+                    let digest = DocumentDigest::from(content.as_str());
+                    Document {
+                        range: 0..10,
+                        embedding: None,
+                        name: format!("document {document_ix}"),
+                        content,
+                        digest,
+                        token_count: rng.gen_range(10..30),
+                    }
+                })
+                .collect(),
+            job_handle: JobHandle::new(&outstanding_job_count),
+        })
+        .collect::<Vec<_>>();
+
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
+
+    let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
+    for file in &files {
+        queue.push(file.clone());
+    }
+    queue.flush();
+
+    cx.foreground().run_until_parked();
+    let finished_files = queue.finished_files();
+    let mut embedded_files: Vec<_> = files
+        .iter()
+        .map(|_| finished_files.try_recv().expect("no finished file"))
+        .collect();
+
+    let expected_files: Vec<_> = files
+        .iter()
+        .map(|file| {
+            let mut file = file.clone();
+            for doc in &mut file.documents {
+                doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
+            }
+            file
+        })
+        .collect();
+
+    embedded_files.sort_by_key(|f| f.path.clone());
+
+    assert_eq!(embedded_files, expected_files);
+}
+
 #[track_caller]
 fn assert_search_results(
     actual: &[SearchResult],
@@ -227,7 +292,8 @@ fn assert_search_results(
 #[gpui::test]
 async fn test_code_context_retrieval_rust() {
     let language = rust_lang();
-    let mut retriever = CodeContextRetriever::new();
+    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = "
         /// A doc comment
@@ -314,7 +380,8 @@ async fn test_code_context_retrieval_rust() {
 #[gpui::test]
 async fn test_code_context_retrieval_json() {
     let language = json_lang();
-    let mut retriever = CodeContextRetriever::new();
+    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
         {
@@ -397,7 +464,8 @@ fn assert_documents_eq(
 #[gpui::test]
 async fn test_code_context_retrieval_javascript() {
     let language = js_lang();
-    let mut retriever = CodeContextRetriever::new();
+    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = "
         /* globals importScripts, backend */
@@ -495,7 +563,8 @@ async fn test_code_context_retrieval_javascript() {
 #[gpui::test]
 async fn test_code_context_retrieval_lua() {
     let language = lua_lang();
-    let mut retriever = CodeContextRetriever::new();
+    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
         -- Creates a new class
@@ -568,7 +637,8 @@ async fn test_code_context_retrieval_lua() {
 #[gpui::test]
 async fn test_code_context_retrieval_elixir() {
     let language = elixir_lang();
-    let mut retriever = CodeContextRetriever::new();
+    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
         defmodule File.Stream do
@@ -684,7 +754,8 @@ async fn test_code_context_retrieval_elixir() {
 #[gpui::test]
 async fn test_code_context_retrieval_cpp() {
     let language = cpp_lang();
-    let mut retriever = CodeContextRetriever::new();
+    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = "
     /**
@@ -836,7 +907,8 @@ async fn test_code_context_retrieval_cpp() {
 #[gpui::test]
 async fn test_code_context_retrieval_ruby() {
     let language = ruby_lang();
-    let mut retriever = CodeContextRetriever::new();
+    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
         # This concern is inspired by "sudo mode" on GitHub. It
@@ -1026,7 +1098,8 @@ async fn test_code_context_retrieval_ruby() {
 #[gpui::test]
 async fn test_code_context_retrieval_php() {
     let language = php_lang();
-    let mut retriever = CodeContextRetriever::new();
+    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
         <?php
@@ -1173,36 +1246,6 @@ async fn test_code_context_retrieval_php() {
     );
 }
 
-#[gpui::test]
-fn test_dot_product(mut rng: StdRng) {
-    assert_eq!(dot(&[1., 0., 0., 0., 0.], &[0., 1., 0., 0., 0.]), 0.);
-    assert_eq!(dot(&[2., 0., 0., 0., 0.], &[3., 1., 0., 0., 0.]), 6.);
-
-    for _ in 0..100 {
-        let size = 1536;
-        let mut a = vec![0.; size];
-        let mut b = vec![0.; size];
-        for (a, b) in a.iter_mut().zip(b.iter_mut()) {
-            *a = rng.gen();
-            *b = rng.gen();
-        }
-
-        assert_eq!(
-            round_to_decimals(dot(&a, &b), 1),
-            round_to_decimals(reference_dot(&a, &b), 1)
-        );
-    }
-
-    fn round_to_decimals(n: f32, decimal_places: i32) -> f32 {
-        let factor = (10.0 as f32).powi(decimal_places);
-        (n * factor).round() / factor
-    }
-
-    fn reference_dot(a: &[f32], b: &[f32]) -> f32 {
-        a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()
-    }
-}
-
 #[derive(Default)]
 struct FakeEmbeddingProvider {
     embedding_count: AtomicUsize,
@@ -1212,35 +1255,42 @@ impl FakeEmbeddingProvider {
     fn embedding_count(&self) -> usize {
         self.embedding_count.load(atomic::Ordering::SeqCst)
     }
+
+    fn embed_sync(&self, span: &str) -> Embedding {
+        let mut result = vec![1.0; 26];
+        for letter in span.chars() {
+            let letter = letter.to_ascii_lowercase();
+            if letter as u32 >= 'a' as u32 {
+                let ix = (letter as u32) - ('a' as u32);
+                if ix < 26 {
+                    result[ix as usize] += 1.0;
+                }
+            }
+        }
+
+        let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
+        for x in &mut result {
+            *x /= norm;
+        }
+
+        result.into()
+    }
 }
 
 #[async_trait]
 impl EmbeddingProvider for FakeEmbeddingProvider {
-    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
-        self.embedding_count
-            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
-        Ok(spans
-            .iter()
-            .map(|span| {
-                let mut result = vec![1.0; 26];
-                for letter in span.chars() {
-                    let letter = letter.to_ascii_lowercase();
-                    if letter as u32 >= 'a' as u32 {
-                        let ix = (letter as u32) - ('a' as u32);
-                        if ix < 26 {
-                            result[ix as usize] += 1.0;
-                        }
-                    }
-                }
+    fn truncate(&self, span: &str) -> (String, usize) {
+        (span.to_string(), 1)
+    }
 
-                let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
-                for x in &mut result {
-                    *x /= norm;
-                }
+    fn max_tokens_per_batch(&self) -> usize {
+        200
+    }
 
-                result
-            })
-            .collect())
+    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
+        self.embedding_count
+            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
+        Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
     }
 }
 
@@ -1684,3 +1734,11 @@ fn test_subtract_ranges() {
 
     assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
 }
+
+fn init_test(cx: &mut TestAppContext) {
+    cx.update(|cx| {
+        cx.set_global(SettingsStore::test(cx));
+        settings::register::<SemanticIndexSettings>(cx);
+        settings::register::<ProjectSettings>(cx);
+    });
+}

crates/util/src/util.rs 🔗

@@ -260,11 +260,22 @@ pub fn defer<F: FnOnce()>(f: F) -> impl Drop {
     Defer(Some(f))
 }
 
-pub struct RandomCharIter<T: Rng>(T);
+pub struct RandomCharIter<T: Rng> {
+    rng: T,
+    simple_text: bool,
+}
 
 impl<T: Rng> RandomCharIter<T> {
     pub fn new(rng: T) -> Self {
-        Self(rng)
+        Self {
+            rng,
+            simple_text: std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()),
+        }
+    }
+
+    pub fn with_simple_text(mut self) -> Self {
+        self.simple_text = true;
+        self
     }
 }
 
@@ -272,25 +283,27 @@ impl<T: Rng> Iterator for RandomCharIter<T> {
     type Item = char;
 
     fn next(&mut self) -> Option<Self::Item> {
-        if std::env::var("SIMPLE_TEXT").map_or(false, |v| !v.is_empty()) {
-            return if self.0.gen_range(0..100) < 5 {
+        if self.simple_text {
+            return if self.rng.gen_range(0..100) < 5 {
                 Some('\n')
             } else {
-                Some(self.0.gen_range(b'a'..b'z' + 1).into())
+                Some(self.rng.gen_range(b'a'..b'z' + 1).into())
             };
         }
 
-        match self.0.gen_range(0..100) {
+        match self.rng.gen_range(0..100) {
             // whitespace
-            0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.0).copied(),
+            0..=19 => [' ', '\n', '\r', '\t'].choose(&mut self.rng).copied(),
             // two-byte greek letters
-            20..=32 => char::from_u32(self.0.gen_range(('α' as u32)..('ω' as u32 + 1))),
+            20..=32 => char::from_u32(self.rng.gen_range(('α' as u32)..('ω' as u32 + 1))),
             // // three-byte characters
-            33..=45 => ['✋', '✅', '❌', '❎', '⭐'].choose(&mut self.0).copied(),
+            33..=45 => ['✋', '✅', '❌', '❎', '⭐']
+                .choose(&mut self.rng)
+                .copied(),
             // // four-byte characters
-            46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.0).copied(),
+            46..=58 => ['🍐', '🏀', '🍗', '🎉'].choose(&mut self.rng).copied(),
             // ascii letters
-            _ => Some(self.0.gen_range(b'a'..b'z' + 1).into()),
+            _ => Some(self.rng.gen_range(b'a'..b'z' + 1).into()),
         }
     }
 }