db.rs

  1use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
  2use anyhow::{anyhow, Context, Result};
  3use globset::GlobMatcher;
  4use project::Fs;
  5use rpc::proto::Timestamp;
  6use rusqlite::{
  7    params,
  8    types::{FromSql, FromSqlResult, ValueRef},
  9};
 10use std::{
 11    cmp::Ordering,
 12    collections::HashMap,
 13    ops::Range,
 14    path::{Path, PathBuf},
 15    rc::Rc,
 16    sync::Arc,
 17    time::SystemTime,
 18};
 19
 20#[derive(Debug)]
 21pub struct FileRecord {
 22    pub id: usize,
 23    pub relative_path: String,
 24    pub mtime: Timestamp,
 25}
 26
 27#[derive(Debug)]
 28struct Embedding(pub Vec<f32>);
 29
 30impl FromSql for Embedding {
 31    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 32        let bytes = value.as_blob()?;
 33        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 34        if embedding.is_err() {
 35            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 36        }
 37        return Ok(Embedding(embedding.unwrap()));
 38    }
 39}
 40
 41pub struct VectorDatabase {
 42    db: rusqlite::Connection,
 43}
 44
 45impl VectorDatabase {
 46    pub async fn new(fs: Arc<dyn Fs>, path: Arc<PathBuf>) -> Result<Self> {
 47        if let Some(db_directory) = path.parent() {
 48            fs.create_dir(db_directory).await?;
 49        }
 50
 51        let this = Self {
 52            db: rusqlite::Connection::open(path.as_path())?,
 53        };
 54        this.initialize_database()?;
 55        Ok(this)
 56    }
 57
 58    fn get_existing_version(&self) -> Result<i64> {
 59        let mut version_query = self
 60            .db
 61            .prepare("SELECT version from semantic_index_config")?;
 62        version_query
 63            .query_row([], |row| Ok(row.get::<_, i64>(0)?))
 64            .map_err(|err| anyhow!("version query failed: {err}"))
 65    }
 66
 67    fn initialize_database(&self) -> Result<()> {
 68        rusqlite::vtab::array::load_module(&self.db)?;
 69
 70        // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
 71        if self
 72            .get_existing_version()
 73            .map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64)
 74        {
 75            log::trace!("vector database schema up to date");
 76            return Ok(());
 77        }
 78
 79        log::trace!("vector database schema out of date. updating...");
 80        self.db
 81            .execute("DROP TABLE IF EXISTS documents", [])
 82            .context("failed to drop 'documents' table")?;
 83        self.db
 84            .execute("DROP TABLE IF EXISTS files", [])
 85            .context("failed to drop 'files' table")?;
 86        self.db
 87            .execute("DROP TABLE IF EXISTS worktrees", [])
 88            .context("failed to drop 'worktrees' table")?;
 89        self.db
 90            .execute("DROP TABLE IF EXISTS semantic_index_config", [])
 91            .context("failed to drop 'semantic_index_config' table")?;
 92
 93        // Initialize Vector Databasing Tables
 94        self.db.execute(
 95            "CREATE TABLE semantic_index_config (
 96                version INTEGER NOT NULL
 97            )",
 98            [],
 99        )?;
