1use std::collections::HashMap;
2
3use anyhow::{anyhow, Result};
4
5use rusqlite::{
6 params,
7 types::{FromSql, FromSqlResult, ValueRef},
8 Connection,
9};
10
11use crate::IndexedFile;
12
13// This is saving to a local database store within the users dev zed path
14// Where do we want this to sit?
15// Assuming near where the workspace DB sits.
16const VECTOR_DB_URL: &str = "embeddings_db";
17
18// Note this is not an appropriate document
19#[derive(Debug)]
20pub struct DocumentRecord {
21 pub id: usize,
22 pub file_id: usize,
23 pub offset: usize,
24 pub name: String,
25 pub embedding: Embedding,
26}
27
28#[derive(Debug)]
29pub struct FileRecord {
30 pub id: usize,
31 pub path: String,
32 pub sha1: String,
33}
34
35#[derive(Debug)]
36pub struct Embedding(pub Vec<f32>);
37
38impl FromSql for Embedding {
39 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
40 let bytes = value.as_blob()?;
41 let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
42 if embedding.is_err() {
43 return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
44 }
45 return Ok(Embedding(embedding.unwrap()));
46 }
47}
48
49pub struct VectorDatabase {}
50
51impl VectorDatabase {
52 pub async fn initialize_database() -> Result<()> {
53 // This will create the database if it doesnt exist
54 let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
55
56 // Initialize Vector Databasing Tables
57 db.execute(
58 "CREATE TABLE IF NOT EXISTS files (
59 id INTEGER PRIMARY KEY AUTOINCREMENT,
60 path NVARCHAR(100) NOT NULL,
61 sha1 NVARCHAR(40) NOT NULL
62 )",
63 [],
64 )?;
65
66 db.execute(
67 "CREATE TABLE IF NOT EXISTS documents (
68 id INTEGER PRIMARY KEY AUTOINCREMENT,
69 file_id INTEGER NOT NULL,
70 offset INTEGER NOT NULL,
71 name NVARCHAR(100) NOT NULL,
72 embedding BLOB NOT NULL,
73 FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
74 )",
75 [],
76 )?;
77
78 Ok(())
79 }
80
81 pub async fn insert_file(indexed_file: IndexedFile) -> Result<()> {
82 // Write to files table, and return generated id.
83 let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
84
85 let files_insert = db.execute(
86 "INSERT INTO files (path, sha1) VALUES (?1, ?2)",
87 params![indexed_file.path.to_str(), indexed_file.sha1],
88 )?;
89
90 let inserted_id = db.last_insert_rowid();
91
92 // Currently inserting at approximately 3400 documents a second
93 // I imagine we can speed this up with a bulk insert of some kind.
94 for document in indexed_file.documents {
95 let embedding_blob = bincode::serialize(&document.embedding)?;
96
97 db.execute(
98 "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
99 params![
100 inserted_id,
101 document.offset.to_string(),
102 document.name,
103 embedding_blob
104 ],
105 )?;
106 }
107
108 Ok(())
109 }
110
111 pub fn get_files(&self) -> Result<HashMap<usize, FileRecord>> {
112 let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
113
114 fn query(db: Connection) -> rusqlite::Result<Vec<FileRecord>> {
115 let mut query_statement = db.prepare("SELECT id, path, sha1 FROM files")?;
116 let result_iter = query_statement.query_map([], |row| {
117 Ok(FileRecord {
118 id: row.get(0)?,
119 path: row.get(1)?,
120 sha1: row.get(2)?,
121 })
122 })?;
123
124 let mut results = vec![];
125 for result in result_iter {
126 results.push(result?);
127 }
128
129 return Ok(results);
130 }
131
132 let mut pages: HashMap<usize, FileRecord> = HashMap::new();
133 let result_iter = query(db);
134 if result_iter.is_ok() {
135 for result in result_iter.unwrap() {
136 pages.insert(result.id, result);
137 }
138 }
139
140 return Ok(pages);
141 }
142
143 pub fn get_documents(&self) -> Result<HashMap<usize, DocumentRecord>> {
144 // Should return a HashMap in which the key is the id, and the value is the finished document
145
146 // Get Data from Database
147 let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
148
149 fn query(db: Connection) -> rusqlite::Result<Vec<DocumentRecord>> {
150 let mut query_statement =
151 db.prepare("SELECT id, file_id, offset, name, embedding FROM documents")?;
152 let result_iter = query_statement.query_map([], |row| {
153 Ok(DocumentRecord {
154 id: row.get(0)?,
155 file_id: row.get(1)?,
156 offset: row.get(2)?,
157 name: row.get(3)?,
158 embedding: row.get(4)?,
159 })
160 })?;
161
162 let mut results = vec![];
163 for result in result_iter {
164 results.push(result?);
165 }
166
167 return Ok(results);
168 }
169
170 let mut documents: HashMap<usize, DocumentRecord> = HashMap::new();
171 let result_iter = query(db);
172 if result_iter.is_ok() {
173 for result in result_iter.unwrap() {
174 documents.insert(result.id, result);
175 }
176 }
177
178 return Ok(documents);
179 }
180}