db.rs

  1use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
  2use anyhow::{anyhow, Context, Result};
  3use project::Fs;
  4use rpc::proto::Timestamp;
  5use rusqlite::{
  6    params,
  7    types::{FromSql, FromSqlResult, ValueRef},
  8};
  9use std::{
 10    cmp::Ordering,
 11    collections::HashMap,
 12    ops::Range,
 13    path::{Path, PathBuf},
 14    rc::Rc,
 15    sync::Arc,
 16    time::SystemTime,
 17};
 18
 19#[derive(Debug)]
 20pub struct FileRecord {
 21    pub id: usize,
 22    pub relative_path: String,
 23    pub mtime: Timestamp,
 24}
 25
 26#[derive(Debug)]
 27struct Embedding(pub Vec<f32>);
 28
 29impl FromSql for Embedding {
 30    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 31        let bytes = value.as_blob()?;
 32        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 33        if embedding.is_err() {
 34            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 35        }
 36        return Ok(Embedding(embedding.unwrap()));
 37    }
 38}
 39
 40pub struct VectorDatabase {
 41    db: rusqlite::Connection,
 42}
 43
 44impl VectorDatabase {
 45    pub async fn new(fs: Arc<dyn Fs>, path: Arc<PathBuf>) -> Result<Self> {
 46        if let Some(db_directory) = path.parent() {
 47            fs.create_dir(db_directory).await?;
 48        }
 49
 50        let this = Self {
 51            db: rusqlite::Connection::open(path.as_path())?,
 52        };
 53        this.initialize_database()?;
 54        Ok(this)
 55    }
 56
 57    fn get_existing_version(&self) -> Result<i64> {
 58        let mut version_query = self
 59            .db
 60            .prepare("SELECT version from semantic_index_config")?;
 61        version_query
 62            .query_row([], |row| Ok(row.get::<_, i64>(0)?))
 63            .map_err(|err| anyhow!("version query failed: {err}"))
 64    }
 65
 66    fn initialize_database(&self) -> Result<()> {
 67        rusqlite::vtab::array::load_module(&self.db)?;
 68
 69        if self
 70            .get_existing_version()
 71            .map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64)
 72        {
 73            return Ok(());
 74        }
 75
 76        self.db
 77            .execute(
 78                "
 79                DROP TABLE IF EXISTS documents;
 80                DROP TABLE IF EXISTS files;
 81                DROP TABLE IF EXISTS worktrees;
 82                DROP TABLE IF EXISTS semantic_index_config;
 83                ",
 84                [],
 85            )
 86            .context("failed to drop tables")?;
 87
 88        // Initialize Vector Databasing Tables
 89        self.db.execute(
 90            "CREATE TABLE semantic_index_config (
 91                version INTEGER NOT NULL
 92            )",
 93            [],
 94        )?;
 95
 96        self.db.execute(
 97            "INSERT INTO semantic_index_config (version) VALUES (?1)",
 98            params![SEMANTIC_INDEX_VERSION],
 99        )?;
