db.rs

  1use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
  2use anyhow::{anyhow, Context, Result};
  3use globset::GlobMatcher;
  4use project::Fs;
  5use rpc::proto::Timestamp;
  6use rusqlite::{
  7    params,
  8    types::{FromSql, FromSqlResult, ValueRef},
  9};
 10use std::{
 11    cmp::Ordering,
 12    collections::HashMap,
 13    ops::Range,
 14    path::{Path, PathBuf},
 15    rc::Rc,
 16    sync::Arc,
 17    time::SystemTime,
 18};
 19
 20#[derive(Debug)]
 21pub struct FileRecord {
 22    pub id: usize,
 23    pub relative_path: String,
 24    pub mtime: Timestamp,
 25}
 26
 27#[derive(Debug)]
 28struct Embedding(pub Vec<f32>);
 29
 30impl FromSql for Embedding {
 31    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 32        let bytes = value.as_blob()?;
 33        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 34        if embedding.is_err() {
 35            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 36        }
 37        return Ok(Embedding(embedding.unwrap()));
 38    }
 39}
 40
 41pub struct VectorDatabase {
 42    db: rusqlite::Connection,
 43}
 44
 45impl VectorDatabase {
 46    pub async fn new(fs: Arc<dyn Fs>, path: Arc<PathBuf>) -> Result<Self> {
 47        if let Some(db_directory) = path.parent() {
 48            fs.create_dir(db_directory).await?;
 49        }
 50
 51        let this = Self {
 52            db: rusqlite::Connection::open(path.as_path())?,
 53        };
 54        this.initialize_database()?;
 55        Ok(this)
 56    }
 57
 58    fn get_existing_version(&self) -> Result<i64> {
 59        let mut version_query = self
 60            .db
 61            .prepare("SELECT version from semantic_index_config")?;
 62        version_query
 63            .query_row([], |row| Ok(row.get::<_, i64>(0)?))
 64            .map_err(|err| anyhow!("version query failed: {err}"))
 65    }
 66
 67    fn initialize_database(&self) -> Result<()> {
 68        rusqlite::vtab::array::load_module(&self.db)?;
 69
 70        // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
 71        if self
 72            .get_existing_version()
 73            .map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64)
 74        {
 75            log::trace!("vector database schema up to date");
 76            return Ok(());
 77        }
 78
 79        log::trace!("vector database schema out of date. updating...");
 80        self.db
 81            .execute("DROP TABLE IF EXISTS documents", [])
 82            .context("failed to drop 'documents' table")?;
 83        self.db
 84            .execute("DROP TABLE IF EXISTS files", [])
 85            .context("failed to drop 'files' table")?;
 86        self.db
 87            .execute("DROP TABLE IF EXISTS worktrees", [])
 88            .context("failed to drop 'worktrees' table")?;
 89        self.db
 90            .execute("DROP TABLE IF EXISTS semantic_index_config", [])
 91            .context("failed to drop 'semantic_index_config' table")?;
 92
 93        // Initialize Vector Databasing Tables
 94        self.db.execute(
 95            "CREATE TABLE semantic_index_config (
 96                version INTEGER NOT NULL
 97            )",
 98            [],
 99        )?;
