db.rs

  1use std::{
  2    collections::HashMap,
  3    path::{Path, PathBuf},
  4    rc::Rc,
  5};
  6
  7use anyhow::{anyhow, Result};
  8
  9use rusqlite::{
 10    params,
 11    types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
 12    ToSql,
 13};
 14use sha1::{Digest, Sha1};
 15
 16use crate::IndexedFile;
 17
 18// Note this is not an appropriate document
 19#[derive(Debug)]
 20pub struct DocumentRecord {
 21    pub id: usize,
 22    pub file_id: usize,
 23    pub offset: usize,
 24    pub name: String,
 25    pub embedding: Embedding,
 26}
 27
 28#[derive(Debug)]
 29pub struct FileRecord {
 30    pub id: usize,
 31    pub relative_path: String,
 32    pub sha1: FileSha1,
 33}
 34
 35#[derive(Debug)]
 36pub struct FileSha1(pub Vec<u8>);
 37
 38impl FileSha1 {
 39    pub fn from_str(content: String) -> Self {
 40        let mut hasher = Sha1::new();
 41        hasher.update(content);
 42        let sha1 = hasher.finalize()[..]
 43            .into_iter()
 44            .map(|val| val.to_owned())
 45            .collect::<Vec<u8>>();
 46        return FileSha1(sha1);
 47    }
 48
 49    pub fn equals(&self, content: &String) -> bool {
 50        let mut hasher = Sha1::new();
 51        hasher.update(content);
 52        let sha1 = hasher.finalize()[..]
 53            .into_iter()
 54            .map(|val| val.to_owned())
 55            .collect::<Vec<u8>>();
 56
 57        let equal = self
 58            .0
 59            .clone()
 60            .into_iter()
 61            .zip(sha1)
 62            .filter(|&(a, b)| a == b)
 63            .count()
 64            == self.0.len();
 65
 66        equal
 67    }
 68}
 69
 70impl ToSql for FileSha1 {
 71    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
 72        return self.0.to_sql();
 73    }
 74}
 75
 76impl FromSql for FileSha1 {
 77    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 78        let bytes = value.as_blob()?;
 79        Ok(FileSha1(
 80            bytes
 81                .into_iter()
 82                .map(|val| val.to_owned())
 83                .collect::<Vec<u8>>(),
 84        ))
 85    }
 86}
 87
 88#[derive(Debug)]
 89pub struct Embedding(pub Vec<f32>);
 90
 91impl FromSql for Embedding {
 92    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 93        let bytes = value.as_blob()?;
 94        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 95        if embedding.is_err() {
 96            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 97        }
 98        return Ok(Embedding(embedding.unwrap()));
 99    }