100
101        self.db.execute(
102            "CREATE TABLE worktrees (
103                id INTEGER PRIMARY KEY AUTOINCREMENT,
104                absolute_path VARCHAR NOT NULL
105            );
106            CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
107            ",
108            [],
109        )?;
110
111        self.db.execute(
112            "CREATE TABLE files (
113                id INTEGER PRIMARY KEY AUTOINCREMENT,
114                worktree_id INTEGER NOT NULL,
115                relative_path VARCHAR NOT NULL,
116                mtime_seconds INTEGER NOT NULL,
117                mtime_nanos INTEGER NOT NULL,
118                FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
119            )",
120            [],
121        )?;
122
123        self.db.execute(
124            "CREATE TABLE documents (
125                id INTEGER PRIMARY KEY AUTOINCREMENT,
126                file_id INTEGER NOT NULL,
127                start_byte INTEGER NOT NULL,
128                end_byte INTEGER NOT NULL,
129                name VARCHAR NOT NULL,
130                embedding BLOB NOT NULL,
131                FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
132            )",
133            [],
134        )?;
135
136        Ok(())
137    }
138
139    pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> {
140        self.db.execute(
141            "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
142            params![worktree_id, delete_path.to_str()],
143        )?;
144        Ok(())
145    }
146
147    pub fn insert_file(
148        &self,
149        worktree_id: i64,
150        path: PathBuf,
151        mtime: SystemTime,
152        documents: Vec<Document>,
153    ) -> Result<()> {
154        // Write to files table, and return generated id.
155        self.db.execute(
156            "
157            DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
158            ",
159            params![worktree_id, path.to_str()],
160        )?;
161        let mtime = Timestamp::from(mtime);
162        self.db.execute(
163            "
164            INSERT INTO files
165            (worktree_id, relative_path, mtime_seconds, mtime_nanos)
166            VALUES
167            (?1, ?2, $3, $4);
168            ",
169            params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
170        )?;
171
172        let file_id = self.db.last_insert_rowid();
173
174        // Currently inserting at approximately 3400 documents a second
175        // I imagine we can speed this up with a bulk insert of some kind.
176        for document in documents {
177            let embedding_blob = bincode::serialize(&document.embedding)?;
178
179            self.db.execute(
180                "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding) VALUES (?1, ?2, ?3, ?4, $5)",
181                params![
182                    file_id,
183                    document.range.start.to_string(),
184                    document.range.end.to_string(),
185                    document.name,
186                    embedding_blob
187                ],
188            )?;
189        }
190
191        Ok(())
192    }
193
194    pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
195        // Check that the absolute path doesnt exist
196        let mut worktree_query = self
197            .db
198            .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
199
200        let worktree_id = worktree_query
201            .query_row(params![worktree_root_path.to_string_lossy()], |row| {
202                Ok(row.get::<_, i64>(0)?)
203            })
204            .map_err(|err| anyhow!(err));
205
206        if worktree_id.is_ok() {
207            return worktree_id;
208        }
209
210        // If worktree_id is Err, insert new worktree
211        self.db.execute(
212            "
213            INSERT into worktrees (absolute_path) VALUES (?1)
214            ",
215            params![worktree_root_path.to_string_lossy()],
216        )?;
217        Ok(self.db.last_insert_rowid())
218    }
219
220    pub fn get_file_mtimes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, SystemTime>> {
221        let mut statement = self.db.prepare(
222            "
223            SELECT relative_path, mtime_seconds, mtime_nanos
224            FROM files
225            WHERE worktree_id = ?1
226            ORDER BY relative_path",
227        )?;
228        let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
229        for row in statement.query_map(params![worktree_id], |row| {
230            Ok((
231                row.get::<_, String>(0)?.into(),
232                Timestamp {
233                    seconds: row.get(1)?,
234                    nanos: row.get(2)?,
235                }
236                .into(),
237            ))
238        })? {
239            let row = row?;
240            result.insert(row.0, row.1);
241        }
242        Ok(result)
243    }
244
245    pub fn top_k_search(
246        &self,
247        worktree_ids: &[i64],
248        query_embedding: &Vec<f32>,
249        limit: usize,
250    ) -> Result<Vec<(i64, PathBuf, Range<usize>, String)>> {
251        let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
252        self.for_each_document(&worktree_ids, |id, embedding| {
253            let similarity = dot(&embedding, &query_embedding);
254            let ix = match results
255                .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
256            {
257                Ok(ix) => ix,
258                Err(ix) => ix,
259            };
260            results.insert(ix, (id, similarity));
261            results.truncate(limit);
262        })?;
263
264        let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
265        self.get_documents_by_ids(&ids)
266    }
267
268    fn for_each_document(
269        &self,
270        worktree_ids: &[i64],
271        mut f: impl FnMut(i64, Vec<f32>),
272    ) -> Result<()> {
273        let mut query_statement = self.db.prepare(
274            "
275            SELECT
276                documents.id, documents.embedding
277            FROM
278                documents, files
279            WHERE
280                documents.file_id = files.id AND
281                files.worktree_id IN rarray(?)
282            ",
283        )?;
284
285        query_statement
286            .query_map(params![ids_to_sql(worktree_ids)], |row| {
287                Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
288            })?
289            .filter_map(|row| row.ok())
290            .for_each(|(id, embedding)| f(id, embedding.0));
291        Ok(())
292    }
293
294    fn get_documents_by_ids(
295        &self,
296        ids: &[i64],
297    ) -> Result<Vec<(i64, PathBuf, Range<usize>, String)>> {
298        let mut statement = self.db.prepare(
299            "
300                SELECT
301                    documents.id,
302                    files.worktree_id,
303                    files.relative_path,
304                    documents.start_byte,
305                    documents.end_byte, documents.name
306                FROM
307                    documents, files
308                WHERE
309                    documents.file_id = files.id AND
310                    documents.id in rarray(?)
311            ",
312        )?;
313
314        let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
315            Ok((
316                row.get::<_, i64>(0)?,
317                row.get::<_, i64>(1)?,
318                row.get::<_, String>(2)?.into(),
319                row.get(3)?..row.get(4)?,
320                row.get(5)?,
321            ))
322        })?;
323
324        let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>, String)>::default();
325        for row in result_iter {
326            let (id, worktree_id, path, range, name) = row?;
327            values_by_id.insert(id, (worktree_id, path, range, name));
328        }
329
330        let mut results = Vec::with_capacity(ids.len());
331        for id in ids {
332            let value = values_by_id
333                .remove(id)
334                .ok_or(anyhow!("missing document id {}", id))?;
335            results.push(value);
336        }
337
338        Ok(results)
339    }
340}
341
342fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
343    Rc::new(
344        ids.iter()
345            .copied()
346            .map(|v| rusqlite::types::Value::from(v))
347            .collect::<Vec<_>>(),
348    )
349}
350
351pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
352    let len = vec_a.len();
353    assert_eq!(len, vec_b.len());
354
355    let mut result = 0.0;
356    unsafe {
357        matrixmultiply::sgemm(
358            1,
359            len,
360            1,
361            1.0,
362            vec_a.as_ptr(),
363            len as isize,
364            1,
365            vec_b.as_ptr(),
366            1,
367            len as isize,
368            0.0,
369            &mut result as *mut f32,
370            1,
371            1,
372        );
373    }
374    result
375}