1extern crate blas_src;
2
3use crate::{
4 parsing::{Span, SpanDigest},
5 SEMANTIC_INDEX_VERSION,
6};
7use ai::embedding::Embedding;
8use anyhow::{anyhow, Context, Result};
9use collections::HashMap;
10use futures::channel::oneshot;
11use gpui::executor;
12use ndarray::{Array1, Array2};
13use ordered_float::OrderedFloat;
14use project::{search::PathMatcher, Fs};
15use rpc::proto::Timestamp;
16use rusqlite::params;
17use rusqlite::types::Value;
18use std::{
19 future::Future,
20 ops::Range,
21 path::{Path, PathBuf},
22 rc::Rc,
23 sync::Arc,
24 time::SystemTime,
25};
26use util::TryFutureExt;
27
28pub fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
29 let mut indices = (0..data.len()).collect::<Vec<_>>();
30 indices.sort_by_key(|&i| &data[i]);
31 indices.reverse();
32 indices
33}
34
35#[derive(Debug)]
36pub struct FileRecord {
37 pub id: usize,
38 pub relative_path: String,
39 pub mtime: Timestamp,
40}
41
42#[derive(Clone)]
43pub struct VectorDatabase {
44 path: Arc<Path>,
45 transactions:
46 smol::channel::Sender<Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>>,
47}
48
49impl VectorDatabase {
50 pub async fn new(
51 fs: Arc<dyn Fs>,
52 path: Arc<Path>,
53 executor: Arc<executor::Background>,
54 ) -> Result<Self> {
55 if let Some(db_directory) = path.parent() {
56 fs.create_dir(db_directory).await?;
57 }
58
59 let (transactions_tx, transactions_rx) = smol::channel::unbounded::<
60 Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>,
61 >();
62 executor
63 .spawn({
64 let path = path.clone();
65 async move {
66 let mut connection = rusqlite::Connection::open(&path)?;
67
68 connection.pragma_update(None, "journal_mode", "wal")?;
69 connection.pragma_update(None, "synchronous", "normal")?;
70 connection.pragma_update(None, "cache_size", 1000000)?;
71 connection.pragma_update(None, "temp_store", "MEMORY")?;
72
73 while let Ok(transaction) = transactions_rx.recv().await {
74 transaction(&mut connection);
75 }
76
77 anyhow::Ok(())
78 }
79 .log_err()
80 })
81 .detach();
82 let this = Self {
83 transactions: transactions_tx,
84 path,
85 };
86 this.initialize_database().await?;
87 Ok(this)
88 }
89
90 pub fn path(&self) -> &Arc<Path> {
91 &self.path
92 }
93
94 fn transact<F, T>(&self, f: F) -> impl Future<Output = Result<T>>
95 where
96 F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result<T>,
97 T: 'static + Send,
98 {
99 let (tx, rx) = oneshot::channel();
100 let transactions = self.transactions.clone();
101 async move {
102 if transactions
103 .send(Box::new(|connection| {
104 let result = connection
105 .transaction()
106 .map_err(|err| anyhow!(err))
107 .and_then(|transaction| {
108 let result = f(&transaction)?;
109 transaction.commit()?;
110 Ok(result)
111 });
112 let _ = tx.send(result);
113 }))
114 .await
115 .is_err()
116 {
117 return Err(anyhow!("connection was dropped"))?;
118 }
119 rx.await?
120 }
121 }
122
123 fn initialize_database(&self) -> impl Future<Output = Result<()>> {
124 self.transact(|db| {
125 rusqlite::vtab::array::load_module(&db)?;
126
127 // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
128 let version_query = db.prepare("SELECT version from semantic_index_config");
129 let version = version_query
130 .and_then(|mut query| query.query_row([], |row| Ok(row.get::<_, i64>(0)?)));
131 if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) {
132 log::trace!("vector database schema up to date");
133 return Ok(());
134 }
135
136 log::trace!("vector database schema out of date. updating...");
137 // We renamed the `documents` table to `spans`, so we want to drop
138 // `documents` without recreating it if it exists.
139 db.execute("DROP TABLE IF EXISTS documents", [])
140 .context("failed to drop 'documents' table")?;
141 db.execute("DROP TABLE IF EXISTS spans", [])
142 .context("failed to drop 'spans' table")?;
143 db.execute("DROP TABLE IF EXISTS files", [])
144 .context("failed to drop 'files' table")?;
145 db.execute("DROP TABLE IF EXISTS worktrees", [])
146 .context("failed to drop 'worktrees' table")?;
147 db.execute("DROP TABLE IF EXISTS semantic_index_config", [])
148 .context("failed to drop 'semantic_index_config' table")?;
149
150 // Initialize Vector Databasing Tables
151 db.execute(
152 "CREATE TABLE semantic_index_config (
153 version INTEGER NOT NULL
154 )",
155 [],
156 )?;
157
158 db.execute(
159 "INSERT INTO semantic_index_config (version) VALUES (?1)",
160 params![SEMANTIC_INDEX_VERSION],
161 )?;
162
163 db.execute(
164 "CREATE TABLE worktrees (
165 id INTEGER PRIMARY KEY AUTOINCREMENT,
166 absolute_path VARCHAR NOT NULL
167 );
168 CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
169 ",
170 [],
171 )?;
172
173 db.execute(
174 "CREATE TABLE files (
175 id INTEGER PRIMARY KEY AUTOINCREMENT,
176 worktree_id INTEGER NOT NULL,
177 relative_path VARCHAR NOT NULL,
178 mtime_seconds INTEGER NOT NULL,
179 mtime_nanos INTEGER NOT NULL,
180 FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
181 )",
182 [],
183 )?;
184
185 db.execute(
186 "CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)",
187 [],
188 )?;
189
190 db.execute(
191 "CREATE TABLE spans (
192 id INTEGER PRIMARY KEY AUTOINCREMENT,
193 file_id INTEGER NOT NULL,
194 start_byte INTEGER NOT NULL,
195 end_byte INTEGER NOT NULL,
196 name VARCHAR NOT NULL,
197 embedding BLOB NOT NULL,
198 digest BLOB NOT NULL,
199 FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
200 )",
201 [],
202 )?;
203 db.execute(
204 "CREATE INDEX spans_digest ON spans (digest)",
205 [],
206 )?;
207
208 log::trace!("vector database initialized with updated schema.");
209 Ok(())
210 })
211 }
212
213 pub fn delete_file(
214 &self,
215 worktree_id: i64,
216 delete_path: Arc<Path>,
217 ) -> impl Future<Output = Result<()>> {
218 self.transact(move |db| {
219 db.execute(
220 "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
221 params![worktree_id, delete_path.to_str()],
222 )?;
223 Ok(())
224 })
225 }
226
227 pub fn insert_file(
228 &self,
229 worktree_id: i64,
230 path: Arc<Path>,
231 mtime: SystemTime,
232 spans: Vec<Span>,
233 ) -> impl Future<Output = Result<()>> {
234 self.transact(move |db| {
235 // Return the existing ID, if both the file and mtime match
236 let mtime = Timestamp::from(mtime);
237
238 db.execute(
239 "
240 REPLACE INTO files
241 (worktree_id, relative_path, mtime_seconds, mtime_nanos)
242 VALUES (?1, ?2, ?3, ?4)
243 ",
244 params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
245 )?;
246
247 let file_id = db.last_insert_rowid();
248
249 let mut query = db.prepare(
250 "
251 INSERT INTO spans
252 (file_id, start_byte, end_byte, name, embedding, digest)
253 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
254 ",
255 )?;
256
257 for span in spans {
258 query.execute(params![
259 file_id,
260 span.range.start.to_string(),
261 span.range.end.to_string(),
262 span.name,
263 span.embedding,
264 span.digest
265 ])?;
266 }
267
268 Ok(())
269 })
270 }
271
272 pub fn worktree_previously_indexed(
273 &self,
274 worktree_root_path: &Path,
275 ) -> impl Future<Output = Result<bool>> {
276 let worktree_root_path = worktree_root_path.to_string_lossy().into_owned();
277 self.transact(move |db| {
278 let mut worktree_query =
279 db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
280 let worktree_id = worktree_query
281 .query_row(params![worktree_root_path], |row| Ok(row.get::<_, i64>(0)?));
282
283 if worktree_id.is_ok() {
284 return Ok(true);
285 } else {
286 return Ok(false);
287 }
288 })
289 }
290
291 pub fn embeddings_for_digests(
292 &self,
293 digests: Vec<SpanDigest>,
294 ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
295 self.transact(move |db| {
296 let mut query = db.prepare(
297 "
298 SELECT digest, embedding
299 FROM spans
300 WHERE digest IN rarray(?)
301 ",
302 )?;
303 let mut embeddings_by_digest = HashMap::default();
304 let digests = Rc::new(
305 digests
306 .into_iter()
307 .map(|p| Value::Blob(p.0.to_vec()))
308 .collect::<Vec<_>>(),
309 );
310 let rows = query.query_map(params![digests], |row| {
311 Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
312 })?;
313
314 for row in rows {
315 if let Ok(row) = row {
316 embeddings_by_digest.insert(row.0, row.1);
317 }
318 }
319
320 Ok(embeddings_by_digest)
321 })
322 }
323
324 pub fn embeddings_for_files(
325 &self,
326 worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
327 ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
328 self.transact(move |db| {
329 let mut query = db.prepare(
330 "
331 SELECT digest, embedding
332 FROM spans
333 LEFT JOIN files ON files.id = spans.file_id
334 WHERE files.worktree_id = ? AND files.relative_path IN rarray(?)
335 ",
336 )?;
337 let mut embeddings_by_digest = HashMap::default();
338 for (worktree_id, file_paths) in worktree_id_file_paths {
339 let file_paths = Rc::new(
340 file_paths
341 .into_iter()
342 .map(|p| Value::Text(p.to_string_lossy().into_owned()))
343 .collect::<Vec<_>>(),
344 );
345 let rows = query.query_map(params![worktree_id, file_paths], |row| {
346 Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
347 })?;
348
349 for row in rows {
350 if let Ok(row) = row {
351 embeddings_by_digest.insert(row.0, row.1);
352 }
353 }
354 }
355
356 Ok(embeddings_by_digest)
357 })
358 }
359
360 pub fn find_or_create_worktree(
361 &self,
362 worktree_root_path: Arc<Path>,
363 ) -> impl Future<Output = Result<i64>> {
364 self.transact(move |db| {
365 let mut worktree_query =
366 db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
367 let worktree_id = worktree_query
368 .query_row(params![worktree_root_path.to_string_lossy()], |row| {
369 Ok(row.get::<_, i64>(0)?)
370 });
371
372 if worktree_id.is_ok() {
373 return Ok(worktree_id?);
374 }
375
376 // If worktree_id is Err, insert new worktree
377 db.execute(
378 "INSERT into worktrees (absolute_path) VALUES (?1)",
379 params![worktree_root_path.to_string_lossy()],
380 )?;
381 Ok(db.last_insert_rowid())
382 })
383 }
384
385 pub fn get_file_mtimes(
386 &self,
387 worktree_id: i64,
388 ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
389 self.transact(move |db| {
390 let mut statement = db.prepare(
391 "
392 SELECT relative_path, mtime_seconds, mtime_nanos
393 FROM files
394 WHERE worktree_id = ?1
395 ORDER BY relative_path",
396 )?;
397 let mut result: HashMap<PathBuf, SystemTime> = HashMap::default();
398 for row in statement.query_map(params![worktree_id], |row| {
399 Ok((
400 row.get::<_, String>(0)?.into(),
401 Timestamp {
402 seconds: row.get(1)?,
403 nanos: row.get(2)?,
404 }
405 .into(),
406 ))
407 })? {
408 let row = row?;
409 result.insert(row.0, row.1);
410 }
411 Ok(result)
412 })
413 }
414
415 pub fn top_k_search(
416 &self,
417 query_embedding: &Embedding,
418 limit: usize,
419 file_ids: &[i64],
420 ) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
421 let file_ids = file_ids.to_vec();
422 let query = query_embedding.clone().0;
423 let query = Array1::from_vec(query);
424 self.transact(move |db| {
425 let mut query_statement = db.prepare(
426 "
427 SELECT
428 id, embedding
429 FROM
430 spans
431 WHERE
432 file_id IN rarray(?)
433 ",
434 )?;
435
436 let deserialized_rows = query_statement
437 .query_map(params![ids_to_sql(&file_ids)], |row| {
438 Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?))
439 })?
440 .filter_map(|row| row.ok())
441 .collect::<Vec<(usize, Embedding)>>();
442
443 let batch_n = 250;
444 let mut batches = Vec::new();
445 let mut batch_ids = Vec::new();
446 let mut batch_embeddings: Vec<f32> = Vec::new();
447 deserialized_rows.iter().for_each(|(id, embedding)| {
448 batch_ids.push(id);
449 batch_embeddings.extend(&embedding.0);
450 if batch_ids.len() == batch_n {
451 let array =
452 Array2::from_shape_vec((batch_ids.len(), 1536), batch_embeddings.clone());
453 match array {
454 Ok(array) => {
455 batches.push((batch_ids.clone(), array));
456 }
457 Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
458 }
459
460 batch_ids = Vec::new();
461 batch_embeddings = Vec::new();
462 }
463 });
464
465 if batch_ids.len() > 0 {
466 let array =
467 Array2::from_shape_vec((batch_ids.len(), 1536), batch_embeddings.clone());
468 match array {
469 Ok(array) => {
470 batches.push((batch_ids.clone(), array));
471 }
472 Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
473 }
474 }
475
476 let mut ids: Vec<usize> = Vec::new();
477 let mut results = Vec::new();
478 for (batch_ids, array) in batches {
479 let scores = array
480 .dot(&query.t())
481 .to_vec()
482 .iter()
483 .map(|score| OrderedFloat(*score))
484 .collect::<Vec<OrderedFloat<f32>>>();
485 results.extend(scores);
486 ids.extend(batch_ids);
487 }
488
489 let sorted_idx = argsort(&results);
490 let mut sorted_results = Vec::new();
491 let last_idx = limit.min(sorted_idx.len());
492 for idx in &sorted_idx[0..last_idx] {
493 sorted_results.push((ids[*idx] as i64, results[*idx]))
494 }
495
496 Ok(sorted_results)
497 })
498 }
499
500 pub fn retrieve_included_file_ids(
501 &self,
502 worktree_ids: &[i64],
503 includes: &[PathMatcher],
504 excludes: &[PathMatcher],
505 ) -> impl Future<Output = Result<Vec<i64>>> {
506 let worktree_ids = worktree_ids.to_vec();
507 let includes = includes.to_vec();
508 let excludes = excludes.to_vec();
509 self.transact(move |db| {
510 let mut file_query = db.prepare(
511 "
512 SELECT
513 id, relative_path
514 FROM
515 files
516 WHERE
517 worktree_id IN rarray(?)
518 ",
519 )?;
520
521 let mut file_ids = Vec::<i64>::new();
522 let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?;
523
524 while let Some(row) = rows.next()? {
525 let file_id = row.get(0)?;
526 let relative_path = row.get_ref(1)?.as_str()?;
527 let included =
528 includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
529 let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
530 if included && !excluded {
531 file_ids.push(file_id);
532 }
533 }
534
535 anyhow::Ok(file_ids)
536 })
537 }
538
539 pub fn spans_for_ids(
540 &self,
541 ids: &[i64],
542 ) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
543 let ids = ids.to_vec();
544 self.transact(move |db| {
545 let mut statement = db.prepare(
546 "
547 SELECT
548 spans.id,
549 files.worktree_id,
550 files.relative_path,
551 spans.start_byte,
552 spans.end_byte
553 FROM
554 spans, files
555 WHERE
556 spans.file_id = files.id AND
557 spans.id in rarray(?)
558 ",
559 )?;
560
561 let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| {
562 Ok((
563 row.get::<_, i64>(0)?,
564 row.get::<_, i64>(1)?,
565 row.get::<_, String>(2)?.into(),
566 row.get(3)?..row.get(4)?,
567 ))
568 })?;
569
570 let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
571 for row in result_iter {
572 let (id, worktree_id, path, range) = row?;
573 values_by_id.insert(id, (worktree_id, path, range));
574 }
575
576 let mut results = Vec::with_capacity(ids.len());
577 for id in &ids {
578 let value = values_by_id
579 .remove(id)
580 .ok_or(anyhow!("missing span id {}", id))?;
581 results.push(value);
582 }
583
584 Ok(results)
585 })
586 }
587}
588
589fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
590 Rc::new(
591 ids.iter()
592 .copied()
593 .map(|v| rusqlite::types::Value::from(v))
594 .collect::<Vec<_>>(),
595 )
596}