db.rs

  1use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
  2use anyhow::{anyhow, Context, Result};
  3use project::Fs;
  4use rpc::proto::Timestamp;
  5use rusqlite::{
  6    params,
  7    types::{FromSql, FromSqlResult, ValueRef},
  8};
  9use std::{
 10    cmp::Ordering,
 11    collections::HashMap,
 12    ops::Range,
 13    path::{Path, PathBuf},
 14    rc::Rc,
 15    sync::Arc,
 16    time::SystemTime,
 17};
 18
 19#[derive(Debug)]
 20pub struct FileRecord {
 21    pub id: usize,
 22    pub relative_path: String,
 23    pub mtime: Timestamp,
 24}
 25
 26#[derive(Debug)]
 27struct Embedding(pub Vec<f32>);
 28
 29impl FromSql for Embedding {
 30    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 31        let bytes = value.as_blob()?;
 32        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 33        if embedding.is_err() {
 34            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 35        }
 36        return Ok(Embedding(embedding.unwrap()));
 37    }
 38}
 39
 40pub struct VectorDatabase {
 41    db: rusqlite::Connection,
 42}
 43
 44impl VectorDatabase {
 45    pub async fn new(fs: Arc<dyn Fs>, path: Arc<PathBuf>) -> Result<Self> {
 46        if let Some(db_directory) = path.parent() {
 47            fs.create_dir(db_directory).await?;
 48        }
 49
 50        let this = Self {
 51            db: rusqlite::Connection::open(path.as_path())?,
 52        };
 53        this.initialize_database()?;
 54        Ok(this)
 55    }
 56
 57    fn get_existing_version(&self) -> Result<i64> {
 58        let mut version_query = self
 59            .db
 60            .prepare("SELECT version from semantic_index_config")?;
 61        version_query
 62            .query_row([], |row| Ok(row.get::<_, i64>(0)?))
 63            .map_err(|err| anyhow!("version query failed: {err}"))
 64    }
 65
 66    fn initialize_database(&self) -> Result<()> {
 67        rusqlite::vtab::array::load_module(&self.db)?;
 68
 69        // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
 70        if self
 71            .get_existing_version()
 72            .map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64)
 73        {
 74            log::trace!("vector database schema up to date");
 75            return Ok(());
 76        }
 77
 78        log::trace!("vector database schema out of date. updating...");
 79        self.db
 80            .execute("DROP TABLE IF EXISTS documents", [])
 81            .context("failed to drop 'documents' table")?;
 82        self.db
 83            .execute("DROP TABLE IF EXISTS files", [])
 84            .context("failed to drop 'files' table")?;
 85        self.db
 86            .execute("DROP TABLE IF EXISTS worktrees", [])
 87            .context("failed to drop 'worktrees' table")?;
 88        self.db
 89            .execute("DROP TABLE IF EXISTS semantic_index_config", [])
 90            .context("failed to drop 'semantic_index_config' table")?;
 91
 92        // Initialize Vector Databasing Tables
 93        self.db.execute(
 94            "CREATE TABLE semantic_index_config (
 95                version INTEGER NOT NULL
 96            )",
 97            [],
 98        )?;
 99
