semantic_index.rs

   1mod chunking;
   2mod embedding;
   3
   4use anyhow::{anyhow, Context as _, Result};
   5use chunking::{chunk_text, Chunk};
   6use collections::{Bound, HashMap, HashSet};
   7pub use embedding::*;
   8use fs::Fs;
   9use futures::stream::StreamExt;
  10use futures_batch::ChunksTimeoutStreamExt;
  11use gpui::{
  12    AppContext, AsyncAppContext, BorrowAppContext, Context, Entity, EntityId, EventEmitter, Global,
  13    Model, ModelContext, Subscription, Task, WeakModel,
  14};
  15use heed::types::{SerdeBincode, Str};
  16use language::LanguageRegistry;
  17use parking_lot::Mutex;
  18use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree};
  19use serde::{Deserialize, Serialize};
  20use smol::channel;
  21use std::{
  22    cmp::Ordering,
  23    future::Future,
  24    iter,
  25    num::NonZeroUsize,
  26    ops::Range,
  27    path::{Path, PathBuf},
  28    sync::{Arc, Weak},
  29    time::{Duration, SystemTime},
  30};
  31use util::ResultExt;
  32use worktree::LocalSnapshot;
  33
  34pub struct SemanticIndex {
  35    embedding_provider: Arc<dyn EmbeddingProvider>,
  36    db_connection: heed::Env,
  37    project_indices: HashMap<WeakModel<Project>, Model<ProjectIndex>>,
  38}
  39
  40impl Global for SemanticIndex {}
  41
  42impl SemanticIndex {
  43    pub async fn new(
  44        db_path: PathBuf,
  45        embedding_provider: Arc<dyn EmbeddingProvider>,
  46        cx: &mut AsyncAppContext,
  47    ) -> Result<Self> {
  48        let db_connection = cx
  49            .background_executor()
  50            .spawn(async move {
  51                std::fs::create_dir_all(&db_path)?;
  52                unsafe {
  53                    heed::EnvOpenOptions::new()
  54                        .map_size(1024 * 1024 * 1024)
  55                        .max_dbs(3000)
  56                        .open(db_path)
  57                }
  58            })
  59            .await
  60            .context("opening database connection")?;
  61
  62        Ok(SemanticIndex {
  63            db_connection,
  64            embedding_provider,
  65            project_indices: HashMap::default(),
  66        })
  67    }
  68
  69    pub fn project_index(
  70        &mut self,
  71        project: Model<Project>,
  72        cx: &mut AppContext,
  73    ) -> Model<ProjectIndex> {
  74        let project_weak = project.downgrade();
  75        project.update(cx, move |_, cx| {
  76            cx.on_release(move |_, cx| {
  77                if cx.has_global::<SemanticIndex>() {
  78                    cx.update_global::<SemanticIndex, _>(|this, _| {
  79                        this.project_indices.remove(&project_weak);
  80                    })
  81                }
  82            })
  83            .detach();
  84        });
  85
  86        self.project_indices
  87            .entry(project.downgrade())
  88            .or_insert_with(|| {
  89                cx.new_model(|cx| {
  90                    ProjectIndex::new(
  91                        project,
  92                        self.db_connection.clone(),
  93                        self.embedding_provider.clone(),
  94                        cx,
  95                    )
  96                })
  97            })
  98            .clone()
  99    }
 100}
 101
 102pub struct ProjectIndex {
 103    db_connection: heed::Env,
 104    project: WeakModel<Project>,
 105    worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
 106    language_registry: Arc<LanguageRegistry>,
 107    fs: Arc<dyn Fs>,
 108    last_status: Status,
 109    status_tx: channel::Sender<()>,
 110    embedding_provider: Arc<dyn EmbeddingProvider>,
 111    _maintain_status: Task<()>,
 112    _subscription: Subscription,
 113}
 114
 115enum WorktreeIndexHandle {
 116    Loading { _task: Task<Result<()>> },
 117    Loaded { index: Model<WorktreeIndex> },
 118}
 119
 120impl ProjectIndex {
 121    fn new(
 122        project: Model<Project>,
 123        db_connection: heed::Env,
 124        embedding_provider: Arc<dyn EmbeddingProvider>,
 125        cx: &mut ModelContext<Self>,
 126    ) -> Self {
 127        let language_registry = project.read(cx).languages().clone();
 128        let fs = project.read(cx).fs().clone();
 129        let (status_tx, mut status_rx) = channel::unbounded();
 130        let mut this = ProjectIndex {
 131            db_connection,
 132            project: project.downgrade(),
 133            worktree_indices: HashMap::default(),
 134            language_registry,
 135            fs,
 136            status_tx,
 137            last_status: Status::Idle,
 138            embedding_provider,
 139            _subscription: cx.subscribe(&project, Self::handle_project_event),
 140            _maintain_status: cx.spawn(|this, mut cx| async move {
 141                while status_rx.next().await.is_some() {
 142                    if this
 143                        .update(&mut cx, |this, cx| this.update_status(cx))
 144                        .is_err()
 145                    {
 146                        break;
 147                    }
 148                }
 149            }),
 150        };
 151        this.update_worktree_indices(cx);
 152        this
 153    }
 154
 155    pub fn status(&self) -> Status {
 156        self.last_status
 157    }
 158
 159    fn handle_project_event(
 160        &mut self,
 161        _: Model<Project>,
 162        event: &project::Event,
 163        cx: &mut ModelContext<Self>,
 164    ) {
 165        match event {
 166            project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
 167                self.update_worktree_indices(cx);
 168            }
 169            _ => {}
 170        }
 171    }
 172
 173    fn update_worktree_indices(&mut self, cx: &mut ModelContext<Self>) {
 174        let Some(project) = self.project.upgrade() else {
 175            return;
 176        };
 177
 178        let worktrees = project
 179            .read(cx)
 180            .visible_worktrees(cx)
 181            .filter_map(|worktree| {
 182                if worktree.read(cx).is_local() {
 183                    Some((worktree.entity_id(), worktree))
 184                } else {
 185                    None
 186                }
 187            })
 188            .collect::<HashMap<_, _>>();
 189
 190        self.worktree_indices
 191            .retain(|worktree_id, _| worktrees.contains_key(worktree_id));
 192        for (worktree_id, worktree) in worktrees {
 193            self.worktree_indices.entry(worktree_id).or_insert_with(|| {
 194                let worktree_index = WorktreeIndex::load(
 195                    worktree.clone(),
 196                    self.db_connection.clone(),
 197                    self.language_registry.clone(),
 198                    self.fs.clone(),
 199                    self.status_tx.clone(),
 200                    self.embedding_provider.clone(),
 201                    cx,
 202                );
 203
 204                let load_worktree = cx.spawn(|this, mut cx| async move {
 205                    if let Some(worktree_index) = worktree_index.await.log_err() {
 206                        this.update(&mut cx, |this, _| {
 207                            this.worktree_indices.insert(
 208                                worktree_id,
 209                                WorktreeIndexHandle::Loaded {
 210                                    index: worktree_index,
 211                                },
 212                            );
 213                        })?;
 214                    } else {
 215                        this.update(&mut cx, |this, _cx| {
 216                            this.worktree_indices.remove(&worktree_id)
 217                        })?;
 218                    }
 219
 220                    this.update(&mut cx, |this, cx| this.update_status(cx))
 221                });
 222
 223                WorktreeIndexHandle::Loading {
 224                    _task: load_worktree,
 225                }
 226            });
 227        }
 228
 229        self.update_status(cx);
 230    }
 231
 232    fn update_status(&mut self, cx: &mut ModelContext<Self>) {
 233        let mut indexing_count = 0;
 234        let mut any_loading = false;
 235
 236        for index in self.worktree_indices.values_mut() {
 237            match index {
 238                WorktreeIndexHandle::Loading { .. } => {
 239                    any_loading = true;
 240                    break;
 241                }
 242                WorktreeIndexHandle::Loaded { index, .. } => {
 243                    indexing_count += index.read(cx).entry_ids_being_indexed.len();
 244                }
 245            }
 246        }
 247
 248        let status = if any_loading {
 249            Status::Loading
 250        } else if let Some(remaining_count) = NonZeroUsize::new(indexing_count) {
 251            Status::Scanning { remaining_count }
 252        } else {
 253            Status::Idle
 254        };
 255
 256        if status != self.last_status {
 257            self.last_status = status;
 258            cx.emit(status);
 259        }
 260    }
 261
 262    pub fn search(&self, query: &str, limit: usize, cx: &AppContext) -> Task<Vec<SearchResult>> {
 263        let mut worktree_searches = Vec::new();
 264        for worktree_index in self.worktree_indices.values() {
 265            if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
 266                worktree_searches
 267                    .push(index.read_with(cx, |index, cx| index.search(query, limit, cx)));
 268            }
 269        }
 270
 271        cx.spawn(|_| async move {
 272            let mut results = Vec::new();
 273            let worktree_searches = futures::future::join_all(worktree_searches).await;
 274
 275            for worktree_search_results in worktree_searches {
 276                if let Some(worktree_search_results) = worktree_search_results.log_err() {
 277                    results.extend(worktree_search_results);
 278                }
 279            }
 280
 281            results
 282                .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
 283            results.truncate(limit);
 284
 285            results
 286        })
 287    }
 288
 289    #[cfg(test)]
 290    pub fn path_count(&self, cx: &AppContext) -> Result<u64> {
 291        let mut result = 0;
 292        for worktree_index in self.worktree_indices.values() {
 293            if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
 294                result += index.read(cx).path_count()?;
 295            }
 296        }
 297        Ok(result)
 298    }
 299
 300    pub fn debug(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 301        let indices = self
 302            .worktree_indices
 303            .values()
 304            .filter_map(|worktree_index| {
 305                if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
 306                    Some(index.clone())
 307                } else {
 308                    None
 309                }
 310            })
 311            .collect::<Vec<_>>();
 312
 313        cx.spawn(|_, mut cx| async move {
 314            eprintln!("semantic index contents:");
 315            for index in indices {
 316                index.update(&mut cx, |index, cx| index.debug(cx))?.await?
 317            }
 318            Ok(())
 319        })
 320    }
 321}
 322
 323pub struct SearchResult {
 324    pub worktree: Model<Worktree>,
 325    pub path: Arc<Path>,
 326    pub range: Range<usize>,
 327    pub score: f32,
 328}
 329
 330#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 331pub enum Status {
 332    Idle,
 333    Loading,
 334    Scanning { remaining_count: NonZeroUsize },
 335}
 336
 337impl EventEmitter<Status> for ProjectIndex {}
 338
 339struct WorktreeIndex {
 340    worktree: Model<Worktree>,
 341    db_connection: heed::Env,
 342    db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
 343    language_registry: Arc<LanguageRegistry>,
 344    fs: Arc<dyn Fs>,
 345    embedding_provider: Arc<dyn EmbeddingProvider>,
 346    entry_ids_being_indexed: Arc<IndexingEntrySet>,
 347    _index_entries: Task<Result<()>>,
 348    _subscription: Subscription,
 349}
 350
 351impl WorktreeIndex {
 352    pub fn load(
 353        worktree: Model<Worktree>,
 354        db_connection: heed::Env,
 355        language_registry: Arc<LanguageRegistry>,
 356        fs: Arc<dyn Fs>,
 357        status_tx: channel::Sender<()>,
 358        embedding_provider: Arc<dyn EmbeddingProvider>,
 359        cx: &mut AppContext,
 360    ) -> Task<Result<Model<Self>>> {
 361        let worktree_abs_path = worktree.read(cx).abs_path();
 362        cx.spawn(|mut cx| async move {
 363            let db = cx
 364                .background_executor()
 365                .spawn({
 366                    let db_connection = db_connection.clone();
 367                    async move {
 368                        let mut txn = db_connection.write_txn()?;
 369                        let db_name = worktree_abs_path.to_string_lossy();
 370                        let db = db_connection.create_database(&mut txn, Some(&db_name))?;
 371                        txn.commit()?;
 372                        anyhow::Ok(db)
 373                    }
 374                })
 375                .await?;
 376            cx.new_model(|cx| {
 377                Self::new(
 378                    worktree,
 379                    db_connection,
 380                    db,
 381                    status_tx,
 382                    language_registry,
 383                    fs,
 384                    embedding_provider,
 385                    cx,
 386                )
 387            })
 388        })
 389    }
 390
 391    #[allow(clippy::too_many_arguments)]
 392    fn new(
 393        worktree: Model<Worktree>,
 394        db_connection: heed::Env,
 395        db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
 396        status: channel::Sender<()>,
 397        language_registry: Arc<LanguageRegistry>,
 398        fs: Arc<dyn Fs>,
 399        embedding_provider: Arc<dyn EmbeddingProvider>,
 400        cx: &mut ModelContext<Self>,
 401    ) -> Self {
 402        let (updated_entries_tx, updated_entries_rx) = channel::unbounded();
 403        let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| {
 404            if let worktree::Event::UpdatedEntries(update) = event {
 405                _ = updated_entries_tx.try_send(update.clone());
 406            }
 407        });
 408
 409        Self {
 410            db_connection,
 411            db,
 412            worktree,
 413            language_registry,
 414            fs,
 415            embedding_provider,
 416            entry_ids_being_indexed: Arc::new(IndexingEntrySet::new(status)),
 417            _index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)),
 418            _subscription,
 419        }
 420    }
 421
 422    async fn index_entries(
 423        this: WeakModel<Self>,
 424        updated_entries: channel::Receiver<UpdatedEntriesSet>,
 425        mut cx: AsyncAppContext,
 426    ) -> Result<()> {
 427        let index = this.update(&mut cx, |this, cx| this.index_entries_changed_on_disk(cx))?;
 428        index.await.log_err();
 429
 430        while let Ok(updated_entries) = updated_entries.recv().await {
 431            let index = this.update(&mut cx, |this, cx| {
 432                this.index_updated_entries(updated_entries, cx)
 433            })?;
 434            index.await.log_err();
 435        }
 436
 437        Ok(())
 438    }
 439
 440    fn index_entries_changed_on_disk(&self, cx: &AppContext) -> impl Future<Output = Result<()>> {
 441        let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
 442        let worktree_abs_path = worktree.abs_path().clone();
 443        let scan = self.scan_entries(worktree.clone(), cx);
 444        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
 445        let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
 446        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
 447        async move {
 448            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
 449            Ok(())
 450        }
 451    }
 452
 453    fn index_updated_entries(
 454        &self,
 455        updated_entries: UpdatedEntriesSet,
 456        cx: &AppContext,
 457    ) -> impl Future<Output = Result<()>> {
 458        let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
 459        let worktree_abs_path = worktree.abs_path().clone();
 460        let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
 461        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
 462        let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
 463        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
 464        async move {
 465            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
 466            Ok(())
 467        }
 468    }
 469
 470    fn scan_entries(&self, worktree: LocalSnapshot, cx: &AppContext) -> ScanEntries {
 471        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
 472        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
 473        let db_connection = self.db_connection.clone();
 474        let db = self.db;
 475        let entries_being_indexed = self.entry_ids_being_indexed.clone();
 476        let task = cx.background_executor().spawn(async move {
 477            let txn = db_connection
 478                .read_txn()
 479                .context("failed to create read transaction")?;
 480            let mut db_entries = db
 481                .iter(&txn)
 482                .context("failed to create iterator")?
 483                .move_between_keys()
 484                .peekable();
 485
 486            let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
 487            for entry in worktree.files(false, 0) {
 488                let entry_db_key = db_key_for_path(&entry.path);
 489
 490                let mut saved_mtime = None;
 491                while let Some(db_entry) = db_entries.peek() {
 492                    match db_entry {
 493                        Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
 494                            Ordering::Less => {
 495                                if let Some(deletion_range) = deletion_range.as_mut() {
 496                                    deletion_range.1 = Bound::Included(db_path);
 497                                } else {
 498                                    deletion_range =
 499                                        Some((Bound::Included(db_path), Bound::Included(db_path)));
 500                                }
 501
 502                                db_entries.next();
 503                            }
 504                            Ordering::Equal => {
 505                                if let Some(deletion_range) = deletion_range.take() {
 506                                    deleted_entry_ranges_tx
 507                                        .send((
 508                                            deletion_range.0.map(ToString::to_string),
 509                                            deletion_range.1.map(ToString::to_string),
 510                                        ))
 511                                        .await?;
 512                                }
 513                                saved_mtime = db_embedded_file.mtime;
 514                                db_entries.next();
 515                                break;
 516                            }
 517                            Ordering::Greater => {
 518                                break;
 519                            }
 520                        },
 521                        Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
 522                    }
 523                }
 524
 525                if entry.mtime != saved_mtime {
 526                    let handle = entries_being_indexed.insert(entry.id);
 527                    updated_entries_tx.send((entry.clone(), handle)).await?;
 528                }
 529            }
 530
 531            if let Some(db_entry) = db_entries.next() {
 532                let (db_path, _) = db_entry?;
 533                deleted_entry_ranges_tx
 534                    .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
 535                    .await?;
 536            }
 537
 538            Ok(())
 539        });
 540
 541        ScanEntries {
 542            updated_entries: updated_entries_rx,
 543            deleted_entry_ranges: deleted_entry_ranges_rx,
 544            task,
 545        }
 546    }
 547
 548    fn scan_updated_entries(
 549        &self,
 550        worktree: LocalSnapshot,
 551        updated_entries: UpdatedEntriesSet,
 552        cx: &AppContext,
 553    ) -> ScanEntries {
 554        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
 555        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
 556        let entries_being_indexed = self.entry_ids_being_indexed.clone();
 557        let task = cx.background_executor().spawn(async move {
 558            for (path, entry_id, status) in updated_entries.iter() {
 559                match status {
 560                    project::PathChange::Added
 561                    | project::PathChange::Updated
 562                    | project::PathChange::AddedOrUpdated => {
 563                        if let Some(entry) = worktree.entry_for_id(*entry_id) {
 564                            if entry.is_file() {
 565                                let handle = entries_being_indexed.insert(entry.id);
 566                                updated_entries_tx.send((entry.clone(), handle)).await?;
 567                            }
 568                        }
 569                    }
 570                    project::PathChange::Removed => {
 571                        let db_path = db_key_for_path(path);
 572                        deleted_entry_ranges_tx
 573                            .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
 574                            .await?;
 575                    }
 576                    project::PathChange::Loaded => {
 577                        // Do nothing.
 578                    }
 579                }
 580            }
 581
 582            Ok(())
 583        });
 584
 585        ScanEntries {
 586            updated_entries: updated_entries_rx,
 587            deleted_entry_ranges: deleted_entry_ranges_rx,
 588            task,
 589        }
 590    }
 591
 592    fn chunk_files(
 593        &self,
 594        worktree_abs_path: Arc<Path>,
 595        entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
 596        cx: &AppContext,
 597    ) -> ChunkFiles {
 598        let language_registry = self.language_registry.clone();
 599        let fs = self.fs.clone();
 600        let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
 601        let task = cx.spawn(|cx| async move {
 602            cx.background_executor()
 603                .scoped(|cx| {
 604                    for _ in 0..cx.num_cpus() {
 605                        cx.spawn(async {
 606                            while let Ok((entry, handle)) = entries.recv().await {
 607                                let entry_abs_path = worktree_abs_path.join(&entry.path);
 608                                let Some(text) = fs
 609                                    .load(&entry_abs_path)
 610                                    .await
 611                                    .with_context(|| {
 612                                        format!("failed to read path {entry_abs_path:?}")
 613                                    })
 614                                    .log_err()
 615                                else {
 616                                    continue;
 617                                };
 618                                let language = language_registry
 619                                    .language_for_file_path(&entry.path)
 620                                    .await
 621                                    .ok();
 622                                let grammar =
 623                                    language.as_ref().and_then(|language| language.grammar());
 624                                let chunked_file = ChunkedFile {
 625                                    chunks: chunk_text(&text, grammar),
 626                                    handle,
 627                                    path: entry.path,
 628                                    mtime: entry.mtime,
 629                                    text,
 630                                };
 631
 632                                if chunked_files_tx.send(chunked_file).await.is_err() {
 633                                    return;
 634                                }
 635                            }
 636                        });
 637                    }
 638                })
 639                .await;
 640            Ok(())
 641        });
 642
 643        ChunkFiles {
 644            files: chunked_files_rx,
 645            task,
 646        }
 647    }
 648
 649    fn embed_files(
 650        embedding_provider: Arc<dyn EmbeddingProvider>,
 651        chunked_files: channel::Receiver<ChunkedFile>,
 652        cx: &AppContext,
 653    ) -> EmbedFiles {
 654        let embedding_provider = embedding_provider.clone();
 655        let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
 656        let task = cx.background_executor().spawn(async move {
 657            let mut chunked_file_batches =
 658                chunked_files.chunks_timeout(512, Duration::from_secs(2));
 659            while let Some(chunked_files) = chunked_file_batches.next().await {
 660                // View the batch of files as a vec of chunks
 661                // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
 662                // Once those are done, reassemble them back into the files in which they belong
 663                // If any embeddings fail for a file, the entire file is discarded
 664
 665                let chunks: Vec<TextToEmbed> = chunked_files
 666                    .iter()
 667                    .flat_map(|file| {
 668                        file.chunks.iter().map(|chunk| TextToEmbed {
 669                            text: &file.text[chunk.range.clone()],
 670                            digest: chunk.digest,
 671                        })
 672                    })
 673                    .collect::<Vec<_>>();
 674
 675                let mut embeddings: Vec<Option<Embedding>> = Vec::new();
 676                for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
 677                    if let Some(batch_embeddings) =
 678                        embedding_provider.embed(embedding_batch).await.log_err()
 679                    {
 680                        if batch_embeddings.len() == embedding_batch.len() {
 681                            embeddings.extend(batch_embeddings.into_iter().map(Some));
 682                            continue;
 683                        }
 684                        log::error!(
 685                            "embedding provider returned unexpected embedding count {}, expected {}",
 686                            batch_embeddings.len(), embedding_batch.len()
 687                        );
 688                    }
 689
 690                    embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
 691                }
 692
 693                let mut embeddings = embeddings.into_iter();
 694                for chunked_file in chunked_files {
 695                    let mut embedded_file = EmbeddedFile {
 696                        path: chunked_file.path,
 697                        mtime: chunked_file.mtime,
 698                        chunks: Vec::new(),
 699                    };
 700
 701                    let mut embedded_all_chunks = true;
 702                    for (chunk, embedding) in
 703                        chunked_file.chunks.into_iter().zip(embeddings.by_ref())
 704                    {
 705                        if let Some(embedding) = embedding {
 706                            embedded_file
 707                                .chunks
 708                                .push(EmbeddedChunk { chunk, embedding });
 709                        } else {
 710                            embedded_all_chunks = false;
 711                        }
 712                    }
 713
 714                    if embedded_all_chunks {
 715                        embedded_files_tx
 716                            .send((embedded_file, chunked_file.handle))
 717                            .await?;
 718                    }
 719                }
 720            }
 721            Ok(())
 722        });
 723
 724        EmbedFiles {
 725            files: embedded_files_rx,
 726            task,
 727        }
 728    }
 729
 730    fn persist_embeddings(
 731        &self,
 732        mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
 733        embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
 734        cx: &AppContext,
 735    ) -> Task<Result<()>> {
 736        let db_connection = self.db_connection.clone();
 737        let db = self.db;
 738        cx.background_executor().spawn(async move {
 739            while let Some(deletion_range) = deleted_entry_ranges.next().await {
 740                let mut txn = db_connection.write_txn()?;
 741                let start = deletion_range.0.as_ref().map(|start| start.as_str());
 742                let end = deletion_range.1.as_ref().map(|end| end.as_str());
 743                log::debug!("deleting embeddings in range {:?}", &(start, end));
 744                db.delete_range(&mut txn, &(start, end))?;
 745                txn.commit()?;
 746            }
 747
 748            let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
 749            while let Some(embedded_files) = embedded_files.next().await {
 750                let mut txn = db_connection.write_txn()?;
 751                for (file, _) in &embedded_files {
 752                    log::debug!("saving embedding for file {:?}", file.path);
 753                    let key = db_key_for_path(&file.path);
 754                    db.put(&mut txn, &key, file)?;
 755                }
 756                txn.commit()?;
 757                eprintln!("committed {:?}", embedded_files.len());
 758
 759                drop(embedded_files);
 760                log::debug!("committed");
 761            }
 762
 763            Ok(())
 764        })
 765    }
 766
 767    fn search(
 768        &self,
 769        query: &str,
 770        limit: usize,
 771        cx: &AppContext,
 772    ) -> Task<Result<Vec<SearchResult>>> {
 773        let (chunks_tx, chunks_rx) = channel::bounded(1024);
 774
 775        let db_connection = self.db_connection.clone();
 776        let db = self.db;
 777        let scan_chunks = cx.background_executor().spawn({
 778            async move {
 779                let txn = db_connection
 780                    .read_txn()
 781                    .context("failed to create read transaction")?;
 782                let db_entries = db.iter(&txn).context("failed to iterate database")?;
 783                for db_entry in db_entries {
 784                    let (_key, db_embedded_file) = db_entry?;
 785                    for chunk in db_embedded_file.chunks {
 786                        chunks_tx
 787                            .send((db_embedded_file.path.clone(), chunk))
 788                            .await?;
 789                    }
 790                }
 791                anyhow::Ok(())
 792            }
 793        });
 794
 795        let query = query.to_string();
 796        let embedding_provider = self.embedding_provider.clone();
 797        let worktree = self.worktree.clone();
 798        cx.spawn(|cx| async move {
 799            #[cfg(debug_assertions)]
 800            let embedding_query_start = std::time::Instant::now();
 801            log::info!("Searching for {query}");
 802
 803            let mut query_embeddings = embedding_provider
 804                .embed(&[TextToEmbed::new(&query)])
 805                .await?;
 806            let query_embedding = query_embeddings
 807                .pop()
 808                .ok_or_else(|| anyhow!("no embedding for query"))?;
 809            let mut workers = Vec::new();
 810            for _ in 0..cx.background_executor().num_cpus() {
 811                workers.push(Vec::<SearchResult>::new());
 812            }
 813
 814            #[cfg(debug_assertions)]
 815            let search_start = std::time::Instant::now();
 816
 817            cx.background_executor()
 818                .scoped(|cx| {
 819                    for worker_results in workers.iter_mut() {
 820                        cx.spawn(async {
 821                            while let Ok((path, embedded_chunk)) = chunks_rx.recv().await {
 822                                let score = embedded_chunk.embedding.similarity(&query_embedding);
 823                                let ix = match worker_results.binary_search_by(|probe| {
 824                                    score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
 825                                }) {
 826                                    Ok(ix) | Err(ix) => ix,
 827                                };
 828                                worker_results.insert(
 829                                    ix,
 830                                    SearchResult {
 831                                        worktree: worktree.clone(),
 832                                        path: path.clone(),
 833                                        range: embedded_chunk.chunk.range.clone(),
 834                                        score,
 835                                    },
 836                                );
 837                                worker_results.truncate(limit);
 838                            }
 839                        });
 840                    }
 841                })
 842                .await;
 843            scan_chunks.await?;
 844
 845            let mut search_results = Vec::with_capacity(workers.len() * limit);
 846            for worker_results in workers {
 847                search_results.extend(worker_results);
 848            }
 849            search_results
 850                .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
 851            search_results.truncate(limit);
 852            #[cfg(debug_assertions)]
 853            {
 854                let search_elapsed = search_start.elapsed();
 855                log::debug!(
 856                    "searched {} entries in {:?}",
 857                    search_results.len(),
 858                    search_elapsed
 859                );
 860                let embedding_query_elapsed = embedding_query_start.elapsed();
 861                log::debug!("embedding query took {:?}", embedding_query_elapsed);
 862            }
 863
 864            Ok(search_results)
 865        })
 866    }
 867
 868    fn debug(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 869        let connection = self.db_connection.clone();
 870        let db = self.db;
 871        cx.background_executor().spawn(async move {
 872            let tx = connection
 873                .read_txn()
 874                .context("failed to create read transaction")?;
 875            for record in db.iter(&tx)? {
 876                let (key, _) = record?;
 877                eprintln!("{}", path_for_db_key(key));
 878            }
 879            Ok(())
 880        })
 881    }
 882
 883    #[cfg(test)]
 884    fn path_count(&self) -> Result<u64> {
 885        let txn = self
 886            .db_connection
 887            .read_txn()
 888            .context("failed to create read transaction")?;
 889        Ok(self.db.len(&txn)?)
 890    }
 891}
 892
 893struct ScanEntries {
 894    updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
 895    deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
 896    task: Task<Result<()>>,
 897}
 898
 899struct ChunkFiles {
 900    files: channel::Receiver<ChunkedFile>,
 901    task: Task<Result<()>>,
 902}
 903
 904struct ChunkedFile {
 905    pub path: Arc<Path>,
 906    pub mtime: Option<SystemTime>,
 907    pub handle: IndexingEntryHandle,
 908    pub text: String,
 909    pub chunks: Vec<Chunk>,
 910}
 911
 912struct EmbedFiles {
 913    files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
 914    task: Task<Result<()>>,
 915}
 916
 917#[derive(Debug, Serialize, Deserialize)]
 918struct EmbeddedFile {
 919    path: Arc<Path>,
 920    mtime: Option<SystemTime>,
 921    chunks: Vec<EmbeddedChunk>,
 922}
 923
 924#[derive(Debug, Serialize, Deserialize)]
 925struct EmbeddedChunk {
 926    chunk: Chunk,
 927    embedding: Embedding,
 928}
 929
 930/// The set of entries that are currently being indexed.
 931struct IndexingEntrySet {
 932    entry_ids: Mutex<HashSet<ProjectEntryId>>,
 933    tx: channel::Sender<()>,
 934}
 935
 936/// When dropped, removes the entry from the set of entries that are being indexed.
 937#[derive(Clone)]
 938struct IndexingEntryHandle {
 939    entry_id: ProjectEntryId,
 940    set: Weak<IndexingEntrySet>,
 941}
 942
 943impl IndexingEntrySet {
 944    fn new(tx: channel::Sender<()>) -> Self {
 945        Self {
 946            entry_ids: Default::default(),
 947            tx,
 948        }
 949    }
 950
 951    fn insert(self: &Arc<Self>, entry_id: ProjectEntryId) -> IndexingEntryHandle {
 952        self.entry_ids.lock().insert(entry_id);
 953        self.tx.send_blocking(()).ok();
 954        IndexingEntryHandle {
 955            entry_id,
 956            set: Arc::downgrade(self),
 957        }
 958    }
 959
 960    pub fn len(&self) -> usize {
 961        self.entry_ids.lock().len()
 962    }
 963}
 964
 965impl Drop for IndexingEntryHandle {
 966    fn drop(&mut self) {
 967        if let Some(set) = self.set.upgrade() {
 968            set.tx.send_blocking(()).ok();
 969            set.entry_ids.lock().remove(&self.entry_id);
 970        }
 971    }
 972}
 973
 974fn db_key_for_path(path: &Arc<Path>) -> String {
 975    path.to_string_lossy().replace('/', "\0")
 976}
 977
 978fn path_for_db_key(key: &str) -> String {
 979    key.replace('\0', "/")
 980}
 981
 982#[cfg(test)]
 983mod tests {
 984    use super::*;
 985    use futures::{future::BoxFuture, FutureExt};
 986    use gpui::TestAppContext;
 987    use language::language_settings::AllLanguageSettings;
 988    use project::Project;
 989    use settings::SettingsStore;
 990    use std::{future, path::Path, sync::Arc};
 991
 992    fn init_test(cx: &mut TestAppContext) {
 993        _ = cx.update(|cx| {
 994            let store = SettingsStore::test(cx);
 995            cx.set_global(store);
 996            language::init(cx);
 997            Project::init_settings(cx);
 998            SettingsStore::update(cx, |store, cx| {
 999                store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
1000            });
1001        });
1002    }
1003
1004    pub struct TestEmbeddingProvider {
1005        batch_size: usize,
1006        compute_embedding: Box<dyn Fn(&str) -> Result<Embedding> + Send + Sync>,
1007    }
1008
1009    impl TestEmbeddingProvider {
1010        pub fn new(
1011            batch_size: usize,
1012            compute_embedding: impl 'static + Fn(&str) -> Result<Embedding> + Send + Sync,
1013        ) -> Self {
1014            return Self {
1015                batch_size,
1016                compute_embedding: Box::new(compute_embedding),
1017            };
1018        }
1019    }
1020
1021    impl EmbeddingProvider for TestEmbeddingProvider {
1022        fn embed<'a>(
1023            &'a self,
1024            texts: &'a [TextToEmbed<'a>],
1025        ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
1026            let embeddings = texts
1027                .iter()
1028                .map(|to_embed| (self.compute_embedding)(to_embed.text))
1029                .collect();
1030            future::ready(embeddings).boxed()
1031        }
1032
1033        fn batch_size(&self) -> usize {
1034            self.batch_size
1035        }
1036    }
1037
1038    #[gpui::test]
1039    async fn test_search(cx: &mut TestAppContext) {
1040        cx.executor().allow_parking();
1041
1042        init_test(cx);
1043
1044        let temp_dir = tempfile::tempdir().unwrap();
1045
1046        let mut semantic_index = SemanticIndex::new(
1047            temp_dir.path().into(),
1048            Arc::new(TestEmbeddingProvider::new(16, |text| {
1049                let mut embedding = vec![0f32; 2];
1050                // if the text contains garbage, give it a 1 in the first dimension
1051                if text.contains("garbage in") {
1052                    embedding[0] = 0.9;
1053                } else {
1054                    embedding[0] = -0.9;
1055                }
1056
1057                if text.contains("garbage out") {
1058                    embedding[1] = 0.9;
1059                } else {
1060                    embedding[1] = -0.9;
1061                }
1062
1063                Ok(Embedding::new(embedding))
1064            })),
1065            &mut cx.to_async(),
1066        )
1067        .await
1068        .unwrap();
1069
1070        let project_path = Path::new("./fixture");
1071
1072        let project = cx
1073            .spawn(|mut cx| async move { Project::example([project_path], &mut cx).await })
1074            .await;
1075
1076        cx.update(|cx| {
1077            let language_registry = project.read(cx).languages().clone();
1078            let node_runtime = project.read(cx).node_runtime().unwrap().clone();
1079            languages::init(language_registry, node_runtime, cx);
1080        });
1081
1082        let project_index = cx.update(|cx| semantic_index.project_index(project.clone(), cx));
1083
1084        while project_index
1085            .read_with(cx, |index, cx| index.path_count(cx))
1086            .unwrap()
1087            == 0
1088        {
1089            project_index.next_event(cx).await;
1090        }
1091
1092        let results = cx
1093            .update(|cx| {
1094                let project_index = project_index.read(cx);
1095                let query = "garbage in, garbage out";
1096                project_index.search(query, 4, cx)
1097            })
1098            .await;
1099
1100        assert!(results.len() > 1, "should have found some results");
1101
1102        for result in &results {
1103            println!("result: {:?}", result.path);
1104            println!("score: {:?}", result.score);
1105        }
1106
1107        // Find result that is greater than 0.5
1108        let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
1109
1110        assert_eq!(search_result.path.to_string_lossy(), "needle.md");
1111
1112        let content = cx
1113            .update(|cx| {
1114                let worktree = search_result.worktree.read(cx);
1115                let entry_abs_path = worktree.abs_path().join(search_result.path.clone());
1116                let fs = project.read(cx).fs().clone();
1117                cx.spawn(|_| async move { fs.load(&entry_abs_path).await.unwrap() })
1118            })
1119            .await;
1120
1121        let range = search_result.range.clone();
1122        let content = content[range.clone()].to_owned();
1123
1124        assert!(content.contains("garbage in, garbage out"));
1125    }
1126
1127    #[gpui::test]
1128    async fn test_embed_files(cx: &mut TestAppContext) {
1129        cx.executor().allow_parking();
1130
1131        let provider = Arc::new(TestEmbeddingProvider::new(3, |text| {
1132            if text.contains('g') {
1133                Err(anyhow!("cannot embed text containing a 'g' character"))
1134            } else {
1135                Ok(Embedding::new(
1136                    ('a'..'z')
1137                        .map(|char| text.chars().filter(|c| *c == char).count() as f32)
1138                        .collect(),
1139                ))
1140            }
1141        }));
1142
1143        let (indexing_progress_tx, _) = channel::unbounded();
1144        let indexing_entries = Arc::new(IndexingEntrySet::new(indexing_progress_tx));
1145
1146        let (chunked_files_tx, chunked_files_rx) = channel::unbounded::<ChunkedFile>();
1147        chunked_files_tx
1148            .send_blocking(ChunkedFile {
1149                path: Path::new("test1.md").into(),
1150                mtime: None,
1151                handle: indexing_entries.insert(ProjectEntryId::from_proto(0)),
1152                text: "abcdefghijklmnop".to_string(),
1153                chunks: [0..4, 4..8, 8..12, 12..16]
1154                    .into_iter()
1155                    .map(|range| Chunk {
1156                        range,
1157                        digest: Default::default(),
1158                    })
1159                    .collect(),
1160            })
1161            .unwrap();
1162        chunked_files_tx
1163            .send_blocking(ChunkedFile {
1164                path: Path::new("test2.md").into(),
1165                mtime: None,
1166                handle: indexing_entries.insert(ProjectEntryId::from_proto(1)),
1167                text: "qrstuvwxyz".to_string(),
1168                chunks: [0..4, 4..8, 8..10]
1169                    .into_iter()
1170                    .map(|range| Chunk {
1171                        range,
1172                        digest: Default::default(),
1173                    })
1174                    .collect(),
1175            })
1176            .unwrap();
1177        chunked_files_tx.close();
1178
1179        let embed_files_task =
1180            cx.update(|cx| WorktreeIndex::embed_files(provider.clone(), chunked_files_rx, cx));
1181        embed_files_task.task.await.unwrap();
1182
1183        let mut embedded_files_rx = embed_files_task.files;
1184        let mut embedded_files = Vec::new();
1185        while let Some((embedded_file, _)) = embedded_files_rx.next().await {
1186            embedded_files.push(embedded_file);
1187        }
1188
1189        assert_eq!(embedded_files.len(), 1);
1190        assert_eq!(embedded_files[0].path.as_ref(), Path::new("test2.md"));
1191        assert_eq!(
1192            embedded_files[0]
1193                .chunks
1194                .iter()
1195                .map(|embedded_chunk| { embedded_chunk.embedding.clone() })
1196                .collect::<Vec<Embedding>>(),
1197            vec![
1198                (provider.compute_embedding)("qrst").unwrap(),
1199                (provider.compute_embedding)("uvwx").unwrap(),
1200                (provider.compute_embedding)("yz").unwrap(),
1201            ],
1202        );
1203    }
1204}