db.rs

  1use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
  2use anyhow::{anyhow, Context, Result};
  3use project::{search::PathMatcher, Fs};
  4use rpc::proto::Timestamp;
  5use rusqlite::{
  6    params,
  7    types::{FromSql, FromSqlResult, ValueRef},
  8};
  9use std::{
 10    cmp::Ordering,
 11    collections::HashMap,
 12    ops::Range,
 13    path::{Path, PathBuf},
 14    rc::Rc,
 15    sync::Arc,
 16    time::SystemTime,
 17};
 18
 19#[derive(Debug)]
 20pub struct FileRecord {
 21    pub id: usize,
 22    pub relative_path: String,
 23    pub mtime: Timestamp,
 24}
 25
 26#[derive(Debug)]
 27struct Embedding(pub Vec<f32>);
 28
 29#[derive(Debug)]
 30struct Sha1(pub Vec<u8>);
 31
 32impl FromSql for Embedding {
 33    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 34        let bytes = value.as_blob()?;
 35        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 36        if embedding.is_err() {
 37            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 38        }
 39        return Ok(Embedding(embedding.unwrap()));
 40    }
 41}
 42
 43impl FromSql for Sha1 {
 44    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 45        let bytes = value.as_blob()?;
 46        let sha1: Result<Vec<u8>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 47        if sha1.is_err() {
 48            return Err(rusqlite::types::FromSqlError::Other(sha1.unwrap_err()));
 49        }
 50        return Ok(Sha1(sha1.unwrap()));
 51    }
 52}
 53
 54pub struct VectorDatabase {
 55    db: rusqlite::Connection,
 56}
 57
 58impl VectorDatabase {
 59    pub async fn new(fs: Arc<dyn Fs>, path: Arc<PathBuf>) -> Result<Self> {
 60        if let Some(db_directory) = path.parent() {
 61            fs.create_dir(db_directory).await?;
 62        }
 63
 64        let this = Self {
 65            db: rusqlite::Connection::open(path.as_path())?,
 66        };
 67        this.initialize_database()?;
 68        Ok(this)
 69    }
 70
 71    fn get_existing_version(&self) -> Result<i64> {
 72        let mut version_query = self
 73            .db
 74            .prepare("SELECT version from semantic_index_config")?;
 75        version_query
 76            .query_row([], |row| Ok(row.get::<_, i64>(0)?))
 77            .map_err(|err| anyhow!("version query failed: {err}"))
 78    }
 79
 80    fn initialize_database(&self) -> Result<()> {
 81        rusqlite::vtab::array::load_module(&self.db)?;
 82
 83        // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
 84        if self
 85            .get_existing_version()
 86            .map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64)
 87        {
 88            log::trace!("vector database schema up to date");
 89            return Ok(());
 90        }
 91
 92        log::trace!("vector database schema out of date. updating...");
 93        self.db
 94            .execute("DROP TABLE IF EXISTS documents", [])
 95            .context("failed to drop 'documents' table")?;
 96        self.db
 97            .execute("DROP TABLE IF EXISTS files", [])
 98            .context("failed to drop 'files' table")?;
 99        self.db
100            .execute("DROP TABLE IF EXISTS worktrees", [])
101            .context("failed to drop 'worktrees' table")?;
102        self.db
103            .execute("DROP TABLE IF EXISTS semantic_index_config", [])
104            .context("failed to drop 'semantic_index_config' table")?;
105
106        // Initialize Vector Databasing Tables
107        self.db.execute(
108            "CREATE TABLE semantic_index_config (
109                version INTEGER NOT NULL
110            )",
111            [],
112        )?;
113
114        self.db.execute(
115            "INSERT INTO semantic_index_config (version) VALUES (?1)",
116            params![SEMANTIC_INDEX_VERSION],
117        )?;
118
119        self.db.execute(
120            "CREATE TABLE worktrees (
121                id INTEGER PRIMARY KEY AUTOINCREMENT,
122                absolute_path VARCHAR NOT NULL
123            );
124            CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
125            ",
126            [],
127        )?;
128
129        self.db.execute(
130            "CREATE TABLE files (
131                id INTEGER PRIMARY KEY AUTOINCREMENT,
132                worktree_id INTEGER NOT NULL,
133                relative_path VARCHAR NOT NULL,
134                mtime_seconds INTEGER NOT NULL,
135                mtime_nanos INTEGER NOT NULL,
136                FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
137            )",
138            [],
139        )?;
140
141        self.db.execute(
142            "CREATE TABLE documents (
143                id INTEGER PRIMARY KEY AUTOINCREMENT,
144                file_id INTEGER NOT NULL,
145                start_byte INTEGER NOT NULL,
146                end_byte INTEGER NOT NULL,
147                name VARCHAR NOT NULL,
148                embedding BLOB NOT NULL,
149                sha1 BLOB NOT NULL,
150                FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
151            )",
152            [],
153        )?;
154
155        log::trace!("vector database initialized with updated schema.");
156        Ok(())
157    }
158
159    pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> {
160        self.db.execute(
161            "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
162            params![worktree_id, delete_path.to_str()],
163        )?;
164        Ok(())
165    }
166
167    pub fn insert_file(
168        &self,
169        worktree_id: i64,
170        path: PathBuf,
171        mtime: SystemTime,
172        documents: Vec<Document>,
173    ) -> Result<()> {
174        // Return the existing ID, if both the file and mtime match
175        let mtime = Timestamp::from(mtime);
176        let mut existing_id_query = self.db.prepare("SELECT id FROM files WHERE worktree_id = ?1 AND relative_path = ?2 AND mtime_seconds = ?3 AND mtime_nanos = ?4")?;
177        let existing_id = existing_id_query
178            .query_row(
179                params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
180                |row| Ok(row.get::<_, i64>(0)?),
181            )
182            .map_err(|err| anyhow!(err));
183        let file_id = if existing_id.is_ok() {
184            // If already exists, just return the existing id
185            existing_id.unwrap()
186        } else {
187            // Delete Existing Row
188            self.db.execute(
189                "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;",
190                params![worktree_id, path.to_str()],
191            )?;
192            self.db.execute("INSERT INTO files (worktree_id, relative_path, mtime_seconds, mtime_nanos) VALUES (?1, ?2, ?3, ?4);", params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos])?;
193            self.db.last_insert_rowid()
194        };
195
196        // Currently inserting at approximately 3400 documents a second
197        // I imagine we can speed this up with a bulk insert of some kind.
198        for document in documents {
199            let embedding_blob = bincode::serialize(&document.embedding)?;
200            let sha_blob = bincode::serialize(&document.sha1)?;
201
202            self.db.execute(
203                "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
204                params![
205                    file_id,
206                    document.range.start.to_string(),
207                    document.range.end.to_string(),
208                    document.name,
209                    embedding_blob,
210                    sha_blob
211                ],
212            )?;
213        }
214
215        Ok(())
216    }
217
218    pub fn worktree_previously_indexed(&self, worktree_root_path: &Path) -> Result<bool> {
219        let mut worktree_query = self
220            .db
221            .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
222        let worktree_id = worktree_query
223            .query_row(params![worktree_root_path.to_string_lossy()], |row| {
224                Ok(row.get::<_, i64>(0)?)
225            })
226            .map_err(|err| anyhow!(err));
227
228        if worktree_id.is_ok() {
229            return Ok(true);
230        } else {
231            return Ok(false);
232        }
233    }
234
235    pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
236        // Check that the absolute path doesnt exist
237        let mut worktree_query = self
238            .db
239            .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
240
241        let worktree_id = worktree_query
242            .query_row(params![worktree_root_path.to_string_lossy()], |row| {
243                Ok(row.get::<_, i64>(0)?)
244            })
245            .map_err(|err| anyhow!(err));
246
247        if worktree_id.is_ok() {
248            return worktree_id;
249        }
250
251        // If worktree_id is Err, insert new worktree
252        self.db.execute(
253            "
254            INSERT into worktrees (absolute_path) VALUES (?1)
255            ",
256            params![worktree_root_path.to_string_lossy()],
257        )?;
258        Ok(self.db.last_insert_rowid())
259    }
260
261    pub fn get_file_mtimes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, SystemTime>> {
262        let mut statement = self.db.prepare(
263            "
264            SELECT relative_path, mtime_seconds, mtime_nanos
265            FROM files
266            WHERE worktree_id = ?1
267            ORDER BY relative_path",
268        )?;
269        let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
270        for row in statement.query_map(params![worktree_id], |row| {
271            Ok((
272                row.get::<_, String>(0)?.into(),
273                Timestamp {
274                    seconds: row.get(1)?,
275                    nanos: row.get(2)?,
276                }
277                .into(),
278            ))
279        })? {
280            let row = row?;
281            result.insert(row.0, row.1);
282        }
283        Ok(result)
284    }
285
286    pub fn top_k_search(
287        &self,
288        query_embedding: &Vec<f32>,
289        limit: usize,
290        file_ids: &[i64],
291    ) -> Result<Vec<(i64, f32)>> {
292        let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
293        self.for_each_document(file_ids, |id, embedding| {
294            let similarity = dot(&embedding, &query_embedding);
295            let ix = match results
296                .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
297            {
298                Ok(ix) => ix,
299                Err(ix) => ix,
300            };
301            results.insert(ix, (id, similarity));
302            results.truncate(limit);
303        })?;
304
305        Ok(results)
306    }
307
308    pub fn retrieve_included_file_ids(
309        &self,
310        worktree_ids: &[i64],
311        includes: &[PathMatcher],
312        excludes: &[PathMatcher],
313    ) -> Result<Vec<i64>> {
314        let mut file_query = self.db.prepare(
315            "
316            SELECT
317                id, relative_path
318            FROM
319                files
320            WHERE
321                worktree_id IN rarray(?)
322            ",
323        )?;
324
325        let mut file_ids = Vec::<i64>::new();
326        let mut rows = file_query.query([ids_to_sql(worktree_ids)])?;
327
328        while let Some(row) = rows.next()? {
329            let file_id = row.get(0)?;
330            let relative_path = row.get_ref(1)?.as_str()?;
331            let included =
332                includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
333            let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
334            if included && !excluded {
335                file_ids.push(file_id);
336            }
337        }
338
339        Ok(file_ids)
340    }
341
342    fn for_each_document(&self, file_ids: &[i64], mut f: impl FnMut(i64, Vec<f32>)) -> Result<()> {
343        let mut query_statement = self.db.prepare(
344            "
345            SELECT
346                id, embedding
347            FROM
348                documents
349            WHERE
350                file_id IN rarray(?)
351            ",
352        )?;
353
354        query_statement
355            .query_map(params![ids_to_sql(&file_ids)], |row| {
356                Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
357            })?
358            .filter_map(|row| row.ok())
359            .for_each(|(id, embedding)| f(id, embedding.0));
360        Ok(())
361    }
362
363    pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
364        let mut statement = self.db.prepare(
365            "
366                SELECT
367                    documents.id,
368                    files.worktree_id,
369                    files.relative_path,
370                    documents.start_byte,
371                    documents.end_byte
372                FROM
373                    documents, files
374                WHERE
375                    documents.file_id = files.id AND
376                    documents.id in rarray(?)
377            ",
378        )?;
379
380        let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
381            Ok((
382                row.get::<_, i64>(0)?,
383                row.get::<_, i64>(1)?,
384                row.get::<_, String>(2)?.into(),
385                row.get(3)?..row.get(4)?,
386            ))
387        })?;
388
389        let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
390        for row in result_iter {
391            let (id, worktree_id, path, range) = row?;
392            values_by_id.insert(id, (worktree_id, path, range));
393        }
394
395        let mut results = Vec::with_capacity(ids.len());
396        for id in ids {
397            let value = values_by_id
398                .remove(id)
399                .ok_or(anyhow!("missing document id {}", id))?;
400            results.push(value);
401        }
402
403        Ok(results)
404    }
405}
406
407fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
408    Rc::new(
409        ids.iter()
410            .copied()
411            .map(|v| rusqlite::types::Value::from(v))
412            .collect::<Vec<_>>(),
413    )
414}
415
416pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
417    let len = vec_a.len();
418    assert_eq!(len, vec_b.len());
419
420    let mut result = 0.0;
421    unsafe {
422        matrixmultiply::sgemm(
423            1,
424            len,
425            1,
426            1.0,
427            vec_a.as_ptr(),
428            len as isize,
429            1,
430            vec_b.as_ptr(),
431            1,
432            len as isize,
433            0.0,
434            &mut result as *mut f32,
435            1,
436            1,
437        );
438    }
439    result
440}