db.rs

  1use crate::{
  2    parsing::{Span, SpanDigest},
  3    SEMANTIC_INDEX_VERSION,
  4};
  5use ai::embedding::Embedding;
  6use anyhow::{anyhow, Context, Result};
  7use collections::HashMap;
  8use futures::channel::oneshot;
  9use gpui::BackgroundExecutor;
 10use ndarray::{Array1, Array2};
 11use ordered_float::OrderedFloat;
 12use project::Fs;
 13use rpc::proto::Timestamp;
 14use rusqlite::params;
 15use rusqlite::types::Value;
 16use std::{
 17    future::Future,
 18    ops::Range,
 19    path::{Path, PathBuf},
 20    rc::Rc,
 21    sync::Arc,
 22    time::SystemTime,
 23};
 24use util::{paths::PathMatcher, TryFutureExt};
 25
 26pub fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
 27    let mut indices = (0..data.len()).collect::<Vec<_>>();
 28    indices.sort_by_key(|&i| &data[i]);
 29    indices.reverse();
 30    indices
 31}
 32
 33#[derive(Debug)]
 34pub struct FileRecord {
 35    pub id: usize,
 36    pub relative_path: String,
 37    pub mtime: Timestamp,
 38}
 39
 40#[derive(Clone)]
 41pub struct VectorDatabase {
 42    path: Arc<Path>,
 43    transactions:
 44        smol::channel::Sender<Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>>,
 45}
 46
 47impl VectorDatabase {
 48    pub async fn new(
 49        fs: Arc<dyn Fs>,
 50        path: Arc<Path>,
 51        executor: BackgroundExecutor,
 52    ) -> Result<Self> {
 53        if let Some(db_directory) = path.parent() {
 54            fs.create_dir(db_directory).await?;
 55        }
 56
 57        let (transactions_tx, transactions_rx) = smol::channel::unbounded::<
 58            Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>,
 59        >();
 60        executor
 61            .spawn({
 62                let path = path.clone();
 63                async move {
 64                    let mut connection = rusqlite::Connection::open(&path)?;
 65
 66                    connection.pragma_update(None, "journal_mode", "wal")?;
 67                    connection.pragma_update(None, "synchronous", "normal")?;
 68                    connection.pragma_update(None, "cache_size", 1000000)?;
 69                    connection.pragma_update(None, "temp_store", "MEMORY")?;
 70
 71                    while let Ok(transaction) = transactions_rx.recv().await {
 72                        transaction(&mut connection);
 73                    }
 74
 75                    anyhow::Ok(())
 76                }
 77                .log_err()
 78            })
 79            .detach();
 80        let this = Self {
 81            transactions: transactions_tx,
 82            path,
 83        };
 84        this.initialize_database().await?;
 85        Ok(this)
 86    }
 87
 88    pub fn path(&self) -> &Arc<Path> {
 89        &self.path
 90    }
 91
 92    fn transact<F, T>(&self, f: F) -> impl Future<Output = Result<T>>
 93    where
 94        F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result<T>,
 95        T: 'static + Send,
 96    {
 97        let (tx, rx) = oneshot::channel();
 98        let transactions = self.transactions.clone();
 99        async move {
100            if transactions
101                .send(Box::new(|connection| {
102                    let result = connection
103                        .transaction()
104                        .map_err(|err| anyhow!(err))
105                        .and_then(|transaction| {
106                            let result = f(&transaction)?;
107                            transaction.commit()?;
108                            Ok(result)
109                        });
110                    let _ = tx.send(result);
111                }))
112                .await
113                .is_err()
114            {
115                return Err(anyhow!("connection was dropped"))?;
116            }
117            rx.await?
118        }
119    }
120
121    fn initialize_database(&self) -> impl Future<Output = Result<()>> {
122        self.transact(|db| {
123            rusqlite::vtab::array::load_module(&db)?;
124
125            // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
126            let version_query = db.prepare("SELECT version from semantic_index_config");
127            let version = version_query
128                .and_then(|mut query| query.query_row([], |row| row.get::<_, i64>(0)));
129            if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) {
130                log::trace!("vector database schema up to date");
131                return Ok(());
132            }
133
134            log::trace!("vector database schema out of date. updating...");
135            // We renamed the `documents` table to `spans`, so we want to drop
136            // `documents` without recreating it if it exists.
137            db.execute("DROP TABLE IF EXISTS documents", [])
138                .context("failed to drop 'documents' table")?;
139            db.execute("DROP TABLE IF EXISTS spans", [])
140                .context("failed to drop 'spans' table")?;
141            db.execute("DROP TABLE IF EXISTS files", [])
142                .context("failed to drop 'files' table")?;
143            db.execute("DROP TABLE IF EXISTS worktrees", [])
144                .context("failed to drop 'worktrees' table")?;
145            db.execute("DROP TABLE IF EXISTS semantic_index_config", [])
146                .context("failed to drop 'semantic_index_config' table")?;
147
148            // Initialize Vector Databasing Tables
149            db.execute(
150                "CREATE TABLE semantic_index_config (
151                    version INTEGER NOT NULL
152                )",
153                [],
154            )?;
155
156            db.execute(
157                "INSERT INTO semantic_index_config (version) VALUES (?1)",
158                params![SEMANTIC_INDEX_VERSION],
159            )?;
160
161            db.execute(
162                "CREATE TABLE worktrees (
163                    id INTEGER PRIMARY KEY AUTOINCREMENT,
164                    absolute_path VARCHAR NOT NULL
165                );
166                CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
167                ",
168                [],
169            )?;
170
171            db.execute(
172                "CREATE TABLE files (
173                    id INTEGER PRIMARY KEY AUTOINCREMENT,
174                    worktree_id INTEGER NOT NULL,
175                    relative_path VARCHAR NOT NULL,
176                    mtime_seconds INTEGER NOT NULL,
177                    mtime_nanos INTEGER NOT NULL,
178                    FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
179                )",
180                [],
181            )?;
182
183            db.execute(
184                "CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)",
185                [],
186            )?;
187
188            db.execute(
189                "CREATE TABLE spans (
190                    id INTEGER PRIMARY KEY AUTOINCREMENT,
191                    file_id INTEGER NOT NULL,
192                    start_byte INTEGER NOT NULL,
193                    end_byte INTEGER NOT NULL,
194                    name VARCHAR NOT NULL,
195                    embedding BLOB NOT NULL,
196                    digest BLOB NOT NULL,
197                    FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
198                )",
199                [],
200            )?;
201            db.execute(
202                "CREATE INDEX spans_digest ON spans (digest)",
203                [],
204            )?;
205
206            log::trace!("vector database initialized with updated schema.");
207            Ok(())
208        })
209    }
210
211    pub fn delete_file(
212        &self,
213        worktree_id: i64,
214        delete_path: Arc<Path>,
215    ) -> impl Future<Output = Result<()>> {
216        self.transact(move |db| {
217            db.execute(
218                "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
219                params![worktree_id, delete_path.to_str()],
220            )?;
221            Ok(())
222        })
223    }
224
225    pub fn insert_file(
226        &self,
227        worktree_id: i64,
228        path: Arc<Path>,
229        mtime: SystemTime,
230        spans: Vec<Span>,
231    ) -> impl Future<Output = Result<()>> {
232        self.transact(move |db| {
233            // Return the existing ID, if both the file and mtime match
234            let mtime = Timestamp::from(mtime);
235
236            db.execute(
237                "
238                REPLACE INTO files
239                (worktree_id, relative_path, mtime_seconds, mtime_nanos)
240                VALUES (?1, ?2, ?3, ?4)
241                ",
242                params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
243            )?;
244
245            let file_id = db.last_insert_rowid();
246
247            let mut query = db.prepare(
248                "
249                INSERT INTO spans
250                (file_id, start_byte, end_byte, name, embedding, digest)
251                VALUES (?1, ?2, ?3, ?4, ?5, ?6)
252                ",
253            )?;
254
255            for span in spans {
256                query.execute(params![
257                    file_id,
258                    span.range.start.to_string(),
259                    span.range.end.to_string(),
260                    span.name,
261                    span.embedding,
262                    span.digest
263                ])?;
264            }
265
266            Ok(())
267        })
268    }
269
270    pub fn worktree_previously_indexed(
271        &self,
272        worktree_root_path: &Path,
273    ) -> impl Future<Output = Result<bool>> {
274        let worktree_root_path = worktree_root_path.to_string_lossy().into_owned();
275        self.transact(move |db| {
276            let mut worktree_query =
277                db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
278            let worktree_id =
279                worktree_query.query_row(params![worktree_root_path], |row| row.get::<_, i64>(0));
280
281            Ok(worktree_id.is_ok())
282        })
283    }
284
285    pub fn embeddings_for_digests(
286        &self,
287        digests: Vec<SpanDigest>,
288    ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
289        self.transact(move |db| {
290            let mut query = db.prepare(
291                "
292                SELECT digest, embedding
293                FROM spans
294                WHERE digest IN rarray(?)
295                ",
296            )?;
297            let mut embeddings_by_digest = HashMap::default();
298            let digests = Rc::new(
299                digests
300                    .into_iter()
301                    .map(|digest| Value::Blob(digest.0.to_vec()))
302                    .collect::<Vec<_>>(),
303            );
304            let rows = query.query_map(params![digests], |row| {
305                Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
306            })?;
307
308            for (digest, embedding) in rows.flatten() {
309                embeddings_by_digest.insert(digest, embedding);
310            }
311
312            Ok(embeddings_by_digest)
313        })
314    }
315
316    pub fn embeddings_for_files(
317        &self,
318        worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
319    ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
320        self.transact(move |db| {
321            let mut query = db.prepare(
322                "
323                SELECT digest, embedding
324                FROM spans
325                LEFT JOIN files ON files.id = spans.file_id
326                WHERE files.worktree_id = ? AND files.relative_path IN rarray(?)
327            ",
328            )?;
329            let mut embeddings_by_digest = HashMap::default();
330            for (worktree_id, file_paths) in worktree_id_file_paths {
331                let file_paths = Rc::new(
332                    file_paths
333                        .into_iter()
334                        .map(|p| Value::Text(p.to_string_lossy().into_owned()))
335                        .collect::<Vec<_>>(),
336                );
337                let rows = query.query_map(params![worktree_id, file_paths], |row| {
338                    Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
339                })?;
340
341                for (digest, embedding) in rows.flatten() {
342                    embeddings_by_digest.insert(digest, embedding);
343                }
344            }
345
346            Ok(embeddings_by_digest)
347        })
348    }
349
350    pub fn find_or_create_worktree(
351        &self,
352        worktree_root_path: Arc<Path>,
353    ) -> impl Future<Output = Result<i64>> {
354        self.transact(move |db| {
355            let mut worktree_query =
356                db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
357            let worktree_id = worktree_query
358                .query_row(params![worktree_root_path.to_string_lossy()], |row| {
359                    row.get::<_, i64>(0)
360                });
361
362            if worktree_id.is_ok() {
363                return Ok(worktree_id?);
364            }
365
366            // If worktree_id is Err, insert new worktree
367            db.execute(
368                "INSERT into worktrees (absolute_path) VALUES (?1)",
369                params![worktree_root_path.to_string_lossy()],
370            )?;
371            Ok(db.last_insert_rowid())
372        })
373    }
374
375    pub fn get_file_mtimes(
376        &self,
377        worktree_id: i64,
378    ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
379        self.transact(move |db| {
380            let mut statement = db.prepare(
381                "
382                SELECT relative_path, mtime_seconds, mtime_nanos
383                FROM files
384                WHERE worktree_id = ?1
385                ORDER BY relative_path",
386            )?;
387            let mut result: HashMap<PathBuf, SystemTime> = HashMap::default();
388            for row in statement.query_map(params![worktree_id], |row| {
389                Ok((
390                    row.get::<_, String>(0)?.into(),
391                    Timestamp {
392                        seconds: row.get(1)?,
393                        nanos: row.get(2)?,
394                    }
395                    .into(),
396                ))
397            })? {
398                let row = row?;
399                result.insert(row.0, row.1);
400            }
401            Ok(result)
402        })
403    }
404
405    pub fn top_k_search(
406        &self,
407        query_embedding: &Embedding,
408        limit: usize,
409        file_ids: &[i64],
410    ) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
411        let file_ids = file_ids.to_vec();
412        let query = query_embedding.clone().0;
413        let query = Array1::from_vec(query);
414        self.transact(move |db| {
415            let mut query_statement = db.prepare(
416                "
417                    SELECT
418                        id, embedding
419                    FROM
420                        spans
421                    WHERE
422                        file_id IN rarray(?)
423                    ",
424            )?;
425
426            let deserialized_rows = query_statement
427                .query_map(params![ids_to_sql(&file_ids)], |row| {
428                    Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?))
429                })?
430                .filter_map(|row| row.ok())
431                .collect::<Vec<(usize, Embedding)>>();
432
433            if deserialized_rows.len() == 0 {
434                return Ok(Vec::new());
435            }
436
437            // Get Length of Embeddings Returned
438            let embedding_len = deserialized_rows[0].1 .0.len();
439
440            let batch_n = 1000;
441            let mut batches = Vec::new();
442            let mut batch_ids = Vec::new();
443            let mut batch_embeddings: Vec<f32> = Vec::new();
444            deserialized_rows.iter().for_each(|(id, embedding)| {
445                batch_ids.push(id);
446                batch_embeddings.extend(&embedding.0);
447
448                if batch_ids.len() == batch_n {
449                    let embeddings = std::mem::take(&mut batch_embeddings);
450                    let ids = std::mem::take(&mut batch_ids);
451                    let array = Array2::from_shape_vec((ids.len(), embedding_len), embeddings);
452                    match array {
453                        Ok(array) => {
454                            batches.push((ids, array));
455                        }
456                        Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
457                    }
458                }
459            });
460
461            if batch_ids.len() > 0 {
462                let array = Array2::from_shape_vec(
463                    (batch_ids.len(), embedding_len),
464                    batch_embeddings.clone(),
465                );
466                match array {
467                    Ok(array) => {
468                        batches.push((batch_ids.clone(), array));
469                    }
470                    Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
471                }
472            }
473
474            let mut ids: Vec<usize> = Vec::new();
475            let mut results = Vec::new();
476            for (batch_ids, array) in batches {
477                let scores = array
478                    .dot(&query.t())
479                    .to_vec()
480                    .iter()
481                    .map(|score| OrderedFloat(*score))
482                    .collect::<Vec<OrderedFloat<f32>>>();
483                results.extend(scores);
484                ids.extend(batch_ids);
485            }
486
487            let sorted_idx = argsort(&results);
488            let mut sorted_results = Vec::new();
489            let last_idx = limit.min(sorted_idx.len());
490            for idx in &sorted_idx[0..last_idx] {
491                sorted_results.push((ids[*idx] as i64, results[*idx]))
492            }
493
494            Ok(sorted_results)
495        })
496    }
497
498    pub fn retrieve_included_file_ids(
499        &self,
500        worktree_ids: &[i64],
501        includes: &[PathMatcher],
502        excludes: &[PathMatcher],
503    ) -> impl Future<Output = Result<Vec<i64>>> {
504        let worktree_ids = worktree_ids.to_vec();
505        let includes = includes.to_vec();
506        let excludes = excludes.to_vec();
507        self.transact(move |db| {
508            let mut file_query = db.prepare(
509                "
510                SELECT
511                    id, relative_path
512                FROM
513                    files
514                WHERE
515                    worktree_id IN rarray(?)
516                ",
517            )?;
518
519            let mut file_ids = Vec::<i64>::new();
520            let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?;
521
522            while let Some(row) = rows.next()? {
523                let file_id = row.get(0)?;
524                let relative_path = row.get_ref(1)?.as_str()?;
525                let included =
526                    includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
527                let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
528                if included && !excluded {
529                    file_ids.push(file_id);
530                }
531            }
532
533            anyhow::Ok(file_ids)
534        })
535    }
536
537    pub fn spans_for_ids(
538        &self,
539        ids: &[i64],
540    ) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
541        let ids = ids.to_vec();
542        self.transact(move |db| {
543            let mut statement = db.prepare(
544                "
545                    SELECT
546                        spans.id,
547                        files.worktree_id,
548                        files.relative_path,
549                        spans.start_byte,
550                        spans.end_byte
551                    FROM
552                        spans, files
553                    WHERE
554                        spans.file_id = files.id AND
555                        spans.id in rarray(?)
556                ",
557            )?;
558
559            let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| {
560                Ok((
561                    row.get::<_, i64>(0)?,
562                    row.get::<_, i64>(1)?,
563                    row.get::<_, String>(2)?.into(),
564                    row.get(3)?..row.get(4)?,
565                ))
566            })?;
567
568            let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
569            for row in result_iter {
570                let (id, worktree_id, path, range) = row?;
571                values_by_id.insert(id, (worktree_id, path, range));
572            }
573
574            let mut results = Vec::with_capacity(ids.len());
575            for id in &ids {
576                let value = values_by_id
577                    .remove(id)
578                    .ok_or(anyhow!("missing span id {}", id))?;
579                results.push(value);
580            }
581
582            Ok(results)
583        })
584    }
585}
586
587fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
588    Rc::new(
589        ids.iter()
590            .copied()
591            .map(|v| rusqlite::types::Value::from(v))
592            .collect::<Vec<_>>(),
593    )
594}