1mod db;
2mod embedding;
3mod modal;
4mod parsing;
5
6#[cfg(test)]
7mod vector_store_tests;
8
9use anyhow::{anyhow, Result};
10use db::VectorDatabase;
11use embedding::{EmbeddingProvider, OpenAIEmbeddings};
12use futures::{channel::oneshot, Future};
13use gpui::{
14 AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
15 WeakModelHandle,
16};
17use language::{Language, LanguageRegistry};
18use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
19use parsing::{CodeContextRetriever, ParsedFile};
20use project::{Fs, PathChange, Project, ProjectEntryId, WorktreeId};
21use smol::channel;
22use std::{
23 collections::HashMap,
24 path::{Path, PathBuf},
25 sync::Arc,
26 time::{Duration, Instant, SystemTime},
27};
28use tree_sitter::{Parser, QueryCursor};
29use util::{
30 channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
31 http::HttpClient,
32 paths::EMBEDDINGS_DIR,
33 ResultExt,
34};
35use workspace::{Workspace, WorkspaceCreated};
36
37const REINDEXING_DELAY_SECONDS: u64 = 3;
38const EMBEDDINGS_BATCH_SIZE: usize = 150;
39
40pub fn init(
41 fs: Arc<dyn Fs>,
42 http_client: Arc<dyn HttpClient>,
43 language_registry: Arc<LanguageRegistry>,
44 cx: &mut AppContext,
45) {
46 if *RELEASE_CHANNEL == ReleaseChannel::Stable {
47 return;
48 }
49
50 let db_file_path = EMBEDDINGS_DIR
51 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
52 .join("embeddings_db");
53
54 cx.spawn(move |mut cx| async move {
55 let vector_store = VectorStore::new(
56 fs,
57 db_file_path,
58 // Arc::new(embedding::DummyEmbeddings {}),
59 Arc::new(OpenAIEmbeddings {
60 client: http_client,
61 executor: cx.background(),
62 }),
63 language_registry,
64 cx.clone(),
65 )
66 .await?;
67
68 cx.update(|cx| {
69 cx.subscribe_global::<WorkspaceCreated, _>({
70 let vector_store = vector_store.clone();
71 move |event, cx| {
72 let workspace = &event.0;
73 if let Some(workspace) = workspace.upgrade(cx) {
74 let project = workspace.read(cx).project().clone();
75 if project.read(cx).is_local() {
76 vector_store.update(cx, |store, cx| {
77 store.add_project(project, cx).detach();
78 });
79 }
80 }
81 }
82 })
83 .detach();
84
85 cx.add_action({
86 move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
87 let vector_store = vector_store.clone();
88 workspace.toggle_modal(cx, |workspace, cx| {
89 let project = workspace.project().clone();
90 let workspace = cx.weak_handle();
91 cx.add_view(|cx| {
92 SemanticSearch::new(
93 SemanticSearchDelegate::new(workspace, project, vector_store),
94 cx,
95 )
96 })
97 })
98 }
99 });
100
101 SemanticSearch::init(cx);
102 });
103
104 anyhow::Ok(())
105 })
106 .detach();
107}
108
109pub struct VectorStore {
110 fs: Arc<dyn Fs>,
111 database_url: Arc<PathBuf>,
112 embedding_provider: Arc<dyn EmbeddingProvider>,
113 language_registry: Arc<LanguageRegistry>,
114 db_update_tx: channel::Sender<DbOperation>,
115 parsing_files_tx: channel::Sender<PendingFile>,
116 _db_update_task: Task<()>,
117 _embed_batch_task: Task<()>,
118 _batch_files_task: Task<()>,
119 _parsing_files_tasks: Vec<Task<()>>,
120 projects: HashMap<WeakModelHandle<Project>, ProjectState>,
121}
122
123struct ProjectState {
124 worktree_db_ids: Vec<(WorktreeId, i64)>,
125 pending_files: HashMap<PathBuf, (PendingFile, SystemTime)>,
126 _subscription: gpui::Subscription,
127}
128
129impl ProjectState {
130 fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
131 self.worktree_db_ids
132 .iter()
133 .find_map(|(worktree_id, db_id)| {
134 if *worktree_id == id {
135 Some(*db_id)
136 } else {
137 None
138 }
139 })
140 }
141
142 fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
143 self.worktree_db_ids
144 .iter()
145 .find_map(|(worktree_id, db_id)| {
146 if *db_id == id {
147 Some(*worktree_id)
148 } else {
149 None
150 }
151 })
152 }
153
154 fn update_pending_files(&mut self, pending_file: PendingFile, indexing_time: SystemTime) {
155 // If Pending File Already Exists, Replace it with the new one
156 // but keep the old indexing time
157 if let Some(old_file) = self
158 .pending_files
159 .remove(&pending_file.relative_path.clone())
160 {
161 self.pending_files.insert(
162 pending_file.relative_path.clone(),
163 (pending_file, old_file.1),
164 );
165 } else {
166 self.pending_files.insert(
167 pending_file.relative_path.clone(),
168 (pending_file, indexing_time),
169 );
170 };
171 }
172
173 fn get_outstanding_files(&mut self) -> Vec<PendingFile> {
174 let mut outstanding_files = vec![];
175 let mut remove_keys = vec![];
176 for key in self.pending_files.keys().into_iter() {
177 if let Some(pending_details) = self.pending_files.get(key) {
178 let (pending_file, index_time) = pending_details;
179 if index_time <= &SystemTime::now() {
180 outstanding_files.push(pending_file.clone());
181 remove_keys.push(key.clone());
182 }
183 }
184 }
185
186 for key in remove_keys.iter() {
187 self.pending_files.remove(key);
188 }
189
190 return outstanding_files;
191 }
192}
193
194#[derive(Clone, Debug)]
195pub struct PendingFile {
196 worktree_db_id: i64,
197 relative_path: PathBuf,
198 absolute_path: PathBuf,
199 language: Arc<Language>,
200 modified_time: SystemTime,
201}
202
203#[derive(Debug, Clone)]
204pub struct SearchResult {
205 pub worktree_id: WorktreeId,
206 pub name: String,
207 pub offset: usize,
208 pub file_path: PathBuf,
209}
210
211enum DbOperation {
212 InsertFile {
213 worktree_id: i64,
214 indexed_file: ParsedFile,
215 },
216 Delete {
217 worktree_id: i64,
218 path: PathBuf,
219 },
220 FindOrCreateWorktree {
221 path: PathBuf,
222 sender: oneshot::Sender<Result<i64>>,
223 },
224 FileMTimes {
225 worktree_id: i64,
226 sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
227 },
228}
229
230enum EmbeddingJob {
231 Enqueue {
232 worktree_id: i64,
233 parsed_file: ParsedFile,
234 document_spans: Vec<String>,
235 },
236 Flush,
237}
238
239impl VectorStore {
240 async fn new(
241 fs: Arc<dyn Fs>,
242 database_url: PathBuf,
243 embedding_provider: Arc<dyn EmbeddingProvider>,
244 language_registry: Arc<LanguageRegistry>,
245 mut cx: AsyncAppContext,
246 ) -> Result<ModelHandle<Self>> {
247 let database_url = Arc::new(database_url);
248
249 let db = cx
250 .background()
251 .spawn({
252 let fs = fs.clone();
253 let database_url = database_url.clone();
254 async move {
255 if let Some(db_directory) = database_url.parent() {
256 fs.create_dir(db_directory).await.log_err();
257 }
258
259 let db = VectorDatabase::new(database_url.to_string_lossy().to_string())?;
260 anyhow::Ok(db)
261 }
262 })
263 .await?;
264
265 Ok(cx.add_model(|cx| {
266 // paths_tx -> embeddings_tx -> db_update_tx
267
268 //db_update_tx/rx: Updating Database
269 let (db_update_tx, db_update_rx) = channel::unbounded();
270 let _db_update_task = cx.background().spawn(async move {
271 while let Ok(job) = db_update_rx.recv().await {
272 match job {
273 DbOperation::InsertFile {
274 worktree_id,
275 indexed_file,
276 } => {
277 log::info!("Inserting Data for {:?}", &indexed_file.path);
278 db.insert_file(worktree_id, indexed_file).log_err();
279 }
280 DbOperation::Delete { worktree_id, path } => {
281 db.delete_file(worktree_id, path).log_err();
282 }
283 DbOperation::FindOrCreateWorktree { path, sender } => {
284 let id = db.find_or_create_worktree(&path);
285 sender.send(id).ok();
286 }
287 DbOperation::FileMTimes {
288 worktree_id: worktree_db_id,
289 sender,
290 } => {
291 let file_mtimes = db.get_file_mtimes(worktree_db_id);
292 sender.send(file_mtimes).ok();
293 }
294 }
295 }
296 });
297
298 // embed_tx/rx: Embed Batch and Send to Database
299 let (embed_batch_tx, embed_batch_rx) =
300 channel::unbounded::<Vec<(i64, ParsedFile, Vec<String>)>>();
301 let _embed_batch_task = cx.background().spawn({
302 let db_update_tx = db_update_tx.clone();
303 let embedding_provider = embedding_provider.clone();
304 async move {
305 while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await {
306 // Construct Batch
307 let mut document_spans = vec![];
308 for (_, _, document_span) in embeddings_queue.iter() {
309 document_spans.extend(document_span.iter().map(|s| s.as_str()));
310 }
311
312 if let Ok(embeddings) = embedding_provider.embed_batch(document_spans).await
313 {
314 let mut i = 0;
315 let mut j = 0;
316
317 for embedding in embeddings.iter() {
318 while embeddings_queue[i].1.documents.len() == j {
319 i += 1;
320 j = 0;
321 }
322
323 embeddings_queue[i].1.documents[j].embedding = embedding.to_owned();
324 j += 1;
325 }
326
327 for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() {
328 for document in indexed_file.documents.iter() {
329 // TODO: Update this so it doesn't panic
330 assert!(
331 document.embedding.len() > 0,
332 "Document Embedding Not Complete"
333 );
334 }
335
336 db_update_tx
337 .send(DbOperation::InsertFile {
338 worktree_id,
339 indexed_file,
340 })
341 .await
342 .unwrap();
343 }
344 }
345 }
346 }
347 });
348
349 // batch_tx/rx: Batch Files to Send for Embeddings
350 let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
351 let _batch_files_task = cx.background().spawn(async move {
352 let mut queue_len = 0;
353 let mut embeddings_queue = vec![];
354
355 while let Ok(job) = batch_files_rx.recv().await {
356 let should_flush = match job {
357 EmbeddingJob::Enqueue {
358 document_spans,
359 worktree_id,
360 parsed_file,
361 } => {
362 queue_len += &document_spans.len();
363 embeddings_queue.push((worktree_id, parsed_file, document_spans));
364 queue_len >= EMBEDDINGS_BATCH_SIZE
365 }
366 EmbeddingJob::Flush => true,
367 };
368
369 if should_flush {
370 embed_batch_tx.try_send(embeddings_queue).unwrap();
371 embeddings_queue = vec![];
372 queue_len = 0;
373 }
374 }
375 });
376
377 // parsing_files_tx/rx: Parsing Files to Embeddable Documents
378 let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
379
380 let mut _parsing_files_tasks = Vec::new();
381 for _ in 0..cx.background().num_cpus() {
382 let fs = fs.clone();
383 let parsing_files_rx = parsing_files_rx.clone();
384 let batch_files_tx = batch_files_tx.clone();
385 _parsing_files_tasks.push(cx.background().spawn(async move {
386 let parser = Parser::new();
387 let cursor = QueryCursor::new();
388 let mut retriever = CodeContextRetriever { parser, cursor, fs };
389 while let Ok(pending_file) = parsing_files_rx.recv().await {
390 log::info!("Parsing File: {:?}", &pending_file.relative_path);
391
392 if let Some((indexed_file, document_spans)) =
393 retriever.parse_file(pending_file.clone()).await.log_err()
394 {
395 batch_files_tx
396 .try_send(EmbeddingJob::Enqueue {
397 worktree_id: pending_file.worktree_db_id,
398 parsed_file: indexed_file,
399 document_spans,
400 })
401 .unwrap();
402 }
403
404 if parsing_files_rx.len() == 0 {
405 batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
406 }
407 }
408 }));
409 }
410
411 Self {
412 fs,
413 database_url,
414 embedding_provider,
415 language_registry,
416 db_update_tx,
417 parsing_files_tx,
418 _db_update_task,
419 _embed_batch_task,
420 _batch_files_task,
421 _parsing_files_tasks,
422 projects: HashMap::new(),
423 }
424 }))
425 }
426
427 fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
428 let (tx, rx) = oneshot::channel();
429 self.db_update_tx
430 .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
431 .unwrap();
432 async move { rx.await? }
433 }
434
435 fn get_file_mtimes(
436 &self,
437 worktree_id: i64,
438 ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
439 let (tx, rx) = oneshot::channel();
440 self.db_update_tx
441 .try_send(DbOperation::FileMTimes {
442 worktree_id,
443 sender: tx,
444 })
445 .unwrap();
446 async move { rx.await? }
447 }
448
449 fn add_project(
450 &mut self,
451 project: ModelHandle<Project>,
452 cx: &mut ModelContext<Self>,
453 ) -> Task<Result<()>> {
454 let worktree_scans_complete = project
455 .read(cx)
456 .worktrees(cx)
457 .map(|worktree| {
458 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
459 async move {
460 scan_complete.await;
461 }
462 })
463 .collect::<Vec<_>>();
464 let worktree_db_ids = project
465 .read(cx)
466 .worktrees(cx)
467 .map(|worktree| {
468 self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
469 })
470 .collect::<Vec<_>>();
471
472 let fs = self.fs.clone();
473 let language_registry = self.language_registry.clone();
474 let database_url = self.database_url.clone();
475 let db_update_tx = self.db_update_tx.clone();
476 let parsing_files_tx = self.parsing_files_tx.clone();
477
478 cx.spawn(|this, mut cx| async move {
479 let t0 = Instant::now();
480 futures::future::join_all(worktree_scans_complete).await;
481
482 let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
483 log::info!("Worktree Scanning Done in {:?}", t0.elapsed().as_millis());
484
485 if let Some(db_directory) = database_url.parent() {
486 fs.create_dir(db_directory).await.log_err();
487 }
488
489 let worktrees = project.read_with(&cx, |project, cx| {
490 project
491 .worktrees(cx)
492 .map(|worktree| worktree.read(cx).snapshot())
493 .collect::<Vec<_>>()
494 });
495
496 let mut worktree_file_times = HashMap::new();
497 let mut db_ids_by_worktree_id = HashMap::new();
498 for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
499 let db_id = db_id?;
500 db_ids_by_worktree_id.insert(worktree.id(), db_id);
501 worktree_file_times.insert(
502 worktree.id(),
503 this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
504 .await?,
505 );
506 }
507
508 cx.background()
509 .spawn({
510 let db_ids_by_worktree_id = db_ids_by_worktree_id.clone();
511 let db_update_tx = db_update_tx.clone();
512 let language_registry = language_registry.clone();
513 let parsing_files_tx = parsing_files_tx.clone();
514 async move {
515 let t0 = Instant::now();
516 for worktree in worktrees.into_iter() {
517 let mut file_mtimes =
518 worktree_file_times.remove(&worktree.id()).unwrap();
519 for file in worktree.files(false, 0) {
520 let absolute_path = worktree.absolutize(&file.path);
521
522 if let Ok(language) = language_registry
523 .language_for_file(&absolute_path, None)
524 .await
525 {
526 if language
527 .grammar()
528 .and_then(|grammar| grammar.embedding_config.as_ref())
529 .is_none()
530 {
531 continue;
532 }
533
534 let path_buf = file.path.to_path_buf();
535 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
536 let already_stored = stored_mtime
537 .map_or(false, |existing_mtime| {
538 existing_mtime == file.mtime
539 });
540
541 if !already_stored {
542 parsing_files_tx
543 .try_send(PendingFile {
544 worktree_db_id: db_ids_by_worktree_id
545 [&worktree.id()],
546 relative_path: path_buf,
547 absolute_path,
548 language,
549 modified_time: file.mtime,
550 })
551 .unwrap();
552 }
553 }
554 }
555 for file in file_mtimes.keys() {
556 db_update_tx
557 .try_send(DbOperation::Delete {
558 worktree_id: db_ids_by_worktree_id[&worktree.id()],
559 path: file.to_owned(),
560 })
561 .unwrap();
562 }
563 }
564 log::info!(
565 "Parsing Worktree Completed in {:?}",
566 t0.elapsed().as_millis()
567 );
568 }
569 })
570 .detach();
571
572 // let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc<Language>, SystemTime), SystemTime))> = vec![];
573 this.update(&mut cx, |this, cx| {
574 // The below is managing for updated on save
575 // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
576 // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
577 let _subscription = cx.subscribe(&project, |this, project, event, cx| {
578 if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event {
579 this.project_entries_changed(project, changes.clone(), cx, worktree_id);
580 }
581 });
582
583 this.projects.insert(
584 project.downgrade(),
585 ProjectState {
586 pending_files: HashMap::new(),
587 worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(),
588 _subscription,
589 },
590 );
591 });
592
593 anyhow::Ok(())
594 })
595 }
596
597 pub fn search(
598 &mut self,
599 project: ModelHandle<Project>,
600 phrase: String,
601 limit: usize,
602 cx: &mut ModelContext<Self>,
603 ) -> Task<Result<Vec<SearchResult>>> {
604 let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
605 state
606 } else {
607 return Task::ready(Err(anyhow!("project not added")));
608 };
609
610 let worktree_db_ids = project
611 .read(cx)
612 .worktrees(cx)
613 .filter_map(|worktree| {
614 let worktree_id = worktree.read(cx).id();
615 project_state.db_id_for_worktree_id(worktree_id)
616 })
617 .collect::<Vec<_>>();
618
619 let embedding_provider = self.embedding_provider.clone();
620 let database_url = self.database_url.clone();
621 cx.spawn(|this, cx| async move {
622 let documents = cx
623 .background()
624 .spawn(async move {
625 let database = VectorDatabase::new(database_url.to_string_lossy().into())?;
626
627 let phrase_embedding = embedding_provider
628 .embed_batch(vec![&phrase])
629 .await?
630 .into_iter()
631 .next()
632 .unwrap();
633
634 database.top_k_search(&worktree_db_ids, &phrase_embedding, limit)
635 })
636 .await?;
637
638 dbg!(&documents);
639
640 this.read_with(&cx, |this, _| {
641 let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
642 state
643 } else {
644 return Err(anyhow!("project not added"));
645 };
646
647 Ok(documents
648 .into_iter()
649 .filter_map(|(worktree_db_id, file_path, offset, name)| {
650 let worktree_id = project_state.worktree_id_for_db_id(worktree_db_id)?;
651 Some(SearchResult {
652 worktree_id,
653 name,
654 offset,
655 file_path,
656 })
657 })
658 .collect())
659 })
660 })
661 }
662
663 fn project_entries_changed(
664 &mut self,
665 project: ModelHandle<Project>,
666 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
667 cx: &mut ModelContext<'_, VectorStore>,
668 worktree_id: &WorktreeId,
669 ) -> Option<()> {
670 let worktree = project
671 .read(cx)
672 .worktree_for_id(worktree_id.clone(), cx)?
673 .read(cx)
674 .snapshot();
675
676 let worktree_db_id = self
677 .projects
678 .get(&project.downgrade())?
679 .db_id_for_worktree_id(worktree.id())?;
680 let file_mtimes = self.get_file_mtimes(worktree_db_id);
681
682 let language_registry = self.language_registry.clone();
683
684 cx.spawn(|this, mut cx| async move {
685 let file_mtimes = file_mtimes.await.log_err()?;
686
687 for change in changes.into_iter() {
688 let change_path = change.0.clone();
689 let absolute_path = worktree.absolutize(&change_path);
690 // Skip if git ignored or symlink
691 if let Some(entry) = worktree.entry_for_id(change.1) {
692 if entry.is_ignored || entry.is_symlink || entry.is_external {
693 continue;
694 }
695 }
696
697 if let Ok(language) = language_registry
698 .language_for_file(&change_path.to_path_buf(), None)
699 .await
700 {
701 if language
702 .grammar()
703 .and_then(|grammar| grammar.embedding_config.as_ref())
704 .is_none()
705 {
706 continue;
707 }
708
709 let modified_time = change_path.metadata().log_err()?.modified().log_err()?;
710
711 let existing_time = file_mtimes.get(&change_path.to_path_buf());
712 let already_stored = existing_time
713 .map_or(false, |existing_time| &modified_time != existing_time);
714
715 if !already_stored {
716 this.update(&mut cx, |this, _| {
717 let reindex_time =
718 modified_time + Duration::from_secs(REINDEXING_DELAY_SECONDS);
719
720 let project_state = this.projects.get_mut(&project.downgrade())?;
721 project_state.update_pending_files(
722 PendingFile {
723 relative_path: change_path.to_path_buf(),
724 absolute_path,
725 modified_time,
726 worktree_db_id,
727 language: language.clone(),
728 },
729 reindex_time,
730 );
731
732 for file in project_state.get_outstanding_files() {
733 this.parsing_files_tx.try_send(file).unwrap();
734 }
735 Some(())
736 });
737 }
738 }
739 }
740
741 Some(())
742 })
743 .detach();
744
745 Some(())
746 }
747}
748
749impl Entity for VectorStore {
750 type Event = ();
751}