semantic_index.rs

   1mod chunking;
   2mod embedding;
   3mod project_index_debug_view;
   4
   5use anyhow::{anyhow, Context as _, Result};
   6use chunking::{chunk_text, Chunk};
   7use collections::{Bound, HashMap, HashSet};
   8pub use embedding::*;
   9use fs::Fs;
  10use futures::stream::StreamExt;
  11use futures_batch::ChunksTimeoutStreamExt;
  12use gpui::{
  13    AppContext, AsyncAppContext, BorrowAppContext, Context, Entity, EntityId, EventEmitter, Global,
  14    Model, ModelContext, Subscription, Task, WeakModel,
  15};
  16use heed::types::{SerdeBincode, Str};
  17use language::LanguageRegistry;
  18use parking_lot::Mutex;
  19use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree, WorktreeId};
  20use serde::{Deserialize, Serialize};
  21use smol::channel;
  22use std::{
  23    cmp::Ordering,
  24    future::Future,
  25    iter,
  26    num::NonZeroUsize,
  27    ops::Range,
  28    path::{Path, PathBuf},
  29    sync::{Arc, Weak},
  30    time::{Duration, SystemTime},
  31};
  32use util::ResultExt;
  33use worktree::LocalSnapshot;
  34
  35pub use project_index_debug_view::ProjectIndexDebugView;
  36
  37pub struct SemanticIndex {
  38    embedding_provider: Arc<dyn EmbeddingProvider>,
  39    db_connection: heed::Env,
  40    project_indices: HashMap<WeakModel<Project>, Model<ProjectIndex>>,
  41}
  42
  43impl Global for SemanticIndex {}
  44
  45impl SemanticIndex {
  46    pub async fn new(
  47        db_path: PathBuf,
  48        embedding_provider: Arc<dyn EmbeddingProvider>,
  49        cx: &mut AsyncAppContext,
  50    ) -> Result<Self> {
  51        let db_connection = cx
  52            .background_executor()
  53            .spawn(async move {
  54                std::fs::create_dir_all(&db_path)?;
  55                unsafe {
  56                    heed::EnvOpenOptions::new()
  57                        .map_size(1024 * 1024 * 1024)
  58                        .max_dbs(3000)
  59                        .open(db_path)
  60                }
  61            })
  62            .await
  63            .context("opening database connection")?;
  64
  65        Ok(SemanticIndex {
  66            db_connection,
  67            embedding_provider,
  68            project_indices: HashMap::default(),
  69        })
  70    }
  71
  72    pub fn project_index(
  73        &mut self,
  74        project: Model<Project>,
  75        cx: &mut AppContext,
  76    ) -> Model<ProjectIndex> {
  77        let project_weak = project.downgrade();
  78        project.update(cx, move |_, cx| {
  79            cx.on_release(move |_, cx| {
  80                if cx.has_global::<SemanticIndex>() {
  81                    cx.update_global::<SemanticIndex, _>(|this, _| {
  82                        this.project_indices.remove(&project_weak);
  83                    })
  84                }
  85            })
  86            .detach();
  87        });
  88
  89        self.project_indices
  90            .entry(project.downgrade())
  91            .or_insert_with(|| {
  92                cx.new_model(|cx| {
  93                    ProjectIndex::new(
  94                        project,
  95                        self.db_connection.clone(),
  96                        self.embedding_provider.clone(),
  97                        cx,
  98                    )
  99                })
 100            })
 101            .clone()
 102    }
 103}
 104
 105pub struct ProjectIndex {
 106    db_connection: heed::Env,
 107    project: WeakModel<Project>,
 108    worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
 109    language_registry: Arc<LanguageRegistry>,
 110    fs: Arc<dyn Fs>,
 111    last_status: Status,
 112    status_tx: channel::Sender<()>,
 113    embedding_provider: Arc<dyn EmbeddingProvider>,
 114    _maintain_status: Task<()>,
 115    _subscription: Subscription,
 116}
 117
 118enum WorktreeIndexHandle {
 119    Loading { _task: Task<Result<()>> },
 120    Loaded { index: Model<WorktreeIndex> },
 121}
 122
 123impl ProjectIndex {
 124    fn new(
 125        project: Model<Project>,
 126        db_connection: heed::Env,
 127        embedding_provider: Arc<dyn EmbeddingProvider>,
 128        cx: &mut ModelContext<Self>,
 129    ) -> Self {
 130        let language_registry = project.read(cx).languages().clone();
 131        let fs = project.read(cx).fs().clone();
 132        let (status_tx, mut status_rx) = channel::unbounded();
 133        let mut this = ProjectIndex {
 134            db_connection,
 135            project: project.downgrade(),
 136            worktree_indices: HashMap::default(),
 137            language_registry,
 138            fs,
 139            status_tx,
 140            last_status: Status::Idle,
 141            embedding_provider,
 142            _subscription: cx.subscribe(&project, Self::handle_project_event),
 143            _maintain_status: cx.spawn(|this, mut cx| async move {
 144                while status_rx.next().await.is_some() {
 145                    if this
 146                        .update(&mut cx, |this, cx| this.update_status(cx))
 147                        .is_err()
 148                    {
 149                        break;
 150                    }
 151                }
 152            }),
 153        };
 154        this.update_worktree_indices(cx);
 155        this
 156    }
 157
 158    pub fn status(&self) -> Status {
 159        self.last_status
 160    }
 161
 162    pub fn project(&self) -> WeakModel<Project> {
 163        self.project.clone()
 164    }
 165
 166    fn handle_project_event(
 167        &mut self,
 168        _: Model<Project>,
 169        event: &project::Event,
 170        cx: &mut ModelContext<Self>,
 171    ) {
 172        match event {
 173            project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
 174                self.update_worktree_indices(cx);
 175            }
 176            _ => {}
 177        }
 178    }
 179
 180    fn update_worktree_indices(&mut self, cx: &mut ModelContext<Self>) {
 181        let Some(project) = self.project.upgrade() else {
 182            return;
 183        };
 184
 185        let worktrees = project
 186            .read(cx)
 187            .visible_worktrees(cx)
 188            .filter_map(|worktree| {
 189                if worktree.read(cx).is_local() {
 190                    Some((worktree.entity_id(), worktree))
 191                } else {
 192                    None
 193                }
 194            })
 195            .collect::<HashMap<_, _>>();
 196
 197        self.worktree_indices
 198            .retain(|worktree_id, _| worktrees.contains_key(worktree_id));
 199        for (worktree_id, worktree) in worktrees {
 200            self.worktree_indices.entry(worktree_id).or_insert_with(|| {
 201                let worktree_index = WorktreeIndex::load(
 202                    worktree.clone(),
 203                    self.db_connection.clone(),
 204                    self.language_registry.clone(),
 205                    self.fs.clone(),
 206                    self.status_tx.clone(),
 207                    self.embedding_provider.clone(),
 208                    cx,
 209                );
 210
 211                let load_worktree = cx.spawn(|this, mut cx| async move {
 212                    if let Some(worktree_index) = worktree_index.await.log_err() {
 213                        this.update(&mut cx, |this, _| {
 214                            this.worktree_indices.insert(
 215                                worktree_id,
 216                                WorktreeIndexHandle::Loaded {
 217                                    index: worktree_index,
 218                                },
 219                            );
 220                        })?;
 221                    } else {
 222                        this.update(&mut cx, |this, _cx| {
 223                            this.worktree_indices.remove(&worktree_id)
 224                        })?;
 225                    }
 226
 227                    this.update(&mut cx, |this, cx| this.update_status(cx))
 228                });
 229
 230                WorktreeIndexHandle::Loading {
 231                    _task: load_worktree,
 232                }
 233            });
 234        }
 235
 236        self.update_status(cx);
 237    }
 238
 239    fn update_status(&mut self, cx: &mut ModelContext<Self>) {
 240        let mut indexing_count = 0;
 241        let mut any_loading = false;
 242
 243        for index in self.worktree_indices.values_mut() {
 244            match index {
 245                WorktreeIndexHandle::Loading { .. } => {
 246                    any_loading = true;
 247                    break;
 248                }
 249                WorktreeIndexHandle::Loaded { index, .. } => {
 250                    indexing_count += index.read(cx).entry_ids_being_indexed.len();
 251                }
 252            }
 253        }
 254
 255        let status = if any_loading {
 256            Status::Loading
 257        } else if let Some(remaining_count) = NonZeroUsize::new(indexing_count) {
 258            Status::Scanning { remaining_count }
 259        } else {
 260            Status::Idle
 261        };
 262
 263        if status != self.last_status {
 264            self.last_status = status;
 265            cx.emit(status);
 266        }
 267    }
 268
 269    pub fn search(
 270        &self,
 271        query: String,
 272        limit: usize,
 273        cx: &AppContext,
 274    ) -> Task<Result<Vec<SearchResult>>> {
 275        let (chunks_tx, chunks_rx) = channel::bounded(1024);
 276        let mut worktree_scan_tasks = Vec::new();
 277        for worktree_index in self.worktree_indices.values() {
 278            if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
 279                let chunks_tx = chunks_tx.clone();
 280                index.read_with(cx, |index, cx| {
 281                    let worktree_id = index.worktree.read(cx).id();
 282                    let db_connection = index.db_connection.clone();
 283                    let db = index.db;
 284                    worktree_scan_tasks.push(cx.background_executor().spawn({
 285                        async move {
 286                            let txn = db_connection
 287                                .read_txn()
 288                                .context("failed to create read transaction")?;
 289                            let db_entries = db.iter(&txn).context("failed to iterate database")?;
 290                            for db_entry in db_entries {
 291                                let (_key, db_embedded_file) = db_entry?;
 292                                for chunk in db_embedded_file.chunks {
 293                                    chunks_tx
 294                                        .send((worktree_id, db_embedded_file.path.clone(), chunk))
 295                                        .await?;
 296                                }
 297                            }
 298                            anyhow::Ok(())
 299                        }
 300                    }));
 301                })
 302            }
 303        }
 304        drop(chunks_tx);
 305
 306        let project = self.project.clone();
 307        let embedding_provider = self.embedding_provider.clone();
 308        cx.spawn(|cx| async move {
 309            #[cfg(debug_assertions)]
 310            let embedding_query_start = std::time::Instant::now();
 311            log::info!("Searching for {query}");
 312
 313            let query_embeddings = embedding_provider
 314                .embed(&[TextToEmbed::new(&query)])
 315                .await?;
 316            let query_embedding = query_embeddings
 317                .into_iter()
 318                .next()
 319                .ok_or_else(|| anyhow!("no embedding for query"))?;
 320
 321            let mut results_by_worker = Vec::new();
 322            for _ in 0..cx.background_executor().num_cpus() {
 323                results_by_worker.push(Vec::<WorktreeSearchResult>::new());
 324            }
 325
 326            #[cfg(debug_assertions)]
 327            let search_start = std::time::Instant::now();
 328
 329            cx.background_executor()
 330                .scoped(|cx| {
 331                    for results in results_by_worker.iter_mut() {
 332                        cx.spawn(async {
 333                            while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
 334                                let score = chunk.embedding.similarity(&query_embedding);
 335                                let ix = match results.binary_search_by(|probe| {
 336                                    score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
 337                                }) {
 338                                    Ok(ix) | Err(ix) => ix,
 339                                };
 340                                results.insert(
 341                                    ix,
 342                                    WorktreeSearchResult {
 343                                        worktree_id,
 344                                        path: path.clone(),
 345                                        range: chunk.chunk.range.clone(),
 346                                        score,
 347                                    },
 348                                );
 349                                results.truncate(limit);
 350                            }
 351                        });
 352                    }
 353                })
 354                .await;
 355
 356            futures::future::try_join_all(worktree_scan_tasks).await?;
 357
 358            project.read_with(&cx, |project, cx| {
 359                let mut search_results = Vec::with_capacity(results_by_worker.len() * limit);
 360                for worker_results in results_by_worker {
 361                    search_results.extend(worker_results.into_iter().filter_map(|result| {
 362                        Some(SearchResult {
 363                            worktree: project.worktree_for_id(result.worktree_id, cx)?,
 364                            path: result.path,
 365                            range: result.range,
 366                            score: result.score,
 367                        })
 368                    }));
 369                }
 370                search_results.sort_unstable_by(|a, b| {
 371                    b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
 372                });
 373                search_results.truncate(limit);
 374
 375                #[cfg(debug_assertions)]
 376                {
 377                    let search_elapsed = search_start.elapsed();
 378                    log::debug!(
 379                        "searched {} entries in {:?}",
 380                        search_results.len(),
 381                        search_elapsed
 382                    );
 383                    let embedding_query_elapsed = embedding_query_start.elapsed();
 384                    log::debug!("embedding query took {:?}", embedding_query_elapsed);
 385                }
 386
 387                search_results
 388            })
 389        })
 390    }
 391
 392    #[cfg(test)]
 393    pub fn path_count(&self, cx: &AppContext) -> Result<u64> {
 394        let mut result = 0;
 395        for worktree_index in self.worktree_indices.values() {
 396            if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
 397                result += index.read(cx).path_count()?;
 398            }
 399        }
 400        Ok(result)
 401    }
 402
 403    pub(crate) fn worktree_index(
 404        &self,
 405        worktree_id: WorktreeId,
 406        cx: &AppContext,
 407    ) -> Option<Model<WorktreeIndex>> {
 408        for index in self.worktree_indices.values() {
 409            if let WorktreeIndexHandle::Loaded { index, .. } = index {
 410                if index.read(cx).worktree.read(cx).id() == worktree_id {
 411                    return Some(index.clone());
 412                }
 413            }
 414        }
 415        None
 416    }
 417
 418    pub(crate) fn worktree_indices(&self, cx: &AppContext) -> Vec<Model<WorktreeIndex>> {
 419        let mut result = self
 420            .worktree_indices
 421            .values()
 422            .filter_map(|index| {
 423                if let WorktreeIndexHandle::Loaded { index, .. } = index {
 424                    Some(index.clone())
 425                } else {
 426                    None
 427                }
 428            })
 429            .collect::<Vec<_>>();
 430        result.sort_by_key(|index| index.read(cx).worktree.read(cx).id());
 431        result
 432    }
 433}
 434
 435pub struct SearchResult {
 436    pub worktree: Model<Worktree>,
 437    pub path: Arc<Path>,
 438    pub range: Range<usize>,
 439    pub score: f32,
 440}
 441
 442pub struct WorktreeSearchResult {
 443    pub worktree_id: WorktreeId,
 444    pub path: Arc<Path>,
 445    pub range: Range<usize>,
 446    pub score: f32,
 447}
 448
 449#[derive(Copy, Clone, Debug, Eq, PartialEq)]
 450pub enum Status {
 451    Idle,
 452    Loading,
 453    Scanning { remaining_count: NonZeroUsize },
 454}
 455
 456impl EventEmitter<Status> for ProjectIndex {}
 457
 458struct WorktreeIndex {
 459    worktree: Model<Worktree>,
 460    db_connection: heed::Env,
 461    db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
 462    language_registry: Arc<LanguageRegistry>,
 463    fs: Arc<dyn Fs>,
 464    embedding_provider: Arc<dyn EmbeddingProvider>,
 465    entry_ids_being_indexed: Arc<IndexingEntrySet>,
 466    _index_entries: Task<Result<()>>,
 467    _subscription: Subscription,
 468}
 469
 470impl WorktreeIndex {
 471    pub fn load(
 472        worktree: Model<Worktree>,
 473        db_connection: heed::Env,
 474        language_registry: Arc<LanguageRegistry>,
 475        fs: Arc<dyn Fs>,
 476        status_tx: channel::Sender<()>,
 477        embedding_provider: Arc<dyn EmbeddingProvider>,
 478        cx: &mut AppContext,
 479    ) -> Task<Result<Model<Self>>> {
 480        let worktree_abs_path = worktree.read(cx).abs_path();
 481        cx.spawn(|mut cx| async move {
 482            let db = cx
 483                .background_executor()
 484                .spawn({
 485                    let db_connection = db_connection.clone();
 486                    async move {
 487                        let mut txn = db_connection.write_txn()?;
 488                        let db_name = worktree_abs_path.to_string_lossy();
 489                        let db = db_connection.create_database(&mut txn, Some(&db_name))?;
 490                        txn.commit()?;
 491                        anyhow::Ok(db)
 492                    }
 493                })
 494                .await?;
 495            cx.new_model(|cx| {
 496                Self::new(
 497                    worktree,
 498                    db_connection,
 499                    db,
 500                    status_tx,
 501                    language_registry,
 502                    fs,
 503                    embedding_provider,
 504                    cx,
 505                )
 506            })
 507        })
 508    }
 509
 510    #[allow(clippy::too_many_arguments)]
 511    fn new(
 512        worktree: Model<Worktree>,
 513        db_connection: heed::Env,
 514        db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
 515        status: channel::Sender<()>,
 516        language_registry: Arc<LanguageRegistry>,
 517        fs: Arc<dyn Fs>,
 518        embedding_provider: Arc<dyn EmbeddingProvider>,
 519        cx: &mut ModelContext<Self>,
 520    ) -> Self {
 521        let (updated_entries_tx, updated_entries_rx) = channel::unbounded();
 522        let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| {
 523            if let worktree::Event::UpdatedEntries(update) = event {
 524                _ = updated_entries_tx.try_send(update.clone());
 525            }
 526        });
 527
 528        Self {
 529            db_connection,
 530            db,
 531            worktree,
 532            language_registry,
 533            fs,
 534            embedding_provider,
 535            entry_ids_being_indexed: Arc::new(IndexingEntrySet::new(status)),
 536            _index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)),
 537            _subscription,
 538        }
 539    }
 540
 541    async fn index_entries(
 542        this: WeakModel<Self>,
 543        updated_entries: channel::Receiver<UpdatedEntriesSet>,
 544        mut cx: AsyncAppContext,
 545    ) -> Result<()> {
 546        let index = this.update(&mut cx, |this, cx| this.index_entries_changed_on_disk(cx))?;
 547        index.await.log_err();
 548
 549        while let Ok(updated_entries) = updated_entries.recv().await {
 550            let index = this.update(&mut cx, |this, cx| {
 551                this.index_updated_entries(updated_entries, cx)
 552            })?;
 553            index.await.log_err();
 554        }
 555
 556        Ok(())
 557    }
 558
 559    fn index_entries_changed_on_disk(&self, cx: &AppContext) -> impl Future<Output = Result<()>> {
 560        let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
 561        let worktree_abs_path = worktree.abs_path().clone();
 562        let scan = self.scan_entries(worktree.clone(), cx);
 563        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
 564        let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
 565        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
 566        async move {
 567            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
 568            Ok(())
 569        }
 570    }
 571
 572    fn index_updated_entries(
 573        &self,
 574        updated_entries: UpdatedEntriesSet,
 575        cx: &AppContext,
 576    ) -> impl Future<Output = Result<()>> {
 577        let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
 578        let worktree_abs_path = worktree.abs_path().clone();
 579        let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
 580        let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
 581        let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
 582        let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
 583        async move {
 584            futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
 585            Ok(())
 586        }
 587    }
 588
 589    fn scan_entries(&self, worktree: LocalSnapshot, cx: &AppContext) -> ScanEntries {
 590        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
 591        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
 592        let db_connection = self.db_connection.clone();
 593        let db = self.db;
 594        let entries_being_indexed = self.entry_ids_being_indexed.clone();
 595        let task = cx.background_executor().spawn(async move {
 596            let txn = db_connection
 597                .read_txn()
 598                .context("failed to create read transaction")?;
 599            let mut db_entries = db
 600                .iter(&txn)
 601                .context("failed to create iterator")?
 602                .move_between_keys()
 603                .peekable();
 604
 605            let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
 606            for entry in worktree.files(false, 0) {
 607                let entry_db_key = db_key_for_path(&entry.path);
 608
 609                let mut saved_mtime = None;
 610                while let Some(db_entry) = db_entries.peek() {
 611                    match db_entry {
 612                        Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
 613                            Ordering::Less => {
 614                                if let Some(deletion_range) = deletion_range.as_mut() {
 615                                    deletion_range.1 = Bound::Included(db_path);
 616                                } else {
 617                                    deletion_range =
 618                                        Some((Bound::Included(db_path), Bound::Included(db_path)));
 619                                }
 620
 621                                db_entries.next();
 622                            }
 623                            Ordering::Equal => {
 624                                if let Some(deletion_range) = deletion_range.take() {
 625                                    deleted_entry_ranges_tx
 626                                        .send((
 627                                            deletion_range.0.map(ToString::to_string),
 628                                            deletion_range.1.map(ToString::to_string),
 629                                        ))
 630                                        .await?;
 631                                }
 632                                saved_mtime = db_embedded_file.mtime;
 633                                db_entries.next();
 634                                break;
 635                            }
 636                            Ordering::Greater => {
 637                                break;
 638                            }
 639                        },
 640                        Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
 641                    }
 642                }
 643
 644                if entry.mtime != saved_mtime {
 645                    let handle = entries_being_indexed.insert(entry.id);
 646                    updated_entries_tx.send((entry.clone(), handle)).await?;
 647                }
 648            }
 649
 650            if let Some(db_entry) = db_entries.next() {
 651                let (db_path, _) = db_entry?;
 652                deleted_entry_ranges_tx
 653                    .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
 654                    .await?;
 655            }
 656
 657            Ok(())
 658        });
 659
 660        ScanEntries {
 661            updated_entries: updated_entries_rx,
 662            deleted_entry_ranges: deleted_entry_ranges_rx,
 663            task,
 664        }
 665    }
 666
 667    fn scan_updated_entries(
 668        &self,
 669        worktree: LocalSnapshot,
 670        updated_entries: UpdatedEntriesSet,
 671        cx: &AppContext,
 672    ) -> ScanEntries {
 673        let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
 674        let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
 675        let entries_being_indexed = self.entry_ids_being_indexed.clone();
 676        let task = cx.background_executor().spawn(async move {
 677            for (path, entry_id, status) in updated_entries.iter() {
 678                match status {
 679                    project::PathChange::Added
 680                    | project::PathChange::Updated
 681                    | project::PathChange::AddedOrUpdated => {
 682                        if let Some(entry) = worktree.entry_for_id(*entry_id) {
 683                            if entry.is_file() {
 684                                let handle = entries_being_indexed.insert(entry.id);
 685                                updated_entries_tx.send((entry.clone(), handle)).await?;
 686                            }
 687                        }
 688                    }
 689                    project::PathChange::Removed => {
 690                        let db_path = db_key_for_path(path);
 691                        deleted_entry_ranges_tx
 692                            .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
 693                            .await?;
 694                    }
 695                    project::PathChange::Loaded => {
 696                        // Do nothing.
 697                    }
 698                }
 699            }
 700
 701            Ok(())
 702        });
 703
 704        ScanEntries {
 705            updated_entries: updated_entries_rx,
 706            deleted_entry_ranges: deleted_entry_ranges_rx,
 707            task,
 708        }
 709    }
 710
 711    fn chunk_files(
 712        &self,
 713        worktree_abs_path: Arc<Path>,
 714        entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
 715        cx: &AppContext,
 716    ) -> ChunkFiles {
 717        let language_registry = self.language_registry.clone();
 718        let fs = self.fs.clone();
 719        let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
 720        let task = cx.spawn(|cx| async move {
 721            cx.background_executor()
 722                .scoped(|cx| {
 723                    for _ in 0..cx.num_cpus() {
 724                        cx.spawn(async {
 725                            while let Ok((entry, handle)) = entries.recv().await {
 726                                let entry_abs_path = worktree_abs_path.join(&entry.path);
 727                                let Some(text) = fs
 728                                    .load(&entry_abs_path)
 729                                    .await
 730                                    .with_context(|| {
 731                                        format!("failed to read path {entry_abs_path:?}")
 732                                    })
 733                                    .log_err()
 734                                else {
 735                                    continue;
 736                                };
 737                                let language = language_registry
 738                                    .language_for_file_path(&entry.path)
 739                                    .await
 740                                    .ok();
 741                                let chunked_file = ChunkedFile {
 742                                    chunks: chunk_text(&text, language.as_ref(), &entry.path),
 743                                    handle,
 744                                    path: entry.path,
 745                                    mtime: entry.mtime,
 746                                    text,
 747                                };
 748
 749                                if chunked_files_tx.send(chunked_file).await.is_err() {
 750                                    return;
 751                                }
 752                            }
 753                        });
 754                    }
 755                })
 756                .await;
 757            Ok(())
 758        });
 759
 760        ChunkFiles {
 761            files: chunked_files_rx,
 762            task,
 763        }
 764    }
 765
 766    fn embed_files(
 767        embedding_provider: Arc<dyn EmbeddingProvider>,
 768        chunked_files: channel::Receiver<ChunkedFile>,
 769        cx: &AppContext,
 770    ) -> EmbedFiles {
 771        let embedding_provider = embedding_provider.clone();
 772        let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
 773        let task = cx.background_executor().spawn(async move {
 774            let mut chunked_file_batches =
 775                chunked_files.chunks_timeout(512, Duration::from_secs(2));
 776            while let Some(chunked_files) = chunked_file_batches.next().await {
 777                // View the batch of files as a vec of chunks
 778                // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
 779                // Once those are done, reassemble them back into the files in which they belong
 780                // If any embeddings fail for a file, the entire file is discarded
 781
 782                let chunks: Vec<TextToEmbed> = chunked_files
 783                    .iter()
 784                    .flat_map(|file| {
 785                        file.chunks.iter().map(|chunk| TextToEmbed {
 786                            text: &file.text[chunk.range.clone()],
 787                            digest: chunk.digest,
 788                        })
 789                    })
 790                    .collect::<Vec<_>>();
 791
 792                let mut embeddings: Vec<Option<Embedding>> = Vec::new();
 793                for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
 794                    if let Some(batch_embeddings) =
 795                        embedding_provider.embed(embedding_batch).await.log_err()
 796                    {
 797                        if batch_embeddings.len() == embedding_batch.len() {
 798                            embeddings.extend(batch_embeddings.into_iter().map(Some));
 799                            continue;
 800                        }
 801                        log::error!(
 802                            "embedding provider returned unexpected embedding count {}, expected {}",
 803                            batch_embeddings.len(), embedding_batch.len()
 804                        );
 805                    }
 806
 807                    embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
 808                }
 809
 810                let mut embeddings = embeddings.into_iter();
 811                for chunked_file in chunked_files {
 812                    let mut embedded_file = EmbeddedFile {
 813                        path: chunked_file.path,
 814                        mtime: chunked_file.mtime,
 815                        chunks: Vec::new(),
 816                    };
 817
 818                    let mut embedded_all_chunks = true;
 819                    for (chunk, embedding) in
 820                        chunked_file.chunks.into_iter().zip(embeddings.by_ref())
 821                    {
 822                        if let Some(embedding) = embedding {
 823                            embedded_file
 824                                .chunks
 825                                .push(EmbeddedChunk { chunk, embedding });
 826                        } else {
 827                            embedded_all_chunks = false;
 828                        }
 829                    }
 830
 831                    if embedded_all_chunks {
 832                        embedded_files_tx
 833                            .send((embedded_file, chunked_file.handle))
 834                            .await?;
 835                    }
 836                }
 837            }
 838            Ok(())
 839        });
 840
 841        EmbedFiles {
 842            files: embedded_files_rx,
 843            task,
 844        }
 845    }
 846
 847    fn persist_embeddings(
 848        &self,
 849        mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
 850        embedded_files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
 851        cx: &AppContext,
 852    ) -> Task<Result<()>> {
 853        let db_connection = self.db_connection.clone();
 854        let db = self.db;
 855        cx.background_executor().spawn(async move {
 856            while let Some(deletion_range) = deleted_entry_ranges.next().await {
 857                let mut txn = db_connection.write_txn()?;
 858                let start = deletion_range.0.as_ref().map(|start| start.as_str());
 859                let end = deletion_range.1.as_ref().map(|end| end.as_str());
 860                log::debug!("deleting embeddings in range {:?}", &(start, end));
 861                db.delete_range(&mut txn, &(start, end))?;
 862                txn.commit()?;
 863            }
 864
 865            let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
 866            while let Some(embedded_files) = embedded_files.next().await {
 867                let mut txn = db_connection.write_txn()?;
 868                for (file, _) in &embedded_files {
 869                    log::debug!("saving embedding for file {:?}", file.path);
 870                    let key = db_key_for_path(&file.path);
 871                    db.put(&mut txn, &key, file)?;
 872                }
 873                txn.commit()?;
 874
 875                drop(embedded_files);
 876                log::debug!("committed");
 877            }
 878
 879            Ok(())
 880        })
 881    }
 882
 883    fn paths(&self, cx: &AppContext) -> Task<Result<Vec<Arc<Path>>>> {
 884        let connection = self.db_connection.clone();
 885        let db = self.db;
 886        cx.background_executor().spawn(async move {
 887            let tx = connection
 888                .read_txn()
 889                .context("failed to create read transaction")?;
 890            let result = db
 891                .iter(&tx)?
 892                .map(|entry| Ok(entry?.1.path.clone()))
 893                .collect::<Result<Vec<Arc<Path>>>>();
 894            drop(tx);
 895            result
 896        })
 897    }
 898
 899    fn chunks_for_path(
 900        &self,
 901        path: Arc<Path>,
 902        cx: &AppContext,
 903    ) -> Task<Result<Vec<EmbeddedChunk>>> {
 904        let connection = self.db_connection.clone();
 905        let db = self.db;
 906        cx.background_executor().spawn(async move {
 907            let tx = connection
 908                .read_txn()
 909                .context("failed to create read transaction")?;
 910            Ok(db
 911                .get(&tx, &db_key_for_path(&path))?
 912                .ok_or_else(|| anyhow!("no such path"))?
 913                .chunks
 914                .clone())
 915        })
 916    }
 917
 918    #[cfg(test)]
 919    fn path_count(&self) -> Result<u64> {
 920        let txn = self
 921            .db_connection
 922            .read_txn()
 923            .context("failed to create read transaction")?;
 924        Ok(self.db.len(&txn)?)
 925    }
 926}
 927
 928struct ScanEntries {
 929    updated_entries: channel::Receiver<(Entry, IndexingEntryHandle)>,
 930    deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
 931    task: Task<Result<()>>,
 932}
 933
 934struct ChunkFiles {
 935    files: channel::Receiver<ChunkedFile>,
 936    task: Task<Result<()>>,
 937}
 938
 939struct ChunkedFile {
 940    pub path: Arc<Path>,
 941    pub mtime: Option<SystemTime>,
 942    pub handle: IndexingEntryHandle,
 943    pub text: String,
 944    pub chunks: Vec<Chunk>,
 945}
 946
 947struct EmbedFiles {
 948    files: channel::Receiver<(EmbeddedFile, IndexingEntryHandle)>,
 949    task: Task<Result<()>>,
 950}
 951
 952#[derive(Debug, Serialize, Deserialize)]
 953struct EmbeddedFile {
 954    path: Arc<Path>,
 955    mtime: Option<SystemTime>,
 956    chunks: Vec<EmbeddedChunk>,
 957}
 958
 959#[derive(Clone, Debug, Serialize, Deserialize)]
 960struct EmbeddedChunk {
 961    chunk: Chunk,
 962    embedding: Embedding,
 963}
 964
 965/// The set of entries that are currently being indexed.
 966struct IndexingEntrySet {
 967    entry_ids: Mutex<HashSet<ProjectEntryId>>,
 968    tx: channel::Sender<()>,
 969}
 970
 971/// When dropped, removes the entry from the set of entries that are being indexed.
 972#[derive(Clone)]
 973struct IndexingEntryHandle {
 974    entry_id: ProjectEntryId,
 975    set: Weak<IndexingEntrySet>,
 976}
 977
 978impl IndexingEntrySet {
 979    fn new(tx: channel::Sender<()>) -> Self {
 980        Self {
 981            entry_ids: Default::default(),
 982            tx,
 983        }
 984    }
 985
 986    fn insert(self: &Arc<Self>, entry_id: ProjectEntryId) -> IndexingEntryHandle {
 987        self.entry_ids.lock().insert(entry_id);
 988        self.tx.send_blocking(()).ok();
 989        IndexingEntryHandle {
 990            entry_id,
 991            set: Arc::downgrade(self),
 992        }
 993    }
 994
 995    pub fn len(&self) -> usize {
 996        self.entry_ids.lock().len()
 997    }
 998}
 999
