project_index.rs

  1use crate::{
  2    embedding::{EmbeddingProvider, TextToEmbed},
  3    summary_index::FileSummary,
  4    worktree_index::{WorktreeIndex, WorktreeIndexHandle},
  5};
  6use anyhow::{anyhow, Context, Result};
  7use collections::HashMap;
  8use fs::Fs;
  9use futures::{stream::StreamExt, FutureExt};
 10use gpui::{
 11    AppContext, Entity, EntityId, EventEmitter, Model, ModelContext, Subscription, Task, WeakModel,
 12};
 13use language::LanguageRegistry;
 14use log;
 15use project::{Project, Worktree, WorktreeId};
 16use serde::{Deserialize, Serialize};
 17use smol::channel;
 18use std::{cmp::Ordering, future::Future, num::NonZeroUsize, ops::Range, path::Path, sync::Arc};
 19use util::ResultExt;
 20
 21#[derive(Debug)]
 22pub struct SearchResult {
 23    pub worktree: Model<Worktree>,
 24    pub path: Arc<Path>,
 25    pub range: Range<usize>,
 26    pub score: f32,
 27}
 28
 29pub struct WorktreeSearchResult {
 30    pub worktree_id: WorktreeId,
 31    pub path: Arc<Path>,
 32    pub range: Range<usize>,
 33    pub score: f32,
 34}
 35
 36#[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
 37pub enum Status {
 38    Idle,
 39    Loading,
 40    Scanning { remaining_count: NonZeroUsize },
 41}
 42
 43pub struct ProjectIndex {
 44    db_connection: heed::Env,
 45    project: WeakModel<Project>,
 46    worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
 47    language_registry: Arc<LanguageRegistry>,
 48    fs: Arc<dyn Fs>,
 49    last_status: Status,
 50    status_tx: channel::Sender<()>,
 51    embedding_provider: Arc<dyn EmbeddingProvider>,
 52    _maintain_status: Task<()>,
 53    _subscription: Subscription,
 54}
 55
 56impl ProjectIndex {
 57    pub fn new(
 58        project: Model<Project>,
 59        db_connection: heed::Env,
 60        embedding_provider: Arc<dyn EmbeddingProvider>,
 61        cx: &mut ModelContext<Self>,
 62    ) -> Self {
 63        let language_registry = project.read(cx).languages().clone();
 64        let fs = project.read(cx).fs().clone();
 65        let (status_tx, mut status_rx) = channel::unbounded();
 66        let mut this = ProjectIndex {
 67            db_connection,
 68            project: project.downgrade(),
 69            worktree_indices: HashMap::default(),
 70            language_registry,
 71            fs,
 72            status_tx,
 73            last_status: Status::Idle,
 74            embedding_provider,
 75            _subscription: cx.subscribe(&project, Self::handle_project_event),
 76            _maintain_status: cx.spawn(|this, mut cx| async move {
 77                while status_rx.next().await.is_some() {
 78                    if this
 79                        .update(&mut cx, |this, cx| this.update_status(cx))
 80                        .is_err()
 81                    {
 82                        break;
 83                    }
 84                }
 85            }),
 86        };
 87        this.update_worktree_indices(cx);
 88        this
 89    }
 90
 91    pub fn status(&self) -> Status {
 92        self.last_status
 93    }
 94
 95    pub fn project(&self) -> WeakModel<Project> {
 96        self.project.clone()
 97    }
 98
 99    pub fn fs(&self) -> Arc<dyn Fs> {
100        self.fs.clone()
101    }
102
103    fn handle_project_event(
104        &mut self,
105        _: Model<Project>,
106        event: &project::Event,
107        cx: &mut ModelContext<Self>,
108    ) {
109        match event {
110            project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
111                self.update_worktree_indices(cx);
112            }
113            _ => {}
114        }
115    }
116
117    fn update_worktree_indices(&mut self, cx: &mut ModelContext<Self>) {
118        let Some(project) = self.project.upgrade() else {
119            return;
120        };
121
122        let worktrees = project
123            .read(cx)
124            .visible_worktrees(cx)
125            .filter_map(|worktree| {
126                if worktree.read(cx).is_local() {
127                    Some((worktree.entity_id(), worktree))
128                } else {
129                    None
130                }
131            })
132            .collect::<HashMap<_, _>>();
133
134        self.worktree_indices
135            .retain(|worktree_id, _| worktrees.contains_key(worktree_id));
136        for (worktree_id, worktree) in worktrees {
137            self.worktree_indices.entry(worktree_id).or_insert_with(|| {
138                let worktree_index = WorktreeIndex::load(
139                    worktree.clone(),
140                    self.db_connection.clone(),
141                    self.language_registry.clone(),
142                    self.fs.clone(),
143                    self.status_tx.clone(),
144                    self.embedding_provider.clone(),
145                    cx,
146                );
147
148                let load_worktree = cx.spawn(|this, mut cx| async move {
149                    let result = match worktree_index.await {
150                        Ok(worktree_index) => {
151                            this.update(&mut cx, |this, _| {
152                                this.worktree_indices.insert(
153                                    worktree_id,
154                                    WorktreeIndexHandle::Loaded {
155                                        index: worktree_index.clone(),
156                                    },
157                                );
158                            })?;
159                            Ok(worktree_index)
160                        }
161                        Err(error) => {
162                            this.update(&mut cx, |this, _cx| {
163                                this.worktree_indices.remove(&worktree_id)
164                            })?;
165                            Err(Arc::new(error))
166                        }
167                    };
168
169                    this.update(&mut cx, |this, cx| this.update_status(cx))?;
170
171                    result
172                });
173
174                WorktreeIndexHandle::Loading {
175                    index: load_worktree.shared(),
176                }
177            });
178        }
179
180        self.update_status(cx);
181    }
182
183    fn update_status(&mut self, cx: &mut ModelContext<Self>) {
184        let mut indexing_count = 0;
185        let mut any_loading = false;
186
187        for index in self.worktree_indices.values_mut() {
188            match index {
189                WorktreeIndexHandle::Loading { .. } => {
190                    any_loading = true;
191                    break;
192                }
193                WorktreeIndexHandle::Loaded { index, .. } => {
194                    indexing_count += index.read(cx).entry_ids_being_indexed().len();
195                }
196            }
197        }
198
199        let status = if any_loading {
200            Status::Loading
201        } else if let Some(remaining_count) = NonZeroUsize::new(indexing_count) {
202            Status::Scanning { remaining_count }
203        } else {
204            Status::Idle
205        };
206
207        if status != self.last_status {
208            self.last_status = status;
209            cx.emit(status);
210        }
211    }
212
213    pub fn search(
214        &self,
215        query: String,
216        limit: usize,
217        cx: &AppContext,
218    ) -> Task<Result<Vec<SearchResult>>> {
219        let (chunks_tx, chunks_rx) = channel::bounded(1024);
220        let mut worktree_scan_tasks = Vec::new();
221        for worktree_index in self.worktree_indices.values() {
222            let worktree_index = worktree_index.clone();
223            let chunks_tx = chunks_tx.clone();
224            worktree_scan_tasks.push(cx.spawn(|cx| async move {
225                let index = match worktree_index {
226                    WorktreeIndexHandle::Loading { index } => {
227                        index.clone().await.map_err(|error| anyhow!(error))?
228                    }
229                    WorktreeIndexHandle::Loaded { index } => index.clone(),
230                };
231
232                index
233                    .read_with(&cx, |index, cx| {
234                        let worktree_id = index.worktree().read(cx).id();
235                        let db_connection = index.db_connection().clone();
236                        let db = *index.embedding_index().db();
237                        cx.background_executor().spawn(async move {
238                            let txn = db_connection
239                                .read_txn()
240                                .context("failed to create read transaction")?;
241                            let db_entries = db.iter(&txn).context("failed to iterate database")?;
242                            for db_entry in db_entries {
243                                let (_key, db_embedded_file) = db_entry?;
244                                for chunk in db_embedded_file.chunks {
245                                    chunks_tx
246                                        .send((worktree_id, db_embedded_file.path.clone(), chunk))
247                                        .await?;
248                                }
249                            }
250                            anyhow::Ok(())
251                        })
252                    })?
253                    .await
254            }));
255        }
256        drop(chunks_tx);
257
258        let project = self.project.clone();
259        let embedding_provider = self.embedding_provider.clone();
260        cx.spawn(|cx| async move {
261            #[cfg(debug_assertions)]
262            let embedding_query_start = std::time::Instant::now();
263            log::info!("Searching for {query}");
264
265            let query_embeddings = embedding_provider
266                .embed(&[TextToEmbed::new(&query)])
267                .await?;
268            let query_embedding = query_embeddings
269                .into_iter()
270                .next()
271                .ok_or_else(|| anyhow!("no embedding for query"))?;
272
273            let mut results_by_worker = Vec::new();
274            for _ in 0..cx.background_executor().num_cpus() {
275                results_by_worker.push(Vec::<WorktreeSearchResult>::new());
276            }
277
278            #[cfg(debug_assertions)]
279            let search_start = std::time::Instant::now();
280
281            cx.background_executor()
282                .scoped(|cx| {
283                    for results in results_by_worker.iter_mut() {
284                        cx.spawn(async {
285                            while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
286                                let score = chunk.embedding.similarity(&query_embedding);
287                                let ix = match results.binary_search_by(|probe| {
288                                    score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
289                                }) {
290                                    Ok(ix) | Err(ix) => ix,
291                                };
292                                results.insert(
293                                    ix,
294                                    WorktreeSearchResult {
295                                        worktree_id,
296                                        path: path.clone(),
297                                        range: chunk.chunk.range.clone(),
298                                        score,
299                                    },
300                                );
301                                results.truncate(limit);
302                            }
303                        });
304                    }
305                })
306                .await;
307
308            for scan_task in futures::future::join_all(worktree_scan_tasks).await {
309                scan_task.log_err();
310            }
311
312            project.read_with(&cx, |project, cx| {
313                let mut search_results = Vec::with_capacity(results_by_worker.len() * limit);
314                for worker_results in results_by_worker {
315                    search_results.extend(worker_results.into_iter().filter_map(|result| {
316                        Some(SearchResult {
317                            worktree: project.worktree_for_id(result.worktree_id, cx)?,
318                            path: result.path,
319                            range: result.range,
320                            score: result.score,
321                        })
322                    }));
323                }
324                search_results.sort_unstable_by(|a, b| {
325                    b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
326                });
327                search_results.truncate(limit);
328
329                #[cfg(debug_assertions)]
330                {
331                    let search_elapsed = search_start.elapsed();
332                    log::debug!(
333                        "searched {} entries in {:?}",
334                        search_results.len(),
335                        search_elapsed
336                    );
337                    let embedding_query_elapsed = embedding_query_start.elapsed();
338                    log::debug!("embedding query took {:?}", embedding_query_elapsed);
339                }
340
341                search_results
342            })
343        })
344    }
345
346    #[cfg(test)]
347    pub fn path_count(&self, cx: &AppContext) -> Result<u64> {
348        let mut result = 0;
349        for worktree_index in self.worktree_indices.values() {
350            if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
351                result += index.read(cx).path_count()?;
352            }
353        }
354        Ok(result)
355    }
356
357    pub(crate) fn worktree_index(
358        &self,
359        worktree_id: WorktreeId,
360        cx: &AppContext,
361    ) -> Option<Model<WorktreeIndex>> {
362        for index in self.worktree_indices.values() {
363            if let WorktreeIndexHandle::Loaded { index, .. } = index {
364                if index.read(cx).worktree().read(cx).id() == worktree_id {
365                    return Some(index.clone());
366                }
367            }
368        }
369        None
370    }
371
372    pub(crate) fn worktree_indices(&self, cx: &AppContext) -> Vec<Model<WorktreeIndex>> {
373        let mut result = self
374            .worktree_indices
375            .values()
376            .filter_map(|index| {
377                if let WorktreeIndexHandle::Loaded { index, .. } = index {
378                    Some(index.clone())
379                } else {
380                    None
381                }
382            })
383            .collect::<Vec<_>>();
384        result.sort_by_key(|index| index.read(cx).worktree().read(cx).id());
385        result
386    }
387
388    pub fn all_summaries(&self, cx: &AppContext) -> Task<Result<Vec<FileSummary>>> {
389        let (summaries_tx, summaries_rx) = channel::bounded(1024);
390        let mut worktree_scan_tasks = Vec::new();
391        for worktree_index in self.worktree_indices.values() {
392            let worktree_index = worktree_index.clone();
393            let summaries_tx: channel::Sender<(String, String)> = summaries_tx.clone();
394            worktree_scan_tasks.push(cx.spawn(|cx| async move {
395                let index = match worktree_index {
396                    WorktreeIndexHandle::Loading { index } => {
397                        index.clone().await.map_err(|error| anyhow!(error))?
398                    }
399                    WorktreeIndexHandle::Loaded { index } => index.clone(),
400                };
401
402                index
403                    .read_with(&cx, |index, cx| {
404                        let db_connection = index.db_connection().clone();
405                        let summary_index = index.summary_index();
406                        let file_digest_db = summary_index.file_digest_db();
407                        let summary_db = summary_index.summary_db();
408
409                        cx.background_executor().spawn(async move {
410                            let txn = db_connection
411                                .read_txn()
412                                .context("failed to create db read transaction")?;
413                            let db_entries = file_digest_db
414                                .iter(&txn)
415                                .context("failed to iterate database")?;
416                            for db_entry in db_entries {
417                                let (file_path, db_file) = db_entry?;
418
419                                match summary_db.get(&txn, &db_file.digest) {
420                                    Ok(opt_summary) => {
421                                        // Currently, we only use summaries we already have. If the file hasn't been
422                                        // summarized yet, then we skip it and don't include it in the inferred context.
423                                        // If we want to do just-in-time summarization, this would be the place to do it!
424                                        if let Some(summary) = opt_summary {
425                                            summaries_tx
426                                                .send((file_path.to_string(), summary.to_string()))
427                                                .await?;
428                                        } else {
429                                            log::warn!("No summary found for {:?}", &db_file);
430                                        }
431                                    }
432                                    Err(err) => {
433                                        log::error!(
434                                            "Error reading from summary database: {:?}",
435                                            err
436                                        );
437                                    }
438                                }
439                            }
440                            anyhow::Ok(())
441                        })
442                    })?
443                    .await
444            }));
445        }
446        drop(summaries_tx);
447
448        let project = self.project.clone();
449        cx.spawn(|cx| async move {
450            let mut results_by_worker = Vec::new();
451            for _ in 0..cx.background_executor().num_cpus() {
452                results_by_worker.push(Vec::<FileSummary>::new());
453            }
454
455            cx.background_executor()
456                .scoped(|cx| {
457                    for results in results_by_worker.iter_mut() {
458                        cx.spawn(async {
459                            while let Ok((filename, summary)) = summaries_rx.recv().await {
460                                results.push(FileSummary { filename, summary });
461                            }
462                        });
463                    }
464                })
465                .await;
466
467            for scan_task in futures::future::join_all(worktree_scan_tasks).await {
468                scan_task.log_err();
469            }
470
471            project.read_with(&cx, |_project, _cx| {
472                results_by_worker.into_iter().flatten().collect()
473            })
474        })
475    }
476
477    /// Empty out the backlogs of all the worktrees in the project
478    pub fn flush_summary_backlogs(&self, cx: &AppContext) -> impl Future<Output = ()> {
479        let flush_start = std::time::Instant::now();
480
481        futures::future::join_all(self.worktree_indices.values().map(|worktree_index| {
482            let worktree_index = worktree_index.clone();
483
484            cx.spawn(|cx| async move {
485                let index = match worktree_index {
486                    WorktreeIndexHandle::Loading { index } => {
487                        index.clone().await.map_err(|error| anyhow!(error))?
488                    }
489                    WorktreeIndexHandle::Loaded { index } => index.clone(),
490                };
491                let worktree_abs_path =
492                    cx.update(|cx| index.read(cx).worktree().read(cx).abs_path())?;
493
494                index
495                    .read_with(&cx, |index, cx| {
496                        cx.background_executor()
497                            .spawn(index.summary_index().flush_backlog(worktree_abs_path, cx))
498                    })?
499                    .await
500            })
501        }))
502        .map(move |results| {
503            // Log any errors, but don't block the user. These summaries are supposed to
504            // improve quality by providing extra context, but they aren't hard requirements!
505            for result in results {
506                if let Err(err) = result {
507                    log::error!("Error flushing summary backlog: {:?}", err);
508                }
509            }
510
511            log::info!("Summary backlog flushed in {:?}", flush_start.elapsed());
512        })
513    }
514
515    pub fn remaining_summaries(&self, cx: &mut ModelContext<Self>) -> usize {
516        self.worktree_indices(cx)
517            .iter()
518            .map(|index| index.read(cx).summary_index().backlog_len())
519            .sum()
520    }
521}
522
523impl EventEmitter<Status> for ProjectIndex {}