db.rs

  1use std::{
  2    collections::HashMap,
  3    path::{Path, PathBuf},
  4    rc::Rc,
  5};
  6
  7use anyhow::{anyhow, Result};
  8
  9use rusqlite::{
 10    params,
 11    types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
 12    ToSql,
 13};
 14use sha1::{Digest, Sha1};
 15
 16use crate::IndexedFile;
 17
 18// This is saving to a local database store within the users dev zed path
 19// Where do we want this to sit?
 20// Assuming near where the workspace DB sits.
 21pub const VECTOR_DB_URL: &str = "embeddings_db";
 22
 23// Note this is not an appropriate document
 24#[derive(Debug)]
 25pub struct DocumentRecord {
 26    pub id: usize,
 27    pub file_id: usize,
 28    pub offset: usize,
 29    pub name: String,
 30    pub embedding: Embedding,
 31}
 32
 33#[derive(Debug)]
 34pub struct FileRecord {
 35    pub id: usize,
 36    pub relative_path: String,
 37    pub sha1: FileSha1,
 38}
 39
 40#[derive(Debug)]
 41pub struct FileSha1(pub Vec<u8>);
 42
 43impl FileSha1 {
 44    pub fn from_str(content: String) -> Self {
 45        let mut hasher = Sha1::new();
 46        hasher.update(content);
 47        let sha1 = hasher.finalize()[..]
 48            .into_iter()
 49            .map(|val| val.to_owned())
 50            .collect::<Vec<u8>>();
 51        return FileSha1(sha1);
 52    }
 53
 54    pub fn equals(&self, content: &String) -> bool {
 55        let mut hasher = Sha1::new();
 56        hasher.update(content);
 57        let sha1 = hasher.finalize()[..]
 58            .into_iter()
 59            .map(|val| val.to_owned())
 60            .collect::<Vec<u8>>();
 61
 62        let equal = self
 63            .0
 64            .clone()
 65            .into_iter()
 66            .zip(sha1)
 67            .filter(|&(a, b)| a == b)
 68            .count()
 69            == self.0.len();
 70
 71        equal
 72    }
 73}
 74
 75impl ToSql for FileSha1 {
 76    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
 77        return self.0.to_sql();
 78    }
 79}
 80
 81impl FromSql for FileSha1 {
 82    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 83        let bytes = value.as_blob()?;
 84        Ok(FileSha1(
 85            bytes
 86                .into_iter()
 87                .map(|val| val.to_owned())
 88                .collect::<Vec<u8>>(),
 89        ))
 90    }
 91}
 92
 93#[derive(Debug)]
 94pub struct Embedding(pub Vec<f32>);
 95
 96impl FromSql for Embedding {
 97    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 98        let bytes = value.as_blob()?;
 99        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
100        if embedding.is_err() {
101            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
102        }
103        return Ok(Embedding(embedding.unwrap()));
104    }
105}
106
107pub struct VectorDatabase {
108    db: rusqlite::Connection,
109}
110
111impl VectorDatabase {
112    pub fn new(path: &str) -> Result<Self> {
113        let this = Self {
114            db: rusqlite::Connection::open(path)?,
115        };
116        this.initialize_database()?;
117        Ok(this)
118    }
119
120    fn initialize_database(&self) -> Result<()> {
121        rusqlite::vtab::array::load_module(&self.db)?;
122
123        // This will create the database if it doesnt exist
124
125        // Initialize Vector Databasing Tables
126        self.db.execute(
127            "CREATE TABLE IF NOT EXISTS worktrees (
128                id INTEGER PRIMARY KEY AUTOINCREMENT,
129                absolute_path VARCHAR NOT NULL
130            );
131            CREATE UNIQUE INDEX IF NOT EXISTS worktrees_absolute_path ON worktrees (absolute_path);
132            ",
133            [],
134        )?;
135
136        self.db.execute(
137            "CREATE TABLE IF NOT EXISTS files (
138                id INTEGER PRIMARY KEY AUTOINCREMENT,
139                worktree_id INTEGER NOT NULL,
140                relative_path VARCHAR NOT NULL,
141                sha1 BLOB NOT NULL,
142                FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
143            )",
144            [],
145        )?;
146
147        self.db.execute(
148            "CREATE TABLE IF NOT EXISTS documents (
149                id INTEGER PRIMARY KEY AUTOINCREMENT,
150                file_id INTEGER NOT NULL,
151                offset INTEGER NOT NULL,
152                name VARCHAR NOT NULL,
153                embedding BLOB NOT NULL,
154                FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
155            )",
156            [],
157        )?;
158
159        Ok(())
160    }
161
162    pub fn insert_file(&self, worktree_id: i64, indexed_file: IndexedFile) -> Result<()> {
163        // Write to files table, and return generated id.
164        log::info!("Inserting File!");
165        self.db.execute(
166            "
167            DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
168            ",
169            params![worktree_id, indexed_file.path.to_str()],
170        )?;
171        self.db.execute(
172            "
173            INSERT INTO files (worktree_id, relative_path, sha1) VALUES (?1, ?2, $3);
174            ",
175            params![worktree_id, indexed_file.path.to_str(), indexed_file.sha1],
176        )?;
177
178        let file_id = self.db.last_insert_rowid();
179
180        // Currently inserting at approximately 3400 documents a second
181        // I imagine we can speed this up with a bulk insert of some kind.
182        for document in indexed_file.documents {
183            let embedding_blob = bincode::serialize(&document.embedding)?;
184
185            self.db.execute(
186                "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
187                params![
188                    file_id,
189                    document.offset.to_string(),
190                    document.name,
191                    embedding_blob
192                ],
193            )?;
194        }
195
196        Ok(())
197    }
198
199    pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
200        // Check that the absolute path doesnt exist
201        let mut worktree_query = self
202            .db
203            .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
204
205        let worktree_id = worktree_query
206            .query_row(params![worktree_root_path.to_string_lossy()], |row| {
207                Ok(row.get::<_, i64>(0)?)
208            })
209            .map_err(|err| anyhow!(err));
210
211        if worktree_id.is_ok() {
212            return worktree_id;
213        }
214
215        // If worktree_id is Err, insert new worktree
216        self.db.execute(
217            "
218            INSERT into worktrees (absolute_path) VALUES (?1)
219            ",
220            params![worktree_root_path.to_string_lossy()],
221        )?;
222        Ok(self.db.last_insert_rowid())
223    }
224
225    pub fn get_file_hashes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, FileSha1>> {
226        let mut statement = self.db.prepare(
227            "SELECT relative_path, sha1 FROM files WHERE worktree_id = ?1 ORDER BY relative_path",
228        )?;
229        let mut result: HashMap<PathBuf, FileSha1> = HashMap::new();
230        for row in statement.query_map(params![worktree_id], |row| {
231            Ok((row.get::<_, String>(0)?.into(), row.get(1)?))
232        })? {
233            let row = row?;
234            result.insert(row.0, row.1);
235        }
236        Ok(result)
237    }
238
239    pub fn for_each_document(
240        &self,
241        worktree_ids: &[i64],
242        mut f: impl FnMut(i64, Embedding),
243    ) -> Result<()> {
244        let mut query_statement = self.db.prepare(
245            "
246            SELECT
247                documents.id, documents.embedding
248            FROM
249                documents, files
250            WHERE
251                documents.file_id = files.id AND
252                files.worktree_id IN rarray(?)
253            ",
254        )?;
255        query_statement
256            .query_map(params![ids_to_sql(worktree_ids)], |row| {
257                Ok((row.get(0)?, row.get(1)?))
258            })?
259            .filter_map(|row| row.ok())
260            .for_each(|row| f(row.0, row.1));
261        Ok(())
262    }
263
264    pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, usize, String)>> {
265        let mut statement = self.db.prepare(
266            "
267                SELECT
268                    documents.id, files.worktree_id, files.relative_path, documents.offset, documents.name
269                FROM
270                    documents, files
271                WHERE
272                    documents.file_id = files.id AND
273                    documents.id in rarray(?)
274            ",
275        )?;
276
277        let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
278            Ok((
279                row.get::<_, i64>(0)?,
280                row.get::<_, i64>(1)?,
281                row.get::<_, String>(2)?.into(),
282                row.get(3)?,
283                row.get(4)?,
284            ))
285        })?;
286
287        let mut values_by_id = HashMap::<i64, (i64, PathBuf, usize, String)>::default();
288        for row in result_iter {
289            let (id, worktree_id, path, offset, name) = row?;
290            values_by_id.insert(id, (worktree_id, path, offset, name));
291        }
292
293        let mut results = Vec::with_capacity(ids.len());
294        for id in ids {
295            let value = values_by_id
296                .remove(id)
297                .ok_or(anyhow!("missing document id {}", id))?;
298            results.push(value);
299        }
300
301        Ok(results)
302    }
303}
304
305fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
306    Rc::new(
307        ids.iter()
308            .copied()
309            .map(|v| rusqlite::types::Value::from(v))
310            .collect::<Vec<_>>(),
311    )
312}