1use crate::{parsing::Document, VECTOR_STORE_VERSION};
2use anyhow::{anyhow, Result};
3use project::Fs;
4use rpc::proto::Timestamp;
5use rusqlite::{
6 params,
7 types::{FromSql, FromSqlResult, ValueRef},
8};
9use std::{
10 cmp::Ordering,
11 collections::HashMap,
12 ops::Range,
13 path::{Path, PathBuf},
14 rc::Rc,
15 sync::Arc,
16 time::SystemTime,
17};
18
19#[derive(Debug)]
20pub struct FileRecord {
21 pub id: usize,
22 pub relative_path: String,
23 pub mtime: Timestamp,
24}
25
26#[derive(Debug)]
27struct Embedding(pub Vec<f32>);
28
29impl FromSql for Embedding {
30 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
31 let bytes = value.as_blob()?;
32 let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
33 if embedding.is_err() {
34 return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
35 }
36 return Ok(Embedding(embedding.unwrap()));
37 }
38}
39
40pub struct VectorDatabase {
41 db: rusqlite::Connection,
42}
43
44impl VectorDatabase {
45 pub async fn new(fs: Arc<dyn Fs>, path: Arc<PathBuf>) -> Result<Self> {
46 if let Some(db_directory) = path.parent() {
47 fs.create_dir(db_directory).await?;
48 }
49
50 let this = Self {
51 db: rusqlite::Connection::open(path.as_path())?,
52 };
53 this.initialize_database()?;
54 Ok(this)
55 }
56
57 fn get_existing_version(&self) -> Result<i64> {
58 let mut version_query = self.db.prepare("SELECT version from vector_store_config")?;
59 version_query
60 .query_row([], |row| Ok(row.get::<_, i64>(0)?))
61 .map_err(|err| anyhow!("version query failed: {err}"))
62 }
63
64 fn initialize_database(&self) -> Result<()> {
65 rusqlite::vtab::array::load_module(&self.db)?;
66
67 if self
68 .get_existing_version()
69 .map_or(false, |version| version == VECTOR_STORE_VERSION as i64)
70 {
71 return Ok(());
72 }
73
74 self.db
75 .execute(
76 "
77 DROP TABLE vector_store_config;
78 DROP TABLE worktrees;
79 DROP TABLE files;
80 DROP TABLE documents;
81 ",
82 [],
83 )
84 .ok();
85
86 // Initialize Vector Databasing Tables
87 self.db.execute(
88 "CREATE TABLE vector_store_config (
89 version INTEGER NOT NULL
90 )",
91 [],
92 )?;
93
94 self.db.execute(
95 "INSERT INTO vector_store_config (version) VALUES (?1)",
96 params![VECTOR_STORE_VERSION],
97 )?;
98
99 self.db.execute(
100 "CREATE TABLE worktrees (
101 id INTEGER PRIMARY KEY AUTOINCREMENT,
102 absolute_path VARCHAR NOT NULL
103 );
104 CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
105 ",
106 [],
107 )?;
108
109 self.db.execute(
110 "CREATE TABLE files (
111 id INTEGER PRIMARY KEY AUTOINCREMENT,
112 worktree_id INTEGER NOT NULL,
113 relative_path VARCHAR NOT NULL,
114 mtime_seconds INTEGER NOT NULL,
115 mtime_nanos INTEGER NOT NULL,
116 FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
117 )",
118 [],
119 )?;
120
121 self.db.execute(
122 "CREATE TABLE documents (
123 id INTEGER PRIMARY KEY AUTOINCREMENT,
124 file_id INTEGER NOT NULL,
125 start_byte INTEGER NOT NULL,
126 end_byte INTEGER NOT NULL,
127 name VARCHAR NOT NULL,
128 embedding BLOB NOT NULL,
129 FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
130 )",
131 [],
132 )?;
133
134 Ok(())
135 }
136
137 pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> {
138 self.db.execute(
139 "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
140 params![worktree_id, delete_path.to_str()],
141 )?;
142 Ok(())
143 }
144
145 pub fn insert_file(
146 &self,
147 worktree_id: i64,
148 path: PathBuf,
149 mtime: SystemTime,
150 documents: Vec<Document>,
151 ) -> Result<()> {
152 // Write to files table, and return generated id.
153 self.db.execute(
154 "
155 DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
156 ",
157 params![worktree_id, path.to_str()],
158 )?;
159 let mtime = Timestamp::from(mtime);
160 self.db.execute(
161 "
162 INSERT INTO files
163 (worktree_id, relative_path, mtime_seconds, mtime_nanos)
164 VALUES
165 (?1, ?2, $3, $4);
166 ",
167 params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
168 )?;
169
170 let file_id = self.db.last_insert_rowid();
171
172 // Currently inserting at approximately 3400 documents a second
173 // I imagine we can speed this up with a bulk insert of some kind.
174 for document in documents {
175 let embedding_blob = bincode::serialize(&document.embedding)?;
176
177 self.db.execute(
178 "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding) VALUES (?1, ?2, ?3, ?4, $5)",
179 params![
180 file_id,
181 document.range.start.to_string(),
182 document.range.end.to_string(),
183 document.name,
184 embedding_blob
185 ],
186 )?;
187 }
188
189 Ok(())
190 }
191
192 pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
193 // Check that the absolute path doesnt exist
194 let mut worktree_query = self
195 .db
196 .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
197
198 let worktree_id = worktree_query
199 .query_row(params![worktree_root_path.to_string_lossy()], |row| {
200 Ok(row.get::<_, i64>(0)?)
201 })
202 .map_err(|err| anyhow!(err));
203
204 if worktree_id.is_ok() {
205 return worktree_id;
206 }
207
208 // If worktree_id is Err, insert new worktree
209 self.db.execute(
210 "
211 INSERT into worktrees (absolute_path) VALUES (?1)
212 ",
213 params![worktree_root_path.to_string_lossy()],
214 )?;
215 Ok(self.db.last_insert_rowid())
216 }
217
218 pub fn get_file_mtimes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, SystemTime>> {
219 let mut statement = self.db.prepare(
220 "
221 SELECT relative_path, mtime_seconds, mtime_nanos
222 FROM files
223 WHERE worktree_id = ?1
224 ORDER BY relative_path",
225 )?;
226 let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
227 for row in statement.query_map(params![worktree_id], |row| {
228 Ok((
229 row.get::<_, String>(0)?.into(),
230 Timestamp {
231 seconds: row.get(1)?,
232 nanos: row.get(2)?,
233 }
234 .into(),
235 ))
236 })? {
237 let row = row?;
238 result.insert(row.0, row.1);
239 }
240 Ok(result)
241 }
242
243 pub fn top_k_search(
244 &self,
245 worktree_ids: &[i64],
246 query_embedding: &Vec<f32>,
247 limit: usize,
248 ) -> Result<Vec<(i64, PathBuf, Range<usize>, String)>> {
249 let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
250 self.for_each_document(&worktree_ids, |id, embedding| {
251 let similarity = dot(&embedding, &query_embedding);
252 let ix = match results
253 .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
254 {
255 Ok(ix) => ix,
256 Err(ix) => ix,
257 };
258 results.insert(ix, (id, similarity));
259 results.truncate(limit);
260 })?;
261
262 let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
263 self.get_documents_by_ids(&ids)
264 }
265
266 fn for_each_document(
267 &self,
268 worktree_ids: &[i64],
269 mut f: impl FnMut(i64, Vec<f32>),
270 ) -> Result<()> {
271 let mut query_statement = self.db.prepare(
272 "
273 SELECT
274 documents.id, documents.embedding
275 FROM
276 documents, files
277 WHERE
278 documents.file_id = files.id AND
279 files.worktree_id IN rarray(?)
280 ",
281 )?;
282
283 query_statement
284 .query_map(params![ids_to_sql(worktree_ids)], |row| {
285 Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
286 })?
287 .filter_map(|row| row.ok())
288 .for_each(|(id, embedding)| f(id, embedding.0));
289 Ok(())
290 }
291
292 fn get_documents_by_ids(
293 &self,
294 ids: &[i64],
295 ) -> Result<Vec<(i64, PathBuf, Range<usize>, String)>> {
296 let mut statement = self.db.prepare(
297 "
298 SELECT
299 documents.id,
300 files.worktree_id,
301 files.relative_path,
302 documents.start_byte,
303 documents.end_byte, documents.name
304 FROM
305 documents, files
306 WHERE
307 documents.file_id = files.id AND
308 documents.id in rarray(?)
309 ",
310 )?;
311
312 let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
313 Ok((
314 row.get::<_, i64>(0)?,
315 row.get::<_, i64>(1)?,
316 row.get::<_, String>(2)?.into(),
317 row.get(3)?..row.get(4)?,
318 row.get(5)?,
319 ))
320 })?;
321
322 let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>, String)>::default();
323 for row in result_iter {
324 let (id, worktree_id, path, range, name) = row?;
325 values_by_id.insert(id, (worktree_id, path, range, name));
326 }
327
328 let mut results = Vec::with_capacity(ids.len());
329 for id in ids {
330 let value = values_by_id
331 .remove(id)
332 .ok_or(anyhow!("missing document id {}", id))?;
333 results.push(value);
334 }
335
336 Ok(results)
337 }
338}
339
340fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
341 Rc::new(
342 ids.iter()
343 .copied()
344 .map(|v| rusqlite::types::Value::from(v))
345 .collect::<Vec<_>>(),
346 )
347}
348
349pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
350 let len = vec_a.len();
351 assert_eq!(len, vec_b.len());
352
353 let mut result = 0.0;
354 unsafe {
355 matrixmultiply::sgemm(
356 1,
357 len,
358 1,
359 1.0,
360 vec_a.as_ptr(),
361 len as isize,
362 1,
363 vec_b.as_ptr(),
364 1,
365 len as isize,
366 0.0,
367 &mut result as *mut f32,
368 1,
369 1,
370 );
371 }
372 result
373}