1use crate::{
2 parsing::{Span, SpanDigest},
3 SEMANTIC_INDEX_VERSION,
4};
5use ai::embedding::Embedding;
6use anyhow::{anyhow, Context, Result};
7use collections::HashMap;
8use futures::channel::oneshot;
9use gpui::executor;
10use ndarray::{Array1, Array2};
11use ordered_float::OrderedFloat;
12use project::Fs;
13use rpc::proto::Timestamp;
14use rusqlite::params;
15use rusqlite::types::Value;
16use std::{
17 future::Future,
18 ops::Range,
19 path::{Path, PathBuf},
20 rc::Rc,
21 sync::Arc,
22 time::SystemTime,
23};
24use util::{paths::PathMatcher, TryFutureExt};
25
26pub fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
27 let mut indices = (0..data.len()).collect::<Vec<_>>();
28 indices.sort_by_key(|&i| &data[i]);
29 indices.reverse();
30 indices
31}
32
33#[derive(Debug)]
34pub struct FileRecord {
35 pub id: usize,
36 pub relative_path: String,
37 pub mtime: Timestamp,
38}
39
40#[derive(Clone)]
41pub struct VectorDatabase {
42 path: Arc<Path>,
43 transactions:
44 smol::channel::Sender<Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>>,
45}
46
47impl VectorDatabase {
48 pub async fn new(
49 fs: Arc<dyn Fs>,
50 path: Arc<Path>,
51 executor: Arc<executor::Background>,
52 ) -> Result<Self> {
53 if let Some(db_directory) = path.parent() {
54 fs.create_dir(db_directory).await?;
55 }
56
57 let (transactions_tx, transactions_rx) = smol::channel::unbounded::<
58 Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>,
59 >();
60 executor
61 .spawn({
62 let path = path.clone();
63 async move {
64 let mut connection = rusqlite::Connection::open(&path)?;
65
66 connection.pragma_update(None, "journal_mode", "wal")?;
67 connection.pragma_update(None, "synchronous", "normal")?;
68 connection.pragma_update(None, "cache_size", 1000000)?;
69 connection.pragma_update(None, "temp_store", "MEMORY")?;
70
71 while let Ok(transaction) = transactions_rx.recv().await {
72 transaction(&mut connection);
73 }
74
75 anyhow::Ok(())
76 }
77 .log_err()
78 })
79 .detach();
80 let this = Self {
81 transactions: transactions_tx,
82 path,
83 };
84 this.initialize_database().await?;
85 Ok(this)
86 }
87
88 pub fn path(&self) -> &Arc<Path> {
89 &self.path
90 }
91
92 fn transact<F, T>(&self, f: F) -> impl Future<Output = Result<T>>
93 where
94 F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result<T>,
95 T: 'static + Send,
96 {
97 let (tx, rx) = oneshot::channel();
98 let transactions = self.transactions.clone();
99 async move {
100 if transactions
101 .send(Box::new(|connection| {
102 let result = connection
103 .transaction()
104 .map_err(|err| anyhow!(err))
105 .and_then(|transaction| {
106 let result = f(&transaction)?;
107 transaction.commit()?;
108 Ok(result)
109 });
110 let _ = tx.send(result);
111 }))
112 .await
113 .is_err()
114 {
115 return Err(anyhow!("connection was dropped"))?;
116 }
117 rx.await?
118 }
119 }
120
121 fn initialize_database(&self) -> impl Future<Output = Result<()>> {
122 self.transact(|db| {
123 rusqlite::vtab::array::load_module(&db)?;
124
125 // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
126 let version_query = db.prepare("SELECT version from semantic_index_config");
127 let version = version_query
128 .and_then(|mut query| query.query_row([], |row| Ok(row.get::<_, i64>(0)?)));
129 if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) {
130 log::trace!("vector database schema up to date");
131 return Ok(());
132 }
133
134 log::trace!("vector database schema out of date. updating...");
135 // We renamed the `documents` table to `spans`, so we want to drop
136 // `documents` without recreating it if it exists.
137 db.execute("DROP TABLE IF EXISTS documents", [])
138 .context("failed to drop 'documents' table")?;
139 db.execute("DROP TABLE IF EXISTS spans", [])
140 .context("failed to drop 'spans' table")?;
141 db.execute("DROP TABLE IF EXISTS files", [])
142 .context("failed to drop 'files' table")?;
143 db.execute("DROP TABLE IF EXISTS worktrees", [])
144 .context("failed to drop 'worktrees' table")?;
145 db.execute("DROP TABLE IF EXISTS semantic_index_config", [])
146 .context("failed to drop 'semantic_index_config' table")?;
147
148 // Initialize Vector Databasing Tables
149 db.execute(
150 "CREATE TABLE semantic_index_config (
151 version INTEGER NOT NULL
152 )",
153 [],
154 )?;
155
156 db.execute(
157 "INSERT INTO semantic_index_config (version) VALUES (?1)",
158 params![SEMANTIC_INDEX_VERSION],
159 )?;
160
161 db.execute(
162 "CREATE TABLE worktrees (
163 id INTEGER PRIMARY KEY AUTOINCREMENT,
164 absolute_path VARCHAR NOT NULL
165 );
166 CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
167 ",
168 [],
169 )?;
170
171 db.execute(
172 "CREATE TABLE files (
173 id INTEGER PRIMARY KEY AUTOINCREMENT,
174 worktree_id INTEGER NOT NULL,
175 relative_path VARCHAR NOT NULL,
176 mtime_seconds INTEGER NOT NULL,
177 mtime_nanos INTEGER NOT NULL,
178 FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
179 )",
180 [],
181 )?;
182
183 db.execute(
184 "CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)",
185 [],
186 )?;
187
188 db.execute(
189 "CREATE TABLE spans (
190 id INTEGER PRIMARY KEY AUTOINCREMENT,
191 file_id INTEGER NOT NULL,
192 start_byte INTEGER NOT NULL,
193 end_byte INTEGER NOT NULL,
194 name VARCHAR NOT NULL,
195 embedding BLOB NOT NULL,
196 digest BLOB NOT NULL,
197 FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
198 )",
199 [],
200 )?;
201 db.execute(
202 "CREATE INDEX spans_digest ON spans (digest)",
203 [],
204 )?;
205
206 log::trace!("vector database initialized with updated schema.");
207 Ok(())
208 })
209 }
210
211 pub fn delete_file(
212 &self,
213 worktree_id: i64,
214 delete_path: Arc<Path>,
215 ) -> impl Future<Output = Result<()>> {
216 self.transact(move |db| {
217 db.execute(
218 "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
219 params![worktree_id, delete_path.to_str()],
220 )?;
221 Ok(())
222 })
223 }
224
225 pub fn insert_file(
226 &self,
227 worktree_id: i64,
228 path: Arc<Path>,
229 mtime: SystemTime,
230 spans: Vec<Span>,
231 ) -> impl Future<Output = Result<()>> {
232 self.transact(move |db| {
233 // Return the existing ID, if both the file and mtime match
234 let mtime = Timestamp::from(mtime);
235
236 db.execute(
237 "
238 REPLACE INTO files
239 (worktree_id, relative_path, mtime_seconds, mtime_nanos)
240 VALUES (?1, ?2, ?3, ?4)
241 ",
242 params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
243 )?;
244
245 let file_id = db.last_insert_rowid();
246
247 let mut query = db.prepare(
248 "
249 INSERT INTO spans
250 (file_id, start_byte, end_byte, name, embedding, digest)
251 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
252 ",
253 )?;
254
255 for span in spans {
256 query.execute(params![
257 file_id,
258 span.range.start.to_string(),
259 span.range.end.to_string(),
260 span.name,
261 span.embedding,
262 span.digest
263 ])?;
264 }
265
266 Ok(())
267 })
268 }
269
270 pub fn worktree_previously_indexed(
271 &self,
272 worktree_root_path: &Path,
273 ) -> impl Future<Output = Result<bool>> {
274 let worktree_root_path = worktree_root_path.to_string_lossy().into_owned();
275 self.transact(move |db| {
276 let mut worktree_query =
277 db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
278 let worktree_id = worktree_query
279 .query_row(params![worktree_root_path], |row| Ok(row.get::<_, i64>(0)?));
280
281 if worktree_id.is_ok() {
282 return Ok(true);
283 } else {
284 return Ok(false);
285 }
286 })
287 }
288
289 pub fn embeddings_for_digests(
290 &self,
291 digests: Vec<SpanDigest>,
292 ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
293 self.transact(move |db| {
294 let mut query = db.prepare(
295 "
296 SELECT digest, embedding
297 FROM spans
298 WHERE digest IN rarray(?)
299 ",
300 )?;
301 let mut embeddings_by_digest = HashMap::default();
302 let digests = Rc::new(
303 digests
304 .into_iter()
305 .map(|p| Value::Blob(p.0.to_vec()))
306 .collect::<Vec<_>>(),
307 );
308 let rows = query.query_map(params![digests], |row| {
309 Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
310 })?;
311
312 for row in rows {
313 if let Ok(row) = row {
314 embeddings_by_digest.insert(row.0, row.1);
315 }
316 }
317
318 Ok(embeddings_by_digest)
319 })
320 }
321
322 pub fn embeddings_for_files(
323 &self,
324 worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
325 ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
326 self.transact(move |db| {
327 let mut query = db.prepare(
328 "
329 SELECT digest, embedding
330 FROM spans
331 LEFT JOIN files ON files.id = spans.file_id
332 WHERE files.worktree_id = ? AND files.relative_path IN rarray(?)
333 ",
334 )?;
335 let mut embeddings_by_digest = HashMap::default();
336 for (worktree_id, file_paths) in worktree_id_file_paths {
337 let file_paths = Rc::new(
338 file_paths
339 .into_iter()
340 .map(|p| Value::Text(p.to_string_lossy().into_owned()))
341 .collect::<Vec<_>>(),
342 );
343 let rows = query.query_map(params![worktree_id, file_paths], |row| {
344 Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
345 })?;
346
347 for row in rows {
348 if let Ok(row) = row {
349 embeddings_by_digest.insert(row.0, row.1);
350 }
351 }
352 }
353
354 Ok(embeddings_by_digest)
355 })
356 }
357
358 pub fn find_or_create_worktree(
359 &self,
360 worktree_root_path: Arc<Path>,
361 ) -> impl Future<Output = Result<i64>> {
362 self.transact(move |db| {
363 let mut worktree_query =
364 db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
365 let worktree_id = worktree_query
366 .query_row(params![worktree_root_path.to_string_lossy()], |row| {
367 Ok(row.get::<_, i64>(0)?)
368 });
369
370 if worktree_id.is_ok() {
371 return Ok(worktree_id?);
372 }
373
374 // If worktree_id is Err, insert new worktree
375 db.execute(
376 "INSERT into worktrees (absolute_path) VALUES (?1)",
377 params![worktree_root_path.to_string_lossy()],
378 )?;
379 Ok(db.last_insert_rowid())
380 })
381 }
382
383 pub fn get_file_mtimes(
384 &self,
385 worktree_id: i64,
386 ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
387 self.transact(move |db| {
388 let mut statement = db.prepare(
389 "
390 SELECT relative_path, mtime_seconds, mtime_nanos
391 FROM files
392 WHERE worktree_id = ?1
393 ORDER BY relative_path",
394 )?;
395 let mut result: HashMap<PathBuf, SystemTime> = HashMap::default();
396 for row in statement.query_map(params![worktree_id], |row| {
397 Ok((
398 row.get::<_, String>(0)?.into(),
399 Timestamp {
400 seconds: row.get(1)?,
401 nanos: row.get(2)?,
402 }
403 .into(),
404 ))
405 })? {
406 let row = row?;
407 result.insert(row.0, row.1);
408 }
409 Ok(result)
410 })
411 }
412
413 pub fn top_k_search(
414 &self,
415 query_embedding: &Embedding,
416 limit: usize,
417 file_ids: &[i64],
418 ) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
419 let file_ids = file_ids.to_vec();
420 let query = query_embedding.clone().0;
421 let query = Array1::from_vec(query);
422 self.transact(move |db| {
423 let mut query_statement = db.prepare(
424 "
425 SELECT
426 id, embedding
427 FROM
428 spans
429 WHERE
430 file_id IN rarray(?)
431 ",
432 )?;
433
434 let deserialized_rows = query_statement
435 .query_map(params![ids_to_sql(&file_ids)], |row| {
436 Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?))
437 })?
438 .filter_map(|row| row.ok())
439 .collect::<Vec<(usize, Embedding)>>();
440
441 if deserialized_rows.len() == 0 {
442 return Ok(Vec::new());
443 }
444
445 // Get Length of Embeddings Returned
446 let embedding_len = deserialized_rows[0].1 .0.len();
447
448 let batch_n = 1000;
449 let mut batches = Vec::new();
450 let mut batch_ids = Vec::new();
451 let mut batch_embeddings: Vec<f32> = Vec::new();
452 deserialized_rows.iter().for_each(|(id, embedding)| {
453 batch_ids.push(id);
454 batch_embeddings.extend(&embedding.0);
455
456 if batch_ids.len() == batch_n {
457 let embeddings = std::mem::take(&mut batch_embeddings);
458 let ids = std::mem::take(&mut batch_ids);
459 let array =
460 Array2::from_shape_vec((ids.len(), embedding_len.clone()), embeddings);
461 match array {
462 Ok(array) => {
463 batches.push((ids, array));
464 }
465 Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
466 }
467 }
468 });
469
470 if batch_ids.len() > 0 {
471 let array = Array2::from_shape_vec(
472 (batch_ids.len(), embedding_len),
473 batch_embeddings.clone(),
474 );
475 match array {
476 Ok(array) => {
477 batches.push((batch_ids.clone(), array));
478 }
479 Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
480 }
481 }
482
483 let mut ids: Vec<usize> = Vec::new();
484 let mut results = Vec::new();
485 for (batch_ids, array) in batches {
486 let scores = array
487 .dot(&query.t())
488 .to_vec()
489 .iter()
490 .map(|score| OrderedFloat(*score))
491 .collect::<Vec<OrderedFloat<f32>>>();
492 results.extend(scores);
493 ids.extend(batch_ids);
494 }
495
496 let sorted_idx = argsort(&results);
497 let mut sorted_results = Vec::new();
498 let last_idx = limit.min(sorted_idx.len());
499 for idx in &sorted_idx[0..last_idx] {
500 sorted_results.push((ids[*idx] as i64, results[*idx]))
501 }
502
503 Ok(sorted_results)
504 })
505 }
506
507 pub fn retrieve_included_file_ids(
508 &self,
509 worktree_ids: &[i64],
510 includes: &[PathMatcher],
511 excludes: &[PathMatcher],
512 ) -> impl Future<Output = Result<Vec<i64>>> {
513 let worktree_ids = worktree_ids.to_vec();
514 let includes = includes.to_vec();
515 let excludes = excludes.to_vec();
516 self.transact(move |db| {
517 let mut file_query = db.prepare(
518 "
519 SELECT
520 id, relative_path
521 FROM
522 files
523 WHERE
524 worktree_id IN rarray(?)
525 ",
526 )?;
527
528 let mut file_ids = Vec::<i64>::new();
529 let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?;
530
531 while let Some(row) = rows.next()? {
532 let file_id = row.get(0)?;
533 let relative_path = row.get_ref(1)?.as_str()?;
534 let included =
535 includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
536 let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
537 if included && !excluded {
538 file_ids.push(file_id);
539 }
540 }
541
542 anyhow::Ok(file_ids)
543 })
544 }
545
546 pub fn spans_for_ids(
547 &self,
548 ids: &[i64],
549 ) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
550 let ids = ids.to_vec();
551 self.transact(move |db| {
552 let mut statement = db.prepare(
553 "
554 SELECT
555 spans.id,
556 files.worktree_id,
557 files.relative_path,
558 spans.start_byte,
559 spans.end_byte
560 FROM
561 spans, files
562 WHERE
563 spans.file_id = files.id AND
564 spans.id in rarray(?)
565 ",
566 )?;
567
568 let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| {
569 Ok((
570 row.get::<_, i64>(0)?,
571 row.get::<_, i64>(1)?,
572 row.get::<_, String>(2)?.into(),
573 row.get(3)?..row.get(4)?,
574 ))
575 })?;
576
577 let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
578 for row in result_iter {
579 let (id, worktree_id, path, range) = row?;
580 values_by_id.insert(id, (worktree_id, path, range));
581 }
582
583 let mut results = Vec::with_capacity(ids.len());
584 for id in &ids {
585 let value = values_by_id
586 .remove(id)
587 .ok_or(anyhow!("missing span id {}", id))?;
588 results.push(value);
589 }
590
591 Ok(results)
592 })
593 }
594}
595
596fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
597 Rc::new(
598 ids.iter()
599 .copied()
600 .map(|v| rusqlite::types::Value::from(v))
601 .collect::<Vec<_>>(),
602 )
603}