db.rs

  1use std::{
  2    cmp::Ordering,
  3    collections::HashMap,
  4    path::{Path, PathBuf},
  5    rc::Rc,
  6    time::SystemTime,
  7};
  8
  9use anyhow::{anyhow, Result};
 10
 11use crate::parsing::ParsedFile;
 12use rpc::proto::Timestamp;
 13use rusqlite::{
 14    params,
 15    types::{FromSql, FromSqlResult, ValueRef},
 16};
 17
 18#[derive(Debug)]
 19pub struct FileRecord {
 20    pub id: usize,
 21    pub relative_path: String,
 22    pub mtime: Timestamp,
 23}
 24
 25#[derive(Debug)]
 26struct Embedding(pub Vec<f32>);
 27
 28impl FromSql for Embedding {
 29    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 30        let bytes = value.as_blob()?;
 31        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 32        if embedding.is_err() {
 33            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 34        }
 35        return Ok(Embedding(embedding.unwrap()));
 36    }
 37}
 38
 39pub struct VectorDatabase {
 40    db: rusqlite::Connection,
 41}
 42
 43impl VectorDatabase {
 44    pub fn new(path: String) -> Result<Self> {
 45        let this = Self {
 46            db: rusqlite::Connection::open(path)?,
 47        };
 48        this.initialize_database()?;
 49        Ok(this)
 50    }
 51
 52    fn initialize_database(&self) -> Result<()> {
 53        rusqlite::vtab::array::load_module(&self.db)?;
 54
 55        // This will create the database if it doesnt exist
 56
 57        // Initialize Vector Databasing Tables
 58        self.db.execute(
 59            "CREATE TABLE IF NOT EXISTS worktrees (
 60                id INTEGER PRIMARY KEY AUTOINCREMENT,
 61                absolute_path VARCHAR NOT NULL
 62            );
 63            CREATE UNIQUE INDEX IF NOT EXISTS worktrees_absolute_path ON worktrees (absolute_path);
 64            ",
 65            [],
 66        )?;
 67
 68        self.db.execute(
 69            "CREATE TABLE IF NOT EXISTS files (
 70                id INTEGER PRIMARY KEY AUTOINCREMENT,
 71                worktree_id INTEGER NOT NULL,
 72                relative_path VARCHAR NOT NULL,
 73                mtime_seconds INTEGER NOT NULL,
 74                mtime_nanos INTEGER NOT NULL,
 75                FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
 76            )",
 77            [],
 78        )?;
 79
 80        self.db.execute(
 81            "CREATE TABLE IF NOT EXISTS documents (
 82                id INTEGER PRIMARY KEY AUTOINCREMENT,
 83                file_id INTEGER NOT NULL,
 84                offset INTEGER NOT NULL,
 85                name VARCHAR NOT NULL,
 86                embedding BLOB NOT NULL,
 87                FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
 88            )",
 89            [],
 90        )?;
 91
 92        Ok(())
 93    }
 94
 95    pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> {
 96        self.db.execute(
 97            "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
 98            params![worktree_id, delete_path.to_str()],
 99        )?;
100        Ok(())
101    }
102
103    pub fn insert_file(&self, worktree_id: i64, indexed_file: ParsedFile) -> Result<()> {
104        // Write to files table, and return generated id.
105        self.db.execute(
106            "
107            DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
108            ",
109            params![worktree_id, indexed_file.path.to_str()],
110        )?;
111        let mtime = Timestamp::from(indexed_file.mtime);
112        self.db.execute(
113            "
114            INSERT INTO files
115            (worktree_id, relative_path, mtime_seconds, mtime_nanos)
116            VALUES
117            (?1, ?2, $3, $4);
118            ",
119            params![
120                worktree_id,
121                indexed_file.path.to_str(),
122                mtime.seconds,
123                mtime.nanos
124            ],
125        )?;
126
127        let file_id = self.db.last_insert_rowid();
128
129        // Currently inserting at approximately 3400 documents a second
130        // I imagine we can speed this up with a bulk insert of some kind.
131        for document in indexed_file.documents {
132            let embedding_blob = bincode::serialize(&document.embedding)?;
133
134            self.db.execute(
135                "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
136                params![
137                    file_id,
138                    document.offset.to_string(),
139                    document.name,
140                    embedding_blob
141                ],
142            )?;
143        }
144
145        Ok(())
146    }
147
148    pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
149        // Check that the absolute path doesnt exist
150        let mut worktree_query = self
151            .db
152            .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
153
154        let worktree_id = worktree_query
155            .query_row(params![worktree_root_path.to_string_lossy()], |row| {
156                Ok(row.get::<_, i64>(0)?)
157            })
158            .map_err(|err| anyhow!(err));
159
160        if worktree_id.is_ok() {
161            return worktree_id;
162        }
163
164        // If worktree_id is Err, insert new worktree
165        self.db.execute(
166            "
167            INSERT into worktrees (absolute_path) VALUES (?1)
168            ",
169            params![worktree_root_path.to_string_lossy()],
170        )?;
171        Ok(self.db.last_insert_rowid())
172    }
173
174    pub fn get_file_mtimes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, SystemTime>> {
175        let mut statement = self.db.prepare(
176            "
177            SELECT relative_path, mtime_seconds, mtime_nanos
178            FROM files
179            WHERE worktree_id = ?1
180            ORDER BY relative_path",
181        )?;
182        let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
183        for row in statement.query_map(params![worktree_id], |row| {
184            Ok((
185                row.get::<_, String>(0)?.into(),
186                Timestamp {
187                    seconds: row.get(1)?,
188                    nanos: row.get(2)?,
189                }
190                .into(),
191            ))
192        })? {
193            let row = row?;
194            result.insert(row.0, row.1);
195        }
196        Ok(result)
197    }
198
199    pub fn top_k_search(
200        &self,
201        worktree_ids: &[i64],
202        query_embedding: &Vec<f32>,
203        limit: usize,
204    ) -> Result<Vec<(i64, PathBuf, usize, String)>> {
205        let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
206        self.for_each_document(&worktree_ids, |id, embedding| {
207            eprintln!("document {id} {embedding:?}");
208
209            let similarity = dot(&embedding, &query_embedding);
210            let ix = match results
211                .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
212            {
213                Ok(ix) => ix,
214                Err(ix) => ix,
215            };
216            results.insert(ix, (id, similarity));
217            results.truncate(limit);
218        })?;
219
220        let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
221        self.get_documents_by_ids(&ids)
222    }
223
224    fn for_each_document(
225        &self,
226        worktree_ids: &[i64],
227        mut f: impl FnMut(i64, Vec<f32>),
228    ) -> Result<()> {
229        let mut query_statement = self.db.prepare(
230            "
231            SELECT
232                documents.id, documents.embedding
233            FROM
234                documents, files
235            WHERE
236                documents.file_id = files.id AND
237                files.worktree_id IN rarray(?)
238            ",
239        )?;
240
241        query_statement
242            .query_map(params![ids_to_sql(worktree_ids)], |row| {
243                Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
244            })?
245            .filter_map(|row| row.ok())
246            .for_each(|(id, embedding)| {
247                dbg!("id");
248                f(id, embedding.0)
249            });
250        Ok(())
251    }
252
253    fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, usize, String)>> {
254        let mut statement = self.db.prepare(
255            "
256                SELECT
257                    documents.id, files.worktree_id, files.relative_path, documents.offset, documents.name
258                FROM
259                    documents, files
260                WHERE
261                    documents.file_id = files.id AND
262                    documents.id in rarray(?)
263            ",
264        )?;
265
266        let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
267            Ok((
268                row.get::<_, i64>(0)?,
269                row.get::<_, i64>(1)?,
270                row.get::<_, String>(2)?.into(),
271                row.get(3)?,
272                row.get(4)?,
273            ))
274        })?;
275
276        let mut values_by_id = HashMap::<i64, (i64, PathBuf, usize, String)>::default();
277        for row in result_iter {
278            let (id, worktree_id, path, offset, name) = row?;
279            values_by_id.insert(id, (worktree_id, path, offset, name));
280        }
281
282        let mut results = Vec::with_capacity(ids.len());
283        for id in ids {
284            let value = values_by_id
285                .remove(id)
286                .ok_or(anyhow!("missing document id {}", id))?;
287            results.push(value);
288        }
289
290        Ok(results)
291    }
292}
293
294fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
295    Rc::new(
296        ids.iter()
297            .copied()
298            .map(|v| rusqlite::types::Value::from(v))
299            .collect::<Vec<_>>(),
300    )
301}
302
303pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
304    let len = vec_a.len();
305    assert_eq!(len, vec_b.len());
306
307    let mut result = 0.0;
308    unsafe {
309        matrixmultiply::sgemm(
310            1,
311            len,
312            1,
313            1.0,
314            vec_a.as_ptr(),
315            len as isize,
316            1,
317            vec_b.as_ptr(),
318            1,
319            len as isize,
320            0.0,
321            &mut result as *mut f32,
322            1,
323            1,
324        );
325    }
326    result
327}