db.rs

  1use std::collections::HashMap;
  2
  3use anyhow::{anyhow, Result};
  4
  5use rusqlite::{
  6    params,
  7    types::{FromSql, FromSqlResult, ValueRef},
  8    Connection,
  9};
 10use util::ResultExt;
 11
 12use crate::{Document, IndexedFile};
 13
 14// This is saving to a local database store within the users dev zed path
 15// Where do we want this to sit?
 16// Assuming near where the workspace DB sits.
 17const VECTOR_DB_URL: &str = "embeddings_db";
 18
 19// Note this is not an appropriate document
 20#[derive(Debug)]
 21pub struct DocumentRecord {
 22    id: usize,
 23    offset: usize,
 24    name: String,
 25    embedding: Embedding,
 26}
 27
 28#[derive(Debug)]
 29struct Embedding(Vec<f32>);
 30
 31impl FromSql for Embedding {
 32    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
 33        let bytes = value.as_blob()?;
 34        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
 35        if embedding.is_err() {
 36            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
 37        }
 38        return Ok(Embedding(embedding.unwrap()));
 39    }
 40}
 41
 42pub struct VectorDatabase {}
 43
 44impl VectorDatabase {
 45    pub async fn initialize_database() -> Result<()> {
 46        // This will create the database if it doesnt exist
 47        let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
 48
 49        // Initialize Vector Databasing Tables
 50        db.execute(
 51            "CREATE TABLE IF NOT EXISTS files (
 52        id INTEGER PRIMARY KEY AUTOINCREMENT,
 53        path NVARCHAR(100) NOT NULL,
 54        sha1 NVARCHAR(40) NOT NULL
 55        )",
 56            [],
 57        )?;
 58
 59        db.execute(
 60            "CREATE TABLE IF NOT EXISTS documents (
 61            id INTEGER PRIMARY KEY AUTOINCREMENT,
 62            file_id INTEGER NOT NULL,
 63            offset INTEGER NOT NULL,
 64            name NVARCHAR(100) NOT NULL,
 65            embedding BLOB NOT NULL,
 66            FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
 67            )",
 68            [],
 69        )?;
 70
 71        Ok(())
 72    }
 73
 74    pub async fn insert_file(indexed_file: IndexedFile) -> Result<()> {
 75        // Write to files table, and return generated id.
 76        let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
 77
 78        let files_insert = db.execute(
 79            "INSERT INTO files (path, sha1) VALUES (?1, ?2)",
 80            params![indexed_file.path.to_str(), indexed_file.sha1],
 81        )?;
 82
 83        let inserted_id = db.last_insert_rowid();
 84
 85        // Currently inserting at approximately 3400 documents a second
 86        // I imagine we can speed this up with a bulk insert of some kind.
 87        for document in indexed_file.documents {
 88            let embedding_blob = bincode::serialize(&document.embedding)?;
 89
 90            db.execute(
 91                "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
 92                params![
 93                    inserted_id,
 94                    document.offset.to_string(),
 95                    document.name,
 96                    embedding_blob
 97                ],
 98            )?;
 99        }
100
101        Ok(())
102    }
103
104    pub fn get_documents(&self) -> Result<HashMap<usize, Document>> {
105        // Should return a HashMap in which the key is the id, and the value is the finished document
106
107        // Get Data from Database
108        let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
109
110        fn query(db: Connection) -> rusqlite::Result<Vec<DocumentRecord>> {
111            let mut query_statement =
112                db.prepare("SELECT id, offset, name, embedding FROM documents LIMIT 10")?;
113            let result_iter = query_statement.query_map([], |row| {
114                Ok(DocumentRecord {
115                    id: row.get(0)?,
116                    offset: row.get(1)?,
117                    name: row.get(2)?,
118                    embedding: row.get(3)?,
119                })
120            })?;
121
122            let mut results = vec![];
123            for result in result_iter {
124                results.push(result?);
125            }
126
127            return Ok(results);
128        }
129
130        let mut documents: HashMap<usize, Document> = HashMap::new();
131        let result_iter = query(db);
132        if result_iter.is_ok() {
133            for result in result_iter.unwrap() {
134                documents.insert(
135                    result.id,
136                    Document {
137                        offset: result.offset,
138                        name: result.name,
139                        embedding: result.embedding.0,
140                    },
141                );
142            }
143        }
144
145        return Ok(documents);
146    }
147}