db.rs

  1use crate::{parsing::Document, VECTOR_STORE_VERSION};
  2use anyhow::{anyhow, Result};
  3use project::Fs;
  4use rpc::proto::Timestamp;
  5use rusqlite::{
  6    params,
  7    types::{FromSql, FromSqlResult, ValueRef},
  8};
  9use std::{
 10    cmp::Ordering,
 11    collections::HashMap,
 12    ops::Range,
 13    path::{Path, PathBuf},
 14    rc::Rc,
 15    sync::Arc,
 16    time::SystemTime,
 17};
 18
 19#[derive(Debug)]
 20pub struct FileRecord {
 21    pub id: usize,
 22    pub relative_path: String,
 23    pub mtime: Timestamp,
 24}
 25
 26#[derive(Debug)]
 27struct Embedding(pub Vec<f32>);
 28
 29impl FromSql for Embedding {
 30    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 31        let bytes = value.as_blob()?;
 32        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 33        if embedding.is_err() {
 34            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 35        }
 36        return Ok(Embedding(embedding.unwrap()));
 37    }
 38}
 39
 40pub struct VectorDatabase {
 41    db: rusqlite::Connection,
 42}
 43
 44impl VectorDatabase {
 45    pub async fn new(fs: Arc<dyn Fs>, path: Arc<PathBuf>) -> Result<Self> {
 46        if let Some(db_directory) = path.parent() {
 47            fs.create_dir(db_directory).await?;
 48        }
 49
 50        let this = Self {
 51            db: rusqlite::Connection::open(path.as_path())?,
 52        };
 53        this.initialize_database()?;
 54        Ok(this)
 55    }
 56
 57    fn get_existing_version(&self) -> Result<i64> {
 58        let mut version_query = self.db.prepare("SELECT version from vector_store_config")?;
 59        version_query
 60            .query_row([], |row| Ok(row.get::<_, i64>(0)?))
 61            .map_err(|err| anyhow!("version query failed: {err}"))
 62    }
 63
 64    fn initialize_database(&self) -> Result<()> {
 65        rusqlite::vtab::array::load_module(&self.db)?;
 66
 67        if self
 68            .get_existing_version()
 69            .map_or(false, |version| version == VECTOR_STORE_VERSION as i64)
 70        {
 71            return Ok(());
 72        }
 73
 74        self.db
 75            .execute(
 76                "
 77                    DROP TABLE vector_store_config;
 78                    DROP TABLE worktrees;
 79                    DROP TABLE files;
 80                    DROP TABLE documents;
 81                ",
 82                [],
 83            )
 84            .ok();
 85
 86        // Initialize Vector Databasing Tables
 87        self.db.execute(
 88            "CREATE TABLE vector_store_config (
 89                version INTEGER NOT NULL
 90            )",
 91            [],
 92        )?;
 93
 94        self.db.execute(
 95            "INSERT INTO vector_store_config (version) VALUES (?1)",
 96            params![VECTOR_STORE_VERSION],
 97        )?;
 98
 99        self.db.execute(
100            "CREATE TABLE worktrees (
101                id INTEGER PRIMARY KEY AUTOINCREMENT,
102                absolute_path VARCHAR NOT NULL
103            );
104            CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
105            ",
106            [],
107        )?;
108
109        self.db.execute(
110            "CREATE TABLE files (
111                id INTEGER PRIMARY KEY AUTOINCREMENT,
112                worktree_id INTEGER NOT NULL,
113                relative_path VARCHAR NOT NULL,
114                mtime_seconds INTEGER NOT NULL,
115                mtime_nanos INTEGER NOT NULL,
116                FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
117            )",
118            [],
119        )?;
120
121        self.db.execute(
122            "CREATE TABLE documents (
123                id INTEGER PRIMARY KEY AUTOINCREMENT,
124                file_id INTEGER NOT NULL,
125                start_byte INTEGER NOT NULL,
126                end_byte INTEGER NOT NULL,
127                name VARCHAR NOT NULL,
128                embedding BLOB NOT NULL,
129                FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
130            )",
131            [],
132        )?;
133
134        Ok(())
135    }
136
137    pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> {
138        self.db.execute(
139            "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
140            params![worktree_id, delete_path.to_str()],
141        )?;
142        Ok(())
143    }
144
145    pub fn insert_file(
146        &self,
147        worktree_id: i64,
148        path: PathBuf,
149        mtime: SystemTime,
150        documents: Vec<Document>,
151    ) -> Result<()> {
152        // Write to files table, and return generated id.
153        self.db.execute(
154            "
155            DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
156            ",
157            params![worktree_id, path.to_str()],
158        )?;
159        let mtime = Timestamp::from(mtime);
160        self.db.execute(
161            "
162            INSERT INTO files
163            (worktree_id, relative_path, mtime_seconds, mtime_nanos)
164            VALUES
165            (?1, ?2, $3, $4);
166            ",
167            params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
168        )?;
169
170        let file_id = self.db.last_insert_rowid();
171
172        // Currently inserting at approximately 3400 documents a second
173        // I imagine we can speed this up with a bulk insert of some kind.
174        for document in documents {
175            let embedding_blob = bincode::serialize(&document.embedding)?;
176
177            self.db.execute(
178                "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding) VALUES (?1, ?2, ?3, ?4, $5)",
179                params![
180                    file_id,
181                    document.range.start.to_string(),
182                    document.range.end.to_string(),
183                    document.name,
184                    embedding_blob
185                ],
186            )?;
187        }
188
189        Ok(())
190    }
191
192    pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
193        // Check that the absolute path doesnt exist
194        let mut worktree_query = self
195            .db
196            .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
197
198        let worktree_id = worktree_query
199            .query_row(params![worktree_root_path.to_string_lossy()], |row| {
200                Ok(row.get::<_, i64>(0)?)
201            })
202            .map_err(|err| anyhow!(err));
203
204        if worktree_id.is_ok() {
205            return worktree_id;
206        }
207
208        // If worktree_id is Err, insert new worktree
209        self.db.execute(
210            "
211            INSERT into worktrees (absolute_path) VALUES (?1)
212            ",
213            params![worktree_root_path.to_string_lossy()],
214        )?;
215        Ok(self.db.last_insert_rowid())
216    }
217
218    pub fn get_file_mtimes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, SystemTime>> {
219        let mut statement = self.db.prepare(
220            "
221            SELECT relative_path, mtime_seconds, mtime_nanos
222            FROM files
223            WHERE worktree_id = ?1
224            ORDER BY relative_path",
225        )?;
226        let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
227        for row in statement.query_map(params![worktree_id], |row| {
228            Ok((
229                row.get::<_, String>(0)?.into(),
230                Timestamp {
231                    seconds: row.get(1)?,
232                    nanos: row.get(2)?,
233                }
234                .into(),
235            ))
236        })? {
237            let row = row?;
238            result.insert(row.0, row.1);
239        }
240        Ok(result)
241    }
242
243    pub fn top_k_search(
244        &self,
245        worktree_ids: &[i64],
246        query_embedding: &Vec<f32>,
247        limit: usize,
248    ) -> Result<Vec<(i64, PathBuf, Range<usize>, String)>> {
249        let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
250        self.for_each_document(&worktree_ids, |id, embedding| {
251            let similarity = dot(&embedding, &query_embedding);
252            let ix = match results
253                .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
254            {
255                Ok(ix) => ix,
256                Err(ix) => ix,
257            };
258            results.insert(ix, (id, similarity));
259            results.truncate(limit);
260        })?;
261
262        let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
263        self.get_documents_by_ids(&ids)
264    }
265
266    fn for_each_document(
267        &self,
268        worktree_ids: &[i64],
269        mut f: impl FnMut(i64, Vec<f32>),
270    ) -> Result<()> {
271        let mut query_statement = self.db.prepare(
272            "
273            SELECT
274                documents.id, documents.embedding
275            FROM
276                documents, files
277            WHERE
278                documents.file_id = files.id AND
279                files.worktree_id IN rarray(?)
280            ",
281        )?;
282
283        query_statement
284            .query_map(params![ids_to_sql(worktree_ids)], |row| {
285                Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
286            })?
287            .filter_map(|row| row.ok())
288            .for_each(|(id, embedding)| f(id, embedding.0));
289        Ok(())
290    }
291
292    fn get_documents_by_ids(
293        &self,
294        ids: &[i64],
295    ) -> Result<Vec<(i64, PathBuf, Range<usize>, String)>> {
296        let mut statement = self.db.prepare(
297            "
298                SELECT
299                    documents.id,
300                    files.worktree_id,
301                    files.relative_path,
302                    documents.start_byte,
303                    documents.end_byte, documents.name
304                FROM
305                    documents, files
306                WHERE
307                    documents.file_id = files.id AND
308                    documents.id in rarray(?)
309            ",
310        )?;
311
312        let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
313            Ok((
314                row.get::<_, i64>(0)?,
315                row.get::<_, i64>(1)?,
316                row.get::<_, String>(2)?.into(),
317                row.get(3)?..row.get(4)?,
318                row.get(5)?,
319            ))
320        })?;
321
322        let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>, String)>::default();
323        for row in result_iter {
324            let (id, worktree_id, path, range, name) = row?;
325            values_by_id.insert(id, (worktree_id, path, range, name));
326        }
327
328        let mut results = Vec::with_capacity(ids.len());
329        for id in ids {
330            let value = values_by_id
331                .remove(id)
332                .ok_or(anyhow!("missing document id {}", id))?;
333            results.push(value);
334        }
335
336        Ok(results)
337    }
338}
339
340fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
341    Rc::new(
342        ids.iter()
343            .copied()
344            .map(|v| rusqlite::types::Value::from(v))
345            .collect::<Vec<_>>(),
346    )
347}
348
349pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
350    let len = vec_a.len();
351    assert_eq!(len, vec_b.len());
352
353    let mut result = 0.0;
354    unsafe {
355        matrixmultiply::sgemm(
356            1,
357            len,
358            1,
359            1.0,
360            vec_a.as_ptr(),
361            len as isize,
362            1,
363            vec_b.as_ptr(),
364            1,
365            len as isize,
366            0.0,
367            &mut result as *mut f32,
368            1,
369            1,
370        );
371    }
372    result
373}