100}
101
102pub struct VectorDatabase {
103    db: rusqlite::Connection,
104}
105
106impl VectorDatabase {
107    pub fn new(path: String) -> Result<Self> {
108        let this = Self {
109            db: rusqlite::Connection::open(path)?,
110        };
111        this.initialize_database()?;
112        Ok(this)
113    }
114
115    fn initialize_database(&self) -> Result<()> {
116        rusqlite::vtab::array::load_module(&self.db)?;
117
118        // This will create the database if it doesnt exist
119
120        // Initialize Vector Databasing Tables
121        self.db.execute(
122            "CREATE TABLE IF NOT EXISTS worktrees (
123                id INTEGER PRIMARY KEY AUTOINCREMENT,
124                absolute_path VARCHAR NOT NULL
125            );
126            CREATE UNIQUE INDEX IF NOT EXISTS worktrees_absolute_path ON worktrees (absolute_path);
127            ",
128            [],
129        )?;
130
131        self.db.execute(
132            "CREATE TABLE IF NOT EXISTS files (
133                id INTEGER PRIMARY KEY AUTOINCREMENT,
134                worktree_id INTEGER NOT NULL,
135                relative_path VARCHAR NOT NULL,
136                sha1 BLOB NOT NULL,
137                FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
138            )",
139            [],
140        )?;
141
142        self.db.execute(
143            "CREATE TABLE IF NOT EXISTS documents (
144                id INTEGER PRIMARY KEY AUTOINCREMENT,
145                file_id INTEGER NOT NULL,
146                offset INTEGER NOT NULL,
147                name VARCHAR NOT NULL,
148                embedding BLOB NOT NULL,
149                FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
150            )",
151            [],
152        )?;
153
154        Ok(())
155    }
156
157    pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> {
158        self.db.execute(
159            "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
160            params![worktree_id, delete_path.to_str()],
161        )?;
162        Ok(())
163    }
164
165    pub fn insert_file(&self, worktree_id: i64, indexed_file: IndexedFile) -> Result<()> {
166        // Write to files table, and return generated id.
167        self.db.execute(
168            "
169            DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
170            ",
171            params![worktree_id, indexed_file.path.to_str()],
172        )?;
173        self.db.execute(
174            "
175            INSERT INTO files (worktree_id, relative_path, sha1) VALUES (?1, ?2, $3);
176            ",
177            params![worktree_id, indexed_file.path.to_str(), indexed_file.sha1],
178        )?;
179
180        let file_id = self.db.last_insert_rowid();
181
182        // Currently inserting at approximately 3400 documents a second
183        // I imagine we can speed this up with a bulk insert of some kind.
184        for document in indexed_file.documents {
185            let embedding_blob = bincode::serialize(&document.embedding)?;
186
187            self.db.execute(
188                "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
189                params![
190                    file_id,
191                    document.offset.to_string(),
192                    document.name,
193                    embedding_blob
194                ],
195            )?;
196        }
197
198        Ok(())
199    }
200
201    pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
202        // Check that the absolute path doesnt exist
203        let mut worktree_query = self
204            .db
205            .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
206
207        let worktree_id = worktree_query
208            .query_row(params![worktree_root_path.to_string_lossy()], |row| {
209                Ok(row.get::<_, i64>(0)?)
210            })
211            .map_err(|err| anyhow!(err));
212
213        if worktree_id.is_ok() {
214            return worktree_id;
215        }
216
217        // If worktree_id is Err, insert new worktree
218        self.db.execute(
219            "
220            INSERT into worktrees (absolute_path) VALUES (?1)
221            ",
222            params![worktree_root_path.to_string_lossy()],
223        )?;
224        Ok(self.db.last_insert_rowid())
225    }
226
227    pub fn get_file_hashes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, FileSha1>> {
228        let mut statement = self.db.prepare(
229            "SELECT relative_path, sha1 FROM files WHERE worktree_id = ?1 ORDER BY relative_path",
230        )?;
231        let mut result: HashMap<PathBuf, FileSha1> = HashMap::new();
232        for row in statement.query_map(params![worktree_id], |row| {
233            Ok((row.get::<_, String>(0)?.into(), row.get(1)?))
234        })? {
235            let row = row?;
236            result.insert(row.0, row.1);
237        }
238        Ok(result)
239    }
240
241    pub fn for_each_document(
242        &self,
243        worktree_ids: &[i64],
244        mut f: impl FnMut(i64, Embedding),
245    ) -> Result<()> {
246        let mut query_statement = self.db.prepare(
247            "
248            SELECT
249                documents.id, documents.embedding
250            FROM
251                documents, files
252            WHERE
253                documents.file_id = files.id AND
254                files.worktree_id IN rarray(?)
255            ",
256        )?;
257        query_statement
258            .query_map(params![ids_to_sql(worktree_ids)], |row| {
259                Ok((row.get(0)?, row.get(1)?))
260            })?
261            .filter_map(|row| row.ok())
262            .for_each(|row| f(row.0, row.1));
263        Ok(())
264    }
265
266    pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, usize, String)>> {
267        let mut statement = self.db.prepare(
268            "
269                SELECT
270                    documents.id, files.worktree_id, files.relative_path, documents.offset, documents.name
271                FROM
272                    documents, files
273                WHERE
274                    documents.file_id = files.id AND
275                    documents.id in rarray(?)
276            ",
277        )?;
278
279        let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
280            Ok((
281                row.get::<_, i64>(0)?,
282                row.get::<_, i64>(1)?,
283                row.get::<_, String>(2)?.into(),
284                row.get(3)?,
285                row.get(4)?,
286            ))
287        })?;
288
289        let mut values_by_id = HashMap::<i64, (i64, PathBuf, usize, String)>::default();
290        for row in result_iter {
291            let (id, worktree_id, path, offset, name) = row?;
292            values_by_id.insert(id, (worktree_id, path, offset, name));
293        }
294
295        let mut results = Vec::with_capacity(ids.len());
296        for id in ids {
297            let value = values_by_id
298                .remove(id)
299                .ok_or(anyhow!("missing document id {}", id))?;
300            results.push(value);
301        }
302
303        Ok(results)
304    }
305}
306
307fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
308    Rc::new(
309        ids.iter()
310            .copied()
311            .map(|v| rusqlite::types::Value::from(v))
312            .collect::<Vec<_>>(),
313    )
314}