1000impl Drop for IndexingEntryHandle {
1001    fn drop(&mut self) {
1002        if let Some(set) = self.set.upgrade() {
1003            set.tx.send_blocking(()).ok();
1004            set.entry_ids.lock().remove(&self.entry_id);
1005        }
1006    }
1007}
1008
1009fn db_key_for_path(path: &Arc<Path>) -> String {
1010    path.to_string_lossy().replace('/', "\0")
1011}
1012
1013#[cfg(test)]
1014mod tests {
1015    use super::*;
1016    use futures::{future::BoxFuture, FutureExt};
1017    use gpui::TestAppContext;
1018    use language::language_settings::AllLanguageSettings;
1019    use project::Project;
1020    use settings::SettingsStore;
1021    use std::{future, path::Path, sync::Arc};
1022
1023    fn init_test(cx: &mut TestAppContext) {
1024        _ = cx.update(|cx| {
1025            let store = SettingsStore::test(cx);
1026            cx.set_global(store);
1027            language::init(cx);
1028            Project::init_settings(cx);
1029            SettingsStore::update(cx, |store, cx| {
1030                store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
1031            });
1032        });
1033    }
1034
1035    pub struct TestEmbeddingProvider {
1036        batch_size: usize,
1037        compute_embedding: Box<dyn Fn(&str) -> Result<Embedding> + Send + Sync>,
1038    }
1039
1040    impl TestEmbeddingProvider {
1041        pub fn new(
1042            batch_size: usize,
1043            compute_embedding: impl 'static + Fn(&str) -> Result<Embedding> + Send + Sync,
1044        ) -> Self {
1045            return Self {
1046                batch_size,
1047                compute_embedding: Box::new(compute_embedding),
1048            };
1049        }
1050    }
1051
1052    impl EmbeddingProvider for TestEmbeddingProvider {
1053        fn embed<'a>(
1054            &'a self,
1055            texts: &'a [TextToEmbed<'a>],
1056        ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
1057            let embeddings = texts
1058                .iter()
1059                .map(|to_embed| (self.compute_embedding)(to_embed.text))
1060                .collect();
1061            future::ready(embeddings).boxed()
1062        }
1063
1064        fn batch_size(&self) -> usize {
1065            self.batch_size
1066        }
1067    }
1068
1069    #[gpui::test]
1070    async fn test_search(cx: &mut TestAppContext) {
1071        cx.executor().allow_parking();
1072
1073        init_test(cx);
1074
1075        let temp_dir = tempfile::tempdir().unwrap();
1076
1077        let mut semantic_index = SemanticIndex::new(
1078            temp_dir.path().into(),
1079            Arc::new(TestEmbeddingProvider::new(16, |text| {
1080                let mut embedding = vec![0f32; 2];
1081                // if the text contains garbage, give it a 1 in the first dimension
1082                if text.contains("garbage in") {
1083                    embedding[0] = 0.9;
1084                } else {
1085                    embedding[0] = -0.9;
1086                }
1087
1088                if text.contains("garbage out") {
1089                    embedding[1] = 0.9;
1090                } else {
1091                    embedding[1] = -0.9;
1092                }
1093
1094                Ok(Embedding::new(embedding))
1095            })),
1096            &mut cx.to_async(),
1097        )
1098        .await
1099        .unwrap();
1100
1101        let project_path = Path::new("./fixture");
1102
1103        let project = cx
1104            .spawn(|mut cx| async move { Project::example([project_path], &mut cx).await })
1105            .await;
1106
1107        cx.update(|cx| {
1108            let language_registry = project.read(cx).languages().clone();
1109            let node_runtime = project.read(cx).node_runtime().unwrap().clone();
1110            languages::init(language_registry, node_runtime, cx);
1111        });
1112
1113        let project_index = cx.update(|cx| semantic_index.project_index(project.clone(), cx));
1114
1115        while project_index
1116            .read_with(cx, |index, cx| index.path_count(cx))
1117            .unwrap()
1118            == 0
1119        {
1120            project_index.next_event(cx).await;
1121        }
1122
1123        let results = cx
1124            .update(|cx| {
1125                let project_index = project_index.read(cx);
1126                let query = "garbage in, garbage out";
1127                project_index.search(query.into(), 4, cx)
1128            })
1129            .await
1130            .unwrap();
1131
1132        assert!(results.len() > 1, "should have found some results");
1133
1134        for result in &results {
1135            println!("result: {:?}", result.path);
1136            println!("score: {:?}", result.score);
1137        }
1138
1139        // Find result that is greater than 0.5
1140        let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
1141
1142        assert_eq!(search_result.path.to_string_lossy(), "needle.md");
1143
1144        let content = cx
1145            .update(|cx| {
1146                let worktree = search_result.worktree.read(cx);
1147                let entry_abs_path = worktree.abs_path().join(&search_result.path);
1148                let fs = project.read(cx).fs().clone();
1149                cx.background_executor()
1150                    .spawn(async move { fs.load(&entry_abs_path).await.unwrap() })
1151            })
1152            .await;
1153
1154        let range = search_result.range.clone();
1155        let content = content[range.clone()].to_owned();
1156
1157        assert!(content.contains("garbage in, garbage out"));
1158    }
1159
1160    #[gpui::test]
1161    async fn test_embed_files(cx: &mut TestAppContext) {
1162        cx.executor().allow_parking();
1163
1164        let provider = Arc::new(TestEmbeddingProvider::new(3, |text| {
1165            if text.contains('g') {
1166                Err(anyhow!("cannot embed text containing a 'g' character"))
1167            } else {
1168                Ok(Embedding::new(
1169                    ('a'..'z')
1170                        .map(|char| text.chars().filter(|c| *c == char).count() as f32)
1171                        .collect(),
1172                ))
1173            }
1174        }));
1175
1176        let (indexing_progress_tx, _) = channel::unbounded();
1177        let indexing_entries = Arc::new(IndexingEntrySet::new(indexing_progress_tx));
1178
1179        let (chunked_files_tx, chunked_files_rx) = channel::unbounded::<ChunkedFile>();
1180        chunked_files_tx
1181            .send_blocking(ChunkedFile {
1182                path: Path::new("test1.md").into(),
1183                mtime: None,
1184                handle: indexing_entries.insert(ProjectEntryId::from_proto(0)),
1185                text: "abcdefghijklmnop".to_string(),
1186                chunks: [0..4, 4..8, 8..12, 12..16]
1187                    .into_iter()
1188                    .map(|range| Chunk {
1189                        range,
1190                        digest: Default::default(),
1191                    })
1192                    .collect(),
1193            })
1194            .unwrap();
1195        chunked_files_tx
1196            .send_blocking(ChunkedFile {
1197                path: Path::new("test2.md").into(),
1198                mtime: None,
1199                handle: indexing_entries.insert(ProjectEntryId::from_proto(1)),
1200                text: "qrstuvwxyz".to_string(),
1201                chunks: [0..4, 4..8, 8..10]
1202                    .into_iter()
1203                    .map(|range| Chunk {
1204                        range,
1205                        digest: Default::default(),
1206                    })
1207                    .collect(),
1208            })
1209            .unwrap();
1210        chunked_files_tx.close();
1211
1212        let embed_files_task =
1213            cx.update(|cx| WorktreeIndex::embed_files(provider.clone(), chunked_files_rx, cx));
1214        embed_files_task.task.await.unwrap();
1215
1216        let mut embedded_files_rx = embed_files_task.files;
1217        let mut embedded_files = Vec::new();
1218        while let Some((embedded_file, _)) = embedded_files_rx.next().await {
1219            embedded_files.push(embedded_file);
1220        }
1221
1222        assert_eq!(embedded_files.len(), 1);
1223        assert_eq!(embedded_files[0].path.as_ref(), Path::new("test2.md"));
1224        assert_eq!(
1225            embedded_files[0]
1226                .chunks
1227                .iter()
1228                .map(|embedded_chunk| { embedded_chunk.embedding.clone() })
1229                .collect::<Vec<Embedding>>(),
1230            vec![
1231                (provider.compute_embedding)("qrst").unwrap(),
1232                (provider.compute_embedding)("uvwx").unwrap(),
1233                (provider.compute_embedding)("yz").unwrap(),
1234            ],
1235        );
1236    }
1237}