vector_store.rs

  1mod db;
  2mod embedding;
  3mod modal;
  4mod parsing;
  5mod vector_store_settings;
  6
  7#[cfg(test)]
  8mod vector_store_tests;
  9
 10use crate::vector_store_settings::VectorStoreSettings;
 11use anyhow::{anyhow, Result};
 12use db::VectorDatabase;
 13use embedding::{EmbeddingProvider, OpenAIEmbeddings};
 14use futures::{channel::oneshot, Future};
 15use gpui::{
 16    AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
 17    WeakModelHandle,
 18};
 19use language::{Language, LanguageRegistry};
 20use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
 21use parking_lot::Mutex;
 22use parsing::{CodeContextRetriever, Document};
 23use project::{Fs, Project, WorktreeId};
 24use smol::channel;
 25use std::{
 26    collections::{HashMap, HashSet},
 27    ops::Range,
 28    path::{Path, PathBuf},
 29    sync::{
 30        atomic::{self, AtomicUsize},
 31        Arc, Weak,
 32    },
 33    time::{Instant, SystemTime},
 34};
 35use util::{
 36    channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
 37    http::HttpClient,
 38    paths::EMBEDDINGS_DIR,
 39    ResultExt,
 40};
 41use workspace::{Workspace, WorkspaceCreated};
 42
 43const VECTOR_STORE_VERSION: usize = 1;
 44const EMBEDDINGS_BATCH_SIZE: usize = 150;
 45
 46pub fn init(
 47    fs: Arc<dyn Fs>,
 48    http_client: Arc<dyn HttpClient>,
 49    language_registry: Arc<LanguageRegistry>,
 50    cx: &mut AppContext,
 51) {
 52    settings::register::<VectorStoreSettings>(cx);
 53
 54    let db_file_path = EMBEDDINGS_DIR
 55        .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
 56        .join("embeddings_db");
 57
 58    SemanticSearch::init(cx);
 59    cx.add_action(
 60        |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
 61            if cx.has_global::<ModelHandle<VectorStore>>() {
 62                let vector_store = cx.global::<ModelHandle<VectorStore>>().clone();
 63                workspace.toggle_modal(cx, |workspace, cx| {
 64                    let project = workspace.project().clone();
 65                    let workspace = cx.weak_handle();
 66                    cx.add_view(|cx| {
 67                        SemanticSearch::new(
 68                            SemanticSearchDelegate::new(workspace, project, vector_store),
 69                            cx,
 70                        )
 71                    })
 72                });
 73            }
 74        },
 75    );
 76
 77    if *RELEASE_CHANNEL == ReleaseChannel::Stable
 78        || !settings::get::<VectorStoreSettings>(cx).enabled
 79    {
 80        return;
 81    }
 82
 83    cx.spawn(move |mut cx| async move {
 84        let vector_store = VectorStore::new(
 85            fs,
 86            db_file_path,
 87            Arc::new(OpenAIEmbeddings {
 88                client: http_client,
 89                executor: cx.background(),
 90            }),
 91            language_registry,
 92            cx.clone(),
 93        )
 94        .await?;
 95
 96        cx.update(|cx| {
 97            cx.set_global(vector_store.clone());
 98            cx.subscribe_global::<WorkspaceCreated, _>({
 99                let vector_store = vector_store.clone();
100                move |event, cx| {
101                    let workspace = &event.0;
102                    if let Some(workspace) = workspace.upgrade(cx) {
103                        let project = workspace.read(cx).project().clone();
104                        if project.read(cx).is_local() {
105                            vector_store.update(cx, |store, cx| {
106                                store.index_project(project, cx).detach();
107                            });
108                        }
109                    }
110                }
111            })
112            .detach();
113        });
114
115        anyhow::Ok(())
116    })
117    .detach();
118}
119
120pub struct VectorStore {
121    fs: Arc<dyn Fs>,
122    database_url: Arc<PathBuf>,
123    embedding_provider: Arc<dyn EmbeddingProvider>,
124    language_registry: Arc<LanguageRegistry>,
125    db_update_tx: channel::Sender<DbOperation>,
126    parsing_files_tx: channel::Sender<PendingFile>,
127    _db_update_task: Task<()>,
128    _embed_batch_task: Task<()>,
129    _batch_files_task: Task<()>,
130    _parsing_files_tasks: Vec<Task<()>>,
131    next_job_id: Arc<AtomicUsize>,
132    projects: HashMap<WeakModelHandle<Project>, ProjectState>,
133}
134
135struct ProjectState {
136    worktree_db_ids: Vec<(WorktreeId, i64)>,
137    outstanding_jobs: Arc<Mutex<HashSet<JobId>>>,
138}
139
140type JobId = usize;
141
142struct JobHandle {
143    id: JobId,
144    set: Weak<Mutex<HashSet<JobId>>>,
145}
146
147impl ProjectState {
148    fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
149        self.worktree_db_ids
150            .iter()
151            .find_map(|(worktree_id, db_id)| {
152                if *worktree_id == id {
153                    Some(*db_id)
154                } else {
155                    None
156                }
157            })
158    }
159
160    fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
161        self.worktree_db_ids
162            .iter()
163            .find_map(|(worktree_id, db_id)| {
164                if *db_id == id {
165                    Some(*worktree_id)
166                } else {
167                    None
168                }
169            })
170    }
171}
172
173pub struct PendingFile {
174    worktree_db_id: i64,
175    relative_path: PathBuf,
176    absolute_path: PathBuf,
177    language: Arc<Language>,
178    modified_time: SystemTime,
179    job_handle: JobHandle,
180}
181
182#[derive(Debug, Clone)]
183pub struct SearchResult {
184    pub worktree_id: WorktreeId,
185    pub name: String,
186    pub byte_range: Range<usize>,
187    pub file_path: PathBuf,
188}
189
190enum DbOperation {
191    InsertFile {
192        worktree_id: i64,
193        documents: Vec<Document>,
194        path: PathBuf,
195        mtime: SystemTime,
196        job_handle: JobHandle,
197    },
198    Delete {
199        worktree_id: i64,
200        path: PathBuf,
201    },
202    FindOrCreateWorktree {
203        path: PathBuf,
204        sender: oneshot::Sender<Result<i64>>,
205    },
206    FileMTimes {
207        worktree_id: i64,
208        sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
209    },
210}
211
212enum EmbeddingJob {
213    Enqueue {
214        worktree_id: i64,
215        path: PathBuf,
216        mtime: SystemTime,
217        documents: Vec<Document>,
218        job_handle: JobHandle,
219    },
220    Flush,
221}
222
223impl VectorStore {
224    async fn new(
225        fs: Arc<dyn Fs>,
226        database_url: PathBuf,
227        embedding_provider: Arc<dyn EmbeddingProvider>,
228        language_registry: Arc<LanguageRegistry>,
229        mut cx: AsyncAppContext,
230    ) -> Result<ModelHandle<Self>> {
231        let database_url = Arc::new(database_url);
232
233        let db = cx
234            .background()
235            .spawn(VectorDatabase::new(fs.clone(), database_url.clone()))
236            .await?;
237
238        Ok(cx.add_model(|cx| {
239            // paths_tx -> embeddings_tx -> db_update_tx
240
241            //db_update_tx/rx: Updating Database
242            let (db_update_tx, db_update_rx) = channel::unbounded();
243            let _db_update_task = cx.background().spawn(async move {
244                while let Ok(job) = db_update_rx.recv().await {
245                    match job {
246                        DbOperation::InsertFile {
247                            worktree_id,
248                            documents,
249                            path,
250                            mtime,
251                            job_handle,
252                        } => {
253                            db.insert_file(worktree_id, path, mtime, documents)
254                                .log_err();
255                            drop(job_handle)
256                        }
257                        DbOperation::Delete { worktree_id, path } => {
258                            db.delete_file(worktree_id, path).log_err();
259                        }
260                        DbOperation::FindOrCreateWorktree { path, sender } => {
261                            let id = db.find_or_create_worktree(&path);
262                            sender.send(id).ok();
263                        }
264                        DbOperation::FileMTimes {
265                            worktree_id: worktree_db_id,
266                            sender,
267                        } => {
268                            let file_mtimes = db.get_file_mtimes(worktree_db_id);
269                            sender.send(file_mtimes).ok();
270                        }
271                    }
272                }
273            });
274
275            // embed_tx/rx: Embed Batch and Send to Database
276            let (embed_batch_tx, embed_batch_rx) =
277                channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
278            let _embed_batch_task = cx.background().spawn({
279                let db_update_tx = db_update_tx.clone();
280                let embedding_provider = embedding_provider.clone();
281                async move {
282                    while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await {
283                        // Construct Batch
284                        let mut batch_documents = vec![];
285                        for (_, documents, _, _, _) in embeddings_queue.iter() {
286                            batch_documents
287                                .extend(documents.iter().map(|document| document.content.as_str()));
288                        }
289
290                        if let Ok(embeddings) =
291                            embedding_provider.embed_batch(batch_documents).await
292                        {
293                            log::trace!(
294                                "created {} embeddings for {} files",
295                                embeddings.len(),
296                                embeddings_queue.len(),
297                            );
298
299                            let mut i = 0;
300                            let mut j = 0;
301
302                            for embedding in embeddings.iter() {
303                                while embeddings_queue[i].1.len() == j {
304                                    i += 1;
305                                    j = 0;
306                                }
307
308                                embeddings_queue[i].1[j].embedding = embedding.to_owned();
309                                j += 1;
310                            }
311
312                            for (worktree_id, documents, path, mtime, job_handle) in
313                                embeddings_queue.into_iter()
314                            {
315                                for document in documents.iter() {
316                                    // TODO: Update this so it doesn't panic
317                                    assert!(
318                                        document.embedding.len() > 0,
319                                        "Document Embedding Not Complete"
320                                    );
321                                }
322
323                                db_update_tx
324                                    .send(DbOperation::InsertFile {
325                                        worktree_id,
326                                        documents,
327                                        path,
328                                        mtime,
329                                        job_handle,
330                                    })
331                                    .await
332                                    .unwrap();
333                            }
334                        }
335                    }
336                }
337            });
338
339            // batch_tx/rx: Batch Files to Send for Embeddings
340            let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
341            let _batch_files_task = cx.background().spawn(async move {
342                let mut queue_len = 0;
343                let mut embeddings_queue = vec![];
344
345                while let Ok(job) = batch_files_rx.recv().await {
346                    let should_flush = match job {
347                        EmbeddingJob::Enqueue {
348                            documents,
349                            worktree_id,
350                            path,
351                            mtime,
352                            job_handle,
353                        } => {
354                            queue_len += &documents.len();
355                            embeddings_queue.push((
356                                worktree_id,
357                                documents,
358                                path,
359                                mtime,
360                                job_handle,
361                            ));
362                            queue_len >= EMBEDDINGS_BATCH_SIZE
363                        }
364                        EmbeddingJob::Flush => true,
365                    };
366
367                    if should_flush {
368                        embed_batch_tx.try_send(embeddings_queue).unwrap();
369                        embeddings_queue = vec![];
370                        queue_len = 0;
371                    }
372                }
373            });
374
375            // parsing_files_tx/rx: Parsing Files to Embeddable Documents
376            let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
377
378            let mut _parsing_files_tasks = Vec::new();
379            for _ in 0..cx.background().num_cpus() {
380                let fs = fs.clone();
381                let parsing_files_rx = parsing_files_rx.clone();
382                let batch_files_tx = batch_files_tx.clone();
383                _parsing_files_tasks.push(cx.background().spawn(async move {
384                    let mut retriever = CodeContextRetriever::new();
385                    while let Ok(pending_file) = parsing_files_rx.recv().await {
386                        if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err()
387                        {
388                            if let Some(documents) = retriever
389                                .parse_file(
390                                    &pending_file.relative_path,
391                                    &content,
392                                    pending_file.language,
393                                )
394                                .log_err()
395                            {
396                                log::trace!(
397                                    "parsed path {:?}: {} documents",
398                                    pending_file.relative_path,
399                                    documents.len()
400                                );
401
402                                batch_files_tx
403                                    .try_send(EmbeddingJob::Enqueue {
404                                        worktree_id: pending_file.worktree_db_id,
405                                        path: pending_file.relative_path,
406                                        mtime: pending_file.modified_time,
407                                        job_handle: pending_file.job_handle,
408                                        documents,
409                                    })
410                                    .unwrap();
411                            }
412                        }
413
414                        if parsing_files_rx.len() == 0 {
415                            batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
416                        }
417                    }
418                }));
419            }
420
421            Self {
422                fs,
423                database_url,
424                embedding_provider,
425                language_registry,
426                db_update_tx,
427                next_job_id: Default::default(),
428                parsing_files_tx,
429                _db_update_task,
430                _embed_batch_task,
431                _batch_files_task,
432                _parsing_files_tasks,
433                projects: HashMap::new(),
434            }
435        }))
436    }
437
438    fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
439        let (tx, rx) = oneshot::channel();
440        self.db_update_tx
441            .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
442            .unwrap();
443        async move { rx.await? }
444    }
445
446    fn get_file_mtimes(
447        &self,
448        worktree_id: i64,
449    ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
450        let (tx, rx) = oneshot::channel();
451        self.db_update_tx
452            .try_send(DbOperation::FileMTimes {
453                worktree_id,
454                sender: tx,
455            })
456            .unwrap();
457        async move { rx.await? }
458    }
459
460    fn index_project(
461        &mut self,
462        project: ModelHandle<Project>,
463        cx: &mut ModelContext<Self>,
464    ) -> Task<Result<usize>> {
465        let worktree_scans_complete = project
466            .read(cx)
467            .worktrees(cx)
468            .map(|worktree| {
469                let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
470                async move {
471                    scan_complete.await;
472                }
473            })
474            .collect::<Vec<_>>();
475        let worktree_db_ids = project
476            .read(cx)
477            .worktrees(cx)
478            .map(|worktree| {
479                self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
480            })
481            .collect::<Vec<_>>();
482
483        let language_registry = self.language_registry.clone();
484        let db_update_tx = self.db_update_tx.clone();
485        let parsing_files_tx = self.parsing_files_tx.clone();
486        let next_job_id = self.next_job_id.clone();
487
488        cx.spawn(|this, mut cx| async move {
489            futures::future::join_all(worktree_scans_complete).await;
490
491            let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
492
493            let worktrees = project.read_with(&cx, |project, cx| {
494                project
495                    .worktrees(cx)
496                    .map(|worktree| worktree.read(cx).snapshot())
497                    .collect::<Vec<_>>()
498            });
499
500            let mut worktree_file_mtimes = HashMap::new();
501            let mut db_ids_by_worktree_id = HashMap::new();
502            for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
503                let db_id = db_id?;
504                db_ids_by_worktree_id.insert(worktree.id(), db_id);
505                worktree_file_mtimes.insert(
506                    worktree.id(),
507                    this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
508                        .await?,
509                );
510            }
511
512            // let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc<Language>, SystemTime), SystemTime))> = vec![];
513            let outstanding_jobs = Arc::new(Mutex::new(HashSet::new()));
514            this.update(&mut cx, |this, _| {
515                this.projects.insert(
516                    project.downgrade(),
517                    ProjectState {
518                        worktree_db_ids: db_ids_by_worktree_id
519                            .iter()
520                            .map(|(a, b)| (*a, *b))
521                            .collect(),
522                        outstanding_jobs: outstanding_jobs.clone(),
523                    },
524                );
525            });
526
527            cx.background()
528                .spawn(async move {
529                    let mut count = 0;
530                    let t0 = Instant::now();
531                    for worktree in worktrees.into_iter() {
532                        let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap();
533                        for file in worktree.files(false, 0) {
534                            let absolute_path = worktree.absolutize(&file.path);
535
536                            if let Ok(language) = language_registry
537                                .language_for_file(&absolute_path, None)
538                                .await
539                            {
540                                if language
541                                    .grammar()
542                                    .and_then(|grammar| grammar.embedding_config.as_ref())
543                                    .is_none()
544                                {
545                                    continue;
546                                }
547
548                                let path_buf = file.path.to_path_buf();
549                                let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
550                                let already_stored = stored_mtime
551                                    .map_or(false, |existing_mtime| existing_mtime == file.mtime);
552
553                                if !already_stored {
554                                    log::trace!("sending for parsing: {:?}", path_buf);
555                                    count += 1;
556                                    let job_id = next_job_id.fetch_add(1, atomic::Ordering::SeqCst);
557                                    let job_handle = JobHandle {
558                                        id: job_id,
559                                        set: Arc::downgrade(&outstanding_jobs),
560                                    };
561                                    outstanding_jobs.lock().insert(job_id);
562                                    parsing_files_tx
563                                        .try_send(PendingFile {
564                                            worktree_db_id: db_ids_by_worktree_id[&worktree.id()],
565                                            relative_path: path_buf,
566                                            absolute_path,
567                                            language,
568                                            job_handle,
569                                            modified_time: file.mtime,
570                                        })
571                                        .unwrap();
572                                }
573                            }
574                        }
575                        for file in file_mtimes.keys() {
576                            db_update_tx
577                                .try_send(DbOperation::Delete {
578                                    worktree_id: db_ids_by_worktree_id[&worktree.id()],
579                                    path: file.to_owned(),
580                                })
581                                .unwrap();
582                        }
583                    }
584                    log::trace!(
585                        "parsing worktree completed in {:?}",
586                        t0.elapsed().as_millis()
587                    );
588
589                    Ok(count)
590                })
591                .await
592        })
593    }
594
595    pub fn remaining_files_to_index_for_project(
596        &self,
597        project: &ModelHandle<Project>,
598    ) -> Option<usize> {
599        Some(
600            self.projects
601                .get(&project.downgrade())?
602                .outstanding_jobs
603                .lock()
604                .len(),
605        )
606    }
607
608    pub fn search_project(
609        &mut self,
610        project: ModelHandle<Project>,
611        phrase: String,
612        limit: usize,
613        cx: &mut ModelContext<Self>,
614    ) -> Task<Result<Vec<SearchResult>>> {
615        let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
616            state
617        } else {
618            return Task::ready(Err(anyhow!("project not added")));
619        };
620
621        let worktree_db_ids = project
622            .read(cx)
623            .worktrees(cx)
624            .filter_map(|worktree| {
625                let worktree_id = worktree.read(cx).id();
626                project_state.db_id_for_worktree_id(worktree_id)
627            })
628            .collect::<Vec<_>>();
629
630        let embedding_provider = self.embedding_provider.clone();
631        let database_url = self.database_url.clone();
632        let fs = self.fs.clone();
633        cx.spawn(|this, cx| async move {
634            let documents = cx
635                .background()
636                .spawn(async move {
637                    let database = VectorDatabase::new(fs, database_url).await?;
638
639                    let phrase_embedding = embedding_provider
640                        .embed_batch(vec![&phrase])
641                        .await?
642                        .into_iter()
643                        .next()
644                        .unwrap();
645
646                    database.top_k_search(&worktree_db_ids, &phrase_embedding, limit)
647                })
648                .await?;
649
650            this.read_with(&cx, |this, _| {
651                let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
652                    state
653                } else {
654                    return Err(anyhow!("project not added"));
655                };
656
657                Ok(documents
658                    .into_iter()
659                    .filter_map(|(worktree_db_id, file_path, byte_range, name)| {
660                        let worktree_id = project_state.worktree_id_for_db_id(worktree_db_id)?;
661                        Some(SearchResult {
662                            worktree_id,
663                            name,
664                            byte_range,
665                            file_path,
666                        })
667                    })
668                    .collect())
669            })
670        })
671    }
672}
673
674impl Entity for VectorStore {
675    type Event = ();
676}
677
678impl Drop for JobHandle {
679    fn drop(&mut self) {
680        if let Some(set) = self.set.upgrade() {
681            set.lock().remove(&self.id);
682        }
683    }
684}