vector_store.rs

  1mod db;
  2mod embedding;
  3mod modal;
  4mod parsing;
  5
  6#[cfg(test)]
  7mod vector_store_tests;
  8
  9use anyhow::{anyhow, Result};
 10use db::VectorDatabase;
 11use embedding::{EmbeddingProvider, OpenAIEmbeddings};
 12use futures::{channel::oneshot, Future};
 13use gpui::{
 14    AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
 15    WeakModelHandle,
 16};
 17use language::{Language, LanguageRegistry};
 18use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
 19use parsing::{CodeContextRetriever, ParsedFile};
 20use project::{Fs, PathChange, Project, ProjectEntryId, WorktreeId};
 21use smol::channel;
 22use std::{
 23    collections::HashMap,
 24    path::{Path, PathBuf},
 25    sync::Arc,
 26    time::{Duration, Instant, SystemTime},
 27};
 28use tree_sitter::{Parser, QueryCursor};
 29use util::{
 30    channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
 31    http::HttpClient,
 32    paths::EMBEDDINGS_DIR,
 33    ResultExt,
 34};
 35use workspace::{Workspace, WorkspaceCreated};
 36
 37const REINDEXING_DELAY_SECONDS: u64 = 3;
 38const EMBEDDINGS_BATCH_SIZE: usize = 150;
 39
 40pub fn init(
 41    fs: Arc<dyn Fs>,
 42    http_client: Arc<dyn HttpClient>,
 43    language_registry: Arc<LanguageRegistry>,
 44    cx: &mut AppContext,
 45) {
 46    if *RELEASE_CHANNEL == ReleaseChannel::Stable {
 47        return;
 48    }
 49
 50    let db_file_path = EMBEDDINGS_DIR
 51        .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
 52        .join("embeddings_db");
 53
 54    cx.spawn(move |mut cx| async move {
 55        let vector_store = VectorStore::new(
 56            fs,
 57            db_file_path,
 58            // Arc::new(embedding::DummyEmbeddings {}),
 59            Arc::new(OpenAIEmbeddings {
 60                client: http_client,
 61                executor: cx.background(),
 62            }),
 63            language_registry,
 64            cx.clone(),
 65        )
 66        .await?;
 67
 68        cx.update(|cx| {
 69            cx.subscribe_global::<WorkspaceCreated, _>({
 70                let vector_store = vector_store.clone();
 71                move |event, cx| {
 72                    let workspace = &event.0;
 73                    if let Some(workspace) = workspace.upgrade(cx) {
 74                        let project = workspace.read(cx).project().clone();
 75                        if project.read(cx).is_local() {
 76                            vector_store.update(cx, |store, cx| {
 77                                store.add_project(project, cx).detach();
 78                            });
 79                        }
 80                    }
 81                }
 82            })
 83            .detach();
 84
 85            cx.add_action({
 86                move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
 87                    let vector_store = vector_store.clone();
 88                    workspace.toggle_modal(cx, |workspace, cx| {
 89                        let project = workspace.project().clone();
 90                        let workspace = cx.weak_handle();
 91                        cx.add_view(|cx| {
 92                            SemanticSearch::new(
 93                                SemanticSearchDelegate::new(workspace, project, vector_store),
 94                                cx,
 95                            )
 96                        })
 97                    })
 98                }
 99            });
100
101            SemanticSearch::init(cx);
102        });
103
104        anyhow::Ok(())
105    })
106    .detach();
107}
108
109pub struct VectorStore {
110    fs: Arc<dyn Fs>,
111    database_url: Arc<PathBuf>,
112    embedding_provider: Arc<dyn EmbeddingProvider>,
113    language_registry: Arc<LanguageRegistry>,
114    db_update_tx: channel::Sender<DbOperation>,
115    parsing_files_tx: channel::Sender<PendingFile>,
116    _db_update_task: Task<()>,
117    _embed_batch_task: Task<()>,
118    _batch_files_task: Task<()>,
119    _parsing_files_tasks: Vec<Task<()>>,
120    projects: HashMap<WeakModelHandle<Project>, ProjectState>,
121}
122
123struct ProjectState {
124    worktree_db_ids: Vec<(WorktreeId, i64)>,
125    pending_files: HashMap<PathBuf, (PendingFile, SystemTime)>,
126    _subscription: gpui::Subscription,
127}
128
129impl ProjectState {
130    fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
131        self.worktree_db_ids
132            .iter()
133            .find_map(|(worktree_id, db_id)| {
134                if *worktree_id == id {
135                    Some(*db_id)
136                } else {
137                    None
138                }
139            })
140    }
141
142    fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
143        self.worktree_db_ids
144            .iter()
145            .find_map(|(worktree_id, db_id)| {
146                if *db_id == id {
147                    Some(*worktree_id)
148                } else {
149                    None
150                }
151            })
152    }
153
154    fn update_pending_files(&mut self, pending_file: PendingFile, indexing_time: SystemTime) {
155        // If Pending File Already Exists, Replace it with the new one
156        // but keep the old indexing time
157        if let Some(old_file) = self
158            .pending_files
159            .remove(&pending_file.relative_path.clone())
160        {
161            self.pending_files.insert(
162                pending_file.relative_path.clone(),
163                (pending_file, old_file.1),
164            );
165        } else {
166            self.pending_files.insert(
167                pending_file.relative_path.clone(),
168                (pending_file, indexing_time),
169            );
170        };
171    }
172
173    fn get_outstanding_files(&mut self) -> Vec<PendingFile> {
174        let mut outstanding_files = vec![];
175        let mut remove_keys = vec![];
176        for key in self.pending_files.keys().into_iter() {
177            if let Some(pending_details) = self.pending_files.get(key) {
178                let (pending_file, index_time) = pending_details;
179                if index_time <= &SystemTime::now() {
180                    outstanding_files.push(pending_file.clone());
181                    remove_keys.push(key.clone());
182                }
183            }
184        }
185
186        for key in remove_keys.iter() {
187            self.pending_files.remove(key);
188        }
189
190        return outstanding_files;
191    }
192}
193
194#[derive(Clone, Debug)]
195pub struct PendingFile {
196    worktree_db_id: i64,
197    relative_path: PathBuf,
198    absolute_path: PathBuf,
199    language: Arc<Language>,
200    modified_time: SystemTime,
201}
202
203#[derive(Debug, Clone)]
204pub struct SearchResult {
205    pub worktree_id: WorktreeId,
206    pub name: String,
207    pub offset: usize,
208    pub file_path: PathBuf,
209}
210
211enum DbOperation {
212    InsertFile {
213        worktree_id: i64,
214        indexed_file: ParsedFile,
215    },
216    Delete {
217        worktree_id: i64,
218        path: PathBuf,
219    },
220    FindOrCreateWorktree {
221        path: PathBuf,
222        sender: oneshot::Sender<Result<i64>>,
223    },
224    FileMTimes {
225        worktree_id: i64,
226        sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
227    },
228}
229
230enum EmbeddingJob {
231    Enqueue {
232        worktree_id: i64,
233        parsed_file: ParsedFile,
234        document_spans: Vec<String>,
235    },
236    Flush,
237}
238
239impl VectorStore {
240    async fn new(
241        fs: Arc<dyn Fs>,
242        database_url: PathBuf,
243        embedding_provider: Arc<dyn EmbeddingProvider>,
244        language_registry: Arc<LanguageRegistry>,
245        mut cx: AsyncAppContext,
246    ) -> Result<ModelHandle<Self>> {
247        let database_url = Arc::new(database_url);
248
249        let db = cx
250            .background()
251            .spawn({
252                let fs = fs.clone();
253                let database_url = database_url.clone();
254                async move {
255                    if let Some(db_directory) = database_url.parent() {
256                        fs.create_dir(db_directory).await.log_err();
257                    }
258
259                    let db = VectorDatabase::new(database_url.to_string_lossy().to_string())?;
260                    anyhow::Ok(db)
261                }
262            })
263            .await?;
264
265        Ok(cx.add_model(|cx| {
266            // paths_tx -> embeddings_tx -> db_update_tx
267
268            //db_update_tx/rx: Updating Database
269            let (db_update_tx, db_update_rx) = channel::unbounded();
270            let _db_update_task = cx.background().spawn(async move {
271                while let Ok(job) = db_update_rx.recv().await {
272                    match job {
273                        DbOperation::InsertFile {
274                            worktree_id,
275                            indexed_file,
276                        } => {
277                            log::info!("Inserting Data for {:?}", &indexed_file.path);
278                            db.insert_file(worktree_id, indexed_file).log_err();
279                        }
280                        DbOperation::Delete { worktree_id, path } => {
281                            db.delete_file(worktree_id, path).log_err();
282                        }
283                        DbOperation::FindOrCreateWorktree { path, sender } => {
284                            let id = db.find_or_create_worktree(&path);
285                            sender.send(id).ok();
286                        }
287                        DbOperation::FileMTimes {
288                            worktree_id: worktree_db_id,
289                            sender,
290                        } => {
291                            let file_mtimes = db.get_file_mtimes(worktree_db_id);
292                            sender.send(file_mtimes).ok();
293                        }
294                    }
295                }
296            });
297
298            // embed_tx/rx: Embed Batch and Send to Database
299            let (embed_batch_tx, embed_batch_rx) =
300                channel::unbounded::<Vec<(i64, ParsedFile, Vec<String>)>>();
301            let _embed_batch_task = cx.background().spawn({
302                let db_update_tx = db_update_tx.clone();
303                let embedding_provider = embedding_provider.clone();
304                async move {
305                    while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await {
306                        // Construct Batch
307                        let mut document_spans = vec![];
308                        for (_, _, document_span) in embeddings_queue.iter() {
309                            document_spans.extend(document_span.iter().map(|s| s.as_str()));
310                        }
311
312                        if let Ok(embeddings) = embedding_provider.embed_batch(document_spans).await
313                        {
314                            let mut i = 0;
315                            let mut j = 0;
316
317                            for embedding in embeddings.iter() {
318                                while embeddings_queue[i].1.documents.len() == j {
319                                    i += 1;
320                                    j = 0;
321                                }
322
323                                embeddings_queue[i].1.documents[j].embedding = embedding.to_owned();
324                                j += 1;
325                            }
326
327                            for (worktree_id, indexed_file, _) in embeddings_queue.into_iter() {
328                                for document in indexed_file.documents.iter() {
329                                    // TODO: Update this so it doesn't panic
330                                    assert!(
331                                        document.embedding.len() > 0,
332                                        "Document Embedding Not Complete"
333                                    );
334                                }
335
336                                db_update_tx
337                                    .send(DbOperation::InsertFile {
338                                        worktree_id,
339                                        indexed_file,
340                                    })
341                                    .await
342                                    .unwrap();
343                            }
344                        }
345                    }
346                }
347            });
348
349            // batch_tx/rx: Batch Files to Send for Embeddings
350            let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
351            let _batch_files_task = cx.background().spawn(async move {
352                let mut queue_len = 0;
353                let mut embeddings_queue = vec![];
354
355                while let Ok(job) = batch_files_rx.recv().await {
356                    let should_flush = match job {
357                        EmbeddingJob::Enqueue {
358                            document_spans,
359                            worktree_id,
360                            parsed_file,
361                        } => {
362                            queue_len += &document_spans.len();
363                            embeddings_queue.push((worktree_id, parsed_file, document_spans));
364                            queue_len >= EMBEDDINGS_BATCH_SIZE
365                        }
366                        EmbeddingJob::Flush => true,
367                    };
368
369                    if should_flush {
370                        embed_batch_tx.try_send(embeddings_queue).unwrap();
371                        embeddings_queue = vec![];
372                        queue_len = 0;
373                    }
374                }
375            });
376
377            // parsing_files_tx/rx: Parsing Files to Embeddable Documents
378            let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
379
380            let mut _parsing_files_tasks = Vec::new();
381            for _ in 0..cx.background().num_cpus() {
382                let fs = fs.clone();
383                let parsing_files_rx = parsing_files_rx.clone();
384                let batch_files_tx = batch_files_tx.clone();
385                _parsing_files_tasks.push(cx.background().spawn(async move {
386                    let parser = Parser::new();
387                    let cursor = QueryCursor::new();
388                    let mut retriever = CodeContextRetriever { parser, cursor, fs };
389                    while let Ok(pending_file) = parsing_files_rx.recv().await {
390                        log::info!("Parsing File: {:?}", &pending_file.relative_path);
391
392                        if let Some((indexed_file, document_spans)) =
393                            retriever.parse_file(pending_file.clone()).await.log_err()
394                        {
395                            batch_files_tx
396                                .try_send(EmbeddingJob::Enqueue {
397                                    worktree_id: pending_file.worktree_db_id,
398                                    parsed_file: indexed_file,
399                                    document_spans,
400                                })
401                                .unwrap();
402                        }
403
404                        if parsing_files_rx.len() == 0 {
405                            batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
406                        }
407                    }
408                }));
409            }
410
411            Self {
412                fs,
413                database_url,
414                embedding_provider,
415                language_registry,
416                db_update_tx,
417                parsing_files_tx,
418                _db_update_task,
419                _embed_batch_task,
420                _batch_files_task,
421                _parsing_files_tasks,
422                projects: HashMap::new(),
423            }
424        }))
425    }
426
427    fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
428        let (tx, rx) = oneshot::channel();
429        self.db_update_tx
430            .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
431            .unwrap();
432        async move { rx.await? }
433    }
434
435    fn get_file_mtimes(
436        &self,
437        worktree_id: i64,
438    ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
439        let (tx, rx) = oneshot::channel();
440        self.db_update_tx
441            .try_send(DbOperation::FileMTimes {
442                worktree_id,
443                sender: tx,
444            })
445            .unwrap();
446        async move { rx.await? }
447    }
448
449    fn add_project(
450        &mut self,
451        project: ModelHandle<Project>,
452        cx: &mut ModelContext<Self>,
453    ) -> Task<Result<()>> {
454        let worktree_scans_complete = project
455            .read(cx)
456            .worktrees(cx)
457            .map(|worktree| {
458                let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
459                async move {
460                    scan_complete.await;
461                }
462            })
463            .collect::<Vec<_>>();
464        let worktree_db_ids = project
465            .read(cx)
466            .worktrees(cx)
467            .map(|worktree| {
468                self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
469            })
470            .collect::<Vec<_>>();
471
472        let fs = self.fs.clone();
473        let language_registry = self.language_registry.clone();
474        let database_url = self.database_url.clone();
475        let db_update_tx = self.db_update_tx.clone();
476        let parsing_files_tx = self.parsing_files_tx.clone();
477
478        cx.spawn(|this, mut cx| async move {
479            let t0 = Instant::now();
480            futures::future::join_all(worktree_scans_complete).await;
481
482            let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
483            log::info!("Worktree Scanning Done in {:?}", t0.elapsed().as_millis());
484
485            if let Some(db_directory) = database_url.parent() {
486                fs.create_dir(db_directory).await.log_err();
487            }
488
489            let worktrees = project.read_with(&cx, |project, cx| {
490                project
491                    .worktrees(cx)
492                    .map(|worktree| worktree.read(cx).snapshot())
493                    .collect::<Vec<_>>()
494            });
495
496            let mut worktree_file_times = HashMap::new();
497            let mut db_ids_by_worktree_id = HashMap::new();
498            for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
499                let db_id = db_id?;
500                db_ids_by_worktree_id.insert(worktree.id(), db_id);
501                worktree_file_times.insert(
502                    worktree.id(),
503                    this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
504                        .await?,
505                );
506            }
507
508            cx.background()
509                .spawn({
510                    let db_ids_by_worktree_id = db_ids_by_worktree_id.clone();
511                    let db_update_tx = db_update_tx.clone();
512                    let language_registry = language_registry.clone();
513                    let parsing_files_tx = parsing_files_tx.clone();
514                    async move {
515                        let t0 = Instant::now();
516                        for worktree in worktrees.into_iter() {
517                            let mut file_mtimes =
518                                worktree_file_times.remove(&worktree.id()).unwrap();
519                            for file in worktree.files(false, 0) {
520                                let absolute_path = worktree.absolutize(&file.path);
521
522                                if let Ok(language) = language_registry
523                                    .language_for_file(&absolute_path, None)
524                                    .await
525                                {
526                                    if language
527                                        .grammar()
528                                        .and_then(|grammar| grammar.embedding_config.as_ref())
529                                        .is_none()
530                                    {
531                                        continue;
532                                    }
533
534                                    let path_buf = file.path.to_path_buf();
535                                    let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
536                                    let already_stored = stored_mtime
537                                        .map_or(false, |existing_mtime| {
538                                            existing_mtime == file.mtime
539                                        });
540
541                                    if !already_stored {
542                                        parsing_files_tx
543                                            .try_send(PendingFile {
544                                                worktree_db_id: db_ids_by_worktree_id
545                                                    [&worktree.id()],
546                                                relative_path: path_buf,
547                                                absolute_path,
548                                                language,
549                                                modified_time: file.mtime,
550                                            })
551                                            .unwrap();
552                                    }
553                                }
554                            }
555                            for file in file_mtimes.keys() {
556                                db_update_tx
557                                    .try_send(DbOperation::Delete {
558                                        worktree_id: db_ids_by_worktree_id[&worktree.id()],
559                                        path: file.to_owned(),
560                                    })
561                                    .unwrap();
562                            }
563                        }
564                        log::info!(
565                            "Parsing Worktree Completed in {:?}",
566                            t0.elapsed().as_millis()
567                        );
568                    }
569                })
570                .detach();
571
572            // let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc<Language>, SystemTime), SystemTime))> = vec![];
573            this.update(&mut cx, |this, cx| {
574                // The below is managing for updated on save
575                // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
576                // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
577                let _subscription = cx.subscribe(&project, |this, project, event, cx| {
578                    if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event {
579                        this.project_entries_changed(project, changes.clone(), cx, worktree_id);
580                    }
581                });
582
583                this.projects.insert(
584                    project.downgrade(),
585                    ProjectState {
586                        pending_files: HashMap::new(),
587                        worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(),
588                        _subscription,
589                    },
590                );
591            });
592
593            anyhow::Ok(())
594        })
595    }
596
597    pub fn search(
598        &mut self,
599        project: ModelHandle<Project>,
600        phrase: String,
601        limit: usize,
602        cx: &mut ModelContext<Self>,
603    ) -> Task<Result<Vec<SearchResult>>> {
604        let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
605            state
606        } else {
607            return Task::ready(Err(anyhow!("project not added")));
608        };
609
610        let worktree_db_ids = project
611            .read(cx)
612            .worktrees(cx)
613            .filter_map(|worktree| {
614                let worktree_id = worktree.read(cx).id();
615                project_state.db_id_for_worktree_id(worktree_id)
616            })
617            .collect::<Vec<_>>();
618
619        let embedding_provider = self.embedding_provider.clone();
620        let database_url = self.database_url.clone();
621        cx.spawn(|this, cx| async move {
622            let documents = cx
623                .background()
624                .spawn(async move {
625                    let database = VectorDatabase::new(database_url.to_string_lossy().into())?;
626
627                    let phrase_embedding = embedding_provider
628                        .embed_batch(vec![&phrase])
629                        .await?
630                        .into_iter()
631                        .next()
632                        .unwrap();
633
634                    database.top_k_search(&worktree_db_ids, &phrase_embedding, limit)
635                })
636                .await?;
637
638            dbg!(&documents);
639
640            this.read_with(&cx, |this, _| {
641                let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
642                    state
643                } else {
644                    return Err(anyhow!("project not added"));
645                };
646
647                Ok(documents
648                    .into_iter()
649                    .filter_map(|(worktree_db_id, file_path, offset, name)| {
650                        let worktree_id = project_state.worktree_id_for_db_id(worktree_db_id)?;
651                        Some(SearchResult {
652                            worktree_id,
653                            name,
654                            offset,
655                            file_path,
656                        })
657                    })
658                    .collect())
659            })
660        })
661    }
662
663    fn project_entries_changed(
664        &mut self,
665        project: ModelHandle<Project>,
666        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
667        cx: &mut ModelContext<'_, VectorStore>,
668        worktree_id: &WorktreeId,
669    ) -> Option<()> {
670        let worktree = project
671            .read(cx)
672            .worktree_for_id(worktree_id.clone(), cx)?
673            .read(cx)
674            .snapshot();
675
676        let worktree_db_id = self
677            .projects
678            .get(&project.downgrade())?
679            .db_id_for_worktree_id(worktree.id())?;
680        let file_mtimes = self.get_file_mtimes(worktree_db_id);
681
682        let language_registry = self.language_registry.clone();
683
684        cx.spawn(|this, mut cx| async move {
685            let file_mtimes = file_mtimes.await.log_err()?;
686
687            for change in changes.into_iter() {
688                let change_path = change.0.clone();
689                let absolute_path = worktree.absolutize(&change_path);
690                // Skip if git ignored or symlink
691                if let Some(entry) = worktree.entry_for_id(change.1) {
692                    if entry.is_ignored || entry.is_symlink || entry.is_external {
693                        continue;
694                    }
695                }
696
697                if let Ok(language) = language_registry
698                    .language_for_file(&change_path.to_path_buf(), None)
699                    .await
700                {
701                    if language
702                        .grammar()
703                        .and_then(|grammar| grammar.embedding_config.as_ref())
704                        .is_none()
705                    {
706                        continue;
707                    }
708
709                    let modified_time = change_path.metadata().log_err()?.modified().log_err()?;
710
711                    let existing_time = file_mtimes.get(&change_path.to_path_buf());
712                    let already_stored = existing_time
713                        .map_or(false, |existing_time| &modified_time != existing_time);
714
715                    if !already_stored {
716                        this.update(&mut cx, |this, _| {
717                            let reindex_time =
718                                modified_time + Duration::from_secs(REINDEXING_DELAY_SECONDS);
719
720                            let project_state = this.projects.get_mut(&project.downgrade())?;
721                            project_state.update_pending_files(
722                                PendingFile {
723                                    relative_path: change_path.to_path_buf(),
724                                    absolute_path,
725                                    modified_time,
726                                    worktree_db_id,
727                                    language: language.clone(),
728                                },
729                                reindex_time,
730                            );
731
732                            for file in project_state.get_outstanding_files() {
733                                this.parsing_files_tx.try_send(file).unwrap();
734                            }
735                            Some(())
736                        });
737                    }
738                }
739            }
740
741            Some(())
742        })
743        .detach();
744
745        Some(())
746    }
747}
748
749impl Entity for VectorStore {
750    type Event = ();
751}