db.rs

  1use std::{
  2    cmp::Ordering,
  3    collections::HashMap,
  4    path::{Path, PathBuf},
  5    rc::Rc,
  6    time::SystemTime,
  7};
  8
  9use anyhow::{anyhow, Result};
 10
 11use crate::parsing::ParsedFile;
 12use crate::VECTOR_STORE_VERSION;
 13use rpc::proto::Timestamp;
 14use rusqlite::{
 15    params,
 16    types::{FromSql, FromSqlResult, ValueRef},
 17};
 18
 19#[derive(Debug)]
 20pub struct FileRecord {
 21    pub id: usize,
 22    pub relative_path: String,
 23    pub mtime: Timestamp,
 24}
 25
 26#[derive(Debug)]
 27struct Embedding(pub Vec<f32>);
 28
 29impl FromSql for Embedding {
 30    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 31        let bytes = value.as_blob()?;
 32        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 33        if embedding.is_err() {
 34            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 35        }
 36        return Ok(Embedding(embedding.unwrap()));
 37    }
 38}
 39
 40pub struct VectorDatabase {
 41    db: rusqlite::Connection,
 42}
 43
 44impl VectorDatabase {
 45    pub fn new(path: String) -> Result<Self> {
 46        let this = Self {
 47            db: rusqlite::Connection::open(path)?,
 48        };
 49        this.initialize_database()?;
 50        Ok(this)
 51    }
 52
 53    fn initialize_database(&self) -> Result<()> {
 54        rusqlite::vtab::array::load_module(&self.db)?;
 55
 56        // This will create the database if it doesnt exist
 57
 58        // Initialize Vector Databasing Tables
 59        self.db.execute(
 60            "CREATE TABLE IF NOT EXISTS worktrees (
 61                id INTEGER PRIMARY KEY AUTOINCREMENT,
 62                absolute_path VARCHAR NOT NULL
 63            );
 64            CREATE UNIQUE INDEX IF NOT EXISTS worktrees_absolute_path ON worktrees (absolute_path);
 65            ",
 66            [],
 67        )?;
 68
 69        self.db.execute(
 70            "CREATE TABLE IF NOT EXISTS files (
 71                id INTEGER PRIMARY KEY AUTOINCREMENT,
 72                worktree_id INTEGER NOT NULL,
 73                relative_path VARCHAR NOT NULL,
 74                mtime_seconds INTEGER NOT NULL,
 75                mtime_nanos INTEGER NOT NULL,
 76                vector_store_version INTEGER NOT NULL,
 77                FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
 78            )",
 79            [],
 80        )?;
 81
 82        self.db.execute(
 83            "CREATE TABLE IF NOT EXISTS documents (
 84                id INTEGER PRIMARY KEY AUTOINCREMENT,
 85                file_id INTEGER NOT NULL,
 86                offset INTEGER NOT NULL,
 87                name VARCHAR NOT NULL,
 88                embedding BLOB NOT NULL,
 89                FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
 90            )",
 91            [],
 92        )?;
 93
 94        Ok(())
 95    }
 96
 97    pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> {
 98        self.db.execute(
 99            "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
100            params![worktree_id, delete_path.to_str()],
101        )?;
102        Ok(())
103    }
104
105    pub fn insert_file(&self, worktree_id: i64, indexed_file: ParsedFile) -> Result<()> {
106        // Write to files table, and return generated id.
107        self.db.execute(
108            "
109            DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
110            ",
111            params![worktree_id, indexed_file.path.to_str()],
112        )?;
113        let mtime = Timestamp::from(indexed_file.mtime);
114        self.db.execute(
115            "
116            INSERT INTO files
117            (worktree_id, relative_path, mtime_seconds, mtime_nanos, vector_store_version)
118            VALUES
119            (?1, ?2, $3, $4, $5);
120            ",
121            params![
122                worktree_id,
123                indexed_file.path.to_str(),
124                mtime.seconds,
125                mtime.nanos,
126                VECTOR_STORE_VERSION
127            ],
128        )?;
129
130        let file_id = self.db.last_insert_rowid();
131
132        // Currently inserting at approximately 3400 documents a second
133        // I imagine we can speed this up with a bulk insert of some kind.
134        for document in indexed_file.documents {
135            let embedding_blob = bincode::serialize(&document.embedding)?;
136
137            self.db.execute(
138                "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
139                params![
140                    file_id,
141                    document.offset.to_string(),
142                    document.name,
143                    embedding_blob
144                ],
145            )?;
146        }
147
148        Ok(())
149    }
150
151    pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
152        // Check that the absolute path doesnt exist
153        let mut worktree_query = self
154            .db
155            .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
156
157        let worktree_id = worktree_query
158            .query_row(params![worktree_root_path.to_string_lossy()], |row| {
159                Ok(row.get::<_, i64>(0)?)
160            })
161            .map_err(|err| anyhow!(err));
162
163        if worktree_id.is_ok() {
164            return worktree_id;
165        }
166
167        // If worktree_id is Err, insert new worktree
168        self.db.execute(
169            "
170            INSERT into worktrees (absolute_path) VALUES (?1)
171            ",
172            params![worktree_root_path.to_string_lossy()],
173        )?;
174        Ok(self.db.last_insert_rowid())
175    }
176
177    pub fn get_file_mtimes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, SystemTime>> {
178        let mut statement = self.db.prepare(
179            "
180            SELECT relative_path, mtime_seconds, mtime_nanos
181            FROM files
182            WHERE worktree_id = ?1
183            ORDER BY relative_path",
184        )?;
185        let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
186        for row in statement.query_map(params![worktree_id], |row| {
187            Ok((
188                row.get::<_, String>(0)?.into(),
189                Timestamp {
190                    seconds: row.get(1)?,
191                    nanos: row.get(2)?,
192                }
193                .into(),
194            ))
195        })? {
196            let row = row?;
197            result.insert(row.0, row.1);
198        }
199        Ok(result)
200    }
201
202    pub fn top_k_search(
203        &self,
204        worktree_ids: &[i64],
205        query_embedding: &Vec<f32>,
206        limit: usize,
207    ) -> Result<Vec<(i64, PathBuf, usize, String)>> {
208        let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
209        self.for_each_document(&worktree_ids, |id, embedding| {
210            let similarity = dot(&embedding, &query_embedding);
211            let ix = match results
212                .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
213            {
214                Ok(ix) => ix,
215                Err(ix) => ix,
216            };
217            results.insert(ix, (id, similarity));
218            results.truncate(limit);
219        })?;
220
221        let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
222        self.get_documents_by_ids(&ids)
223    }
224
225    fn for_each_document(
226        &self,
227        worktree_ids: &[i64],
228        mut f: impl FnMut(i64, Vec<f32>),
229    ) -> Result<()> {
230        let mut query_statement = self.db.prepare(
231            "
232            SELECT
233                documents.id, documents.embedding
234            FROM
235                documents, files
236            WHERE
237                documents.file_id = files.id AND
238                files.worktree_id IN rarray(?)
239            ",
240        )?;
241
242        query_statement
243            .query_map(params![ids_to_sql(worktree_ids)], |row| {
244                Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
245            })?
246            .filter_map(|row| row.ok())
247            .for_each(|(id, embedding)| f(id, embedding.0));
248        Ok(())
249    }
250
251    fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, usize, String)>> {
252        let mut statement = self.db.prepare(
253            "
254                SELECT
255                    documents.id, files.worktree_id, files.relative_path, documents.offset, documents.name
256                FROM
257                    documents, files
258                WHERE
259                    documents.file_id = files.id AND
260                    documents.id in rarray(?)
261            ",
262        )?;
263
264        let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
265            Ok((
266                row.get::<_, i64>(0)?,
267                row.get::<_, i64>(1)?,
268                row.get::<_, String>(2)?.into(),
269                row.get(3)?,
270                row.get(4)?,
271            ))
272        })?;
273
274        let mut values_by_id = HashMap::<i64, (i64, PathBuf, usize, String)>::default();
275        for row in result_iter {
276            let (id, worktree_id, path, offset, name) = row?;
277            values_by_id.insert(id, (worktree_id, path, offset, name));
278        }
279
280        let mut results = Vec::with_capacity(ids.len());
281        for id in ids {
282            let value = values_by_id
283                .remove(id)
284                .ok_or(anyhow!("missing document id {}", id))?;
285            results.push(value);
286        }
287
288        Ok(results)
289    }
290}
291
292fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
293    Rc::new(
294        ids.iter()
295            .copied()
296            .map(|v| rusqlite::types::Value::from(v))
297            .collect::<Vec<_>>(),
298    )
299}
300
301pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
302    let len = vec_a.len();
303    assert_eq!(len, vec_b.len());
304
305    let mut result = 0.0;
306    unsafe {
307        matrixmultiply::sgemm(
308            1,
309            len,
310            1,
311            1.0,
312            vec_a.as_ptr(),
313            len as isize,
314            1,
315            vec_b.as_ptr(),
316            1,
317            len as isize,
318            0.0,
319            &mut result as *mut f32,
320            1,
321            1,
322        );
323    }
324    result
325}