db.rs

  1use std::{
  2    collections::HashMap,
  3    path::{Path, PathBuf},
  4};
  5
  6use anyhow::{anyhow, Result};
  7
  8use rusqlite::{
  9    params,
 10    types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
 11    ToSql,
 12};
 13use sha1::{Digest, Sha1};
 14
 15use crate::IndexedFile;
 16
 17// This is saving to a local database store within the users dev zed path
 18// Where do we want this to sit?
 19// Assuming near where the workspace DB sits.
 20pub const VECTOR_DB_URL: &str = "embeddings_db";
 21
 22// Note this is not an appropriate document
 23#[derive(Debug)]
 24pub struct DocumentRecord {
 25    pub id: usize,
 26    pub file_id: usize,
 27    pub offset: usize,
 28    pub name: String,
 29    pub embedding: Embedding,
 30}
 31
 32#[derive(Debug)]
 33pub struct FileRecord {
 34    pub id: usize,
 35    pub relative_path: String,
 36    pub sha1: FileSha1,
 37}
 38
 39#[derive(Debug)]
 40pub struct FileSha1(pub Vec<u8>);
 41
 42impl FileSha1 {
 43    pub fn from_str(content: String) -> Self {
 44        let mut hasher = Sha1::new();
 45        hasher.update(content);
 46        let sha1 = hasher.finalize()[..]
 47            .into_iter()
 48            .map(|val| val.to_owned())
 49            .collect::<Vec<u8>>();
 50        return FileSha1(sha1);
 51    }
 52
 53    pub fn equals(&self, content: &String) -> bool {
 54        let mut hasher = Sha1::new();
 55        hasher.update(content);
 56        let sha1 = hasher.finalize()[..]
 57            .into_iter()
 58            .map(|val| val.to_owned())
 59            .collect::<Vec<u8>>();
 60
 61        let equal = self
 62            .0
 63            .clone()
 64            .into_iter()
 65            .zip(sha1)
 66            .filter(|&(a, b)| a == b)
 67            .count()
 68            == self.0.len();
 69
 70        equal
 71    }
 72}
 73
 74impl ToSql for FileSha1 {
 75    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
 76        return self.0.to_sql();
 77    }
 78}
 79
 80impl FromSql for FileSha1 {
 81    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 82        let bytes = value.as_blob()?;
 83        Ok(FileSha1(
 84            bytes
 85                .into_iter()
 86                .map(|val| val.to_owned())
 87                .collect::<Vec<u8>>(),
 88        ))
 89    }
 90}
 91
 92#[derive(Debug)]
 93pub struct Embedding(pub Vec<f32>);
 94
 95impl FromSql for Embedding {
 96    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 97        let bytes = value.as_blob()?;
 98        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 99        if embedding.is_err() {
100            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
101        }
102        return Ok(Embedding(embedding.unwrap()));
103    }
104}
105
106pub struct VectorDatabase {
107    db: rusqlite::Connection,
108}
109
110impl VectorDatabase {
111    pub fn new(path: &str) -> Result<Self> {
112        let this = Self {
113            db: rusqlite::Connection::open(path)?,
114        };
115        this.initialize_database()?;
116        Ok(this)
117    }
118
119    fn initialize_database(&self) -> Result<()> {
120        rusqlite::vtab::array::load_module(&self.db)?;
121
122        // This will create the database if it doesnt exist
123
124        // Initialize Vector Databasing Tables
125        self.db.execute(
126            "CREATE TABLE IF NOT EXISTS worktrees (
127                id INTEGER PRIMARY KEY AUTOINCREMENT,
128                absolute_path VARCHAR NOT NULL
129            );
130            CREATE UNIQUE INDEX IF NOT EXISTS worktrees_absolute_path ON worktrees (absolute_path);
131            ",
132            [],
133        )?;
134
135        self.db.execute(
136            "CREATE TABLE IF NOT EXISTS files (
137                id INTEGER PRIMARY KEY AUTOINCREMENT,
138                worktree_id INTEGER NOT NULL,
139                relative_path VARCHAR NOT NULL,
140                sha1 BLOB NOT NULL,
141                FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
142            )",
143            [],
144        )?;
145
146        self.db.execute(
147            "CREATE TABLE IF NOT EXISTS documents (
148                id INTEGER PRIMARY KEY AUTOINCREMENT,
149                file_id INTEGER NOT NULL,
150                offset INTEGER NOT NULL,
151                name VARCHAR NOT NULL,
152                embedding BLOB NOT NULL,
153                FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
154            )",
155            [],
156        )?;
157
158        Ok(())
159    }
160
161    pub fn insert_file(&self, worktree_id: i64, indexed_file: IndexedFile) -> Result<()> {
162        // Write to files table, and return generated id.
163        log::info!("Inserting File!");
164        self.db.execute(
165            "
166            DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
167            ",
168            params![worktree_id, indexed_file.path.to_str()],
169        )?;
170        self.db.execute(
171            "
172            INSERT INTO files (worktree_id, relative_path, sha1) VALUES (?1, ?2, $3);
173            ",
174            params![worktree_id, indexed_file.path.to_str(), indexed_file.sha1],
175        )?;
176
177        let file_id = self.db.last_insert_rowid();
178
179        // Currently inserting at approximately 3400 documents a second
180        // I imagine we can speed this up with a bulk insert of some kind.
181        for document in indexed_file.documents {
182            let embedding_blob = bincode::serialize(&document.embedding)?;
183
184            self.db.execute(
185                "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
186                params![
187                    file_id,
188                    document.offset.to_string(),
189                    document.name,
190                    embedding_blob
191                ],
192            )?;
193        }
194
195        Ok(())
196    }
197
198    pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
199        // Check that the absolute path doesnt exist
200        let mut worktree_query = self
201            .db
202            .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
203
204        let worktree_id = worktree_query
205            .query_row(params![worktree_root_path.to_string_lossy()], |row| {
206                Ok(row.get::<_, i64>(0)?)
207            })
208            .map_err(|err| anyhow!(err));
209
210        if worktree_id.is_ok() {
211            return worktree_id;
212        }
213
214        // If worktree_id is Err, insert new worktree
215        self.db.execute(
216            "
217            INSERT into worktrees (absolute_path) VALUES (?1)
218            ",
219            params![worktree_root_path.to_string_lossy()],
220        )?;
221        Ok(self.db.last_insert_rowid())
222    }
223
224    pub fn get_file_hashes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, FileSha1>> {
225        let mut statement = self.db.prepare(
226            "SELECT relative_path, sha1 FROM files WHERE worktree_id = ?1 ORDER BY relative_path",
227        )?;
228        let mut result: HashMap<PathBuf, FileSha1> = HashMap::new();
229        for row in statement.query_map(params![worktree_id], |row| {
230            Ok((row.get::<_, String>(0)?.into(), row.get(1)?))
231        })? {
232            let row = row?;
233            result.insert(row.0, row.1);
234        }
235        Ok(result)
236    }
237
238    pub fn get_files(&self) -> Result<HashMap<usize, FileRecord>> {
239        let mut query_statement = self
240            .db
241            .prepare("SELECT id, relative_path, sha1 FROM files")?;
242        let result_iter = query_statement.query_map([], |row| {
243            Ok(FileRecord {
244                id: row.get(0)?,
245                relative_path: row.get(1)?,
246                sha1: row.get(2)?,
247            })
248        })?;
249
250        let mut pages: HashMap<usize, FileRecord> = HashMap::new();
251        for result in result_iter {
252            let result = result?;
253            pages.insert(result.id, result);
254        }
255
256        Ok(pages)
257    }
258
259    pub fn for_each_document(
260        &self,
261        worktree_id: i64,
262        mut f: impl FnMut(i64, Embedding),
263    ) -> Result<()> {
264        let mut query_statement = self.db.prepare("SELECT id, embedding FROM documents")?;
265        query_statement
266            .query_map(params![], |row| Ok((row.get(0)?, row.get(1)?)))?
267            .filter_map(|row| row.ok())
268            .for_each(|row| f(row.0, row.1));
269        Ok(())
270    }
271
272    pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(PathBuf, usize, String)>> {
273        let mut statement = self.db.prepare(
274            "
275                SELECT
276                    documents.id, files.relative_path, documents.offset, documents.name
277                FROM
278                    documents, files
279                WHERE
280                    documents.file_id = files.id AND
281                    documents.id in rarray(?)
282            ",
283        )?;
284
285        let result_iter = statement.query_map(
286            params![std::rc::Rc::new(
287                ids.iter()
288                    .copied()
289                    .map(|v| rusqlite::types::Value::from(v))
290                    .collect::<Vec<_>>()
291            )],
292            |row| {
293                Ok((
294                    row.get::<_, i64>(0)?,
295                    row.get::<_, String>(1)?.into(),
296                    row.get(2)?,
297                    row.get(3)?,
298                ))
299            },
300        )?;
301
302        let mut values_by_id = HashMap::<i64, (PathBuf, usize, String)>::default();
303        for row in result_iter {
304            let (id, path, offset, name) = row?;
305            values_by_id.insert(id, (path, offset, name));
306        }
307
308        let mut results = Vec::with_capacity(ids.len());
309        for id in ids {
310            let (path, offset, name) = values_by_id
311                .remove(id)
312                .ok_or(anyhow!("missing document id {}", id))?;
313            results.push((path, offset, name));
314        }
315
316        Ok(results)
317    }
318
319    pub fn get_documents(&self) -> Result<HashMap<usize, DocumentRecord>> {
320        let mut query_statement = self
321            .db
322            .prepare("SELECT id, file_id, offset, name, embedding FROM documents")?;
323        let result_iter = query_statement.query_map([], |row| {
324            Ok(DocumentRecord {
325                id: row.get(0)?,
326                file_id: row.get(1)?,
327                offset: row.get(2)?,
328                name: row.get(3)?,
329                embedding: row.get(4)?,
330            })
331        })?;
332
333        let mut documents: HashMap<usize, DocumentRecord> = HashMap::new();
334        for result in result_iter {
335            let result = result?;
336            documents.insert(result.id, result);
337        }
338
339        return Ok(documents);
340    }
341}