db.rs

  1use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
  2use anyhow::{anyhow, Context, Result};
  3use project::{search::PathMatcher, Fs};
  4use rpc::proto::Timestamp;
  5use rusqlite::{
  6    params,
  7    types::{FromSql, FromSqlResult, ValueRef},
  8};
  9use std::{
 10    cmp::Ordering,
 11    collections::HashMap,
 12    ops::Range,
 13    path::{Path, PathBuf},
 14    rc::Rc,
 15    sync::Arc,
 16    time::SystemTime,
 17};
 18
 19#[derive(Debug)]
 20pub struct FileRecord {
 21    pub id: usize,
 22    pub relative_path: String,
 23    pub mtime: Timestamp,
 24}
 25
 26#[derive(Debug)]
 27struct Embedding(pub Vec<f32>);
 28
 29impl FromSql for Embedding {
 30    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 31        let bytes = value.as_blob()?;
 32        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 33        if embedding.is_err() {
 34            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 35        }
 36        return Ok(Embedding(embedding.unwrap()));
 37    }
 38}
 39
 40pub struct VectorDatabase {
 41    db: rusqlite::Connection,
 42}
 43
 44impl VectorDatabase {
 45    pub async fn new(fs: Arc<dyn Fs>, path: Arc<PathBuf>) -> Result<Self> {
 46        if let Some(db_directory) = path.parent() {
 47            fs.create_dir(db_directory).await?;
 48        }
 49
 50        let this = Self {
 51            db: rusqlite::Connection::open(path.as_path())?,
 52        };
 53        this.initialize_database()?;
 54        Ok(this)
 55    }
 56
 57    fn get_existing_version(&self) -> Result<i64> {
 58        let mut version_query = self
 59            .db
 60            .prepare("SELECT version from semantic_index_config")?;
 61        version_query
 62            .query_row([], |row| Ok(row.get::<_, i64>(0)?))
 63            .map_err(|err| anyhow!("version query failed: {err}"))
 64    }
 65
 66    fn initialize_database(&self) -> Result<()> {
 67        rusqlite::vtab::array::load_module(&self.db)?;
 68
 69        // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
 70        if self
 71            .get_existing_version()
 72            .map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64)
 73        {
 74            log::trace!("vector database schema up to date");
 75            return Ok(());
 76        }
 77
 78        log::trace!("vector database schema out of date. updating...");
 79        self.db
 80            .execute("DROP TABLE IF EXISTS documents", [])
 81            .context("failed to drop 'documents' table")?;
 82        self.db
 83            .execute("DROP TABLE IF EXISTS files", [])
 84            .context("failed to drop 'files' table")?;
 85        self.db
 86            .execute("DROP TABLE IF EXISTS worktrees", [])
 87            .context("failed to drop 'worktrees' table")?;
 88        self.db
 89            .execute("DROP TABLE IF EXISTS semantic_index_config", [])
 90            .context("failed to drop 'semantic_index_config' table")?;
 91
 92        // Initialize Vector Databasing Tables
 93        self.db.execute(
 94            "CREATE TABLE semantic_index_config (
 95                version INTEGER NOT NULL
 96            )",
 97            [],
 98        )?;
 99
