1use std::{
2 collections::HashMap,
3 path::{Path, PathBuf},
4 rc::Rc,
5};
6
7use anyhow::{anyhow, Result};
8
9use rusqlite::{
10 params,
11 types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
12 ToSql,
13};
14use sha1::{Digest, Sha1};
15
16use crate::IndexedFile;
17
18// Note this is not an appropriate document
19#[derive(Debug)]
20pub struct DocumentRecord {
21 pub id: usize,
22 pub file_id: usize,
23 pub offset: usize,
24 pub name: String,
25 pub embedding: Embedding,
26}
27
28#[derive(Debug)]
29pub struct FileRecord {
30 pub id: usize,
31 pub relative_path: String,
32 pub sha1: FileSha1,
33}
34
35#[derive(Debug)]
36pub struct FileSha1(pub Vec<u8>);
37
38impl FileSha1 {
39 pub fn from_str(content: String) -> Self {
40 let mut hasher = Sha1::new();
41 hasher.update(content);
42 let sha1 = hasher.finalize()[..]
43 .into_iter()
44 .map(|val| val.to_owned())
45 .collect::<Vec<u8>>();
46 return FileSha1(sha1);
47 }
48
49 pub fn equals(&self, content: &String) -> bool {
50 let mut hasher = Sha1::new();
51 hasher.update(content);
52 let sha1 = hasher.finalize()[..]
53 .into_iter()
54 .map(|val| val.to_owned())
55 .collect::<Vec<u8>>();
56
57 let equal = self
58 .0
59 .clone()
60 .into_iter()
61 .zip(sha1)
62 .filter(|&(a, b)| a == b)
63 .count()
64 == self.0.len();
65
66 equal
67 }
68}
69
70impl ToSql for FileSha1 {
71 fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
72 return self.0.to_sql();
73 }
74}
75
76impl FromSql for FileSha1 {
77 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
78 let bytes = value.as_blob()?;
79 Ok(FileSha1(
80 bytes
81 .into_iter()
82 .map(|val| val.to_owned())
83 .collect::<Vec<u8>>(),
84 ))
85 }
86}
87
88#[derive(Debug)]
89pub struct Embedding(pub Vec<f32>);
90
91impl FromSql for Embedding {
92 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
93 let bytes = value.as_blob()?;
94 let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
95 if embedding.is_err() {
96 return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
97 }
98 return Ok(Embedding(embedding.unwrap()));
99 }
100}
101
102pub struct VectorDatabase {
103 db: rusqlite::Connection,
104}
105
106impl VectorDatabase {
107 pub fn new(path: String) -> Result<Self> {
108 let this = Self {
109 db: rusqlite::Connection::open(path)?,
110 };
111 this.initialize_database()?;
112 Ok(this)
113 }
114
115 fn initialize_database(&self) -> Result<()> {
116 rusqlite::vtab::array::load_module(&self.db)?;
117
118 // This will create the database if it doesnt exist
119
120 // Initialize Vector Databasing Tables
121 self.db.execute(
122 "CREATE TABLE IF NOT EXISTS worktrees (
123 id INTEGER PRIMARY KEY AUTOINCREMENT,
124 absolute_path VARCHAR NOT NULL
125 );
126 CREATE UNIQUE INDEX IF NOT EXISTS worktrees_absolute_path ON worktrees (absolute_path);
127 ",
128 [],
129 )?;
130
131 self.db.execute(
132 "CREATE TABLE IF NOT EXISTS files (
133 id INTEGER PRIMARY KEY AUTOINCREMENT,
134 worktree_id INTEGER NOT NULL,
135 relative_path VARCHAR NOT NULL,
136 sha1 BLOB NOT NULL,
137 FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
138 )",
139 [],
140 )?;
141
142 self.db.execute(
143 "CREATE TABLE IF NOT EXISTS documents (
144 id INTEGER PRIMARY KEY AUTOINCREMENT,
145 file_id INTEGER NOT NULL,
146 offset INTEGER NOT NULL,
147 name VARCHAR NOT NULL,
148 embedding BLOB NOT NULL,
149 FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
150 )",
151 [],
152 )?;
153
154 Ok(())
155 }
156
157 pub fn insert_file(&self, worktree_id: i64, indexed_file: IndexedFile) -> Result<()> {
158 // Write to files table, and return generated id.
159 log::info!("Inserting File!");
160 self.db.execute(
161 "
162 DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
163 ",
164 params![worktree_id, indexed_file.path.to_str()],
165 )?;
166 self.db.execute(
167 "
168 INSERT INTO files (worktree_id, relative_path, sha1) VALUES (?1, ?2, $3);
169 ",
170 params![worktree_id, indexed_file.path.to_str(), indexed_file.sha1],
171 )?;
172
173 let file_id = self.db.last_insert_rowid();
174
175 // Currently inserting at approximately 3400 documents a second
176 // I imagine we can speed this up with a bulk insert of some kind.
177 for document in indexed_file.documents {
178 let embedding_blob = bincode::serialize(&document.embedding)?;
179
180 self.db.execute(
181 "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
182 params![
183 file_id,
184 document.offset.to_string(),
185 document.name,
186 embedding_blob
187 ],
188 )?;
189 }
190
191 Ok(())
192 }
193
194 pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
195 // Check that the absolute path doesnt exist
196 let mut worktree_query = self
197 .db
198 .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
199
200 let worktree_id = worktree_query
201 .query_row(params![worktree_root_path.to_string_lossy()], |row| {
202 Ok(row.get::<_, i64>(0)?)
203 })
204 .map_err(|err| anyhow!(err));
205
206 if worktree_id.is_ok() {
207 return worktree_id;
208 }
209
210 // If worktree_id is Err, insert new worktree
211 self.db.execute(
212 "
213 INSERT into worktrees (absolute_path) VALUES (?1)
214 ",
215 params![worktree_root_path.to_string_lossy()],
216 )?;
217 Ok(self.db.last_insert_rowid())
218 }
219
220 pub fn get_file_hashes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, FileSha1>> {
221 let mut statement = self.db.prepare(
222 "SELECT relative_path, sha1 FROM files WHERE worktree_id = ?1 ORDER BY relative_path",
223 )?;
224 let mut result: HashMap<PathBuf, FileSha1> = HashMap::new();
225 for row in statement.query_map(params![worktree_id], |row| {
226 Ok((row.get::<_, String>(0)?.into(), row.get(1)?))
227 })? {
228 let row = row?;
229 result.insert(row.0, row.1);
230 }
231 Ok(result)
232 }
233
234 pub fn for_each_document(
235 &self,
236 worktree_ids: &[i64],
237 mut f: impl FnMut(i64, Embedding),
238 ) -> Result<()> {
239 let mut query_statement = self.db.prepare(
240 "
241 SELECT
242 documents.id, documents.embedding
243 FROM
244 documents, files
245 WHERE
246 documents.file_id = files.id AND
247 files.worktree_id IN rarray(?)
248 ",
249 )?;
250 query_statement
251 .query_map(params![ids_to_sql(worktree_ids)], |row| {
252 Ok((row.get(0)?, row.get(1)?))
253 })?
254 .filter_map(|row| row.ok())
255 .for_each(|row| f(row.0, row.1));
256 Ok(())
257 }
258
259 pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, usize, String)>> {
260 let mut statement = self.db.prepare(
261 "
262 SELECT
263 documents.id, files.worktree_id, files.relative_path, documents.offset, documents.name
264 FROM
265 documents, files
266 WHERE
267 documents.file_id = files.id AND
268 documents.id in rarray(?)
269 ",
270 )?;
271
272 let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
273 Ok((
274 row.get::<_, i64>(0)?,
275 row.get::<_, i64>(1)?,
276 row.get::<_, String>(2)?.into(),
277 row.get(3)?,
278 row.get(4)?,
279 ))
280 })?;
281
282 let mut values_by_id = HashMap::<i64, (i64, PathBuf, usize, String)>::default();
283 for row in result_iter {
284 let (id, worktree_id, path, offset, name) = row?;
285 values_by_id.insert(id, (worktree_id, path, offset, name));
286 }
287
288 let mut results = Vec::with_capacity(ids.len());
289 for id in ids {
290 let value = values_by_id
291 .remove(id)
292 .ok_or(anyhow!("missing document id {}", id))?;
293 results.push(value);
294 }
295
296 Ok(results)
297 }
298}
299
300fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
301 Rc::new(
302 ids.iter()
303 .copied()
304 .map(|v| rusqlite::types::Value::from(v))
305 .collect::<Vec<_>>(),
306 )
307}