db.rs

  1use crate::{
  2    parsing::{Span, SpanDigest},
  3    SEMANTIC_INDEX_VERSION,
  4};
  5use ai::embedding::Embedding;
  6use anyhow::{anyhow, Context, Result};
  7use collections::HashMap;
  8use futures::channel::oneshot;
  9use gpui::executor;
 10use ndarray::{Array1, Array2};
 11use ordered_float::OrderedFloat;
 12use project::{search::PathMatcher, Fs};
 13use rpc::proto::Timestamp;
 14use rusqlite::params;
 15use rusqlite::types::Value;
 16use std::{
 17    future::Future,
 18    ops::Range,
 19    path::{Path, PathBuf},
 20    rc::Rc,
 21    sync::Arc,
 22    time::SystemTime,
 23};
 24use util::TryFutureExt;
 25
 26pub fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
 27    let mut indices = (0..data.len()).collect::<Vec<_>>();
 28    indices.sort_by_key(|&i| &data[i]);
 29    indices.reverse();
 30    indices
 31}
 32
 33#[derive(Debug)]
 34pub struct FileRecord {
 35    pub id: usize,
 36    pub relative_path: String,
 37    pub mtime: Timestamp,
 38}
 39
 40#[derive(Clone)]
 41pub struct VectorDatabase {
 42    path: Arc<Path>,
 43    transactions:
 44        smol::channel::Sender<Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>>,
 45}
 46
 47impl VectorDatabase {
 48    pub async fn new(
 49        fs: Arc<dyn Fs>,
 50        path: Arc<Path>,
 51        executor: Arc<executor::Background>,
 52    ) -> Result<Self> {
 53        if let Some(db_directory) = path.parent() {
 54            fs.create_dir(db_directory).await?;
 55        }
 56
 57        let (transactions_tx, transactions_rx) = smol::channel::unbounded::<
 58            Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>,
 59        >();
 60        executor
 61            .spawn({
 62                let path = path.clone();
 63                async move {
 64                    let mut connection = rusqlite::Connection::open(&path)?;
 65
 66                    connection.pragma_update(None, "journal_mode", "wal")?;
 67                    connection.pragma_update(None, "synchronous", "normal")?;
 68                    connection.pragma_update(None, "cache_size", 1000000)?;
 69                    connection.pragma_update(None, "temp_store", "MEMORY")?;
 70
 71                    while let Ok(transaction) = transactions_rx.recv().await {
 72                        transaction(&mut connection);
 73                    }
 74
 75                    anyhow::Ok(())
 76                }
 77                .log_err()
 78            })
 79            .detach();
 80        let this = Self {
 81            transactions: transactions_tx,
 82            path,
 83        };
 84        this.initialize_database().await?;
 85        Ok(this)
 86    }
 87
 88    pub fn path(&self) -> &Arc<Path> {
 89        &self.path
 90    }
 91
 92    fn transact<F, T>(&self, f: F) -> impl Future<Output = Result<T>>
 93    where
 94        F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result<T>,
 95        T: 'static + Send,
 96    {
 97        let (tx, rx) = oneshot::channel();
 98        let transactions = self.transactions.clone();
 99        async move {
100            if transactions
101                .send(Box::new(|connection| {
102                    let result = connection
103                        .transaction()
104                        .map_err(|err| anyhow!(err))
105                        .and_then(|transaction| {
106                            let result = f(&transaction)?;
107                            transaction.commit()?;
108                            Ok(result)
109                        });
110                    let _ = tx.send(result);
111                }))
112                .await
113                .is_err()
114            {
115                return Err(anyhow!("connection was dropped"))?;
116            }
117            rx.await?
118        }
119    }
120
121    fn initialize_database(&self) -> impl Future<Output = Result<()>> {
122        self.transact(|db| {
123            rusqlite::vtab::array::load_module(&db)?;
124
125            // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
126            let version_query = db.prepare("SELECT version from semantic_index_config");
127            let version = version_query
128                .and_then(|mut query| query.query_row([], |row| Ok(row.get::<_, i64>(0)?)));
129            if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) {
130                log::trace!("vector database schema up to date");
131                return Ok(());
132            }
133
134            log::trace!("vector database schema out of date. updating...");
135            // We renamed the `documents` table to `spans`, so we want to drop
136            // `documents` without recreating it if it exists.
137            db.execute("DROP TABLE IF EXISTS documents", [])
138                .context("failed to drop 'documents' table")?;
139            db.execute("DROP TABLE IF EXISTS spans", [])
140                .context("failed to drop 'spans' table")?;
141            db.execute("DROP TABLE IF EXISTS files", [])
142                .context("failed to drop 'files' table")?;
143            db.execute("DROP TABLE IF EXISTS worktrees", [])
144                .context("failed to drop 'worktrees' table")?;
145            db.execute("DROP TABLE IF EXISTS semantic_index_config", [])
146                .context("failed to drop 'semantic_index_config' table")?;
147
148            // Initialize Vector Databasing Tables
149            db.execute(
150                "CREATE TABLE semantic_index_config (
151                    version INTEGER NOT NULL
152                )",
153                [],
154            )?;
155
156            db.execute(
157                "INSERT INTO semantic_index_config (version) VALUES (?1)",
158                params![SEMANTIC_INDEX_VERSION],
159            )?;
160
161            db.execute(
162                "CREATE TABLE worktrees (
163                    id INTEGER PRIMARY KEY AUTOINCREMENT,
164                    absolute_path VARCHAR NOT NULL
165                );
166                CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
167                ",
168                [],
169            )?;
170
171            db.execute(
172                "CREATE TABLE files (
173                    id INTEGER PRIMARY KEY AUTOINCREMENT,
174                    worktree_id INTEGER NOT NULL,
175                    relative_path VARCHAR NOT NULL,
176                    mtime_seconds INTEGER NOT NULL,
177                    mtime_nanos INTEGER NOT NULL,
178                    FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
179                )",
180                [],
181            )?;
182
183            db.execute(
184                "CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)",
185                [],
186            )?;
187
188            db.execute(
189                "CREATE TABLE spans (
190                    id INTEGER PRIMARY KEY AUTOINCREMENT,
191                    file_id INTEGER NOT NULL,
192                    start_byte INTEGER NOT NULL,
193                    end_byte INTEGER NOT NULL,
194                    name VARCHAR NOT NULL,
195                    embedding BLOB NOT NULL,
196                    digest BLOB NOT NULL,
197                    FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
198                )",
199                [],
200            )?;
201            db.execute(
202                "CREATE INDEX spans_digest ON spans (digest)",
203                [],
204            )?;
205
206            log::trace!("vector database initialized with updated schema.");
207            Ok(())
208        })
209    }
210
211    pub fn delete_file(
212        &self,
213        worktree_id: i64,
214        delete_path: Arc<Path>,
215    ) -> impl Future<Output = Result<()>> {
216        self.transact(move |db| {
217            db.execute(
218                "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
219                params![worktree_id, delete_path.to_str()],
220            )?;
221            Ok(())
222        })
223    }
224
225    pub fn insert_file(
226        &self,
227        worktree_id: i64,
228        path: Arc<Path>,
229        mtime: SystemTime,
230        spans: Vec<Span>,
231    ) -> impl Future<Output = Result<()>> {
232        self.transact(move |db| {
233            // Return the existing ID, if both the file and mtime match
234            let mtime = Timestamp::from(mtime);
235
236            db.execute(
237                "
238                REPLACE INTO files
239                (worktree_id, relative_path, mtime_seconds, mtime_nanos)
240                VALUES (?1, ?2, ?3, ?4)
241                ",
242                params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
243            )?;
244
245            let file_id = db.last_insert_rowid();
246
247            let mut query = db.prepare(
248                "
249                INSERT INTO spans
250                (file_id, start_byte, end_byte, name, embedding, digest)
251                VALUES (?1, ?2, ?3, ?4, ?5, ?6)
252                ",
253            )?;
254
255            for span in spans {
256                query.execute(params![
257                    file_id,
258                    span.range.start.to_string(),
259                    span.range.end.to_string(),
260                    span.name,
261                    span.embedding,
262                    span.digest
263                ])?;
264            }
265
266            Ok(())
267        })
268    }
269
270    pub fn worktree_previously_indexed(
271        &self,
272        worktree_root_path: &Path,
273    ) -> impl Future<Output = Result<bool>> {
274        let worktree_root_path = worktree_root_path.to_string_lossy().into_owned();
275        self.transact(move |db| {
276            let mut worktree_query =
277                db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
278            let worktree_id = worktree_query
279                .query_row(params![worktree_root_path], |row| Ok(row.get::<_, i64>(0)?));
280
281            if worktree_id.is_ok() {
282                return Ok(true);
283            } else {
284                return Ok(false);
285            }
286        })
287    }
288
289    pub fn embeddings_for_digests(
290        &self,
291        digests: Vec<SpanDigest>,
292    ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
293        self.transact(move |db| {
294            let mut query = db.prepare(
295                "
296                SELECT digest, embedding
297                FROM spans
298                WHERE digest IN rarray(?)
299                ",
300            )?;
301            let mut embeddings_by_digest = HashMap::default();
302            let digests = Rc::new(
303                digests
304                    .into_iter()
305                    .map(|p| Value::Blob(p.0.to_vec()))
306                    .collect::<Vec<_>>(),
307            );
308            let rows = query.query_map(params![digests], |row| {
309                Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
310            })?;
311
312            for row in rows {
313                if let Ok(row) = row {
314                    embeddings_by_digest.insert(row.0, row.1);
315                }
316            }
317
318            Ok(embeddings_by_digest)
319        })
320    }
321
322    pub fn embeddings_for_files(
323        &self,
324        worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
325    ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
326        self.transact(move |db| {
327            let mut query = db.prepare(
328                "
329                SELECT digest, embedding
330                FROM spans
331                LEFT JOIN files ON files.id = spans.file_id
332                WHERE files.worktree_id = ? AND files.relative_path IN rarray(?)
333            ",
334            )?;
335            let mut embeddings_by_digest = HashMap::default();
336            for (worktree_id, file_paths) in worktree_id_file_paths {
337                let file_paths = Rc::new(
338                    file_paths
339                        .into_iter()
340                        .map(|p| Value::Text(p.to_string_lossy().into_owned()))
341                        .collect::<Vec<_>>(),
342                );
343                let rows = query.query_map(params![worktree_id, file_paths], |row| {
344                    Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
345                })?;
346
347                for row in rows {
348                    if let Ok(row) = row {
349                        embeddings_by_digest.insert(row.0, row.1);
350                    }
351                }
352            }
353
354            Ok(embeddings_by_digest)
355        })
356    }
357
358    pub fn find_or_create_worktree(
359        &self,
360        worktree_root_path: Arc<Path>,
361    ) -> impl Future<Output = Result<i64>> {
362        self.transact(move |db| {
363            let mut worktree_query =
364                db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
365            let worktree_id = worktree_query
366                .query_row(params![worktree_root_path.to_string_lossy()], |row| {
367                    Ok(row.get::<_, i64>(0)?)
368                });
369
370            if worktree_id.is_ok() {
371                return Ok(worktree_id?);
372            }
373
374            // If worktree_id is Err, insert new worktree
375            db.execute(
376                "INSERT into worktrees (absolute_path) VALUES (?1)",
377                params![worktree_root_path.to_string_lossy()],
378            )?;
379            Ok(db.last_insert_rowid())
380        })
381    }
382
383    pub fn get_file_mtimes(
384        &self,
385        worktree_id: i64,
386    ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
387        self.transact(move |db| {
388            let mut statement = db.prepare(
389                "
390                SELECT relative_path, mtime_seconds, mtime_nanos
391                FROM files
392                WHERE worktree_id = ?1
393                ORDER BY relative_path",
394            )?;
395            let mut result: HashMap<PathBuf, SystemTime> = HashMap::default();
396            for row in statement.query_map(params![worktree_id], |row| {
397                Ok((
398                    row.get::<_, String>(0)?.into(),
399                    Timestamp {
400                        seconds: row.get(1)?,
401                        nanos: row.get(2)?,
402                    }
403                    .into(),
404                ))
405            })? {
406                let row = row?;
407                result.insert(row.0, row.1);
408            }
409            Ok(result)
410        })
411    }
412
413    pub fn top_k_search(
414        &self,
415        query_embedding: &Embedding,
416        limit: usize,
417        file_ids: &[i64],
418    ) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
419        let file_ids = file_ids.to_vec();
420        let query = query_embedding.clone().0;
421        let query = Array1::from_vec(query);
422        self.transact(move |db| {
423            let mut query_statement = db.prepare(
424                "
425                    SELECT
426                        id, embedding
427                    FROM
428                        spans
429                    WHERE
430                        file_id IN rarray(?)
431                    ",
432            )?;
433
434            let deserialized_rows = query_statement
435                .query_map(params![ids_to_sql(&file_ids)], |row| {
436                    Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?))
437                })?
438                .filter_map(|row| row.ok())
439                .collect::<Vec<(usize, Embedding)>>();
440
441            if deserialized_rows.len() == 0 {
442                return Ok(Vec::new());
443            }
444
445            // Get Length of Embeddings Returned
446            let embedding_len = deserialized_rows[0].1 .0.len();
447
448            let batch_n = 1000;
449            let mut batches = Vec::new();
450            let mut batch_ids = Vec::new();
451            let mut batch_embeddings: Vec<f32> = Vec::new();
452            deserialized_rows.iter().for_each(|(id, embedding)| {
453                batch_ids.push(id);
454                batch_embeddings.extend(&embedding.0);
455
456                if batch_ids.len() == batch_n {
457                    let embeddings = std::mem::take(&mut batch_embeddings);
458                    let ids = std::mem::take(&mut batch_ids);
459                    let array =
460                        Array2::from_shape_vec((ids.len(), embedding_len.clone()), embeddings);
461                    match array {
462                        Ok(array) => {
463                            batches.push((ids, array));
464                        }
465                        Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
466                    }
467                }
468            });
469
470            if batch_ids.len() > 0 {
471                let array = Array2::from_shape_vec(
472                    (batch_ids.len(), embedding_len),
473                    batch_embeddings.clone(),
474                );
475                match array {
476                    Ok(array) => {
477                        batches.push((batch_ids.clone(), array));
478                    }
479                    Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
480                }
481            }
482
483            let mut ids: Vec<usize> = Vec::new();
484            let mut results = Vec::new();
485            for (batch_ids, array) in batches {
486                let scores = array
487                    .dot(&query.t())
488                    .to_vec()
489                    .iter()
490                    .map(|score| OrderedFloat(*score))
491                    .collect::<Vec<OrderedFloat<f32>>>();
492                results.extend(scores);
493                ids.extend(batch_ids);
494            }
495
496            let sorted_idx = argsort(&results);
497            let mut sorted_results = Vec::new();
498            let last_idx = limit.min(sorted_idx.len());
499            for idx in &sorted_idx[0..last_idx] {
500                sorted_results.push((ids[*idx] as i64, results[*idx]))
501            }
502
503            Ok(sorted_results)
504        })
505    }
506
507    pub fn retrieve_included_file_ids(
508        &self,
509        worktree_ids: &[i64],
510        includes: &[PathMatcher],
511        excludes: &[PathMatcher],
512    ) -> impl Future<Output = Result<Vec<i64>>> {
513        let worktree_ids = worktree_ids.to_vec();
514        let includes = includes.to_vec();
515        let excludes = excludes.to_vec();
516        self.transact(move |db| {
517            let mut file_query = db.prepare(
518                "
519                SELECT
520                    id, relative_path
521                FROM
522                    files
523                WHERE
524                    worktree_id IN rarray(?)
525                ",
526            )?;
527
528            let mut file_ids = Vec::<i64>::new();
529            let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?;
530
531            while let Some(row) = rows.next()? {
532                let file_id = row.get(0)?;
533                let relative_path = row.get_ref(1)?.as_str()?;
534                let included =
535                    includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
536                let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
537                if included && !excluded {
538                    file_ids.push(file_id);
539                }
540            }
541
542            anyhow::Ok(file_ids)
543        })
544    }
545
546    pub fn spans_for_ids(
547        &self,
548        ids: &[i64],
549    ) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
550        let ids = ids.to_vec();
551        self.transact(move |db| {
552            let mut statement = db.prepare(
553                "
554                    SELECT
555                        spans.id,
556                        files.worktree_id,
557                        files.relative_path,
558                        spans.start_byte,
559                        spans.end_byte
560                    FROM
561                        spans, files
562                    WHERE
563                        spans.file_id = files.id AND
564                        spans.id in rarray(?)
565                ",
566            )?;
567
568            let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| {
569                Ok((
570                    row.get::<_, i64>(0)?,
571                    row.get::<_, i64>(1)?,
572                    row.get::<_, String>(2)?.into(),
573                    row.get(3)?..row.get(4)?,
574                ))
575            })?;
576
577            let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
578            for row in result_iter {
579                let (id, worktree_id, path, range) = row?;
580                values_by_id.insert(id, (worktree_id, path, range));
581            }
582
583            let mut results = Vec::with_capacity(ids.len());
584            for id in &ids {
585                let value = values_by_id
586                    .remove(id)
587                    .ok_or(anyhow!("missing span id {}", id))?;
588                results.push(value);
589            }
590
591            Ok(results)
592        })
593    }
594}
595
596fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
597    Rc::new(
598        ids.iter()
599            .copied()
600            .map(|v| rusqlite::types::Value::from(v))
601            .collect::<Vec<_>>(),
602    )
603}