db.rs

  1use crate::{
  2    embedding::Embedding,
  3    parsing::{Span, SpanDigest},
  4    SEMANTIC_INDEX_VERSION,
  5};
  6use anyhow::{anyhow, Context, Result};
  7use collections::HashMap;
  8use futures::channel::oneshot;
  9use gpui::executor;
 10use project::{search::PathMatcher, Fs};
 11use rpc::proto::Timestamp;
 12use rusqlite::params;
 13use rusqlite::types::Value;
 14use std::{
 15    cmp::Ordering,
 16    future::Future,
 17    ops::Range,
 18    path::{Path, PathBuf},
 19    rc::Rc,
 20    sync::Arc,
 21    time::SystemTime,
 22};
 23use util::TryFutureExt;
 24
 25#[derive(Debug)]
 26pub struct FileRecord {
 27    pub id: usize,
 28    pub relative_path: String,
 29    pub mtime: Timestamp,
 30}
 31
 32#[derive(Clone)]
 33pub struct VectorDatabase {
 34    path: Arc<Path>,
 35    transactions:
 36        smol::channel::Sender<Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>>,
 37}
 38
 39impl VectorDatabase {
 40    pub async fn new(
 41        fs: Arc<dyn Fs>,
 42        path: Arc<Path>,
 43        executor: Arc<executor::Background>,
 44    ) -> Result<Self> {
 45        if let Some(db_directory) = path.parent() {
 46            fs.create_dir(db_directory).await?;
 47        }
 48
 49        let (transactions_tx, transactions_rx) = smol::channel::unbounded::<
 50            Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>,
 51        >();
 52        executor
 53            .spawn({
 54                let path = path.clone();
 55                async move {
 56                    let mut connection = rusqlite::Connection::open(&path)?;
 57
 58                    connection.pragma_update(None, "journal_mode", "wal")?;
 59                    connection.pragma_update(None, "synchronous", "normal")?;
 60                    connection.pragma_update(None, "cache_size", 1000000)?;
 61                    connection.pragma_update(None, "temp_store", "MEMORY")?;
 62
 63                    while let Ok(transaction) = transactions_rx.recv().await {
 64                        transaction(&mut connection);
 65                    }
 66
 67                    anyhow::Ok(())
 68                }
 69                .log_err()
 70            })
 71            .detach();
 72        let this = Self {
 73            transactions: transactions_tx,
 74            path,
 75        };
 76        this.initialize_database().await?;
 77        Ok(this)
 78    }
 79
 80    pub fn path(&self) -> &Arc<Path> {
 81        &self.path
 82    }
 83
 84    fn transact<F, T>(&self, f: F) -> impl Future<Output = Result<T>>
 85    where
 86        F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result<T>,
 87        T: 'static + Send,
 88    {
 89        let (tx, rx) = oneshot::channel();
 90        let transactions = self.transactions.clone();
 91        async move {
 92            if transactions
 93                .send(Box::new(|connection| {
 94                    let result = connection
 95                        .transaction()
 96                        .map_err(|err| anyhow!(err))
 97                        .and_then(|transaction| {
 98                            let result = f(&transaction)?;
 99                            transaction.commit()?;
100                            Ok(result)
101                        });
102                    let _ = tx.send(result);
103                }))
104                .await
105                .is_err()
106            {
107                return Err(anyhow!("connection was dropped"))?;
108            }
109            rx.await?
110        }
111    }
112
113    fn initialize_database(&self) -> impl Future<Output = Result<()>> {
114        self.transact(|db| {
115            rusqlite::vtab::array::load_module(&db)?;
116
117            // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
118            let version_query = db.prepare("SELECT version from semantic_index_config");
119            let version = version_query
120                .and_then(|mut query| query.query_row([], |row| Ok(row.get::<_, i64>(0)?)));
121            if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) {
122                log::trace!("vector database schema up to date");
123                return Ok(());
124            }
125
126            log::trace!("vector database schema out of date. updating...");
127            // We renamed the `documents` table to `spans`, so we want to drop
128            // `documents` without recreating it if it exists.
129            db.execute("DROP TABLE IF EXISTS documents", [])
130                .context("failed to drop 'documents' table")?;
131            db.execute("DROP TABLE IF EXISTS spans", [])
132                .context("failed to drop 'spans' table")?;
133            db.execute("DROP TABLE IF EXISTS files", [])
134                .context("failed to drop 'files' table")?;
135            db.execute("DROP TABLE IF EXISTS worktrees", [])
136                .context("failed to drop 'worktrees' table")?;
137            db.execute("DROP TABLE IF EXISTS semantic_index_config", [])
138                .context("failed to drop 'semantic_index_config' table")?;
139
140            // Initialize Vector Databasing Tables
141            db.execute(
142                "CREATE TABLE semantic_index_config (
143                    version INTEGER NOT NULL
144                )",
145                [],
146            )?;
147
148            db.execute(
149                "INSERT INTO semantic_index_config (version) VALUES (?1)",
150                params![SEMANTIC_INDEX_VERSION],
151            )?;
152
153            db.execute(
154                "CREATE TABLE worktrees (
155                    id INTEGER PRIMARY KEY AUTOINCREMENT,
156                    absolute_path VARCHAR NOT NULL
157                );
158                CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
159                ",
160                [],
161            )?;
162
163            db.execute(
164                "CREATE TABLE files (
165                    id INTEGER PRIMARY KEY AUTOINCREMENT,
166                    worktree_id INTEGER NOT NULL,
167                    relative_path VARCHAR NOT NULL,
168                    mtime_seconds INTEGER NOT NULL,
169                    mtime_nanos INTEGER NOT NULL,
170                    FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
171                )",
172                [],
173            )?;
174
175            db.execute(
176                "CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)",
177                [],
178            )?;
179
180            db.execute(
181                "CREATE TABLE spans (
182                    id INTEGER PRIMARY KEY AUTOINCREMENT,
183                    file_id INTEGER NOT NULL,
184                    start_byte INTEGER NOT NULL,
185                    end_byte INTEGER NOT NULL,
186                    name VARCHAR NOT NULL,
187                    embedding BLOB NOT NULL,
188                    digest BLOB NOT NULL,
189                    FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
190                )",
191                [],
192            )?;
193
194            log::trace!("vector database initialized with updated schema.");
195            Ok(())
196        })
197    }
198
199    pub fn delete_file(
200        &self,
201        worktree_id: i64,
202        delete_path: Arc<Path>,
203    ) -> impl Future<Output = Result<()>> {
204        self.transact(move |db| {
205            db.execute(
206                "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
207                params![worktree_id, delete_path.to_str()],
208            )?;
209            Ok(())
210        })
211    }
212
213    pub fn insert_file(
214        &self,
215        worktree_id: i64,
216        path: Arc<Path>,
217        mtime: SystemTime,
218        spans: Vec<Span>,
219    ) -> impl Future<Output = Result<()>> {
220        self.transact(move |db| {
221            // Return the existing ID, if both the file and mtime match
222            let mtime = Timestamp::from(mtime);
223
224            db.execute(
225                "
226                REPLACE INTO files
227                (worktree_id, relative_path, mtime_seconds, mtime_nanos)
228                VALUES (?1, ?2, ?3, ?4)
229                ",
230                params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
231            )?;
232
233            let file_id = db.last_insert_rowid();
234
235            let mut query = db.prepare(
236                "
237                INSERT INTO spans
238                (file_id, start_byte, end_byte, name, embedding, digest)
239                VALUES (?1, ?2, ?3, ?4, ?5, ?6)
240                ",
241            )?;
242
243            for span in spans {
244                query.execute(params![
245                    file_id,
246                    span.range.start.to_string(),
247                    span.range.end.to_string(),
248                    span.name,
249                    span.embedding,
250                    span.digest
251                ])?;
252            }
253
254            Ok(())
255        })
256    }
257
258    pub fn worktree_previously_indexed(
259        &self,
260        worktree_root_path: &Path,
261    ) -> impl Future<Output = Result<bool>> {
262        let worktree_root_path = worktree_root_path.to_string_lossy().into_owned();
263        self.transact(move |db| {
264            let mut worktree_query =
265                db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
266            let worktree_id = worktree_query
267                .query_row(params![worktree_root_path], |row| Ok(row.get::<_, i64>(0)?));
268
269            if worktree_id.is_ok() {
270                return Ok(true);
271            } else {
272                return Ok(false);
273            }
274        })
275    }
276
277    pub fn embeddings_for_files(
278        &self,
279        worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
280    ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
281        self.transact(move |db| {
282            let mut query = db.prepare(
283                "
284                SELECT digest, embedding
285                FROM spans
286                LEFT JOIN files ON files.id = spans.file_id
287                WHERE files.worktree_id = ? AND files.relative_path IN rarray(?)
288            ",
289            )?;
290            let mut embeddings_by_digest = HashMap::default();
291            for (worktree_id, file_paths) in worktree_id_file_paths {
292                let file_paths = Rc::new(
293                    file_paths
294                        .into_iter()
295                        .map(|p| Value::Text(p.to_string_lossy().into_owned()))
296                        .collect::<Vec<_>>(),
297                );
298                let rows = query.query_map(params![worktree_id, file_paths], |row| {
299                    Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
300                })?;
301
302                for row in rows {
303                    if let Ok(row) = row {
304                        embeddings_by_digest.insert(row.0, row.1);
305                    }
306                }
307            }
308
309            Ok(embeddings_by_digest)
310        })
311    }
312
313    pub fn find_or_create_worktree(
314        &self,
315        worktree_root_path: Arc<Path>,
316    ) -> impl Future<Output = Result<i64>> {
317        self.transact(move |db| {
318            let mut worktree_query =
319                db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
320            let worktree_id = worktree_query
321                .query_row(params![worktree_root_path.to_string_lossy()], |row| {
322                    Ok(row.get::<_, i64>(0)?)
323                });
324
325            if worktree_id.is_ok() {
326                return Ok(worktree_id?);
327            }
328
329            // If worktree_id is Err, insert new worktree
330            db.execute(
331                "INSERT into worktrees (absolute_path) VALUES (?1)",
332                params![worktree_root_path.to_string_lossy()],
333            )?;
334            Ok(db.last_insert_rowid())
335        })
336    }
337
338    pub fn get_file_mtimes(
339        &self,
340        worktree_id: i64,
341    ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
342        self.transact(move |db| {
343            let mut statement = db.prepare(
344                "
345                SELECT relative_path, mtime_seconds, mtime_nanos
346                FROM files
347                WHERE worktree_id = ?1
348                ORDER BY relative_path",
349            )?;
350            let mut result: HashMap<PathBuf, SystemTime> = HashMap::default();
351            for row in statement.query_map(params![worktree_id], |row| {
352                Ok((
353                    row.get::<_, String>(0)?.into(),
354                    Timestamp {
355                        seconds: row.get(1)?,
356                        nanos: row.get(2)?,
357                    }
358                    .into(),
359                ))
360            })? {
361                let row = row?;
362                result.insert(row.0, row.1);
363            }
364            Ok(result)
365        })
366    }
367
368    pub fn top_k_search(
369        &self,
370        query_embedding: &Embedding,
371        limit: usize,
372        file_ids: &[i64],
373    ) -> impl Future<Output = Result<Vec<(i64, f32)>>> {
374        let query_embedding = query_embedding.clone();
375        let file_ids = file_ids.to_vec();
376        self.transact(move |db| {
377            let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
378            Self::for_each_span(db, &file_ids, |id, embedding| {
379                let similarity = embedding.similarity(&query_embedding);
380                let ix = match results.binary_search_by(|(_, s)| {
381                    similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
382                }) {
383                    Ok(ix) => ix,
384                    Err(ix) => ix,
385                };
386                results.insert(ix, (id, similarity));
387                results.truncate(limit);
388            })?;
389
390            anyhow::Ok(results)
391        })
392    }
393
394    pub fn retrieve_included_file_ids(
395        &self,
396        worktree_ids: &[i64],
397        includes: &[PathMatcher],
398        excludes: &[PathMatcher],
399    ) -> impl Future<Output = Result<Vec<i64>>> {
400        let worktree_ids = worktree_ids.to_vec();
401        let includes = includes.to_vec();
402        let excludes = excludes.to_vec();
403        self.transact(move |db| {
404            let mut file_query = db.prepare(
405                "
406                SELECT
407                    id, relative_path
408                FROM
409                    files
410                WHERE
411                    worktree_id IN rarray(?)
412                ",
413            )?;
414
415            let mut file_ids = Vec::<i64>::new();
416            let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?;
417
418            while let Some(row) = rows.next()? {
419                let file_id = row.get(0)?;
420                let relative_path = row.get_ref(1)?.as_str()?;
421                let included =
422                    includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
423                let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
424                if included && !excluded {
425                    file_ids.push(file_id);
426                }
427            }
428
429            anyhow::Ok(file_ids)
430        })
431    }
432
433    fn for_each_span(
434        db: &rusqlite::Connection,
435        file_ids: &[i64],
436        mut f: impl FnMut(i64, Embedding),
437    ) -> Result<()> {
438        let mut query_statement = db.prepare(
439            "
440            SELECT
441                id, embedding
442            FROM
443                spans
444            WHERE
445                file_id IN rarray(?)
446            ",
447        )?;
448
449        query_statement
450            .query_map(params![ids_to_sql(&file_ids)], |row| {
451                Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
452            })?
453            .filter_map(|row| row.ok())
454            .for_each(|(id, embedding)| f(id, embedding));
455        Ok(())
456    }
457
458    pub fn spans_for_ids(
459        &self,
460        ids: &[i64],
461    ) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
462        let ids = ids.to_vec();
463        self.transact(move |db| {
464            let mut statement = db.prepare(
465                "
466                    SELECT
467                        spans.id,
468                        files.worktree_id,
469                        files.relative_path,
470                        spans.start_byte,
471                        spans.end_byte
472                    FROM
473                        spans, files
474                    WHERE
475                        spans.file_id = files.id AND
476                        spans.id in rarray(?)
477                ",
478            )?;
479
480            let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| {
481                Ok((
482                    row.get::<_, i64>(0)?,
483                    row.get::<_, i64>(1)?,
484                    row.get::<_, String>(2)?.into(),
485                    row.get(3)?..row.get(4)?,
486                ))
487            })?;
488
489            let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
490            for row in result_iter {
491                let (id, worktree_id, path, range) = row?;
492                values_by_id.insert(id, (worktree_id, path, range));
493            }
494
495            let mut results = Vec::with_capacity(ids.len());
496            for id in &ids {
497                let value = values_by_id
498                    .remove(id)
499                    .ok_or(anyhow!("missing span id {}", id))?;
500                results.push(value);
501            }
502
503            Ok(results)
504        })
505    }
506}
507
508fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
509    Rc::new(
510        ids.iter()
511            .copied()
512            .map(|v| rusqlite::types::Value::from(v))
513            .collect::<Vec<_>>(),
514    )
515}