db.rs

  1use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
  2use anyhow::{anyhow, Context, Result};
  3use project::{search::PathMatcher, Fs};
  4use rpc::proto::Timestamp;
  5use rusqlite::{
  6    params,
  7    types::{FromSql, FromSqlResult, ValueRef},
  8};
  9use std::{
 10    cmp::Ordering,
 11    collections::HashMap,
 12    ops::Range,
 13    path::{Path, PathBuf},
 14    rc::Rc,
 15    sync::Arc,
 16    time::SystemTime,
 17};
 18
 19#[derive(Debug)]
 20pub struct FileRecord {
 21    pub id: usize,
 22    pub relative_path: String,
 23    pub mtime: Timestamp,
 24}
 25
 26#[derive(Debug)]
 27struct Embedding(pub Vec<f32>);
 28
 29impl FromSql for Embedding {
 30    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 31        let bytes = value.as_blob()?;
 32        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 33        if embedding.is_err() {
 34            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 35        }
 36        return Ok(Embedding(embedding.unwrap()));
 37    }
 38}
 39
 40pub struct VectorDatabase {
 41    db: rusqlite::Connection,
 42}
 43
 44impl VectorDatabase {
 45    pub async fn new(fs: Arc<dyn Fs>, path: Arc<PathBuf>) -> Result<Self> {
 46        if let Some(db_directory) = path.parent() {
 47            fs.create_dir(db_directory).await?;
 48        }
 49
 50        let this = Self {
 51            db: rusqlite::Connection::open(path.as_path())?,
 52        };
 53        this.initialize_database()?;
 54        Ok(this)
 55    }
 56
 57    fn get_existing_version(&self) -> Result<i64> {
 58        let mut version_query = self
 59            .db
 60            .prepare("SELECT version from semantic_index_config")?;
 61        version_query
 62            .query_row([], |row| Ok(row.get::<_, i64>(0)?))
 63            .map_err(|err| anyhow!("version query failed: {err}"))
 64    }
 65
 66    fn initialize_database(&self) -> Result<()> {
 67        rusqlite::vtab::array::load_module(&self.db)?;
 68
 69        // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
 70        if self
 71            .get_existing_version()
 72            .map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64)
 73        {
 74            log::trace!("vector database schema up to date");
 75            return Ok(());
 76        }
 77
 78        log::trace!("vector database schema out of date. updating...");
 79        self.db
 80            .execute("DROP TABLE IF EXISTS documents", [])
 81            .context("failed to drop 'documents' table")?;
 82        self.db
 83            .execute("DROP TABLE IF EXISTS files", [])
 84            .context("failed to drop 'files' table")?;
 85        self.db
 86            .execute("DROP TABLE IF EXISTS worktrees", [])
 87            .context("failed to drop 'worktrees' table")?;
 88        self.db
 89            .execute("DROP TABLE IF EXISTS semantic_index_config", [])
 90            .context("failed to drop 'semantic_index_config' table")?;
 91
 92        // Initialize Vector Databasing Tables
 93        self.db.execute(
 94            "CREATE TABLE semantic_index_config (
 95                version INTEGER NOT NULL
 96            )",
 97            [],
 98        )?;
 99
