semantic_index.rs

   1mod db;
   2mod embedding_queue;
   3mod parsing;
   4pub mod semantic_index_settings;
   5
   6#[cfg(test)]
   7mod semantic_index_tests;
   8
   9use crate::semantic_index_settings::SemanticIndexSettings;
  10use ai::embedding::{Embedding, EmbeddingProvider};
  11use ai::providers::open_ai::OpenAIEmbeddingProvider;
  12use anyhow::{anyhow, Result};
  13use collections::{BTreeMap, HashMap, HashSet};
  14use db::VectorDatabase;
  15use embedding_queue::{EmbeddingQueue, FileToEmbed};
  16use futures::{future, FutureExt, StreamExt};
  17use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
  18use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
  19use lazy_static::lazy_static;
  20use ordered_float::OrderedFloat;
  21use parking_lot::Mutex;
  22use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
  23use postage::watch;
  24use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
  25use smol::channel;
  26use std::{
  27    cmp::Reverse,
  28    env,
  29    future::Future,
  30    mem,
  31    ops::Range,
  32    path::{Path, PathBuf},
  33    sync::{Arc, Weak},
  34    time::{Duration, Instant, SystemTime},
  35};
  36use util::{channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
  37use workspace::WorkspaceCreated;
  38
  39const SEMANTIC_INDEX_VERSION: usize = 11;
  40const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
  41const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
  42
  43lazy_static! {
  44    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
  45}
  46
  47pub fn init(
  48    fs: Arc<dyn Fs>,
  49    http_client: Arc<dyn HttpClient>,
  50    language_registry: Arc<LanguageRegistry>,
  51    cx: &mut AppContext,
  52) {
  53    settings::register::<SemanticIndexSettings>(cx);
  54
  55    let db_file_path = EMBEDDINGS_DIR
  56        .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
  57        .join("embeddings_db");
  58
  59    cx.subscribe_global::<WorkspaceCreated, _>({
  60        move |event, cx| {
  61            let Some(semantic_index) = SemanticIndex::global(cx) else {
  62                return;
  63            };
  64            let workspace = &event.0;
  65            if let Some(workspace) = workspace.upgrade(cx) {
  66                let project = workspace.read(cx).project().clone();
  67                if project.read(cx).is_local() {
  68                    cx.spawn(|mut cx| async move {
  69                        let previously_indexed = semantic_index
  70                            .update(&mut cx, |index, cx| {
  71                                index.project_previously_indexed(&project, cx)
  72                            })
  73                            .await?;
  74                        if previously_indexed {
  75                            semantic_index
  76                                .update(&mut cx, |index, cx| index.index_project(project, cx))
  77                                .await?;
  78                        }
  79                        anyhow::Ok(())
  80                    })
  81                    .detach_and_log_err(cx);
  82                }
  83            }
  84        }
  85    })
  86    .detach();
  87
  88    cx.spawn(move |mut cx| async move {
  89        let semantic_index = SemanticIndex::new(
  90            fs,
  91            db_file_path,
  92            Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
  93            language_registry,
  94            cx.clone(),
  95        )
  96        .await?;
  97
  98        cx.update(|cx| {
  99            cx.set_global(semantic_index.clone());
 100        });
 101
 102        anyhow::Ok(())
 103    })
 104    .detach();
 105}
 106
 107#[derive(Copy, Clone, Debug)]
 108pub enum SemanticIndexStatus {
 109    NotAuthenticated,
 110    NotIndexed,
 111    Indexed,
 112    Indexing {
 113        remaining_files: usize,
 114        rate_limit_expiry: Option<Instant>,
 115    },
 116}
 117
 118pub struct SemanticIndex {
 119    fs: Arc<dyn Fs>,
 120    db: VectorDatabase,
 121    embedding_provider: Arc<dyn EmbeddingProvider>,
 122    language_registry: Arc<LanguageRegistry>,
 123    parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
 124    _embedding_task: Task<()>,
 125    _parsing_files_tasks: Vec<Task<()>>,
 126    projects: HashMap<WeakModelHandle<Project>, ProjectState>,
 127}
 128
 129struct ProjectState {
 130    worktrees: HashMap<WorktreeId, WorktreeState>,
 131    pending_file_count_rx: watch::Receiver<usize>,
 132    pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
 133    pending_index: usize,
 134    _subscription: gpui::Subscription,
 135    _observe_pending_file_count: Task<()>,
 136}
 137
 138enum WorktreeState {
 139    Registering(RegisteringWorktreeState),
 140    Registered(RegisteredWorktreeState),
 141}
 142
 143impl WorktreeState {
 144    fn is_registered(&self) -> bool {
 145        matches!(self, Self::Registered(_))
 146    }
 147
 148    fn paths_changed(
 149        &mut self,
 150        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 151        worktree: &Worktree,
 152    ) {
 153        let changed_paths = match self {
 154            Self::Registering(state) => &mut state.changed_paths,
 155            Self::Registered(state) => &mut state.changed_paths,
 156        };
 157
 158        for (path, entry_id, change) in changes.iter() {
 159            let Some(entry) = worktree.entry_for_id(*entry_id) else {
 160                continue;
 161            };
 162            if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
 163                continue;
 164            }
 165            changed_paths.insert(
 166                path.clone(),
 167                ChangedPathInfo {
 168                    mtime: entry.mtime,
 169                    is_deleted: *change == PathChange::Removed,
 170                },
 171            );
 172        }
 173    }
 174}
 175
 176struct RegisteringWorktreeState {
 177    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 178    done_rx: watch::Receiver<Option<()>>,
 179    _registration: Task<()>,
 180}
 181
 182impl RegisteringWorktreeState {
 183    fn done(&self) -> impl Future<Output = ()> {
 184        let mut done_rx = self.done_rx.clone();
 185        async move {
 186            while let Some(result) = done_rx.next().await {
 187                if result.is_some() {
 188                    break;
 189                }
 190            }
 191        }
 192    }
 193}
 194
 195struct RegisteredWorktreeState {
 196    db_id: i64,
 197    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 198}
 199
 200struct ChangedPathInfo {
 201    mtime: SystemTime,
 202    is_deleted: bool,
 203}
 204
 205#[derive(Clone)]
 206pub struct JobHandle {
 207    /// The outer Arc is here to count the clones of a JobHandle instance;
 208    /// when the last handle to a given job is dropped, we decrement a counter (just once).
 209    tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
 210}
 211
 212impl JobHandle {
 213    fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
 214        *tx.lock().borrow_mut() += 1;
 215        Self {
 216            tx: Arc::new(Arc::downgrade(&tx)),
 217        }
 218    }
 219}
 220
 221impl ProjectState {
 222    fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
 223        let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
 224        let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
 225        Self {
 226            worktrees: Default::default(),
 227            pending_file_count_rx: pending_file_count_rx.clone(),
 228            pending_file_count_tx,
 229            pending_index: 0,
 230            _subscription: subscription,
 231            _observe_pending_file_count: cx.spawn_weak({
 232                let mut pending_file_count_rx = pending_file_count_rx.clone();
 233                |this, mut cx| async move {
 234                    while let Some(_) = pending_file_count_rx.next().await {
 235                        if let Some(this) = this.upgrade(&cx) {
 236                            this.update(&mut cx, |_, cx| cx.notify());
 237                        }
 238                    }
 239                }
 240            }),
 241        }
 242    }
 243
 244    fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
 245        self.worktrees
 246            .iter()
 247            .find_map(|(worktree_id, worktree_state)| match worktree_state {
 248                WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
 249                _ => None,
 250            })
 251    }
 252}
 253
 254#[derive(Clone)]
 255pub struct PendingFile {
 256    worktree_db_id: i64,
 257    relative_path: Arc<Path>,
 258    absolute_path: PathBuf,
 259    language: Option<Arc<Language>>,
 260    modified_time: SystemTime,
 261    job_handle: JobHandle,
 262}
 263
 264#[derive(Clone)]
 265pub struct SearchResult {
 266    pub buffer: ModelHandle<Buffer>,
 267    pub range: Range<Anchor>,
 268    pub similarity: OrderedFloat<f32>,
 269}
 270
 271impl SemanticIndex {
 272    pub fn global(cx: &mut AppContext) -> Option<ModelHandle<SemanticIndex>> {
 273        if cx.has_global::<ModelHandle<Self>>() {
 274            Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
 275        } else {
 276            None
 277        }
 278    }
 279
 280    pub fn authenticate(&mut self, cx: &AppContext) -> bool {
 281        if !self.embedding_provider.has_credentials() {
 282            self.embedding_provider.retrieve_credentials(cx);
 283        } else {
 284            return true;
 285        }
 286
 287        self.embedding_provider.has_credentials()
 288    }
 289
 290    pub fn is_authenticated(&self) -> bool {
 291        self.embedding_provider.has_credentials()
 292    }
 293
 294    pub fn enabled(cx: &AppContext) -> bool {
 295        settings::get::<SemanticIndexSettings>(cx).enabled
 296    }
 297
 298    pub fn status(&self, project: &ModelHandle<Project>) -> SemanticIndexStatus {
 299        if !self.is_authenticated() {
 300            return SemanticIndexStatus::NotAuthenticated;
 301        }
 302
 303        if let Some(project_state) = self.projects.get(&project.downgrade()) {
 304            if project_state
 305                .worktrees
 306                .values()
 307                .all(|worktree| worktree.is_registered())
 308                && project_state.pending_index == 0
 309            {
 310                SemanticIndexStatus::Indexed
 311            } else {
 312                SemanticIndexStatus::Indexing {
 313                    remaining_files: project_state.pending_file_count_rx.borrow().clone(),
 314                    rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
 315                }
 316            }
 317        } else {
 318            SemanticIndexStatus::NotIndexed
 319        }
 320    }
 321
 322    pub async fn new(
 323        fs: Arc<dyn Fs>,
 324        database_path: PathBuf,
 325        embedding_provider: Arc<dyn EmbeddingProvider>,
 326        language_registry: Arc<LanguageRegistry>,
 327        mut cx: AsyncAppContext,
 328    ) -> Result<ModelHandle<Self>> {
 329        let t0 = Instant::now();
 330        let database_path = Arc::from(database_path);
 331        let db = VectorDatabase::new(fs.clone(), database_path, cx.background()).await?;
 332
 333        log::trace!(
 334            "db initialization took {:?} milliseconds",
 335            t0.elapsed().as_millis()
 336        );
 337
 338        Ok(cx.add_model(|cx| {
 339            let t0 = Instant::now();
 340            let embedding_queue =
 341                EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
 342            let _embedding_task = cx.background().spawn({
 343                let embedded_files = embedding_queue.finished_files();
 344                let db = db.clone();
 345                async move {
 346                    while let Ok(file) = embedded_files.recv().await {
 347                        db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
 348                            .await
 349                            .log_err();
 350                    }
 351                }
 352            });
 353
 354            // Parse files into embeddable spans.
 355            let (parsing_files_tx, parsing_files_rx) =
 356                channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
 357            let embedding_queue = Arc::new(Mutex::new(embedding_queue));
 358            let mut _parsing_files_tasks = Vec::new();
 359            for _ in 0..cx.background().num_cpus() {
 360                let fs = fs.clone();
 361                let mut parsing_files_rx = parsing_files_rx.clone();
 362                let embedding_provider = embedding_provider.clone();
 363                let embedding_queue = embedding_queue.clone();
 364                let background = cx.background().clone();
 365                _parsing_files_tasks.push(cx.background().spawn(async move {
 366                    let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 367                    loop {
 368                        let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
 369                        let mut next_file_to_parse = parsing_files_rx.next().fuse();
 370                        futures::select_biased! {
 371                            next_file_to_parse = next_file_to_parse => {
 372                                if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
 373                                    Self::parse_file(
 374                                        &fs,
 375                                        pending_file,
 376                                        &mut retriever,
 377                                        &embedding_queue,
 378                                        &embeddings_for_digest,
 379                                    )
 380                                    .await
 381                                } else {
 382                                    break;
 383                                }
 384                            },
 385                            _ = timer => {
 386                                embedding_queue.lock().flush();
 387                            }
 388                        }
 389                    }
 390                }));
 391            }
 392
 393            log::trace!(
 394                "semantic index task initialization took {:?} milliseconds",
 395                t0.elapsed().as_millis()
 396            );
 397            Self {
 398                fs,
 399                db,
 400                embedding_provider,
 401                language_registry,
 402                parsing_files_tx,
 403                _embedding_task,
 404                _parsing_files_tasks,
 405                projects: Default::default(),
 406            }
 407        }))
 408    }
 409
 410    async fn parse_file(
 411        fs: &Arc<dyn Fs>,
 412        pending_file: PendingFile,
 413        retriever: &mut CodeContextRetriever,
 414        embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
 415        embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
 416    ) {
 417        let Some(language) = pending_file.language else {
 418            return;
 419        };
 420
 421        if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
 422            if let Some(mut spans) = retriever
 423                .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
 424                .log_err()
 425            {
 426                log::trace!(
 427                    "parsed path {:?}: {} spans",
 428                    pending_file.relative_path,
 429                    spans.len()
 430                );
 431
 432                for span in &mut spans {
 433                    if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
 434                        span.embedding = Some(embedding.to_owned());
 435                    }
 436                }
 437
 438                embedding_queue.lock().push(FileToEmbed {
 439                    worktree_id: pending_file.worktree_db_id,
 440                    path: pending_file.relative_path,
 441                    mtime: pending_file.modified_time,
 442                    job_handle: pending_file.job_handle,
 443                    spans,
 444                });
 445            }
 446        }
 447    }
 448
 449    pub fn project_previously_indexed(
 450        &mut self,
 451        project: &ModelHandle<Project>,
 452        cx: &mut ModelContext<Self>,
 453    ) -> Task<Result<bool>> {
 454        let worktrees_indexed_previously = project
 455            .read(cx)
 456            .worktrees(cx)
 457            .map(|worktree| {
 458                self.db
 459                    .worktree_previously_indexed(&worktree.read(cx).abs_path())
 460            })
 461            .collect::<Vec<_>>();
 462        cx.spawn(|_, _cx| async move {
 463            let worktree_indexed_previously =
 464                futures::future::join_all(worktrees_indexed_previously).await;
 465
 466            Ok(worktree_indexed_previously
 467                .iter()
 468                .filter(|worktree| worktree.is_ok())
 469                .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
 470        })
 471    }
 472
 473    fn project_entries_changed(
 474        &mut self,
 475        project: ModelHandle<Project>,
 476        worktree_id: WorktreeId,
 477        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
 478        cx: &mut ModelContext<Self>,
 479    ) {
 480        let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
 481            return;
 482        };
 483        let project = project.downgrade();
 484        let Some(project_state) = self.projects.get_mut(&project) else {
 485            return;
 486        };
 487
 488        let worktree = worktree.read(cx);
 489        let worktree_state =
 490            if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
 491                worktree_state
 492            } else {
 493                return;
 494            };
 495        worktree_state.paths_changed(changes, worktree);
 496        if let WorktreeState::Registered(_) = worktree_state {
 497            cx.spawn_weak(|this, mut cx| async move {
 498                cx.background().timer(BACKGROUND_INDEXING_DELAY).await;
 499                if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) {
 500                    this.update(&mut cx, |this, cx| {
 501                        this.index_project(project, cx).detach_and_log_err(cx)
 502                    });
 503                }
 504            })
 505            .detach();
 506        }
 507    }
 508
 509    fn register_worktree(
 510        &mut self,
 511        project: ModelHandle<Project>,
 512        worktree: ModelHandle<Worktree>,
 513        cx: &mut ModelContext<Self>,
 514    ) {
 515        let project = project.downgrade();
 516        let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
 517            project_state
 518        } else {
 519            return;
 520        };
 521        let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
 522            worktree
 523        } else {
 524            return;
 525        };
 526        let worktree_abs_path = worktree.abs_path().clone();
 527        let scan_complete = worktree.scan_complete();
 528        let worktree_id = worktree.id();
 529        let db = self.db.clone();
 530        let language_registry = self.language_registry.clone();
 531        let (mut done_tx, done_rx) = watch::channel();
 532        let registration = cx.spawn(|this, mut cx| {
 533            async move {
 534                let register = async {
 535                    scan_complete.await;
 536                    let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
 537                    let mut file_mtimes = db.get_file_mtimes(db_id).await?;
 538                    let worktree = if let Some(project) = project.upgrade(&cx) {
 539                        project
 540                            .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
 541                            .ok_or_else(|| anyhow!("worktree not found"))?
 542                    } else {
 543                        return anyhow::Ok(());
 544                    };
 545                    let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot());
 546                    let mut changed_paths = cx
 547                        .background()
 548                        .spawn(async move {
 549                            let mut changed_paths = BTreeMap::new();
 550                            for file in worktree.files(false, 0) {
 551                                let absolute_path = worktree.absolutize(&file.path);
 552
 553                                if file.is_external || file.is_ignored || file.is_symlink {
 554                                    continue;
 555                                }
 556
 557                                if let Ok(language) = language_registry
 558                                    .language_for_file(&absolute_path, None)
 559                                    .await
 560                                {
 561                                    // Test if file is valid parseable file
 562                                    if !PARSEABLE_ENTIRE_FILE_TYPES
 563                                        .contains(&language.name().as_ref())
 564                                        && &language.name().as_ref() != &"Markdown"
 565                                        && language
 566                                            .grammar()
 567                                            .and_then(|grammar| grammar.embedding_config.as_ref())
 568                                            .is_none()
 569                                    {
 570                                        continue;
 571                                    }
 572
 573                                    let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
 574                                    let already_stored = stored_mtime
 575                                        .map_or(false, |existing_mtime| {
 576                                            existing_mtime == file.mtime
 577                                        });
 578
 579                                    if !already_stored {
 580                                        changed_paths.insert(
 581                                            file.path.clone(),
 582                                            ChangedPathInfo {
 583                                                mtime: file.mtime,
 584                                                is_deleted: false,
 585                                            },
 586                                        );
 587                                    }
 588                                }
 589                            }
 590
 591                            // Clean up entries from database that are no longer in the worktree.
 592                            for (path, mtime) in file_mtimes {
 593                                changed_paths.insert(
 594                                    path.into(),
 595                                    ChangedPathInfo {
 596                                        mtime,
 597                                        is_deleted: true,
 598                                    },
 599                                );
 600                            }
 601
 602                            anyhow::Ok(changed_paths)
 603                        })
 604                        .await?;
 605                    this.update(&mut cx, |this, cx| {
 606                        let project_state = this
 607                            .projects
 608                            .get_mut(&project)
 609                            .ok_or_else(|| anyhow!("project not registered"))?;
 610                        let project = project
 611                            .upgrade(cx)
 612                            .ok_or_else(|| anyhow!("project was dropped"))?;
 613
 614                        if let Some(WorktreeState::Registering(state)) =
 615                            project_state.worktrees.remove(&worktree_id)
 616                        {
 617                            changed_paths.extend(state.changed_paths);
 618                        }
 619                        project_state.worktrees.insert(
 620                            worktree_id,
 621                            WorktreeState::Registered(RegisteredWorktreeState {
 622                                db_id,
 623                                changed_paths,
 624                            }),
 625                        );
 626                        this.index_project(project, cx).detach_and_log_err(cx);
 627
 628                        anyhow::Ok(())
 629                    })?;
 630
 631                    anyhow::Ok(())
 632                };
 633
 634                if register.await.log_err().is_none() {
 635                    // Stop tracking this worktree if the registration failed.
 636                    this.update(&mut cx, |this, _| {
 637                        this.projects.get_mut(&project).map(|project_state| {
 638                            project_state.worktrees.remove(&worktree_id);
 639                        });
 640                    })
 641                }
 642
 643                *done_tx.borrow_mut() = Some(());
 644            }
 645        });
 646        project_state.worktrees.insert(
 647            worktree_id,
 648            WorktreeState::Registering(RegisteringWorktreeState {
 649                changed_paths: Default::default(),
 650                done_rx,
 651                _registration: registration,
 652            }),
 653        );
 654    }
 655
 656    fn project_worktrees_changed(
 657        &mut self,
 658        project: ModelHandle<Project>,
 659        cx: &mut ModelContext<Self>,
 660    ) {
 661        let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
 662        {
 663            project_state
 664        } else {
 665            return;
 666        };
 667
 668        let mut worktrees = project
 669            .read(cx)
 670            .worktrees(cx)
 671            .filter(|worktree| worktree.read(cx).is_local())
 672            .collect::<Vec<_>>();
 673        let worktree_ids = worktrees
 674            .iter()
 675            .map(|worktree| worktree.read(cx).id())
 676            .collect::<HashSet<_>>();
 677
 678        // Remove worktrees that are no longer present
 679        project_state
 680            .worktrees
 681            .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
 682
 683        // Register new worktrees
 684        worktrees.retain(|worktree| {
 685            let worktree_id = worktree.read(cx).id();
 686            !project_state.worktrees.contains_key(&worktree_id)
 687        });
 688        for worktree in worktrees {
 689            self.register_worktree(project.clone(), worktree, cx);
 690        }
 691    }
 692
 693    pub fn pending_file_count(
 694        &self,
 695        project: &ModelHandle<Project>,
 696    ) -> Option<watch::Receiver<usize>> {
 697        Some(
 698            self.projects
 699                .get(&project.downgrade())?
 700                .pending_file_count_rx
 701                .clone(),
 702        )
 703    }
 704
 705    pub fn search_project(
 706        &mut self,
 707        project: ModelHandle<Project>,
 708        query: String,
 709        limit: usize,
 710        includes: Vec<PathMatcher>,
 711        excludes: Vec<PathMatcher>,
 712        cx: &mut ModelContext<Self>,
 713    ) -> Task<Result<Vec<SearchResult>>> {
 714        if query.is_empty() {
 715            return Task::ready(Ok(Vec::new()));
 716        }
 717
 718        let index = self.index_project(project.clone(), cx);
 719        let embedding_provider = self.embedding_provider.clone();
 720
 721        cx.spawn(|this, mut cx| async move {
 722            index.await?;
 723            let t0 = Instant::now();
 724
 725            let query = embedding_provider
 726                .embed_batch(vec![query])
 727                .await?
 728                .pop()
 729                .ok_or_else(|| anyhow!("could not embed query"))?;
 730            log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
 731
 732            let search_start = Instant::now();
 733            let modified_buffer_results = this.update(&mut cx, |this, cx| {
 734                this.search_modified_buffers(
 735                    &project,
 736                    query.clone(),
 737                    limit,
 738                    &includes,
 739                    &excludes,
 740                    cx,
 741                )
 742            });
 743            let file_results = this.update(&mut cx, |this, cx| {
 744                this.search_files(project, query, limit, includes, excludes, cx)
 745            });
 746            let (modified_buffer_results, file_results) =
 747                futures::join!(modified_buffer_results, file_results);
 748
 749            // Weave together the results from modified buffers and files.
 750            let mut results = Vec::new();
 751            let mut modified_buffers = HashSet::default();
 752            for result in modified_buffer_results.log_err().unwrap_or_default() {
 753                modified_buffers.insert(result.buffer.clone());
 754                results.push(result);
 755            }
 756            for result in file_results.log_err().unwrap_or_default() {
 757                if !modified_buffers.contains(&result.buffer) {
 758                    results.push(result);
 759                }
 760            }
 761            results.sort_by_key(|result| Reverse(result.similarity));
 762            results.truncate(limit);
 763            log::trace!("Semantic search took {:?}", search_start.elapsed());
 764            Ok(results)
 765        })
 766    }
 767
 768    pub fn search_files(
 769        &mut self,
 770        project: ModelHandle<Project>,
 771        query: Embedding,
 772        limit: usize,
 773        includes: Vec<PathMatcher>,
 774        excludes: Vec<PathMatcher>,
 775        cx: &mut ModelContext<Self>,
 776    ) -> Task<Result<Vec<SearchResult>>> {
 777        let db_path = self.db.path().clone();
 778        let fs = self.fs.clone();
 779        cx.spawn(|this, mut cx| async move {
 780            let database =
 781                VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
 782
 783            let worktree_db_ids = this.read_with(&cx, |this, _| {
 784                let project_state = this
 785                    .projects
 786                    .get(&project.downgrade())
 787                    .ok_or_else(|| anyhow!("project was not indexed"))?;
 788                let worktree_db_ids = project_state
 789                    .worktrees
 790                    .values()
 791                    .filter_map(|worktree| {
 792                        if let WorktreeState::Registered(worktree) = worktree {
 793                            Some(worktree.db_id)
 794                        } else {
 795                            None
 796                        }
 797                    })
 798                    .collect::<Vec<i64>>();
 799                anyhow::Ok(worktree_db_ids)
 800            })?;
 801
 802            let file_ids = database
 803                .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
 804                .await?;
 805
 806            let batch_n = cx.background().num_cpus();
 807            let ids_len = file_ids.clone().len();
 808            let minimum_batch_size = 50;
 809
 810            let batch_size = {
 811                let size = ids_len / batch_n;
 812                if size < minimum_batch_size {
 813                    minimum_batch_size
 814                } else {
 815                    size
 816                }
 817            };
 818
 819            let mut batch_results = Vec::new();
 820            for batch in file_ids.chunks(batch_size) {
 821                let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
 822                let limit = limit.clone();
 823                let fs = fs.clone();
 824                let db_path = db_path.clone();
 825                let query = query.clone();
 826                if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
 827                    .await
 828                    .log_err()
 829                {
 830                    batch_results.push(async move {
 831                        db.top_k_search(&query, limit, batch.as_slice()).await
 832                    });
 833                }
 834            }
 835
 836            let batch_results = futures::future::join_all(batch_results).await;
 837
 838            let mut results = Vec::new();
 839            for batch_result in batch_results {
 840                if batch_result.is_ok() {
 841                    for (id, similarity) in batch_result.unwrap() {
 842                        let ix = match results
 843                            .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
 844                        {
 845                            Ok(ix) => ix,
 846                            Err(ix) => ix,
 847                        };
 848
 849                        results.insert(ix, (id, similarity));
 850                        results.truncate(limit);
 851                    }
 852                }
 853            }
 854
 855            let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
 856            let scores = results
 857                .into_iter()
 858                .map(|(_, score)| score)
 859                .collect::<Vec<_>>();
 860            let spans = database.spans_for_ids(ids.as_slice()).await?;
 861
 862            let mut tasks = Vec::new();
 863            let mut ranges = Vec::new();
 864            let weak_project = project.downgrade();
 865            project.update(&mut cx, |project, cx| {
 866                for (worktree_db_id, file_path, byte_range) in spans {
 867                    let project_state =
 868                        if let Some(state) = this.read(cx).projects.get(&weak_project) {
 869                            state
 870                        } else {
 871                            return Err(anyhow!("project not added"));
 872                        };
 873                    if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
 874                        tasks.push(project.open_buffer((worktree_id, file_path), cx));
 875                        ranges.push(byte_range);
 876                    }
 877                }
 878
 879                Ok(())
 880            })?;
 881
 882            let buffers = futures::future::join_all(tasks).await;
 883            Ok(buffers
 884                .into_iter()
 885                .zip(ranges)
 886                .zip(scores)
 887                .filter_map(|((buffer, range), similarity)| {
 888                    let buffer = buffer.log_err()?;
 889                    let range = buffer.read_with(&cx, |buffer, _| {
 890                        let start = buffer.clip_offset(range.start, Bias::Left);
 891                        let end = buffer.clip_offset(range.end, Bias::Right);
 892                        buffer.anchor_before(start)..buffer.anchor_after(end)
 893                    });
 894                    Some(SearchResult {
 895                        buffer,
 896                        range,
 897                        similarity,
 898                    })
 899                })
 900                .collect())
 901        })
 902    }
 903
 904    fn search_modified_buffers(
 905        &self,
 906        project: &ModelHandle<Project>,
 907        query: Embedding,
 908        limit: usize,
 909        includes: &[PathMatcher],
 910        excludes: &[PathMatcher],
 911        cx: &mut ModelContext<Self>,
 912    ) -> Task<Result<Vec<SearchResult>>> {
 913        let modified_buffers = project
 914            .read(cx)
 915            .opened_buffers(cx)
 916            .into_iter()
 917            .filter_map(|buffer_handle| {
 918                let buffer = buffer_handle.read(cx);
 919                let snapshot = buffer.snapshot();
 920                let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
 921                    excludes.iter().any(|matcher| matcher.is_match(&path))
 922                });
 923
 924                let included = if includes.len() == 0 {
 925                    true
 926                } else {
 927                    snapshot.resolve_file_path(cx, false).map_or(false, |path| {
 928                        includes.iter().any(|matcher| matcher.is_match(&path))
 929                    })
 930                };
 931
 932                if buffer.is_dirty() && !excluded && included {
 933                    Some((buffer_handle, snapshot))
 934                } else {
 935                    None
 936                }
 937            })
 938            .collect::<HashMap<_, _>>();
 939
 940        let embedding_provider = self.embedding_provider.clone();
 941        let fs = self.fs.clone();
 942        let db_path = self.db.path().clone();
 943        let background = cx.background().clone();
 944        cx.background().spawn(async move {
 945            let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
 946            let mut results = Vec::<SearchResult>::new();
 947
 948            let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
 949            for (buffer, snapshot) in modified_buffers {
 950                let language = snapshot
 951                    .language_at(0)
 952                    .cloned()
 953                    .unwrap_or_else(|| language::PLAIN_TEXT.clone());
 954                let mut spans = retriever
 955                    .parse_file_with_template(None, &snapshot.text(), language)
 956                    .log_err()
 957                    .unwrap_or_default();
 958                if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
 959                    .await
 960                    .log_err()
 961                    .is_some()
 962                {
 963                    for span in spans {
 964                        let similarity = span.embedding.unwrap().similarity(&query);
 965                        let ix = match results
 966                            .binary_search_by_key(&Reverse(similarity), |result| {
 967                                Reverse(result.similarity)
 968                            }) {
 969                            Ok(ix) => ix,
 970                            Err(ix) => ix,
 971                        };
 972
 973                        let range = {
 974                            let start = snapshot.clip_offset(span.range.start, Bias::Left);
 975                            let end = snapshot.clip_offset(span.range.end, Bias::Right);
 976                            snapshot.anchor_before(start)..snapshot.anchor_after(end)
 977                        };
 978
 979                        results.insert(
 980                            ix,
 981                            SearchResult {
 982                                buffer: buffer.clone(),
 983                                range,
 984                                similarity,
 985                            },
 986                        );
 987                        results.truncate(limit);
 988                    }
 989                }
 990            }
 991
 992            Ok(results)
 993        })
 994    }
 995
 996    pub fn index_project(
 997        &mut self,
 998        project: ModelHandle<Project>,
 999        cx: &mut ModelContext<Self>,
1000    ) -> Task<Result<()>> {
1001        if !self.is_authenticated() {
1002            if !self.authenticate(cx) {
1003                return Task::ready(Err(anyhow!("user is not authenticated")));
1004            }
1005        }
1006
1007        if !self.projects.contains_key(&project.downgrade()) {
1008            let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
1009                project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
1010                    this.project_worktrees_changed(project.clone(), cx);
1011                }
1012                project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
1013                    this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
1014                }
1015                _ => {}
1016            });
1017            let project_state = ProjectState::new(subscription, cx);
1018            self.projects.insert(project.downgrade(), project_state);
1019            self.project_worktrees_changed(project.clone(), cx);
1020        }
1021        let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
1022        project_state.pending_index += 1;
1023        cx.notify();
1024
1025        let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
1026        let db = self.db.clone();
1027        let language_registry = self.language_registry.clone();
1028        let parsing_files_tx = self.parsing_files_tx.clone();
1029        let worktree_registration = self.wait_for_worktree_registration(&project, cx);
1030
1031        cx.spawn(|this, mut cx| async move {
1032            worktree_registration.await?;
1033
1034            let mut pending_files = Vec::new();
1035            let mut files_to_delete = Vec::new();
1036            this.update(&mut cx, |this, cx| {
1037                let project_state = this
1038                    .projects
1039                    .get_mut(&project.downgrade())
1040                    .ok_or_else(|| anyhow!("project was dropped"))?;
1041                let pending_file_count_tx = &project_state.pending_file_count_tx;
1042
1043                project_state
1044                    .worktrees
1045                    .retain(|worktree_id, worktree_state| {
1046                        let worktree = if let Some(worktree) =
1047                            project.read(cx).worktree_for_id(*worktree_id, cx)
1048                        {
1049                            worktree
1050                        } else {
1051                            return false;
1052                        };
1053                        let worktree_state =
1054                            if let WorktreeState::Registered(worktree_state) = worktree_state {
1055                                worktree_state
1056                            } else {
1057                                return true;
1058                            };
1059
1060                        worktree_state.changed_paths.retain(|path, info| {
1061                            if info.is_deleted {
1062                                files_to_delete.push((worktree_state.db_id, path.clone()));
1063                            } else {
1064                                let absolute_path = worktree.read(cx).absolutize(path);
1065                                let job_handle = JobHandle::new(pending_file_count_tx);
1066                                pending_files.push(PendingFile {
1067                                    absolute_path,
1068                                    relative_path: path.clone(),
1069                                    language: None,
1070                                    job_handle,
1071                                    modified_time: info.mtime,
1072                                    worktree_db_id: worktree_state.db_id,
1073                                });
1074                            }
1075
1076                            false
1077                        });
1078                        true
1079                    });
1080
1081                anyhow::Ok(())
1082            })?;
1083
1084            cx.background()
1085                .spawn(async move {
1086                    for (worktree_db_id, path) in files_to_delete {
1087                        db.delete_file(worktree_db_id, path).await.log_err();
1088                    }
1089
1090                    let embeddings_for_digest = {
1091                        let mut files = HashMap::default();
1092                        for pending_file in &pending_files {
1093                            files
1094                                .entry(pending_file.worktree_db_id)
1095                                .or_insert(Vec::new())
1096                                .push(pending_file.relative_path.clone());
1097                        }
1098                        Arc::new(
1099                            db.embeddings_for_files(files)
1100                                .await
1101                                .log_err()
1102                                .unwrap_or_default(),
1103                        )
1104                    };
1105
1106                    for mut pending_file in pending_files {
1107                        if let Ok(language) = language_registry
1108                            .language_for_file(&pending_file.relative_path, None)
1109                            .await
1110                        {
1111                            if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
1112                                && &language.name().as_ref() != &"Markdown"
1113                                && language
1114                                    .grammar()
1115                                    .and_then(|grammar| grammar.embedding_config.as_ref())
1116                                    .is_none()
1117                            {
1118                                continue;
1119                            }
1120                            pending_file.language = Some(language);
1121                        }
1122                        parsing_files_tx
1123                            .try_send((embeddings_for_digest.clone(), pending_file))
1124                            .ok();
1125                    }
1126
1127                    // Wait until we're done indexing.
1128                    while let Some(count) = pending_file_count_rx.next().await {
1129                        if count == 0 {
1130                            break;
1131                        }
1132                    }
1133                })
1134                .await;
1135
1136            this.update(&mut cx, |this, cx| {
1137                let project_state = this
1138                    .projects
1139                    .get_mut(&project.downgrade())
1140                    .ok_or_else(|| anyhow!("project was dropped"))?;
1141                project_state.pending_index -= 1;
1142                cx.notify();
1143                anyhow::Ok(())
1144            })?;
1145
1146            Ok(())
1147        })
1148    }
1149
1150    fn wait_for_worktree_registration(
1151        &self,
1152        project: &ModelHandle<Project>,
1153        cx: &mut ModelContext<Self>,
1154    ) -> Task<Result<()>> {
1155        let project = project.downgrade();
1156        cx.spawn_weak(|this, cx| async move {
1157            loop {
1158                let mut pending_worktrees = Vec::new();
1159                this.upgrade(&cx)
1160                    .ok_or_else(|| anyhow!("semantic index dropped"))?
1161                    .read_with(&cx, |this, _| {
1162                        if let Some(project) = this.projects.get(&project) {
1163                            for worktree in project.worktrees.values() {
1164                                if let WorktreeState::Registering(worktree) = worktree {
1165                                    pending_worktrees.push(worktree.done());
1166                                }
1167                            }
1168                        }
1169                    });
1170
1171                if pending_worktrees.is_empty() {
1172                    break;
1173                } else {
1174                    future::join_all(pending_worktrees).await;
1175                }
1176            }
1177            Ok(())
1178        })
1179    }
1180
1181    async fn embed_spans(
1182        spans: &mut [Span],
1183        embedding_provider: &dyn EmbeddingProvider,
1184        db: &VectorDatabase,
1185    ) -> Result<()> {
1186        let mut batch = Vec::new();
1187        let mut batch_tokens = 0;
1188        let mut embeddings = Vec::new();
1189
1190        let digests = spans
1191            .iter()
1192            .map(|span| span.digest.clone())
1193            .collect::<Vec<_>>();
1194        let embeddings_for_digests = db
1195            .embeddings_for_digests(digests)
1196            .await
1197            .log_err()
1198            .unwrap_or_default();
1199
1200        for span in &*spans {
1201            if embeddings_for_digests.contains_key(&span.digest) {
1202                continue;
1203            };
1204
1205            if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
1206                let batch_embeddings = embedding_provider
1207                    .embed_batch(mem::take(&mut batch))
1208                    .await?;
1209                embeddings.extend(batch_embeddings);
1210                batch_tokens = 0;
1211            }
1212
1213            batch_tokens += span.token_count;
1214            batch.push(span.content.clone());
1215        }
1216
1217        if !batch.is_empty() {
1218            let batch_embeddings = embedding_provider
1219                .embed_batch(mem::take(&mut batch))
1220                .await?;
1221
1222            embeddings.extend(batch_embeddings);
1223        }
1224
1225        let mut embeddings = embeddings.into_iter();
1226        for span in spans {
1227            let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
1228                Some(embedding.clone())
1229            } else {
1230                embeddings.next()
1231            };
1232            let embedding = embedding.ok_or_else(|| anyhow!("failed to embed spans"))?;
1233            span.embedding = Some(embedding);
1234        }
1235        Ok(())
1236    }
1237}
1238
1239impl Entity for SemanticIndex {
1240    type Event = ();
1241}
1242
1243impl Drop for JobHandle {
1244    fn drop(&mut self) {
1245        if let Some(inner) = Arc::get_mut(&mut self.tx) {
1246            // This is the last instance of the JobHandle (regardless of it's origin - whether it was cloned or not)
1247            if let Some(tx) = inner.upgrade() {
1248                let mut tx = tx.lock();
1249                *tx.borrow_mut() -= 1;
1250            }
1251        }
1252    }
1253}
1254
1255#[cfg(test)]
1256mod tests {
1257
1258    use super::*;
1259    #[test]
1260    fn test_job_handle() {
1261        let (job_count_tx, job_count_rx) = watch::channel_with(0);
1262        let tx = Arc::new(Mutex::new(job_count_tx));
1263        let job_handle = JobHandle::new(&tx);
1264
1265        assert_eq!(1, *job_count_rx.borrow());
1266        let new_job_handle = job_handle.clone();
1267        assert_eq!(1, *job_count_rx.borrow());
1268        drop(job_handle);
1269        assert_eq!(1, *job_count_rx.borrow());
1270        drop(new_job_handle);
1271        assert_eq!(0, *job_count_rx.borrow());
1272    }
1273}