1mod db;
2mod embedding;
3mod modal;
4mod parsing;
5mod vector_store_settings;
6
7#[cfg(test)]
8mod vector_store_tests;
9
10use crate::vector_store_settings::VectorStoreSettings;
11use anyhow::{anyhow, Result};
12use db::VectorDatabase;
13use embedding::{EmbeddingProvider, OpenAIEmbeddings};
14use futures::{channel::oneshot, Future};
15use gpui::{
16 AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
17 WeakModelHandle,
18};
19use language::{Language, LanguageRegistry};
20use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
21use parsing::{CodeContextRetriever, Document};
22use project::{Fs, PathChange, Project, ProjectEntryId, WorktreeId};
23use smol::channel;
24use std::{
25 collections::HashMap,
26 ops::Range,
27 path::{Path, PathBuf},
28 sync::Arc,
29 time::{Duration, Instant, SystemTime},
30};
31use util::{
32 channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
33 http::HttpClient,
34 paths::EMBEDDINGS_DIR,
35 ResultExt,
36};
37use workspace::{Workspace, WorkspaceCreated};
38
39const VECTOR_STORE_VERSION: usize = 1;
40const EMBEDDINGS_BATCH_SIZE: usize = 150;
41
42pub fn init(
43 fs: Arc<dyn Fs>,
44 http_client: Arc<dyn HttpClient>,
45 language_registry: Arc<LanguageRegistry>,
46 cx: &mut AppContext,
47) {
48 settings::register::<VectorStoreSettings>(cx);
49
50 let db_file_path = EMBEDDINGS_DIR
51 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
52 .join("embeddings_db");
53
54 SemanticSearch::init(cx);
55 cx.add_action(
56 |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
57 if cx.has_global::<ModelHandle<VectorStore>>() {
58 let vector_store = cx.global::<ModelHandle<VectorStore>>().clone();
59 workspace.toggle_modal(cx, |workspace, cx| {
60 let project = workspace.project().clone();
61 let workspace = cx.weak_handle();
62 cx.add_view(|cx| {
63 SemanticSearch::new(
64 SemanticSearchDelegate::new(workspace, project, vector_store),
65 cx,
66 )
67 })
68 });
69 }
70 },
71 );
72
73 if *RELEASE_CHANNEL == ReleaseChannel::Stable
74 || !settings::get::<VectorStoreSettings>(cx).enabled
75 {
76 return;
77 }
78
79 cx.spawn(move |mut cx| async move {
80 let vector_store = VectorStore::new(
81 fs,
82 db_file_path,
83 Arc::new(OpenAIEmbeddings {
84 client: http_client,
85 executor: cx.background(),
86 }),
87 language_registry,
88 cx.clone(),
89 )
90 .await?;
91
92 cx.update(|cx| {
93 cx.set_global(vector_store.clone());
94 cx.subscribe_global::<WorkspaceCreated, _>({
95 let vector_store = vector_store.clone();
96 move |event, cx| {
97 let workspace = &event.0;
98 if let Some(workspace) = workspace.upgrade(cx) {
99 let project = workspace.read(cx).project().clone();
100 if project.read(cx).is_local() {
101 vector_store.update(cx, |store, cx| {
102 store.add_project(project, cx).detach();
103 });
104 }
105 }
106 }
107 })
108 .detach();
109 });
110
111 anyhow::Ok(())
112 })
113 .detach();
114}
115
116pub struct VectorStore {
117 fs: Arc<dyn Fs>,
118 database_url: Arc<PathBuf>,
119 embedding_provider: Arc<dyn EmbeddingProvider>,
120 language_registry: Arc<LanguageRegistry>,
121 db_update_tx: channel::Sender<DbOperation>,
122 parsing_files_tx: channel::Sender<PendingFile>,
123 _db_update_task: Task<()>,
124 _embed_batch_task: Task<()>,
125 _batch_files_task: Task<()>,
126 _parsing_files_tasks: Vec<Task<()>>,
127 projects: HashMap<WeakModelHandle<Project>, ProjectState>,
128}
129
130struct ProjectState {
131 worktree_db_ids: Vec<(WorktreeId, i64)>,
132 pending_files: HashMap<PathBuf, (PendingFile, SystemTime)>,
133 _subscription: gpui::Subscription,
134}
135
136impl ProjectState {
137 fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
138 self.worktree_db_ids
139 .iter()
140 .find_map(|(worktree_id, db_id)| {
141 if *worktree_id == id {
142 Some(*db_id)
143 } else {
144 None
145 }
146 })
147 }
148
149 fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
150 self.worktree_db_ids
151 .iter()
152 .find_map(|(worktree_id, db_id)| {
153 if *db_id == id {
154 Some(*worktree_id)
155 } else {
156 None
157 }
158 })
159 }
160
161 fn update_pending_files(&mut self, pending_file: PendingFile, indexing_time: SystemTime) {
162 // If Pending File Already Exists, Replace it with the new one
163 // but keep the old indexing time
164 if let Some(old_file) = self
165 .pending_files
166 .remove(&pending_file.relative_path.clone())
167 {
168 self.pending_files.insert(
169 pending_file.relative_path.clone(),
170 (pending_file, old_file.1),
171 );
172 } else {
173 self.pending_files.insert(
174 pending_file.relative_path.clone(),
175 (pending_file, indexing_time),
176 );
177 };
178 }
179
180 fn get_outstanding_files(&mut self) -> Vec<PendingFile> {
181 let mut outstanding_files = vec![];
182 let mut remove_keys = vec![];
183 for key in self.pending_files.keys().into_iter() {
184 if let Some(pending_details) = self.pending_files.get(key) {
185 let (pending_file, index_time) = pending_details;
186 if index_time <= &SystemTime::now() {
187 outstanding_files.push(pending_file.clone());
188 remove_keys.push(key.clone());
189 }
190 }
191 }
192
193 for key in remove_keys.iter() {
194 self.pending_files.remove(key);
195 }
196
197 return outstanding_files;
198 }
199}
200
201#[derive(Clone, Debug)]
202pub struct PendingFile {
203 worktree_db_id: i64,
204 relative_path: PathBuf,
205 absolute_path: PathBuf,
206 language: Arc<Language>,
207 modified_time: SystemTime,
208}
209
210#[derive(Debug, Clone)]
211pub struct SearchResult {
212 pub worktree_id: WorktreeId,
213 pub name: String,
214 pub byte_range: Range<usize>,
215 pub file_path: PathBuf,
216}
217
218enum DbOperation {
219 InsertFile {
220 worktree_id: i64,
221 documents: Vec<Document>,
222 path: PathBuf,
223 mtime: SystemTime,
224 },
225 Delete {
226 worktree_id: i64,
227 path: PathBuf,
228 },
229 FindOrCreateWorktree {
230 path: PathBuf,
231 sender: oneshot::Sender<Result<i64>>,
232 },
233 FileMTimes {
234 worktree_id: i64,
235 sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
236 },
237}
238
239enum EmbeddingJob {
240 Enqueue {
241 worktree_id: i64,
242 path: PathBuf,
243 mtime: SystemTime,
244 documents: Vec<Document>,
245 },
246 Flush,
247}
248
249impl VectorStore {
250 async fn new(
251 fs: Arc<dyn Fs>,
252 database_url: PathBuf,
253 embedding_provider: Arc<dyn EmbeddingProvider>,
254 language_registry: Arc<LanguageRegistry>,
255 mut cx: AsyncAppContext,
256 ) -> Result<ModelHandle<Self>> {
257 let database_url = Arc::new(database_url);
258
259 let db = cx
260 .background()
261 .spawn(VectorDatabase::new(fs.clone(), database_url.clone()))
262 .await?;
263
264 Ok(cx.add_model(|cx| {
265 // paths_tx -> embeddings_tx -> db_update_tx
266
267 //db_update_tx/rx: Updating Database
268 let (db_update_tx, db_update_rx) = channel::unbounded();
269 let _db_update_task = cx.background().spawn(async move {
270 while let Ok(job) = db_update_rx.recv().await {
271 match job {
272 DbOperation::InsertFile {
273 worktree_id,
274 documents,
275 path,
276 mtime,
277 } => {
278 db.insert_file(worktree_id, path, mtime, documents)
279 .log_err();
280 }
281 DbOperation::Delete { worktree_id, path } => {
282 db.delete_file(worktree_id, path).log_err();
283 }
284 DbOperation::FindOrCreateWorktree { path, sender } => {
285 let id = db.find_or_create_worktree(&path);
286 sender.send(id).ok();
287 }
288 DbOperation::FileMTimes {
289 worktree_id: worktree_db_id,
290 sender,
291 } => {
292 let file_mtimes = db.get_file_mtimes(worktree_db_id);
293 sender.send(file_mtimes).ok();
294 }
295 }
296 }
297 });
298
299 // embed_tx/rx: Embed Batch and Send to Database
300 let (embed_batch_tx, embed_batch_rx) =
301 channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime)>>();
302 let _embed_batch_task = cx.background().spawn({
303 let db_update_tx = db_update_tx.clone();
304 let embedding_provider = embedding_provider.clone();
305 async move {
306 while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await {
307 // Construct Batch
308 let mut batch_documents = vec![];
309 for (_, documents, _, _) in embeddings_queue.iter() {
310 batch_documents
311 .extend(documents.iter().map(|document| document.content.as_str()));
312 }
313
314 if let Ok(embeddings) =
315 embedding_provider.embed_batch(batch_documents).await
316 {
317 log::trace!(
318 "created {} embeddings for {} files",
319 embeddings.len(),
320 embeddings_queue.len(),
321 );
322
323 let mut i = 0;
324 let mut j = 0;
325
326 for embedding in embeddings.iter() {
327 while embeddings_queue[i].1.len() == j {
328 i += 1;
329 j = 0;
330 }
331
332 embeddings_queue[i].1[j].embedding = embedding.to_owned();
333 j += 1;
334 }
335
336 for (worktree_id, documents, path, mtime) in
337 embeddings_queue.into_iter()
338 {
339 for document in documents.iter() {
340 // TODO: Update this so it doesn't panic
341 assert!(
342 document.embedding.len() > 0,
343 "Document Embedding Not Complete"
344 );
345 }
346
347 db_update_tx
348 .send(DbOperation::InsertFile {
349 worktree_id,
350 documents,
351 path,
352 mtime,
353 })
354 .await
355 .unwrap();
356 }
357 }
358 }
359 }
360 });
361
362 // batch_tx/rx: Batch Files to Send for Embeddings
363 let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
364 let _batch_files_task = cx.background().spawn(async move {
365 let mut queue_len = 0;
366 let mut embeddings_queue = vec![];
367
368 while let Ok(job) = batch_files_rx.recv().await {
369 let should_flush = match job {
370 EmbeddingJob::Enqueue {
371 documents,
372 worktree_id,
373 path,
374 mtime,
375 } => {
376 queue_len += &documents.len();
377 embeddings_queue.push((worktree_id, documents, path, mtime));
378 queue_len >= EMBEDDINGS_BATCH_SIZE
379 }
380 EmbeddingJob::Flush => true,
381 };
382
383 if should_flush {
384 embed_batch_tx.try_send(embeddings_queue).unwrap();
385 embeddings_queue = vec![];
386 queue_len = 0;
387 }
388 }
389 });
390
391 // parsing_files_tx/rx: Parsing Files to Embeddable Documents
392 let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
393
394 let mut _parsing_files_tasks = Vec::new();
395 for _ in 0..cx.background().num_cpus() {
396 let fs = fs.clone();
397 let parsing_files_rx = parsing_files_rx.clone();
398 let batch_files_tx = batch_files_tx.clone();
399 _parsing_files_tasks.push(cx.background().spawn(async move {
400 let mut retriever = CodeContextRetriever::new();
401 while let Ok(pending_file) = parsing_files_rx.recv().await {
402 if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err()
403 {
404 if let Some(documents) = retriever
405 .parse_file(
406 &pending_file.relative_path,
407 &content,
408 pending_file.language,
409 )
410 .log_err()
411 {
412 log::trace!(
413 "parsed path {:?}: {} documents",
414 pending_file.relative_path,
415 documents.len()
416 );
417
418 batch_files_tx
419 .try_send(EmbeddingJob::Enqueue {
420 worktree_id: pending_file.worktree_db_id,
421 path: pending_file.relative_path,
422 mtime: pending_file.modified_time,
423 documents,
424 })
425 .unwrap();
426 }
427 }
428
429 if parsing_files_rx.len() == 0 {
430 batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
431 }
432 }
433 }));
434 }
435
436 Self {
437 fs,
438 database_url,
439 embedding_provider,
440 language_registry,
441 db_update_tx,
442 parsing_files_tx,
443 _db_update_task,
444 _embed_batch_task,
445 _batch_files_task,
446 _parsing_files_tasks,
447 projects: HashMap::new(),
448 }
449 }))
450 }
451
452 fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
453 let (tx, rx) = oneshot::channel();
454 self.db_update_tx
455 .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
456 .unwrap();
457 async move { rx.await? }
458 }
459
460 fn get_file_mtimes(
461 &self,
462 worktree_id: i64,
463 ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
464 let (tx, rx) = oneshot::channel();
465 self.db_update_tx
466 .try_send(DbOperation::FileMTimes {
467 worktree_id,
468 sender: tx,
469 })
470 .unwrap();
471 async move { rx.await? }
472 }
473
474 fn add_project(
475 &mut self,
476 project: ModelHandle<Project>,
477 cx: &mut ModelContext<Self>,
478 ) -> Task<Result<()>> {
479 let worktree_scans_complete = project
480 .read(cx)
481 .worktrees(cx)
482 .map(|worktree| {
483 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
484 async move {
485 scan_complete.await;
486 }
487 })
488 .collect::<Vec<_>>();
489 let worktree_db_ids = project
490 .read(cx)
491 .worktrees(cx)
492 .map(|worktree| {
493 self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
494 })
495 .collect::<Vec<_>>();
496
497 let fs = self.fs.clone();
498 let language_registry = self.language_registry.clone();
499 let database_url = self.database_url.clone();
500 let db_update_tx = self.db_update_tx.clone();
501 let parsing_files_tx = self.parsing_files_tx.clone();
502
503 cx.spawn(|this, mut cx| async move {
504 futures::future::join_all(worktree_scans_complete).await;
505
506 let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
507
508 if let Some(db_directory) = database_url.parent() {
509 fs.create_dir(db_directory).await.log_err();
510 }
511
512 let worktrees = project.read_with(&cx, |project, cx| {
513 project
514 .worktrees(cx)
515 .map(|worktree| worktree.read(cx).snapshot())
516 .collect::<Vec<_>>()
517 });
518
519 let mut worktree_file_times = HashMap::new();
520 let mut db_ids_by_worktree_id = HashMap::new();
521 for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
522 let db_id = db_id?;
523 db_ids_by_worktree_id.insert(worktree.id(), db_id);
524 worktree_file_times.insert(
525 worktree.id(),
526 this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
527 .await?,
528 );
529 }
530
531 cx.background()
532 .spawn({
533 let db_ids_by_worktree_id = db_ids_by_worktree_id.clone();
534 let db_update_tx = db_update_tx.clone();
535 let language_registry = language_registry.clone();
536 let parsing_files_tx = parsing_files_tx.clone();
537 async move {
538 let t0 = Instant::now();
539 for worktree in worktrees.into_iter() {
540 let mut file_mtimes =
541 worktree_file_times.remove(&worktree.id()).unwrap();
542 for file in worktree.files(false, 0) {
543 let absolute_path = worktree.absolutize(&file.path);
544
545 if let Ok(language) = language_registry
546 .language_for_file(&absolute_path, None)
547 .await
548 {
549 if language
550 .grammar()
551 .and_then(|grammar| grammar.embedding_config.as_ref())
552 .is_none()
553 {
554 continue;
555 }
556
557 let path_buf = file.path.to_path_buf();
558 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
559 let already_stored = stored_mtime
560 .map_or(false, |existing_mtime| {
561 existing_mtime == file.mtime
562 });
563
564 if !already_stored {
565 log::trace!("sending for parsing: {:?}", path_buf);
566 parsing_files_tx
567 .try_send(PendingFile {
568 worktree_db_id: db_ids_by_worktree_id
569 [&worktree.id()],
570 relative_path: path_buf,
571 absolute_path,
572 language,
573 modified_time: file.mtime,
574 })
575 .unwrap();
576 }
577 }
578 }
579 for file in file_mtimes.keys() {
580 db_update_tx
581 .try_send(DbOperation::Delete {
582 worktree_id: db_ids_by_worktree_id[&worktree.id()],
583 path: file.to_owned(),
584 })
585 .unwrap();
586 }
587 }
588 log::trace!(
589 "parsing worktree completed in {:?}",
590 t0.elapsed().as_millis()
591 );
592 }
593 })
594 .detach();
595
596 // let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc<Language>, SystemTime), SystemTime))> = vec![];
597 this.update(&mut cx, |this, cx| {
598 // The below is managing for updated on save
599 // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
600 // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
601 let _subscription = cx.subscribe(&project, |this, project, event, cx| {
602 if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event {
603 this.project_entries_changed(project, changes.clone(), cx, worktree_id);
604 }
605 });
606
607 this.projects.insert(
608 project.downgrade(),
609 ProjectState {
610 pending_files: HashMap::new(),
611 worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(),
612 _subscription,
613 },
614 );
615 });
616
617 anyhow::Ok(())
618 })
619 }
620
621 pub fn search(
622 &mut self,
623 project: ModelHandle<Project>,
624 phrase: String,
625 limit: usize,
626 cx: &mut ModelContext<Self>,
627 ) -> Task<Result<Vec<SearchResult>>> {
628 let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
629 state
630 } else {
631 return Task::ready(Err(anyhow!("project not added")));
632 };
633
634 let worktree_db_ids = project
635 .read(cx)
636 .worktrees(cx)
637 .filter_map(|worktree| {
638 let worktree_id = worktree.read(cx).id();
639 project_state.db_id_for_worktree_id(worktree_id)
640 })
641 .collect::<Vec<_>>();
642
643 let embedding_provider = self.embedding_provider.clone();
644 let database_url = self.database_url.clone();
645 let fs = self.fs.clone();
646 cx.spawn(|this, cx| async move {
647 let documents = cx
648 .background()
649 .spawn(async move {
650 let database = VectorDatabase::new(fs, database_url).await?;
651
652 let phrase_embedding = embedding_provider
653 .embed_batch(vec![&phrase])
654 .await?
655 .into_iter()
656 .next()
657 .unwrap();
658
659 database.top_k_search(&worktree_db_ids, &phrase_embedding, limit)
660 })
661 .await?;
662
663 this.read_with(&cx, |this, _| {
664 let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
665 state
666 } else {
667 return Err(anyhow!("project not added"));
668 };
669
670 Ok(documents
671 .into_iter()
672 .filter_map(|(worktree_db_id, file_path, byte_range, name)| {
673 let worktree_id = project_state.worktree_id_for_db_id(worktree_db_id)?;
674 Some(SearchResult {
675 worktree_id,
676 name,
677 byte_range,
678 file_path,
679 })
680 })
681 .collect())
682 })
683 })
684 }
685
686 fn project_entries_changed(
687 &mut self,
688 project: ModelHandle<Project>,
689 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
690 cx: &mut ModelContext<'_, VectorStore>,
691 worktree_id: &WorktreeId,
692 ) -> Option<()> {
693 let reindexing_delay = settings::get::<VectorStoreSettings>(cx).reindexing_delay_seconds;
694
695 let worktree = project
696 .read(cx)
697 .worktree_for_id(worktree_id.clone(), cx)?
698 .read(cx)
699 .snapshot();
700
701 let worktree_db_id = self
702 .projects
703 .get(&project.downgrade())?
704 .db_id_for_worktree_id(worktree.id())?;
705 let file_mtimes = self.get_file_mtimes(worktree_db_id);
706
707 let language_registry = self.language_registry.clone();
708
709 cx.spawn(|this, mut cx| async move {
710 let file_mtimes = file_mtimes.await.log_err()?;
711
712 for change in changes.into_iter() {
713 let change_path = change.0.clone();
714 let absolute_path = worktree.absolutize(&change_path);
715
716 // Skip if git ignored or symlink
717 if let Some(entry) = worktree.entry_for_id(change.1) {
718 if entry.is_ignored || entry.is_symlink || entry.is_external {
719 continue;
720 }
721 }
722
723 match change.2 {
724 PathChange::Removed => this.update(&mut cx, |this, _| {
725 this.db_update_tx
726 .try_send(DbOperation::Delete {
727 worktree_id: worktree_db_id,
728 path: absolute_path,
729 })
730 .unwrap();
731 }),
732 _ => {
733 if let Ok(language) = language_registry
734 .language_for_file(&change_path.to_path_buf(), None)
735 .await
736 {
737 if language
738 .grammar()
739 .and_then(|grammar| grammar.embedding_config.as_ref())
740 .is_none()
741 {
742 continue;
743 }
744
745 let modified_time =
746 change_path.metadata().log_err()?.modified().log_err()?;
747
748 let existing_time = file_mtimes.get(&change_path.to_path_buf());
749 let already_stored = existing_time
750 .map_or(false, |existing_time| &modified_time != existing_time);
751
752 if !already_stored {
753 this.update(&mut cx, |this, _| {
754 let reindex_time = modified_time
755 + Duration::from_secs(reindexing_delay as u64);
756
757 let project_state =
758 this.projects.get_mut(&project.downgrade())?;
759 project_state.update_pending_files(
760 PendingFile {
761 relative_path: change_path.to_path_buf(),
762 absolute_path,
763 modified_time,
764 worktree_db_id,
765 language: language.clone(),
766 },
767 reindex_time,
768 );
769
770 for file in project_state.get_outstanding_files() {
771 this.parsing_files_tx.try_send(file).unwrap();
772 }
773 Some(())
774 });
775 }
776 }
777 }
778 }
779 }
780
781 Some(())
782 })
783 .detach();
784
785 Some(())
786 }
787}
788
789impl Entity for VectorStore {
790 type Event = ();
791}