1mod db;
2mod embedding;
3mod modal;
4mod parsing;
5mod vector_store_settings;
6
7#[cfg(test)]
8mod vector_store_tests;
9
10use crate::vector_store_settings::VectorStoreSettings;
11use anyhow::{anyhow, Result};
12use db::VectorDatabase;
13use embedding::{EmbeddingProvider, OpenAIEmbeddings};
14use futures::{channel::oneshot, Future};
15use gpui::{
16 AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
17 WeakModelHandle,
18};
19use language::{Language, LanguageRegistry};
20use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
21use parking_lot::Mutex;
22use parsing::{CodeContextRetriever, Document};
23use project::{Fs, Project, WorktreeId};
24use smol::channel;
25use std::{
26 collections::{HashMap, HashSet},
27 ops::Range,
28 path::{Path, PathBuf},
29 sync::{
30 atomic::{self, AtomicUsize},
31 Arc, Weak,
32 },
33 time::{Instant, SystemTime},
34};
35use util::{
36 channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
37 http::HttpClient,
38 paths::EMBEDDINGS_DIR,
39 ResultExt,
40};
41use workspace::{Workspace, WorkspaceCreated};
42
43const VECTOR_STORE_VERSION: usize = 1;
44const EMBEDDINGS_BATCH_SIZE: usize = 150;
45
46pub fn init(
47 fs: Arc<dyn Fs>,
48 http_client: Arc<dyn HttpClient>,
49 language_registry: Arc<LanguageRegistry>,
50 cx: &mut AppContext,
51) {
52 settings::register::<VectorStoreSettings>(cx);
53
54 let db_file_path = EMBEDDINGS_DIR
55 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
56 .join("embeddings_db");
57
58 SemanticSearch::init(cx);
59 cx.add_action(
60 |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
61 if cx.has_global::<ModelHandle<VectorStore>>() {
62 let vector_store = cx.global::<ModelHandle<VectorStore>>().clone();
63 workspace.toggle_modal(cx, |workspace, cx| {
64 let project = workspace.project().clone();
65 let workspace = cx.weak_handle();
66 cx.add_view(|cx| {
67 SemanticSearch::new(
68 SemanticSearchDelegate::new(workspace, project, vector_store),
69 cx,
70 )
71 })
72 });
73 }
74 },
75 );
76
77 if *RELEASE_CHANNEL == ReleaseChannel::Stable
78 || !settings::get::<VectorStoreSettings>(cx).enabled
79 {
80 return;
81 }
82
83 cx.spawn(move |mut cx| async move {
84 let vector_store = VectorStore::new(
85 fs,
86 db_file_path,
87 Arc::new(OpenAIEmbeddings {
88 client: http_client,
89 executor: cx.background(),
90 }),
91 language_registry,
92 cx.clone(),
93 )
94 .await?;
95
96 cx.update(|cx| {
97 cx.set_global(vector_store.clone());
98 cx.subscribe_global::<WorkspaceCreated, _>({
99 let vector_store = vector_store.clone();
100 move |event, cx| {
101 let workspace = &event.0;
102 if let Some(workspace) = workspace.upgrade(cx) {
103 let project = workspace.read(cx).project().clone();
104 if project.read(cx).is_local() {
105 vector_store.update(cx, |store, cx| {
106 store.index_project(project, cx).detach();
107 });
108 }
109 }
110 }
111 })
112 .detach();
113 });
114
115 anyhow::Ok(())
116 })
117 .detach();
118}
119
120pub struct VectorStore {
121 fs: Arc<dyn Fs>,
122 database_url: Arc<PathBuf>,
123 embedding_provider: Arc<dyn EmbeddingProvider>,
124 language_registry: Arc<LanguageRegistry>,
125 db_update_tx: channel::Sender<DbOperation>,
126 parsing_files_tx: channel::Sender<PendingFile>,
127 _db_update_task: Task<()>,
128 _embed_batch_task: Task<()>,
129 _batch_files_task: Task<()>,
130 _parsing_files_tasks: Vec<Task<()>>,
131 next_job_id: Arc<AtomicUsize>,
132 projects: HashMap<WeakModelHandle<Project>, ProjectState>,
133}
134
135struct ProjectState {
136 worktree_db_ids: Vec<(WorktreeId, i64)>,
137 outstanding_jobs: Arc<Mutex<HashSet<JobId>>>,
138}
139
140type JobId = usize;
141
142struct JobHandle {
143 id: JobId,
144 set: Weak<Mutex<HashSet<JobId>>>,
145}
146
147impl ProjectState {
148 fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
149 self.worktree_db_ids
150 .iter()
151 .find_map(|(worktree_id, db_id)| {
152 if *worktree_id == id {
153 Some(*db_id)
154 } else {
155 None
156 }
157 })
158 }
159
160 fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
161 self.worktree_db_ids
162 .iter()
163 .find_map(|(worktree_id, db_id)| {
164 if *db_id == id {
165 Some(*worktree_id)
166 } else {
167 None
168 }
169 })
170 }
171}
172
173pub struct PendingFile {
174 worktree_db_id: i64,
175 relative_path: PathBuf,
176 absolute_path: PathBuf,
177 language: Arc<Language>,
178 modified_time: SystemTime,
179 job_handle: JobHandle,
180}
181
182#[derive(Debug, Clone)]
183pub struct SearchResult {
184 pub worktree_id: WorktreeId,
185 pub name: String,
186 pub byte_range: Range<usize>,
187 pub file_path: PathBuf,
188}
189
190enum DbOperation {
191 InsertFile {
192 worktree_id: i64,
193 documents: Vec<Document>,
194 path: PathBuf,
195 mtime: SystemTime,
196 job_handle: JobHandle,
197 },
198 Delete {
199 worktree_id: i64,
200 path: PathBuf,
201 },
202 FindOrCreateWorktree {
203 path: PathBuf,
204 sender: oneshot::Sender<Result<i64>>,
205 },
206 FileMTimes {
207 worktree_id: i64,
208 sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
209 },
210}
211
212enum EmbeddingJob {
213 Enqueue {
214 worktree_id: i64,
215 path: PathBuf,
216 mtime: SystemTime,
217 documents: Vec<Document>,
218 job_handle: JobHandle,
219 },
220 Flush,
221}
222
223impl VectorStore {
224 async fn new(
225 fs: Arc<dyn Fs>,
226 database_url: PathBuf,
227 embedding_provider: Arc<dyn EmbeddingProvider>,
228 language_registry: Arc<LanguageRegistry>,
229 mut cx: AsyncAppContext,
230 ) -> Result<ModelHandle<Self>> {
231 let database_url = Arc::new(database_url);
232
233 let db = cx
234 .background()
235 .spawn(VectorDatabase::new(fs.clone(), database_url.clone()))
236 .await?;
237
238 Ok(cx.add_model(|cx| {
239 // paths_tx -> embeddings_tx -> db_update_tx
240
241 //db_update_tx/rx: Updating Database
242 let (db_update_tx, db_update_rx) = channel::unbounded();
243 let _db_update_task = cx.background().spawn(async move {
244 while let Ok(job) = db_update_rx.recv().await {
245 match job {
246 DbOperation::InsertFile {
247 worktree_id,
248 documents,
249 path,
250 mtime,
251 job_handle,
252 } => {
253 db.insert_file(worktree_id, path, mtime, documents)
254 .log_err();
255 drop(job_handle)
256 }
257 DbOperation::Delete { worktree_id, path } => {
258 db.delete_file(worktree_id, path).log_err();
259 }
260 DbOperation::FindOrCreateWorktree { path, sender } => {
261 let id = db.find_or_create_worktree(&path);
262 sender.send(id).ok();
263 }
264 DbOperation::FileMTimes {
265 worktree_id: worktree_db_id,
266 sender,
267 } => {
268 let file_mtimes = db.get_file_mtimes(worktree_db_id);
269 sender.send(file_mtimes).ok();
270 }
271 }
272 }
273 });
274
275 // embed_tx/rx: Embed Batch and Send to Database
276 let (embed_batch_tx, embed_batch_rx) =
277 channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
278 let _embed_batch_task = cx.background().spawn({
279 let db_update_tx = db_update_tx.clone();
280 let embedding_provider = embedding_provider.clone();
281 async move {
282 while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await {
283 // Construct Batch
284 let mut batch_documents = vec![];
285 for (_, documents, _, _, _) in embeddings_queue.iter() {
286 batch_documents
287 .extend(documents.iter().map(|document| document.content.as_str()));
288 }
289
290 if let Ok(embeddings) =
291 embedding_provider.embed_batch(batch_documents).await
292 {
293 log::trace!(
294 "created {} embeddings for {} files",
295 embeddings.len(),
296 embeddings_queue.len(),
297 );
298
299 let mut i = 0;
300 let mut j = 0;
301
302 for embedding in embeddings.iter() {
303 while embeddings_queue[i].1.len() == j {
304 i += 1;
305 j = 0;
306 }
307
308 embeddings_queue[i].1[j].embedding = embedding.to_owned();
309 j += 1;
310 }
311
312 for (worktree_id, documents, path, mtime, job_handle) in
313 embeddings_queue.into_iter()
314 {
315 for document in documents.iter() {
316 // TODO: Update this so it doesn't panic
317 assert!(
318 document.embedding.len() > 0,
319 "Document Embedding Not Complete"
320 );
321 }
322
323 db_update_tx
324 .send(DbOperation::InsertFile {
325 worktree_id,
326 documents,
327 path,
328 mtime,
329 job_handle,
330 })
331 .await
332 .unwrap();
333 }
334 }
335 }
336 }
337 });
338
339 // batch_tx/rx: Batch Files to Send for Embeddings
340 let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
341 let _batch_files_task = cx.background().spawn(async move {
342 let mut queue_len = 0;
343 let mut embeddings_queue = vec![];
344
345 while let Ok(job) = batch_files_rx.recv().await {
346 let should_flush = match job {
347 EmbeddingJob::Enqueue {
348 documents,
349 worktree_id,
350 path,
351 mtime,
352 job_handle,
353 } => {
354 queue_len += &documents.len();
355 embeddings_queue.push((
356 worktree_id,
357 documents,
358 path,
359 mtime,
360 job_handle,
361 ));
362 queue_len >= EMBEDDINGS_BATCH_SIZE
363 }
364 EmbeddingJob::Flush => true,
365 };
366
367 if should_flush {
368 embed_batch_tx.try_send(embeddings_queue).unwrap();
369 embeddings_queue = vec![];
370 queue_len = 0;
371 }
372 }
373 });
374
375 // parsing_files_tx/rx: Parsing Files to Embeddable Documents
376 let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
377
378 let mut _parsing_files_tasks = Vec::new();
379 for _ in 0..cx.background().num_cpus() {
380 let fs = fs.clone();
381 let parsing_files_rx = parsing_files_rx.clone();
382 let batch_files_tx = batch_files_tx.clone();
383 _parsing_files_tasks.push(cx.background().spawn(async move {
384 let mut retriever = CodeContextRetriever::new();
385 while let Ok(pending_file) = parsing_files_rx.recv().await {
386 if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err()
387 {
388 if let Some(documents) = retriever
389 .parse_file(
390 &pending_file.relative_path,
391 &content,
392 pending_file.language,
393 )
394 .log_err()
395 {
396 log::trace!(
397 "parsed path {:?}: {} documents",
398 pending_file.relative_path,
399 documents.len()
400 );
401
402 batch_files_tx
403 .try_send(EmbeddingJob::Enqueue {
404 worktree_id: pending_file.worktree_db_id,
405 path: pending_file.relative_path,
406 mtime: pending_file.modified_time,
407 job_handle: pending_file.job_handle,
408 documents,
409 })
410 .unwrap();
411 }
412 }
413
414 if parsing_files_rx.len() == 0 {
415 batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
416 }
417 }
418 }));
419 }
420
421 Self {
422 fs,
423 database_url,
424 embedding_provider,
425 language_registry,
426 db_update_tx,
427 next_job_id: Default::default(),
428 parsing_files_tx,
429 _db_update_task,
430 _embed_batch_task,
431 _batch_files_task,
432 _parsing_files_tasks,
433 projects: HashMap::new(),
434 }
435 }))
436 }
437
438 fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
439 let (tx, rx) = oneshot::channel();
440 self.db_update_tx
441 .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
442 .unwrap();
443 async move { rx.await? }
444 }
445
446 fn get_file_mtimes(
447 &self,
448 worktree_id: i64,
449 ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
450 let (tx, rx) = oneshot::channel();
451 self.db_update_tx
452 .try_send(DbOperation::FileMTimes {
453 worktree_id,
454 sender: tx,
455 })
456 .unwrap();
457 async move { rx.await? }
458 }
459
460 fn index_project(
461 &mut self,
462 project: ModelHandle<Project>,
463 cx: &mut ModelContext<Self>,
464 ) -> Task<Result<usize>> {
465 let worktree_scans_complete = project
466 .read(cx)
467 .worktrees(cx)
468 .map(|worktree| {
469 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
470 async move {
471 scan_complete.await;
472 }
473 })
474 .collect::<Vec<_>>();
475 let worktree_db_ids = project
476 .read(cx)
477 .worktrees(cx)
478 .map(|worktree| {
479 self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
480 })
481 .collect::<Vec<_>>();
482
483 let language_registry = self.language_registry.clone();
484 let db_update_tx = self.db_update_tx.clone();
485 let parsing_files_tx = self.parsing_files_tx.clone();
486 let next_job_id = self.next_job_id.clone();
487
488 cx.spawn(|this, mut cx| async move {
489 futures::future::join_all(worktree_scans_complete).await;
490
491 let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
492
493 let worktrees = project.read_with(&cx, |project, cx| {
494 project
495 .worktrees(cx)
496 .map(|worktree| worktree.read(cx).snapshot())
497 .collect::<Vec<_>>()
498 });
499
500 let mut worktree_file_mtimes = HashMap::new();
501 let mut db_ids_by_worktree_id = HashMap::new();
502 for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
503 let db_id = db_id?;
504 db_ids_by_worktree_id.insert(worktree.id(), db_id);
505 worktree_file_mtimes.insert(
506 worktree.id(),
507 this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
508 .await?,
509 );
510 }
511
512 // let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc<Language>, SystemTime), SystemTime))> = vec![];
513 let outstanding_jobs = Arc::new(Mutex::new(HashSet::new()));
514 this.update(&mut cx, |this, _| {
515 this.projects.insert(
516 project.downgrade(),
517 ProjectState {
518 worktree_db_ids: db_ids_by_worktree_id
519 .iter()
520 .map(|(a, b)| (*a, *b))
521 .collect(),
522 outstanding_jobs: outstanding_jobs.clone(),
523 },
524 );
525 });
526
527 cx.background()
528 .spawn(async move {
529 let mut count = 0;
530 let t0 = Instant::now();
531 for worktree in worktrees.into_iter() {
532 let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap();
533 for file in worktree.files(false, 0) {
534 let absolute_path = worktree.absolutize(&file.path);
535
536 if let Ok(language) = language_registry
537 .language_for_file(&absolute_path, None)
538 .await
539 {
540 if language
541 .grammar()
542 .and_then(|grammar| grammar.embedding_config.as_ref())
543 .is_none()
544 {
545 continue;
546 }
547
548 let path_buf = file.path.to_path_buf();
549 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
550 let already_stored = stored_mtime
551 .map_or(false, |existing_mtime| existing_mtime == file.mtime);
552
553 if !already_stored {
554 log::trace!("sending for parsing: {:?}", path_buf);
555 count += 1;
556 let job_id = next_job_id.fetch_add(1, atomic::Ordering::SeqCst);
557 let job_handle = JobHandle {
558 id: job_id,
559 set: Arc::downgrade(&outstanding_jobs),
560 };
561 outstanding_jobs.lock().insert(job_id);
562 parsing_files_tx
563 .try_send(PendingFile {
564 worktree_db_id: db_ids_by_worktree_id[&worktree.id()],
565 relative_path: path_buf,
566 absolute_path,
567 language,
568 job_handle,
569 modified_time: file.mtime,
570 })
571 .unwrap();
572 }
573 }
574 }
575 for file in file_mtimes.keys() {
576 db_update_tx
577 .try_send(DbOperation::Delete {
578 worktree_id: db_ids_by_worktree_id[&worktree.id()],
579 path: file.to_owned(),
580 })
581 .unwrap();
582 }
583 }
584 log::trace!(
585 "parsing worktree completed in {:?}",
586 t0.elapsed().as_millis()
587 );
588
589 Ok(count)
590 })
591 .await
592 })
593 }
594
595 pub fn remaining_files_to_index_for_project(
596 &self,
597 project: &ModelHandle<Project>,
598 ) -> Option<usize> {
599 Some(
600 self.projects
601 .get(&project.downgrade())?
602 .outstanding_jobs
603 .lock()
604 .len(),
605 )
606 }
607
608 pub fn search_project(
609 &mut self,
610 project: ModelHandle<Project>,
611 phrase: String,
612 limit: usize,
613 cx: &mut ModelContext<Self>,
614 ) -> Task<Result<Vec<SearchResult>>> {
615 let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
616 state
617 } else {
618 return Task::ready(Err(anyhow!("project not added")));
619 };
620
621 let worktree_db_ids = project
622 .read(cx)
623 .worktrees(cx)
624 .filter_map(|worktree| {
625 let worktree_id = worktree.read(cx).id();
626 project_state.db_id_for_worktree_id(worktree_id)
627 })
628 .collect::<Vec<_>>();
629
630 let embedding_provider = self.embedding_provider.clone();
631 let database_url = self.database_url.clone();
632 let fs = self.fs.clone();
633 cx.spawn(|this, cx| async move {
634 let documents = cx
635 .background()
636 .spawn(async move {
637 let database = VectorDatabase::new(fs, database_url).await?;
638
639 let phrase_embedding = embedding_provider
640 .embed_batch(vec![&phrase])
641 .await?
642 .into_iter()
643 .next()
644 .unwrap();
645
646 database.top_k_search(&worktree_db_ids, &phrase_embedding, limit)
647 })
648 .await?;
649
650 this.read_with(&cx, |this, _| {
651 let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
652 state
653 } else {
654 return Err(anyhow!("project not added"));
655 };
656
657 Ok(documents
658 .into_iter()
659 .filter_map(|(worktree_db_id, file_path, byte_range, name)| {
660 let worktree_id = project_state.worktree_id_for_db_id(worktree_db_id)?;
661 Some(SearchResult {
662 worktree_id,
663 name,
664 byte_range,
665 file_path,
666 })
667 })
668 .collect())
669 })
670 })
671 }
672}
673
674impl Entity for VectorStore {
675 type Event = ();
676}
677
678impl Drop for JobHandle {
679 fn drop(&mut self) {
680 if let Some(set) = self.set.upgrade() {
681 set.lock().remove(&self.id);
682 }
683 }
684}