100
101        self.db.execute(
102            "INSERT INTO semantic_index_config (version) VALUES (?1)",
103            params![SEMANTIC_INDEX_VERSION],
104        )?;
105
106        self.db.execute(
107            "CREATE TABLE worktrees (
108                id INTEGER PRIMARY KEY AUTOINCREMENT,
109                absolute_path VARCHAR NOT NULL
110            );
111            CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
112            ",
113            [],
114        )?;
115
116        self.db.execute(
117            "CREATE TABLE files (
118                id INTEGER PRIMARY KEY AUTOINCREMENT,
119                worktree_id INTEGER NOT NULL,
120                relative_path VARCHAR NOT NULL,
121                mtime_seconds INTEGER NOT NULL,
122                mtime_nanos INTEGER NOT NULL,
123                FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
124            )",
125            [],
126        )?;
127
128        self.db.execute(
129            "CREATE TABLE documents (
130                id INTEGER PRIMARY KEY AUTOINCREMENT,
131                file_id INTEGER NOT NULL,
132                start_byte INTEGER NOT NULL,
133                end_byte INTEGER NOT NULL,
134                name VARCHAR NOT NULL,
135                embedding BLOB NOT NULL,
136                FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
137            )",
138            [],
139        )?;
140
141        log::trace!("vector database initialized with updated schema.");
142        Ok(())
143    }
144
145    pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> {
146        self.db.execute(
147            "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
148            params![worktree_id, delete_path.to_str()],
149        )?;
150        Ok(())
151    }
152
153    pub fn insert_file(
154        &self,
155        worktree_id: i64,
156        path: PathBuf,
157        mtime: SystemTime,
158        documents: Vec<Document>,
159    ) -> Result<()> {
160        // Write to files table, and return generated id.
161        self.db.execute(
162            "
163            DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
164            ",
165            params![worktree_id, path.to_str()],
166        )?;
167        let mtime = Timestamp::from(mtime);
168        self.db.execute(
169            "
170            INSERT INTO files
171            (worktree_id, relative_path, mtime_seconds, mtime_nanos)
172            VALUES
173            (?1, ?2, $3, $4);
174            ",
175            params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
176        )?;
177
178        let file_id = self.db.last_insert_rowid();
179
180        // Currently inserting at approximately 3400 documents a second
181        // I imagine we can speed this up with a bulk insert of some kind.
182        for document in documents {
183            let embedding_blob = bincode::serialize(&document.embedding)?;
184
185            self.db.execute(
186                "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding) VALUES (?1, ?2, ?3, ?4, $5)",
187                params![
188                    file_id,
189                    document.range.start.to_string(),
190                    document.range.end.to_string(),
191                    document.name,
192                    embedding_blob
193                ],
194            )?;
195        }
196
197        Ok(())
198    }
199
200    pub fn worktree_previously_indexed(&self, worktree_root_path: &Path) -> Result<bool> {
201        let mut worktree_query = self
202            .db
203            .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
204        let worktree_id = worktree_query
205            .query_row(params![worktree_root_path.to_string_lossy()], |row| {
206                Ok(row.get::<_, i64>(0)?)
207            })
208            .map_err(|err| anyhow!(err));
209
210        if worktree_id.is_ok() {
211            return Ok(true);
212        } else {
213            return Ok(false);
214        }
215    }
216
217    pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
218        // Check that the absolute path doesnt exist
219        let mut worktree_query = self
220            .db
221            .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
222
223        let worktree_id = worktree_query
224            .query_row(params![worktree_root_path.to_string_lossy()], |row| {
225                Ok(row.get::<_, i64>(0)?)
226            })
227            .map_err(|err| anyhow!(err));
228
229        if worktree_id.is_ok() {
230            return worktree_id;
231        }
232
233        // If worktree_id is Err, insert new worktree
234        self.db.execute(
235            "
236            INSERT into worktrees (absolute_path) VALUES (?1)
237            ",
238            params![worktree_root_path.to_string_lossy()],
239        )?;
240        Ok(self.db.last_insert_rowid())
241    }
242
243    pub fn get_file_mtimes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, SystemTime>> {
244        let mut statement = self.db.prepare(
245            "
246            SELECT relative_path, mtime_seconds, mtime_nanos
247            FROM files
248            WHERE worktree_id = ?1
249            ORDER BY relative_path",
250        )?;
251        let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
252        for row in statement.query_map(params![worktree_id], |row| {
253            Ok((
254                row.get::<_, String>(0)?.into(),
255                Timestamp {
256                    seconds: row.get(1)?,
257                    nanos: row.get(2)?,
258                }
259                .into(),
260            ))
261        })? {
262            let row = row?;
263            result.insert(row.0, row.1);
264        }
265        Ok(result)
266    }
267
268    pub fn top_k_search(
269        &self,
270        query_embedding: &Vec<f32>,
271        limit: usize,
272        file_ids: &[i64],
273    ) -> Result<Vec<(i64, f32)>> {
274        let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
275        self.for_each_document(file_ids, |id, embedding| {
276            let similarity = dot(&embedding, &query_embedding);
277            let ix = match results
278                .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
279            {
280                Ok(ix) => ix,
281                Err(ix) => ix,
282            };
283            results.insert(ix, (id, similarity));
284            results.truncate(limit);
285        })?;
286
287        Ok(results)
288    }
289
290    pub fn retrieve_included_file_ids(
291        &self,
292        worktree_ids: &[i64],
293        include_globs: Vec<GlobMatcher>,
294        exclude_globs: Vec<GlobMatcher>,
295    ) -> Result<Vec<i64>> {
296        let mut file_query = self.db.prepare(
297            "
298            SELECT
299                id, relative_path
300            FROM
301                files
302            WHERE
303                worktree_id IN rarray(?)
304            ",
305        )?;
306
307        let mut file_ids = Vec::<i64>::new();
308        let mut rows = file_query.query([ids_to_sql(worktree_ids)])?;
309
310        while let Some(row) = rows.next()? {
311            let file_id = row.get(0)?;
312            let relative_path = row.get_ref(1)?.as_str()?;
313            let included = include_globs.is_empty()
314                || include_globs
315                    .iter()
316                    .any(|glob| glob.is_match(relative_path));
317            let excluded = exclude_globs
318                .iter()
319                .any(|glob| glob.is_match(relative_path));
320            if included && !excluded {
321                file_ids.push(file_id);
322            }
323        }
324
325        Ok(file_ids)
326    }
327
328    fn for_each_document(&self, file_ids: &[i64], mut f: impl FnMut(i64, Vec<f32>)) -> Result<()> {
329        let mut query_statement = self.db.prepare(
330            "
331            SELECT
332                id, embedding
333            FROM
334                documents
335            WHERE
336                file_id IN rarray(?)
337            ",
338        )?;
339
340        query_statement
341            .query_map(params![ids_to_sql(&file_ids)], |row| {
342                Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
343            })?
344            .filter_map(|row| row.ok())
345            .for_each(|(id, embedding)| f(id, embedding.0));
346        Ok(())
347    }
348
349    pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
350        let mut statement = self.db.prepare(
351            "
352                SELECT
353                    documents.id,
354                    files.worktree_id,
355                    files.relative_path,
356                    documents.start_byte,
357                    documents.end_byte
358                FROM
359                    documents, files
360                WHERE
361                    documents.file_id = files.id AND
362                    documents.id in rarray(?)
363            ",
364        )?;
365
366        let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
367            Ok((
368                row.get::<_, i64>(0)?,
369                row.get::<_, i64>(1)?,
370                row.get::<_, String>(2)?.into(),
371                row.get(3)?..row.get(4)?,
372            ))
373        })?;
374
375        let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
376        for row in result_iter {
377            let (id, worktree_id, path, range) = row?;
378            values_by_id.insert(id, (worktree_id, path, range));
379        }
380
381        let mut results = Vec::with_capacity(ids.len());
382        for id in ids {
383            let value = values_by_id
384                .remove(id)
385                .ok_or(anyhow!("missing document id {}", id))?;
386            results.push(value);
387        }
388
389        Ok(results)
390    }
391}
392
393fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
394    Rc::new(
395        ids.iter()
396            .copied()
397            .map(|v| rusqlite::types::Value::from(v))
398            .collect::<Vec<_>>(),
399    )
400}
401
402pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
403    let len = vec_a.len();
404    assert_eq!(len, vec_b.len());
405
406    let mut result = 0.0;
407    unsafe {
408        matrixmultiply::sgemm(
409            1,
410            len,
411            1,
412            1.0,
413            vec_a.as_ptr(),
414            len as isize,
415            1,
416            vec_b.as_ptr(),
417            1,
418            len as isize,
419            0.0,
420            &mut result as *mut f32,
421            1,
422            1,
423        );
424    }
425    result
426}