semantic_index.rs

  1mod db;
  2mod embedding;
  3mod parsing;
  4mod semantic_index_settings;
  5
  6#[cfg(test)]
  7mod semantic_index_tests;
  8
  9use crate::semantic_index_settings::SemanticIndexSettings;
 10use anyhow::{anyhow, Result};
 11use db::VectorDatabase;
 12use embedding::{EmbeddingProvider, OpenAIEmbeddings};
 13use futures::{channel::oneshot, Future};
 14use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
 15use language::{Language, LanguageRegistry};
 16use parking_lot::Mutex;
 17use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES};
 18use postage::watch;
 19use project::{Fs, Project, WorktreeId};
 20use smol::channel;
 21use std::{
 22    collections::HashMap,
 23    mem,
 24    ops::Range,
 25    path::{Path, PathBuf},
 26    sync::{Arc, Weak},
 27    time::SystemTime,
 28};
 29use util::{
 30    channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
 31    http::HttpClient,
 32    paths::EMBEDDINGS_DIR,
 33    ResultExt,
 34};
 35
 36const SEMANTIC_INDEX_VERSION: usize = 3;
 37const EMBEDDINGS_BATCH_SIZE: usize = 150;
 38
 39pub fn init(
 40    fs: Arc<dyn Fs>,
 41    http_client: Arc<dyn HttpClient>,
 42    language_registry: Arc<LanguageRegistry>,
 43    cx: &mut AppContext,
 44) {
 45    settings::register::<SemanticIndexSettings>(cx);
 46
 47    let db_file_path = EMBEDDINGS_DIR
 48        .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
 49        .join("embeddings_db");
 50
 51    if *RELEASE_CHANNEL == ReleaseChannel::Stable
 52        || !settings::get::<SemanticIndexSettings>(cx).enabled
 53    {
 54        return;
 55    }
 56
 57    cx.spawn(move |mut cx| async move {
 58        let semantic_index = SemanticIndex::new(
 59            fs,
 60            db_file_path,
 61            Arc::new(OpenAIEmbeddings {
 62                client: http_client,
 63                executor: cx.background(),
 64            }),
 65            language_registry,
 66            cx.clone(),
 67        )
 68        .await?;
 69
 70        cx.update(|cx| {
 71            cx.set_global(semantic_index.clone());
 72        });
 73
 74        anyhow::Ok(())
 75    })
 76    .detach();
 77}
 78
 79pub struct SemanticIndex {
 80    fs: Arc<dyn Fs>,
 81    database_url: Arc<PathBuf>,
 82    embedding_provider: Arc<dyn EmbeddingProvider>,
 83    language_registry: Arc<LanguageRegistry>,
 84    db_update_tx: channel::Sender<DbOperation>,
 85    parsing_files_tx: channel::Sender<PendingFile>,
 86    _db_update_task: Task<()>,
 87    _embed_batch_task: Task<()>,
 88    _batch_files_task: Task<()>,
 89    _parsing_files_tasks: Vec<Task<()>>,
 90    projects: HashMap<WeakModelHandle<Project>, ProjectState>,
 91}
 92
 93struct ProjectState {
 94    worktree_db_ids: Vec<(WorktreeId, i64)>,
 95    outstanding_job_count_rx: watch::Receiver<usize>,
 96    outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
 97}
 98
 99struct JobHandle {
100    tx: Weak<Mutex<watch::Sender<usize>>>,
101}
102
103impl ProjectState {
104    fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
105        self.worktree_db_ids
106            .iter()
107            .find_map(|(worktree_id, db_id)| {
108                if *worktree_id == id {
109                    Some(*db_id)
110                } else {
111                    None
112                }
113            })
114    }
115
116    fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
117        self.worktree_db_ids
118            .iter()
119            .find_map(|(worktree_id, db_id)| {
120                if *db_id == id {
121                    Some(*worktree_id)
122                } else {
123                    None
124                }
125            })
126    }
127}
128
129pub struct PendingFile {
130    worktree_db_id: i64,
131    relative_path: PathBuf,
132    absolute_path: PathBuf,
133    language: Arc<Language>,
134    modified_time: SystemTime,
135    job_handle: JobHandle,
136}
137
138#[derive(Debug, Clone)]
139pub struct SearchResult {
140    pub worktree_id: WorktreeId,
141    pub name: String,
142    pub byte_range: Range<usize>,
143    pub file_path: PathBuf,
144}
145
146enum DbOperation {
147    InsertFile {
148        worktree_id: i64,
149        documents: Vec<Document>,
150        path: PathBuf,
151        mtime: SystemTime,
152        job_handle: JobHandle,
153    },
154    Delete {
155        worktree_id: i64,
156        path: PathBuf,
157    },
158    FindOrCreateWorktree {
159        path: PathBuf,
160        sender: oneshot::Sender<Result<i64>>,
161    },
162    FileMTimes {
163        worktree_id: i64,
164        sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
165    },
166}
167
168enum EmbeddingJob {
169    Enqueue {
170        worktree_id: i64,
171        path: PathBuf,
172        mtime: SystemTime,
173        documents: Vec<Document>,
174        job_handle: JobHandle,
175    },
176    Flush,
177}
178
179impl SemanticIndex {
180    pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
181        if cx.has_global::<ModelHandle<Self>>() {
182            Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
183        } else {
184            None
185        }
186    }
187
188    async fn new(
189        fs: Arc<dyn Fs>,
190        database_url: PathBuf,
191        embedding_provider: Arc<dyn EmbeddingProvider>,
192        language_registry: Arc<LanguageRegistry>,
193        mut cx: AsyncAppContext,
194    ) -> Result<ModelHandle<Self>> {
195        let database_url = Arc::new(database_url);
196
197        let db = cx
198            .background()
199            .spawn(VectorDatabase::new(fs.clone(), database_url.clone()))
200            .await?;
201
202        Ok(cx.add_model(|cx| {
203            // Perform database operations
204            let (db_update_tx, db_update_rx) = channel::unbounded();
205            let _db_update_task = cx.background().spawn({
206                async move {
207                    while let Ok(job) = db_update_rx.recv().await {
208                        Self::run_db_operation(&db, job)
209                    }
210                }
211            });
212
213            // Group documents into batches and send them to the embedding provider.
214            let (embed_batch_tx, embed_batch_rx) =
215                channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
216            let _embed_batch_task = cx.background().spawn({
217                let db_update_tx = db_update_tx.clone();
218                let embedding_provider = embedding_provider.clone();
219                async move {
220                    while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
221                        Self::compute_embeddings_for_batch(
222                            embeddings_queue,
223                            &embedding_provider,
224                            &db_update_tx,
225                        )
226                        .await;
227                    }
228                }
229            });
230
231            // Group documents into batches and send them to the embedding provider.
232            let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
233            let _batch_files_task = cx.background().spawn(async move {
234                let mut queue_len = 0;
235                let mut embeddings_queue = vec![];
236                while let Ok(job) = batch_files_rx.recv().await {
237                    Self::enqueue_documents_to_embed(
238                        job,
239                        &mut queue_len,
240                        &mut embeddings_queue,
241                        &embed_batch_tx,
242                    );
243                }
244            });
245
246            // Parse files into embeddable documents.
247            let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
248            let mut _parsing_files_tasks = Vec::new();
249            for _ in 0..cx.background().num_cpus() {
250                let fs = fs.clone();
251                let parsing_files_rx = parsing_files_rx.clone();
252                let batch_files_tx = batch_files_tx.clone();
253                let db_update_tx = db_update_tx.clone();
254                _parsing_files_tasks.push(cx.background().spawn(async move {
255                    let mut retriever = CodeContextRetriever::new();
256                    while let Ok(pending_file) = parsing_files_rx.recv().await {
257                        Self::parse_file(
258                            &fs,
259                            pending_file,
260                            &mut retriever,
261                            &batch_files_tx,
262                            &parsing_files_rx,
263                            &db_update_tx,
264                        )
265                        .await;
266                    }
267                }));
268            }
269
270            Self {
271                fs,
272                database_url,
273                embedding_provider,
274                language_registry,
275                db_update_tx,
276                parsing_files_tx,
277                _db_update_task,
278                _embed_batch_task,
279                _batch_files_task,
280                _parsing_files_tasks,
281                projects: HashMap::new(),
282            }
283        }))
284    }
285
286    fn run_db_operation(db: &VectorDatabase, job: DbOperation) {
287        match job {
288            DbOperation::InsertFile {
289                worktree_id,
290                documents,
291                path,
292                mtime,
293                job_handle,
294            } => {
295                db.insert_file(worktree_id, path, mtime, documents)
296                    .log_err();
297                drop(job_handle)
298            }
299            DbOperation::Delete { worktree_id, path } => {
300                db.delete_file(worktree_id, path).log_err();
301            }
302            DbOperation::FindOrCreateWorktree { path, sender } => {
303                let id = db.find_or_create_worktree(&path);
304                sender.send(id).ok();
305            }
306            DbOperation::FileMTimes {
307                worktree_id: worktree_db_id,
308                sender,
309            } => {
310                let file_mtimes = db.get_file_mtimes(worktree_db_id);
311                sender.send(file_mtimes).ok();
312            }
313        }
314    }
315
316    async fn compute_embeddings_for_batch(
317        mut embeddings_queue: Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
318        embedding_provider: &Arc<dyn EmbeddingProvider>,
319        db_update_tx: &channel::Sender<DbOperation>,
320    ) {
321        let mut batch_documents = vec![];
322        for (_, documents, _, _, _) in embeddings_queue.iter() {
323            batch_documents.extend(documents.iter().map(|document| document.content.as_str()));
324        }
325
326        if let Ok(embeddings) = embedding_provider.embed_batch(batch_documents).await {
327            log::trace!(
328                "created {} embeddings for {} files",
329                embeddings.len(),
330                embeddings_queue.len(),
331            );
332
333            let mut i = 0;
334            let mut j = 0;
335
336            for embedding in embeddings.iter() {
337                while embeddings_queue[i].1.len() == j {
338                    i += 1;
339                    j = 0;
340                }
341
342                embeddings_queue[i].1[j].embedding = embedding.to_owned();
343                j += 1;
344            }
345
346            for (worktree_id, documents, path, mtime, job_handle) in embeddings_queue.into_iter() {
347                // for document in documents.iter() {
348                //     // TODO: Update this so it doesn't panic
349                //     assert!(
350                //         document.embedding.len() > 0,
351                //         "Document Embedding Not Complete"
352                //     );
353                // }
354
355                db_update_tx
356                    .send(DbOperation::InsertFile {
357                        worktree_id,
358                        documents,
359                        path,
360                        mtime,
361                        job_handle,
362                    })
363                    .await
364                    .unwrap();
365            }
366        }
367    }
368
369    fn enqueue_documents_to_embed(
370        job: EmbeddingJob,
371        queue_len: &mut usize,
372        embeddings_queue: &mut Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
373        embed_batch_tx: &channel::Sender<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>,
374    ) {
375        let should_flush = match job {
376            EmbeddingJob::Enqueue {
377                documents,
378                worktree_id,
379                path,
380                mtime,
381                job_handle,
382            } => {
383                *queue_len += &documents.len();
384                embeddings_queue.push((worktree_id, documents, path, mtime, job_handle));
385                *queue_len >= EMBEDDINGS_BATCH_SIZE
386            }
387            EmbeddingJob::Flush => true,
388        };
389
390        if should_flush {
391            embed_batch_tx
392                .try_send(mem::take(embeddings_queue))
393                .unwrap();
394            *queue_len = 0;
395        }
396    }
397
398    async fn parse_file(
399        fs: &Arc<dyn Fs>,
400        pending_file: PendingFile,
401        retriever: &mut CodeContextRetriever,
402        batch_files_tx: &channel::Sender<EmbeddingJob>,
403        parsing_files_rx: &channel::Receiver<PendingFile>,
404        db_update_tx: &channel::Sender<DbOperation>,
405    ) {
406        if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
407            if let Some(documents) = retriever
408                .parse_file(&pending_file.relative_path, &content, pending_file.language)
409                .log_err()
410            {
411                log::trace!(
412                    "parsed path {:?}: {} documents",
413                    pending_file.relative_path,
414                    documents.len()
415                );
416
417                if documents.len() == 0 {
418                    db_update_tx
419                        .send(DbOperation::InsertFile {
420                            worktree_id: pending_file.worktree_db_id,
421                            documents,
422                            path: pending_file.relative_path,
423                            mtime: pending_file.modified_time,
424                            job_handle: pending_file.job_handle,
425                        })
426                        .await
427                        .unwrap();
428                } else {
429                    batch_files_tx
430                        .try_send(EmbeddingJob::Enqueue {
431                            worktree_id: pending_file.worktree_db_id,
432                            path: pending_file.relative_path,
433                            mtime: pending_file.modified_time,
434                            job_handle: pending_file.job_handle,
435                            documents,
436                        })
437                        .unwrap();
438                }
439            }
440        }
441
442        if parsing_files_rx.len() == 0 {
443            batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
444        }
445    }
446
447    fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
448        let (tx, rx) = oneshot::channel();
449        self.db_update_tx
450            .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
451            .unwrap();
452        async move { rx.await? }
453    }
454
455    fn get_file_mtimes(
456        &self,
457        worktree_id: i64,
458    ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
459        let (tx, rx) = oneshot::channel();
460        self.db_update_tx
461            .try_send(DbOperation::FileMTimes {
462                worktree_id,
463                sender: tx,
464            })
465            .unwrap();
466        async move { rx.await? }
467    }
468
469    pub fn index_project(
470        &mut self,
471        project: ModelHandle<Project>,
472        cx: &mut ModelContext<Self>,
473    ) -> Task<Result<(usize, watch::Receiver<usize>)>> {
474        let worktree_scans_complete = project
475            .read(cx)
476            .worktrees(cx)
477            .map(|worktree| {
478                let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
479                async move {
480                    scan_complete.await;
481                }
482            })
483            .collect::<Vec<_>>();
484        let worktree_db_ids = project
485            .read(cx)
486            .worktrees(cx)
487            .map(|worktree| {
488                self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
489            })
490            .collect::<Vec<_>>();
491
492        let language_registry = self.language_registry.clone();
493        let db_update_tx = self.db_update_tx.clone();
494        let parsing_files_tx = self.parsing_files_tx.clone();
495
496        cx.spawn(|this, mut cx| async move {
497            futures::future::join_all(worktree_scans_complete).await;
498
499            let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
500
501            let worktrees = project.read_with(&cx, |project, cx| {
502                project
503                    .worktrees(cx)
504                    .map(|worktree| worktree.read(cx).snapshot())
505                    .collect::<Vec<_>>()
506            });
507
508            let mut worktree_file_mtimes = HashMap::new();
509            let mut db_ids_by_worktree_id = HashMap::new();
510            for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
511                let db_id = db_id?;
512                db_ids_by_worktree_id.insert(worktree.id(), db_id);
513                worktree_file_mtimes.insert(
514                    worktree.id(),
515                    this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
516                        .await?,
517                );
518            }
519
520            let (job_count_tx, job_count_rx) = watch::channel_with(0);
521            let job_count_tx = Arc::new(Mutex::new(job_count_tx));
522            this.update(&mut cx, |this, _| {
523                this.projects.insert(
524                    project.downgrade(),
525                    ProjectState {
526                        worktree_db_ids: db_ids_by_worktree_id
527                            .iter()
528                            .map(|(a, b)| (*a, *b))
529                            .collect(),
530                        outstanding_job_count_rx: job_count_rx.clone(),
531                        outstanding_job_count_tx: job_count_tx.clone(),
532                    },
533                );
534            });
535
536            cx.background()
537                .spawn(async move {
538                    let mut count = 0;
539                    for worktree in worktrees.into_iter() {
540                        let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap();
541                        for file in worktree.files(false, 0) {
542                            let absolute_path = worktree.absolutize(&file.path);
543
544                            if let Ok(language) = language_registry
545                                .language_for_file(&absolute_path, None)
546                                .await
547                            {
548                                if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
549                                    && language
550                                        .grammar()
551                                        .and_then(|grammar| grammar.embedding_config.as_ref())
552                                        .is_none()
553                                {
554                                    continue;
555                                }
556
557                                let path_buf = file.path.to_path_buf();
558                                let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
559                                let already_stored = stored_mtime
560                                    .map_or(false, |existing_mtime| existing_mtime == file.mtime);
561
562                                if !already_stored {
563                                    count += 1;
564                                    *job_count_tx.lock().borrow_mut() += 1;
565                                    let job_handle = JobHandle {
566                                        tx: Arc::downgrade(&job_count_tx),
567                                    };
568                                    parsing_files_tx
569                                        .try_send(PendingFile {
570                                            worktree_db_id: db_ids_by_worktree_id[&worktree.id()],
571                                            relative_path: path_buf,
572                                            absolute_path,
573                                            language,
574                                            job_handle,
575                                            modified_time: file.mtime,
576                                        })
577                                        .unwrap();
578                                }
579                            }
580                        }
581                        for file in file_mtimes.keys() {
582                            db_update_tx
583                                .try_send(DbOperation::Delete {
584                                    worktree_id: db_ids_by_worktree_id[&worktree.id()],
585                                    path: file.to_owned(),
586                                })
587                                .unwrap();
588                        }
589                    }
590
591                    anyhow::Ok((count, job_count_rx))
592                })
593                .await
594        })
595    }
596
597    pub fn outstanding_job_count_rx(
598        &self,
599        project: &ModelHandle<Project>,
600    ) -> Option<watch::Receiver<usize>> {
601        Some(
602            self.projects
603                .get(&project.downgrade())?
604                .outstanding_job_count_rx
605                .clone(),
606        )
607    }
608
609    pub fn search_project(
610        &mut self,
611        project: ModelHandle<Project>,
612        phrase: String,
613        limit: usize,
614        cx: &mut ModelContext<Self>,
615    ) -> Task<Result<Vec<SearchResult>>> {
616        let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
617            state
618        } else {
619            return Task::ready(Err(anyhow!("project not added")));
620        };
621
622        let worktree_db_ids = project
623            .read(cx)
624            .worktrees(cx)
625            .filter_map(|worktree| {
626                let worktree_id = worktree.read(cx).id();
627                project_state.db_id_for_worktree_id(worktree_id)
628            })
629            .collect::<Vec<_>>();
630
631        let embedding_provider = self.embedding_provider.clone();
632        let database_url = self.database_url.clone();
633        let fs = self.fs.clone();
634        cx.spawn(|this, cx| async move {
635            let documents = cx
636                .background()
637                .spawn(async move {
638                    let database = VectorDatabase::new(fs, database_url).await?;
639
640                    let phrase_embedding = embedding_provider
641                        .embed_batch(vec![&phrase])
642                        .await?
643                        .into_iter()
644                        .next()
645                        .unwrap();
646
647                    database.top_k_search(&worktree_db_ids, &phrase_embedding, limit)
648                })
649                .await?;
650
651            this.read_with(&cx, |this, _| {
652                let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
653                    state
654                } else {
655                    return Err(anyhow!("project not added"));
656                };
657
658                Ok(documents
659                    .into_iter()
660                    .filter_map(|(worktree_db_id, file_path, byte_range, name)| {
661                        let worktree_id = project_state.worktree_id_for_db_id(worktree_db_id)?;
662                        Some(SearchResult {
663                            worktree_id,
664                            name,
665                            byte_range,
666                            file_path,
667                        })
668                    })
669                    .collect())
670            })
671        })
672    }
673}
674
675impl Entity for SemanticIndex {
676    type Event = ();
677}
678
679impl Drop for JobHandle {
680    fn drop(&mut self) {
681        if let Some(tx) = self.tx.upgrade() {
682            let mut tx = tx.lock();
683            *tx.borrow_mut() -= 1;
684        }
685    }
686}