100        self.db.execute(
101            "INSERT INTO semantic_index_config (version) VALUES (?1)",
102            params![SEMANTIC_INDEX_VERSION],
103        )?;
104
105        self.db.execute(
106            "CREATE TABLE worktrees (
107                id INTEGER PRIMARY KEY AUTOINCREMENT,
108                absolute_path VARCHAR NOT NULL
109            );
110            CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
111            ",
112            [],
113        )?;
114
115        self.db.execute(
116            "CREATE TABLE files (
117                id INTEGER PRIMARY KEY AUTOINCREMENT,
118                worktree_id INTEGER NOT NULL,
119                relative_path VARCHAR NOT NULL,
120                mtime_seconds INTEGER NOT NULL,
121                mtime_nanos INTEGER NOT NULL,
122                FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
123            )",
124            [],
125        )?;
126
127        self.db.execute(
128            "CREATE TABLE documents (
129                id INTEGER PRIMARY KEY AUTOINCREMENT,
130                file_id INTEGER NOT NULL,
131                start_byte INTEGER NOT NULL,
132                end_byte INTEGER NOT NULL,
133                name VARCHAR NOT NULL,
134                embedding BLOB NOT NULL,
135                FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
136            )",
137            [],
138        )?;
139
140        log::trace!("vector database initialized with updated schema.");
141        Ok(())
142    }
143
144    pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> {
145        self.db.execute(
146            "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
147            params![worktree_id, delete_path.to_str()],
148        )?;
149        Ok(())
150    }
151
152    pub fn insert_file(
153        &self,
154        worktree_id: i64,
155        path: PathBuf,
156        mtime: SystemTime,
157        documents: Vec<Document>,
158    ) -> Result<()> {
159        // Write to files table, and return generated id.
160        self.db.execute(
161            "
162            DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
163            ",
164            params![worktree_id, path.to_str()],
165        )?;
166        let mtime = Timestamp::from(mtime);
167        self.db.execute(
168            "
169            INSERT INTO files
170            (worktree_id, relative_path, mtime_seconds, mtime_nanos)
171            VALUES
172            (?1, ?2, $3, $4);
173            ",
174            params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
175        )?;
176
177        let file_id = self.db.last_insert_rowid();
178
179        // Currently inserting at approximately 3400 documents a second
180        // I imagine we can speed this up with a bulk insert of some kind.
181        for document in documents {
182            let embedding_blob = bincode::serialize(&document.embedding)?;
183
184            self.db.execute(
185                "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding) VALUES (?1, ?2, ?3, ?4, $5)",
186                params![
187                    file_id,
188                    document.range.start.to_string(),
189                    document.range.end.to_string(),
190                    document.name,
191                    embedding_blob
192                ],
193            )?;
194        }
195
196        Ok(())
197    }
198
199    pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
200        // Check that the absolute path doesnt exist
201        let mut worktree_query = self
202            .db
203            .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
204
205        let worktree_id = worktree_query
206            .query_row(params![worktree_root_path.to_string_lossy()], |row| {
207                Ok(row.get::<_, i64>(0)?)
208            })
209            .map_err(|err| anyhow!(err));
210
211        if worktree_id.is_ok() {
212            return worktree_id;
213        }
214
215        // If worktree_id is Err, insert new worktree
216        self.db.execute(
217            "
218            INSERT into worktrees (absolute_path) VALUES (?1)
219            ",
220            params![worktree_root_path.to_string_lossy()],
221        )?;
222        Ok(self.db.last_insert_rowid())
223    }
224
225    pub fn get_file_mtimes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, SystemTime>> {
226        let mut statement = self.db.prepare(
227            "
228            SELECT relative_path, mtime_seconds, mtime_nanos
229            FROM files
230            WHERE worktree_id = ?1
231            ORDER BY relative_path",
232        )?;
233        let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
234        for row in statement.query_map(params![worktree_id], |row| {
235            Ok((
236                row.get::<_, String>(0)?.into(),
237                Timestamp {
238                    seconds: row.get(1)?,
239                    nanos: row.get(2)?,
240                }
241                .into(),
242            ))
243        })? {
244            let row = row?;
245            result.insert(row.0, row.1);
246        }
247        Ok(result)
248    }
249
250    pub fn top_k_search(
251        &self,
252        worktree_ids: &[i64],
253        query_embedding: &Vec<f32>,
254        limit: usize,
255    ) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
256        let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
257        self.for_each_document(&worktree_ids, |id, embedding| {
258            let similarity = dot(&embedding, &query_embedding);
259            let ix = match results
260                .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
261            {
262                Ok(ix) => ix,
263                Err(ix) => ix,
264            };
265            results.insert(ix, (id, similarity));
266            results.truncate(limit);
267        })?;
268
269        let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
270        self.get_documents_by_ids(&ids)
271    }
272
273    fn for_each_document(
274        &self,
275        worktree_ids: &[i64],
276        mut f: impl FnMut(i64, Vec<f32>),
277    ) -> Result<()> {
278        let mut query_statement = self.db.prepare(
279            "
280            SELECT
281                documents.id, documents.embedding
282            FROM
283                documents, files
284            WHERE
285                documents.file_id = files.id AND
286                files.worktree_id IN rarray(?)
287            ",
288        )?;
289
290        query_statement
291            .query_map(params![ids_to_sql(worktree_ids)], |row| {
292                Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
293            })?
294            .filter_map(|row| row.ok())
295            .for_each(|(id, embedding)| f(id, embedding.0));
296        Ok(())
297    }
298
299    fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
300        let mut statement = self.db.prepare(
301            "
302                SELECT
303                    documents.id,
304                    files.worktree_id,
305                    files.relative_path,
306                    documents.start_byte,
307                    documents.end_byte
308                FROM
309                    documents, files
310                WHERE
311                    documents.file_id = files.id AND
312                    documents.id in rarray(?)
313            ",
314        )?;
315
316        let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
317            Ok((
318                row.get::<_, i64>(0)?,
319                row.get::<_, i64>(1)?,
320                row.get::<_, String>(2)?.into(),
321                row.get(3)?..row.get(4)?,
322            ))
323        })?;
324
325        let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
326        for row in result_iter {
327            let (id, worktree_id, path, range) = row?;
328            values_by_id.insert(id, (worktree_id, path, range));
329        }
330
331        let mut results = Vec::with_capacity(ids.len());
332        for id in ids {
333            let value = values_by_id
334                .remove(id)
335                .ok_or(anyhow!("missing document id {}", id))?;
336            results.push(value);
337        }
338
339        Ok(results)
340    }
341}
342
343fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
344    Rc::new(
345        ids.iter()
346            .copied()
347            .map(|v| rusqlite::types::Value::from(v))
348            .collect::<Vec<_>>(),
349    )
350}
351
352pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
353    let len = vec_a.len();
354    assert_eq!(len, vec_b.len());
355
356    let mut result = 0.0;
357    unsafe {
358        matrixmultiply::sgemm(
359            1,
360            len,
361            1,
362            1.0,
363            vec_a.as_ptr(),
364            len as isize,
365            1,
366            vec_b.as_ptr(),
367            1,
368            len as isize,
369            0.0,
370            &mut result as *mut f32,
371            1,
372            1,
373        );
374    }
375    result
376}