1use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
2use anyhow::{anyhow, Context, Result};
3use project::{search::PathMatcher, Fs};
4use rpc::proto::Timestamp;
5use rusqlite::{
6 params,
7 types::{FromSql, FromSqlResult, ValueRef},
8};
9use std::{
10 cmp::Ordering,
11 collections::HashMap,
12 ops::Range,
13 path::{Path, PathBuf},
14 rc::Rc,
15 sync::Arc,
16 time::SystemTime,
17};
18
19#[derive(Debug)]
20pub struct FileRecord {
21 pub id: usize,
22 pub relative_path: String,
23 pub mtime: Timestamp,
24}
25
26#[derive(Debug)]
27struct Embedding(pub Vec<f32>);
28
29#[derive(Debug)]
30struct Sha1(pub Vec<u8>);
31
32impl FromSql for Embedding {
33 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
34 let bytes = value.as_blob()?;
35 let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
36 if embedding.is_err() {
37 return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
38 }
39 return Ok(Embedding(embedding.unwrap()));
40 }
41}
42
43impl FromSql for Sha1 {
44 fn column_result(value: ValueRef) -> FromSqlResult<Self> {
45 let bytes = value.as_blob()?;
46 let sha1: Result<Vec<u8>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
47 if sha1.is_err() {
48 return Err(rusqlite::types::FromSqlError::Other(sha1.unwrap_err()));
49 }
50 return Ok(Sha1(sha1.unwrap()));
51 }
52}
53
54pub struct VectorDatabase {
55 db: rusqlite::Connection,
56}
57
58impl VectorDatabase {
59 pub async fn new(fs: Arc<dyn Fs>, path: Arc<PathBuf>) -> Result<Self> {
60 if let Some(db_directory) = path.parent() {
61 fs.create_dir(db_directory).await?;
62 }
63
64 let this = Self {
65 db: rusqlite::Connection::open(path.as_path())?,
66 };
67 this.initialize_database()?;
68 Ok(this)
69 }
70
71 fn get_existing_version(&self) -> Result<i64> {
72 let mut version_query = self
73 .db
74 .prepare("SELECT version from semantic_index_config")?;
75 version_query
76 .query_row([], |row| Ok(row.get::<_, i64>(0)?))
77 .map_err(|err| anyhow!("version query failed: {err}"))
78 }
79
80 fn initialize_database(&self) -> Result<()> {
81 rusqlite::vtab::array::load_module(&self.db)?;
82
83 // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
84 if self
85 .get_existing_version()
86 .map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64)
87 {
88 log::trace!("vector database schema up to date");
89 return Ok(());
90 }
91
92 log::trace!("vector database schema out of date. updating...");
93 self.db
94 .execute("DROP TABLE IF EXISTS documents", [])
95 .context("failed to drop 'documents' table")?;
96 self.db
97 .execute("DROP TABLE IF EXISTS files", [])
98 .context("failed to drop 'files' table")?;
99 self.db
100 .execute("DROP TABLE IF EXISTS worktrees", [])
101 .context("failed to drop 'worktrees' table")?;
102 self.db
103 .execute("DROP TABLE IF EXISTS semantic_index_config", [])
104 .context("failed to drop 'semantic_index_config' table")?;
105
106 // Initialize Vector Databasing Tables
107 self.db.execute(
108 "CREATE TABLE semantic_index_config (
109 version INTEGER NOT NULL
110 )",
111 [],
112 )?;
113
114 self.db.execute(
115 "INSERT INTO semantic_index_config (version) VALUES (?1)",
116 params![SEMANTIC_INDEX_VERSION],
117 )?;
118
119 self.db.execute(
120 "CREATE TABLE worktrees (
121 id INTEGER PRIMARY KEY AUTOINCREMENT,
122 absolute_path VARCHAR NOT NULL
123 );
124 CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
125 ",
126 [],
127 )?;
128
129 self.db.execute(
130 "CREATE TABLE files (
131 id INTEGER PRIMARY KEY AUTOINCREMENT,
132 worktree_id INTEGER NOT NULL,
133 relative_path VARCHAR NOT NULL,
134 mtime_seconds INTEGER NOT NULL,
135 mtime_nanos INTEGER NOT NULL,
136 FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
137 )",
138 [],
139 )?;
140
141 self.db.execute(
142 "CREATE TABLE documents (
143 id INTEGER PRIMARY KEY AUTOINCREMENT,
144 file_id INTEGER NOT NULL,
145 start_byte INTEGER NOT NULL,
146 end_byte INTEGER NOT NULL,
147 name VARCHAR NOT NULL,
148 embedding BLOB NOT NULL,
149 sha1 BLOB NOT NULL,
150 FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
151 )",
152 [],
153 )?;
154
155 log::trace!("vector database initialized with updated schema.");
156 Ok(())
157 }
158
159 pub fn delete_file(&self, worktree_id: i64, delete_path: PathBuf) -> Result<()> {
160 self.db.execute(
161 "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
162 params![worktree_id, delete_path.to_str()],
163 )?;
164 Ok(())
165 }
166
167 pub fn insert_file(
168 &self,
169 worktree_id: i64,
170 path: PathBuf,
171 mtime: SystemTime,
172 documents: Vec<Document>,
173 ) -> Result<()> {
174 // Return the existing ID, if both the file and mtime match
175 let mtime = Timestamp::from(mtime);
176 let mut existing_id_query = self.db.prepare("SELECT id FROM files WHERE worktree_id = ?1 AND relative_path = ?2 AND mtime_seconds = ?3 AND mtime_nanos = ?4")?;
177 let existing_id = existing_id_query
178 .query_row(
179 params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
180 |row| Ok(row.get::<_, i64>(0)?),
181 )
182 .map_err(|err| anyhow!(err));
183 let file_id = if existing_id.is_ok() {
184 // If already exists, just return the existing id
185 existing_id.unwrap()
186 } else {
187 // Delete Existing Row
188 self.db.execute(
189 "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;",
190 params![worktree_id, path.to_str()],
191 )?;
192 self.db.execute("INSERT INTO files (worktree_id, relative_path, mtime_seconds, mtime_nanos) VALUES (?1, ?2, ?3, ?4);", params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos])?;
193 self.db.last_insert_rowid()
194 };
195
196 // Currently inserting at approximately 3400 documents a second
197 // I imagine we can speed this up with a bulk insert of some kind.
198 for document in documents {
199 let embedding_blob = bincode::serialize(&document.embedding)?;
200 let sha_blob = bincode::serialize(&document.sha1)?;
201
202 self.db.execute(
203 "INSERT INTO documents (file_id, start_byte, end_byte, name, embedding, sha1) VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
204 params![
205 file_id,
206 document.range.start.to_string(),
207 document.range.end.to_string(),
208 document.name,
209 embedding_blob,
210 sha_blob
211 ],
212 )?;
213 }
214
215 Ok(())
216 }
217
218 pub fn worktree_previously_indexed(&self, worktree_root_path: &Path) -> Result<bool> {
219 let mut worktree_query = self
220 .db
221 .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
222 let worktree_id = worktree_query
223 .query_row(params![worktree_root_path.to_string_lossy()], |row| {
224 Ok(row.get::<_, i64>(0)?)
225 })
226 .map_err(|err| anyhow!(err));
227
228 if worktree_id.is_ok() {
229 return Ok(true);
230 } else {
231 return Ok(false);
232 }
233 }
234
235 pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
236 // Check that the absolute path doesnt exist
237 let mut worktree_query = self
238 .db
239 .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
240
241 let worktree_id = worktree_query
242 .query_row(params![worktree_root_path.to_string_lossy()], |row| {
243 Ok(row.get::<_, i64>(0)?)
244 })
245 .map_err(|err| anyhow!(err));
246
247 if worktree_id.is_ok() {
248 return worktree_id;
249 }
250
251 // If worktree_id is Err, insert new worktree
252 self.db.execute(
253 "
254 INSERT into worktrees (absolute_path) VALUES (?1)
255 ",
256 params![worktree_root_path.to_string_lossy()],
257 )?;
258 Ok(self.db.last_insert_rowid())
259 }
260
261 pub fn get_file_mtimes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, SystemTime>> {
262 let mut statement = self.db.prepare(
263 "
264 SELECT relative_path, mtime_seconds, mtime_nanos
265 FROM files
266 WHERE worktree_id = ?1
267 ORDER BY relative_path",
268 )?;
269 let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
270 for row in statement.query_map(params![worktree_id], |row| {
271 Ok((
272 row.get::<_, String>(0)?.into(),
273 Timestamp {
274 seconds: row.get(1)?,
275 nanos: row.get(2)?,
276 }
277 .into(),
278 ))
279 })? {
280 let row = row?;
281 result.insert(row.0, row.1);
282 }
283 Ok(result)
284 }
285
286 pub fn top_k_search(
287 &self,
288 query_embedding: &Vec<f32>,
289 limit: usize,
290 file_ids: &[i64],
291 ) -> Result<Vec<(i64, f32)>> {
292 let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
293 self.for_each_document(file_ids, |id, embedding| {
294 let similarity = dot(&embedding, &query_embedding);
295 let ix = match results
296 .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
297 {
298 Ok(ix) => ix,
299 Err(ix) => ix,
300 };
301 results.insert(ix, (id, similarity));
302 results.truncate(limit);
303 })?;
304
305 Ok(results)
306 }
307
308 pub fn retrieve_included_file_ids(
309 &self,
310 worktree_ids: &[i64],
311 includes: &[PathMatcher],
312 excludes: &[PathMatcher],
313 ) -> Result<Vec<i64>> {
314 let mut file_query = self.db.prepare(
315 "
316 SELECT
317 id, relative_path
318 FROM
319 files
320 WHERE
321 worktree_id IN rarray(?)
322 ",
323 )?;
324
325 let mut file_ids = Vec::<i64>::new();
326 let mut rows = file_query.query([ids_to_sql(worktree_ids)])?;
327
328 while let Some(row) = rows.next()? {
329 let file_id = row.get(0)?;
330 let relative_path = row.get_ref(1)?.as_str()?;
331 let included =
332 includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
333 let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
334 if included && !excluded {
335 file_ids.push(file_id);
336 }
337 }
338
339 Ok(file_ids)
340 }
341
342 fn for_each_document(&self, file_ids: &[i64], mut f: impl FnMut(i64, Vec<f32>)) -> Result<()> {
343 let mut query_statement = self.db.prepare(
344 "
345 SELECT
346 id, embedding
347 FROM
348 documents
349 WHERE
350 file_id IN rarray(?)
351 ",
352 )?;
353
354 query_statement
355 .query_map(params![ids_to_sql(&file_ids)], |row| {
356 Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
357 })?
358 .filter_map(|row| row.ok())
359 .for_each(|(id, embedding)| f(id, embedding.0));
360 Ok(())
361 }
362
363 pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
364 let mut statement = self.db.prepare(
365 "
366 SELECT
367 documents.id,
368 files.worktree_id,
369 files.relative_path,
370 documents.start_byte,
371 documents.end_byte
372 FROM
373 documents, files
374 WHERE
375 documents.file_id = files.id AND
376 documents.id in rarray(?)
377 ",
378 )?;
379
380 let result_iter = statement.query_map(params![ids_to_sql(ids)], |row| {
381 Ok((
382 row.get::<_, i64>(0)?,
383 row.get::<_, i64>(1)?,
384 row.get::<_, String>(2)?.into(),
385 row.get(3)?..row.get(4)?,
386 ))
387 })?;
388
389 let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
390 for row in result_iter {
391 let (id, worktree_id, path, range) = row?;
392 values_by_id.insert(id, (worktree_id, path, range));
393 }
394
395 let mut results = Vec::with_capacity(ids.len());
396 for id in ids {
397 let value = values_by_id
398 .remove(id)
399 .ok_or(anyhow!("missing document id {}", id))?;
400 results.push(value);
401 }
402
403 Ok(results)
404 }
405}
406
407fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
408 Rc::new(
409 ids.iter()
410 .copied()
411 .map(|v| rusqlite::types::Value::from(v))
412 .collect::<Vec<_>>(),
413 )
414}
415
416pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
417 let len = vec_a.len();
418 assert_eq!(len, vec_b.len());
419
420 let mut result = 0.0;
421 unsafe {
422 matrixmultiply::sgemm(
423 1,
424 len,
425 1,
426 1.0,
427 vec_a.as_ptr(),
428 len as isize,
429 1,
430 vec_b.as_ptr(),
431 1,
432 len as isize,
433 0.0,
434 &mut result as *mut f32,
435 1,
436 1,
437 );
438 }
439 result
440}