1use std::{
2 collections::HashMap,
3 path::{Path, PathBuf},
4 rc::Rc,
5};
6
7use anyhow::{anyhow, Result};
8
9use rusqlite::{
10 params,
11 types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
12 ToSql,
13};
14use sha1::{Digest, Sha1};
15
16use crate::IndexedFile;
17
18// This is saving to a local database store within the users dev zed path
19// Where do we want this to sit?
20// Assuming near where the workspace DB sits.
21pub const VECTOR_DB_URL: &str = "embeddings_db";
22
23// Note this is not an appropriate document
24#[derive(Debug)]
25pub struct DocumentRecord {
26 pub id: usize,
27 pub file_id: usize,
28 pub offset: usize,
29 pub name: String,
30 pub embedding: Embedding,
31}
32
33#[derive(Debug)]
34pub struct FileRecord {
35 pub id: usize,
36 pub relative_path: String,
37 pub sha1: FileSha1,
38}
39
40#[derive(Debug)]
41pub struct FileSha1(pub Vec<u8>);
42
43impl FileSha1 {
44 pub fn from_str(content: String) -> Self {
45 let mut hasher = Sha1::new();
46 hasher.update(content);
47 let sha1 = hasher.finalize()[..]
48 .into_iter()
49 .map(|val| val.to_owned())
50 .collect::<Vec<u8>>();
51 return FileSha1(sha1);
52 }
53
54 pub fn equals(&self, content: &String) -> bool {
55 let mut hasher = Sha1::new();
56 hasher.update(content);
57 let sha1 = hasher.finalize()[..]
58 .into_iter()
59 .map(|val| val.to_owned())
60 .collect::<Vec<u8>>();
61
62 let equal = self
63 .0
64 .clone()
65 .into_iter()
66 .zip(sha1)
67 .filter(|&(a, b)| a == b)
68 .count()
69 == self.0.len();
70
71 equal
72 }
73}
74
75impl ToSql for FileSha1 {
76 fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
77 return self.0.to_sql();
78 }
79}
80
81impl FromSql for FileSha1 {
82 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
83 let bytes = value.as_blob()?;
84 Ok(FileSha1(
85 bytes
86 .into_iter()
87 .map(|val| val.to_owned())
88 .collect::<Vec<u8>>(),
89 ))
90 }
91}
92
93#[derive(Debug)]
94pub struct Embedding(pub Vec<f32>);
95
96impl FromSql for Embedding {
97 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
98 let bytes = value.as_blob()?;
99 let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
100 if embedding.is_err() {
101 return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
102 }
103 return Ok(Embedding(embedding.unwrap()));
104 }
105}
106
107pub struct VectorDatabase {
108 db: rusqlite::Connection,
109}
110
111impl VectorDatabase {
112 pub fn new(path: &str) -> Result<Self> {
113 let this = Self {
114 db: rusqlite::Connection::open(path)?,
115 };
116 this.initialize_database()?;
117 Ok(this)
118 }
119
120 fn initialize_database(&self) -> Result<()> {
121 rusqlite::vtab::array::load_module(&self.db)?;
122
123 // This will create the database if it doesnt exist
124
125 // Initialize Vector Databasing Tables
126 self.db.execute(
127 "CREATE TABLE IF NOT EXISTS worktrees (
128 id INTEGER PRIMARY KEY AUTOINCREMENT,
129 absolute_path VARCHAR NOT NULL
130 );
131 CREATE UNIQUE INDEX IF NOT EXISTS worktrees_absolute_path ON worktrees (absolute_path);
132 ",
133 [],
134 )?;
135
136 self.db.execute(
137 "CREATE TABLE IF NOT EXISTS files (
138 id INTEGER PRIMARY KEY AUTOINCREMENT,
139 worktree_id INTEGER NOT NULL,
140 relative_path VARCHAR NOT NULL,
141 sha1 BLOB NOT NULL,
142 FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
143 )",
144 [],
145 )?;
146
147 self.db.execute(
148 "CREATE TABLE IF NOT EXISTS documents (
149 id INTEGER PRIMARY KEY AUTOINCREMENT,
150 file_id INTEGER NOT NULL,
151 offset INTEGER NOT NULL,
152 name VARCHAR NOT NULL,
153 embedding BLOB NOT NULL,
154 FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
155 )",
156 [],
157 )?;
158
159 Ok(())
160 }
161
162 pub fn insert_file(&self, worktree_id: i64, indexed_file: IndexedFile) -> Result<()> {
163 // Write to files table, and return generated id.
164 log::info!("Inserting File!");
165 self.db.execute(
166 "
167 DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
168 ",
169 params![worktree_id, indexed_file.path.to_str()],
170 )?;
171 self.db.execute(
172 "
173 INSERT INTO files (worktree_id, relative_path, sha1) VALUES (?1, ?2, $3);
174 ",
175 params![worktree_id, indexed_file.path.to_str(), indexed_file.sha1],
176 )?;
177
178 let file_id = self.db.last_insert_rowid();
179
180 // Currently inserting at approximately 3400 documents a second
181 // I imagine we can speed this up with a bulk insert of some kind.
182 for document in indexed_file.documents {
183 let embedding_blob = bincode::serialize(&document.embedding)?;
184
185 self.db.execute(
186 "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
187 params![
188 file_id,
189 document.offset.to_string(),
190 document.name,
191 embedding_blob
192 ],
193 )?;
194 }
195
196 Ok(())
197 }
198
199 pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
200 // Check that the absolute path doesnt exist
201 let mut worktree_query = self
202 .db
203 .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
204
205 let worktree_id = worktree_query
206 .query_row(params![worktree_root_path.to_string_lossy()], |row| {
207 Ok(row.get::<_, i64>(0)?)
208 })
209 .map_err(|err| anyhow!(err));
210
211 if worktree_id.is_ok() {
212 return worktree_id;
213 }
214
215 // If worktree_id is Err, insert new worktree
216 self.db.execute(
217 "
218 INSERT into worktrees (absolute_path) VALUES (?1)
219 ",
220 params![worktree_root_path.to_string_lossy()],
221 )?;
222 Ok(self.db.last_insert_rowid())
223 }
224
225 pub fn get_file_hashes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, FileSha1>> {
226 let mut statement = self.db.prepare(
227 "SELECT relative_path, sha1 FROM files WHERE worktree_id = ?1 ORDER BY relative_path",
228 )?;
229 let mut result: HashMap<PathBuf, FileSha1> = HashMap::new();
230 for row in statement.query_map(params![worktree_id], |row| {
231 Ok((row.get::<_, String>(0)?.into(), row.get(1)?))
232 })? {
233 let row = row?;
234 result.insert(row.0, row.1);
235 }
236 Ok(result)
237 }
238
239 pub fn for_each_document(
240 &self,
241 worktree_ids: &[i64],
242 mut f: impl FnMut(i64, Embedding),
243 ) -> Result<()> {
244 let mut query_statement = self.db.prepare(
245 "
246 SELECT
247 documents.id, documents.embedding
248 FROM
249 documents, files
250 WHERE
251 documents.file_id = files.id AND
252 files.worktree_id IN rarray(?)
253 ",
254 )?;
255 query_statement
256 .query_map(params![ids_to_sql(worktree_ids)], |row| {
257 Ok((row.get(0)?, row.get(1)?))
258 })?
259 .filter_map(|row| row.ok())
260 .for_each(|row| f(row.0, row.1));
261 Ok(())
262 }
263
264 pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, usize, String)>> {
265 let mut statement = self.db.prepare(
266 "
267 SELECT
268 documents.id, files.worktree_id, files.relative_path, documents.offset, documents.name
269 FROM
270 documents, files
271 WHERE
272 documents.file_id = files.id AND
273 documents.id in rarray(?)
274 ",
275 )?;
276
277 let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
278 Ok((
279 row.get::<_, i64>(0)?,
280 row.get::<_, i64>(1)?,
281 row.get::<_, String>(2)?.into(),
282 row.get(3)?,
283 row.get(4)?,
284 ))
285 })?;
286
287 let mut values_by_id = HashMap::<i64, (i64, PathBuf, usize, String)>::default();
288 for row in result_iter {
289 let (id, worktree_id, path, offset, name) = row?;
290 values_by_id.insert(id, (worktree_id, path, offset, name));
291 }
292
293 let mut results = Vec::with_capacity(ids.len());
294 for id in ids {
295 let value = values_by_id
296 .remove(id)
297 .ok_or(anyhow!("missing document id {}", id))?;
298 results.push(value);
299 }
300
301 Ok(results)
302 }
303}
304
305fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
306 Rc::new(
307 ids.iter()
308 .copied()
309 .map(|v| rusqlite::types::Value::from(v))
310 .collect::<Vec<_>>(),
311 )
312}