1mod db;
2mod embedding;
3mod modal;
4mod parsing;
5mod vector_store_settings;
6
7#[cfg(test)]
8mod vector_store_tests;
9
10use crate::vector_store_settings::VectorStoreSettings;
11use anyhow::{anyhow, Result};
12use db::VectorDatabase;
13use embedding::{EmbeddingProvider, OpenAIEmbeddings};
14use futures::{channel::oneshot, Future};
15use gpui::{
16 AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
17 WeakModelHandle,
18};
19use language::{Language, LanguageRegistry};
20use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
21use parsing::{CodeContextRetriever, ParsedFile};
22use project::{Fs, PathChange, Project, ProjectEntryId, WorktreeId};
23use smol::channel;
24use std::{
25 collections::HashMap,
26 path::{Path, PathBuf},
27 sync::Arc,
28 time::{Duration, Instant, SystemTime},
29};
30use tree_sitter::{Parser, QueryCursor};
31use util::{
32 channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
33 http::HttpClient,
34 paths::EMBEDDINGS_DIR,
35 ResultExt,
36};
37use workspace::{Workspace, WorkspaceCreated};
38
39const VECTOR_STORE_VERSION: usize = 0;
40
41pub fn init(
42 fs: Arc<dyn Fs>,
43 http_client: Arc<dyn HttpClient>,
44 language_registry: Arc<LanguageRegistry>,
45 cx: &mut AppContext,
46) {
47 settings::register::<VectorStoreSettings>(cx);
48
49 let db_file_path = EMBEDDINGS_DIR
50 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
51 .join("embeddings_db");
52
53 SemanticSearch::init(cx);
54 cx.add_action(
55 |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
56 if cx.has_global::<ModelHandle<VectorStore>>() {
57 let vector_store = cx.global::<ModelHandle<VectorStore>>().clone();
58 workspace.toggle_modal(cx, |workspace, cx| {
59 let project = workspace.project().clone();
60 let workspace = cx.weak_handle();
61 cx.add_view(|cx| {
62 SemanticSearch::new(
63 SemanticSearchDelegate::new(workspace, project, vector_store),
64 cx,
65 )
66 })
67 });
68 }
69 },
70 );
71
72 if *RELEASE_CHANNEL == ReleaseChannel::Stable
73 || !settings::get::<VectorStoreSettings>(cx).enable
74 {
75 return;
76 }
77
78 cx.spawn(move |mut cx| async move {
79 let vector_store = VectorStore::new(
80 fs,
81 db_file_path,
82 // Arc::new(embedding::DummyEmbeddings {}),
83 Arc::new(OpenAIEmbeddings {
84 client: http_client,
85 executor: cx.background(),
86 }),
87 language_registry,
88 cx.clone(),
89 )
90 .await?;
91
92 cx.update(|cx| {
93 cx.set_global(vector_store.clone());
94 cx.subscribe_global::<WorkspaceCreated, _>({
95 let vector_store = vector_store.clone();
96 move |event, cx| {
97 let workspace = &event.0;
98 if let Some(workspace) = workspace.upgrade(cx) {
99 let project = workspace.read(cx).project().clone();
100 if project.read(cx).is_local() {
101 vector_store.update(cx, |store, cx| {
102 store.add_project(project, cx).detach();
103 });
104 }
105 }
106 }
107 })
108 .detach();
109 });
110
111 anyhow::Ok(())
112 })
113 .detach();
114}
115
116pub struct VectorStore {
117 fs: Arc<dyn Fs>,
118 database_url: Arc<PathBuf>,
119 embedding_provider: Arc<dyn EmbeddingProvider>,
120 language_registry: Arc<LanguageRegistry>,
121 db_update_tx: channel::Sender<DbOperation>,
122 parsing_files_tx: channel::Sender<PendingFile>,
123 _db_update_task: Task<()>,
124 _embed_batch_task: Task<()>,
125 _batch_files_task: Task<()>,
126 _parsing_files_tasks: Vec<Task<()>>,
127 projects: HashMap<WeakModelHandle<Project>, ProjectState>,
128}
129
130struct ProjectState {
131 worktree_db_ids: Vec<(WorktreeId, i64)>,
132 pending_files: HashMap<PathBuf, (PendingFile, SystemTime)>,
133 _subscription: gpui::Subscription,
134}
135
136impl ProjectState {
137 fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
138 self.worktree_db_ids
139 .iter()
140 .find_map(|(worktree_id, db_id)| {
141 if *worktree_id == id {
142 Some(*db_id)
143 } else {
144 None
145 }
146 })
147 }
148
149 fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
150 self.worktree_db_ids
151 .iter()
152 .find_map(|(worktree_id, db_id)| {
153 if *db_id == id {
154 Some(*worktree_id)
155 } else {
156 None
157 }
158 })
159 }
160
161 fn update_pending_files(&mut self, pending_file: PendingFile, indexing_time: SystemTime) {
162 // If Pending File Already Exists, Replace it with the new one
163 // but keep the old indexing time
164 if let Some(old_file) = self
165 .pending_files
166 .remove(&pending_file.relative_path.clone())
167 {
168 self.pending_files.insert(
169 pending_file.relative_path.clone(),
170 (pending_file, old_file.1),
171 );
172 } else {
173 self.pending_files.insert(
174 pending_file.relative_path.clone(),
175 (pending_file, indexing_time),
176 );
177 };
178 }
179
180 fn get_outstanding_files(&mut self) -> Vec<PendingFile> {
181 let mut outstanding_files = vec![];
182 let mut remove_keys = vec![];
183 for key in self.pending_files.keys().into_iter() {
184 if let Some(pending_details) = self.pending_files.get(key) {
185 let (pending_file, index_time) = pending_details;
186 if index_time <= &SystemTime::now() {
187 outstanding_files.push(pending_file.clone());
188 remove_keys.push(key.clone());
189 }
190 }
191 }
192
193 for key in remove_keys.iter() {
194 self.pending_files.remove(key);
195 }
196
197 return outstanding_files;
198 }
199}
200
201#[derive(Clone, Debug)]
202pub struct PendingFile {
203 worktree_db_id: i64,
204 relative_path: PathBuf,
205 absolute_path: PathBuf,
206 language: Arc<Language>,
207 modified_time: SystemTime,
208}
209
210#[derive(Debug, Clone)]
211pub struct SearchResult {
212 pub worktree_id: WorktreeId,
213 pub name: String,
214 pub offset: usize,
215 pub file_path: PathBuf,
216}
217
218enum DbOperation {
219 InsertFile {
220 worktree_id: i64,
221 indexed_file: ParsedFile,
222 },
223 Delete {
224 worktree_id: i64,
225 path: PathBuf,
226 },
227 FindOrCreateWorktree {
228 path: PathBuf,
229 sender: oneshot::Sender<Result<i64>>,
230 },
231 FileMTimes {
232 worktree_id: i64,
233 sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
234 },
235}
236
237enum EmbeddingJob {
238 Enqueue {
239 worktree_id: i64,
240 parsed_file: ParsedFile,
241 document_spans: Vec<String>,
242 },
243 Flush,
244}
245
246impl VectorStore {
247 async fn new(
248 fs: Arc<dyn Fs>,
249 database_url: PathBuf,
250 embedding_provider: Arc<dyn EmbeddingProvider>,
251 language_registry: Arc<LanguageRegistry>,
252 mut cx: AsyncAppContext,
253 ) -> Result<ModelHandle<Self>> {
254 let database_url = Arc::new(database_url);
255
256 let db = cx
257 .background()
258 .spawn({
259 let fs = fs.clone();
260 let database_url = database_url.clone();
261 async move {
262 if let Some(db_directory) = database_url.parent() {
263 fs.create_dir(db_directory).await.log_err();
264 }
265
266 let db = VectorDatabase::new(database_url.to_string_lossy().to_string())?;
267 anyhow::Ok(db)
268 }
269 })
270 .await?;
271
272 Ok(cx.add_model(|cx| {
273 // paths_tx -> embeddings_tx -> db_update_tx
274
275 //db_update_tx/rx: Updating Database
276 let (db_update_tx, db_update_rx) = channel::unbounded();
277 let _db_update_task = cx.background().spawn(async move {
278 while let Ok(job) = db_update_rx.recv().await {
279 match job {
280 DbOperation::InsertFile {
281 worktree_id,
282 indexed_file,
283 } => {
284 db.insert_file(worktree_id, indexed_file).log_err();
285 }
286 DbOperation::Delete { worktree_id, path } => {
287 db.delete_file(worktree_id, path).log_err();
288 }
289 DbOperation::FindOrCreateWorktree { path, sender } => {
290 let id = db.find_or_create_worktree(&path);
291 sender.send(id).ok();
292 }
293 DbOperation::FileMTimes {
294 worktree_id: worktree_db_id,
295 sender,
296 } => {
297 let file_mtimes = db.get_file_mtimes(worktree_db_id);
298 sender.send(file_mtimes).ok();
299 }
300 }
301 }
302 });
303
304 // embed_tx/rx: Embed Batch and Send to Database
305 let (embed_batch_tx, embed_batch_rx) =
306 channel::unbounded::<Vec<(i64, ParsedFile, Vec<String>)>>();
307 let _embed_batch_task = cx.background().spawn({
308 let db_update_tx = db_update_tx.clone();
309 let embedding_provider = embedding_provider.clone();
310 async move {
311 while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await {
312 // Construct Batch
313 let mut document_spans = vec![];
314 for (_, _, document_span) in embeddings_queue.iter() {
315 document_spans.extend(document_span.iter().map(|s| s.as_str()));
316 }
317
318 if let Ok(embeddings) = embedding_provider.embed_batch(document_spans).await
319 {
320 let mut i = 0;
321 let mut j = 0;
322
323 for embedding in embeddings.iter() {
324 while embeddings_queue[i].1.documents.len() == j {
325 i += 1;
326 j = 0;
327 }
328
329 embeddings_queue[i].1.documents[j].embedding = embedding.to_owned();
330 j += 1;
331 }
332
333 for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() {
334 for document in indexed_file.documents.iter() {
335 // TODO: Update this so it doesn't panic
336 assert!(
337 document.embedding.len() > 0,
338 "Document Embedding Not Complete"
339 );
340 }
341
342 db_update_tx
343 .send(DbOperation::InsertFile {
344 worktree_id,
345 indexed_file,
346 })
347 .await
348 .unwrap();
349 }
350 }
351 }
352 }
353 });
354
355 // batch_tx/rx: Batch Files to Send for Embeddings
356 let batch_size = settings::get::<VectorStoreSettings>(cx).embedding_batch_size;
357 let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
358 let _batch_files_task = cx.background().spawn(async move {
359 let mut queue_len = 0;
360 let mut embeddings_queue = vec![];
361
362 while let Ok(job) = batch_files_rx.recv().await {
363 let should_flush = match job {
364 EmbeddingJob::Enqueue {
365 document_spans,
366 worktree_id,
367 parsed_file,
368 } => {
369 queue_len += &document_spans.len();
370 embeddings_queue.push((worktree_id, parsed_file, document_spans));
371 queue_len >= batch_size
372 }
373 EmbeddingJob::Flush => true,
374 };
375
376 if should_flush {
377 embed_batch_tx.try_send(embeddings_queue).unwrap();
378 embeddings_queue = vec![];
379 queue_len = 0;
380 }
381 }
382 });
383
384 // parsing_files_tx/rx: Parsing Files to Embeddable Documents
385 let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
386
387 let mut _parsing_files_tasks = Vec::new();
388 // for _ in 0..cx.background().num_cpus() {
389 for _ in 0..1 {
390 let fs = fs.clone();
391 let parsing_files_rx = parsing_files_rx.clone();
392 let batch_files_tx = batch_files_tx.clone();
393 _parsing_files_tasks.push(cx.background().spawn(async move {
394 let parser = Parser::new();
395 let cursor = QueryCursor::new();
396 let mut retriever = CodeContextRetriever { parser, cursor, fs };
397 while let Ok(pending_file) = parsing_files_rx.recv().await {
398 if let Some((indexed_file, document_spans)) =
399 retriever.parse_file(pending_file.clone()).await.log_err()
400 {
401 batch_files_tx
402 .try_send(EmbeddingJob::Enqueue {
403 worktree_id: pending_file.worktree_db_id,
404 parsed_file: indexed_file,
405 document_spans,
406 })
407 .unwrap();
408 }
409
410 if parsing_files_rx.len() == 0 {
411 batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
412 }
413 }
414 }));
415 }
416
417 Self {
418 fs,
419 database_url,
420 embedding_provider,
421 language_registry,
422 db_update_tx,
423 parsing_files_tx,
424 _db_update_task,
425 _embed_batch_task,
426 _batch_files_task,
427 _parsing_files_tasks,
428 projects: HashMap::new(),
429 }
430 }))
431 }
432
433 fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
434 let (tx, rx) = oneshot::channel();
435 self.db_update_tx
436 .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
437 .unwrap();
438 async move { rx.await? }
439 }
440
441 fn get_file_mtimes(
442 &self,
443 worktree_id: i64,
444 ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
445 let (tx, rx) = oneshot::channel();
446 self.db_update_tx
447 .try_send(DbOperation::FileMTimes {
448 worktree_id,
449 sender: tx,
450 })
451 .unwrap();
452 async move { rx.await? }
453 }
454
455 fn add_project(
456 &mut self,
457 project: ModelHandle<Project>,
458 cx: &mut ModelContext<Self>,
459 ) -> Task<Result<()>> {
460 let worktree_scans_complete = project
461 .read(cx)
462 .worktrees(cx)
463 .map(|worktree| {
464 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
465 async move {
466 scan_complete.await;
467 }
468 })
469 .collect::<Vec<_>>();
470 let worktree_db_ids = project
471 .read(cx)
472 .worktrees(cx)
473 .map(|worktree| {
474 self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
475 })
476 .collect::<Vec<_>>();
477
478 let fs = self.fs.clone();
479 let language_registry = self.language_registry.clone();
480 let database_url = self.database_url.clone();
481 let db_update_tx = self.db_update_tx.clone();
482 let parsing_files_tx = self.parsing_files_tx.clone();
483
484 cx.spawn(|this, mut cx| async move {
485 futures::future::join_all(worktree_scans_complete).await;
486
487 let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
488
489 if let Some(db_directory) = database_url.parent() {
490 fs.create_dir(db_directory).await.log_err();
491 }
492
493 let worktrees = project.read_with(&cx, |project, cx| {
494 project
495 .worktrees(cx)
496 .map(|worktree| worktree.read(cx).snapshot())
497 .collect::<Vec<_>>()
498 });
499
500 let mut worktree_file_times = HashMap::new();
501 let mut db_ids_by_worktree_id = HashMap::new();
502 for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
503 let db_id = db_id?;
504 db_ids_by_worktree_id.insert(worktree.id(), db_id);
505 worktree_file_times.insert(
506 worktree.id(),
507 this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
508 .await?,
509 );
510 }
511
512 cx.background()
513 .spawn({
514 let db_ids_by_worktree_id = db_ids_by_worktree_id.clone();
515 let db_update_tx = db_update_tx.clone();
516 let language_registry = language_registry.clone();
517 let parsing_files_tx = parsing_files_tx.clone();
518 async move {
519 let t0 = Instant::now();
520 for worktree in worktrees.into_iter() {
521 let mut file_mtimes =
522 worktree_file_times.remove(&worktree.id()).unwrap();
523 for file in worktree.files(false, 0) {
524 let absolute_path = worktree.absolutize(&file.path);
525
526 if let Ok(language) = language_registry
527 .language_for_file(&absolute_path, None)
528 .await
529 {
530 if language
531 .grammar()
532 .and_then(|grammar| grammar.embedding_config.as_ref())
533 .is_none()
534 {
535 continue;
536 }
537
538 let path_buf = file.path.to_path_buf();
539 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
540 let already_stored = stored_mtime
541 .map_or(false, |existing_mtime| {
542 existing_mtime == file.mtime
543 });
544
545 if !already_stored {
546 parsing_files_tx
547 .try_send(PendingFile {
548 worktree_db_id: db_ids_by_worktree_id
549 [&worktree.id()],
550 relative_path: path_buf,
551 absolute_path,
552 language,
553 modified_time: file.mtime,
554 })
555 .unwrap();
556 }
557 }
558 }
559 for file in file_mtimes.keys() {
560 db_update_tx
561 .try_send(DbOperation::Delete {
562 worktree_id: db_ids_by_worktree_id[&worktree.id()],
563 path: file.to_owned(),
564 })
565 .unwrap();
566 }
567 }
568 log::info!(
569 "Parsing Worktree Completed in {:?}",
570 t0.elapsed().as_millis()
571 );
572 }
573 })
574 .detach();
575
576 // let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc<Language>, SystemTime), SystemTime))> = vec![];
577 this.update(&mut cx, |this, cx| {
578 // The below is managing for updated on save
579 // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
580 // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
581 let _subscription = cx.subscribe(&project, |this, project, event, cx| {
582 if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event {
583 this.project_entries_changed(project, changes.clone(), cx, worktree_id);
584 }
585 });
586
587 this.projects.insert(
588 project.downgrade(),
589 ProjectState {
590 pending_files: HashMap::new(),
591 worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(),
592 _subscription,
593 },
594 );
595 });
596
597 anyhow::Ok(())
598 })
599 }
600
601 pub fn search(
602 &mut self,
603 project: ModelHandle<Project>,
604 phrase: String,
605 limit: usize,
606 cx: &mut ModelContext<Self>,
607 ) -> Task<Result<Vec<SearchResult>>> {
608 let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
609 state
610 } else {
611 return Task::ready(Err(anyhow!("project not added")));
612 };
613
614 let worktree_db_ids = project
615 .read(cx)
616 .worktrees(cx)
617 .filter_map(|worktree| {
618 let worktree_id = worktree.read(cx).id();
619 project_state.db_id_for_worktree_id(worktree_id)
620 })
621 .collect::<Vec<_>>();
622
623 let embedding_provider = self.embedding_provider.clone();
624 let database_url = self.database_url.clone();
625 cx.spawn(|this, cx| async move {
626 let documents = cx
627 .background()
628 .spawn(async move {
629 let database = VectorDatabase::new(database_url.to_string_lossy().into())?;
630
631 let phrase_embedding = embedding_provider
632 .embed_batch(vec![&phrase])
633 .await?
634 .into_iter()
635 .next()
636 .unwrap();
637
638 database.top_k_search(&worktree_db_ids, &phrase_embedding, limit)
639 })
640 .await?;
641
642 this.read_with(&cx, |this, _| {
643 let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
644 state
645 } else {
646 return Err(anyhow!("project not added"));
647 };
648
649 Ok(documents
650 .into_iter()
651 .filter_map(|(worktree_db_id, file_path, offset, name)| {
652 let worktree_id = project_state.worktree_id_for_db_id(worktree_db_id)?;
653 Some(SearchResult {
654 worktree_id,
655 name,
656 offset,
657 file_path,
658 })
659 })
660 .collect())
661 })
662 })
663 }
664
665 fn project_entries_changed(
666 &mut self,
667 project: ModelHandle<Project>,
668 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
669 cx: &mut ModelContext<'_, VectorStore>,
670 worktree_id: &WorktreeId,
671 ) -> Option<()> {
672 let reindexing_delay = settings::get::<VectorStoreSettings>(cx).reindexing_delay_seconds;
673
674 let worktree = project
675 .read(cx)
676 .worktree_for_id(worktree_id.clone(), cx)?
677 .read(cx)
678 .snapshot();
679
680 let worktree_db_id = self
681 .projects
682 .get(&project.downgrade())?
683 .db_id_for_worktree_id(worktree.id())?;
684 let file_mtimes = self.get_file_mtimes(worktree_db_id);
685
686 let language_registry = self.language_registry.clone();
687
688 cx.spawn(|this, mut cx| async move {
689 let file_mtimes = file_mtimes.await.log_err()?;
690
691 for change in changes.into_iter() {
692 let change_path = change.0.clone();
693 let absolute_path = worktree.absolutize(&change_path);
694
695 // Skip if git ignored or symlink
696 if let Some(entry) = worktree.entry_for_id(change.1) {
697 if entry.is_ignored || entry.is_symlink || entry.is_external {
698 continue;
699 }
700 }
701
702 match change.2 {
703 PathChange::Removed => this.update(&mut cx, |this, _| {
704 this.db_update_tx
705 .try_send(DbOperation::Delete {
706 worktree_id: worktree_db_id,
707 path: absolute_path,
708 })
709 .unwrap();
710 }),
711 _ => {
712 if let Ok(language) = language_registry
713 .language_for_file(&change_path.to_path_buf(), None)
714 .await
715 {
716 if language
717 .grammar()
718 .and_then(|grammar| grammar.embedding_config.as_ref())
719 .is_none()
720 {
721 continue;
722 }
723
724 let modified_time =
725 change_path.metadata().log_err()?.modified().log_err()?;
726
727 let existing_time = file_mtimes.get(&change_path.to_path_buf());
728 let already_stored = existing_time
729 .map_or(false, |existing_time| &modified_time != existing_time);
730
731 if !already_stored {
732 this.update(&mut cx, |this, _| {
733 let reindex_time = modified_time
734 + Duration::from_secs(reindexing_delay as u64);
735
736 let project_state =
737 this.projects.get_mut(&project.downgrade())?;
738 project_state.update_pending_files(
739 PendingFile {
740 relative_path: change_path.to_path_buf(),
741 absolute_path,
742 modified_time,
743 worktree_db_id,
744 language: language.clone(),
745 },
746 reindex_time,
747 );
748
749 for file in project_state.get_outstanding_files() {
750 this.parsing_files_tx.try_send(file).unwrap();
751 }
752 Some(())
753 });
754 }
755 }
756 }
757 }
758 }
759
760 Some(())
761 })
762 .detach();
763
764 Some(())
765 }
766}
767
768impl Entity for VectorStore {
769 type Event = ();
770}