100        self.db.execute(
101            "INSERT INTO semantic_index_config (version) VALUES (?1)",
102            params![SEMANTIC_INDEX_VERSION],
103        )?;
104
105        self.db.execute(
106            "CREATE TABLE worktrees (
107                id INTEGER PRIMARY KEY AUTOINCREMENT,
108                absolute_path VARCHAR NOT NULL
109            );
110            CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
111            ",
112            [],
113        )?;
114
115        self.db.execute(
116            "CREATE TABLE files (
117                id INTEGER PRIMARY KEY AUTOINCREMENT,
118                worktree_id INTEGER NOT NULL,
119                relative_path VARCHAR NOT NULL,
120                mtime_seconds INTEGER NOT NULL,
121                mtime_nanos INTEGER NOT NULL,
122                FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
123            )",
124            [],
125        )?;
126
127        self.db.execute(
128            "CREATE TABLE documents (
129                id INTEGER PRIMARY KEY AUTOINCREMENT,
130                file_id INTEGER NOT NULL,
131                start_byte INTEGER NOT NULL,
132                end_byte INTEGER NOT NULL,
133                name VARCHAR NOT NULL,
134                embedding BLOB NOT NULL,
135                FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
136            )",
137            [],
138        )?;
139
140        log::trace!("vector database initialized with updated schema.");
141        Ok(())
142    }
143
144    pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> {
145        self.db.execute(
146            "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
147            params![worktree_id, delete_path.to_str()],
148        )?;
149        Ok(())
150    }
151
152    pub fn insert_file(
153        &self,
154        worktree_id: i64,
155        path: PathBuf,
156        mtime: SystemTime,
157        documents: Vec<Document>,
158    ) -> Result<()> {
159        // Write to files table, and return generated id.
160        self.db.execute(
161            "
162            DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
163            ",
164            params![worktree_id, path.to_str()],
165        )?;
166        let mtime = Timestamp::from(mtime);
167        self.db.execute(
168            "
169            INSERT INTO files
170            (worktree_id, relative_path, mtime_seconds, mtime_nanos)
171            VALUES
172            (?1, ?2, $3, $4);
173            ",
174            params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
175        )?;
176
177        let file_id = self.db.last_insert_rowid();
178
179        // Currently inserting at approximately 3400 documents a second
180        // I imagine we can speed this up with a bulk insert of some kind.
181        for document in documents {
182            let embedding_blob = bincode::serialize(&document.embedding)?;
183
184            self.db.execute(
185                "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding) VALUES (?1, ?2, ?3, ?4, $5)",
186                params![
187                    file_id,
188                    document.range.start.to_string(),
189                    document.range.end.to_string(),
190                    document.name,
191                    embedding_blob
192                ],
193            )?;
194        }
195
196        Ok(())
197    }
198
199    pub fn worktree_previously_indexed(&self, worktree_root_path: &Path) -> Result<bool> {
200        let mut worktree_query = self
201            .db
202            .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
203        let worktree_id = worktree_query
204            .query_row(params![worktree_root_path.to_string_lossy()], |row| {
205                Ok(row.get::<_, i64>(0)?)
206            })
207            .map_err(|err| anyhow!(err));
208
209        if worktree_id.is_ok() {
210            return Ok(true);
211        } else {
212            return Ok(false);
213        }
214    }
215
216    pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
217        // Check that the absolute path doesnt exist
218        let mut worktree_query = self
219            .db
220            .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
221
222        let worktree_id = worktree_query
223            .query_row(params![worktree_root_path.to_string_lossy()], |row| {
224                Ok(row.get::<_, i64>(0)?)
225            })
226            .map_err(|err| anyhow!(err));
227
228        if worktree_id.is_ok() {
229            return worktree_id;
230        }
231
232        // If worktree_id is Err, insert new worktree
233        self.db.execute(
234            "
235            INSERT into worktrees (absolute_path) VALUES (?1)
236            ",
237            params![worktree_root_path.to_string_lossy()],
238        )?;
239        Ok(self.db.last_insert_rowid())
240    }
241
242    pub fn get_file_mtimes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, SystemTime>> {
243        let mut statement = self.db.prepare(
244            "
245            SELECT relative_path, mtime_seconds, mtime_nanos
246            FROM files
247            WHERE worktree_id = ?1
248            ORDER BY relative_path",
249        )?;
250        let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
251        for row in statement.query_map(params![worktree_id], |row| {
252            Ok((
253                row.get::<_, String>(0)?.into(),
254                Timestamp {
255                    seconds: row.get(1)?,
256                    nanos: row.get(2)?,
257                }
258                .into(),
259            ))
260        })? {
261            let row = row?;
262            result.insert(row.0, row.1);
263        }
264        Ok(result)
265    }
266
267    pub fn top_k_search(
268        &self,
269        query_embedding: &Vec<f32>,
270        limit: usize,
271        file_ids: &[i64],
272    ) -> Result<Vec<(i64, f32)>> {
273        let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
274        self.for_each_document(file_ids, |id, embedding| {
275            let similarity = dot(&embedding, &query_embedding);
276            let ix = match results
277                .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
278            {
279                Ok(ix) => ix,
280                Err(ix) => ix,
281            };
282            results.insert(ix, (id, similarity));
283            results.truncate(limit);
284        })?;
285
286        Ok(results)
287    }
288
289    pub fn retrieve_included_file_ids(
290        &self,
291        worktree_ids: &[i64],
292        includes: &[PathMatcher],
293        excludes: &[PathMatcher],
294    ) -> Result<Vec<i64>> {
295        let mut file_query = self.db.prepare(
296            "
297            SELECT
298                id, relative_path
299            FROM
300                files
301            WHERE
302                worktree_id IN rarray(?)
303            ",
304        )?;
305
306        let mut file_ids = Vec::<i64>::new();
307        let mut rows = file_query.query([ids_to_sql(worktree_ids)])?;
308
309        while let Some(row) = rows.next()? {
310            let file_id = row.get(0)?;
311            let relative_path = row.get_ref(1)?.as_str()?;
312            let included =
313                includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
314            let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
315            if included && !excluded {
316                file_ids.push(file_id);
317            }
318        }
319
320        Ok(file_ids)
321    }
322
323    fn for_each_document(&self, file_ids: &[i64], mut f: impl FnMut(i64, Vec<f32>)) -> Result<()> {
324        let mut query_statement = self.db.prepare(
325            "
326            SELECT
327                id, embedding
328            FROM
329                documents
330            WHERE
331                file_id IN rarray(?)
332            ",
333        )?;
334
335        query_statement
336            .query_map(params![ids_to_sql(&file_ids)], |row| {
337                Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
338            })?
339            .filter_map(|row| row.ok())
340            .for_each(|(id, embedding)| f(id, embedding.0));
341        Ok(())
342    }
343
344    pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
345        let mut statement = self.db.prepare(
346            "
347                SELECT
348                    documents.id,
349                    files.worktree_id,
350                    files.relative_path,
351                    documents.start_byte,
352                    documents.end_byte
353                FROM
354                    documents, files
355                WHERE
356                    documents.file_id = files.id AND
357                    documents.id in rarray(?)
358            ",
359        )?;
360
361        let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
362            Ok((
363                row.get::<_, i64>(0)?,
364                row.get::<_, i64>(1)?,
365                row.get::<_, String>(2)?.into(),
366                row.get(3)?..row.get(4)?,
367            ))
368        })?;
369
370        let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
371        for row in result_iter {
372            let (id, worktree_id, path, range) = row?;
373            values_by_id.insert(id, (worktree_id, path, range));
374        }
375
376        let mut results = Vec::with_capacity(ids.len());
377        for id in ids {
378            let value = values_by_id
379                .remove(id)
380                .ok_or(anyhow!("missing document id {}", id))?;
381            results.push(value);
382        }
383
384        Ok(results)
385    }
386}
387
388fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
389    Rc::new(
390        ids.iter()
391            .copied()
392            .map(|v| rusqlite::types::Value::from(v))
393            .collect::<Vec<_>>(),
394    )
395}
396
397pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
398    let len = vec_a.len();
399    assert_eq!(len, vec_b.len());
400
401    let mut result = 0.0;
402    unsafe {
403        matrixmultiply::sgemm(
404            1,
405            len,
406            1,
407            1.0,
408            vec_a.as_ptr(),
409            len as isize,
410            1,
411            vec_b.as_ptr(),
412            1,
413            len as isize,
414            0.0,
415            &mut result as *mut f32,
416            1,
417            1,
418        );
419    }
420    result
421}