100
101        self.db.execute(
102            "INSERT INTO semantic_index_config (version) VALUES (?1)",
103            params![SEMANTIC_INDEX_VERSION],
104        )?;
105
106        self.db.execute(
107            "CREATE TABLE worktrees (
108                id INTEGER PRIMARY KEY AUTOINCREMENT,
109                absolute_path VARCHAR NOT NULL
110            );
111            CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
112            ",
113            [],
114        )?;
115
116        self.db.execute(
117            "CREATE TABLE files (
118                id INTEGER PRIMARY KEY AUTOINCREMENT,
119                worktree_id INTEGER NOT NULL,
120                relative_path VARCHAR NOT NULL,
121                mtime_seconds INTEGER NOT NULL,
122                mtime_nanos INTEGER NOT NULL,
123                FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
124            )",
125            [],
126        )?;
127
128        self.db.execute(
129            "CREATE TABLE documents (
130                id INTEGER PRIMARY KEY AUTOINCREMENT,
131                file_id INTEGER NOT NULL,
132                start_byte INTEGER NOT NULL,
133                end_byte INTEGER NOT NULL,
134                name VARCHAR NOT NULL,
135                embedding BLOB NOT NULL,
136                FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
137            )",
138            [],
139        )?;
140
141        log::trace!("vector database initialized with updated schema.");
142        Ok(())
143    }
144
145    pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> {
146        self.db.execute(
147            "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
148            params![worktree_id, delete_path.to_str()],
149        )?;
150        Ok(())
151    }
152
153    pub fn insert_file(
154        &self,
155        worktree_id: i64,
156        path: PathBuf,
157        mtime: SystemTime,
158        documents: Vec<Document>,
159    ) -> Result<()> {
160        // Write to files table, and return generated id.
161        self.db.execute(
162            "
163            DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
164            ",
165            params![worktree_id, path.to_str()],
166        )?;
167        let mtime = Timestamp::from(mtime);
168        self.db.execute(
169            "
170            INSERT INTO files
171            (worktree_id, relative_path, mtime_seconds, mtime_nanos)
172            VALUES
173            (?1, ?2, $3, $4);
174            ",
175            params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
176        )?;
177
178        let file_id = self.db.last_insert_rowid();
179
180        // Currently inserting at approximately 3400 documents a second
181        // I imagine we can speed this up with a bulk insert of some kind.
182        for document in documents {
183            let embedding_blob = bincode::serialize(&document.embedding)?;
184
185            self.db.execute(
186                "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding) VALUES (?1, ?2, ?3, ?4, $5)",
187                params![
188                    file_id,
189                    document.range.start.to_string(),
190                    document.range.end.to_string(),
191                    document.name,
192                    embedding_blob
193                ],
194            )?;
195        }
196
197        Ok(())
198    }
199
200    pub fn worktree_previously_indexed(&self, worktree_root_path: &Path) -> Result<bool> {
201        let mut worktree_query = self
202            .db
203            .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
204        let worktree_id = worktree_query
205            .query_row(params![worktree_root_path.to_string_lossy()], |row| {
206                Ok(row.get::<_, i64>(0)?)
207            })
208            .map_err(|err| anyhow!(err));
209
210        if worktree_id.is_ok() {
211            return Ok(true);
212        } else {
213            return Ok(false);
214        }
215    }
216
217    pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
218        // Check that the absolute path doesnt exist
219        let mut worktree_query = self
220            .db
221            .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
222
223        let worktree_id = worktree_query
224            .query_row(params![worktree_root_path.to_string_lossy()], |row| {
225                Ok(row.get::<_, i64>(0)?)
226            })
227            .map_err(|err| anyhow!(err));
228
229        if worktree_id.is_ok() {
230            return worktree_id;
231        }
232
233        // If worktree_id is Err, insert new worktree
234        self.db.execute(
235            "
236            INSERT into worktrees (absolute_path) VALUES (?1)
237            ",
238            params![worktree_root_path.to_string_lossy()],
239        )?;
240        Ok(self.db.last_insert_rowid())
241    }
242
243    pub fn get_file_mtimes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, SystemTime>> {
244        let mut statement = self.db.prepare(
245            "
246            SELECT relative_path, mtime_seconds, mtime_nanos
247            FROM files
248            WHERE worktree_id = ?1
249            ORDER BY relative_path",
250        )?;
251        let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
252        for row in statement.query_map(params![worktree_id], |row| {
253            Ok((
254                row.get::<_, String>(0)?.into(),
255                Timestamp {
256                    seconds: row.get(1)?,
257                    nanos: row.get(2)?,
258                }
259                .into(),
260            ))
261        })? {
262            let row = row?;
263            result.insert(row.0, row.1);
264        }
265        Ok(result)
266    }
267
268    pub fn top_k_search(
269        &self,
270        worktree_ids: &[i64],
271        query_embedding: &Vec<f32>,
272        limit: usize,
273        include_globs: Vec<GlobMatcher>,
274        exclude_globs: Vec<GlobMatcher>,
275    ) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
276        let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
277        self.for_each_document(
278            &worktree_ids,
279            include_globs,
280            exclude_globs,
281            |id, embedding| {
282                let similarity = dot(&embedding, &query_embedding);
283                let ix = match results.binary_search_by(|(_, s)| {
284                    similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
285                }) {
286                    Ok(ix) => ix,
287                    Err(ix) => ix,
288                };
289                results.insert(ix, (id, similarity));
290                results.truncate(limit);
291            },
292        )?;
293
294        let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
295        self.get_documents_by_ids(&ids)
296    }
297
298    fn for_each_document(
299        &self,
300        worktree_ids: &[i64],
301        include_globs: Vec<GlobMatcher>,
302        exclude_globs: Vec<GlobMatcher>,
303        mut f: impl FnMut(i64, Vec<f32>),
304    ) -> Result<()> {
305        let mut file_query = self.db.prepare(
306            "
307            SELECT
308                id, relative_path
309            FROM
310                files
311            WHERE
312                worktree_id IN rarray(?)
313            ",
314        )?;
315
316        let mut file_ids = Vec::<i64>::new();
317        let mut rows = file_query.query([ids_to_sql(worktree_ids)])?;
318        while let Some(row) = rows.next()? {
319            let file_id = row.get(0)?;
320            let relative_path = row.get_ref(1)?.as_str()?;
321            let included = include_globs.is_empty()
322                || include_globs
323                    .iter()
324                    .any(|glob| glob.is_match(relative_path));
325            let excluded = exclude_globs
326                .iter()
327                .any(|glob| glob.is_match(relative_path));
328            if included && !excluded {
329                file_ids.push(file_id);
330            }
331        }
332
333        let mut query_statement = self.db.prepare(
334            "
335            SELECT
336                id, embedding
337            FROM
338                documents
339            WHERE
340                file_id IN rarray(?)
341            ",
342        )?;
343
344        query_statement
345            .query_map(params![ids_to_sql(&file_ids)], |row| {
346                Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
347            })?
348            .filter_map(|row| row.ok())
349            .for_each(|(id, embedding)| f(id, embedding.0));
350        Ok(())
351    }
352
353    fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
354        let mut statement = self.db.prepare(
355            "
356                SELECT
357                    documents.id,
358                    files.worktree_id,
359                    files.relative_path,
360                    documents.start_byte,
361                    documents.end_byte
362                FROM
363                    documents, files
364                WHERE
365                    documents.file_id = files.id AND
366                    documents.id in rarray(?)
367            ",
368        )?;
369
370        let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
371            Ok((
372                row.get::<_, i64>(0)?,
373                row.get::<_, i64>(1)?,
374                row.get::<_, String>(2)?.into(),
375                row.get(3)?..row.get(4)?,
376            ))
377        })?;
378
379        let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
380        for row in result_iter {
381            let (id, worktree_id, path, range) = row?;
382            values_by_id.insert(id, (worktree_id, path, range));
383        }
384
385        let mut results = Vec::with_capacity(ids.len());
386        for id in ids {
387            let value = values_by_id
388                .remove(id)
389                .ok_or(anyhow!("missing document id {}", id))?;
390            results.push(value);
391        }
392
393        Ok(results)
394    }
395}
396
397fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
398    Rc::new(
399        ids.iter()
400            .copied()
401            .map(|v| rusqlite::types::Value::from(v))
402            .collect::<Vec<_>>(),
403    )
404}
405
406pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
407    let len = vec_a.len();
408    assert_eq!(len, vec_b.len());
409
410    let mut result = 0.0;
411    unsafe {
412        matrixmultiply::sgemm(
413            1,
414            len,
415            1,
416            1.0,
417            vec_a.as_ptr(),
418            len as isize,
419            1,
420            vec_b.as_ptr(),
421            1,
422            len as isize,
423            0.0,
424            &mut result as *mut f32,
425            1,
426            1,
427        );
428    }
429    result
430}