db.rs

  1use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
  2use anyhow::{anyhow, Context, Result};
  3use futures::channel::oneshot;
  4use gpui::executor;
  5use project::{search::PathMatcher, Fs};
  6use rpc::proto::Timestamp;
  7use rusqlite::{
  8    params,
  9    types::{FromSql, FromSqlResult, ValueRef},
 10};
 11use std::{
 12    cmp::Ordering,
 13    collections::HashMap,
 14    future::Future,
 15    ops::Range,
 16    path::{Path, PathBuf},
 17    rc::Rc,
 18    sync::Arc,
 19    time::SystemTime,
 20};
 21use util::TryFutureExt;
 22
 23#[derive(Debug)]
 24pub struct FileRecord {
 25    pub id: usize,
 26    pub relative_path: String,
 27    pub mtime: Timestamp,
 28}
 29
 30#[derive(Debug)]
 31struct Embedding(pub Vec<f32>);
 32
 33#[derive(Debug)]
 34struct Sha1(pub Vec<u8>);
 35
 36impl FromSql for Embedding {
 37    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 38        let bytes = value.as_blob()?;
 39        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 40        if embedding.is_err() {
 41            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 42        }
 43        return Ok(Embedding(embedding.unwrap()));
 44    }
 45}
 46
 47impl FromSql for Sha1 {
 48    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 49        let bytes = value.as_blob()?;
 50        let sha1: Result<Vec<u8>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 51        if sha1.is_err() {
 52            return Err(rusqlite::types::FromSqlError::Other(sha1.unwrap_err()));
 53        }
 54        return Ok(Sha1(sha1.unwrap()));
 55    }
 56}
 57
 58#[derive(Clone)]
 59pub struct VectorDatabase {
 60    path: Arc<Path>,
 61    transactions: smol::channel::Sender<Box<dyn 'static + Send + FnOnce(&rusqlite::Connection)>>,
 62}
 63
 64impl VectorDatabase {
 65    pub async fn new(
 66        fs: Arc<dyn Fs>,
 67        path: Arc<Path>,
 68        executor: Arc<executor::Background>,
 69    ) -> Result<Self> {
 70        if let Some(db_directory) = path.parent() {
 71            fs.create_dir(db_directory).await?;
 72        }
 73
 74        let (transactions_tx, transactions_rx) =
 75            smol::channel::unbounded::<Box<dyn 'static + Send + FnOnce(&rusqlite::Connection)>>();
 76        executor
 77            .spawn({
 78                let path = path.clone();
 79                async move {
 80                    let connection = rusqlite::Connection::open(&path)?;
 81                    while let Ok(transaction) = transactions_rx.recv().await {
 82                        transaction(&connection);
 83                    }
 84
 85                    anyhow::Ok(())
 86                }
 87                .log_err()
 88            })
 89            .detach();
 90        let this = Self {
 91            transactions: transactions_tx,
 92            path,
 93        };
 94        this.initialize_database().await?;
 95        Ok(this)
 96    }
 97
 98    pub fn path(&self) -> &Arc<Path> {
 99        &self.path
100    }
101
102    fn transact<F, T>(&self, transaction: F) -> impl Future<Output = Result<T>>
103    where
104        F: 'static + Send + FnOnce(&rusqlite::Connection) -> Result<T>,
105        T: 'static + Send,
106    {
107        let (tx, rx) = oneshot::channel();
108        let transactions = self.transactions.clone();
109        async move {
110            if transactions
111                .send(Box::new(|connection| {
112                    let result = transaction(connection);
113                    let _ = tx.send(result);
114                }))
115                .await
116                .is_err()
117            {
118                return Err(anyhow!("connection was dropped"))?;
119            }
120            rx.await?
121        }
122    }
123
124    fn initialize_database(&self) -> impl Future<Output = Result<()>> {
125        self.transact(|db| {
126            rusqlite::vtab::array::load_module(&db)?;
127
128            // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
129            let version_query = db.prepare("SELECT version from semantic_index_config");
130            let version = version_query
131                .and_then(|mut query| query.query_row([], |row| Ok(row.get::<_, i64>(0)?)));
132            if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) {
133                log::trace!("vector database schema up to date");
134                return Ok(());
135            }
136
137            log::trace!("vector database schema out of date. updating...");
138            db.execute("DROP TABLE IF EXISTS documents", [])
139                .context("failed to drop 'documents' table")?;
140            db.execute("DROP TABLE IF EXISTS files", [])
141                .context("failed to drop 'files' table")?;
142            db.execute("DROP TABLE IF EXISTS worktrees", [])
143                .context("failed to drop 'worktrees' table")?;
144            db.execute("DROP TABLE IF EXISTS semantic_index_config", [])
145                .context("failed to drop 'semantic_index_config' table")?;
146
147            // Initialize Vector Databasing Tables
148            db.execute(
149                "CREATE TABLE semantic_index_config (
150                    version INTEGER NOT NULL
151                )",
152                [],
153            )?;
154
155            db.execute(
156                "INSERT INTO semantic_index_config (version) VALUES (?1)",
157                params![SEMANTIC_INDEX_VERSION],
158            )?;
159
160            db.execute(
161                "CREATE TABLE worktrees (
162                    id INTEGER PRIMARY KEY AUTOINCREMENT,
163                    absolute_path VARCHAR NOT NULL
164                );
165                CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
166                ",
167                [],
168            )?;
169
170            db.execute(
171                "CREATE TABLE files (
172                    id INTEGER PRIMARY KEY AUTOINCREMENT,
173                    worktree_id INTEGER NOT NULL,
174                    relative_path VARCHAR NOT NULL,
175                    mtime_seconds INTEGER NOT NULL,
176                    mtime_nanos INTEGER NOT NULL,
177                    FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
178                )",
179                [],
180            )?;
181
182            db.execute(
183                "CREATE TABLE documents (
184                    id INTEGER PRIMARY KEY AUTOINCREMENT,
185                    file_id INTEGER NOT NULL,
186                    start_byte INTEGER NOT NULL,
187                    end_byte INTEGER NOT NULL,
188                    name VARCHAR NOT NULL,
189                    embedding BLOB NOT NULL,
190                    sha1 BLOB NOT NULL,
191                    FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
192                )",
193                [],
194            )?;
195
196            log::trace!("vector database initialized with updated schema.");
197            Ok(())
198        })
199    }
200
201    pub fn delete_file(
202        &self,
203        worktree_id: i64,
204        delete_path: PathBuf,
205    ) -> impl Future<Output = Result<()>> {
206        self.transact(move |db| {
207            db.execute(
208                "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
209                params![worktree_id, delete_path.to_str()],
210            )?;
211            Ok(())
212        })
213    }
214
215    pub fn insert_file(
216        &self,
217        worktree_id: i64,
218        path: PathBuf,
219        mtime: SystemTime,
220        documents: Vec<Document>,
221    ) -> impl Future<Output = Result<()>> {
222        self.transact(move |db| {
223            // Return the existing ID, if both the file and mtime match
224            let mtime = Timestamp::from(mtime);
225
226            let mut existing_id_query = db.prepare("SELECT id FROM files WHERE worktree_id = ?1 AND relative_path = ?2 AND mtime_seconds = ?3 AND mtime_nanos = ?4")?;
227            let existing_id = existing_id_query
228                .query_row(
229                    params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
230                    |row| Ok(row.get::<_, i64>(0)?),
231                );
232
233            let file_id = if existing_id.is_ok() {
234                // If already exists, just return the existing id
235                existing_id?
236            } else {
237                // Delete Existing Row
238                db.execute(
239                    "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;",
240                    params![worktree_id, path.to_str()],
241                )?;
242                db.execute("INSERT INTO files (worktree_id, relative_path, mtime_seconds, mtime_nanos) VALUES (?1, ?2, ?3, ?4);", params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos])?;
243                db.last_insert_rowid()
244            };
245
246            // Currently inserting at approximately 3400 documents a second
247            // I imagine we can speed this up with a bulk insert of some kind.
248            for document in documents {
249                let embedding_blob = bincode::serialize(&document.embedding)?;
250                let sha_blob = bincode::serialize(&document.sha1)?;
251
252                db.execute(
253                    "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
254                    params![
255                        file_id,
256                        document.range.start.to_string(),
257                        document.range.end.to_string(),
258                        document.name,
259                        embedding_blob,
260                        sha_blob
261                    ],
262                )?;
263           }
264
265           Ok(())
266        })
267    }
268
269    pub fn worktree_previously_indexed(
270        &self,
271        worktree_root_path: &Path,
272    ) -> impl Future<Output = Result<bool>> {
273        let worktree_root_path = worktree_root_path.to_string_lossy().into_owned();
274        self.transact(move |db| {
275            let mut worktree_query =
276                db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
277            let worktree_id = worktree_query
278                .query_row(params![worktree_root_path], |row| Ok(row.get::<_, i64>(0)?));
279
280            if worktree_id.is_ok() {
281                return Ok(true);
282            } else {
283                return Ok(false);
284            }
285        })
286    }
287
288    pub fn find_or_create_worktree(
289        &self,
290        worktree_root_path: PathBuf,
291    ) -> impl Future<Output = Result<i64>> {
292        self.transact(move |db| {
293            let mut worktree_query =
294                db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
295            let worktree_id = worktree_query
296                .query_row(params![worktree_root_path.to_string_lossy()], |row| {
297                    Ok(row.get::<_, i64>(0)?)
298                });
299
300            if worktree_id.is_ok() {
301                return Ok(worktree_id?);
302            }
303
304            // If worktree_id is Err, insert new worktree
305            db.execute(
306                "INSERT into worktrees (absolute_path) VALUES (?1)",
307                params![worktree_root_path.to_string_lossy()],
308            )?;
309            Ok(db.last_insert_rowid())
310        })
311    }
312
313    pub fn get_file_mtimes(
314        &self,
315        worktree_id: i64,
316    ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
317        self.transact(move |db| {
318            let mut statement = db.prepare(
319                "
320                SELECT relative_path, mtime_seconds, mtime_nanos
321                FROM files
322                WHERE worktree_id = ?1
323                ORDER BY relative_path",
324            )?;
325            let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
326            for row in statement.query_map(params![worktree_id], |row| {
327                Ok((
328                    row.get::<_, String>(0)?.into(),
329                    Timestamp {
330                        seconds: row.get(1)?,
331                        nanos: row.get(2)?,
332                    }
333                    .into(),
334                ))
335            })? {
336                let row = row?;
337                result.insert(row.0, row.1);
338            }
339            Ok(result)
340        })
341    }
342
343    pub fn top_k_search(
344        &self,
345        query_embedding: &Vec<f32>,
346        limit: usize,
347        file_ids: &[i64],
348    ) -> impl Future<Output = Result<Vec<(i64, f32)>>> {
349        let query_embedding = query_embedding.clone();
350        let file_ids = file_ids.to_vec();
351        self.transact(move |db| {
352            let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
353            Self::for_each_document(db, &file_ids, |id, embedding| {
354                let similarity = dot(&embedding, &query_embedding);
355                let ix = match results.binary_search_by(|(_, s)| {
356                    similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
357                }) {
358                    Ok(ix) => ix,
359                    Err(ix) => ix,
360                };
361                results.insert(ix, (id, similarity));
362                results.truncate(limit);
363            })?;
364
365            anyhow::Ok(results)
366        })
367    }
368
369    pub fn retrieve_included_file_ids(
370        &self,
371        worktree_ids: &[i64],
372        includes: &[PathMatcher],
373        excludes: &[PathMatcher],
374    ) -> impl Future<Output = Result<Vec<i64>>> {
375        let worktree_ids = worktree_ids.to_vec();
376        let includes = includes.to_vec();
377        let excludes = excludes.to_vec();
378        self.transact(move |db| {
379            let mut file_query = db.prepare(
380                "
381                SELECT
382                    id, relative_path
383                FROM
384                    files
385                WHERE
386                    worktree_id IN rarray(?)
387                ",
388            )?;
389
390            let mut file_ids = Vec::<i64>::new();
391            let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?;
392
393            while let Some(row) = rows.next()? {
394                let file_id = row.get(0)?;
395                let relative_path = row.get_ref(1)?.as_str()?;
396                let included =
397                    includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
398                let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
399                if included && !excluded {
400                    file_ids.push(file_id);
401                }
402            }
403
404            anyhow::Ok(file_ids)
405        })
406    }
407
408    fn for_each_document(
409        db: &rusqlite::Connection,
410        file_ids: &[i64],
411        mut f: impl FnMut(i64, Vec<f32>),
412    ) -> Result<()> {
413        let mut query_statement = db.prepare(
414            "
415            SELECT
416                id, embedding
417            FROM
418                documents
419            WHERE
420                file_id IN rarray(?)
421            ",
422        )?;
423
424        query_statement
425            .query_map(params![ids_to_sql(&file_ids)], |row| {
426                Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
427            })?
428            .filter_map(|row| row.ok())
429            .for_each(|(id, embedding)| f(id, embedding.0));
430        Ok(())
431    }
432
433    pub fn get_documents_by_ids(
434        &self,
435        ids: &[i64],
436    ) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
437        let ids = ids.to_vec();
438        self.transact(move |db| {
439            let mut statement = db.prepare(
440                "
441                    SELECT
442                        documents.id,
443                        files.worktree_id,
444                        files.relative_path,
445                        documents.start_byte,
446                        documents.end_byte
447                    FROM
448                        documents, files
449                    WHERE
450                        documents.file_id = files.id AND
451                        documents.id in rarray(?)
452                ",
453            )?;
454
455            let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| {
456                Ok((
457                    row.get::<_, i64>(0)?,
458                    row.get::<_, i64>(1)?,
459                    row.get::<_, String>(2)?.into(),
460                    row.get(3)?..row.get(4)?,
461                ))
462            })?;
463
464            let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
465            for row in result_iter {
466                let (id, worktree_id, path, range) = row?;
467                values_by_id.insert(id, (worktree_id, path, range));
468            }
469
470            let mut results = Vec::with_capacity(ids.len());
471            for id in &ids {
472                let value = values_by_id
473                    .remove(id)
474                    .ok_or(anyhow!("missing document id {}", id))?;
475                results.push(value);
476            }
477
478            Ok(results)
479        })
480    }
481}
482
483fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
484    Rc::new(
485        ids.iter()
486            .copied()
487            .map(|v| rusqlite::types::Value::from(v))
488            .collect::<Vec<_>>(),
489    )
490}
491
492pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
493    let len = vec_a.len();
494    assert_eq!(len, vec_b.len());
495
496    let mut result = 0.0;
497    unsafe {
498        matrixmultiply::sgemm(
499            1,
500            len,
501            1,
502            1.0,
503            vec_a.as_ptr(),
504            len as isize,
505            1,
506            vec_b.as_ptr(),
507            1,
508            len as isize,
509            0.0,
510            &mut result as *mut f32,
511            1,
512            1,
513        );
514    }
515    result
516}