Abstract away how database transactions are executed

Antonio Scandurra and Kyle Caverly created

Co-Authored-By: Kyle Caverly <kyle@zed.dev>

Change summary

crates/semantic_index/src/db.rs             | 614 ++++++++++++----------
crates/semantic_index/src/semantic_index.rs | 199 +-----
2 files changed, 389 insertions(+), 424 deletions(-)

Detailed changes

crates/semantic_index/src/db.rs 🔗

@@ -1,5 +1,7 @@
 use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
 use anyhow::{anyhow, Context, Result};
+use futures::channel::oneshot;
+use gpui::executor;
 use project::{search::PathMatcher, Fs};
 use rpc::proto::Timestamp;
 use rusqlite::{
@@ -9,12 +11,14 @@ use rusqlite::{
 use std::{
     cmp::Ordering,
     collections::HashMap,
+    future::Future,
     ops::Range,
     path::{Path, PathBuf},
     rc::Rc,
     sync::Arc,
     time::SystemTime,
 };
+use util::TryFutureExt;
 
 #[derive(Debug)]
 pub struct FileRecord {
@@ -51,117 +55,161 @@ impl FromSql for Sha1 {
     }
 }
 
+#[derive(Clone)]
 pub struct VectorDatabase {
-    db: rusqlite::Connection,
+    path: Arc<Path>,
+    transactions: smol::channel::Sender<Box<dyn 'static + Send + FnOnce(&rusqlite::Connection)>>,
 }
 
 impl VectorDatabase {
-    pub async fn new(fs: Arc<dyn Fs>, path: Arc<PathBuf>) -> Result<Self> {
+    pub async fn new(
+        fs: Arc<dyn Fs>,
+        path: Arc<Path>,
+        executor: Arc<executor::Background>,
+    ) -> Result<Self> {
         if let Some(db_directory) = path.parent() {
             fs.create_dir(db_directory).await?;
         }
 
+        let (transactions_tx, transactions_rx) =
+            smol::channel::unbounded::<Box<dyn 'static + Send + FnOnce(&rusqlite::Connection)>>();
+        executor
+            .spawn({
+                let path = path.clone();
+                async move {
+                    let connection = rusqlite::Connection::open(&path)?;
+                    while let Ok(transaction) = transactions_rx.recv().await {
+                        transaction(&connection);
+                    }
+
+                    anyhow::Ok(())
+                }
+                .log_err()
+            })
+            .detach();
         let this = Self {
-            db: rusqlite::Connection::open(path.as_path())?,
+            transactions: transactions_tx,
+            path,
         };
-        this.initialize_database()?;
+        this.initialize_database().await?;
         Ok(this)
     }
 
-    fn get_existing_version(&self) -> Result<i64> {
-        let mut version_query = self
-            .db
-            .prepare("SELECT version from semantic_index_config")?;
-        version_query
-            .query_row([], |row| Ok(row.get::<_, i64>(0)?))
-            .map_err(|err| anyhow!("version query failed: {err}"))
+    pub fn path(&self) -> &Arc<Path> {
+        &self.path
     }
 
-    fn initialize_database(&self) -> Result<()> {
-        rusqlite::vtab::array::load_module(&self.db)?;
-
-        // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
-        if self
-            .get_existing_version()
-            .map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64)
-        {
-            log::trace!("vector database schema up to date");
-            return Ok(());
+    fn transact<F, T>(&self, transaction: F) -> impl Future<Output = Result<T>>
+    where
+        F: 'static + Send + FnOnce(&rusqlite::Connection) -> Result<T>,
+        T: 'static + Send,
+    {
+        let (tx, rx) = oneshot::channel();
+        let transactions = self.transactions.clone();
+        async move {
+            if transactions
+                .send(Box::new(|connection| {
+                    let result = transaction(connection);
+                    let _ = tx.send(result);
+                }))
+                .await
+                .is_err()
+            {
+                return Err(anyhow!("connection was dropped"))?;
+            }
+            rx.await?
         }
+    }
 
-        log::trace!("vector database schema out of date. updating...");
-        self.db
-            .execute("DROP TABLE IF EXISTS documents", [])
-            .context("failed to drop 'documents' table")?;
-        self.db
-            .execute("DROP TABLE IF EXISTS files", [])
-            .context("failed to drop 'files' table")?;
-        self.db
-            .execute("DROP TABLE IF EXISTS worktrees", [])
-            .context("failed to drop 'worktrees' table")?;
-        self.db
-            .execute("DROP TABLE IF EXISTS semantic_index_config", [])
-            .context("failed to drop 'semantic_index_config' table")?;
-
-        // Initialize Vector Databasing Tables
-        self.db.execute(
-            "CREATE TABLE semantic_index_config (
-                version INTEGER NOT NULL
-            )",
-            [],
-        )?;
+    fn initialize_database(&self) -> impl Future<Output = Result<()>> {
+        self.transact(|db| {
+            rusqlite::vtab::array::load_module(&db)?;
+
+            // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
+            let version_query = db.prepare("SELECT version from semantic_index_config");
+            let version = version_query
+                .and_then(|mut query| query.query_row([], |row| Ok(row.get::<_, i64>(0)?)));
+            if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) {
+                log::trace!("vector database schema up to date");
+                return Ok(());
+            }
 
-        self.db.execute(
-            "INSERT INTO semantic_index_config (version) VALUES (?1)",
-            params![SEMANTIC_INDEX_VERSION],
-        )?;
+            log::trace!("vector database schema out of date. updating...");
+            db.execute("DROP TABLE IF EXISTS documents", [])
+                .context("failed to drop 'documents' table")?;
+            db.execute("DROP TABLE IF EXISTS files", [])
+                .context("failed to drop 'files' table")?;
+            db.execute("DROP TABLE IF EXISTS worktrees", [])
+                .context("failed to drop 'worktrees' table")?;
+            db.execute("DROP TABLE IF EXISTS semantic_index_config", [])
+                .context("failed to drop 'semantic_index_config' table")?;
+
+            // Initialize Vector Databasing Tables
+            db.execute(
+                "CREATE TABLE semantic_index_config (
+                    version INTEGER NOT NULL
+                )",
+                [],
+            )?;
 
-        self.db.execute(
-            "CREATE TABLE worktrees (
-                id INTEGER PRIMARY KEY AUTOINCREMENT,
-                absolute_path VARCHAR NOT NULL
-            );
-            CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
-            ",
-            [],
-        )?;
+            db.execute(
+                "INSERT INTO semantic_index_config (version) VALUES (?1)",
+                params![SEMANTIC_INDEX_VERSION],
+            )?;
 
-        self.db.execute(
-            "CREATE TABLE files (
-                id INTEGER PRIMARY KEY AUTOINCREMENT,
-                worktree_id INTEGER NOT NULL,
-                relative_path VARCHAR NOT NULL,
-                mtime_seconds INTEGER NOT NULL,
-                mtime_nanos INTEGER NOT NULL,
-                FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
-            )",
-            [],
-        )?;
+            db.execute(
+                "CREATE TABLE worktrees (
+                    id INTEGER PRIMARY KEY AUTOINCREMENT,
+                    absolute_path VARCHAR NOT NULL
+                );
+                CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
+                ",
+                [],
+            )?;
 
-        self.db.execute(
-            "CREATE TABLE documents (
-                id INTEGER PRIMARY KEY AUTOINCREMENT,
-                file_id INTEGER NOT NULL,
-                start_byte INTEGER NOT NULL,
-                end_byte INTEGER NOT NULL,
-                name VARCHAR NOT NULL,
-                embedding BLOB NOT NULL,
-                sha1 BLOB NOT NULL,
-                FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
-            )",
-            [],
-        )?;
+            db.execute(
+                "CREATE TABLE files (
+                    id INTEGER PRIMARY KEY AUTOINCREMENT,
+                    worktree_id INTEGER NOT NULL,
+                    relative_path VARCHAR NOT NULL,
+                    mtime_seconds INTEGER NOT NULL,
+                    mtime_nanos INTEGER NOT NULL,
+                    FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
+                )",
+                [],
+            )?;
 
-        log::trace!("vector database initialized with updated schema.");
-        Ok(())
+            db.execute(
+                "CREATE TABLE documents (
+                    id INTEGER PRIMARY KEY AUTOINCREMENT,
+                    file_id INTEGER NOT NULL,
+                    start_byte INTEGER NOT NULL,
+                    end_byte INTEGER NOT NULL,
+                    name VARCHAR NOT NULL,
+                    embedding BLOB NOT NULL,
+                    sha1 BLOB NOT NULL,
+                    FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
+                )",
+                [],
+            )?;
+
+            log::trace!("vector database initialized with updated schema.");
+            Ok(())
+        })
     }
 
