1use std::{
2 collections::HashMap,
3 path::{Path, PathBuf},
4};
5
6use anyhow::{anyhow, Result};
7
8use rusqlite::{
9 params,
10 types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
11 ToSql,
12};
13use sha1::{Digest, Sha1};
14
15use crate::IndexedFile;
16
17// This is saving to a local database store within the users dev zed path
18// Where do we want this to sit?
19// Assuming near where the workspace DB sits.
20pub const VECTOR_DB_URL: &str = "embeddings_db";
21
22// Note this is not an appropriate document
23#[derive(Debug)]
24pub struct DocumentRecord {
25 pub id: usize,
26 pub file_id: usize,
27 pub offset: usize,
28 pub name: String,
29 pub embedding: Embedding,
30}
31
32#[derive(Debug)]
33pub struct FileRecord {
34 pub id: usize,
35 pub relative_path: String,
36 pub sha1: FileSha1,
37}
38
39#[derive(Debug)]
40pub struct FileSha1(pub Vec<u8>);
41
42impl FileSha1 {
43 pub fn from_str(content: String) -> Self {
44 let mut hasher = Sha1::new();
45 hasher.update(content);
46 let sha1 = hasher.finalize()[..]
47 .into_iter()
48 .map(|val| val.to_owned())
49 .collect::<Vec<u8>>();
50 return FileSha1(sha1);
51 }
52
53 pub fn equals(&self, content: &String) -> bool {
54 let mut hasher = Sha1::new();
55 hasher.update(content);
56 let sha1 = hasher.finalize()[..]
57 .into_iter()
58 .map(|val| val.to_owned())
59 .collect::<Vec<u8>>();
60
61 let equal = self
62 .0
63 .clone()
64 .into_iter()
65 .zip(sha1)
66 .filter(|&(a, b)| a == b)
67 .count()
68 == self.0.len();
69
70 equal
71 }
72}
73
74impl ToSql for FileSha1 {
75 fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
76 return self.0.to_sql();
77 }
78}
79
80impl FromSql for FileSha1 {
81 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
82 let bytes = value.as_blob()?;
83 Ok(FileSha1(
84 bytes
85 .into_iter()
86 .map(|val| val.to_owned())
87 .collect::<Vec<u8>>(),
88 ))
89 }
90}
91
92#[derive(Debug)]
93pub struct Embedding(pub Vec<f32>);
94
95impl FromSql for Embedding {
96 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
97 let bytes = value.as_blob()?;
98 let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
99 if embedding.is_err() {
100 return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
101 }
102 return Ok(Embedding(embedding.unwrap()));
103 }
104}
105
106pub struct VectorDatabase {
107 db: rusqlite::Connection,
108}
109
110impl VectorDatabase {
111 pub fn new(path: &str) -> Result<Self> {
112 let this = Self {
113 db: rusqlite::Connection::open(path)?,
114 };
115 this.initialize_database()?;
116 Ok(this)
117 }
118
119 fn initialize_database(&self) -> Result<()> {
120 rusqlite::vtab::array::load_module(&self.db)?;
121
122 // This will create the database if it doesnt exist
123
124 // Initialize Vector Databasing Tables
125 self.db.execute(
126 "CREATE TABLE IF NOT EXISTS worktrees (
127 id INTEGER PRIMARY KEY AUTOINCREMENT,
128 absolute_path VARCHAR NOT NULL
129 );
130 CREATE UNIQUE INDEX IF NOT EXISTS worktrees_absolute_path ON worktrees (absolute_path);
131 ",
132 [],
133 )?;
134
135 self.db.execute(
136 "CREATE TABLE IF NOT EXISTS files (
137 id INTEGER PRIMARY KEY AUTOINCREMENT,
138 worktree_id INTEGER NOT NULL,
139 relative_path VARCHAR NOT NULL,
140 sha1 BLOB NOT NULL,
141 FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
142 )",
143 [],
144 )?;
145
146 self.db.execute(
147 "CREATE TABLE IF NOT EXISTS documents (
148 id INTEGER PRIMARY KEY AUTOINCREMENT,
149 file_id INTEGER NOT NULL,
150 offset INTEGER NOT NULL,
151 name VARCHAR NOT NULL,
152 embedding BLOB NOT NULL,
153 FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
154 )",
155 [],
156 )?;
157
158 Ok(())
159 }
160
161 pub fn insert_file(&self, worktree_id: i64, indexed_file: IndexedFile) -> Result<()> {
162 // Write to files table, and return generated id.
163 log::info!("Inserting File!");
164 self.db.execute(
165 "
166 DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
167 ",
168 params![worktree_id, indexed_file.path.to_str()],
169 )?;
170 self.db.execute(
171 "
172 INSERT INTO files (worktree_id, relative_path, sha1) VALUES (?1, ?2, $3);
173 ",
174 params![worktree_id, indexed_file.path.to_str(), indexed_file.sha1],
175 )?;
176
177 let file_id = self.db.last_insert_rowid();
178
179 // Currently inserting at approximately 3400 documents a second
180 // I imagine we can speed this up with a bulk insert of some kind.
181 for document in indexed_file.documents {
182 let embedding_blob = bincode::serialize(&document.embedding)?;
183
184 self.db.execute(
185 "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
186 params![
187 file_id,
188 document.offset.to_string(),
189 document.name,
190 embedding_blob
191 ],
192 )?;
193 }
194
195 Ok(())
196 }
197
198 pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
199 // Check that the absolute path doesnt exist
200 let mut worktree_query = self
201 .db
202 .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
203
204 let worktree_id = worktree_query
205 .query_row(params![worktree_root_path.to_string_lossy()], |row| {
206 Ok(row.get::<_, i64>(0)?)
207 })
208 .map_err(|err| anyhow!(err));
209
210 if worktree_id.is_ok() {
211 return worktree_id;
212 }
213
214 // If worktree_id is Err, insert new worktree
215 self.db.execute(
216 "
217 INSERT into worktrees (absolute_path) VALUES (?1)
218 ",
219 params![worktree_root_path.to_string_lossy()],
220 )?;
221 Ok(self.db.last_insert_rowid())
222 }
223
224 pub fn get_file_hashes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, FileSha1>> {
225 let mut statement = self.db.prepare(
226 "SELECT relative_path, sha1 FROM files WHERE worktree_id = ?1 ORDER BY relative_path",
227 )?;
228 let mut result: HashMap<PathBuf, FileSha1> = HashMap::new();
229 for row in statement.query_map(params![worktree_id], |row| {
230 Ok((row.get::<_, String>(0)?.into(), row.get(1)?))
231 })? {
232 let row = row?;
233 result.insert(row.0, row.1);
234 }
235 Ok(result)
236 }
237
238 pub fn get_files(&self) -> Result<HashMap<usize, FileRecord>> {
239 let mut query_statement = self
240 .db
241 .prepare("SELECT id, relative_path, sha1 FROM files")?;
242 let result_iter = query_statement.query_map([], |row| {
243 Ok(FileRecord {
244 id: row.get(0)?,
245 relative_path: row.get(1)?,
246 sha1: row.get(2)?,
247 })
248 })?;
249
250 let mut pages: HashMap<usize, FileRecord> = HashMap::new();
251 for result in result_iter {
252 let result = result?;
253 pages.insert(result.id, result);
254 }
255
256 Ok(pages)
257 }
258
259 pub fn for_each_document(
260 &self,
261 worktree_id: i64,
262 mut f: impl FnMut(i64, Embedding),
263 ) -> Result<()> {
264 let mut query_statement = self.db.prepare("SELECT id, embedding FROM documents")?;
265 query_statement
266 .query_map(params![], |row| Ok((row.get(0)?, row.get(1)?)))?
267 .filter_map(|row| row.ok())
268 .for_each(|row| f(row.0, row.1));
269 Ok(())
270 }
271
272 pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(PathBuf, usize, String)>> {
273 let mut statement = self.db.prepare(
274 "
275 SELECT
276 documents.id, files.relative_path, documents.offset, documents.name
277 FROM
278 documents, files
279 WHERE
280 documents.file_id = files.id AND
281 documents.id in rarray(?)
282 ",
283 )?;
284
285 let result_iter = statement.query_map(
286 params![std::rc::Rc::new(
287 ids.iter()
288 .copied()
289 .map(|v| rusqlite::types::Value::from(v))
290 .collect::<Vec<_>>()
291 )],
292 |row| {
293 Ok((
294 row.get::<_, i64>(0)?,
295 row.get::<_, String>(1)?.into(),
296 row.get(2)?,
297 row.get(3)?,
298 ))
299 },
300 )?;
301
302 let mut values_by_id = HashMap::<i64, (PathBuf, usize, String)>::default();
303 for row in result_iter {
304 let (id, path, offset, name) = row?;
305 values_by_id.insert(id, (path, offset, name));
306 }
307
308 let mut results = Vec::with_capacity(ids.len());
309 for id in ids {
310 let (path, offset, name) = values_by_id
311 .remove(id)
312 .ok_or(anyhow!("missing document id {}", id))?;
313 results.push((path, offset, name));
314 }
315
316 Ok(results)
317 }
318
319 pub fn get_documents(&self) -> Result<HashMap<usize, DocumentRecord>> {
320 let mut query_statement = self
321 .db
322 .prepare("SELECT id, file_id, offset, name, embedding FROM documents")?;
323 let result_iter = query_statement.query_map([], |row| {
324 Ok(DocumentRecord {
325 id: row.get(0)?,
326 file_id: row.get(1)?,
327 offset: row.get(2)?,
328 name: row.get(3)?,
329 embedding: row.get(4)?,
330 })
331 })?;
332
333 let mut documents: HashMap<usize, DocumentRecord> = HashMap::new();
334 for result in result_iter {
335 let result = result?;
336 documents.insert(result.id, result);
337 }
338
339 return Ok(documents);
340 }
341}