1mod db;
2mod embedding_queue;
3mod parsing;
4pub mod semantic_index_settings;
5
6#[cfg(test)]
7mod semantic_index_tests;
8
9use crate::semantic_index_settings::SemanticIndexSettings;
10use ai::embedding::{Embedding, EmbeddingProvider};
11use ai::providers::open_ai::OpenAiEmbeddingProvider;
12use anyhow::{anyhow, Context as _, Result};
13use collections::{BTreeMap, HashMap, HashSet};
14use db::VectorDatabase;
15use embedding_queue::{EmbeddingQueue, FileToEmbed};
16use futures::{future, FutureExt, StreamExt};
17use gpui::{
18 AppContext, AsyncAppContext, BorrowWindow, Context, Global, Model, ModelContext, Task,
19 ViewContext, WeakModel,
20};
21use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
22use lazy_static::lazy_static;
23use ordered_float::OrderedFloat;
24use parking_lot::Mutex;
25use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
26use postage::watch;
27use project::{Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
28use release_channel::ReleaseChannel;
29use settings::Settings;
30use smol::channel;
31use std::{
32 cmp::Reverse,
33 env,
34 future::Future,
35 mem,
36 ops::Range,
37 path::{Path, PathBuf},
38 sync::{Arc, Weak},
39 time::{Duration, Instant, SystemTime},
40};
41use util::paths::PathMatcher;
42use util::{http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
43use workspace::Workspace;
44
45const SEMANTIC_INDEX_VERSION: usize = 11;
46const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
47const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
48
49lazy_static! {
50 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
51}
52
53pub fn init(
54 fs: Arc<dyn Fs>,
55 http_client: Arc<dyn HttpClient>,
56 language_registry: Arc<LanguageRegistry>,
57 cx: &mut AppContext,
58) {
59 SemanticIndexSettings::register(cx);
60
61 let db_file_path = EMBEDDINGS_DIR
62 .join(Path::new(ReleaseChannel::global(cx).dev_name()))
63 .join("embeddings_db");
64
65 cx.observe_new_views(
66 |workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
67 let Some(semantic_index) = SemanticIndex::global(cx) else {
68 return;
69 };
70 let project = workspace.project().clone();
71
72 if project.read(cx).is_local() {
73 cx.app_mut()
74 .spawn(|mut cx| async move {
75 let previously_indexed = semantic_index
76 .update(&mut cx, |index, cx| {
77 index.project_previously_indexed(&project, cx)
78 })?
79 .await?;
80 if previously_indexed {
81 semantic_index
82 .update(&mut cx, |index, cx| index.index_project(project, cx))?
83 .await?;
84 }
85 anyhow::Ok(())
86 })
87 .detach_and_log_err(cx);
88 }
89 },
90 )
91 .detach();
92
93 cx.spawn(move |cx| async move {
94 let embedding_provider =
95 OpenAiEmbeddingProvider::new(http_client, cx.background_executor().clone()).await;
96 let semantic_index = SemanticIndex::new(
97 fs,
98 db_file_path,
99 Arc::new(embedding_provider),
100 language_registry,
101 cx.clone(),
102 )
103 .await?;
104
105 cx.update(|cx| cx.set_global(GlobalSemanticIndex(semantic_index.clone())))?;
106
107 anyhow::Ok(())
108 })
109 .detach();
110}
111
112#[derive(Copy, Clone, Debug)]
113pub enum SemanticIndexStatus {
114 NotAuthenticated,
115 NotIndexed,
116 Indexed,
117 Indexing {
118 remaining_files: usize,
119 rate_limit_expiry: Option<Instant>,
120 },
121}
122
123pub struct SemanticIndex {
124 fs: Arc<dyn Fs>,
125 db: VectorDatabase,
126 embedding_provider: Arc<dyn EmbeddingProvider>,
127 language_registry: Arc<LanguageRegistry>,
128 parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
129 _embedding_task: Task<()>,
130 _parsing_files_tasks: Vec<Task<()>>,
131 projects: HashMap<WeakModel<Project>, ProjectState>,
132}
133
134struct GlobalSemanticIndex(Model<SemanticIndex>);
135
136impl Global for GlobalSemanticIndex {}
137
138struct ProjectState {
139 worktrees: HashMap<WorktreeId, WorktreeState>,
140 pending_file_count_rx: watch::Receiver<usize>,
141 pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
142 pending_index: usize,
143 _subscription: gpui::Subscription,
144 _observe_pending_file_count: Task<()>,
145}
146
147enum WorktreeState {
148 Registering(RegisteringWorktreeState),
149 Registered(RegisteredWorktreeState),
150}
151
152impl WorktreeState {
153 fn is_registered(&self) -> bool {
154 matches!(self, Self::Registered(_))
155 }
156
157 fn paths_changed(
158 &mut self,
159 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
160 worktree: &Worktree,
161 ) {
162 let changed_paths = match self {
163 Self::Registering(state) => &mut state.changed_paths,
164 Self::Registered(state) => &mut state.changed_paths,
165 };
166
167 for (path, entry_id, change) in changes.iter() {
168 let Some(entry) = worktree.entry_for_id(*entry_id) else {
169 continue;
170 };
171 if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
172 continue;
173 }
174 changed_paths.insert(
175 path.clone(),
176 ChangedPathInfo {
177 mtime: entry.mtime,
178 is_deleted: *change == PathChange::Removed,
179 },
180 );
181 }
182 }
183}
184
185struct RegisteringWorktreeState {
186 changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
187 done_rx: watch::Receiver<Option<()>>,
188 _registration: Task<()>,
189}
190
191impl RegisteringWorktreeState {
192 fn done(&self) -> impl Future<Output = ()> {
193 let mut done_rx = self.done_rx.clone();
194 async move {
195 while let Some(result) = done_rx.next().await {
196 if result.is_some() {
197 break;
198 }
199 }
200 }
201 }
202}
203
204struct RegisteredWorktreeState {
205 db_id: i64,
206 changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
207}
208
209struct ChangedPathInfo {
210 mtime: SystemTime,
211 is_deleted: bool,
212}
213
214#[derive(Clone)]
215pub struct JobHandle {
216 /// The outer Arc is here to count the clones of a JobHandle instance;
217 /// when the last handle to a given job is dropped, we decrement a counter (just once).
218 tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
219}
220
221impl JobHandle {
222 fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
223 *tx.lock().borrow_mut() += 1;
224 Self {
225 tx: Arc::new(Arc::downgrade(&tx)),
226 }
227 }
228}
229
230impl ProjectState {
231 fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
232 let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
233 let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
234 Self {
235 worktrees: Default::default(),
236 pending_file_count_rx: pending_file_count_rx.clone(),
237 pending_file_count_tx,
238 pending_index: 0,
239 _subscription: subscription,
240 _observe_pending_file_count: cx.spawn({
241 let mut pending_file_count_rx = pending_file_count_rx.clone();
242 |this, mut cx| async move {
243 while let Some(_) = pending_file_count_rx.next().await {
244 if this.update(&mut cx, |_, cx| cx.notify()).is_err() {
245 break;
246 }
247 }
248 }
249 }),
250 }
251 }
252
253 fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
254 self.worktrees
255 .iter()
256 .find_map(|(worktree_id, worktree_state)| match worktree_state {
257 WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
258 _ => None,
259 })
260 }
261}
262
263#[derive(Clone)]
264pub struct PendingFile {
265 worktree_db_id: i64,
266 relative_path: Arc<Path>,
267 absolute_path: PathBuf,
268 language: Option<Arc<Language>>,
269 modified_time: SystemTime,
270 job_handle: JobHandle,
271}
272
273#[derive(Clone)]
274pub struct SearchResult {
275 pub buffer: Model<Buffer>,
276 pub range: Range<Anchor>,
277 pub similarity: OrderedFloat<f32>,
278}
279
280impl SemanticIndex {
281 pub fn global(cx: &mut AppContext) -> Option<Model<SemanticIndex>> {
282 cx.try_global::<GlobalSemanticIndex>()
283 .map(|semantic_index| semantic_index.0.clone())
284 }
285
286 pub fn authenticate(&mut self, cx: &mut AppContext) -> Task<bool> {
287 if !self.embedding_provider.has_credentials() {
288 let embedding_provider = self.embedding_provider.clone();
289 cx.spawn(|cx| async move {
290 if let Some(retrieve_credentials) = cx
291 .update(|cx| embedding_provider.retrieve_credentials(cx))
292 .log_err()
293 {
294 retrieve_credentials.await;
295 }
296
297 embedding_provider.has_credentials()
298 })
299 } else {
300 Task::ready(true)
301 }
302 }
303
304 pub fn is_authenticated(&self) -> bool {
305 self.embedding_provider.has_credentials()
306 }
307
308 pub fn enabled(cx: &AppContext) -> bool {
309 SemanticIndexSettings::get_global(cx).enabled
310 }
311
312 pub fn status(&self, project: &Model<Project>) -> SemanticIndexStatus {
313 if !self.is_authenticated() {
314 return SemanticIndexStatus::NotAuthenticated;
315 }
316
317 if let Some(project_state) = self.projects.get(&project.downgrade()) {
318 if project_state
319 .worktrees
320 .values()
321 .all(|worktree| worktree.is_registered())
322 && project_state.pending_index == 0
323 {
324 SemanticIndexStatus::Indexed
325 } else {
326 SemanticIndexStatus::Indexing {
327 remaining_files: project_state.pending_file_count_rx.borrow().clone(),
328 rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
329 }
330 }
331 } else {
332 SemanticIndexStatus::NotIndexed
333 }
334 }
335
336 pub async fn new(
337 fs: Arc<dyn Fs>,
338 database_path: PathBuf,
339 embedding_provider: Arc<dyn EmbeddingProvider>,
340 language_registry: Arc<LanguageRegistry>,
341 mut cx: AsyncAppContext,
342 ) -> Result<Model<Self>> {
343 let t0 = Instant::now();
344 let database_path = Arc::from(database_path);
345 let db = VectorDatabase::new(fs.clone(), database_path, cx.background_executor().clone())
346 .await?;
347
348 log::trace!(
349 "db initialization took {:?} milliseconds",
350 t0.elapsed().as_millis()
351 );
352
353 cx.new_model(|cx| {
354 let t0 = Instant::now();
355 let embedding_queue =
356 EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor().clone());
357 let _embedding_task = cx.background_executor().spawn({
358 let embedded_files = embedding_queue.finished_files();
359 let db = db.clone();
360 async move {
361 while let Ok(file) = embedded_files.recv().await {
362 db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
363 .await
364 .log_err();
365 }
366 }
367 });
368
369 // Parse files into embeddable spans.
370 let (parsing_files_tx, parsing_files_rx) =
371 channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
372 let embedding_queue = Arc::new(Mutex::new(embedding_queue));
373 let mut _parsing_files_tasks = Vec::new();
374 for _ in 0..cx.background_executor().num_cpus() {
375 let fs = fs.clone();
376 let mut parsing_files_rx = parsing_files_rx.clone();
377 let embedding_provider = embedding_provider.clone();
378 let embedding_queue = embedding_queue.clone();
379 let background = cx.background_executor().clone();
380 _parsing_files_tasks.push(cx.background_executor().spawn(async move {
381 let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
382 loop {
383 let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
384 let mut next_file_to_parse = parsing_files_rx.next().fuse();
385 futures::select_biased! {
386 next_file_to_parse = next_file_to_parse => {
387 if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
388 Self::parse_file(
389 &fs,
390 pending_file,
391 &mut retriever,
392 &embedding_queue,
393 &embeddings_for_digest,
394 )
395 .await
396 } else {
397 break;
398 }
399 },
400 _ = timer => {
401 embedding_queue.lock().flush();
402 }
403 }
404 }
405 }));
406 }
407
408 log::trace!(
409 "semantic index task initialization took {:?} milliseconds",
410 t0.elapsed().as_millis()
411 );
412 Self {
413 fs,
414 db,
415 embedding_provider,
416 language_registry,
417 parsing_files_tx,
418 _embedding_task,
419 _parsing_files_tasks,
420 projects: Default::default(),
421 }
422 })
423 }
424
425 async fn parse_file(
426 fs: &Arc<dyn Fs>,
427 pending_file: PendingFile,
428 retriever: &mut CodeContextRetriever,
429 embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
430 embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
431 ) {
432 let Some(language) = pending_file.language else {
433 return;
434 };
435
436 if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
437 if let Some(mut spans) = retriever
438 .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
439 .log_err()
440 {
441 log::trace!(
442 "parsed path {:?}: {} spans",
443 pending_file.relative_path,
444 spans.len()
445 );
446
447 for span in &mut spans {
448 if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
449 span.embedding = Some(embedding.to_owned());
450 }
451 }
452
453 embedding_queue.lock().push(FileToEmbed {
454 worktree_id: pending_file.worktree_db_id,
455 path: pending_file.relative_path,
456 mtime: pending_file.modified_time,
457 job_handle: pending_file.job_handle,
458 spans,
459 });
460 }
461 }
462 }
463
464 pub fn project_previously_indexed(
465 &mut self,
466 project: &Model<Project>,
467 cx: &mut ModelContext<Self>,
468 ) -> Task<Result<bool>> {
469 let worktrees_indexed_previously = project
470 .read(cx)
471 .worktrees()
472 .map(|worktree| {
473 self.db
474 .worktree_previously_indexed(&worktree.read(cx).abs_path())
475 })
476 .collect::<Vec<_>>();
477 cx.spawn(|_, _cx| async move {
478 let worktree_indexed_previously =
479 futures::future::join_all(worktrees_indexed_previously).await;
480
481 Ok(worktree_indexed_previously
482 .iter()
483 .filter(|worktree| worktree.is_ok())
484 .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
485 })
486 }
487
488 fn project_entries_changed(
489 &mut self,
490 project: Model<Project>,
491 worktree_id: WorktreeId,
492 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
493 cx: &mut ModelContext<Self>,
494 ) {
495 let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
496 return;
497 };
498 let project = project.downgrade();
499 let Some(project_state) = self.projects.get_mut(&project) else {
500 return;
501 };
502
503 let worktree = worktree.read(cx);
504 let worktree_state =
505 if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
506 worktree_state
507 } else {
508 return;
509 };
510 worktree_state.paths_changed(changes, worktree);
511 if let WorktreeState::Registered(_) = worktree_state {
512 cx.spawn(|this, mut cx| async move {
513 cx.background_executor()
514 .timer(BACKGROUND_INDEXING_DELAY)
515 .await;
516 if let Some((this, project)) = this.upgrade().zip(project.upgrade()) {
517 this.update(&mut cx, |this, cx| {
518 this.index_project(project, cx).detach_and_log_err(cx)
519 })?;
520 }
521 anyhow::Ok(())
522 })
523 .detach_and_log_err(cx);
524 }
525 }
526
527 fn register_worktree(
528 &mut self,
529 project: Model<Project>,
530 worktree: Model<Worktree>,
531 cx: &mut ModelContext<Self>,
532 ) {
533 let project = project.downgrade();
534 let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
535 project_state
536 } else {
537 return;
538 };
539 let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
540 worktree
541 } else {
542 return;
543 };
544 let worktree_abs_path = worktree.abs_path().clone();
545 let scan_complete = worktree.scan_complete();
546 let worktree_id = worktree.id();
547 let db = self.db.clone();
548 let language_registry = self.language_registry.clone();
549 let (mut done_tx, done_rx) = watch::channel();
550 let registration = cx.spawn(|this, mut cx| {
551 async move {
552 let register = async {
553 scan_complete.await;
554 let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
555 let mut file_mtimes = db.get_file_mtimes(db_id).await?;
556 let worktree = if let Some(project) = project.upgrade() {
557 project
558 .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
559 .ok()
560 .flatten()
561 .context("worktree not found")?
562 } else {
563 return anyhow::Ok(());
564 };
565 let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot())?;
566 let mut changed_paths = cx
567 .background_executor()
568 .spawn(async move {
569 let mut changed_paths = BTreeMap::new();
570 for file in worktree.files(false, 0) {
571 let absolute_path = worktree.absolutize(&file.path)?;
572
573 if file.is_external || file.is_ignored || file.is_symlink {
574 continue;
575 }
576
577 if let Ok(language) = language_registry
578 .language_for_file(&absolute_path, None)
579 .await
580 {
581 // Test if file is valid parseable file
582 if !PARSEABLE_ENTIRE_FILE_TYPES
583 .contains(&language.name().as_ref())
584 && &language.name().as_ref() != &"Markdown"
585 && language
586 .grammar()
587 .and_then(|grammar| grammar.embedding_config.as_ref())
588 .is_none()
589 {
590 continue;
591 }
592
593 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
594 let already_stored = stored_mtime
595 .map_or(false, |existing_mtime| {
596 existing_mtime == file.mtime
597 });
598
599 if !already_stored {
600 changed_paths.insert(
601 file.path.clone(),
602 ChangedPathInfo {
603 mtime: file.mtime,
604 is_deleted: false,
605 },
606 );
607 }
608 }
609 }
610
611 // Clean up entries from database that are no longer in the worktree.
612 for (path, mtime) in file_mtimes {
613 changed_paths.insert(
614 path.into(),
615 ChangedPathInfo {
616 mtime,
617 is_deleted: true,
618 },
619 );
620 }
621
622 anyhow::Ok(changed_paths)
623 })
624 .await?;
625 this.update(&mut cx, |this, cx| {
626 let project_state = this
627 .projects
628 .get_mut(&project)
629 .context("project not registered")?;
630 let project = project.upgrade().context("project was dropped")?;
631
632 if let Some(WorktreeState::Registering(state)) =
633 project_state.worktrees.remove(&worktree_id)
634 {
635 changed_paths.extend(state.changed_paths);
636 }
637 project_state.worktrees.insert(
638 worktree_id,
639 WorktreeState::Registered(RegisteredWorktreeState {
640 db_id,
641 changed_paths,
642 }),
643 );
644 this.index_project(project, cx).detach_and_log_err(cx);
645
646 anyhow::Ok(())
647 })??;
648
649 anyhow::Ok(())
650 };
651
652 if register.await.log_err().is_none() {
653 // Stop tracking this worktree if the registration failed.
654 this.update(&mut cx, |this, _| {
655 this.projects.get_mut(&project).map(|project_state| {
656 project_state.worktrees.remove(&worktree_id);
657 });
658 })
659 .ok();
660 }
661
662 *done_tx.borrow_mut() = Some(());
663 }
664 });
665 project_state.worktrees.insert(
666 worktree_id,
667 WorktreeState::Registering(RegisteringWorktreeState {
668 changed_paths: Default::default(),
669 done_rx,
670 _registration: registration,
671 }),
672 );
673 }
674
675 fn project_worktrees_changed(&mut self, project: Model<Project>, cx: &mut ModelContext<Self>) {
676 let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
677 {
678 project_state
679 } else {
680 return;
681 };
682
683 let mut worktrees = project
684 .read(cx)
685 .worktrees()
686 .filter(|worktree| worktree.read(cx).is_local())
687 .collect::<Vec<_>>();
688 let worktree_ids = worktrees
689 .iter()
690 .map(|worktree| worktree.read(cx).id())
691 .collect::<HashSet<_>>();
692
693 // Remove worktrees that are no longer present
694 project_state
695 .worktrees
696 .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
697
698 // Register new worktrees
699 worktrees.retain(|worktree| {
700 let worktree_id = worktree.read(cx).id();
701 !project_state.worktrees.contains_key(&worktree_id)
702 });
703 for worktree in worktrees {
704 self.register_worktree(project.clone(), worktree, cx);
705 }
706 }
707
708 pub fn pending_file_count(&self, project: &Model<Project>) -> Option<watch::Receiver<usize>> {
709 Some(
710 self.projects
711 .get(&project.downgrade())?
712 .pending_file_count_rx
713 .clone(),
714 )
715 }
716
717 pub fn search_project(
718 &mut self,
719 project: Model<Project>,
720 query: String,
721 limit: usize,
722 includes: Vec<PathMatcher>,
723 excludes: Vec<PathMatcher>,
724 cx: &mut ModelContext<Self>,
725 ) -> Task<Result<Vec<SearchResult>>> {
726 if query.is_empty() {
727 return Task::ready(Ok(Vec::new()));
728 }
729
730 let index = self.index_project(project.clone(), cx);
731 let embedding_provider = self.embedding_provider.clone();
732
733 cx.spawn(|this, mut cx| async move {
734 index.await?;
735 let t0 = Instant::now();
736
737 let query = embedding_provider
738 .embed_batch(vec![query])
739 .await?
740 .pop()
741 .context("could not embed query")?;
742 log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
743
744 let search_start = Instant::now();
745 let modified_buffer_results = this.update(&mut cx, |this, cx| {
746 this.search_modified_buffers(
747 &project,
748 query.clone(),
749 limit,
750 &includes,
751 &excludes,
752 cx,
753 )
754 })?;
755 let file_results = this.update(&mut cx, |this, cx| {
756 this.search_files(project, query, limit, includes, excludes, cx)
757 })?;
758 let (modified_buffer_results, file_results) =
759 futures::join!(modified_buffer_results, file_results);
760
761 // Weave together the results from modified buffers and files.
762 let mut results = Vec::new();
763 let mut modified_buffers = HashSet::default();
764 for result in modified_buffer_results.log_err().unwrap_or_default() {
765 modified_buffers.insert(result.buffer.clone());
766 results.push(result);
767 }
768 for result in file_results.log_err().unwrap_or_default() {
769 if !modified_buffers.contains(&result.buffer) {
770 results.push(result);
771 }
772 }
773 results.sort_by_key(|result| Reverse(result.similarity));
774 results.truncate(limit);
775 log::trace!("Semantic search took {:?}", search_start.elapsed());
776 Ok(results)
777 })
778 }
779
780 pub fn search_files(
781 &mut self,
782 project: Model<Project>,
783 query: Embedding,
784 limit: usize,
785 includes: Vec<PathMatcher>,
786 excludes: Vec<PathMatcher>,
787 cx: &mut ModelContext<Self>,
788 ) -> Task<Result<Vec<SearchResult>>> {
789 let db_path = self.db.path().clone();
790 let fs = self.fs.clone();
791 cx.spawn(|this, mut cx| async move {
792 let database = VectorDatabase::new(
793 fs.clone(),
794 db_path.clone(),
795 cx.background_executor().clone(),
796 )
797 .await?;
798
799 let worktree_db_ids = this.read_with(&cx, |this, _| {
800 let project_state = this
801 .projects
802 .get(&project.downgrade())
803 .context("project was not indexed")?;
804 let worktree_db_ids = project_state
805 .worktrees
806 .values()
807 .filter_map(|worktree| {
808 if let WorktreeState::Registered(worktree) = worktree {
809 Some(worktree.db_id)
810 } else {
811 None
812 }
813 })
814 .collect::<Vec<i64>>();
815 anyhow::Ok(worktree_db_ids)
816 })??;
817
818 let file_ids = database
819 .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
820 .await?;
821
822 let batch_n = cx.background_executor().num_cpus();
823 let ids_len = file_ids.clone().len();
824 let minimum_batch_size = 50;
825
826 let batch_size = {
827 let size = ids_len / batch_n;
828 if size < minimum_batch_size {
829 minimum_batch_size
830 } else {
831 size
832 }
833 };
834
835 let mut batch_results = Vec::new();
836 for batch in file_ids.chunks(batch_size) {
837 let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
838 let limit = limit.clone();
839 let fs = fs.clone();
840 let db_path = db_path.clone();
841 let query = query.clone();
842 if let Some(db) =
843 VectorDatabase::new(fs, db_path.clone(), cx.background_executor().clone())
844 .await
845 .log_err()
846 {
847 batch_results.push(async move {
848 db.top_k_search(&query, limit, batch.as_slice()).await
849 });
850 }
851 }
852
853 let batch_results = futures::future::join_all(batch_results).await;
854
855 let mut results = Vec::new();
856 for batch_result in batch_results {
857 if batch_result.is_ok() {
858 for (id, similarity) in batch_result.unwrap() {
859 let ix = match results
860 .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
861 {
862 Ok(ix) => ix,
863 Err(ix) => ix,
864 };
865
866 results.insert(ix, (id, similarity));
867 results.truncate(limit);
868 }
869 }
870 }
871
872 let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
873 let scores = results
874 .into_iter()
875 .map(|(_, score)| score)
876 .collect::<Vec<_>>();
877 let spans = database.spans_for_ids(ids.as_slice()).await?;
878
879 let mut tasks = Vec::new();
880 let mut ranges = Vec::new();
881 let weak_project = project.downgrade();
882 project.update(&mut cx, |project, cx| {
883 let this = this.upgrade().context("index was dropped")?;
884 for (worktree_db_id, file_path, byte_range) in spans {
885 let project_state =
886 if let Some(state) = this.read(cx).projects.get(&weak_project) {
887 state
888 } else {
889 return Err(anyhow!("project not added"));
890 };
891 if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
892 tasks.push(project.open_buffer((worktree_id, file_path), cx));
893 ranges.push(byte_range);
894 }
895 }
896
897 Ok(())
898 })??;
899
900 let buffers = futures::future::join_all(tasks).await;
901 Ok(buffers
902 .into_iter()
903 .zip(ranges)
904 .zip(scores)
905 .filter_map(|((buffer, range), similarity)| {
906 let buffer = buffer.log_err()?;
907 let range = buffer
908 .read_with(&cx, |buffer, _| {
909 let start = buffer.clip_offset(range.start, Bias::Left);
910 let end = buffer.clip_offset(range.end, Bias::Right);
911 buffer.anchor_before(start)..buffer.anchor_after(end)
912 })
913 .log_err()?;
914 Some(SearchResult {
915 buffer,
916 range,
917 similarity,
918 })
919 })
920 .collect())
921 })
922 }
923
924 fn search_modified_buffers(
925 &self,
926 project: &Model<Project>,
927 query: Embedding,
928 limit: usize,
929 includes: &[PathMatcher],
930 excludes: &[PathMatcher],
931 cx: &mut ModelContext<Self>,
932 ) -> Task<Result<Vec<SearchResult>>> {
933 let modified_buffers = project
934 .read(cx)
935 .opened_buffers()
936 .into_iter()
937 .filter_map(|buffer_handle| {
938 let buffer = buffer_handle.read(cx);
939 let snapshot = buffer.snapshot();
940 let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
941 excludes.iter().any(|matcher| matcher.is_match(&path))
942 });
943
944 let included = if includes.len() == 0 {
945 true
946 } else {
947 snapshot.resolve_file_path(cx, false).map_or(false, |path| {
948 includes.iter().any(|matcher| matcher.is_match(&path))
949 })
950 };
951
952 if buffer.is_dirty() && !excluded && included {
953 Some((buffer_handle, snapshot))
954 } else {
955 None
956 }
957 })
958 .collect::<HashMap<_, _>>();
959
960 let embedding_provider = self.embedding_provider.clone();
961 let fs = self.fs.clone();
962 let db_path = self.db.path().clone();
963 let background = cx.background_executor().clone();
964 cx.background_executor().spawn(async move {
965 let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
966 let mut results = Vec::<SearchResult>::new();
967
968 let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
969 for (buffer, snapshot) in modified_buffers {
970 let language = snapshot
971 .language_at(0)
972 .cloned()
973 .unwrap_or_else(|| language::PLAIN_TEXT.clone());
974 let mut spans = retriever
975 .parse_file_with_template(None, &snapshot.text(), language)
976 .log_err()
977 .unwrap_or_default();
978 if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
979 .await
980 .log_err()
981 .is_some()
982 {
983 for span in spans {
984 let similarity = span.embedding.unwrap().similarity(&query);
985 let ix = match results
986 .binary_search_by_key(&Reverse(similarity), |result| {
987 Reverse(result.similarity)
988 }) {
989 Ok(ix) => ix,
990 Err(ix) => ix,
991 };
992
993 let range = {
994 let start = snapshot.clip_offset(span.range.start, Bias::Left);
995 let end = snapshot.clip_offset(span.range.end, Bias::Right);
996 snapshot.anchor_before(start)..snapshot.anchor_after(end)
997 };
998
999 results.insert(
1000 ix,
1001 SearchResult {
1002 buffer: buffer.clone(),
1003 range,
1004 similarity,
1005 },
1006 );
1007 results.truncate(limit);
1008 }
1009 }
1010 }
1011
1012 Ok(results)
1013 })
1014 }
1015
1016 pub fn index_project(
1017 &mut self,
1018 project: Model<Project>,
1019 cx: &mut ModelContext<Self>,
1020 ) -> Task<Result<()>> {
1021 if self.is_authenticated() {
1022 self.index_project_internal(project, cx)
1023 } else {
1024 let authenticate = self.authenticate(cx);
1025 cx.spawn(|this, mut cx| async move {
1026 if authenticate.await {
1027 this.update(&mut cx, |this, cx| this.index_project_internal(project, cx))?
1028 .await
1029 } else {
1030 Err(anyhow!("user is not authenticated"))
1031 }
1032 })
1033 }
1034 }
1035
1036 fn index_project_internal(
1037 &mut self,
1038 project: Model<Project>,
1039 cx: &mut ModelContext<Self>,
1040 ) -> Task<Result<()>> {
1041 if !self.projects.contains_key(&project.downgrade()) {
1042 let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
1043 project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
1044 this.project_worktrees_changed(project.clone(), cx);
1045 }
1046 project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
1047 this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
1048 }
1049 _ => {}
1050 });
1051 let project_state = ProjectState::new(subscription, cx);
1052 self.projects.insert(project.downgrade(), project_state);
1053 self.project_worktrees_changed(project.clone(), cx);
1054 }
1055 let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
1056 project_state.pending_index += 1;
1057 cx.notify();
1058
1059 let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
1060 let db = self.db.clone();
1061 let language_registry = self.language_registry.clone();
1062 let parsing_files_tx = self.parsing_files_tx.clone();
1063 let worktree_registration = self.wait_for_worktree_registration(&project, cx);
1064
1065 cx.spawn(|this, mut cx| async move {
1066 worktree_registration.await?;
1067
1068 let mut pending_files = Vec::new();
1069 let mut files_to_delete = Vec::new();
1070 this.update(&mut cx, |this, cx| {
1071 let project_state = this
1072 .projects
1073 .get_mut(&project.downgrade())
1074 .context("project was dropped")?;
1075 let pending_file_count_tx = &project_state.pending_file_count_tx;
1076
1077 project_state
1078 .worktrees
1079 .retain(|worktree_id, worktree_state| {
1080 let worktree = if let Some(worktree) =
1081 project.read(cx).worktree_for_id(*worktree_id, cx)
1082 {
1083 worktree
1084 } else {
1085 return false;
1086 };
1087 let worktree_state =
1088 if let WorktreeState::Registered(worktree_state) = worktree_state {
1089 worktree_state
1090 } else {
1091 return true;
1092 };
1093
1094 for (path, info) in &worktree_state.changed_paths {
1095 if info.is_deleted {
1096 files_to_delete.push((worktree_state.db_id, path.clone()));
1097 } else if let Ok(absolute_path) = worktree.read(cx).absolutize(path) {
1098 let job_handle = JobHandle::new(pending_file_count_tx);
1099 pending_files.push(PendingFile {
1100 absolute_path,
1101 relative_path: path.clone(),
1102 language: None,
1103 job_handle,
1104 modified_time: info.mtime,
1105 worktree_db_id: worktree_state.db_id,
1106 });
1107 }
1108 }
1109 worktree_state.changed_paths.clear();
1110 true
1111 });
1112
1113 anyhow::Ok(())
1114 })??;
1115
1116 cx.background_executor()
1117 .spawn(async move {
1118 for (worktree_db_id, path) in files_to_delete {
1119 db.delete_file(worktree_db_id, path).await.log_err();
1120 }
1121
1122 let embeddings_for_digest = {
1123 let mut files = HashMap::default();
1124 for pending_file in &pending_files {
1125 files
1126 .entry(pending_file.worktree_db_id)
1127 .or_insert(Vec::new())
1128 .push(pending_file.relative_path.clone());
1129 }
1130 Arc::new(
1131 db.embeddings_for_files(files)
1132 .await
1133 .log_err()
1134 .unwrap_or_default(),
1135 )
1136 };
1137
1138 for mut pending_file in pending_files {
1139 if let Ok(language) = language_registry
1140 .language_for_file(&pending_file.relative_path, None)
1141 .await
1142 {
1143 if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
1144 && &language.name().as_ref() != &"Markdown"
1145 && language
1146 .grammar()
1147 .and_then(|grammar| grammar.embedding_config.as_ref())
1148 .is_none()
1149 {
1150 continue;
1151 }
1152 pending_file.language = Some(language);
1153 }
1154 parsing_files_tx
1155 .try_send((embeddings_for_digest.clone(), pending_file))
1156 .ok();
1157 }
1158
1159 // Wait until we're done indexing.
1160 while let Some(count) = pending_file_count_rx.next().await {
1161 if count == 0 {
1162 break;
1163 }
1164 }
1165 })
1166 .await;
1167
1168 this.update(&mut cx, |this, cx| {
1169 let project_state = this
1170 .projects
1171 .get_mut(&project.downgrade())
1172 .context("project was dropped")?;
1173 project_state.pending_index -= 1;
1174 cx.notify();
1175 anyhow::Ok(())
1176 })??;
1177
1178 Ok(())
1179 })
1180 }
1181
1182 fn wait_for_worktree_registration(
1183 &self,
1184 project: &Model<Project>,
1185 cx: &mut ModelContext<Self>,
1186 ) -> Task<Result<()>> {
1187 let project = project.downgrade();
1188 cx.spawn(|this, cx| async move {
1189 loop {
1190 let mut pending_worktrees = Vec::new();
1191 this.upgrade()
1192 .context("semantic index dropped")?
1193 .read_with(&cx, |this, _| {
1194 if let Some(project) = this.projects.get(&project) {
1195 for worktree in project.worktrees.values() {
1196 if let WorktreeState::Registering(worktree) = worktree {
1197 pending_worktrees.push(worktree.done());
1198 }
1199 }
1200 }
1201 })?;
1202
1203 if pending_worktrees.is_empty() {
1204 break;
1205 } else {
1206 future::join_all(pending_worktrees).await;
1207 }
1208 }
1209 Ok(())
1210 })
1211 }
1212
1213 async fn embed_spans(
1214 spans: &mut [Span],
1215 embedding_provider: &dyn EmbeddingProvider,
1216 db: &VectorDatabase,
1217 ) -> Result<()> {
1218 let mut batch = Vec::new();
1219 let mut batch_tokens = 0;
1220 let mut embeddings = Vec::new();
1221
1222 let digests = spans
1223 .iter()
1224 .map(|span| span.digest.clone())
1225 .collect::<Vec<_>>();
1226 let embeddings_for_digests = db
1227 .embeddings_for_digests(digests)
1228 .await
1229 .log_err()
1230 .unwrap_or_default();
1231
1232 for span in &*spans {
1233 if embeddings_for_digests.contains_key(&span.digest) {
1234 continue;
1235 };
1236
1237 if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
1238 let batch_embeddings = embedding_provider
1239 .embed_batch(mem::take(&mut batch))
1240 .await?;
1241 embeddings.extend(batch_embeddings);
1242 batch_tokens = 0;
1243 }
1244
1245 batch_tokens += span.token_count;
1246 batch.push(span.content.clone());
1247 }
1248
1249 if !batch.is_empty() {
1250 let batch_embeddings = embedding_provider
1251 .embed_batch(mem::take(&mut batch))
1252 .await?;
1253
1254 embeddings.extend(batch_embeddings);
1255 }
1256
1257 let mut embeddings = embeddings.into_iter();
1258 for span in spans {
1259 let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
1260 Some(embedding.clone())
1261 } else {
1262 embeddings.next()
1263 };
1264 let embedding = embedding.context("failed to embed spans")?;
1265 span.embedding = Some(embedding);
1266 }
1267 Ok(())
1268 }
1269}
1270
1271impl Drop for JobHandle {
1272 fn drop(&mut self) {
1273 if let Some(inner) = Arc::get_mut(&mut self.tx) {
1274 // This is the last instance of the JobHandle (regardless of its origin - whether it was cloned or not)
1275 if let Some(tx) = inner.upgrade() {
1276 let mut tx = tx.lock();
1277 *tx.borrow_mut() -= 1;
1278 }
1279 }
1280 }
1281}
1282
1283#[cfg(test)]
1284mod tests {
1285
1286 use super::*;
1287 #[test]
1288 fn test_job_handle() {
1289 let (job_count_tx, job_count_rx) = watch::channel_with(0);
1290 let tx = Arc::new(Mutex::new(job_count_tx));
1291 let job_handle = JobHandle::new(&tx);
1292
1293 assert_eq!(1, *job_count_rx.borrow());
1294 let new_job_handle = job_handle.clone();
1295 assert_eq!(1, *job_count_rx.borrow());
1296 drop(job_handle);
1297 assert_eq!(1, *job_count_rx.borrow());
1298 drop(new_job_handle);
1299 assert_eq!(0, *job_count_rx.borrow());
1300 }
1301}