-    pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> {
-        self.db.execute(
-            "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
-            params![worktree_id, delete_path.to_str()],
-        )?;
-        Ok(())
+    pub fn delete_file(
+        &self,
+        worktree_id: i64,
+        delete_path: PathBuf,
+    ) -> impl Future<Output = Result<()>> {
+        self.transact(move |db| {
+            db.execute(
+                "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
+                params![worktree_id, delete_path.to_str()],
+            )?;
+            Ok(())
+        })
     }
 
     pub fn insert_file(
@@ -170,117 +218,126 @@ impl VectorDatabase {
         path: PathBuf,
         mtime: SystemTime,
         documents: Vec<Document>,
-    ) -> Result<()> {
-        // Return the existing ID, if both the file and mtime match
-        let mtime = Timestamp::from(mtime);
-        let mut existing_id_query = self.db.prepare("SELECT id FROM files WHERE worktree_id = ?1 AND relative_path = ?2 AND mtime_seconds = ?3 AND mtime_nanos = ?4")?;
-        let existing_id = existing_id_query
-            .query_row(
-                params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
-                |row| Ok(row.get::<_, i64>(0)?),
-            )
-            .map_err(|err| anyhow!(err));
-        let file_id = if existing_id.is_ok() {
-            // If already exists, just return the existing id
-            existing_id.unwrap()
-        } else {
-            // Delete Existing Row
-            self.db.execute(
-                "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;",
-                params![worktree_id, path.to_str()],
-            )?;
-            self.db.execute("INSERT INTO files (worktree_id, relative_path, mtime_seconds, mtime_nanos) VALUES (?1, ?2, ?3, ?4);", params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos])?;
-            self.db.last_insert_rowid()
-        };
-
-        // Currently inserting at approximately 3400 documents a second
-        // I imagine we can speed this up with a bulk insert of some kind.
-        for document in documents {
-            let embedding_blob = bincode::serialize(&document.embedding)?;
-            let sha_blob = bincode::serialize(&document.sha1)?;
-
-            self.db.execute(
-                "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
-                params![
-                    file_id,
-                    document.range.start.to_string(),
-                    document.range.end.to_string(),
-                    document.name,
-                    embedding_blob,
-                    sha_blob
-                ],
-            )?;
-        }
+    ) -> impl Future<Output = Result<()>> {
+        self.transact(move |db| {
+            // Return the existing ID, if both the file and mtime match
+            let mtime = Timestamp::from(mtime);
+
+            let mut existing_id_query = db.prepare("SELECT id FROM files WHERE worktree_id = ?1 AND relative_path = ?2 AND mtime_seconds = ?3 AND mtime_nanos = ?4")?;
+            let existing_id = existing_id_query
+                .query_row(
+                    params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
+                    |row| Ok(row.get::<_, i64>(0)?),
+                );
+
+            let file_id = if existing_id.is_ok() {
+                // If already exists, just return the existing id
+                existing_id?
+            } else {
+                // Delete Existing Row
+                db.execute(
+                    "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;",
+                    params![worktree_id, path.to_str()],
+                )?;
+                db.execute("INSERT INTO files (worktree_id, relative_path, mtime_seconds, mtime_nanos) VALUES (?1, ?2, ?3, ?4);", params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos])?;
+                db.last_insert_rowid()
+            };
 
-        Ok(())
+            // Currently inserting at approximately 3400 documents a second
+            // I imagine we can speed this up with a bulk insert of some kind.
+            for document in documents {
+                let embedding_blob = bincode::serialize(&document.embedding)?;
+                let sha_blob = bincode::serialize(&document.sha1)?;
+
+                db.execute(
+                    "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
+                    params![
+                        file_id,
+                        document.range.start.to_string(),
+                        document.range.end.to_string(),
+                        document.name,
+                        embedding_blob,
+                        sha_blob
+                    ],
+                )?;
+           }
+
+           Ok(())
+        })
     }
 
