1mod db;
2mod embedding_queue;
3mod parsing;
4pub mod semantic_index_settings;
5
6#[cfg(test)]
7mod semantic_index_tests;
8
9use crate::semantic_index_settings::SemanticIndexSettings;
10use ai::embedding::{Embedding, EmbeddingProvider};
11use ai::providers::open_ai::OpenAIEmbeddingProvider;
12use anyhow::{anyhow, Context as _, Result};
13use collections::{BTreeMap, HashMap, HashSet};
14use db::VectorDatabase;
15use embedding_queue::{EmbeddingQueue, FileToEmbed};
16use futures::{future, FutureExt, StreamExt};
17use gpui::{
18 AppContext, AsyncAppContext, BorrowWindow, Context, Model, ModelContext, Task, ViewContext,
19 WeakModel,
20};
21use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
22use lazy_static::lazy_static;
23use ordered_float::OrderedFloat;
24use parking_lot::Mutex;
25use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
26use postage::watch;
27use project::{Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
28use settings::Settings;
29use smol::channel;
30use std::{
31 cmp::Reverse,
32 env,
33 future::Future,
34 mem,
35 ops::Range,
36 path::{Path, PathBuf},
37 sync::{Arc, Weak},
38 time::{Duration, Instant, SystemTime},
39};
40use util::paths::PathMatcher;
41use util::{channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
42use workspace::Workspace;
43
44const SEMANTIC_INDEX_VERSION: usize = 11;
45const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
46const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
47
48lazy_static! {
49 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
50}
51
52pub fn init(
53 fs: Arc<dyn Fs>,
54 http_client: Arc<dyn HttpClient>,
55 language_registry: Arc<LanguageRegistry>,
56 cx: &mut AppContext,
57) {
58 SemanticIndexSettings::register(cx);
59
60 let db_file_path = EMBEDDINGS_DIR
61 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
62 .join("embeddings_db");
63
64 cx.observe_new_views(
65 |workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
66 let Some(semantic_index) = SemanticIndex::global(cx) else {
67 return;
68 };
69 let project = workspace.project().clone();
70
71 if project.read(cx).is_local() {
72 cx.app_mut()
73 .spawn(|mut cx| async move {
74 let previously_indexed = semantic_index
75 .update(&mut cx, |index, cx| {
76 index.project_previously_indexed(&project, cx)
77 })?
78 .await?;
79 if previously_indexed {
80 semantic_index
81 .update(&mut cx, |index, cx| index.index_project(project, cx))?
82 .await?;
83 }
84 anyhow::Ok(())
85 })
86 .detach_and_log_err(cx);
87 }
88 },
89 )
90 .detach();
91
92 cx.spawn(move |cx| async move {
93 let semantic_index = SemanticIndex::new(
94 fs,
95 db_file_path,
96 Arc::new(OpenAIEmbeddingProvider::new(
97 http_client,
98 cx.background_executor().clone(),
99 )),
100 language_registry,
101 cx.clone(),
102 )
103 .await?;
104
105 cx.update(|cx| cx.set_global(semantic_index.clone()))?;
106
107 anyhow::Ok(())
108 })
109 .detach();
110}
111
112#[derive(Copy, Clone, Debug)]
113pub enum SemanticIndexStatus {
114 NotAuthenticated,
115 NotIndexed,
116 Indexed,
117 Indexing {
118 remaining_files: usize,
119 rate_limit_expiry: Option<Instant>,
120 },
121}
122
123pub struct SemanticIndex {
124 fs: Arc<dyn Fs>,
125 db: VectorDatabase,
126 embedding_provider: Arc<dyn EmbeddingProvider>,
127 language_registry: Arc<LanguageRegistry>,
128 parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
129 _embedding_task: Task<()>,
130 _parsing_files_tasks: Vec<Task<()>>,
131 projects: HashMap<WeakModel<Project>, ProjectState>,
132}
133
134struct ProjectState {
135 worktrees: HashMap<WorktreeId, WorktreeState>,
136 pending_file_count_rx: watch::Receiver<usize>,
137 pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
138 pending_index: usize,
139 _subscription: gpui::Subscription,
140 _observe_pending_file_count: Task<()>,
141}
142
143enum WorktreeState {
144 Registering(RegisteringWorktreeState),
145 Registered(RegisteredWorktreeState),
146}
147
148impl WorktreeState {
149 fn is_registered(&self) -> bool {
150 matches!(self, Self::Registered(_))
151 }
152
153 fn paths_changed(
154 &mut self,
155 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
156 worktree: &Worktree,
157 ) {
158 let changed_paths = match self {
159 Self::Registering(state) => &mut state.changed_paths,
160 Self::Registered(state) => &mut state.changed_paths,
161 };
162
163 for (path, entry_id, change) in changes.iter() {
164 let Some(entry) = worktree.entry_for_id(*entry_id) else {
165 continue;
166 };
167 if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
168 continue;
169 }
170 changed_paths.insert(
171 path.clone(),
172 ChangedPathInfo {
173 mtime: entry.mtime,
174 is_deleted: *change == PathChange::Removed,
175 },
176 );
177 }
178 }
179}
180
181struct RegisteringWorktreeState {
182 changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
183 done_rx: watch::Receiver<Option<()>>,
184 _registration: Task<()>,
185}
186
187impl RegisteringWorktreeState {
188 fn done(&self) -> impl Future<Output = ()> {
189 let mut done_rx = self.done_rx.clone();
190 async move {
191 while let Some(result) = done_rx.next().await {
192 if result.is_some() {
193 break;
194 }
195 }
196 }
197 }
198}
199
200struct RegisteredWorktreeState {
201 db_id: i64,
202 changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
203}
204
205struct ChangedPathInfo {
206 mtime: SystemTime,
207 is_deleted: bool,
208}
209
210#[derive(Clone)]
211pub struct JobHandle {
212 /// The outer Arc is here to count the clones of a JobHandle instance;
213 /// when the last handle to a given job is dropped, we decrement a counter (just once).
214 tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
215}
216
217impl JobHandle {
218 fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
219 *tx.lock().borrow_mut() += 1;
220 Self {
221 tx: Arc::new(Arc::downgrade(&tx)),
222 }
223 }
224}
225
226impl ProjectState {
227 fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
228 let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
229 let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
230 Self {
231 worktrees: Default::default(),
232 pending_file_count_rx: pending_file_count_rx.clone(),
233 pending_file_count_tx,
234 pending_index: 0,
235 _subscription: subscription,
236 _observe_pending_file_count: cx.spawn({
237 let mut pending_file_count_rx = pending_file_count_rx.clone();
238 |this, mut cx| async move {
239 while let Some(_) = pending_file_count_rx.next().await {
240 if this.update(&mut cx, |_, cx| cx.notify()).is_err() {
241 break;
242 }
243 }
244 }
245 }),
246 }
247 }
248
249 fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
250 self.worktrees
251 .iter()
252 .find_map(|(worktree_id, worktree_state)| match worktree_state {
253 WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
254 _ => None,
255 })
256 }
257}
258
259#[derive(Clone)]
260pub struct PendingFile {
261 worktree_db_id: i64,
262 relative_path: Arc<Path>,
263 absolute_path: PathBuf,
264 language: Option<Arc<Language>>,
265 modified_time: SystemTime,
266 job_handle: JobHandle,
267}
268
269#[derive(Clone)]
270pub struct SearchResult {
271 pub buffer: Model<Buffer>,
272 pub range: Range<Anchor>,
273 pub similarity: OrderedFloat<f32>,
274}
275
276impl SemanticIndex {
277 pub fn global(cx: &mut AppContext) -> Option<Model<SemanticIndex>> {
278 if cx.has_global::<Model<Self>>() {
279 Some(cx.global::<Model<SemanticIndex>>().clone())
280 } else {
281 None
282 }
283 }
284
285 pub fn authenticate(&mut self, cx: &mut AppContext) -> bool {
286 if !self.embedding_provider.has_credentials() {
287 self.embedding_provider.retrieve_credentials(cx);
288 } else {
289 return true;
290 }
291
292 self.embedding_provider.has_credentials()
293 }
294
295 pub fn is_authenticated(&self) -> bool {
296 self.embedding_provider.has_credentials()
297 }
298
299 pub fn enabled(cx: &AppContext) -> bool {
300 SemanticIndexSettings::get_global(cx).enabled
301 }
302
303 pub fn status(&self, project: &Model<Project>) -> SemanticIndexStatus {
304 if !self.is_authenticated() {
305 return SemanticIndexStatus::NotAuthenticated;
306 }
307
308 if let Some(project_state) = self.projects.get(&project.downgrade()) {
309 if project_state
310 .worktrees
311 .values()
312 .all(|worktree| worktree.is_registered())
313 && project_state.pending_index == 0
314 {
315 SemanticIndexStatus::Indexed
316 } else {
317 SemanticIndexStatus::Indexing {
318 remaining_files: project_state.pending_file_count_rx.borrow().clone(),
319 rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
320 }
321 }
322 } else {
323 SemanticIndexStatus::NotIndexed
324 }
325 }
326
327 pub async fn new(
328 fs: Arc<dyn Fs>,
329 database_path: PathBuf,
330 embedding_provider: Arc<dyn EmbeddingProvider>,
331 language_registry: Arc<LanguageRegistry>,
332 mut cx: AsyncAppContext,
333 ) -> Result<Model<Self>> {
334 let t0 = Instant::now();
335 let database_path = Arc::from(database_path);
336 let db = VectorDatabase::new(fs.clone(), database_path, cx.background_executor().clone())
337 .await?;
338
339 log::trace!(
340 "db initialization took {:?} milliseconds",
341 t0.elapsed().as_millis()
342 );
343
344 cx.new_model(|cx| {
345 let t0 = Instant::now();
346 let embedding_queue =
347 EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor().clone());
348 let _embedding_task = cx.background_executor().spawn({
349 let embedded_files = embedding_queue.finished_files();
350 let db = db.clone();
351 async move {
352 while let Ok(file) = embedded_files.recv().await {
353 db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
354 .await
355 .log_err();
356 }
357 }
358 });
359
360 // Parse files into embeddable spans.
361 let (parsing_files_tx, parsing_files_rx) =
362 channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
363 let embedding_queue = Arc::new(Mutex::new(embedding_queue));
364 let mut _parsing_files_tasks = Vec::new();
365 for _ in 0..cx.background_executor().num_cpus() {
366 let fs = fs.clone();
367 let mut parsing_files_rx = parsing_files_rx.clone();
368 let embedding_provider = embedding_provider.clone();
369 let embedding_queue = embedding_queue.clone();
370 let background = cx.background_executor().clone();
371 _parsing_files_tasks.push(cx.background_executor().spawn(async move {
372 let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
373 loop {
374 let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
375 let mut next_file_to_parse = parsing_files_rx.next().fuse();
376 futures::select_biased! {
377 next_file_to_parse = next_file_to_parse => {
378 if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
379 Self::parse_file(
380 &fs,
381 pending_file,
382 &mut retriever,
383 &embedding_queue,
384 &embeddings_for_digest,
385 )
386 .await
387 } else {
388 break;
389 }
390 },
391 _ = timer => {
392 embedding_queue.lock().flush();
393 }
394 }
395 }
396 }));
397 }
398
399 log::trace!(
400 "semantic index task initialization took {:?} milliseconds",
401 t0.elapsed().as_millis()
402 );
403 Self {
404 fs,
405 db,
406 embedding_provider,
407 language_registry,
408 parsing_files_tx,
409 _embedding_task,
410 _parsing_files_tasks,
411 projects: Default::default(),
412 }
413 })
414 }
415
416 async fn parse_file(
417 fs: &Arc<dyn Fs>,
418 pending_file: PendingFile,
419 retriever: &mut CodeContextRetriever,
420 embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
421 embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
422 ) {
423 let Some(language) = pending_file.language else {
424 return;
425 };
426
427 if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
428 if let Some(mut spans) = retriever
429 .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
430 .log_err()
431 {
432 log::trace!(
433 "parsed path {:?}: {} spans",
434 pending_file.relative_path,
435 spans.len()
436 );
437
438 for span in &mut spans {
439 if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
440 span.embedding = Some(embedding.to_owned());
441 }
442 }
443
444 embedding_queue.lock().push(FileToEmbed {
445 worktree_id: pending_file.worktree_db_id,
446 path: pending_file.relative_path,
447 mtime: pending_file.modified_time,
448 job_handle: pending_file.job_handle,
449 spans,
450 });
451 }
452 }
453 }
454
455 pub fn project_previously_indexed(
456 &mut self,
457 project: &Model<Project>,
458 cx: &mut ModelContext<Self>,
459 ) -> Task<Result<bool>> {
460 let worktrees_indexed_previously = project
461 .read(cx)
462 .worktrees()
463 .map(|worktree| {
464 self.db
465 .worktree_previously_indexed(&worktree.read(cx).abs_path())
466 })
467 .collect::<Vec<_>>();
468 cx.spawn(|_, _cx| async move {
469 let worktree_indexed_previously =
470 futures::future::join_all(worktrees_indexed_previously).await;
471
472 Ok(worktree_indexed_previously
473 .iter()
474 .filter(|worktree| worktree.is_ok())
475 .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
476 })
477 }
478
479 fn project_entries_changed(
480 &mut self,
481 project: Model<Project>,
482 worktree_id: WorktreeId,
483 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
484 cx: &mut ModelContext<Self>,
485 ) {
486 let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
487 return;
488 };
489 let project = project.downgrade();
490 let Some(project_state) = self.projects.get_mut(&project) else {
491 return;
492 };
493
494 let worktree = worktree.read(cx);
495 let worktree_state =
496 if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
497 worktree_state
498 } else {
499 return;
500 };
501 worktree_state.paths_changed(changes, worktree);
502 if let WorktreeState::Registered(_) = worktree_state {
503 cx.spawn(|this, mut cx| async move {
504 cx.background_executor()
505 .timer(BACKGROUND_INDEXING_DELAY)
506 .await;
507 if let Some((this, project)) = this.upgrade().zip(project.upgrade()) {
508 this.update(&mut cx, |this, cx| {
509 this.index_project(project, cx).detach_and_log_err(cx)
510 })?;
511 }
512 anyhow::Ok(())
513 })
514 .detach_and_log_err(cx);
515 }
516 }
517
518 fn register_worktree(
519 &mut self,
520 project: Model<Project>,
521 worktree: Model<Worktree>,
522 cx: &mut ModelContext<Self>,
523 ) {
524 let project = project.downgrade();
525 let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
526 project_state
527 } else {
528 return;
529 };
530 let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
531 worktree
532 } else {
533 return;
534 };
535 let worktree_abs_path = worktree.abs_path().clone();
536 let scan_complete = worktree.scan_complete();
537 let worktree_id = worktree.id();
538 let db = self.db.clone();
539 let language_registry = self.language_registry.clone();
540 let (mut done_tx, done_rx) = watch::channel();
541 let registration = cx.spawn(|this, mut cx| {
542 async move {
543 let register = async {
544 scan_complete.await;
545 let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
546 let mut file_mtimes = db.get_file_mtimes(db_id).await?;
547 let worktree = if let Some(project) = project.upgrade() {
548 project
549 .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
550 .ok()
551 .flatten()
552 .context("worktree not found")?
553 } else {
554 return anyhow::Ok(());
555 };
556 let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot())?;
557 let mut changed_paths = cx
558 .background_executor()
559 .spawn(async move {
560 let mut changed_paths = BTreeMap::new();
561 for file in worktree.files(false, 0) {
562 let absolute_path = worktree.absolutize(&file.path)?;
563
564 if file.is_external || file.is_ignored || file.is_symlink {
565 continue;
566 }
567
568 if let Ok(language) = language_registry
569 .language_for_file(&absolute_path, None)
570 .await
571 {
572 // Test if file is valid parseable file
573 if !PARSEABLE_ENTIRE_FILE_TYPES
574 .contains(&language.name().as_ref())
575 && &language.name().as_ref() != &"Markdown"
576 && language
577 .grammar()
578 .and_then(|grammar| grammar.embedding_config.as_ref())
579 .is_none()
580 {
581 continue;
582 }
583
584 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
585 let already_stored = stored_mtime
586 .map_or(false, |existing_mtime| {
587 existing_mtime == file.mtime
588 });
589
590 if !already_stored {
591 changed_paths.insert(
592 file.path.clone(),
593 ChangedPathInfo {
594 mtime: file.mtime,
595 is_deleted: false,
596 },
597 );
598 }
599 }
600 }
601
602 // Clean up entries from database that are no longer in the worktree.
603 for (path, mtime) in file_mtimes {
604 changed_paths.insert(
605 path.into(),
606 ChangedPathInfo {
607 mtime,
608 is_deleted: true,
609 },
610 );
611 }
612
613 anyhow::Ok(changed_paths)
614 })
615 .await?;
616 this.update(&mut cx, |this, cx| {
617 let project_state = this
618 .projects
619 .get_mut(&project)
620 .context("project not registered")?;
621 let project = project.upgrade().context("project was dropped")?;
622
623 if let Some(WorktreeState::Registering(state)) =
624 project_state.worktrees.remove(&worktree_id)
625 {
626 changed_paths.extend(state.changed_paths);
627 }
628 project_state.worktrees.insert(
629 worktree_id,
630 WorktreeState::Registered(RegisteredWorktreeState {
631 db_id,
632 changed_paths,
633 }),
634 );
635 this.index_project(project, cx).detach_and_log_err(cx);
636
637 anyhow::Ok(())
638 })??;
639
640 anyhow::Ok(())
641 };
642
643 if register.await.log_err().is_none() {
644 // Stop tracking this worktree if the registration failed.
645 this.update(&mut cx, |this, _| {
646 this.projects.get_mut(&project).map(|project_state| {
647 project_state.worktrees.remove(&worktree_id);
648 });
649 })
650 .ok();
651 }
652
653 *done_tx.borrow_mut() = Some(());
654 }
655 });
656 project_state.worktrees.insert(
657 worktree_id,
658 WorktreeState::Registering(RegisteringWorktreeState {
659 changed_paths: Default::default(),
660 done_rx,
661 _registration: registration,
662 }),
663 );
664 }
665
666 fn project_worktrees_changed(&mut self, project: Model<Project>, cx: &mut ModelContext<Self>) {
667 let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
668 {
669 project_state
670 } else {
671 return;
672 };
673
674 let mut worktrees = project
675 .read(cx)
676 .worktrees()
677 .filter(|worktree| worktree.read(cx).is_local())
678 .collect::<Vec<_>>();
679 let worktree_ids = worktrees
680 .iter()
681 .map(|worktree| worktree.read(cx).id())
682 .collect::<HashSet<_>>();
683
684 // Remove worktrees that are no longer present
685 project_state
686 .worktrees
687 .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
688
689 // Register new worktrees
690 worktrees.retain(|worktree| {
691 let worktree_id = worktree.read(cx).id();
692 !project_state.worktrees.contains_key(&worktree_id)
693 });
694 for worktree in worktrees {
695 self.register_worktree(project.clone(), worktree, cx);
696 }
697 }
698
699 pub fn pending_file_count(&self, project: &Model<Project>) -> Option<watch::Receiver<usize>> {
700 Some(
701 self.projects
702 .get(&project.downgrade())?
703 .pending_file_count_rx
704 .clone(),
705 )
706 }
707
708 pub fn search_project(
709 &mut self,
710 project: Model<Project>,
711 query: String,
712 limit: usize,
713 includes: Vec<PathMatcher>,
714 excludes: Vec<PathMatcher>,
715 cx: &mut ModelContext<Self>,
716 ) -> Task<Result<Vec<SearchResult>>> {
717 if query.is_empty() {
718 return Task::ready(Ok(Vec::new()));
719 }
720
721 let index = self.index_project(project.clone(), cx);
722 let embedding_provider = self.embedding_provider.clone();
723
724 cx.spawn(|this, mut cx| async move {
725 index.await?;
726 let t0 = Instant::now();
727
728 let query = embedding_provider
729 .embed_batch(vec![query])
730 .await?
731 .pop()
732 .context("could not embed query")?;
733 log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
734
735 let search_start = Instant::now();
736 let modified_buffer_results = this.update(&mut cx, |this, cx| {
737 this.search_modified_buffers(
738 &project,
739 query.clone(),
740 limit,
741 &includes,
742 &excludes,
743 cx,
744 )
745 })?;
746 let file_results = this.update(&mut cx, |this, cx| {
747 this.search_files(project, query, limit, includes, excludes, cx)
748 })?;
749 let (modified_buffer_results, file_results) =
750 futures::join!(modified_buffer_results, file_results);
751
752 // Weave together the results from modified buffers and files.
753 let mut results = Vec::new();
754 let mut modified_buffers = HashSet::default();
755 for result in modified_buffer_results.log_err().unwrap_or_default() {
756 modified_buffers.insert(result.buffer.clone());
757 results.push(result);
758 }
759 for result in file_results.log_err().unwrap_or_default() {
760 if !modified_buffers.contains(&result.buffer) {
761 results.push(result);
762 }
763 }
764 results.sort_by_key(|result| Reverse(result.similarity));
765 results.truncate(limit);
766 log::trace!("Semantic search took {:?}", search_start.elapsed());
767 Ok(results)
768 })
769 }
770
771 pub fn search_files(
772 &mut self,
773 project: Model<Project>,
774 query: Embedding,
775 limit: usize,
776 includes: Vec<PathMatcher>,
777 excludes: Vec<PathMatcher>,
778 cx: &mut ModelContext<Self>,
779 ) -> Task<Result<Vec<SearchResult>>> {
780 let db_path = self.db.path().clone();
781 let fs = self.fs.clone();
782 cx.spawn(|this, mut cx| async move {
783 let database = VectorDatabase::new(
784 fs.clone(),
785 db_path.clone(),
786 cx.background_executor().clone(),
787 )
788 .await?;
789
790 let worktree_db_ids = this.read_with(&cx, |this, _| {
791 let project_state = this
792 .projects
793 .get(&project.downgrade())
794 .context("project was not indexed")?;
795 let worktree_db_ids = project_state
796 .worktrees
797 .values()
798 .filter_map(|worktree| {
799 if let WorktreeState::Registered(worktree) = worktree {
800 Some(worktree.db_id)
801 } else {
802 None
803 }
804 })
805 .collect::<Vec<i64>>();
806 anyhow::Ok(worktree_db_ids)
807 })??;
808
809 let file_ids = database
810 .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
811 .await?;
812
813 let batch_n = cx.background_executor().num_cpus();
814 let ids_len = file_ids.clone().len();
815 let minimum_batch_size = 50;
816
817 let batch_size = {
818 let size = ids_len / batch_n;
819 if size < minimum_batch_size {
820 minimum_batch_size
821 } else {
822 size
823 }
824 };
825
826 let mut batch_results = Vec::new();
827 for batch in file_ids.chunks(batch_size) {
828 let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
829 let limit = limit.clone();
830 let fs = fs.clone();
831 let db_path = db_path.clone();
832 let query = query.clone();
833 if let Some(db) =
834 VectorDatabase::new(fs, db_path.clone(), cx.background_executor().clone())
835 .await
836 .log_err()
837 {
838 batch_results.push(async move {
839 db.top_k_search(&query, limit, batch.as_slice()).await
840 });
841 }
842 }
843
844 let batch_results = futures::future::join_all(batch_results).await;
845
846 let mut results = Vec::new();
847 for batch_result in batch_results {
848 if batch_result.is_ok() {
849 for (id, similarity) in batch_result.unwrap() {
850 let ix = match results
851 .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
852 {
853 Ok(ix) => ix,
854 Err(ix) => ix,
855 };
856
857 results.insert(ix, (id, similarity));
858 results.truncate(limit);
859 }
860 }
861 }
862
863 let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
864 let scores = results
865 .into_iter()
866 .map(|(_, score)| score)
867 .collect::<Vec<_>>();
868 let spans = database.spans_for_ids(ids.as_slice()).await?;
869
870 let mut tasks = Vec::new();
871 let mut ranges = Vec::new();
872 let weak_project = project.downgrade();
873 project.update(&mut cx, |project, cx| {
874 let this = this.upgrade().context("index was dropped")?;
875 for (worktree_db_id, file_path, byte_range) in spans {
876 let project_state =
877 if let Some(state) = this.read(cx).projects.get(&weak_project) {
878 state
879 } else {
880 return Err(anyhow!("project not added"));
881 };
882 if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
883 tasks.push(project.open_buffer((worktree_id, file_path), cx));
884 ranges.push(byte_range);
885 }
886 }
887
888 Ok(())
889 })??;
890
891 let buffers = futures::future::join_all(tasks).await;
892 Ok(buffers
893 .into_iter()
894 .zip(ranges)
895 .zip(scores)
896 .filter_map(|((buffer, range), similarity)| {
897 let buffer = buffer.log_err()?;
898 let range = buffer
899 .read_with(&cx, |buffer, _| {
900 let start = buffer.clip_offset(range.start, Bias::Left);
901 let end = buffer.clip_offset(range.end, Bias::Right);
902 buffer.anchor_before(start)..buffer.anchor_after(end)
903 })
904 .log_err()?;
905 Some(SearchResult {
906 buffer,
907 range,
908 similarity,
909 })
910 })
911 .collect())
912 })
913 }
914
915 fn search_modified_buffers(
916 &self,
917 project: &Model<Project>,
918 query: Embedding,
919 limit: usize,
920 includes: &[PathMatcher],
921 excludes: &[PathMatcher],
922 cx: &mut ModelContext<Self>,
923 ) -> Task<Result<Vec<SearchResult>>> {
924 let modified_buffers = project
925 .read(cx)
926 .opened_buffers()
927 .into_iter()
928 .filter_map(|buffer_handle| {
929 let buffer = buffer_handle.read(cx);
930 let snapshot = buffer.snapshot();
931 let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
932 excludes.iter().any(|matcher| matcher.is_match(&path))
933 });
934
935 let included = if includes.len() == 0 {
936 true
937 } else {
938 snapshot.resolve_file_path(cx, false).map_or(false, |path| {
939 includes.iter().any(|matcher| matcher.is_match(&path))
940 })
941 };
942
943 if buffer.is_dirty() && !excluded && included {
944 Some((buffer_handle, snapshot))
945 } else {
946 None
947 }
948 })
949 .collect::<HashMap<_, _>>();
950
951 let embedding_provider = self.embedding_provider.clone();
952 let fs = self.fs.clone();
953 let db_path = self.db.path().clone();
954 let background = cx.background_executor().clone();
955 cx.background_executor().spawn(async move {
956 let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
957 let mut results = Vec::<SearchResult>::new();
958
959 let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
960 for (buffer, snapshot) in modified_buffers {
961 let language = snapshot
962 .language_at(0)
963 .cloned()
964 .unwrap_or_else(|| language::PLAIN_TEXT.clone());
965 let mut spans = retriever
966 .parse_file_with_template(None, &snapshot.text(), language)
967 .log_err()
968 .unwrap_or_default();
969 if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
970 .await
971 .log_err()
972 .is_some()
973 {
974 for span in spans {
975 let similarity = span.embedding.unwrap().similarity(&query);
976 let ix = match results
977 .binary_search_by_key(&Reverse(similarity), |result| {
978 Reverse(result.similarity)
979 }) {
980 Ok(ix) => ix,
981 Err(ix) => ix,
982 };
983
984 let range = {
985 let start = snapshot.clip_offset(span.range.start, Bias::Left);
986 let end = snapshot.clip_offset(span.range.end, Bias::Right);
987 snapshot.anchor_before(start)..snapshot.anchor_after(end)
988 };
989
990 results.insert(
991 ix,
992 SearchResult {
993 buffer: buffer.clone(),
994 range,
995 similarity,
996 },
997 );
998 results.truncate(limit);
999 }
1000 }
1001 }
1002
1003 Ok(results)
1004 })
1005 }
1006
1007 pub fn index_project(
1008 &mut self,
1009 project: Model<Project>,
1010 cx: &mut ModelContext<Self>,
1011 ) -> Task<Result<()>> {
1012 if !self.is_authenticated() {
1013 if !self.authenticate(cx) {
1014 return Task::ready(Err(anyhow!("user is not authenticated")));
1015 }
1016 }
1017
1018 if !self.projects.contains_key(&project.downgrade()) {
1019 let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
1020 project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
1021 this.project_worktrees_changed(project.clone(), cx);
1022 }
1023 project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
1024 this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
1025 }
1026 _ => {}
1027 });
1028 let project_state = ProjectState::new(subscription, cx);
1029 self.projects.insert(project.downgrade(), project_state);
1030 self.project_worktrees_changed(project.clone(), cx);
1031 }
1032 let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
1033 project_state.pending_index += 1;
1034 cx.notify();
1035
1036 let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
1037 let db = self.db.clone();
1038 let language_registry = self.language_registry.clone();
1039 let parsing_files_tx = self.parsing_files_tx.clone();
1040 let worktree_registration = self.wait_for_worktree_registration(&project, cx);
1041
1042 cx.spawn(|this, mut cx| async move {
1043 worktree_registration.await?;
1044
1045 let mut pending_files = Vec::new();
1046 let mut files_to_delete = Vec::new();
1047 this.update(&mut cx, |this, cx| {
1048 let project_state = this
1049 .projects
1050 .get_mut(&project.downgrade())
1051 .context("project was dropped")?;
1052 let pending_file_count_tx = &project_state.pending_file_count_tx;
1053
1054 project_state
1055 .worktrees
1056 .retain(|worktree_id, worktree_state| {
1057 let worktree = if let Some(worktree) =
1058 project.read(cx).worktree_for_id(*worktree_id, cx)
1059 {
1060 worktree
1061 } else {
1062 return false;
1063 };
1064 let worktree_state =
1065 if let WorktreeState::Registered(worktree_state) = worktree_state {
1066 worktree_state
1067 } else {
1068 return true;
1069 };
1070
1071 for (path, info) in &worktree_state.changed_paths {
1072 if info.is_deleted {
1073 files_to_delete.push((worktree_state.db_id, path.clone()));
1074 } else if let Ok(absolute_path) = worktree.read(cx).absolutize(path) {
1075 let job_handle = JobHandle::new(pending_file_count_tx);
1076 pending_files.push(PendingFile {
1077 absolute_path,
1078 relative_path: path.clone(),
1079 language: None,
1080 job_handle,
1081 modified_time: info.mtime,
1082 worktree_db_id: worktree_state.db_id,
1083 });
1084 }
1085 }
1086 worktree_state.changed_paths.clear();
1087 true
1088 });
1089
1090 anyhow::Ok(())
1091 })??;
1092
1093 cx.background_executor()
1094 .spawn(async move {
1095 for (worktree_db_id, path) in files_to_delete {
1096 db.delete_file(worktree_db_id, path).await.log_err();
1097 }
1098
1099 let embeddings_for_digest = {
1100 let mut files = HashMap::default();
1101 for pending_file in &pending_files {
1102 files
1103 .entry(pending_file.worktree_db_id)
1104 .or_insert(Vec::new())
1105 .push(pending_file.relative_path.clone());
1106 }
1107 Arc::new(
1108 db.embeddings_for_files(files)
1109 .await
1110 .log_err()
1111 .unwrap_or_default(),
1112 )
1113 };
1114
1115 for mut pending_file in pending_files {
1116 if let Ok(language) = language_registry
1117 .language_for_file(&pending_file.relative_path, None)
1118 .await
1119 {
1120 if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
1121 && &language.name().as_ref() != &"Markdown"
1122 && language
1123 .grammar()
1124 .and_then(|grammar| grammar.embedding_config.as_ref())
1125 .is_none()
1126 {
1127 continue;
1128 }
1129 pending_file.language = Some(language);
1130 }
1131 parsing_files_tx
1132 .try_send((embeddings_for_digest.clone(), pending_file))
1133 .ok();
1134 }
1135
1136 // Wait until we're done indexing.
1137 while let Some(count) = pending_file_count_rx.next().await {
1138 if count == 0 {
1139 break;
1140 }
1141 }
1142 })
1143 .await;
1144
1145 this.update(&mut cx, |this, cx| {
1146 let project_state = this
1147 .projects
1148 .get_mut(&project.downgrade())
1149 .context("project was dropped")?;
1150 project_state.pending_index -= 1;
1151 cx.notify();
1152 anyhow::Ok(())
1153 })??;
1154
1155 Ok(())
1156 })
1157 }
1158
1159 fn wait_for_worktree_registration(
1160 &self,
1161 project: &Model<Project>,
1162 cx: &mut ModelContext<Self>,
1163 ) -> Task<Result<()>> {
1164 let project = project.downgrade();
1165 cx.spawn(|this, cx| async move {
1166 loop {
1167 let mut pending_worktrees = Vec::new();
1168 this.upgrade()
1169 .context("semantic index dropped")?
1170 .read_with(&cx, |this, _| {
1171 if let Some(project) = this.projects.get(&project) {
1172 for worktree in project.worktrees.values() {
1173 if let WorktreeState::Registering(worktree) = worktree {
1174 pending_worktrees.push(worktree.done());
1175 }
1176 }
1177 }
1178 })?;
1179
1180 if pending_worktrees.is_empty() {
1181 break;
1182 } else {
1183 future::join_all(pending_worktrees).await;
1184 }
1185 }
1186 Ok(())
1187 })
1188 }
1189
1190 async fn embed_spans(
1191 spans: &mut [Span],
1192 embedding_provider: &dyn EmbeddingProvider,
1193 db: &VectorDatabase,
1194 ) -> Result<()> {
1195 let mut batch = Vec::new();
1196 let mut batch_tokens = 0;
1197 let mut embeddings = Vec::new();
1198
1199 let digests = spans
1200 .iter()
1201 .map(|span| span.digest.clone())
1202 .collect::<Vec<_>>();
1203 let embeddings_for_digests = db
1204 .embeddings_for_digests(digests)
1205 .await
1206 .log_err()
1207 .unwrap_or_default();
1208
1209 for span in &*spans {
1210 if embeddings_for_digests.contains_key(&span.digest) {
1211 continue;
1212 };
1213
1214 if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
1215 let batch_embeddings = embedding_provider
1216 .embed_batch(mem::take(&mut batch))
1217 .await?;
1218 embeddings.extend(batch_embeddings);
1219 batch_tokens = 0;
1220 }
1221
1222 batch_tokens += span.token_count;
1223 batch.push(span.content.clone());
1224 }
1225
1226 if !batch.is_empty() {
1227 let batch_embeddings = embedding_provider
1228 .embed_batch(mem::take(&mut batch))
1229 .await?;
1230
1231 embeddings.extend(batch_embeddings);
1232 }
1233
1234 let mut embeddings = embeddings.into_iter();
1235 for span in spans {
1236 let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
1237 Some(embedding.clone())
1238 } else {
1239 embeddings.next()
1240 };
1241 let embedding = embedding.context("failed to embed spans")?;
1242 span.embedding = Some(embedding);
1243 }
1244 Ok(())
1245 }
1246}
1247
1248impl Drop for JobHandle {
1249 fn drop(&mut self) {
1250 if let Some(inner) = Arc::get_mut(&mut self.tx) {
1251 // This is the last instance of the JobHandle (regardless of its origin - whether it was cloned or not)
1252 if let Some(tx) = inner.upgrade() {
1253 let mut tx = tx.lock();
1254 *tx.borrow_mut() -= 1;
1255 }
1256 }
1257 }
1258}
1259
1260#[cfg(test)]
1261mod tests {
1262
1263 use super::*;
1264 #[test]
1265 fn test_job_handle() {
1266 let (job_count_tx, job_count_rx) = watch::channel_with(0);
1267 let tx = Arc::new(Mutex::new(job_count_tx));
1268 let job_handle = JobHandle::new(&tx);
1269
1270 assert_eq!(1, *job_count_rx.borrow());
1271 let new_job_handle = job_handle.clone();
1272 assert_eq!(1, *job_count_rx.borrow());
1273 drop(job_handle);
1274 assert_eq!(1, *job_count_rx.borrow());
1275 drop(new_job_handle);
1276 assert_eq!(0, *job_count_rx.borrow());
1277 }
1278}