db.rs

  1use std::{
  2    collections::HashMap,
  3    path::{Path, PathBuf},
  4};
  5
  6use anyhow::{anyhow, Result};
  7
  8use rusqlite::{
  9    params,
 10    types::{FromSql, FromSqlResult, ValueRef},
 11    Connection,
 12};
 13
 14use crate::IndexedFile;
 15
 16// This is saving to a local database store within the users dev zed path
 17// Where do we want this to sit?
 18// Assuming near where the workspace DB sits.
 19pub const VECTOR_DB_URL: &str = "embeddings_db";
 20
 21// Note this is not an appropriate document
 22#[derive(Debug)]
 23pub struct DocumentRecord {
 24    pub id: usize,
 25    pub file_id: usize,
 26    pub offset: usize,
 27    pub name: String,
 28    pub embedding: Embedding,
 29}
 30
 31#[derive(Debug)]
 32pub struct FileRecord {
 33    pub id: usize,
 34    pub relative_path: String,
 35    pub sha1: String,
 36}
 37
 38#[derive(Debug)]
 39pub struct Embedding(pub Vec<f32>);
 40
 41impl FromSql for Embedding {
 42    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 43        let bytes = value.as_blob()?;
 44        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 45        if embedding.is_err() {
 46            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 47        }
 48        return Ok(Embedding(embedding.unwrap()));
 49    }
 50}
 51
 52pub struct VectorDatabase {
 53    db: rusqlite::Connection,
 54}
 55
 56impl VectorDatabase {
 57    pub fn new(path: &str) -> Result<Self> {
 58        let this = Self {
 59            db: rusqlite::Connection::open(path)?,
 60        };
 61        this.initialize_database()?;
 62        Ok(this)
 63    }
 64
 65    fn initialize_database(&self) -> Result<()> {
 66        // This will create the database if it doesnt exist
 67
 68        // Initialize Vector Databasing Tables
 69        self.db.execute(
 70            "CREATE TABLE IF NOT EXISTS worktrees (
 71                id INTEGER PRIMARY KEY AUTOINCREMENT,
 72                absolute_path VARCHAR NOT NULL
 73            );
 74            CREATE UNIQUE INDEX IF NOT EXISTS worktrees_absolute_path ON worktrees (absolute_path);
 75            ",
 76            [],
 77        )?;
 78
 79        self.db.execute(
 80            "CREATE TABLE IF NOT EXISTS files (
 81                id INTEGER PRIMARY KEY AUTOINCREMENT,
 82                worktree_id INTEGER NOT NULL,
 83                relative_path VARCHAR NOT NULL,
 84                sha1 NVARCHAR(40) NOT NULL,
 85                FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
 86            )",
 87            [],
 88        )?;
 89
 90        self.db.execute(
 91            "CREATE TABLE IF NOT EXISTS documents (
 92                id INTEGER PRIMARY KEY AUTOINCREMENT,
 93                file_id INTEGER NOT NULL,
 94                offset INTEGER NOT NULL,
 95                name VARCHAR NOT NULL,
 96                embedding BLOB NOT NULL,
 97                FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
 98            )",
 99            [],
100        )?;
101
102        Ok(())
103    }
104
105    // pub async fn get_or_create_project(project_path: PathBuf) -> Result<usize> {
106    //     // Check if we have the project, if we do, return the ID
107    //     // If we do not have the project, insert the project and return the ID
108
109    //     let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
110
111    //     let projects_query = db.prepare(&format!(
112    //         "SELECT id FROM projects WHERE path = {}",
113    //         project_path.to_str().unwrap() // This is unsafe
114    //     ))?;
115
116    //     let project_id = db.last_insert_rowid();
117
118    //     return Ok(project_id as usize);
119    // }
120
121    pub fn insert_file(&self, indexed_file: IndexedFile) -> Result<()> {
122        // Write to files table, and return generated id.
123        let files_insert = self.db.execute(
124            "INSERT INTO files (relative_path, sha1) VALUES (?1, ?2)",
125            params![indexed_file.path.to_str(), indexed_file.sha1],
126        )?;
127
128        let inserted_id = self.db.last_insert_rowid();
129
130        // Currently inserting at approximately 3400 documents a second
131        // I imagine we can speed this up with a bulk insert of some kind.
132        for document in indexed_file.documents {
133            let embedding_blob = bincode::serialize(&document.embedding)?;
134
135            self.db.execute(
136                "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
137                params![
138                    inserted_id,
139                    document.offset.to_string(),
140                    document.name,
141                    embedding_blob
142                ],
143            )?;
144        }
145
146        Ok(())
147    }
148
149    pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
150        self.db.execute(
151            "
152            INSERT into worktrees (absolute_path) VALUES (?1)
153            ON CONFLICT DO NOTHING
154            ",
155            params![worktree_root_path.to_string_lossy()],
156        )?;
157        Ok(self.db.last_insert_rowid())
158    }
159
160    pub fn get_file_hashes(&self, worktree_id: i64) -> Result<Vec<(PathBuf, String)>> {
161        let mut statement = self
162            .db
163            .prepare("SELECT relative_path, sha1 FROM files ORDER BY relative_path")?;
164        let mut result = Vec::new();
165        for row in
166            statement.query_map([], |row| Ok((row.get::<_, String>(0)?.into(), row.get(1)?)))?
167        {
168            result.push(row?);
169        }
170        Ok(result)
171    }
172
173    pub fn get_files(&self) -> Result<HashMap<usize, FileRecord>> {
174        let mut query_statement = self
175            .db
176            .prepare("SELECT id, relative_path, sha1 FROM files")?;
177        let result_iter = query_statement.query_map([], |row| {
178            Ok(FileRecord {
179                id: row.get(0)?,
180                relative_path: row.get(1)?,
181                sha1: row.get(2)?,
182            })
183        })?;
184
185        let mut pages: HashMap<usize, FileRecord> = HashMap::new();
186        for result in result_iter {
187            let result = result?;
188            pages.insert(result.id, result);
189        }
190
191        Ok(pages)
192    }
193
194    pub fn for_each_document(
195        &self,
196        worktree_id: i64,
197        mut f: impl FnMut(i64, Embedding),
198    ) -> Result<()> {
199        let mut query_statement = self.db.prepare("SELECT id, embedding FROM documents")?;
200        query_statement
201            .query_map(params![], |row| Ok((row.get(0)?, row.get(1)?)))?
202            .filter_map(|row| row.ok())
203            .for_each(|row| f(row.0, row.1));
204        Ok(())
205    }
206
207    pub fn get_documents(&self) -> Result<HashMap<usize, DocumentRecord>> {
208        let mut query_statement = self
209            .db
210            .prepare("SELECT id, file_id, offset, name, embedding FROM documents")?;
211        let result_iter = query_statement.query_map([], |row| {
212            Ok(DocumentRecord {
213                id: row.get(0)?,
214                file_id: row.get(1)?,
215                offset: row.get(2)?,
216                name: row.get(3)?,
217                embedding: row.get(4)?,
218            })
219        })?;
220
221        let mut documents: HashMap<usize, DocumentRecord> = HashMap::new();
222        for result in result_iter {
223            let result = result?;
224            documents.insert(result.id, result);
225        }
226
227        return Ok(documents);
228    }
229}