-    pub fn worktree_previously_indexed(&self, worktree_root_path: &Path) -> Result<bool> {
-        let mut worktree_query = self
-            .db
-            .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
-        let worktree_id = worktree_query
-            .query_row(params![worktree_root_path.to_string_lossy()], |row| {
-                Ok(row.get::<_, i64>(0)?)
-            })
-            .map_err(|err| anyhow!(err));
-
-        if worktree_id.is_ok() {
-            return Ok(true);
-        } else {
-            return Ok(false);
-        }
+    pub fn worktree_previously_indexed(
+        &self,
+        worktree_root_path: &Path,
+    ) -> impl Future<Output = Result<bool>> {
+        let worktree_root_path = worktree_root_path.to_string_lossy().into_owned();
+        self.transact(move |db| {
+            let mut worktree_query =
+                db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
+            let worktree_id = worktree_query
+                .query_row(params![worktree_root_path], |row| Ok(row.get::<_, i64>(0)?));
+
+            if worktree_id.is_ok() {
+                return Ok(true);
+            } else {
+                return Ok(false);
+            }
+        })
     }
 
-    pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
-        // Check that the absolute path doesnt exist
-        let mut worktree_query = self
-            .db
-            .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
-
-        let worktree_id = worktree_query
-            .query_row(params![worktree_root_path.to_string_lossy()], |row| {
-                Ok(row.get::<_, i64>(0)?)
-            })
-            .map_err(|err| anyhow!(err));
-
-        if worktree_id.is_ok() {
-            return worktree_id;
-        }
+    pub fn find_or_create_worktree(
+        &self,
+        worktree_root_path: PathBuf,
+    ) -> impl Future<Output = Result<i64>> {
+        self.transact(move |db| {
+            let mut worktree_query =
+                db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
+            let worktree_id = worktree_query
+                .query_row(params![worktree_root_path.to_string_lossy()], |row| {
+                    Ok(row.get::<_, i64>(0)?)
+                });
+
+            if worktree_id.is_ok() {
+                return Ok(worktree_id?);
+            }
 
-        // If worktree_id is Err, insert new worktree
-        self.db.execute(
-            "
-            INSERT into worktrees (absolute_path) VALUES (?1)
-            ",
-            params![worktree_root_path.to_string_lossy()],
-        )?;
-        Ok(self.db.last_insert_rowid())
+            // If worktree_id is Err, insert new worktree
+            db.execute(
+                "INSERT into worktrees (absolute_path) VALUES (?1)",
+                params![worktree_root_path.to_string_lossy()],
+            )?;
+            Ok(db.last_insert_rowid())
+        })
     }
 
-    pub fn get_file_mtimes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, SystemTime>> {
-        let mut statement = self.db.prepare(
-            "
-            SELECT relative_path, mtime_seconds, mtime_nanos
-            FROM files
-            WHERE worktree_id = ?1
-            ORDER BY relative_path",
-        )?;
-        let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
-        for row in statement.query_map(params![worktree_id], |row| {
-            Ok((
-                row.get::<_, String>(0)?.into(),
-                Timestamp {
-                    seconds: row.get(1)?,
-                    nanos: row.get(2)?,
-                }
-                .into(),
-            ))
-        })? {
-            let row = row?;
-            result.insert(row.0, row.1);
-        }
-        Ok(result)
+    pub fn get_file_mtimes(
+        &self,
+        worktree_id: i64,
+    ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
+        self.transact(move |db| {
+            let mut statement = db.prepare(
+                "
+                SELECT relative_path, mtime_seconds, mtime_nanos
+                FROM files
+                WHERE worktree_id = ?1
+                ORDER BY relative_path",
+            )?;
+            let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
+            for row in statement.query_map(params![worktree_id], |row| {
+                Ok((
+                    row.get::<_, String>(0)?.into(),
+                    Timestamp {
+                        seconds: row.get(1)?,
+                        nanos: row.get(2)?,
+                    }
+                    .into(),
+                ))
+            })? {
+                let row = row?;
+                result.insert(row.0, row.1);
+            }
+            Ok(result)
+        })
     }
 
     pub fn top_k_search(
@@ -288,21 +345,25 @@ impl VectorDatabase {
         query_embedding: &Vec<f32>,
         limit: usize,
         file_ids: &[i64],
-    ) -> Result<Vec<(i64, f32)>> {
-        let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
-        self.for_each_document(file_ids, |id, embedding| {
-            let similarity = dot(&embedding, &query_embedding);
-            let ix = match results
-                .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
-            {
-                Ok(ix) => ix,
-                Err(ix) => ix,
-            };
-            results.insert(ix, (id, similarity));
-            results.truncate(limit);
-        })?;
-
-        Ok(results)
+    ) -> impl Future<Output = Result<Vec<(i64, f32)>>> {
+        let query_embedding = query_embedding.clone();
+        let file_ids = file_ids.to_vec();
+        self.transact(move |db| {
+            let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
+            Self::for_each_document(db, &file_ids, |id, embedding| {
+                let similarity = dot(&embedding, &query_embedding);
+                let ix = match results.binary_search_by(|(_, s)| {
+                    similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
+                }) {
+                    Ok(ix) => ix,
+                    Err(ix) => ix,
+                };
+                results.insert(ix, (id, similarity));
+                results.truncate(limit);
+            })?;
+
+            anyhow::Ok(results)
+        })
     }
 
     pub fn retrieve_included_file_ids(
@@ -310,37 +371,46 @@ impl VectorDatabase {
         worktree_ids: &[i64],
         includes: &[PathMatcher],
         excludes: &[PathMatcher],
-    ) -> Result<Vec<i64>> {
-        let mut file_query = self.db.prepare(
-            "
-            SELECT
-                id, relative_path
-            FROM
-                files
-            WHERE
-                worktree_id IN rarray(?)
-            ",
-        )?;
+    ) -> impl Future<Output = Result<Vec<i64>>> {
+        let worktree_ids = worktree_ids.to_vec();
+        let includes = includes.to_vec();
+        let excludes = excludes.to_vec();
+        self.transact(move |db| {
+            let mut file_query = db.prepare(
+                "
+                SELECT
+                    id, relative_path
+                FROM
+                    files
+                WHERE
+                    worktree_id IN rarray(?)
+                ",
+            )?;
 
-        let mut file_ids = Vec::<i64>::new();
-        let mut rows = file_query.query([ids_to_sql(worktree_ids)])?;
-
-        while let Some(row) = rows.next()? {
-            let file_id = row.get(0)?;
-            let relative_path = row.get_ref(1)?.as_str()?;
-            let included =
-                includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
-            let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
-            if included && !excluded {
-                file_ids.push(file_id);
+            let mut file_ids = Vec::<i64>::new();
+            let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?;
+
+            while let Some(row) = rows.next()? {
+                let file_id = row.get(0)?;
+                let relative_path = row.get_ref(1)?.as_str()?;
+                let included =
+                    includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
+                let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
+                if included && !excluded {
+                    file_ids.push(file_id);
+                }
             }
-        }
 
-        Ok(file_ids)
+            anyhow::Ok(file_ids)
+        })
     }
 
-    fn for_each_document(&self, file_ids: &[i64], mut f: impl FnMut(i64, Vec<f32>)) -> Result<()> {
-        let mut query_statement = self.db.prepare(
+    fn for_each_document(
+        db: &rusqlite::Connection,
+        file_ids: &[i64],
+        mut f: impl FnMut(i64, Vec<f32>),
+    ) -> Result<()> {
+        let mut query_statement = db.prepare(
             "
             SELECT
                 id, embedding
@@ -360,47 +430,53 @@ impl VectorDatabase {
         Ok(())
     }
 
-    pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
-        let mut statement = self.db.prepare(
-            "
-                SELECT
-                    documents.id,
-                    files.worktree_id,
-                    files.relative_path,
-                    documents.start_byte,
-                    documents.end_byte
-                FROM
-                    documents, files
-                WHERE
-                    documents.file_id = files.id AND
-                    documents.id in rarray(?)
-            ",
-        )?;
+    pub fn get_documents_by_ids(
+        &self,
+        ids: &[i64],
+    ) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
+        let ids = ids.to_vec();
+        self.transact(move |db| {
+            let mut statement = db.prepare(
+                "
+                    SELECT
+                        documents.id,
+                        files.worktree_id,
+                        files.relative_path,
+                        documents.start_byte,
+                        documents.end_byte
+                    FROM
+                        documents, files
+                    WHERE
+                        documents.file_id = files.id AND
+                        documents.id in rarray(?)
+                ",
+            )?;
 
-        let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
-            Ok((
-                row.get::<_, i64>(0)?,
-                row.get::<_, i64>(1)?,
-                row.get::<_, String>(2)?.into(),
-                row.get(3)?..row.get(4)?,
-            ))
-        })?;
-
-        let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
-        for row in result_iter {
-            let (id, worktree_id, path, range) = row?;
-            values_by_id.insert(id, (worktree_id, path, range));
-        }
+            let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| {
+                Ok((
+                    row.get::<_, i64>(0)?,
+                    row.get::<_, i64>(1)?,
+                    row.get::<_, String>(2)?.into(),
+                    row.get(3)?..row.get(4)?,
+                ))
+            })?;
+
+            let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
+            for row in result_iter {
+                let (id, worktree_id, path, range) = row?;
+                values_by_id.insert(id, (worktree_id, path, range));
+            }
 
-        let mut results = Vec::with_capacity(ids.len());
-        for id in ids {
-            let value = values_by_id
-                .remove(id)
-                .ok_or(anyhow!("missing document id {}", id))?;
-            results.push(value);
-        }
+            let mut results = Vec::with_capacity(ids.len());
+            for id in &ids {
+                let value = values_by_id
+                    .remove(id)
+                    .ok_or(anyhow!("missing document id {}", id))?;
+                results.push(value);
+            }
 
-        Ok(results)
+            Ok(results)
+        })
     }
 }
 

crates/semantic_index/src/semantic_index.rs 🔗

@@ -12,11 +12,10 @@ use anyhow::{anyhow, Result};
 use db::VectorDatabase;
 use embedding::{EmbeddingProvider, OpenAIEmbeddings};
 use embedding_queue::{EmbeddingQueue, FileToEmbed};
-use futures::{channel::oneshot, Future};
 use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
 use language::{Anchor, Buffer, Language, LanguageRegistry};
 use parking_lot::Mutex;
-use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES};
+use parsing::{CodeContextRetriever, PARSEABLE_ENTIRE_FILE_TYPES};
 use postage::watch;
 use project::{
     search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, ProjectPath, Worktree, WorktreeId,
@@ -101,13 +100,11 @@ pub fn init(
 
 pub struct SemanticIndex {
     fs: Arc<dyn Fs>,
-    database_url: Arc<PathBuf>,
+    db: VectorDatabase,
     embedding_provider: Arc<dyn EmbeddingProvider>,
     language_registry: Arc<LanguageRegistry>,
-    db_update_tx: channel::Sender<DbOperation>,
     parsing_files_tx: channel::Sender<PendingFile>,
     _embedding_task: Task<()>,
-    _db_update_task: Task<()>,
     _parsing_files_tasks: Vec<Task<()>>,
     projects: HashMap<WeakModelHandle<Project>, ProjectState>,
 }
@@ -203,32 +200,6 @@ pub struct SearchResult {
     pub range: Range<Anchor>,
 }
 
-enum DbOperation {
-    InsertFile {
-        worktree_id: i64,
-        documents: Vec<Document>,
-        path: PathBuf,
-        mtime: SystemTime,
-        job_handle: JobHandle,
-    },
-    Delete {
-        worktree_id: i64,
-        path: PathBuf,
-    },
-    FindOrCreateWorktree {
-        path: PathBuf,
-        sender: oneshot::Sender<Result<i64>>,
-    },
-    FileMTimes {
-        worktree_id: i64,
-        sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
-    },
-    WorktreePreviouslyIndexed {
-        path: Arc<Path>,
-        sender: oneshot::Sender<Result<bool>>,
-    },
-}
-
 impl SemanticIndex {
     pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
         if cx.has_global::<ModelHandle<Self>>() {
@@ -245,18 +216,14 @@ impl SemanticIndex {
 
     async fn new(
         fs: Arc<dyn Fs>,
-        database_url: PathBuf,
+        database_path: PathBuf,
         embedding_provider: Arc<dyn EmbeddingProvider>,
         language_registry: Arc<LanguageRegistry>,
         mut cx: AsyncAppContext,
     ) -> Result<ModelHandle<Self>> {
         let t0 = Instant::now();
-        let database_url = Arc::new(database_url);
-
-        let db = cx
-            .background()
-            .spawn(VectorDatabase::new(fs.clone(), database_url.clone()))
-            .await?;
+        let database_path = Arc::from(database_path);
+        let db = VectorDatabase::new(fs.clone(), database_path, cx.background()).await?;
 
         log::trace!(
             "db initialization took {:?} milliseconds",
@@ -265,32 +232,16 @@ impl SemanticIndex {
 
         Ok(cx.add_model(|cx| {
             let t0 = Instant::now();
-            // Perform database operations
-            let (db_update_tx, db_update_rx) = channel::unbounded();
-            let _db_update_task = cx.background().spawn({
-                async move {
-                    while let Ok(job) = db_update_rx.recv().await {
-                        Self::run_db_operation(&db, job)
-                    }
-                }
-            });
-
             let embedding_queue =
                 EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
             let _embedding_task = cx.background().spawn({
                 let embedded_files = embedding_queue.finished_files();
-                let db_update_tx = db_update_tx.clone();
+                let db = db.clone();
                 async move {
                     while let Ok(file) = embedded_files.recv().await {
-                        db_update_tx
-                            .try_send(DbOperation::InsertFile {
-                                worktree_id: file.worktree_id,
-                                documents: file.documents,
-                                path: file.path,
-                                mtime: file.mtime,
-                                job_handle: file.job_handle,
-                            })
-                            .ok();
+                        db.insert_file(file.worktree_id, file.path, file.mtime, file.documents)
+                            .await
+                            .log_err();
                     }
                 }
             });
@@ -325,12 +276,10 @@ impl SemanticIndex {
             );
             Self {
                 fs,
-                database_url,
+                db,
                 embedding_provider,
                 language_registry,
-                db_update_tx,
                 parsing_files_tx,
-                _db_update_task,
                 _embedding_task,
                 _parsing_files_tasks,
                 projects: HashMap::new(),
@@ -338,40 +287,6 @@ impl SemanticIndex {
         }))
     }
 
-    fn run_db_operation(db: &VectorDatabase, job: DbOperation) {
-        match job {
-            DbOperation::InsertFile {
-                worktree_id,
-                documents,
-                path,
-                mtime,
-                job_handle,
-            } => {
-                db.insert_file(worktree_id, path, mtime, documents)
-                    .log_err();
-                drop(job_handle)
-            }
-            DbOperation::Delete { worktree_id, path } => {
-                db.delete_file(worktree_id, path).log_err();
-            }
-            DbOperation::FindOrCreateWorktree { path, sender } => {
-                let id = db.find_or_create_worktree(&path);
-                sender.send(id).ok();
-            }
-            DbOperation::FileMTimes {
-                worktree_id: worktree_db_id,
-                sender,
-            } => {
-                let file_mtimes = db.get_file_mtimes(worktree_db_id);
-                sender.send(file_mtimes).ok();
-            }
-            DbOperation::WorktreePreviouslyIndexed { path, sender } => {
-                let worktree_indexed = db.worktree_previously_indexed(path.as_ref());
-                sender.send(worktree_indexed).ok();
-            }
-        }
-    }
-
     async fn parse_file(
         fs: &Arc<dyn Fs>,
         pending_file: PendingFile,
@@ -409,36 +324,6 @@ impl SemanticIndex {
         }
     }
 
-    fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
-        let (tx, rx) = oneshot::channel();
-        self.db_update_tx
-            .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
-            .unwrap();
-        async move { rx.await? }
-    }
-
-    fn get_file_mtimes(
-        &self,
-        worktree_id: i64,
-    ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
-        let (tx, rx) = oneshot::channel();
-        self.db_update_tx
-            .try_send(DbOperation::FileMTimes {
-                worktree_id,
-                sender: tx,
-            })
-            .unwrap();
-        async move { rx.await? }
-    }
-
-    fn worktree_previously_indexed(&self, path: Arc<Path>) -> impl Future<Output = Result<bool>> {
-        let (tx, rx) = oneshot::channel();
-        self.db_update_tx
-            .try_send(DbOperation::WorktreePreviouslyIndexed { path, sender: tx })
-            .unwrap();
-        async move { rx.await? }
-    }
-
     pub fn project_previously_indexed(
         &mut self,
         project: ModelHandle<Project>,
@@ -447,7 +332,10 @@ impl SemanticIndex {
         let worktrees_indexed_previously = project
             .read(cx)
             .worktrees(cx)
-            .map(|worktree| self.worktree_previously_indexed(worktree.read(cx).abs_path()))
+            .map(|worktree| {
+                self.db
+                    .worktree_previously_indexed(&worktree.read(cx).abs_path())
+            })
             .collect::<Vec<_>>();
         cx.spawn(|_, _cx| async move {
             let worktree_indexed_previously =
@@ -528,7 +416,8 @@ impl SemanticIndex {
             .read(cx)
             .worktrees(cx)
             .map(|worktree| {
-                self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
+                self.db
+                    .find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
             })
             .collect::<Vec<_>>();
 
@@ -559,7 +448,7 @@ impl SemanticIndex {
                 db_ids_by_worktree_id.insert(worktree.id(), db_id);
                 worktree_file_mtimes.insert(
                     worktree.id(),
-                    this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
+                    this.read_with(&cx, |this, _| this.db.get_file_mtimes(db_id))
                         .await?,
                 );
             }
@@ -704,11 +593,12 @@ impl SemanticIndex {
             .collect::<Vec<_>>();
 
         let embedding_provider = self.embedding_provider.clone();
-        let database_url = self.database_url.clone();
+        let db_path = self.db.path().clone();
         let fs = self.fs.clone();
         cx.spawn(|this, mut cx| async move {
             let t0 = Instant::now();
-            let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?;
+            let database =
+                VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
 
             let phrase_embedding = embedding_provider
                 .embed_batch(vec![phrase])
@@ -722,8 +612,9 @@ impl SemanticIndex {
                 t0.elapsed().as_millis()
             );
 
-            let file_ids =
-                database.retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)?;
+            let file_ids = database
+                .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
+                .await?;
 
             let batch_n = cx.background().num_cpus();
             let ids_len = file_ids.clone().len();
@@ -733,27 +624,24 @@ impl SemanticIndex {
                 ids_len / batch_n
             };
 
-            let mut result_tasks = Vec::new();
+            let mut batch_results = Vec::new();
             for batch in file_ids.chunks(batch_size) {
                 let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
                 let limit = limit.clone();
                 let fs = fs.clone();
-                let database_url = database_url.clone();
+                let db_path = db_path.clone();
                 let phrase_embedding = phrase_embedding.clone();
-                let task = cx.background().spawn(async move {
-                    let database = VectorDatabase::new(fs, database_url).await.log_err();
-                    if database.is_none() {
-                        return Err(anyhow!("failed to acquire database connection"));
-                    } else {
-                        database
-                            .unwrap()
-                            .top_k_search(&phrase_embedding, limit, batch.as_slice())
-                    }
-                });
-                result_tasks.push(task);
+                if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
+                    .await
+                    .log_err()
+                {
+                    batch_results.push(async move {
+                        db.top_k_search(&phrase_embedding, limit, batch.as_slice())
+                            .await
+                    });
+                }
             }
-
-            let batch_results = futures::future::join_all(result_tasks).await;
+            let batch_results = futures::future::join_all(batch_results).await;
 
             let mut results = Vec::new();
             for batch_result in batch_results {
@@ -772,7 +660,7 @@ impl SemanticIndex {
             }
 
             let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>();
-            let documents = database.get_documents_by_ids(ids.as_slice())?;
+            let documents = database.get_documents_by_ids(ids.as_slice()).await?;
 
             let mut tasks = Vec::new();
             let mut ranges = Vec::new();
@@ -822,7 +710,8 @@ impl SemanticIndex {
         cx: &mut AsyncAppContext,
     ) {
         let mut pending_files = Vec::new();
-        let (language_registry, parsing_files_tx) = this.update(cx, |this, cx| {
+        let mut files_to_delete = Vec::new();
+        let (db, language_registry, parsing_files_tx) = this.update(cx, |this, cx| {
             if let Some(project_state) = this.projects.get_mut(&project.downgrade()) {
                 let outstanding_job_count_tx = &project_state.outstanding_job_count_tx;
                 let db_ids = &project_state.worktree_db_ids;
@@ -853,12 +742,7 @@ impl SemanticIndex {
                     };
 
                     if info.is_deleted {
-                        this.db_update_tx
-                            .try_send(DbOperation::Delete {
-                                worktree_id: worktree_db_id,
-                                path: path.path.to_path_buf(),
-                            })
-                            .ok();
+                        files_to_delete.push((worktree_db_id, path.path.to_path_buf()));
                     } else {
                         let absolute_path = worktree.read(cx).absolutize(&path.path);
                         let job_handle = JobHandle::new(&outstanding_job_count_tx);
@@ -877,11 +761,16 @@ impl SemanticIndex {
             }
 
             (
+                this.db.clone(),
                 this.language_registry.clone(),
                 this.parsing_files_tx.clone(),
             )
         });
 
+        for (worktree_db_id, path) in files_to_delete {
+            db.delete_file(worktree_db_id, path).await.log_err();
+        }
+
         for mut pending_file in pending_files {
             if let Ok(language) = language_registry
                 .language_for_file(&pending_file.relative_path, None)