1mod db;
2mod embedding;
3mod modal;
4mod parsing;
5mod vector_store_settings;
6
7#[cfg(test)]
8mod vector_store_tests;
9
10use crate::vector_store_settings::VectorStoreSettings;
11use anyhow::{anyhow, Result};
12use db::VectorDatabase;
13use embedding::{EmbeddingProvider, OpenAIEmbeddings};
14use futures::{channel::oneshot, Future};
15use gpui::{
16 AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
17 WeakModelHandle,
18};
19use language::{Language, LanguageRegistry};
20use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
21use parsing::{CodeContextRetriever, Document};
22use project::{Fs, PathChange, Project, ProjectEntryId, WorktreeId};
23use smol::channel;
24use std::{
25 collections::HashMap,
26 ops::Range,
27 path::{Path, PathBuf},
28 sync::Arc,
29 time::{Duration, Instant, SystemTime},
30};
31use util::{
32 channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
33 http::HttpClient,
34 paths::EMBEDDINGS_DIR,
35 ResultExt,
36};
37use workspace::{Workspace, WorkspaceCreated};
38
39const VECTOR_STORE_VERSION: usize = 1;
40const EMBEDDINGS_BATCH_SIZE: usize = 150;
41
42pub fn init(
43 fs: Arc<dyn Fs>,
44 http_client: Arc<dyn HttpClient>,
45 language_registry: Arc<LanguageRegistry>,
46 cx: &mut AppContext,
47) {
48 settings::register::<VectorStoreSettings>(cx);
49
50 let db_file_path = EMBEDDINGS_DIR
51 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
52 .join("embeddings_db");
53
54 SemanticSearch::init(cx);
55 cx.add_action(
56 |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
57 if cx.has_global::<ModelHandle<VectorStore>>() {
58 let vector_store = cx.global::<ModelHandle<VectorStore>>().clone();
59 workspace.toggle_modal(cx, |workspace, cx| {
60 let project = workspace.project().clone();
61 let workspace = cx.weak_handle();
62 cx.add_view(|cx| {
63 SemanticSearch::new(
64 SemanticSearchDelegate::new(workspace, project, vector_store),
65 cx,
66 )
67 })
68 });
69 }
70 },
71 );
72
73 if *RELEASE_CHANNEL == ReleaseChannel::Stable
74 || !settings::get::<VectorStoreSettings>(cx).enabled
75 {
76 return;
77 }
78
79 cx.spawn(move |mut cx| async move {
80 let vector_store = VectorStore::new(
81 fs,
82 db_file_path,
83 Arc::new(embedding::DummyEmbeddings {}),
84 // Arc::new(OpenAIEmbeddings {
85 // client: http_client,
86 // executor: cx.background(),
87 // }),
88 language_registry,
89 cx.clone(),
90 )
91 .await?;
92
93 cx.update(|cx| {
94 cx.set_global(vector_store.clone());
95 cx.subscribe_global::<WorkspaceCreated, _>({
96 let vector_store = vector_store.clone();
97 move |event, cx| {
98 let workspace = &event.0;
99 if let Some(workspace) = workspace.upgrade(cx) {
100 let project = workspace.read(cx).project().clone();
101 if project.read(cx).is_local() {
102 vector_store.update(cx, |store, cx| {
103 store.add_project(project, cx).detach();
104 });
105 }
106 }
107 }
108 })
109 .detach();
110 });
111
112 anyhow::Ok(())
113 })
114 .detach();
115}
116
117pub struct VectorStore {
118 fs: Arc<dyn Fs>,
119 database_url: Arc<PathBuf>,
120 embedding_provider: Arc<dyn EmbeddingProvider>,
121 language_registry: Arc<LanguageRegistry>,
122 db_update_tx: channel::Sender<DbOperation>,
123 parsing_files_tx: channel::Sender<PendingFile>,
124 _db_update_task: Task<()>,
125 _embed_batch_task: Task<()>,
126 _batch_files_task: Task<()>,
127 _parsing_files_tasks: Vec<Task<()>>,
128 projects: HashMap<WeakModelHandle<Project>, ProjectState>,
129}
130
131struct ProjectState {
132 worktree_db_ids: Vec<(WorktreeId, i64)>,
133 pending_files: HashMap<PathBuf, (PendingFile, SystemTime)>,
134 _subscription: gpui::Subscription,
135}
136
137impl ProjectState {
138 fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
139 self.worktree_db_ids
140 .iter()
141 .find_map(|(worktree_id, db_id)| {
142 if *worktree_id == id {
143 Some(*db_id)
144 } else {
145 None
146 }
147 })
148 }
149
150 fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
151 self.worktree_db_ids
152 .iter()
153 .find_map(|(worktree_id, db_id)| {
154 if *db_id == id {
155 Some(*worktree_id)
156 } else {
157 None
158 }
159 })
160 }
161
162 fn update_pending_files(&mut self, pending_file: PendingFile, indexing_time: SystemTime) {
163 // If Pending File Already Exists, Replace it with the new one
164 // but keep the old indexing time
165 if let Some(old_file) = self
166 .pending_files
167 .remove(&pending_file.relative_path.clone())
168 {
169 self.pending_files.insert(
170 pending_file.relative_path.clone(),
171 (pending_file, old_file.1),
172 );
173 } else {
174 self.pending_files.insert(
175 pending_file.relative_path.clone(),
176 (pending_file, indexing_time),
177 );
178 };
179 }
180
181 fn get_outstanding_files(&mut self) -> Vec<PendingFile> {
182 let mut outstanding_files = vec![];
183 let mut remove_keys = vec![];
184 for key in self.pending_files.keys().into_iter() {
185 if let Some(pending_details) = self.pending_files.get(key) {
186 let (pending_file, index_time) = pending_details;
187 if index_time <= &SystemTime::now() {
188 outstanding_files.push(pending_file.clone());
189 remove_keys.push(key.clone());
190 }
191 }
192 }
193
194 for key in remove_keys.iter() {
195 self.pending_files.remove(key);
196 }
197
198 return outstanding_files;
199 }
200}
201
202#[derive(Clone, Debug)]
203pub struct PendingFile {
204 worktree_db_id: i64,
205 relative_path: PathBuf,
206 absolute_path: PathBuf,
207 language: Arc<Language>,
208 modified_time: SystemTime,
209}
210
211#[derive(Debug, Clone)]
212pub struct SearchResult {
213 pub worktree_id: WorktreeId,
214 pub name: String,
215 pub byte_range: Range<usize>,
216 pub file_path: PathBuf,
217}
218
219enum DbOperation {
220 InsertFile {
221 worktree_id: i64,
222 documents: Vec<Document>,
223 path: PathBuf,
224 mtime: SystemTime,
225 },
226 Delete {
227 worktree_id: i64,
228 path: PathBuf,
229 },
230 FindOrCreateWorktree {
231 path: PathBuf,
232 sender: oneshot::Sender<Result<i64>>,
233 },
234 FileMTimes {
235 worktree_id: i64,
236 sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
237 },
238}
239
240enum EmbeddingJob {
241 Enqueue {
242 worktree_id: i64,
243 path: PathBuf,
244 mtime: SystemTime,
245 documents: Vec<Document>,
246 },
247 Flush,
248}
249
250impl VectorStore {
251 async fn new(
252 fs: Arc<dyn Fs>,
253 database_url: PathBuf,
254 embedding_provider: Arc<dyn EmbeddingProvider>,
255 language_registry: Arc<LanguageRegistry>,
256 mut cx: AsyncAppContext,
257 ) -> Result<ModelHandle<Self>> {
258 let database_url = Arc::new(database_url);
259
260 let db = cx
261 .background()
262 .spawn(VectorDatabase::new(fs.clone(), database_url.clone()))
263 .await?;
264
265 Ok(cx.add_model(|cx| {
266 // paths_tx -> embeddings_tx -> db_update_tx
267
268 //db_update_tx/rx: Updating Database
269 let (db_update_tx, db_update_rx) = channel::unbounded();
270 let _db_update_task = cx.background().spawn(async move {
271 while let Ok(job) = db_update_rx.recv().await {
272 match job {
273 DbOperation::InsertFile {
274 worktree_id,
275 documents,
276 path,
277 mtime,
278 } => {
279 db.insert_file(worktree_id, path, mtime, documents)
280 .log_err();
281 }
282 DbOperation::Delete { worktree_id, path } => {
283 db.delete_file(worktree_id, path).log_err();
284 }
285 DbOperation::FindOrCreateWorktree { path, sender } => {
286 let id = db.find_or_create_worktree(&path);
287 sender.send(id).ok();
288 }
289 DbOperation::FileMTimes {
290 worktree_id: worktree_db_id,
291 sender,
292 } => {
293 let file_mtimes = db.get_file_mtimes(worktree_db_id);
294 sender.send(file_mtimes).ok();
295 }
296 }
297 }
298 });
299
300 // embed_tx/rx: Embed Batch and Send to Database
301 let (embed_batch_tx, embed_batch_rx) =
302 channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime)>>();
303 let _embed_batch_task = cx.background().spawn({
304 let db_update_tx = db_update_tx.clone();
305 let embedding_provider = embedding_provider.clone();
306 async move {
307 while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await {
308 // Construct Batch
309 let mut batch_documents = vec![];
310 for (_, documents, _, _) in embeddings_queue.iter() {
311 batch_documents
312 .extend(documents.iter().map(|document| document.content.as_str()));
313 }
314
315 if let Ok(embeddings) =
316 embedding_provider.embed_batch(batch_documents).await
317 {
318 log::trace!(
319 "created {} embeddings for {} files",
320 embeddings.len(),
321 embeddings_queue.len(),
322 );
323
324 let mut i = 0;
325 let mut j = 0;
326
327 for embedding in embeddings.iter() {
328 while embeddings_queue[i].1.len() == j {
329 i += 1;
330 j = 0;
331 }
332
333 embeddings_queue[i].1[j].embedding = embedding.to_owned();
334 j += 1;
335 }
336
337 for (worktree_id, documents, path, mtime) in
338 embeddings_queue.into_iter()
339 {
340 for document in documents.iter() {
341 // TODO: Update this so it doesn't panic
342 assert!(
343 document.embedding.len() > 0,
344 "Document Embedding Not Complete"
345 );
346 }
347
348 db_update_tx
349 .send(DbOperation::InsertFile {
350 worktree_id,
351 documents,
352 path,
353 mtime,
354 })
355 .await
356 .unwrap();
357 }
358 }
359 }
360 }
361 });
362
363 // batch_tx/rx: Batch Files to Send for Embeddings
364 let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
365 let _batch_files_task = cx.background().spawn(async move {
366 let mut queue_len = 0;
367 let mut embeddings_queue = vec![];
368
369 while let Ok(job) = batch_files_rx.recv().await {
370 let should_flush = match job {
371 EmbeddingJob::Enqueue {
372 documents,
373 worktree_id,
374 path,
375 mtime,
376 } => {
377 queue_len += &documents.len();
378 embeddings_queue.push((worktree_id, documents, path, mtime));
379 queue_len >= EMBEDDINGS_BATCH_SIZE
380 }
381 EmbeddingJob::Flush => true,
382 };
383
384 if should_flush {
385 embed_batch_tx.try_send(embeddings_queue).unwrap();
386 embeddings_queue = vec![];
387 queue_len = 0;
388 }
389 }
390 });
391
392 // parsing_files_tx/rx: Parsing Files to Embeddable Documents
393 let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
394
395 let mut _parsing_files_tasks = Vec::new();
396 for _ in 0..cx.background().num_cpus() {
397 let fs = fs.clone();
398 let parsing_files_rx = parsing_files_rx.clone();
399 let batch_files_tx = batch_files_tx.clone();
400 _parsing_files_tasks.push(cx.background().spawn(async move {
401 let mut retriever = CodeContextRetriever::new();
402 while let Ok(pending_file) = parsing_files_rx.recv().await {
403 if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err()
404 {
405 if let Some(documents) = retriever
406 .parse_file(
407 &pending_file.relative_path,
408 &content,
409 pending_file.language,
410 )
411 .log_err()
412 {
413 log::trace!(
414 "parsed path {:?}: {} documents",
415 pending_file.relative_path,
416 documents.len()
417 );
418
419 batch_files_tx
420 .try_send(EmbeddingJob::Enqueue {
421 worktree_id: pending_file.worktree_db_id,
422 path: pending_file.relative_path,
423 mtime: pending_file.modified_time,
424 documents,
425 })
426 .unwrap();
427 }
428 }
429
430 if parsing_files_rx.len() == 0 {
431 batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
432 }
433 }
434 }));
435 }
436
437 Self {
438 fs,
439 database_url,
440 embedding_provider,
441 language_registry,
442 db_update_tx,
443 parsing_files_tx,
444 _db_update_task,
445 _embed_batch_task,
446 _batch_files_task,
447 _parsing_files_tasks,
448 projects: HashMap::new(),
449 }
450 }))
451 }
452
453 fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
454 let (tx, rx) = oneshot::channel();
455 self.db_update_tx
456 .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
457 .unwrap();
458 async move { rx.await? }
459 }
460
461 fn get_file_mtimes(
462 &self,
463 worktree_id: i64,
464 ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
465 let (tx, rx) = oneshot::channel();
466 self.db_update_tx
467 .try_send(DbOperation::FileMTimes {
468 worktree_id,
469 sender: tx,
470 })
471 .unwrap();
472 async move { rx.await? }
473 }
474
475 fn add_project(
476 &mut self,
477 project: ModelHandle<Project>,
478 cx: &mut ModelContext<Self>,
479 ) -> Task<Result<()>> {
480 let worktree_scans_complete = project
481 .read(cx)
482 .worktrees(cx)
483 .map(|worktree| {
484 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
485 async move {
486 scan_complete.await;
487 }
488 })
489 .collect::<Vec<_>>();
490 let worktree_db_ids = project
491 .read(cx)
492 .worktrees(cx)
493 .map(|worktree| {
494 self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
495 })
496 .collect::<Vec<_>>();
497
498 let fs = self.fs.clone();
499 let language_registry = self.language_registry.clone();
500 let database_url = self.database_url.clone();
501 let db_update_tx = self.db_update_tx.clone();
502 let parsing_files_tx = self.parsing_files_tx.clone();
503
504 cx.spawn(|this, mut cx| async move {
505 futures::future::join_all(worktree_scans_complete).await;
506
507 let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
508
509 if let Some(db_directory) = database_url.parent() {
510 fs.create_dir(db_directory).await.log_err();
511 }
512
513 let worktrees = project.read_with(&cx, |project, cx| {
514 project
515 .worktrees(cx)
516 .map(|worktree| worktree.read(cx).snapshot())
517 .collect::<Vec<_>>()
518 });
519
520 let mut worktree_file_times = HashMap::new();
521 let mut db_ids_by_worktree_id = HashMap::new();
522 for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
523 let db_id = db_id?;
524 db_ids_by_worktree_id.insert(worktree.id(), db_id);
525 worktree_file_times.insert(
526 worktree.id(),
527 this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
528 .await?,
529 );
530 }
531
532 cx.background()
533 .spawn({
534 let db_ids_by_worktree_id = db_ids_by_worktree_id.clone();
535 let db_update_tx = db_update_tx.clone();
536 let language_registry = language_registry.clone();
537 let parsing_files_tx = parsing_files_tx.clone();
538 async move {
539 let t0 = Instant::now();
540 for worktree in worktrees.into_iter() {
541 let mut file_mtimes =
542 worktree_file_times.remove(&worktree.id()).unwrap();
543 for file in worktree.files(false, 0) {
544 let absolute_path = worktree.absolutize(&file.path);
545
546 if let Ok(language) = language_registry
547 .language_for_file(&absolute_path, None)
548 .await
549 {
550 if language
551 .grammar()
552 .and_then(|grammar| grammar.embedding_config.as_ref())
553 .is_none()
554 {
555 continue;
556 }
557
558 let path_buf = file.path.to_path_buf();
559 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
560 let already_stored = stored_mtime
561 .map_or(false, |existing_mtime| {
562 existing_mtime == file.mtime
563 });
564
565 if !already_stored {
566 log::trace!("sending for parsing: {:?}", path_buf);
567 parsing_files_tx
568 .try_send(PendingFile {
569 worktree_db_id: db_ids_by_worktree_id
570 [&worktree.id()],
571 relative_path: path_buf,
572 absolute_path,
573 language,
574 modified_time: file.mtime,
575 })
576 .unwrap();
577 }
578 }
579 }
580 for file in file_mtimes.keys() {
581 db_update_tx
582 .try_send(DbOperation::Delete {
583 worktree_id: db_ids_by_worktree_id[&worktree.id()],
584 path: file.to_owned(),
585 })
586 .unwrap();
587 }
588 }
589 log::trace!(
590 "parsing worktree completed in {:?}",
591 t0.elapsed().as_millis()
592 );
593 }
594 })
595 .detach();
596
597 // let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc<Language>, SystemTime), SystemTime))> = vec![];
598 this.update(&mut cx, |this, cx| {
599 // The below is managing for updated on save
600 // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
601 // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
602 let _subscription = cx.subscribe(&project, |this, project, event, cx| {
603 if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event {
604 this.project_entries_changed(project, changes.clone(), cx, worktree_id);
605 }
606 });
607
608 this.projects.insert(
609 project.downgrade(),
610 ProjectState {
611 pending_files: HashMap::new(),
612 worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(),
613 _subscription,
614 },
615 );
616 });
617
618 anyhow::Ok(())
619 })
620 }
621
622 pub fn search(
623 &mut self,
624 project: ModelHandle<Project>,
625 phrase: String,
626 limit: usize,
627 cx: &mut ModelContext<Self>,
628 ) -> Task<Result<Vec<SearchResult>>> {
629 let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
630 state
631 } else {
632 return Task::ready(Err(anyhow!("project not added")));
633 };
634
635 let worktree_db_ids = project
636 .read(cx)
637 .worktrees(cx)
638 .filter_map(|worktree| {
639 let worktree_id = worktree.read(cx).id();
640 project_state.db_id_for_worktree_id(worktree_id)
641 })
642 .collect::<Vec<_>>();
643
644 let embedding_provider = self.embedding_provider.clone();
645 let database_url = self.database_url.clone();
646 let fs = self.fs.clone();
647 cx.spawn(|this, cx| async move {
648 let documents = cx
649 .background()
650 .spawn(async move {
651 let database = VectorDatabase::new(fs, database_url).await?;
652
653 let phrase_embedding = embedding_provider
654 .embed_batch(vec![&phrase])
655 .await?
656 .into_iter()
657 .next()
658 .unwrap();
659
660 database.top_k_search(&worktree_db_ids, &phrase_embedding, limit)
661 })
662 .await?;
663
664 this.read_with(&cx, |this, _| {
665 let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
666 state
667 } else {
668 return Err(anyhow!("project not added"));
669 };
670
671 Ok(documents
672 .into_iter()
673 .filter_map(|(worktree_db_id, file_path, byte_range, name)| {
674 let worktree_id = project_state.worktree_id_for_db_id(worktree_db_id)?;
675 Some(SearchResult {
676 worktree_id,
677 name,
678 byte_range,
679 file_path,
680 })
681 })
682 .collect())
683 })
684 })
685 }
686
687 fn project_entries_changed(
688 &mut self,
689 project: ModelHandle<Project>,
690 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
691 cx: &mut ModelContext<'_, VectorStore>,
692 worktree_id: &WorktreeId,
693 ) -> Option<()> {
694 let reindexing_delay = settings::get::<VectorStoreSettings>(cx).reindexing_delay_seconds;
695
696 let worktree = project
697 .read(cx)
698 .worktree_for_id(worktree_id.clone(), cx)?
699 .read(cx)
700 .snapshot();
701
702 let worktree_db_id = self
703 .projects
704 .get(&project.downgrade())?
705 .db_id_for_worktree_id(worktree.id())?;
706 let file_mtimes = self.get_file_mtimes(worktree_db_id);
707
708 let language_registry = self.language_registry.clone();
709
710 cx.spawn(|this, mut cx| async move {
711 let file_mtimes = file_mtimes.await.log_err()?;
712
713 for change in changes.into_iter() {
714 let change_path = change.0.clone();
715 let absolute_path = worktree.absolutize(&change_path);
716
717 // Skip if git ignored or symlink
718 if let Some(entry) = worktree.entry_for_id(change.1) {
719 if entry.is_ignored || entry.is_symlink || entry.is_external {
720 continue;
721 }
722 }
723
724 match change.2 {
725 PathChange::Removed => this.update(&mut cx, |this, _| {
726 this.db_update_tx
727 .try_send(DbOperation::Delete {
728 worktree_id: worktree_db_id,
729 path: absolute_path,
730 })
731 .unwrap();
732 }),
733 _ => {
734 if let Ok(language) = language_registry
735 .language_for_file(&change_path.to_path_buf(), None)
736 .await
737 {
738 if language
739 .grammar()
740 .and_then(|grammar| grammar.embedding_config.as_ref())
741 .is_none()
742 {
743 continue;
744 }
745
746 let modified_time =
747 change_path.metadata().log_err()?.modified().log_err()?;
748
749 let existing_time = file_mtimes.get(&change_path.to_path_buf());
750 let already_stored = existing_time
751 .map_or(false, |existing_time| &modified_time != existing_time);
752
753 if !already_stored {
754 this.update(&mut cx, |this, _| {
755 let reindex_time = modified_time
756 + Duration::from_secs(reindexing_delay as u64);
757
758 let project_state =
759 this.projects.get_mut(&project.downgrade())?;
760 project_state.update_pending_files(
761 PendingFile {
762 relative_path: change_path.to_path_buf(),
763 absolute_path,
764 modified_time,
765 worktree_db_id,
766 language: language.clone(),
767 },
768 reindex_time,
769 );
770
771 for file in project_state.get_outstanding_files() {
772 this.parsing_files_tx.try_send(file).unwrap();
773 }
774 Some(())
775 });
776 }
777 }
778 }
779 }
780 }
781
782 Some(())
783 })
784 .detach();
785
786 Some(())
787 }
788}
789
790impl Entity for VectorStore {
791 type Event = ();
792}