1use std::{
2 collections::HashMap,
3 path::{Path, PathBuf},
4 rc::Rc,
5};
6
7use anyhow::{anyhow, Result};
8
9use rusqlite::{
10 params,
11 types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
12 ToSql,
13};
14use sha1::{Digest, Sha1};
15
16use crate::IndexedFile;
17
18// Note this is not an appropriate document
19#[derive(Debug)]
20pub struct DocumentRecord {
21 pub id: usize,
22 pub file_id: usize,
23 pub offset: usize,
24 pub name: String,
25 pub embedding: Embedding,
26}
27
28#[derive(Debug)]
29pub struct FileRecord {
30 pub id: usize,
31 pub relative_path: String,
32 pub sha1: FileSha1,
33}
34
35#[derive(Debug)]
36pub struct FileSha1(pub Vec<u8>);
37
38impl FileSha1 {
39 pub fn from_str(content: String) -> Self {
40 let mut hasher = Sha1::new();
41 hasher.update(content);
42 let sha1 = hasher.finalize()[..]
43 .into_iter()
44 .map(|val| val.to_owned())
45 .collect::<Vec<u8>>();
46 return FileSha1(sha1);
47 }
48
49 pub fn equals(&self, content: &String) -> bool {
50 let mut hasher = Sha1::new();
51 hasher.update(content);
52 let sha1 = hasher.finalize()[..]
53 .into_iter()
54 .map(|val| val.to_owned())
55 .collect::<Vec<u8>>();
56
57 let equal = self
58 .0
59 .clone()
60 .into_iter()
61 .zip(sha1)
62 .filter(|&(a, b)| a == b)
63 .count()
64 == self.0.len();
65
66 equal
67 }
68}
69
70impl ToSql for FileSha1 {
71 fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
72 return self.0.to_sql();
73 }
74}
75
76impl FromSql for FileSha1 {
77 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
78 let bytes = value.as_blob()?;
79 Ok(FileSha1(
80 bytes
81 .into_iter()
82 .map(|val| val.to_owned())
83 .collect::<Vec<u8>>(),
84 ))
85 }
86}
87
88#[derive(Debug)]
89pub struct Embedding(pub Vec<f32>);
90
91impl FromSql for Embedding {
92 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
93 let bytes = value.as_blob()?;
94 let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
95 if embedding.is_err() {
96 return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
97 }
98 return Ok(Embedding(embedding.unwrap()));
99 }
100}
101
102pub struct VectorDatabase {
103 db: rusqlite::Connection,
104}
105
106impl VectorDatabase {
107 pub fn new(path: String) -> Result<Self> {
108 let this = Self {
109 db: rusqlite::Connection::open(path)?,
110 };
111 this.initialize_database()?;
112 Ok(this)
113 }
114
115 fn initialize_database(&self) -> Result<()> {
116 rusqlite::vtab::array::load_module(&self.db)?;
117
118 // This will create the database if it doesnt exist
119
120 // Initialize Vector Databasing Tables
121 self.db.execute(
122 "CREATE TABLE IF NOT EXISTS worktrees (
123 id INTEGER PRIMARY KEY AUTOINCREMENT,
124 absolute_path VARCHAR NOT NULL
125 );
126 CREATE UNIQUE INDEX IF NOT EXISTS worktrees_absolute_path ON worktrees (absolute_path);
127 ",
128 [],
129 )?;
130
131 self.db.execute(
132 "CREATE TABLE IF NOT EXISTS files (
133 id INTEGER PRIMARY KEY AUTOINCREMENT,
134 worktree_id INTEGER NOT NULL,
135 relative_path VARCHAR NOT NULL,
136 sha1 BLOB NOT NULL,
137 FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
138 )",
139 [],
140 )?;
141
142 self.db.execute(
143 "CREATE TABLE IF NOT EXISTS documents (
144 id INTEGER PRIMARY KEY AUTOINCREMENT,
145 file_id INTEGER NOT NULL,
146 offset INTEGER NOT NULL,
147 name VARCHAR NOT NULL,
148 embedding BLOB NOT NULL,
149 FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
150 )",
151 [],
152 )?;
153
154 Ok(())
155 }
156
157 pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> {
158 self.db.execute(
159 "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
160 params![worktree_id, delete_path.to_str()],
161 )?;
162 Ok(())
163 }
164
165 pub fn insert_file(&self, worktree_id: i64, indexed_file: IndexedFile) -> Result<()> {
166 // Write to files table, and return generated id.
167 self.db.execute(
168 "
169 DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
170 ",
171 params![worktree_id, indexed_file.path.to_str()],
172 )?;
173 self.db.execute(
174 "
175 INSERT INTO files (worktree_id, relative_path, sha1) VALUES (?1, ?2, $3);
176 ",
177 params![worktree_id, indexed_file.path.to_str(), indexed_file.sha1],
178 )?;
179
180 let file_id = self.db.last_insert_rowid();
181
182 // Currently inserting at approximately 3400 documents a second
183 // I imagine we can speed this up with a bulk insert of some kind.
184 for document in indexed_file.documents {
185 let embedding_blob = bincode::serialize(&document.embedding)?;
186
187 self.db.execute(
188 "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
189 params![
190 file_id,
191 document.offset.to_string(),
192 document.name,
193 embedding_blob
194 ],
195 )?;
196 }
197
198 Ok(())
199 }
200
201 pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
202 // Check that the absolute path doesnt exist
203 let mut worktree_query = self
204 .db
205 .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
206
207 let worktree_id = worktree_query
208 .query_row(params![worktree_root_path.to_string_lossy()], |row| {
209 Ok(row.get::<_, i64>(0)?)
210 })
211 .map_err(|err| anyhow!(err));
212
213 if worktree_id.is_ok() {
214 return worktree_id;
215 }
216
217 // If worktree_id is Err, insert new worktree
218 self.db.execute(
219 "
220 INSERT into worktrees (absolute_path) VALUES (?1)
221 ",
222 params![worktree_root_path.to_string_lossy()],
223 )?;
224 Ok(self.db.last_insert_rowid())
225 }
226
227 pub fn get_file_hashes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, FileSha1>> {
228 let mut statement = self.db.prepare(
229 "SELECT relative_path, sha1 FROM files WHERE worktree_id = ?1 ORDER BY relative_path",
230 )?;
231 let mut result: HashMap<PathBuf, FileSha1> = HashMap::new();
232 for row in statement.query_map(params![worktree_id], |row| {
233 Ok((row.get::<_, String>(0)?.into(), row.get(1)?))
234 })? {
235 let row = row?;
236 result.insert(row.0, row.1);
237 }
238 Ok(result)
239 }
240
241 pub fn for_each_document(
242 &self,
243 worktree_ids: &[i64],
244 mut f: impl FnMut(i64, Embedding),
245 ) -> Result<()> {
246 let mut query_statement = self.db.prepare(
247 "
248 SELECT
249 documents.id, documents.embedding
250 FROM
251 documents, files
252 WHERE
253 documents.file_id = files.id AND
254 files.worktree_id IN rarray(?)
255 ",
256 )?;
257 query_statement
258 .query_map(params![ids_to_sql(worktree_ids)], |row| {
259 Ok((row.get(0)?, row.get(1)?))
260 })?
261 .filter_map(|row| row.ok())
262 .for_each(|row| f(row.0, row.1));
263 Ok(())
264 }
265
266 pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, usize, String)>> {
267 let mut statement = self.db.prepare(
268 "
269 SELECT
270 documents.id, files.worktree_id, files.relative_path, documents.offset, documents.name
271 FROM
272 documents, files
273 WHERE
274 documents.file_id = files.id AND
275 documents.id in rarray(?)
276 ",
277 )?;
278
279 let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
280 Ok((
281 row.get::<_, i64>(0)?,
282 row.get::<_, i64>(1)?,
283 row.get::<_, String>(2)?.into(),
284 row.get(3)?,
285 row.get(4)?,
286 ))
287 })?;
288
289 let mut values_by_id = HashMap::<i64, (i64, PathBuf, usize, String)>::default();
290 for row in result_iter {
291 let (id, worktree_id, path, offset, name) = row?;
292 values_by_id.insert(id, (worktree_id, path, offset, name));
293 }
294
295 let mut results = Vec::with_capacity(ids.len());
296 for id in ids {
297 let value = values_by_id
298 .remove(id)
299 .ok_or(anyhow!("missing document id {}", id))?;
300 results.push(value);
301 }
302
303 Ok(results)
304 }
305}
306
307fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
308 Rc::new(
309 ids.iter()
310 .copied()
311 .map(|v| rusqlite::types::Value::from(v))
312 .collect::<Vec<_>>(),
313 )
314}