1mod db;
2mod embedding;
3mod modal;
4mod parsing;
5mod vector_store_settings;
6
7#[cfg(test)]
8mod vector_store_tests;
9
10use crate::vector_store_settings::VectorStoreSettings;
11use anyhow::{anyhow, Result};
12use db::VectorDatabase;
13use embedding::{EmbeddingProvider, OpenAIEmbeddings};
14use futures::{channel::oneshot, Future};
15use gpui::{
16 AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
17 WeakModelHandle,
18};
19use language::{Language, LanguageRegistry};
20use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
21use parsing::{CodeContextRetriever, ParsedFile};
22use project::{Fs, PathChange, Project, ProjectEntryId, WorktreeId};
23use smol::channel;
24use std::{
25 collections::HashMap,
26 path::{Path, PathBuf},
27 sync::Arc,
28 time::{Duration, Instant, SystemTime},
29};
30use tree_sitter::{Parser, QueryCursor};
31use util::{
32 channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
33 http::HttpClient,
34 paths::EMBEDDINGS_DIR,
35 ResultExt,
36};
37use workspace::{Workspace, WorkspaceCreated};
38
39const VECTOR_STORE_VERSION: usize = 0;
40
41pub fn init(
42 fs: Arc<dyn Fs>,
43 http_client: Arc<dyn HttpClient>,
44 language_registry: Arc<LanguageRegistry>,
45 cx: &mut AppContext,
46) {
47 if *RELEASE_CHANNEL == ReleaseChannel::Stable {
48 return;
49 }
50
51 settings::register::<VectorStoreSettings>(cx);
52 if !settings::get::<VectorStoreSettings>(cx).enable {
53 return;
54 }
55
56 let db_file_path = EMBEDDINGS_DIR
57 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
58 .join("embeddings_db");
59
60 SemanticSearch::init(cx);
61 cx.add_action(
62 |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
63 eprintln!("semantic_search::Toggle action");
64
65 if cx.has_global::<ModelHandle<VectorStore>>() {
66 let vector_store = cx.global::<ModelHandle<VectorStore>>().clone();
67 workspace.toggle_modal(cx, |workspace, cx| {
68 let project = workspace.project().clone();
69 let workspace = cx.weak_handle();
70 cx.add_view(|cx| {
71 SemanticSearch::new(
72 SemanticSearchDelegate::new(workspace, project, vector_store),
73 cx,
74 )
75 })
76 });
77 }
78 },
79 );
80
81 cx.spawn(move |mut cx| async move {
82 let vector_store = VectorStore::new(
83 fs,
84 db_file_path,
85 // Arc::new(embedding::DummyEmbeddings {}),
86 Arc::new(OpenAIEmbeddings {
87 client: http_client,
88 executor: cx.background(),
89 }),
90 language_registry,
91 cx.clone(),
92 )
93 .await?;
94
95 cx.update(|cx| {
96 cx.set_global(vector_store.clone());
97 cx.subscribe_global::<WorkspaceCreated, _>({
98 let vector_store = vector_store.clone();
99 move |event, cx| {
100 let workspace = &event.0;
101 if let Some(workspace) = workspace.upgrade(cx) {
102 let project = workspace.read(cx).project().clone();
103 if project.read(cx).is_local() {
104 vector_store.update(cx, |store, cx| {
105 store.add_project(project, cx).detach();
106 });
107 }
108 }
109 }
110 })
111 .detach();
112 });
113
114 anyhow::Ok(())
115 })
116 .detach();
117}
118
119pub struct VectorStore {
120 fs: Arc<dyn Fs>,
121 database_url: Arc<PathBuf>,
122 embedding_provider: Arc<dyn EmbeddingProvider>,
123 language_registry: Arc<LanguageRegistry>,
124 db_update_tx: channel::Sender<DbOperation>,
125 parsing_files_tx: channel::Sender<PendingFile>,
126 _db_update_task: Task<()>,
127 _embed_batch_task: Task<()>,
128 _batch_files_task: Task<()>,
129 _parsing_files_tasks: Vec<Task<()>>,
130 projects: HashMap<WeakModelHandle<Project>, ProjectState>,
131}
132
133struct ProjectState {
134 worktree_db_ids: Vec<(WorktreeId, i64)>,
135 pending_files: HashMap<PathBuf, (PendingFile, SystemTime)>,
136 _subscription: gpui::Subscription,
137}
138
139impl ProjectState {
140 fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
141 self.worktree_db_ids
142 .iter()
143 .find_map(|(worktree_id, db_id)| {
144 if *worktree_id == id {
145 Some(*db_id)
146 } else {
147 None
148 }
149 })
150 }
151
152 fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
153 self.worktree_db_ids
154 .iter()
155 .find_map(|(worktree_id, db_id)| {
156 if *db_id == id {
157 Some(*worktree_id)
158 } else {
159 None
160 }
161 })
162 }
163
164 fn update_pending_files(&mut self, pending_file: PendingFile, indexing_time: SystemTime) {
165 // If Pending File Already Exists, Replace it with the new one
166 // but keep the old indexing time
167 if let Some(old_file) = self
168 .pending_files
169 .remove(&pending_file.relative_path.clone())
170 {
171 self.pending_files.insert(
172 pending_file.relative_path.clone(),
173 (pending_file, old_file.1),
174 );
175 } else {
176 self.pending_files.insert(
177 pending_file.relative_path.clone(),
178 (pending_file, indexing_time),
179 );
180 };
181 }
182
183 fn get_outstanding_files(&mut self) -> Vec<PendingFile> {
184 let mut outstanding_files = vec![];
185 let mut remove_keys = vec![];
186 for key in self.pending_files.keys().into_iter() {
187 if let Some(pending_details) = self.pending_files.get(key) {
188 let (pending_file, index_time) = pending_details;
189 if index_time <= &SystemTime::now() {
190 outstanding_files.push(pending_file.clone());
191 remove_keys.push(key.clone());
192 }
193 }
194 }
195
196 for key in remove_keys.iter() {
197 self.pending_files.remove(key);
198 }
199
200 return outstanding_files;
201 }
202}
203
204#[derive(Clone, Debug)]
205pub struct PendingFile {
206 worktree_db_id: i64,
207 relative_path: PathBuf,
208 absolute_path: PathBuf,
209 language: Arc<Language>,
210 modified_time: SystemTime,
211}
212
213#[derive(Debug, Clone)]
214pub struct SearchResult {
215 pub worktree_id: WorktreeId,
216 pub name: String,
217 pub offset: usize,
218 pub file_path: PathBuf,
219}
220
221enum DbOperation {
222 InsertFile {
223 worktree_id: i64,
224 indexed_file: ParsedFile,
225 },
226 Delete {
227 worktree_id: i64,
228 path: PathBuf,
229 },
230 FindOrCreateWorktree {
231 path: PathBuf,
232 sender: oneshot::Sender<Result<i64>>,
233 },
234 FileMTimes {
235 worktree_id: i64,
236 sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
237 },
238}
239
240enum EmbeddingJob {
241 Enqueue {
242 worktree_id: i64,
243 parsed_file: ParsedFile,
244 document_spans: Vec<String>,
245 },
246 Flush,
247}
248
249impl VectorStore {
250 async fn new(
251 fs: Arc<dyn Fs>,
252 database_url: PathBuf,
253 embedding_provider: Arc<dyn EmbeddingProvider>,
254 language_registry: Arc<LanguageRegistry>,
255 mut cx: AsyncAppContext,
256 ) -> Result<ModelHandle<Self>> {
257 let database_url = Arc::new(database_url);
258
259 let db = cx
260 .background()
261 .spawn({
262 let fs = fs.clone();
263 let database_url = database_url.clone();
264 async move {
265 if let Some(db_directory) = database_url.parent() {
266 fs.create_dir(db_directory).await.log_err();
267 }
268
269 let db = VectorDatabase::new(database_url.to_string_lossy().to_string())?;
270 anyhow::Ok(db)
271 }
272 })
273 .await?;
274
275 Ok(cx.add_model(|cx| {
276 // paths_tx -> embeddings_tx -> db_update_tx
277
278 //db_update_tx/rx: Updating Database
279 let (db_update_tx, db_update_rx) = channel::unbounded();
280 let _db_update_task = cx.background().spawn(async move {
281 while let Ok(job) = db_update_rx.recv().await {
282 match job {
283 DbOperation::InsertFile {
284 worktree_id,
285 indexed_file,
286 } => {
287 db.insert_file(worktree_id, indexed_file).log_err();
288 }
289 DbOperation::Delete { worktree_id, path } => {
290 db.delete_file(worktree_id, path).log_err();
291 }
292 DbOperation::FindOrCreateWorktree { path, sender } => {
293 let id = db.find_or_create_worktree(&path);
294 sender.send(id).ok();
295 }
296 DbOperation::FileMTimes {
297 worktree_id: worktree_db_id,
298 sender,
299 } => {
300 let file_mtimes = db.get_file_mtimes(worktree_db_id);
301 sender.send(file_mtimes).ok();
302 }
303 }
304 }
305 });
306
307 // embed_tx/rx: Embed Batch and Send to Database
308 let (embed_batch_tx, embed_batch_rx) =
309 channel::unbounded::<Vec<(i64, ParsedFile, Vec<String>)>>();
310 let _embed_batch_task = cx.background().spawn({
311 let db_update_tx = db_update_tx.clone();
312 let embedding_provider = embedding_provider.clone();
313 async move {
314 while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await {
315 // Construct Batch
316 let mut document_spans = vec![];
317 for (_, _, document_span) in embeddings_queue.iter() {
318 document_spans.extend(document_span.iter().map(|s| s.as_str()));
319 }
320
321 if let Ok(embeddings) = embedding_provider.embed_batch(document_spans).await
322 {
323 let mut i = 0;
324 let mut j = 0;
325
326 for embedding in embeddings.iter() {
327 while embeddings_queue[i].1.documents.len() == j {
328 i += 1;
329 j = 0;
330 }
331
332 embeddings_queue[i].1.documents[j].embedding = embedding.to_owned();
333 j += 1;
334 }
335
336 for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() {
337 for document in indexed_file.documents.iter() {
338 // TODO: Update this so it doesn't panic
339 assert!(
340 document.embedding.len() > 0,
341 "Document Embedding Not Complete"
342 );
343 }
344
345 db_update_tx
346 .send(DbOperation::InsertFile {
347 worktree_id,
348 indexed_file,
349 })
350 .await
351 .unwrap();
352 }
353 }
354 }
355 }
356 });
357
358 // batch_tx/rx: Batch Files to Send for Embeddings
359 let batch_size = settings::get::<VectorStoreSettings>(cx).embedding_batch_size;
360 let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
361 let _batch_files_task = cx.background().spawn(async move {
362 let mut queue_len = 0;
363 let mut embeddings_queue = vec![];
364
365 while let Ok(job) = batch_files_rx.recv().await {
366 let should_flush = match job {
367 EmbeddingJob::Enqueue {
368 document_spans,
369 worktree_id,
370 parsed_file,
371 } => {
372 queue_len += &document_spans.len();
373 embeddings_queue.push((worktree_id, parsed_file, document_spans));
374 queue_len >= batch_size
375 }
376 EmbeddingJob::Flush => true,
377 };
378
379 if should_flush {
380 embed_batch_tx.try_send(embeddings_queue).unwrap();
381 embeddings_queue = vec![];
382 queue_len = 0;
383 }
384 }
385 });
386
387 // parsing_files_tx/rx: Parsing Files to Embeddable Documents
388 let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
389
390 let mut _parsing_files_tasks = Vec::new();
391 // for _ in 0..cx.background().num_cpus() {
392 for _ in 0..1 {
393 let fs = fs.clone();
394 let parsing_files_rx = parsing_files_rx.clone();
395 let batch_files_tx = batch_files_tx.clone();
396 _parsing_files_tasks.push(cx.background().spawn(async move {
397 let parser = Parser::new();
398 let cursor = QueryCursor::new();
399 let mut retriever = CodeContextRetriever { parser, cursor, fs };
400 while let Ok(pending_file) = parsing_files_rx.recv().await {
401 if let Some((indexed_file, document_spans)) =
402 retriever.parse_file(pending_file.clone()).await.log_err()
403 {
404 batch_files_tx
405 .try_send(EmbeddingJob::Enqueue {
406 worktree_id: pending_file.worktree_db_id,
407 parsed_file: indexed_file,
408 document_spans,
409 })
410 .unwrap();
411 }
412
413 if parsing_files_rx.len() == 0 {
414 batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
415 }
416 }
417 }));
418 }
419
420 Self {
421 fs,
422 database_url,
423 embedding_provider,
424 language_registry,
425 db_update_tx,
426 parsing_files_tx,
427 _db_update_task,
428 _embed_batch_task,
429 _batch_files_task,
430 _parsing_files_tasks,
431 projects: HashMap::new(),
432 }
433 }))
434 }
435
436 fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
437 let (tx, rx) = oneshot::channel();
438 self.db_update_tx
439 .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
440 .unwrap();
441 async move { rx.await? }
442 }
443
444 fn get_file_mtimes(
445 &self,
446 worktree_id: i64,
447 ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
448 let (tx, rx) = oneshot::channel();
449 self.db_update_tx
450 .try_send(DbOperation::FileMTimes {
451 worktree_id,
452 sender: tx,
453 })
454 .unwrap();
455 async move { rx.await? }
456 }
457
458 fn add_project(
459 &mut self,
460 project: ModelHandle<Project>,
461 cx: &mut ModelContext<Self>,
462 ) -> Task<Result<()>> {
463 let worktree_scans_complete = project
464 .read(cx)
465 .worktrees(cx)
466 .map(|worktree| {
467 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
468 async move {
469 scan_complete.await;
470 }
471 })
472 .collect::<Vec<_>>();
473 let worktree_db_ids = project
474 .read(cx)
475 .worktrees(cx)
476 .map(|worktree| {
477 self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
478 })
479 .collect::<Vec<_>>();
480
481 let fs = self.fs.clone();
482 let language_registry = self.language_registry.clone();
483 let database_url = self.database_url.clone();
484 let db_update_tx = self.db_update_tx.clone();
485 let parsing_files_tx = self.parsing_files_tx.clone();
486
487 cx.spawn(|this, mut cx| async move {
488 futures::future::join_all(worktree_scans_complete).await;
489
490 let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
491
492 if let Some(db_directory) = database_url.parent() {
493 fs.create_dir(db_directory).await.log_err();
494 }
495
496 let worktrees = project.read_with(&cx, |project, cx| {
497 project
498 .worktrees(cx)
499 .map(|worktree| worktree.read(cx).snapshot())
500 .collect::<Vec<_>>()
501 });
502
503 let mut worktree_file_times = HashMap::new();
504 let mut db_ids_by_worktree_id = HashMap::new();
505 for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
506 let db_id = db_id?;
507 db_ids_by_worktree_id.insert(worktree.id(), db_id);
508 worktree_file_times.insert(
509 worktree.id(),
510 this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
511 .await?,
512 );
513 }
514
515 cx.background()
516 .spawn({
517 let db_ids_by_worktree_id = db_ids_by_worktree_id.clone();
518 let db_update_tx = db_update_tx.clone();
519 let language_registry = language_registry.clone();
520 let parsing_files_tx = parsing_files_tx.clone();
521 async move {
522 let t0 = Instant::now();
523 for worktree in worktrees.into_iter() {
524 let mut file_mtimes =
525 worktree_file_times.remove(&worktree.id()).unwrap();
526 for file in worktree.files(false, 0) {
527 let absolute_path = worktree.absolutize(&file.path);
528
529 if let Ok(language) = language_registry
530 .language_for_file(&absolute_path, None)
531 .await
532 {
533 if language
534 .grammar()
535 .and_then(|grammar| grammar.embedding_config.as_ref())
536 .is_none()
537 {
538 continue;
539 }
540
541 let path_buf = file.path.to_path_buf();
542 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
543 let already_stored = stored_mtime
544 .map_or(false, |existing_mtime| {
545 existing_mtime == file.mtime
546 });
547
548 if !already_stored {
549 parsing_files_tx
550 .try_send(PendingFile {
551 worktree_db_id: db_ids_by_worktree_id
552 [&worktree.id()],
553 relative_path: path_buf,
554 absolute_path,
555 language,
556 modified_time: file.mtime,
557 })
558 .unwrap();
559 }
560 }
561 }
562 for file in file_mtimes.keys() {
563 db_update_tx
564 .try_send(DbOperation::Delete {
565 worktree_id: db_ids_by_worktree_id[&worktree.id()],
566 path: file.to_owned(),
567 })
568 .unwrap();
569 }
570 }
571 log::info!(
572 "Parsing Worktree Completed in {:?}",
573 t0.elapsed().as_millis()
574 );
575 }
576 })
577 .detach();
578
579 // let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc<Language>, SystemTime), SystemTime))> = vec![];
580 this.update(&mut cx, |this, cx| {
581 // The below is managing for updated on save
582 // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
583 // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
584 let _subscription = cx.subscribe(&project, |this, project, event, cx| {
585 if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event {
586 this.project_entries_changed(project, changes.clone(), cx, worktree_id);
587 }
588 });
589
590 this.projects.insert(
591 project.downgrade(),
592 ProjectState {
593 pending_files: HashMap::new(),
594 worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(),
595 _subscription,
596 },
597 );
598 });
599
600 anyhow::Ok(())
601 })
602 }
603
604 pub fn search(
605 &mut self,
606 project: ModelHandle<Project>,
607 phrase: String,
608 limit: usize,
609 cx: &mut ModelContext<Self>,
610 ) -> Task<Result<Vec<SearchResult>>> {
611 let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
612 state
613 } else {
614 return Task::ready(Err(anyhow!("project not added")));
615 };
616
617 let worktree_db_ids = project
618 .read(cx)
619 .worktrees(cx)
620 .filter_map(|worktree| {
621 let worktree_id = worktree.read(cx).id();
622 project_state.db_id_for_worktree_id(worktree_id)
623 })
624 .collect::<Vec<_>>();
625
626 let embedding_provider = self.embedding_provider.clone();
627 let database_url = self.database_url.clone();
628 cx.spawn(|this, cx| async move {
629 let documents = cx
630 .background()
631 .spawn(async move {
632 let database = VectorDatabase::new(database_url.to_string_lossy().into())?;
633
634 let phrase_embedding = embedding_provider
635 .embed_batch(vec![&phrase])
636 .await?
637 .into_iter()
638 .next()
639 .unwrap();
640
641 database.top_k_search(&worktree_db_ids, &phrase_embedding, limit)
642 })
643 .await?;
644
645 this.read_with(&cx, |this, _| {
646 let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
647 state
648 } else {
649 return Err(anyhow!("project not added"));
650 };
651
652 Ok(documents
653 .into_iter()
654 .filter_map(|(worktree_db_id, file_path, offset, name)| {
655 let worktree_id = project_state.worktree_id_for_db_id(worktree_db_id)?;
656 Some(SearchResult {
657 worktree_id,
658 name,
659 offset,
660 file_path,
661 })
662 })
663 .collect())
664 })
665 })
666 }
667
668 fn project_entries_changed(
669 &mut self,
670 project: ModelHandle<Project>,
671 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
672 cx: &mut ModelContext<'_, VectorStore>,
673 worktree_id: &WorktreeId,
674 ) -> Option<()> {
675 let reindexing_delay = settings::get::<VectorStoreSettings>(cx).reindexing_delay_seconds;
676
677 let worktree = project
678 .read(cx)
679 .worktree_for_id(worktree_id.clone(), cx)?
680 .read(cx)
681 .snapshot();
682
683 let worktree_db_id = self
684 .projects
685 .get(&project.downgrade())?
686 .db_id_for_worktree_id(worktree.id())?;
687 let file_mtimes = self.get_file_mtimes(worktree_db_id);
688
689 let language_registry = self.language_registry.clone();
690
691 cx.spawn(|this, mut cx| async move {
692 let file_mtimes = file_mtimes.await.log_err()?;
693
694 for change in changes.into_iter() {
695 let change_path = change.0.clone();
696 let absolute_path = worktree.absolutize(&change_path);
697
698 // Skip if git ignored or symlink
699 if let Some(entry) = worktree.entry_for_id(change.1) {
700 if entry.is_ignored || entry.is_symlink || entry.is_external {
701 continue;
702 }
703 }
704
705 match change.2 {
706 PathChange::Removed => this.update(&mut cx, |this, _| {
707 this.db_update_tx
708 .try_send(DbOperation::Delete {
709 worktree_id: worktree_db_id,
710 path: absolute_path,
711 })
712 .unwrap();
713 }),
714 _ => {
715 if let Ok(language) = language_registry
716 .language_for_file(&change_path.to_path_buf(), None)
717 .await
718 {
719 if language
720 .grammar()
721 .and_then(|grammar| grammar.embedding_config.as_ref())
722 .is_none()
723 {
724 continue;
725 }
726
727 let modified_time =
728 change_path.metadata().log_err()?.modified().log_err()?;
729
730 let existing_time = file_mtimes.get(&change_path.to_path_buf());
731 let already_stored = existing_time
732 .map_or(false, |existing_time| &modified_time != existing_time);
733
734 if !already_stored {
735 this.update(&mut cx, |this, _| {
736 let reindex_time = modified_time
737 + Duration::from_secs(reindexing_delay as u64);
738
739 let project_state =
740 this.projects.get_mut(&project.downgrade())?;
741 project_state.update_pending_files(
742 PendingFile {
743 relative_path: change_path.to_path_buf(),
744 absolute_path,
745 modified_time,
746 worktree_db_id,
747 language: language.clone(),
748 },
749 reindex_time,
750 );
751
752 for file in project_state.get_outstanding_files() {
753 this.parsing_files_tx.try_send(file).unwrap();
754 }
755 Some(())
756 });
757 }
758 }
759 }
760 }
761 }
762
763 Some(())
764 })
765 .detach();
766
767 Some(())
768 }
769}
770
771impl Entity for VectorStore {
772 type Event = ();
773}