1use crate::{
2 parsing::{Span, SpanDigest},
3 SEMANTIC_INDEX_VERSION,
4};
5use ai::embedding::Embedding;
6use anyhow::{anyhow, Context, Result};
7use collections::HashMap;
8use futures::channel::oneshot;
9use gpui::BackgroundExecutor;
10use ndarray::{Array1, Array2};
11use ordered_float::OrderedFloat;
12use project::Fs;
13use rpc::proto::Timestamp;
14use rusqlite::params;
15use rusqlite::types::Value;
16use std::{
17 future::Future,
18 ops::Range,
19 path::{Path, PathBuf},
20 rc::Rc,
21 sync::Arc,
22 time::SystemTime,
23};
24use util::{paths::PathMatcher, TryFutureExt};
25
26pub fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
27 let mut indices = (0..data.len()).collect::<Vec<_>>();
28 indices.sort_by_key(|&i| &data[i]);
29 indices.reverse();
30 indices
31}
32
33#[derive(Debug)]
34pub struct FileRecord {
35 pub id: usize,
36 pub relative_path: String,
37 pub mtime: Timestamp,
38}
39
40#[derive(Clone)]
41pub struct VectorDatabase {
42 path: Arc<Path>,
43 transactions:
44 smol::channel::Sender<Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>>,
45}
46
47impl VectorDatabase {
48 pub async fn new(
49 fs: Arc<dyn Fs>,
50 path: Arc<Path>,
51 executor: BackgroundExecutor,
52 ) -> Result<Self> {
53 if let Some(db_directory) = path.parent() {
54 fs.create_dir(db_directory).await?;
55 }
56
57 let (transactions_tx, transactions_rx) = smol::channel::unbounded::<
58 Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>,
59 >();
60 executor
61 .spawn({
62 let path = path.clone();
63 async move {
64 let mut connection = rusqlite::Connection::open(&path)?;
65
66 connection.pragma_update(None, "journal_mode", "wal")?;
67 connection.pragma_update(None, "synchronous", "normal")?;
68 connection.pragma_update(None, "cache_size", 1000000)?;
69 connection.pragma_update(None, "temp_store", "MEMORY")?;
70
71 while let Ok(transaction) = transactions_rx.recv().await {
72 transaction(&mut connection);
73 }
74
75 anyhow::Ok(())
76 }
77 .log_err()
78 })
79 .detach();
80 let this = Self {
81 transactions: transactions_tx,
82 path,
83 };
84 this.initialize_database().await?;
85 Ok(this)
86 }
87
88 pub fn path(&self) -> &Arc<Path> {
89 &self.path
90 }
91
92 fn transact<F, T>(&self, f: F) -> impl Future<Output = Result<T>>
93 where
94 F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result<T>,
95 T: 'static + Send,
96 {
97 let (tx, rx) = oneshot::channel();
98 let transactions = self.transactions.clone();
99 async move {
100 if transactions
101 .send(Box::new(|connection| {
102 let result = connection
103 .transaction()
104 .map_err(|err| anyhow!(err))
105 .and_then(|transaction| {
106 let result = f(&transaction)?;
107 transaction.commit()?;
108 Ok(result)
109 });
110 let _ = tx.send(result);
111 }))
112 .await
113 .is_err()
114 {
115 return Err(anyhow!("connection was dropped"))?;
116 }
117 rx.await?
118 }
119 }
120
121 fn initialize_database(&self) -> impl Future<Output = Result<()>> {
122 self.transact(|db| {
123 rusqlite::vtab::array::load_module(&db)?;
124
125 // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
126 let version_query = db.prepare("SELECT version from semantic_index_config");
127 let version = version_query
128 .and_then(|mut query| query.query_row([], |row| row.get::<_, i64>(0)));
129 if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) {
130 log::trace!("vector database schema up to date");
131 return Ok(());
132 }
133
134 log::trace!("vector database schema out of date. updating...");
135 // We renamed the `documents` table to `spans`, so we want to drop
136 // `documents` without recreating it if it exists.
137 db.execute("DROP TABLE IF EXISTS documents", [])
138 .context("failed to drop 'documents' table")?;
139 db.execute("DROP TABLE IF EXISTS spans", [])
140 .context("failed to drop 'spans' table")?;
141 db.execute("DROP TABLE IF EXISTS files", [])
142 .context("failed to drop 'files' table")?;
143 db.execute("DROP TABLE IF EXISTS worktrees", [])
144 .context("failed to drop 'worktrees' table")?;
145 db.execute("DROP TABLE IF EXISTS semantic_index_config", [])
146 .context("failed to drop 'semantic_index_config' table")?;
147
148 // Initialize Vector Databasing Tables
149 db.execute(
150 "CREATE TABLE semantic_index_config (
151 version INTEGER NOT NULL
152 )",
153 [],
154 )?;
155
156 db.execute(
157 "INSERT INTO semantic_index_config (version) VALUES (?1)",
158 params![SEMANTIC_INDEX_VERSION],
159 )?;
160
161 db.execute(
162 "CREATE TABLE worktrees (
163 id INTEGER PRIMARY KEY AUTOINCREMENT,
164 absolute_path VARCHAR NOT NULL
165 );
166 CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
167 ",
168 [],
169 )?;
170
171 db.execute(
172 "CREATE TABLE files (
173 id INTEGER PRIMARY KEY AUTOINCREMENT,
174 worktree_id INTEGER NOT NULL,
175 relative_path VARCHAR NOT NULL,
176 mtime_seconds INTEGER NOT NULL,
177 mtime_nanos INTEGER NOT NULL,
178 FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
179 )",
180 [],
181 )?;
182
183 db.execute(
184 "CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)",
185 [],
186 )?;
187
188 db.execute(
189 "CREATE TABLE spans (
190 id INTEGER PRIMARY KEY AUTOINCREMENT,
191 file_id INTEGER NOT NULL,
192 start_byte INTEGER NOT NULL,
193 end_byte INTEGER NOT NULL,
194 name VARCHAR NOT NULL,
195 embedding BLOB NOT NULL,
196 digest BLOB NOT NULL,
197 FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
198 )",
199 [],
200 )?;
201 db.execute(
202 "CREATE INDEX spans_digest ON spans (digest)",
203 [],
204 )?;
205
206 log::trace!("vector database initialized with updated schema.");
207 Ok(())
208 })
209 }
210
211 pub fn delete_file(
212 &self,
213 worktree_id: i64,
214 delete_path: Arc<Path>,
215 ) -> impl Future<Output = Result<()>> {
216 self.transact(move |db| {
217 db.execute(
218 "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
219 params![worktree_id, delete_path.to_str()],
220 )?;
221 Ok(())
222 })
223 }
224
225 pub fn insert_file(
226 &self,
227 worktree_id: i64,
228 path: Arc<Path>,
229 mtime: SystemTime,
230 spans: Vec<Span>,
231 ) -> impl Future<Output = Result<()>> {
232 self.transact(move |db| {
233 // Return the existing ID, if both the file and mtime match
234 let mtime = Timestamp::from(mtime);
235
236 db.execute(
237 "
238 REPLACE INTO files
239 (worktree_id, relative_path, mtime_seconds, mtime_nanos)
240 VALUES (?1, ?2, ?3, ?4)
241 ",
242 params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
243 )?;
244
245 let file_id = db.last_insert_rowid();
246
247 let mut query = db.prepare(
248 "
249 INSERT INTO spans
250 (file_id, start_byte, end_byte, name, embedding, digest)
251 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
252 ",
253 )?;
254
255 for span in spans {
256 query.execute(params![
257 file_id,
258 span.range.start.to_string(),
259 span.range.end.to_string(),
260 span.name,
261 span.embedding,
262 span.digest
263 ])?;
264 }
265
266 Ok(())
267 })
268 }
269
270 pub fn worktree_previously_indexed(
271 &self,
272 worktree_root_path: &Path,
273 ) -> impl Future<Output = Result<bool>> {
274 let worktree_root_path = worktree_root_path.to_string_lossy().into_owned();
275 self.transact(move |db| {
276 let mut worktree_query =
277 db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
278 let worktree_id =
279 worktree_query.query_row(params![worktree_root_path], |row| row.get::<_, i64>(0));
280
281 Ok(worktree_id.is_ok())
282 })
283 }
284
285 pub fn embeddings_for_digests(
286 &self,
287 digests: Vec<SpanDigest>,
288 ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
289 self.transact(move |db| {
290 let mut query = db.prepare(
291 "
292 SELECT digest, embedding
293 FROM spans
294 WHERE digest IN rarray(?)
295 ",
296 )?;
297 let mut embeddings_by_digest = HashMap::default();
298 let digests = Rc::new(
299 digests
300 .into_iter()
301 .map(|digest| Value::Blob(digest.0.to_vec()))
302 .collect::<Vec<_>>(),
303 );
304 let rows = query.query_map(params![digests], |row| {
305 Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
306 })?;
307
308 for (digest, embedding) in rows.flatten() {
309 embeddings_by_digest.insert(digest, embedding);
310 }
311
312 Ok(embeddings_by_digest)
313 })
314 }
315
316 pub fn embeddings_for_files(
317 &self,
318 worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
319 ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
320 self.transact(move |db| {
321 let mut query = db.prepare(
322 "
323 SELECT digest, embedding
324 FROM spans
325 LEFT JOIN files ON files.id = spans.file_id
326 WHERE files.worktree_id = ? AND files.relative_path IN rarray(?)
327 ",
328 )?;
329 let mut embeddings_by_digest = HashMap::default();
330 for (worktree_id, file_paths) in worktree_id_file_paths {
331 let file_paths = Rc::new(
332 file_paths
333 .into_iter()
334 .map(|p| Value::Text(p.to_string_lossy().into_owned()))
335 .collect::<Vec<_>>(),
336 );
337 let rows = query.query_map(params![worktree_id, file_paths], |row| {
338 Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
339 })?;
340
341 for (digest, embedding) in rows.flatten() {
342 embeddings_by_digest.insert(digest, embedding);
343 }
344 }
345
346 Ok(embeddings_by_digest)
347 })
348 }
349
350 pub fn find_or_create_worktree(
351 &self,
352 worktree_root_path: Arc<Path>,
353 ) -> impl Future<Output = Result<i64>> {
354 self.transact(move |db| {
355 let mut worktree_query =
356 db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
357 let worktree_id = worktree_query
358 .query_row(params![worktree_root_path.to_string_lossy()], |row| {
359 row.get::<_, i64>(0)
360 });
361
362 if worktree_id.is_ok() {
363 return Ok(worktree_id?);
364 }
365
366 // If worktree_id is Err, insert new worktree
367 db.execute(
368 "INSERT into worktrees (absolute_path) VALUES (?1)",
369 params![worktree_root_path.to_string_lossy()],
370 )?;
371 Ok(db.last_insert_rowid())
372 })
373 }
374
375 pub fn get_file_mtimes(
376 &self,
377 worktree_id: i64,
378 ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
379 self.transact(move |db| {
380 let mut statement = db.prepare(
381 "
382 SELECT relative_path, mtime_seconds, mtime_nanos
383 FROM files
384 WHERE worktree_id = ?1
385 ORDER BY relative_path",
386 )?;
387 let mut result: HashMap<PathBuf, SystemTime> = HashMap::default();
388 for row in statement.query_map(params![worktree_id], |row| {
389 Ok((
390 row.get::<_, String>(0)?.into(),
391 Timestamp {
392 seconds: row.get(1)?,
393 nanos: row.get(2)?,
394 }
395 .into(),
396 ))
397 })? {
398 let row = row?;
399 result.insert(row.0, row.1);
400 }
401 Ok(result)
402 })
403 }
404
405 pub fn top_k_search(
406 &self,
407 query_embedding: &Embedding,
408 limit: usize,
409 file_ids: &[i64],
410 ) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
411 let file_ids = file_ids.to_vec();
412 let query = query_embedding.clone().0;
413 let query = Array1::from_vec(query);
414 self.transact(move |db| {
415 let mut query_statement = db.prepare(
416 "
417 SELECT
418 id, embedding
419 FROM
420 spans
421 WHERE
422 file_id IN rarray(?)
423 ",
424 )?;
425
426 let deserialized_rows = query_statement
427 .query_map(params![ids_to_sql(&file_ids)], |row| {
428 Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?))
429 })?
430 .filter_map(|row| row.ok())
431 .collect::<Vec<(usize, Embedding)>>();
432
433 if deserialized_rows.len() == 0 {
434 return Ok(Vec::new());
435 }
436
437 // Get Length of Embeddings Returned
438 let embedding_len = deserialized_rows[0].1 .0.len();
439
440 let batch_n = 1000;
441 let mut batches = Vec::new();
442 let mut batch_ids = Vec::new();
443 let mut batch_embeddings: Vec<f32> = Vec::new();
444 deserialized_rows.iter().for_each(|(id, embedding)| {
445 batch_ids.push(id);
446 batch_embeddings.extend(&embedding.0);
447
448 if batch_ids.len() == batch_n {
449 let embeddings = std::mem::take(&mut batch_embeddings);
450 let ids = std::mem::take(&mut batch_ids);
451 let array = Array2::from_shape_vec((ids.len(), embedding_len), embeddings);
452 match array {
453 Ok(array) => {
454 batches.push((ids, array));
455 }
456 Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
457 }
458 }
459 });
460
461 if batch_ids.len() > 0 {
462 let array = Array2::from_shape_vec(
463 (batch_ids.len(), embedding_len),
464 batch_embeddings.clone(),
465 );
466 match array {
467 Ok(array) => {
468 batches.push((batch_ids.clone(), array));
469 }
470 Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
471 }
472 }
473
474 let mut ids: Vec<usize> = Vec::new();
475 let mut results = Vec::new();
476 for (batch_ids, array) in batches {
477 let scores = array
478 .dot(&query.t())
479 .to_vec()
480 .iter()
481 .map(|score| OrderedFloat(*score))
482 .collect::<Vec<OrderedFloat<f32>>>();
483 results.extend(scores);
484 ids.extend(batch_ids);
485 }
486
487 let sorted_idx = argsort(&results);
488 let mut sorted_results = Vec::new();
489 let last_idx = limit.min(sorted_idx.len());
490 for idx in &sorted_idx[0..last_idx] {
491 sorted_results.push((ids[*idx] as i64, results[*idx]))
492 }
493
494 Ok(sorted_results)
495 })
496 }
497
498 pub fn retrieve_included_file_ids(
499 &self,
500 worktree_ids: &[i64],
501 includes: &[PathMatcher],
502 excludes: &[PathMatcher],
503 ) -> impl Future<Output = Result<Vec<i64>>> {
504 let worktree_ids = worktree_ids.to_vec();
505 let includes = includes.to_vec();
506 let excludes = excludes.to_vec();
507 self.transact(move |db| {
508 let mut file_query = db.prepare(
509 "
510 SELECT
511 id, relative_path
512 FROM
513 files
514 WHERE
515 worktree_id IN rarray(?)
516 ",
517 )?;
518
519 let mut file_ids = Vec::<i64>::new();
520 let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?;
521
522 while let Some(row) = rows.next()? {
523 let file_id = row.get(0)?;
524 let relative_path = row.get_ref(1)?.as_str()?;
525 let included =
526 includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
527 let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
528 if included && !excluded {
529 file_ids.push(file_id);
530 }
531 }
532
533 anyhow::Ok(file_ids)
534 })
535 }
536
537 pub fn spans_for_ids(
538 &self,
539 ids: &[i64],
540 ) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
541 let ids = ids.to_vec();
542 self.transact(move |db| {
543 let mut statement = db.prepare(
544 "
545 SELECT
546 spans.id,
547 files.worktree_id,
548 files.relative_path,
549 spans.start_byte,
550 spans.end_byte
551 FROM
552 spans, files
553 WHERE
554 spans.file_id = files.id AND
555 spans.id in rarray(?)
556 ",
557 )?;
558
559 let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| {
560 Ok((
561 row.get::<_, i64>(0)?,
562 row.get::<_, i64>(1)?,
563 row.get::<_, String>(2)?.into(),
564 row.get(3)?..row.get(4)?,
565 ))
566 })?;
567
568 let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
569 for row in result_iter {
570 let (id, worktree_id, path, range) = row?;
571 values_by_id.insert(id, (worktree_id, path, range));
572 }
573
574 let mut results = Vec::with_capacity(ids.len());
575 for id in &ids {
576 let value = values_by_id
577 .remove(id)
578 .ok_or(anyhow!("missing span id {}", id))?;
579 results.push(value);
580 }
581
582 Ok(results)
583 })
584 }
585}
586
587fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
588 Rc::new(
589 ids.iter()
590 .copied()
591 .map(|v| rusqlite::types::Value::from(v))
592 .collect::<Vec<_>>(),
593 )
594}