100        self.db.execute(
101            "INSERT INTO semantic_index_config (version) VALUES (?1)",
102            params![SEMANTIC_INDEX_VERSION],
103        )?;
104
105        self.db.execute(
106            "CREATE TABLE worktrees (
107                id INTEGER PRIMARY KEY AUTOINCREMENT,
108                absolute_path VARCHAR NOT NULL
109            );
110            CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
111            ",
112            [],
113        )?;
114
115        self.db.execute(
116            "CREATE TABLE files (
117                id INTEGER PRIMARY KEY AUTOINCREMENT,
118                worktree_id INTEGER NOT NULL,
119                relative_path VARCHAR NOT NULL,
120                mtime_seconds INTEGER NOT NULL,
121                mtime_nanos INTEGER NOT NULL,
122                FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
123            )",
124            [],
125        )?;
126
127        self.db.execute(
128            "CREATE TABLE documents (
129                id INTEGER PRIMARY KEY AUTOINCREMENT,
130                file_id INTEGER NOT NULL,
131                start_byte INTEGER NOT NULL,
132                end_byte INTEGER NOT NULL,
133                name VARCHAR NOT NULL,
134                embedding BLOB NOT NULL,
135                FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
136            )",
137            [],
138        )?;
139
140        log::trace!("vector database initialized with updated schema.");
141        Ok(())
142    }
143
144    pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> {
145        self.db.execute(
146            "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
147            params![worktree_id, delete_path.to_str()],
148        )?;
149        Ok(())
150    }
151
152    pub fn insert_file(
153        &self,
154        worktree_id: i64,
155        path: PathBuf,
156        mtime: SystemTime,
157        documents: Vec<Document>,
158    ) -> Result<()> {
159        // Return the existing ID, if both the file and mtime match
160        let mtime = Timestamp::from(mtime);
161        let mut existing_id_query = self.db.prepare("SELECT id FROM files WHERE worktree_id = ?1 AND relative_path = ?2 AND mtime_seconds = ?3 AND mtime_nanos = ?4")?;
162        let existing_id = existing_id_query
163            .query_row(
164                params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
165                |row| Ok(row.get::<_, i64>(0)?),
166            )
167            .map_err(|err| anyhow!(err));
168        let file_id = if existing_id.is_ok() {
169            // If already exists, just return the existing id
170            existing_id.unwrap()
171        } else {
172            // Delete Existing Row
173            self.db.execute(
174                "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;",
175                params![worktree_id, path.to_str()],
176            )?;
177            self.db.execute("INSERT INTO files (worktree_id, relative_path, mtime_seconds, mtime_nanos) VALUES (?1, ?2, ?3, ?4);", params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos])?;
178            self.db.last_insert_rowid()
179        };
180
181        // Currently inserting at approximately 3400 documents a second
182        // I imagine we can speed this up with a bulk insert of some kind.
183        for document in documents {
184            let embedding_blob = bincode::serialize(&document.embedding)?;
185
186            self.db.execute(
187                "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding) VALUES (?1, ?2, ?3, ?4, $5)",
188                params![
189                    file_id,
190                    document.range.start.to_string(),
191                    document.range.end.to_string(),
192                    document.name,
193                    embedding_blob
194                ],
195            )?;
196        }
197
198        Ok(())
199    }
200
201    pub fn worktree_previously_indexed(&self, worktree_root_path: &Path) -> Result<bool> {
202        let mut worktree_query = self
203            .db
204            .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
205        let worktree_id = worktree_query
206            .query_row(params![worktree_root_path.to_string_lossy()], |row| {
207                Ok(row.get::<_, i64>(0)?)
208            })
209            .map_err(|err| anyhow!(err));
210
211        if worktree_id.is_ok() {
212            return Ok(true);
213        } else {
214            return Ok(false);
215        }
216    }
217
218    pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
219        // Check that the absolute path doesnt exist
220        let mut worktree_query = self
221            .db
222            .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
223
224        let worktree_id = worktree_query
225            .query_row(params![worktree_root_path.to_string_lossy()], |row| {
226                Ok(row.get::<_, i64>(0)?)
227            })
228            .map_err(|err| anyhow!(err));
229
230        if worktree_id.is_ok() {
231            return worktree_id;
232        }
233
234        // If worktree_id is Err, insert new worktree
235        self.db.execute(
236            "
237            INSERT into worktrees (absolute_path) VALUES (?1)
238            ",
239            params![worktree_root_path.to_string_lossy()],
240        )?;
241        Ok(self.db.last_insert_rowid())
242    }
243
244    pub fn get_file_mtimes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, SystemTime>> {
245        let mut statement = self.db.prepare(
246            "
247            SELECT relative_path, mtime_seconds, mtime_nanos
248            FROM files
249            WHERE worktree_id = ?1
250            ORDER BY relative_path",
251        )?;
252        let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
253        for row in statement.query_map(params![worktree_id], |row| {
254            Ok((
255                row.get::<_, String>(0)?.into(),
256                Timestamp {
257                    seconds: row.get(1)?,
258                    nanos: row.get(2)?,
259                }
260                .into(),
261            ))
262        })? {
263            let row = row?;
264            result.insert(row.0, row.1);
265        }
266        Ok(result)
267    }
268
269    pub fn top_k_search(
270        &self,
271        query_embedding: &Vec<f32>,
272        limit: usize,
273        file_ids: &[i64],
274    ) -> Result<Vec<(i64, f32)>> {
275        let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
276        self.for_each_document(file_ids, |id, embedding| {
277            let similarity = dot(&embedding, &query_embedding);
278            let ix = match results
279                .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
280            {
281                Ok(ix) => ix,
282                Err(ix) => ix,
283            };
284            results.insert(ix, (id, similarity));
285            results.truncate(limit);
286        })?;
287
288        Ok(results)
289    }
290
291    pub fn retrieve_included_file_ids(
292        &self,
293        worktree_ids: &[i64],
294        includes: &[PathMatcher],
295        excludes: &[PathMatcher],
296    ) -> Result<Vec<i64>> {
297        let mut file_query = self.db.prepare(
298            "
299            SELECT
300                id, relative_path
301            FROM
302                files
303            WHERE
304                worktree_id IN rarray(?)
305            ",
306        )?;
307
308        let mut file_ids = Vec::<i64>::new();
309        let mut rows = file_query.query([ids_to_sql(worktree_ids)])?;
310
311        while let Some(row) = rows.next()? {
312            let file_id = row.get(0)?;
313            let relative_path = row.get_ref(1)?.as_str()?;
314            let included =
315                includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
316            let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
317            if included && !excluded {
318                file_ids.push(file_id);
319            }
320        }
321
322        Ok(file_ids)
323    }
324
325    fn for_each_document(&self, file_ids: &[i64], mut f: impl FnMut(i64, Vec<f32>)) -> Result<()> {
326        let mut query_statement = self.db.prepare(
327            "
328            SELECT
329                id, embedding
330            FROM
331                documents
332            WHERE
333                file_id IN rarray(?)
334            ",
335        )?;
336
337        query_statement
338            .query_map(params![ids_to_sql(&file_ids)], |row| {
339                Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
340            })?
341            .filter_map(|row| row.ok())
342            .for_each(|(id, embedding)| f(id, embedding.0));
343        Ok(())
344    }
345
346    pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
347        let mut statement = self.db.prepare(
348            "
349                SELECT
350                    documents.id,
351                    files.worktree_id,
352                    files.relative_path,
353                    documents.start_byte,
354                    documents.end_byte
355                FROM
356                    documents, files
357                WHERE
358                    documents.file_id = files.id AND
359                    documents.id in rarray(?)
360            ",
361        )?;
362
363        let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
364            Ok((
365                row.get::<_, i64>(0)?,
366                row.get::<_, i64>(1)?,
367                row.get::<_, String>(2)?.into(),
368                row.get(3)?..row.get(4)?,
369            ))
370        })?;
371
372        let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
373        for row in result_iter {
374            let (id, worktree_id, path, range) = row?;
375            values_by_id.insert(id, (worktree_id, path, range));
376        }
377
378        let mut results = Vec::with_capacity(ids.len());
379        for id in ids {
380            let value = values_by_id
381                .remove(id)
382                .ok_or(anyhow!("missing document id {}", id))?;
383            results.push(value);
384        }
385
386        Ok(results)
387    }
388}
389
390fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
391    Rc::new(
392        ids.iter()
393            .copied()
394            .map(|v| rusqlite::types::Value::from(v))
395            .collect::<Vec<_>>(),
396    )
397}
398
399pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
400    let len = vec_a.len();
401    assert_eq!(len, vec_b.len());
402
403    let mut result = 0.0;
404    unsafe {
405        matrixmultiply::sgemm(
406            1,
407            len,
408            1,
409            1.0,
410            vec_a.as_ptr(),
411            len as isize,
412            1,
413            vec_b.as_ptr(),
414            1,
415            len as isize,
416            0.0,
417            &mut result as *mut f32,
418            1,
419            1,
420        );
421    }
422    result
423}