db.rs

  1extern crate blas_src;
  2
  3use crate::{
  4    parsing::{Span, SpanDigest},
  5    SEMANTIC_INDEX_VERSION,
  6};
  7use ai::embedding::Embedding;
  8use anyhow::{anyhow, Context, Result};
  9use collections::HashMap;
 10use futures::channel::oneshot;
 11use gpui::executor;
 12use ndarray::{Array1, Array2};
 13use ordered_float::OrderedFloat;
 14use project::{search::PathMatcher, Fs};
 15use rpc::proto::Timestamp;
 16use rusqlite::params;
 17use rusqlite::types::Value;
 18use std::{
 19    future::Future,
 20    ops::Range,
 21    path::{Path, PathBuf},
 22    rc::Rc,
 23    sync::Arc,
 24    time::SystemTime,
 25};
 26use util::TryFutureExt;
 27
 28pub fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
 29    let mut indices = (0..data.len()).collect::<Vec<_>>();
 30    indices.sort_by_key(|&i| &data[i]);
 31    indices.reverse();
 32    indices
 33}
 34
 35#[derive(Debug)]
 36pub struct FileRecord {
 37    pub id: usize,
 38    pub relative_path: String,
 39    pub mtime: Timestamp,
 40}
 41
 42#[derive(Clone)]
 43pub struct VectorDatabase {
 44    path: Arc<Path>,
 45    transactions:
 46        smol::channel::Sender<Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>>,
 47}
 48
 49impl VectorDatabase {
 50    pub async fn new(
 51        fs: Arc<dyn Fs>,
 52        path: Arc<Path>,
 53        executor: Arc<executor::Background>,
 54    ) -> Result<Self> {
 55        if let Some(db_directory) = path.parent() {
 56            fs.create_dir(db_directory).await?;
 57        }
 58
 59        let (transactions_tx, transactions_rx) = smol::channel::unbounded::<
 60            Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>,
 61        >();
 62        executor
 63            .spawn({
 64                let path = path.clone();
 65                async move {
 66                    let mut connection = rusqlite::Connection::open(&path)?;
 67
 68                    connection.pragma_update(None, "journal_mode", "wal")?;
 69                    connection.pragma_update(None, "synchronous", "normal")?;
 70                    connection.pragma_update(None, "cache_size", 1000000)?;
 71                    connection.pragma_update(None, "temp_store", "MEMORY")?;
 72
 73                    while let Ok(transaction) = transactions_rx.recv().await {
 74                        transaction(&mut connection);
 75                    }
 76
 77                    anyhow::Ok(())
 78                }
 79                .log_err()
 80            })
 81            .detach();
 82        let this = Self {
 83            transactions: transactions_tx,
 84            path,
 85        };
 86        this.initialize_database().await?;
 87        Ok(this)
 88    }
 89
 90    pub fn path(&self) -> &Arc<Path> {
 91        &self.path
 92    }
 93
 94    fn transact<F, T>(&self, f: F) -> impl Future<Output = Result<T>>
 95    where
 96        F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result<T>,
 97        T: 'static + Send,
 98    {
 99        let (tx, rx) = oneshot::channel();
100        let transactions = self.transactions.clone();
101        async move {
102            if transactions
103                .send(Box::new(|connection| {
104                    let result = connection
105                        .transaction()
106                        .map_err(|err| anyhow!(err))
107                        .and_then(|transaction| {
108                            let result = f(&transaction)?;
109                            transaction.commit()?;
110                            Ok(result)
111                        });
112                    let _ = tx.send(result);
113                }))
114                .await
115                .is_err()
116            {
117                return Err(anyhow!("connection was dropped"))?;
118            }
119            rx.await?
120        }
121    }
122
123    fn initialize_database(&self) -> impl Future<Output = Result<()>> {
124        self.transact(|db| {
125            rusqlite::vtab::array::load_module(&db)?;
126
127            // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
128            let version_query = db.prepare("SELECT version from semantic_index_config");
129            let version = version_query
130                .and_then(|mut query| query.query_row([], |row| Ok(row.get::<_, i64>(0)?)));
131            if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) {
132                log::trace!("vector database schema up to date");
133                return Ok(());
134            }
135
136            log::trace!("vector database schema out of date. updating...");
137            // We renamed the `documents` table to `spans`, so we want to drop
138            // `documents` without recreating it if it exists.
139            db.execute("DROP TABLE IF EXISTS documents", [])
140                .context("failed to drop 'documents' table")?;
141            db.execute("DROP TABLE IF EXISTS spans", [])
142                .context("failed to drop 'spans' table")?;
143            db.execute("DROP TABLE IF EXISTS files", [])
144                .context("failed to drop 'files' table")?;
145            db.execute("DROP TABLE IF EXISTS worktrees", [])
146                .context("failed to drop 'worktrees' table")?;
147            db.execute("DROP TABLE IF EXISTS semantic_index_config", [])
148                .context("failed to drop 'semantic_index_config' table")?;
149
150            // Initialize Vector Databasing Tables
151            db.execute(
152                "CREATE TABLE semantic_index_config (
153                    version INTEGER NOT NULL
154                )",
155                [],
156            )?;
157
158            db.execute(
159                "INSERT INTO semantic_index_config (version) VALUES (?1)",
160                params![SEMANTIC_INDEX_VERSION],
161            )?;
162
163            db.execute(
164                "CREATE TABLE worktrees (
165                    id INTEGER PRIMARY KEY AUTOINCREMENT,
166                    absolute_path VARCHAR NOT NULL
167                );
168                CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
169                ",
170                [],
171            )?;
172
173            db.execute(
174                "CREATE TABLE files (
175                    id INTEGER PRIMARY KEY AUTOINCREMENT,
176                    worktree_id INTEGER NOT NULL,
177                    relative_path VARCHAR NOT NULL,
178                    mtime_seconds INTEGER NOT NULL,
179                    mtime_nanos INTEGER NOT NULL,
180                    FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
181                )",
182                [],
183            )?;
184
185            db.execute(
186                "CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)",
187                [],
188            )?;
189
190            db.execute(
191                "CREATE TABLE spans (
192                    id INTEGER PRIMARY KEY AUTOINCREMENT,
193                    file_id INTEGER NOT NULL,
194                    start_byte INTEGER NOT NULL,
195                    end_byte INTEGER NOT NULL,
196                    name VARCHAR NOT NULL,
197                    embedding BLOB NOT NULL,
198                    digest BLOB NOT NULL,
199                    FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
200                )",
201                [],
202            )?;
203            db.execute(
204                "CREATE INDEX spans_digest ON spans (digest)",
205                [],
206            )?;
207
208            log::trace!("vector database initialized with updated schema.");
209            Ok(())
210        })
211    }
212
213    pub fn delete_file(
214        &self,
215        worktree_id: i64,
216        delete_path: Arc<Path>,
217    ) -> impl Future<Output = Result<()>> {
218        self.transact(move |db| {
219            db.execute(
220                "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
221                params![worktree_id, delete_path.to_str()],
222            )?;
223            Ok(())
224        })
225    }
226
227    pub fn insert_file(
228        &self,
229        worktree_id: i64,
230        path: Arc<Path>,
231        mtime: SystemTime,
232        spans: Vec<Span>,
233    ) -> impl Future<Output = Result<()>> {
234        self.transact(move |db| {
235            // Return the existing ID, if both the file and mtime match
236            let mtime = Timestamp::from(mtime);
237
238            db.execute(
239                "
240                REPLACE INTO files
241                (worktree_id, relative_path, mtime_seconds, mtime_nanos)
242                VALUES (?1, ?2, ?3, ?4)
243                ",
244                params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
245            )?;
246
247            let file_id = db.last_insert_rowid();
248
249            let mut query = db.prepare(
250                "
251                INSERT INTO spans
252                (file_id, start_byte, end_byte, name, embedding, digest)
253                VALUES (?1, ?2, ?3, ?4, ?5, ?6)
254                ",
255            )?;
256
257            for span in spans {
258                query.execute(params![
259                    file_id,
260                    span.range.start.to_string(),
261                    span.range.end.to_string(),
262                    span.name,
263                    span.embedding,
264                    span.digest
265                ])?;
266            }
267
268            Ok(())
269        })
270    }
271
272    pub fn worktree_previously_indexed(
273        &self,
274        worktree_root_path: &Path,
275    ) -> impl Future<Output = Result<bool>> {
276        let worktree_root_path = worktree_root_path.to_string_lossy().into_owned();
277        self.transact(move |db| {
278            let mut worktree_query =
279                db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
280            let worktree_id = worktree_query
281                .query_row(params![worktree_root_path], |row| Ok(row.get::<_, i64>(0)?));
282
283            if worktree_id.is_ok() {
284                return Ok(true);
285            } else {
286                return Ok(false);
287            }
288        })
289    }
290
291    pub fn embeddings_for_digests(
292        &self,
293        digests: Vec<SpanDigest>,
294    ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
295        self.transact(move |db| {
296            let mut query = db.prepare(
297                "
298                SELECT digest, embedding
299                FROM spans
300                WHERE digest IN rarray(?)
301                ",
302            )?;
303            let mut embeddings_by_digest = HashMap::default();
304            let digests = Rc::new(
305                digests
306                    .into_iter()
307                    .map(|p| Value::Blob(p.0.to_vec()))
308                    .collect::<Vec<_>>(),
309            );
310            let rows = query.query_map(params![digests], |row| {
311                Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
312            })?;
313
314            for row in rows {
315                if let Ok(row) = row {
316                    embeddings_by_digest.insert(row.0, row.1);
317                }
318            }
319
320            Ok(embeddings_by_digest)
321        })
322    }
323
324    pub fn embeddings_for_files(
325        &self,
326        worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
327    ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
328        self.transact(move |db| {
329            let mut query = db.prepare(
330                "
331                SELECT digest, embedding
332                FROM spans
333                LEFT JOIN files ON files.id = spans.file_id
334                WHERE files.worktree_id = ? AND files.relative_path IN rarray(?)
335            ",
336            )?;
337            let mut embeddings_by_digest = HashMap::default();
338            for (worktree_id, file_paths) in worktree_id_file_paths {
339                let file_paths = Rc::new(
340                    file_paths
341                        .into_iter()
342                        .map(|p| Value::Text(p.to_string_lossy().into_owned()))
343                        .collect::<Vec<_>>(),
344                );
345                let rows = query.query_map(params![worktree_id, file_paths], |row| {
346                    Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
347                })?;
348
349                for row in rows {
350                    if let Ok(row) = row {
351                        embeddings_by_digest.insert(row.0, row.1);
352                    }
353                }
354            }
355
356            Ok(embeddings_by_digest)
357        })
358    }
359
360    pub fn find_or_create_worktree(
361        &self,
362        worktree_root_path: Arc<Path>,
363    ) -> impl Future<Output = Result<i64>> {
364        self.transact(move |db| {
365            let mut worktree_query =
366                db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
367            let worktree_id = worktree_query
368                .query_row(params![worktree_root_path.to_string_lossy()], |row| {
369                    Ok(row.get::<_, i64>(0)?)
370                });
371
372            if worktree_id.is_ok() {
373                return Ok(worktree_id?);
374            }
375
376            // If worktree_id is Err, insert new worktree
377            db.execute(
378                "INSERT into worktrees (absolute_path) VALUES (?1)",
379                params![worktree_root_path.to_string_lossy()],
380            )?;
381            Ok(db.last_insert_rowid())
382        })
383    }
384
385    pub fn get_file_mtimes(
386        &self,
387        worktree_id: i64,
388    ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
389        self.transact(move |db| {
390            let mut statement = db.prepare(
391                "
392                SELECT relative_path, mtime_seconds, mtime_nanos
393                FROM files
394                WHERE worktree_id = ?1
395                ORDER BY relative_path",
396            )?;
397            let mut result: HashMap<PathBuf, SystemTime> = HashMap::default();
398            for row in statement.query_map(params![worktree_id], |row| {
399                Ok((
400                    row.get::<_, String>(0)?.into(),
401                    Timestamp {
402                        seconds: row.get(1)?,
403                        nanos: row.get(2)?,
404                    }
405                    .into(),
406                ))
407            })? {
408                let row = row?;
409                result.insert(row.0, row.1);
410            }
411            Ok(result)
412        })
413    }
414
415    pub fn top_k_search(
416        &self,
417        query_embedding: &Embedding,
418        limit: usize,
419        file_ids: &[i64],
420    ) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
421        let file_ids = file_ids.to_vec();
422        let query = query_embedding.clone().0;
423        let query = Array1::from_vec(query);
424        self.transact(move |db| {
425            let mut query_statement = db.prepare(
426                "
427                    SELECT
428                        id, embedding
429                    FROM
430                        spans
431                    WHERE
432                        file_id IN rarray(?)
433                    ",
434            )?;
435
436            let deserialized_rows = query_statement
437                .query_map(params![ids_to_sql(&file_ids)], |row| {
438                    Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?))
439                })?
440                .filter_map(|row| row.ok())
441                .collect::<Vec<(usize, Embedding)>>();
442
443            let batch_n = 250;
444            let mut batches = Vec::new();
445            let mut batch_ids = Vec::new();
446            let mut batch_embeddings: Vec<f32> = Vec::new();
447            deserialized_rows.iter().for_each(|(id, embedding)| {
448                batch_ids.push(id);
449                batch_embeddings.extend(&embedding.0);
450                if batch_ids.len() == batch_n {
451                    let array =
452                        Array2::from_shape_vec((batch_ids.len(), 1536), batch_embeddings.clone());
453                    match array {
454                        Ok(array) => {
455                            batches.push((batch_ids.clone(), array));
456                        }
457                        Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
458                    }
459
460                    batch_ids = Vec::new();
461                    batch_embeddings = Vec::new();
462                }
463            });
464
465            if batch_ids.len() > 0 {
466                let array =
467                    Array2::from_shape_vec((batch_ids.len(), 1536), batch_embeddings.clone());
468                match array {
469                    Ok(array) => {
470                        batches.push((batch_ids.clone(), array));
471                    }
472                    Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
473                }
474            }
475
476            let mut ids: Vec<usize> = Vec::new();
477            let mut results = Vec::new();
478            for (batch_ids, array) in batches {
479                let scores = array
480                    .dot(&query.t())
481                    .to_vec()
482                    .iter()
483                    .map(|score| OrderedFloat(*score))
484                    .collect::<Vec<OrderedFloat<f32>>>();
485                results.extend(scores);
486                ids.extend(batch_ids);
487            }
488
489            let sorted_idx = argsort(&results);
490            let mut sorted_results = Vec::new();
491            let last_idx = limit.min(sorted_idx.len());
492            for idx in &sorted_idx[0..last_idx] {
493                sorted_results.push((ids[*idx] as i64, results[*idx]))
494            }
495
496            Ok(sorted_results)
497        })
498    }
499
500    pub fn retrieve_included_file_ids(
501        &self,
502        worktree_ids: &[i64],
503        includes: &[PathMatcher],
504        excludes: &[PathMatcher],
505    ) -> impl Future<Output = Result<Vec<i64>>> {
506        let worktree_ids = worktree_ids.to_vec();
507        let includes = includes.to_vec();
508        let excludes = excludes.to_vec();
509        self.transact(move |db| {
510            let mut file_query = db.prepare(
511                "
512                SELECT
513                    id, relative_path
514                FROM
515                    files
516                WHERE
517                    worktree_id IN rarray(?)
518                ",
519            )?;
520
521            let mut file_ids = Vec::<i64>::new();
522            let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?;
523
524            while let Some(row) = rows.next()? {
525                let file_id = row.get(0)?;
526                let relative_path = row.get_ref(1)?.as_str()?;
527                let included =
528                    includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
529                let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
530                if included && !excluded {
531                    file_ids.push(file_id);
532                }
533            }
534
535            anyhow::Ok(file_ids)
536        })
537    }
538
539    pub fn spans_for_ids(
540        &self,
541        ids: &[i64],
542    ) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
543        let ids = ids.to_vec();
544        self.transact(move |db| {
545            let mut statement = db.prepare(
546                "
547                    SELECT
548                        spans.id,
549                        files.worktree_id,
550                        files.relative_path,
551                        spans.start_byte,
552                        spans.end_byte
553                    FROM
554                        spans, files
555                    WHERE
556                        spans.file_id = files.id AND
557                        spans.id in rarray(?)
558                ",
559            )?;
560
561            let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| {
562                Ok((
563                    row.get::<_, i64>(0)?,
564                    row.get::<_, i64>(1)?,
565                    row.get::<_, String>(2)?.into(),
566                    row.get(3)?..row.get(4)?,
567                ))
568            })?;
569
570            let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
571            for row in result_iter {
572                let (id, worktree_id, path, range) = row?;
573                values_by_id.insert(id, (worktree_id, path, range));
574            }
575
576            let mut results = Vec::with_capacity(ids.len());
577            for id in &ids {
578                let value = values_by_id
579                    .remove(id)
580                    .ok_or(anyhow!("missing span id {}", id))?;
581                results.push(value);
582            }
583
584            Ok(results)
585        })
586    }
587}
588
589fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
590    Rc::new(
591        ids.iter()
592            .copied()
593            .map(|v| rusqlite::types::Value::from(v))
594            .collect::<Vec<_>>(),
595    )
596}