1mod db;
2mod embedding;
3mod modal;
4mod parsing;
5mod vector_store_settings;
6
7#[cfg(test)]
8mod vector_store_tests;
9
10use crate::vector_store_settings::VectorStoreSettings;
11use anyhow::{anyhow, Result};
12use db::VectorDatabase;
13use embedding::{EmbeddingProvider, OpenAIEmbeddings};
14use futures::{channel::oneshot, Future};
15use gpui::{
16 AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
17 WeakModelHandle,
18};
19use language::{Language, LanguageRegistry};
20use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
21use parsing::{CodeContextRetriever, ParsedFile};
22use project::{Fs, PathChange, Project, ProjectEntryId, WorktreeId};
23use smol::channel;
24use std::{
25 collections::HashMap,
26 path::{Path, PathBuf},
27 sync::Arc,
28 time::{Duration, Instant, SystemTime},
29};
30use tree_sitter::{Parser, QueryCursor};
31use util::{
32 channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
33 http::HttpClient,
34 paths::EMBEDDINGS_DIR,
35 ResultExt,
36};
37use workspace::{Workspace, WorkspaceCreated};
38
39const VECTOR_STORE_VERSION: usize = 0;
40
41pub fn init(
42 fs: Arc<dyn Fs>,
43 http_client: Arc<dyn HttpClient>,
44 language_registry: Arc<LanguageRegistry>,
45 cx: &mut AppContext,
46) {
47 if *RELEASE_CHANNEL == ReleaseChannel::Stable {
48 return;
49 }
50
51 settings::register::<VectorStoreSettings>(cx);
52
53 if !settings::get::<VectorStoreSettings>(cx).enable {
54 return;
55 }
56
57 let db_file_path = EMBEDDINGS_DIR
58 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
59 .join("embeddings_db");
60
61 cx.spawn(move |mut cx| async move {
62 let vector_store = VectorStore::new(
63 fs,
64 db_file_path,
65 // Arc::new(embedding::DummyEmbeddings {}),
66 Arc::new(OpenAIEmbeddings {
67 client: http_client,
68 executor: cx.background(),
69 }),
70 language_registry,
71 cx.clone(),
72 )
73 .await?;
74
75 cx.update(|cx| {
76 cx.subscribe_global::<WorkspaceCreated, _>({
77 let vector_store = vector_store.clone();
78 move |event, cx| {
79 let workspace = &event.0;
80 if let Some(workspace) = workspace.upgrade(cx) {
81 let project = workspace.read(cx).project().clone();
82 if project.read(cx).is_local() {
83 vector_store.update(cx, |store, cx| {
84 store.add_project(project, cx).detach();
85 });
86 }
87 }
88 }
89 })
90 .detach();
91
92 cx.add_action({
93 // "semantic search: Toggle"
94 move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
95 let vector_store = vector_store.clone();
96 workspace.toggle_modal(cx, |workspace, cx| {
97 let project = workspace.project().clone();
98 let workspace = cx.weak_handle();
99 cx.add_view(|cx| {
100 SemanticSearch::new(
101 SemanticSearchDelegate::new(workspace, project, vector_store),
102 cx,
103 )
104 })
105 })
106 }
107 });
108
109 SemanticSearch::init(cx);
110 });
111
112 anyhow::Ok(())
113 })
114 .detach();
115}
116
117pub struct VectorStore {
118 fs: Arc<dyn Fs>,
119 database_url: Arc<PathBuf>,
120 embedding_provider: Arc<dyn EmbeddingProvider>,
121 language_registry: Arc<LanguageRegistry>,
122 db_update_tx: channel::Sender<DbOperation>,
123 parsing_files_tx: channel::Sender<PendingFile>,
124 _db_update_task: Task<()>,
125 _embed_batch_task: Task<()>,
126 _batch_files_task: Task<()>,
127 _parsing_files_tasks: Vec<Task<()>>,
128 projects: HashMap<WeakModelHandle<Project>, ProjectState>,
129}
130
131struct ProjectState {
132 worktree_db_ids: Vec<(WorktreeId, i64)>,
133 pending_files: HashMap<PathBuf, (PendingFile, SystemTime)>,
134 _subscription: gpui::Subscription,
135}
136
137impl ProjectState {
138 fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
139 self.worktree_db_ids
140 .iter()
141 .find_map(|(worktree_id, db_id)| {
142 if *worktree_id == id {
143 Some(*db_id)
144 } else {
145 None
146 }
147 })
148 }
149
150 fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
151 self.worktree_db_ids
152 .iter()
153 .find_map(|(worktree_id, db_id)| {
154 if *db_id == id {
155 Some(*worktree_id)
156 } else {
157 None
158 }
159 })
160 }
161
162 fn update_pending_files(&mut self, pending_file: PendingFile, indexing_time: SystemTime) {
163 // If Pending File Already Exists, Replace it with the new one
164 // but keep the old indexing time
165 if let Some(old_file) = self
166 .pending_files
167 .remove(&pending_file.relative_path.clone())
168 {
169 self.pending_files.insert(
170 pending_file.relative_path.clone(),
171 (pending_file, old_file.1),
172 );
173 } else {
174 self.pending_files.insert(
175 pending_file.relative_path.clone(),
176 (pending_file, indexing_time),
177 );
178 };
179 }
180
181 fn get_outstanding_files(&mut self) -> Vec<PendingFile> {
182 let mut outstanding_files = vec![];
183 let mut remove_keys = vec![];
184 for key in self.pending_files.keys().into_iter() {
185 if let Some(pending_details) = self.pending_files.get(key) {
186 let (pending_file, index_time) = pending_details;
187 if index_time <= &SystemTime::now() {
188 outstanding_files.push(pending_file.clone());
189 remove_keys.push(key.clone());
190 }
191 }
192 }
193
194 for key in remove_keys.iter() {
195 self.pending_files.remove(key);
196 }
197
198 return outstanding_files;
199 }
200}
201
202#[derive(Clone, Debug)]
203pub struct PendingFile {
204 worktree_db_id: i64,
205 relative_path: PathBuf,
206 absolute_path: PathBuf,
207 language: Arc<Language>,
208 modified_time: SystemTime,
209}
210
211#[derive(Debug, Clone)]
212pub struct SearchResult {
213 pub worktree_id: WorktreeId,
214 pub name: String,
215 pub offset: usize,
216 pub file_path: PathBuf,
217}
218
219enum DbOperation {
220 InsertFile {
221 worktree_id: i64,
222 indexed_file: ParsedFile,
223 },
224 Delete {
225 worktree_id: i64,
226 path: PathBuf,
227 },
228 FindOrCreateWorktree {
229 path: PathBuf,
230 sender: oneshot::Sender<Result<i64>>,
231 },
232 FileMTimes {
233 worktree_id: i64,
234 sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
235 },
236}
237
238enum EmbeddingJob {
239 Enqueue {
240 worktree_id: i64,
241 parsed_file: ParsedFile,
242 document_spans: Vec<String>,
243 },
244 Flush,
245}
246
247impl VectorStore {
248 async fn new(
249 fs: Arc<dyn Fs>,
250 database_url: PathBuf,
251 embedding_provider: Arc<dyn EmbeddingProvider>,
252 language_registry: Arc<LanguageRegistry>,
253 mut cx: AsyncAppContext,
254 ) -> Result<ModelHandle<Self>> {
255 let database_url = Arc::new(database_url);
256
257 let db = cx
258 .background()
259 .spawn({
260 let fs = fs.clone();
261 let database_url = database_url.clone();
262 async move {
263 if let Some(db_directory) = database_url.parent() {
264 fs.create_dir(db_directory).await.log_err();
265 }
266
267 let db = VectorDatabase::new(database_url.to_string_lossy().to_string())?;
268 anyhow::Ok(db)
269 }
270 })
271 .await?;
272
273 Ok(cx.add_model(|cx| {
274 // paths_tx -> embeddings_tx -> db_update_tx
275
276 //db_update_tx/rx: Updating Database
277 let (db_update_tx, db_update_rx) = channel::unbounded();
278 let _db_update_task = cx.background().spawn(async move {
279 while let Ok(job) = db_update_rx.recv().await {
280 match job {
281 DbOperation::InsertFile {
282 worktree_id,
283 indexed_file,
284 } => {
285 db.insert_file(worktree_id, indexed_file).log_err();
286 }
287 DbOperation::Delete { worktree_id, path } => {
288 db.delete_file(worktree_id, path).log_err();
289 }
290 DbOperation::FindOrCreateWorktree { path, sender } => {
291 let id = db.find_or_create_worktree(&path);
292 sender.send(id).ok();
293 }
294 DbOperation::FileMTimes {
295 worktree_id: worktree_db_id,
296 sender,
297 } => {
298 let file_mtimes = db.get_file_mtimes(worktree_db_id);
299 sender.send(file_mtimes).ok();
300 }
301 }
302 }
303 });
304
305 // embed_tx/rx: Embed Batch and Send to Database
306 let (embed_batch_tx, embed_batch_rx) =
307 channel::unbounded::<Vec<(i64, ParsedFile, Vec<String>)>>();
308 let _embed_batch_task = cx.background().spawn({
309 let db_update_tx = db_update_tx.clone();
310 let embedding_provider = embedding_provider.clone();
311 async move {
312 while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await {
313 // Construct Batch
314 let mut document_spans = vec![];
315 for (_, _, document_span) in embeddings_queue.iter() {
316 document_spans.extend(document_span.iter().map(|s| s.as_str()));
317 }
318
319 if let Ok(embeddings) = embedding_provider.embed_batch(document_spans).await
320 {
321 let mut i = 0;
322 let mut j = 0;
323
324 for embedding in embeddings.iter() {
325 while embeddings_queue[i].1.documents.len() == j {
326 i += 1;
327 j = 0;
328 }
329
330 embeddings_queue[i].1.documents[j].embedding = embedding.to_owned();
331 j += 1;
332 }
333
334 for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() {
335 for document in indexed_file.documents.iter() {
336 // TODO: Update this so it doesn't panic
337 assert!(
338 document.embedding.len() > 0,
339 "Document Embedding Not Complete"
340 );
341 }
342
343 db_update_tx
344 .send(DbOperation::InsertFile {
345 worktree_id,
346 indexed_file,
347 })
348 .await
349 .unwrap();
350 }
351 }
352 }
353 }
354 });
355
356 // batch_tx/rx: Batch Files to Send for Embeddings
357 let batch_size = settings::get::<VectorStoreSettings>(cx).embedding_batch_size;
358 let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
359 let _batch_files_task = cx.background().spawn(async move {
360 let mut queue_len = 0;
361 let mut embeddings_queue = vec![];
362
363 while let Ok(job) = batch_files_rx.recv().await {
364 let should_flush = match job {
365 EmbeddingJob::Enqueue {
366 document_spans,
367 worktree_id,
368 parsed_file,
369 } => {
370 queue_len += &document_spans.len();
371 embeddings_queue.push((worktree_id, parsed_file, document_spans));
372 queue_len >= batch_size
373 }
374 EmbeddingJob::Flush => true,
375 };
376
377 if should_flush {
378 embed_batch_tx.try_send(embeddings_queue).unwrap();
379 embeddings_queue = vec![];
380 queue_len = 0;
381 }
382 }
383 });
384
385 // parsing_files_tx/rx: Parsing Files to Embeddable Documents
386 let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
387
388 let mut _parsing_files_tasks = Vec::new();
389 // for _ in 0..cx.background().num_cpus() {
390 for _ in 0..1 {
391 let fs = fs.clone();
392 let parsing_files_rx = parsing_files_rx.clone();
393 let batch_files_tx = batch_files_tx.clone();
394 _parsing_files_tasks.push(cx.background().spawn(async move {
395 let parser = Parser::new();
396 let cursor = QueryCursor::new();
397 let mut retriever = CodeContextRetriever { parser, cursor, fs };
398 while let Ok(pending_file) = parsing_files_rx.recv().await {
399 if let Some((indexed_file, document_spans)) =
400 retriever.parse_file(pending_file.clone()).await.log_err()
401 {
402 batch_files_tx
403 .try_send(EmbeddingJob::Enqueue {
404 worktree_id: pending_file.worktree_db_id,
405 parsed_file: indexed_file,
406 document_spans,
407 })
408 .unwrap();
409 }
410
411 if parsing_files_rx.len() == 0 {
412 batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
413 }
414 }
415 }));
416 }
417
418 Self {
419 fs,
420 database_url,
421 embedding_provider,
422 language_registry,
423 db_update_tx,
424 parsing_files_tx,
425 _db_update_task,
426 _embed_batch_task,
427 _batch_files_task,
428 _parsing_files_tasks,
429 projects: HashMap::new(),
430 }
431 }))
432 }
433
434 fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
435 let (tx, rx) = oneshot::channel();
436 self.db_update_tx
437 .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
438 .unwrap();
439 async move { rx.await? }
440 }
441
442 fn get_file_mtimes(
443 &self,
444 worktree_id: i64,
445 ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
446 let (tx, rx) = oneshot::channel();
447 self.db_update_tx
448 .try_send(DbOperation::FileMTimes {
449 worktree_id,
450 sender: tx,
451 })
452 .unwrap();
453 async move { rx.await? }
454 }
455
456 fn add_project(
457 &mut self,
458 project: ModelHandle<Project>,
459 cx: &mut ModelContext<Self>,
460 ) -> Task<Result<()>> {
461 let worktree_scans_complete = project
462 .read(cx)
463 .worktrees(cx)
464 .map(|worktree| {
465 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
466 async move {
467 scan_complete.await;
468 }
469 })
470 .collect::<Vec<_>>();
471 let worktree_db_ids = project
472 .read(cx)
473 .worktrees(cx)
474 .map(|worktree| {
475 self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
476 })
477 .collect::<Vec<_>>();
478
479 let fs = self.fs.clone();
480 let language_registry = self.language_registry.clone();
481 let database_url = self.database_url.clone();
482 let db_update_tx = self.db_update_tx.clone();
483 let parsing_files_tx = self.parsing_files_tx.clone();
484
485 cx.spawn(|this, mut cx| async move {
486 futures::future::join_all(worktree_scans_complete).await;
487
488 let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
489
490 if let Some(db_directory) = database_url.parent() {
491 fs.create_dir(db_directory).await.log_err();
492 }
493
494 let worktrees = project.read_with(&cx, |project, cx| {
495 project
496 .worktrees(cx)
497 .map(|worktree| worktree.read(cx).snapshot())
498 .collect::<Vec<_>>()
499 });
500
501 let mut worktree_file_times = HashMap::new();
502 let mut db_ids_by_worktree_id = HashMap::new();
503 for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
504 let db_id = db_id?;
505 db_ids_by_worktree_id.insert(worktree.id(), db_id);
506 worktree_file_times.insert(
507 worktree.id(),
508 this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
509 .await?,
510 );
511 }
512
513 cx.background()
514 .spawn({
515 let db_ids_by_worktree_id = db_ids_by_worktree_id.clone();
516 let db_update_tx = db_update_tx.clone();
517 let language_registry = language_registry.clone();
518 let parsing_files_tx = parsing_files_tx.clone();
519 async move {
520 let t0 = Instant::now();
521 for worktree in worktrees.into_iter() {
522 let mut file_mtimes =
523 worktree_file_times.remove(&worktree.id()).unwrap();
524 for file in worktree.files(false, 0) {
525 let absolute_path = worktree.absolutize(&file.path);
526
527 if let Ok(language) = language_registry
528 .language_for_file(&absolute_path, None)
529 .await
530 {
531 if language
532 .grammar()
533 .and_then(|grammar| grammar.embedding_config.as_ref())
534 .is_none()
535 {
536 continue;
537 }
538
539 let path_buf = file.path.to_path_buf();
540 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
541 let already_stored = stored_mtime
542 .map_or(false, |existing_mtime| {
543 existing_mtime == file.mtime
544 });
545
546 if !already_stored {
547 parsing_files_tx
548 .try_send(PendingFile {
549 worktree_db_id: db_ids_by_worktree_id
550 [&worktree.id()],
551 relative_path: path_buf,
552 absolute_path,
553 language,
554 modified_time: file.mtime,
555 })
556 .unwrap();
557 }
558 }
559 }
560 for file in file_mtimes.keys() {
561 db_update_tx
562 .try_send(DbOperation::Delete {
563 worktree_id: db_ids_by_worktree_id[&worktree.id()],
564 path: file.to_owned(),
565 })
566 .unwrap();
567 }
568 }
569 log::info!(
570 "Parsing Worktree Completed in {:?}",
571 t0.elapsed().as_millis()
572 );
573 }
574 })
575 .detach();
576
577 // let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc<Language>, SystemTime), SystemTime))> = vec![];
578 this.update(&mut cx, |this, cx| {
579 // The below is managing for updated on save
580 // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
581 // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
582 let _subscription = cx.subscribe(&project, |this, project, event, cx| {
583 if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event {
584 this.project_entries_changed(project, changes.clone(), cx, worktree_id);
585 }
586 });
587
588 this.projects.insert(
589 project.downgrade(),
590 ProjectState {
591 pending_files: HashMap::new(),
592 worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(),
593 _subscription,
594 },
595 );
596 });
597
598 anyhow::Ok(())
599 })
600 }
601
602 pub fn search(
603 &mut self,
604 project: ModelHandle<Project>,
605 phrase: String,
606 limit: usize,
607 cx: &mut ModelContext<Self>,
608 ) -> Task<Result<Vec<SearchResult>>> {
609 let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
610 state
611 } else {
612 return Task::ready(Err(anyhow!("project not added")));
613 };
614
615 let worktree_db_ids = project
616 .read(cx)
617 .worktrees(cx)
618 .filter_map(|worktree| {
619 let worktree_id = worktree.read(cx).id();
620 project_state.db_id_for_worktree_id(worktree_id)
621 })
622 .collect::<Vec<_>>();
623
624 let embedding_provider = self.embedding_provider.clone();
625 let database_url = self.database_url.clone();
626 cx.spawn(|this, cx| async move {
627 let documents = cx
628 .background()
629 .spawn(async move {
630 let database = VectorDatabase::new(database_url.to_string_lossy().into())?;
631
632 let phrase_embedding = embedding_provider
633 .embed_batch(vec![&phrase])
634 .await?
635 .into_iter()
636 .next()
637 .unwrap();
638
639 database.top_k_search(&worktree_db_ids, &phrase_embedding, limit)
640 })
641 .await?;
642
643 this.read_with(&cx, |this, _| {
644 let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
645 state
646 } else {
647 return Err(anyhow!("project not added"));
648 };
649
650 Ok(documents
651 .into_iter()
652 .filter_map(|(worktree_db_id, file_path, offset, name)| {
653 let worktree_id = project_state.worktree_id_for_db_id(worktree_db_id)?;
654 Some(SearchResult {
655 worktree_id,
656 name,
657 offset,
658 file_path,
659 })
660 })
661 .collect())
662 })
663 })
664 }
665
666 fn project_entries_changed(
667 &mut self,
668 project: ModelHandle<Project>,
669 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
670 cx: &mut ModelContext<'_, VectorStore>,
671 worktree_id: &WorktreeId,
672 ) -> Option<()> {
673 let reindexing_delay = settings::get::<VectorStoreSettings>(cx).reindexing_delay_seconds;
674
675 let worktree = project
676 .read(cx)
677 .worktree_for_id(worktree_id.clone(), cx)?
678 .read(cx)
679 .snapshot();
680
681 let worktree_db_id = self
682 .projects
683 .get(&project.downgrade())?
684 .db_id_for_worktree_id(worktree.id())?;
685 let file_mtimes = self.get_file_mtimes(worktree_db_id);
686
687 let language_registry = self.language_registry.clone();
688
689 cx.spawn(|this, mut cx| async move {
690 let file_mtimes = file_mtimes.await.log_err()?;
691
692 for change in changes.into_iter() {
693 let change_path = change.0.clone();
694 let absolute_path = worktree.absolutize(&change_path);
695
696 // Skip if git ignored or symlink
697 if let Some(entry) = worktree.entry_for_id(change.1) {
698 if entry.is_ignored || entry.is_symlink || entry.is_external {
699 continue;
700 }
701 }
702
703 match change.2 {
704 PathChange::Removed => this.update(&mut cx, |this, _| {
705 this.db_update_tx
706 .try_send(DbOperation::Delete {
707 worktree_id: worktree_db_id,
708 path: absolute_path,
709 })
710 .unwrap();
711 }),
712 _ => {
713 if let Ok(language) = language_registry
714 .language_for_file(&change_path.to_path_buf(), None)
715 .await
716 {
717 if language
718 .grammar()
719 .and_then(|grammar| grammar.embedding_config.as_ref())
720 .is_none()
721 {
722 continue;
723 }
724
725 let modified_time =
726 change_path.metadata().log_err()?.modified().log_err()?;
727
728 let existing_time = file_mtimes.get(&change_path.to_path_buf());
729 let already_stored = existing_time
730 .map_or(false, |existing_time| &modified_time != existing_time);
731
732 if !already_stored {
733 this.update(&mut cx, |this, _| {
734 let reindex_time = modified_time
735 + Duration::from_secs(reindexing_delay as u64);
736
737 let project_state =
738 this.projects.get_mut(&project.downgrade())?;
739 project_state.update_pending_files(
740 PendingFile {
741 relative_path: change_path.to_path_buf(),
742 absolute_path,
743 modified_time,
744 worktree_db_id,
745 language: language.clone(),
746 },
747 reindex_time,
748 );
749
750 for file in project_state.get_outstanding_files() {
751 this.parsing_files_tx.try_send(file).unwrap();
752 }
753 Some(())
754 });
755 }
756 }
757 }
758 }
759 }
760
761 Some(())
762 })
763 .detach();
764
765 Some(())
766 }
767}
768
769impl Entity for VectorStore {
770 type Event = ();
771}