1use crate::{
2 parsing::{Span, SpanDigest},
3 SEMANTIC_INDEX_VERSION,
4};
5use ai::embedding::Embedding;
6use anyhow::{anyhow, Context, Result};
7use collections::HashMap;
8use futures::channel::oneshot;
9use gpui::BackgroundExecutor;
10use ndarray::{Array1, Array2};
11use ordered_float::OrderedFloat;
12use project::Fs;
13use rpc::proto::Timestamp;
14use rusqlite::params;
15use rusqlite::types::Value;
16use std::{
17 future::Future,
18 ops::Range,
19 path::{Path, PathBuf},
20 rc::Rc,
21 sync::Arc,
22 time::SystemTime,
23};
24use util::{paths::PathMatcher, TryFutureExt};
25
26pub fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
27 let mut indices = (0..data.len()).collect::<Vec<_>>();
28 indices.sort_by_key(|&i| &data[i]);
29 indices.reverse();
30 indices
31}
32
33#[derive(Debug)]
34pub struct FileRecord {
35 pub id: usize,
36 pub relative_path: String,
37 pub mtime: Timestamp,
38}
39
40#[derive(Clone)]
41pub struct VectorDatabase {
42 path: Arc<Path>,
43 transactions:
44 smol::channel::Sender<Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>>,
45}
46
47impl VectorDatabase {
48 pub async fn new(
49 fs: Arc<dyn Fs>,
50 path: Arc<Path>,
51 executor: BackgroundExecutor,
52 ) -> Result<Self> {
53 if let Some(db_directory) = path.parent() {
54 fs.create_dir(db_directory).await?;
55 }
56
57 let (transactions_tx, transactions_rx) = smol::channel::unbounded::<
58 Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>,
59 >();
60 executor
61 .spawn({
62 let path = path.clone();
63 async move {
64 let mut connection = rusqlite::Connection::open(&path)?;
65
66 connection.pragma_update(None, "journal_mode", "wal")?;
67 connection.pragma_update(None, "synchronous", "normal")?;
68 connection.pragma_update(None, "cache_size", 1000000)?;
69 connection.pragma_update(None, "temp_store", "MEMORY")?;
70
71 while let Ok(transaction) = transactions_rx.recv().await {
72 transaction(&mut connection);
73 }
74
75 anyhow::Ok(())
76 }
77 .log_err()
78 })
79 .detach();
80 let this = Self {
81 transactions: transactions_tx,
82 path,
83 };
84 this.initialize_database().await?;
85 Ok(this)
86 }
87
88 pub fn path(&self) -> &Arc<Path> {
89 &self.path
90 }
91
92 fn transact<F, T>(&self, f: F) -> impl Future<Output = Result<T>>
93 where
94 F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result<T>,
95 T: 'static + Send,
96 {
97 let (tx, rx) = oneshot::channel();
98 let transactions = self.transactions.clone();
99 async move {
100 if transactions
101 .send(Box::new(|connection| {
102 let result = connection
103 .transaction()
104 .map_err(|err| anyhow!(err))
105 .and_then(|transaction| {
106 let result = f(&transaction)?;
107 transaction.commit()?;
108 Ok(result)
109 });
110 let _ = tx.send(result);
111 }))
112 .await
113 .is_err()
114 {
115 return Err(anyhow!("connection was dropped"))?;
116 }
117 rx.await?
118 }
119 }
120
121 fn initialize_database(&self) -> impl Future<Output = Result<()>> {
122 self.transact(|db| {
123 rusqlite::vtab::array::load_module(&db)?;
124
125 // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
126 let version_query = db.prepare("SELECT version from semantic_index_config");
127 let version = version_query
128 .and_then(|mut query| query.query_row([], |row| Ok(row.get::<_, i64>(0)?)));
129 if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) {
130 log::trace!("vector database schema up to date");
131 return Ok(());
132 }
133
134 log::trace!("vector database schema out of date. updating...");
135 // We renamed the `documents` table to `spans`, so we want to drop
136 // `documents` without recreating it if it exists.
137 db.execute("DROP TABLE IF EXISTS documents", [])
138 .context("failed to drop 'documents' table")?;
139 db.execute("DROP TABLE IF EXISTS spans", [])
140 .context("failed to drop 'spans' table")?;
141 db.execute("DROP TABLE IF EXISTS files", [])
142 .context("failed to drop 'files' table")?;
143 db.execute("DROP TABLE IF EXISTS worktrees", [])
144 .context("failed to drop 'worktrees' table")?;
145 db.execute("DROP TABLE IF EXISTS semantic_index_config", [])
146 .context("failed to drop 'semantic_index_config' table")?;
147
148 // Initialize Vector Databasing Tables
149 db.execute(
150 "CREATE TABLE semantic_index_config (
151 version INTEGER NOT NULL
152 )",
153 [],
154 )?;
155
156 db.execute(
157 "INSERT INTO semantic_index_config (version) VALUES (?1)",
158 params![SEMANTIC_INDEX_VERSION],
159 )?;
160
161 db.execute(
162 "CREATE TABLE worktrees (
163 id INTEGER PRIMARY KEY AUTOINCREMENT,
164 absolute_path VARCHAR NOT NULL
165 );
166 CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
167 ",
168 [],
169 )?;
170
171 db.execute(
172 "CREATE TABLE files (
173 id INTEGER PRIMARY KEY AUTOINCREMENT,
174 worktree_id INTEGER NOT NULL,
175 relative_path VARCHAR NOT NULL,
176 mtime_seconds INTEGER NOT NULL,
177 mtime_nanos INTEGER NOT NULL,
178 FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
179 )",
180 [],
181 )?;
182
183 db.execute(
184 "CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)",
185 [],
186 )?;
187
188 db.execute(
189 "CREATE TABLE spans (
190 id INTEGER PRIMARY KEY AUTOINCREMENT,
191 file_id INTEGER NOT NULL,
192 start_byte INTEGER NOT NULL,
193 end_byte INTEGER NOT NULL,
194 name VARCHAR NOT NULL,
195 embedding BLOB NOT NULL,
196 digest BLOB NOT NULL,
197 FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
198 )",
199 [],
200 )?;
201 db.execute(
202 "CREATE INDEX spans_digest ON spans (digest)",
203 [],
204 )?;
205
206 log::trace!("vector database initialized with updated schema.");
207 Ok(())
208 })
209 }
210
211 pub fn delete_file(
212 &self,
213 worktree_id: i64,
214 delete_path: Arc<Path>,
215 ) -> impl Future<Output = Result<()>> {
216 self.transact(move |db| {
217 db.execute(
218 "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
219 params![worktree_id, delete_path.to_str()],
220 )?;
221 Ok(())
222 })
223 }
224
225 pub fn insert_file(
226 &self,
227 worktree_id: i64,
228 path: Arc<Path>,
229 mtime: SystemTime,
230 spans: Vec<Span>,
231 ) -> impl Future<Output = Result<()>> {
232 self.transact(move |db| {
233 // Return the existing ID, if both the file and mtime match
234 let mtime = Timestamp::from(mtime);
235
236 db.execute(
237 "
238 REPLACE INTO files
239 (worktree_id, relative_path, mtime_seconds, mtime_nanos)
240 VALUES (?1, ?2, ?3, ?4)
241 ",
242 params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
243 )?;
244
245 let file_id = db.last_insert_rowid();
246
247 let mut query = db.prepare(
248 "
249 INSERT INTO spans
250 (file_id, start_byte, end_byte, name, embedding, digest)
251 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
252 ",
253 )?;
254
255 for span in spans {
256 query.execute(params![
257 file_id,
258 span.range.start.to_string(),
259 span.range.end.to_string(),
260 span.name,
261 span.embedding,
262 span.digest
263 ])?;
264 }
265
266 Ok(())
267 })
268 }
269
270 pub fn worktree_previously_indexed(
271 &self,
272 worktree_root_path: &Path,
273 ) -> impl Future<Output = Result<bool>> {
274 let worktree_root_path = worktree_root_path.to_string_lossy().into_owned();
275 self.transact(move |db| {
276 let mut worktree_query =
277 db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
278 let worktree_id = worktree_query
279 .query_row(params![worktree_root_path], |row| Ok(row.get::<_, i64>(0)?));
280
281 if worktree_id.is_ok() {
282 return Ok(true);
283 } else {
284 return Ok(false);
285 }
286 })
287 }
288
289 pub fn embeddings_for_digests(
290 &self,
291 digests: Vec<SpanDigest>,
292 ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
293 self.transact(move |db| {
294 let mut query = db.prepare(
295 "
296 SELECT digest, embedding
297 FROM spans
298 WHERE digest IN rarray(?)
299 ",
300 )?;
301 let mut embeddings_by_digest = HashMap::default();
302 let digests = Rc::new(
303 digests
304 .into_iter()
305 .map(|p| Value::Blob(p.0.to_vec()))
306 .collect::<Vec<_>>(),
307 );
308 let rows = query.query_map(params![digests], |row| {
309 Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
310 })?;
311
312 for row in rows {
313 if let Ok(row) = row {
314 embeddings_by_digest.insert(row.0, row.1);
315 }
316 }
317
318 Ok(embeddings_by_digest)
319 })
320 }
321
322 pub fn embeddings_for_files(
323 &self,
324 worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
325 ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
326 self.transact(move |db| {
327 let mut query = db.prepare(
328 "
329 SELECT digest, embedding
330 FROM spans
331 LEFT JOIN files ON files.id = spans.file_id
332 WHERE files.worktree_id = ? AND files.relative_path IN rarray(?)
333 ",
334 )?;
335 let mut embeddings_by_digest = HashMap::default();
336 for (worktree_id, file_paths) in worktree_id_file_paths {
337 let file_paths = Rc::new(
338 file_paths
339 .into_iter()
340 .map(|p| Value::Text(p.to_string_lossy().into_owned()))
341 .collect::<Vec<_>>(),
342 );
343 let rows = query.query_map(params![worktree_id, file_paths], |row| {
344 Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
345 })?;
346
347 for row in rows {
348 if let Ok(row) = row {
349 embeddings_by_digest.insert(row.0, row.1);
350 }
351 }
352 }
353
354 Ok(embeddings_by_digest)
355 })
356 }
357
358 pub fn find_or_create_worktree(
359 &self,
360 worktree_root_path: Arc<Path>,
361 ) -> impl Future<Output = Result<i64>> {
362 self.transact(move |db| {
363 let mut worktree_query =
364 db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
365 let worktree_id = worktree_query
366 .query_row(params![worktree_root_path.to_string_lossy()], |row| {
367 Ok(row.get::<_, i64>(0)?)
368 });
369
370 if worktree_id.is_ok() {
371 return Ok(worktree_id?);
372 }
373
374 // If worktree_id is Err, insert new worktree
375 db.execute(
376 "INSERT into worktrees (absolute_path) VALUES (?1)",
377 params![worktree_root_path.to_string_lossy()],
378 )?;
379 Ok(db.last_insert_rowid())
380 })
381 }
382
383 pub fn get_file_mtimes(
384 &self,
385 worktree_id: i64,
386 ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
387 self.transact(move |db| {
388 let mut statement = db.prepare(
389 "
390 SELECT relative_path, mtime_seconds, mtime_nanos
391 FROM files
392 WHERE worktree_id = ?1
393 ORDER BY relative_path",
394 )?;
395 let mut result: HashMap<PathBuf, SystemTime> = HashMap::default();
396 for row in statement.query_map(params![worktree_id], |row| {
397 Ok((
398 row.get::<_, String>(0)?.into(),
399 Timestamp {
400 seconds: row.get(1)?,
401 nanos: row.get(2)?,
402 }
403 .into(),
404 ))
405 })? {
406 let row = row?;
407 result.insert(row.0, row.1);
408 }
409 Ok(result)
410 })
411 }
412
413 pub fn top_k_search(
414 &self,
415 query_embedding: &Embedding,
416 limit: usize,
417 file_ids: &[i64],
418 ) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
419 let file_ids = file_ids.to_vec();
420 let query = query_embedding.clone().0;
421 let query = Array1::from_vec(query);
422 self.transact(move |db| {
423 let mut query_statement = db.prepare(
424 "
425 SELECT
426 id, embedding
427 FROM
428 spans
429 WHERE
430 file_id IN rarray(?)
431 ",
432 )?;
433
434 let deserialized_rows = query_statement
435 .query_map(params![ids_to_sql(&file_ids)], |row| {
436 Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?))
437 })?
438 .filter_map(|row| row.ok())
439 .collect::<Vec<(usize, Embedding)>>();
440
441 if deserialized_rows.len() == 0 {
442 return Ok(Vec::new());
443 }
444
445 // Get Length of Embeddings Returned
446 let embedding_len = deserialized_rows[0].1 .0.len();
447
448 let batch_n = 1000;
449 let mut batches = Vec::new();
450 let mut batch_ids = Vec::new();
451 let mut batch_embeddings: Vec<f32> = Vec::new();
452 deserialized_rows.iter().for_each(|(id, embedding)| {
453 batch_ids.push(id);
454 batch_embeddings.extend(&embedding.0);
455
456 if batch_ids.len() == batch_n {
457 let embeddings = std::mem::take(&mut batch_embeddings);
458 let ids = std::mem::take(&mut batch_ids);
459 let array = Array2::from_shape_vec((ids.len(), embedding_len), embeddings);
460 match array {
461 Ok(array) => {
462 batches.push((ids, array));
463 }
464 Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
465 }
466 }
467 });
468
469 if batch_ids.len() > 0 {
470 let array = Array2::from_shape_vec(
471 (batch_ids.len(), embedding_len),
472 batch_embeddings.clone(),
473 );
474 match array {
475 Ok(array) => {
476 batches.push((batch_ids.clone(), array));
477 }
478 Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
479 }
480 }
481
482 let mut ids: Vec<usize> = Vec::new();
483 let mut results = Vec::new();
484 for (batch_ids, array) in batches {
485 let scores = array
486 .dot(&query.t())
487 .to_vec()
488 .iter()
489 .map(|score| OrderedFloat(*score))
490 .collect::<Vec<OrderedFloat<f32>>>();
491 results.extend(scores);
492 ids.extend(batch_ids);
493 }
494
495 let sorted_idx = argsort(&results);
496 let mut sorted_results = Vec::new();
497 let last_idx = limit.min(sorted_idx.len());
498 for idx in &sorted_idx[0..last_idx] {
499 sorted_results.push((ids[*idx] as i64, results[*idx]))
500 }
501
502 Ok(sorted_results)
503 })
504 }
505
506 pub fn retrieve_included_file_ids(
507 &self,
508 worktree_ids: &[i64],
509 includes: &[PathMatcher],
510 excludes: &[PathMatcher],
511 ) -> impl Future<Output = Result<Vec<i64>>> {
512 let worktree_ids = worktree_ids.to_vec();
513 let includes = includes.to_vec();
514 let excludes = excludes.to_vec();
515 self.transact(move |db| {
516 let mut file_query = db.prepare(
517 "
518 SELECT
519 id, relative_path
520 FROM
521 files
522 WHERE
523 worktree_id IN rarray(?)
524 ",
525 )?;
526
527 let mut file_ids = Vec::<i64>::new();
528 let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?;
529
530 while let Some(row) = rows.next()? {
531 let file_id = row.get(0)?;
532 let relative_path = row.get_ref(1)?.as_str()?;
533 let included =
534 includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
535 let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
536 if included && !excluded {
537 file_ids.push(file_id);
538 }
539 }
540
541 anyhow::Ok(file_ids)
542 })
543 }
544
545 pub fn spans_for_ids(
546 &self,
547 ids: &[i64],
548 ) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
549 let ids = ids.to_vec();
550 self.transact(move |db| {
551 let mut statement = db.prepare(
552 "
553 SELECT
554 spans.id,
555 files.worktree_id,
556 files.relative_path,
557 spans.start_byte,
558 spans.end_byte
559 FROM
560 spans, files
561 WHERE
562 spans.file_id = files.id AND
563 spans.id in rarray(?)
564 ",
565 )?;
566
567 let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| {
568 Ok((
569 row.get::<_, i64>(0)?,
570 row.get::<_, i64>(1)?,
571 row.get::<_, String>(2)?.into(),
572 row.get(3)?..row.get(4)?,
573 ))
574 })?;
575
576 let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
577 for row in result_iter {
578 let (id, worktree_id, path, range) = row?;
579 values_by_id.insert(id, (worktree_id, path, range));
580 }
581
582 let mut results = Vec::with_capacity(ids.len());
583 for id in &ids {
584 let value = values_by_id
585 .remove(id)
586 .ok_or(anyhow!("missing span id {}", id))?;
587 results.push(value);
588 }
589
590 Ok(results)
591 })
592 }
593}
594
595fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
596 Rc::new(
597 ids.iter()
598 .copied()
599 .map(|v| rusqlite::types::Value::from(v))
600 .collect::<Vec<_>>(),
601 )
602}