db.rs

  1use std::collections::HashMap;
  2
  3use anyhow::{anyhow, Result};
  4
  5use rusqlite::{
  6    params,
  7    types::{FromSql, FromSqlResult, ValueRef},
  8    Connection,
  9};
 10
 11use crate::IndexedFile;
 12
 13// This is saving to a local database store within the users dev zed path
 14// Where do we want this to sit?
 15// Assuming near where the workspace DB sits.
 16const VECTOR_DB_URL: &str = "embeddings_db";
 17
 18// Note this is not an appropriate document
 19#[derive(Debug)]
 20pub struct DocumentRecord {
 21    pub id: usize,
 22    pub file_id: usize,
 23    pub offset: usize,
 24    pub name: String,
 25    pub embedding: Embedding,
 26}
 27
 28#[derive(Debug)]
 29pub struct FileRecord {
 30    pub id: usize,
 31    pub path: String,
 32    pub sha1: String,
 33}
 34
 35#[derive(Debug)]
 36pub struct Embedding(pub Vec<f32>);
 37
 38impl FromSql for Embedding {
 39    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 40        let bytes = value.as_blob()?;
 41        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 42        if embedding.is_err() {
 43            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 44        }
 45        return Ok(Embedding(embedding.unwrap()));
 46    }
 47}
 48
 49pub struct VectorDatabase {}
 50
 51impl VectorDatabase {
 52    pub async fn initialize_database() -> Result<()> {
 53        // This will create the database if it doesnt exist
 54        let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
 55
 56        // Initialize Vector Databasing Tables
 57        db.execute(
 58            "CREATE TABLE IF NOT EXISTS files (
 59        id INTEGER PRIMARY KEY AUTOINCREMENT,
 60        path NVARCHAR(100) NOT NULL,
 61        sha1 NVARCHAR(40) NOT NULL
 62        )",
 63            [],
 64        )?;
 65
 66        db.execute(
 67            "CREATE TABLE IF NOT EXISTS documents (
 68            id INTEGER PRIMARY KEY AUTOINCREMENT,
 69            file_id INTEGER NOT NULL,
 70            offset INTEGER NOT NULL,
 71            name NVARCHAR(100) NOT NULL,
 72            embedding BLOB NOT NULL,
 73            FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
 74            )",
 75            [],
 76        )?;
 77
 78        Ok(())
 79    }
 80
 81    pub async fn insert_file(indexed_file: IndexedFile) -> Result<()> {
 82        // Write to files table, and return generated id.
 83        let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
 84
 85        let files_insert = db.execute(
 86            "INSERT INTO files (path, sha1) VALUES (?1, ?2)",
 87            params![indexed_file.path.to_str(), indexed_file.sha1],
 88        )?;
 89
 90        let inserted_id = db.last_insert_rowid();
 91
 92        // Currently inserting at approximately 3400 documents a second
 93        // I imagine we can speed this up with a bulk insert of some kind.
 94        for document in indexed_file.documents {
 95            let embedding_blob = bincode::serialize(&document.embedding)?;
 96
 97            db.execute(
 98                "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
 99                params![
100                    inserted_id,
101                    document.offset.to_string(),
102                    document.name,
103                    embedding_blob
104                ],
105            )?;
106        }
107
108        Ok(())
109    }
110
111    pub fn get_files(&self) -> Result<HashMap<usize, FileRecord>> {
112        let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
113
114        fn query(db: Connection) -> rusqlite::Result<Vec<FileRecord>> {
115            let mut query_statement = db.prepare("SELECT id, path, sha1 FROM files")?;
116            let result_iter = query_statement.query_map([], |row| {
117                Ok(FileRecord {
118                    id: row.get(0)?,
119                    path: row.get(1)?,
120                    sha1: row.get(2)?,
121                })
122            })?;
123
124            let mut results = vec![];
125            for result in result_iter {
126                results.push(result?);
127            }
128
129            return Ok(results);
130        }
131
132        let mut pages: HashMap<usize, FileRecord> = HashMap::new();
133        let result_iter = query(db);
134        if result_iter.is_ok() {
135            for result in result_iter.unwrap() {
136                pages.insert(result.id, result);
137            }
138        }
139
140        return Ok(pages);
141    }
142
143    pub fn get_documents(&self) -> Result<HashMap<usize, DocumentRecord>> {
144        // Should return a HashMap in which the key is the id, and the value is the finished document
145
146        // Get Data from Database
147        let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
148
149        fn query(db: Connection) -> rusqlite::Result<Vec<DocumentRecord>> {
150            let mut query_statement =
151                db.prepare("SELECT id, file_id, offset, name, embedding FROM documents")?;
152            let result_iter = query_statement.query_map([], |row| {
153                Ok(DocumentRecord {
154                    id: row.get(0)?,
155                    file_id: row.get(1)?,
156                    offset: row.get(2)?,
157                    name: row.get(3)?,
158                    embedding: row.get(4)?,
159                })
160            })?;
161
162            let mut results = vec![];
163            for result in result_iter {
164                results.push(result?);
165            }
166
167            return Ok(results);
168        }
169
170        let mut documents: HashMap<usize, DocumentRecord> = HashMap::new();
171        let result_iter = query(db);
172        if result_iter.is_ok() {
173            for result in result_iter.unwrap() {
174                documents.insert(result.id, result);
175            }
176        }
177
178        return Ok(documents);
179    }
180}