1mod db;
2mod embedding_queue;
3mod parsing;
4pub mod semantic_index_settings;
5
6#[cfg(test)]
7mod semantic_index_tests;
8
9use crate::semantic_index_settings::SemanticIndexSettings;
10use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
11use anyhow::{anyhow, Result};
12use collections::{BTreeMap, HashMap, HashSet};
13use db::VectorDatabase;
14use embedding_queue::{EmbeddingQueue, FileToEmbed};
15use futures::{future, FutureExt, StreamExt};
16use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
17use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
18use lazy_static::lazy_static;
19use ordered_float::OrderedFloat;
20use parking_lot::Mutex;
21use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
22use postage::watch;
23use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
24use smol::channel;
25use std::{
26 cmp::Reverse,
27 env,
28 future::Future,
29 mem,
30 ops::Range,
31 path::{Path, PathBuf},
32 sync::{Arc, Weak},
33 time::{Duration, Instant, SystemTime},
34};
35use util::{channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
36use workspace::WorkspaceCreated;
37
38const SEMANTIC_INDEX_VERSION: usize = 11;
39const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
40const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
41
42lazy_static! {
43 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
44}
45
46pub fn init(
47 fs: Arc<dyn Fs>,
48 http_client: Arc<dyn HttpClient>,
49 language_registry: Arc<LanguageRegistry>,
50 cx: &mut AppContext,
51) {
52 settings::register::<SemanticIndexSettings>(cx);
53
54 let db_file_path = EMBEDDINGS_DIR
55 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
56 .join("embeddings_db");
57
58 cx.subscribe_global::<WorkspaceCreated, _>({
59 move |event, cx| {
60 let Some(semantic_index) = SemanticIndex::global(cx) else {
61 return;
62 };
63 let workspace = &event.0;
64 if let Some(workspace) = workspace.upgrade(cx) {
65 let project = workspace.read(cx).project().clone();
66 if project.read(cx).is_local() {
67 cx.spawn(|mut cx| async move {
68 let previously_indexed = semantic_index
69 .update(&mut cx, |index, cx| {
70 index.project_previously_indexed(&project, cx)
71 })
72 .await?;
73 if previously_indexed {
74 semantic_index
75 .update(&mut cx, |index, cx| index.index_project(project, cx))
76 .await?;
77 }
78 anyhow::Ok(())
79 })
80 .detach_and_log_err(cx);
81 }
82 }
83 }
84 })
85 .detach();
86
87 cx.spawn(move |mut cx| async move {
88 let semantic_index = SemanticIndex::new(
89 fs,
90 db_file_path,
91 Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
92 language_registry,
93 cx.clone(),
94 )
95 .await?;
96
97 cx.update(|cx| {
98 cx.set_global(semantic_index.clone());
99 });
100
101 anyhow::Ok(())
102 })
103 .detach();
104}
105
106#[derive(Copy, Clone, Debug)]
107pub enum SemanticIndexStatus {
108 NotAuthenticated,
109 NotIndexed,
110 Indexed,
111 Indexing {
112 remaining_files: usize,
113 rate_limit_expiry: Option<Instant>,
114 },
115}
116
117pub struct SemanticIndex {
118 fs: Arc<dyn Fs>,
119 db: VectorDatabase,
120 embedding_provider: Arc<dyn EmbeddingProvider>,
121 language_registry: Arc<LanguageRegistry>,
122 parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
123 _embedding_task: Task<()>,
124 _parsing_files_tasks: Vec<Task<()>>,
125 projects: HashMap<WeakModelHandle<Project>, ProjectState>,
126}
127
128struct ProjectState {
129 worktrees: HashMap<WorktreeId, WorktreeState>,
130 pending_file_count_rx: watch::Receiver<usize>,
131 pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
132 pending_index: usize,
133 _subscription: gpui::Subscription,
134 _observe_pending_file_count: Task<()>,
135}
136
137enum WorktreeState {
138 Registering(RegisteringWorktreeState),
139 Registered(RegisteredWorktreeState),
140}
141
142impl WorktreeState {
143 fn is_registered(&self) -> bool {
144 matches!(self, Self::Registered(_))
145 }
146
147 fn paths_changed(
148 &mut self,
149 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
150 worktree: &Worktree,
151 ) {
152 let changed_paths = match self {
153 Self::Registering(state) => &mut state.changed_paths,
154 Self::Registered(state) => &mut state.changed_paths,
155 };
156
157 for (path, entry_id, change) in changes.iter() {
158 let Some(entry) = worktree.entry_for_id(*entry_id) else {
159 continue;
160 };
161 if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
162 continue;
163 }
164 changed_paths.insert(
165 path.clone(),
166 ChangedPathInfo {
167 mtime: entry.mtime,
168 is_deleted: *change == PathChange::Removed,
169 },
170 );
171 }
172 }
173}
174
175struct RegisteringWorktreeState {
176 changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
177 done_rx: watch::Receiver<Option<()>>,
178 _registration: Task<()>,
179}
180
181impl RegisteringWorktreeState {
182 fn done(&self) -> impl Future<Output = ()> {
183 let mut done_rx = self.done_rx.clone();
184 async move {
185 while let Some(result) = done_rx.next().await {
186 if result.is_some() {
187 break;
188 }
189 }
190 }
191 }
192}
193
194struct RegisteredWorktreeState {
195 db_id: i64,
196 changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
197}
198
199struct ChangedPathInfo {
200 mtime: SystemTime,
201 is_deleted: bool,
202}
203
204#[derive(Clone)]
205pub struct JobHandle {
206 /// The outer Arc is here to count the clones of a JobHandle instance;
207 /// when the last handle to a given job is dropped, we decrement a counter (just once).
208 tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
209}
210
211impl JobHandle {
212 fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
213 *tx.lock().borrow_mut() += 1;
214 Self {
215 tx: Arc::new(Arc::downgrade(&tx)),
216 }
217 }
218}
219
220impl ProjectState {
221 fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
222 let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
223 let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
224 Self {
225 worktrees: Default::default(),
226 pending_file_count_rx: pending_file_count_rx.clone(),
227 pending_file_count_tx,
228 pending_index: 0,
229 _subscription: subscription,
230 _observe_pending_file_count: cx.spawn_weak({
231 let mut pending_file_count_rx = pending_file_count_rx.clone();
232 |this, mut cx| async move {
233 while let Some(_) = pending_file_count_rx.next().await {
234 if let Some(this) = this.upgrade(&cx) {
235 this.update(&mut cx, |_, cx| cx.notify());
236 }
237 }
238 }
239 }),
240 }
241 }
242
243 fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
244 self.worktrees
245 .iter()
246 .find_map(|(worktree_id, worktree_state)| match worktree_state {
247 WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
248 _ => None,
249 })
250 }
251}
252
253#[derive(Clone)]
254pub struct PendingFile {
255 worktree_db_id: i64,
256 relative_path: Arc<Path>,
257 absolute_path: PathBuf,
258 language: Option<Arc<Language>>,
259 modified_time: SystemTime,
260 job_handle: JobHandle,
261}
262
263#[derive(Clone)]
264pub struct SearchResult {
265 pub buffer: ModelHandle<Buffer>,
266 pub range: Range<Anchor>,
267 pub similarity: OrderedFloat<f32>,
268}
269
270impl SemanticIndex {
271 pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
272 if cx.has_global::<ModelHandle<Self>>() {
273 Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
274 } else {
275 None
276 }
277 }
278
279 pub fn enabled(cx: &AppContext) -> bool {
280 settings::get::<SemanticIndexSettings>(cx).enabled
281 }
282
283 pub fn status(&self, project: &ModelHandle<Project>) -> SemanticIndexStatus {
284 if !self.embedding_provider.is_authenticated() {
285 return SemanticIndexStatus::NotAuthenticated;
286 }
287
288 if let Some(project_state) = self.projects.get(&project.downgrade()) {
289 if project_state
290 .worktrees
291 .values()
292 .all(|worktree| worktree.is_registered())
293 && project_state.pending_index == 0
294 {
295 SemanticIndexStatus::Indexed
296 } else {
297 SemanticIndexStatus::Indexing {
298 remaining_files: project_state.pending_file_count_rx.borrow().clone(),
299 rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
300 }
301 }
302 } else {
303 SemanticIndexStatus::NotIndexed
304 }
305 }
306
307 pub async fn new(
308 fs: Arc<dyn Fs>,
309 database_path: PathBuf,
310 embedding_provider: Arc<dyn EmbeddingProvider>,
311 language_registry: Arc<LanguageRegistry>,
312 mut cx: AsyncAppContext,
313 ) -> Result<ModelHandle<Self>> {
314 let t0 = Instant::now();
315 let database_path = Arc::from(database_path);
316 let db = VectorDatabase::new(fs.clone(), database_path, cx.background()).await?;
317
318 log::trace!(
319 "db initialization took {:?} milliseconds",
320 t0.elapsed().as_millis()
321 );
322
323 Ok(cx.add_model(|cx| {
324 let t0 = Instant::now();
325 let embedding_queue =
326 EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
327 let _embedding_task = cx.background().spawn({
328 let embedded_files = embedding_queue.finished_files();
329 let db = db.clone();
330 async move {
331 while let Ok(file) = embedded_files.recv().await {
332 db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
333 .await
334 .log_err();
335 }
336 }
337 });
338
339 // Parse files into embeddable spans.
340 let (parsing_files_tx, parsing_files_rx) =
341 channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
342 let embedding_queue = Arc::new(Mutex::new(embedding_queue));
343 let mut _parsing_files_tasks = Vec::new();
344 for _ in 0..cx.background().num_cpus() {
345 let fs = fs.clone();
346 let mut parsing_files_rx = parsing_files_rx.clone();
347 let embedding_provider = embedding_provider.clone();
348 let embedding_queue = embedding_queue.clone();
349 let background = cx.background().clone();
350 _parsing_files_tasks.push(cx.background().spawn(async move {
351 let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
352 loop {
353 let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
354 let mut next_file_to_parse = parsing_files_rx.next().fuse();
355 futures::select_biased! {
356 next_file_to_parse = next_file_to_parse => {
357 if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
358 Self::parse_file(
359 &fs,
360 pending_file,
361 &mut retriever,
362 &embedding_queue,
363 &embeddings_for_digest,
364 )
365 .await
366 } else {
367 break;
368 }
369 },
370 _ = timer => {
371 embedding_queue.lock().flush();
372 }
373 }
374 }
375 }));
376 }
377
378 log::trace!(
379 "semantic index task initialization took {:?} milliseconds",
380 t0.elapsed().as_millis()
381 );
382 Self {
383 fs,
384 db,
385 embedding_provider,
386 language_registry,
387 parsing_files_tx,
388 _embedding_task,
389 _parsing_files_tasks,
390 projects: Default::default(),
391 }
392 }))
393 }
394
395 async fn parse_file(
396 fs: &Arc<dyn Fs>,
397 pending_file: PendingFile,
398 retriever: &mut CodeContextRetriever,
399 embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
400 embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
401 ) {
402 let Some(language) = pending_file.language else {
403 return;
404 };
405
406 if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
407 if let Some(mut spans) = retriever
408 .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
409 .log_err()
410 {
411 log::trace!(
412 "parsed path {:?}: {} spans",
413 pending_file.relative_path,
414 spans.len()
415 );
416
417 for span in &mut spans {
418 if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
419 span.embedding = Some(embedding.to_owned());
420 }
421 }
422
423 embedding_queue.lock().push(FileToEmbed {
424 worktree_id: pending_file.worktree_db_id,
425 path: pending_file.relative_path,
426 mtime: pending_file.modified_time,
427 job_handle: pending_file.job_handle,
428 spans,
429 });
430 }
431 }
432 }
433
434 pub fn project_previously_indexed(
435 &mut self,
436 project: &ModelHandle<Project>,
437 cx: &mut ModelContext<Self>,
438 ) -> Task<Result<bool>> {
439 let worktrees_indexed_previously = project
440 .read(cx)
441 .worktrees(cx)
442 .map(|worktree| {
443 self.db
444 .worktree_previously_indexed(&worktree.read(cx).abs_path())
445 })
446 .collect::<Vec<_>>();
447 cx.spawn(|_, _cx| async move {
448 let worktree_indexed_previously =
449 futures::future::join_all(worktrees_indexed_previously).await;
450
451 Ok(worktree_indexed_previously
452 .iter()
453 .filter(|worktree| worktree.is_ok())
454 .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
455 })
456 }
457
458 fn project_entries_changed(
459 &mut self,
460 project: ModelHandle<Project>,
461 worktree_id: WorktreeId,
462 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
463 cx: &mut ModelContext<Self>,
464 ) {
465 let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
466 return;
467 };
468 let project = project.downgrade();
469 let Some(project_state) = self.projects.get_mut(&project) else {
470 return;
471 };
472
473 let worktree = worktree.read(cx);
474 let worktree_state =
475 if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
476 worktree_state
477 } else {
478 return;
479 };
480 worktree_state.paths_changed(changes, worktree);
481 if let WorktreeState::Registered(_) = worktree_state {
482 cx.spawn_weak(|this, mut cx| async move {
483 cx.background().timer(BACKGROUND_INDEXING_DELAY).await;
484 if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) {
485 this.update(&mut cx, |this, cx| {
486 this.index_project(project, cx).detach_and_log_err(cx)
487 });
488 }
489 })
490 .detach();
491 }
492 }
493
494 fn register_worktree(
495 &mut self,
496 project: ModelHandle<Project>,
497 worktree: ModelHandle<Worktree>,
498 cx: &mut ModelContext<Self>,
499 ) {
500 let project = project.downgrade();
501 let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
502 project_state
503 } else {
504 return;
505 };
506 let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
507 worktree
508 } else {
509 return;
510 };
511 let worktree_abs_path = worktree.abs_path().clone();
512 let scan_complete = worktree.scan_complete();
513 let worktree_id = worktree.id();
514 let db = self.db.clone();
515 let language_registry = self.language_registry.clone();
516 let (mut done_tx, done_rx) = watch::channel();
517 let registration = cx.spawn(|this, mut cx| {
518 async move {
519 let register = async {
520 scan_complete.await;
521 let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
522 let mut file_mtimes = db.get_file_mtimes(db_id).await?;
523 let worktree = if let Some(project) = project.upgrade(&cx) {
524 project
525 .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
526 .ok_or_else(|| anyhow!("worktree not found"))?
527 } else {
528 return anyhow::Ok(());
529 };
530 let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot());
531 let mut changed_paths = cx
532 .background()
533 .spawn(async move {
534 let mut changed_paths = BTreeMap::new();
535 for file in worktree.files(false, 0) {
536 let absolute_path = worktree.absolutize(&file.path);
537
538 if file.is_external || file.is_ignored || file.is_symlink {
539 continue;
540 }
541
542 if let Ok(language) = language_registry
543 .language_for_file(&absolute_path, None)
544 .await
545 {
546 // Test if file is valid parseable file
547 if !PARSEABLE_ENTIRE_FILE_TYPES
548 .contains(&language.name().as_ref())
549 && &language.name().as_ref() != &"Markdown"
550 && language
551 .grammar()
552 .and_then(|grammar| grammar.embedding_config.as_ref())
553 .is_none()
554 {
555 continue;
556 }
557
558 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
559 let already_stored = stored_mtime
560 .map_or(false, |existing_mtime| {
561 existing_mtime == file.mtime
562 });
563
564 if !already_stored {
565 changed_paths.insert(
566 file.path.clone(),
567 ChangedPathInfo {
568 mtime: file.mtime,
569 is_deleted: false,
570 },
571 );
572 }
573 }
574 }
575
576 // Clean up entries from database that are no longer in the worktree.
577 for (path, mtime) in file_mtimes {
578 changed_paths.insert(
579 path.into(),
580 ChangedPathInfo {
581 mtime,
582 is_deleted: true,
583 },
584 );
585 }
586
587 anyhow::Ok(changed_paths)
588 })
589 .await?;
590 this.update(&mut cx, |this, cx| {
591 let project_state = this
592 .projects
593 .get_mut(&project)
594 .ok_or_else(|| anyhow!("project not registered"))?;
595 let project = project
596 .upgrade(cx)
597 .ok_or_else(|| anyhow!("project was dropped"))?;
598
599 if let Some(WorktreeState::Registering(state)) =
600 project_state.worktrees.remove(&worktree_id)
601 {
602 changed_paths.extend(state.changed_paths);
603 }
604 project_state.worktrees.insert(
605 worktree_id,
606 WorktreeState::Registered(RegisteredWorktreeState {
607 db_id,
608 changed_paths,
609 }),
610 );
611 this.index_project(project, cx).detach_and_log_err(cx);
612
613 anyhow::Ok(())
614 })?;
615
616 anyhow::Ok(())
617 };
618
619 if register.await.log_err().is_none() {
620 // Stop tracking this worktree if the registration failed.
621 this.update(&mut cx, |this, _| {
622 this.projects.get_mut(&project).map(|project_state| {
623 project_state.worktrees.remove(&worktree_id);
624 });
625 })
626 }
627
628 *done_tx.borrow_mut() = Some(());
629 }
630 });
631 project_state.worktrees.insert(
632 worktree_id,
633 WorktreeState::Registering(RegisteringWorktreeState {
634 changed_paths: Default::default(),
635 done_rx,
636 _registration: registration,
637 }),
638 );
639 }
640
641 fn project_worktrees_changed(
642 &mut self,
643 project: ModelHandle<Project>,
644 cx: &mut ModelContext<Self>,
645 ) {
646 let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
647 {
648 project_state
649 } else {
650 return;
651 };
652
653 let mut worktrees = project
654 .read(cx)
655 .worktrees(cx)
656 .filter(|worktree| worktree.read(cx).is_local())
657 .collect::<Vec<_>>();
658 let worktree_ids = worktrees
659 .iter()
660 .map(|worktree| worktree.read(cx).id())
661 .collect::<HashSet<_>>();
662
663 // Remove worktrees that are no longer present
664 project_state
665 .worktrees
666 .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
667
668 // Register new worktrees
669 worktrees.retain(|worktree| {
670 let worktree_id = worktree.read(cx).id();
671 !project_state.worktrees.contains_key(&worktree_id)
672 });
673 for worktree in worktrees {
674 self.register_worktree(project.clone(), worktree, cx);
675 }
676 }
677
678 pub fn pending_file_count(
679 &self,
680 project: &ModelHandle<Project>,
681 ) -> Option<watch::Receiver<usize>> {
682 Some(
683 self.projects
684 .get(&project.downgrade())?
685 .pending_file_count_rx
686 .clone(),
687 )
688 }
689
690 pub fn search_project(
691 &mut self,
692 project: ModelHandle<Project>,
693 query: String,
694 limit: usize,
695 includes: Vec<PathMatcher>,
696 excludes: Vec<PathMatcher>,
697 cx: &mut ModelContext<Self>,
698 ) -> Task<Result<Vec<SearchResult>>> {
699 if query.is_empty() {
700 return Task::ready(Ok(Vec::new()));
701 }
702
703 let index = self.index_project(project.clone(), cx);
704 let embedding_provider = self.embedding_provider.clone();
705
706 cx.spawn(|this, mut cx| async move {
707 index.await?;
708 let t0 = Instant::now();
709 let query = embedding_provider
710 .embed_batch(vec![query])
711 .await?
712 .pop()
713 .ok_or_else(|| anyhow!("could not embed query"))?;
714 log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
715
716 let search_start = Instant::now();
717 let modified_buffer_results = this.update(&mut cx, |this, cx| {
718 this.search_modified_buffers(
719 &project,
720 query.clone(),
721 limit,
722 &includes,
723 &excludes,
724 cx,
725 )
726 });
727 let file_results = this.update(&mut cx, |this, cx| {
728 this.search_files(project, query, limit, includes, excludes, cx)
729 });
730 let (modified_buffer_results, file_results) =
731 futures::join!(modified_buffer_results, file_results);
732
733 // Weave together the results from modified buffers and files.
734 let mut results = Vec::new();
735 let mut modified_buffers = HashSet::default();
736 for result in modified_buffer_results.log_err().unwrap_or_default() {
737 modified_buffers.insert(result.buffer.clone());
738 results.push(result);
739 }
740 for result in file_results.log_err().unwrap_or_default() {
741 if !modified_buffers.contains(&result.buffer) {
742 results.push(result);
743 }
744 }
745 results.sort_by_key(|result| Reverse(result.similarity));
746 results.truncate(limit);
747 log::trace!("Semantic search took {:?}", search_start.elapsed());
748 Ok(results)
749 })
750 }
751
752 pub fn search_files(
753 &mut self,
754 project: ModelHandle<Project>,
755 query: Embedding,
756 limit: usize,
757 includes: Vec<PathMatcher>,
758 excludes: Vec<PathMatcher>,
759 cx: &mut ModelContext<Self>,
760 ) -> Task<Result<Vec<SearchResult>>> {
761 let db_path = self.db.path().clone();
762 let fs = self.fs.clone();
763 cx.spawn(|this, mut cx| async move {
764 let database =
765 VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
766
767 let worktree_db_ids = this.read_with(&cx, |this, _| {
768 let project_state = this
769 .projects
770 .get(&project.downgrade())
771 .ok_or_else(|| anyhow!("project was not indexed"))?;
772 let worktree_db_ids = project_state
773 .worktrees
774 .values()
775 .filter_map(|worktree| {
776 if let WorktreeState::Registered(worktree) = worktree {
777 Some(worktree.db_id)
778 } else {
779 None
780 }
781 })
782 .collect::<Vec<i64>>();
783 anyhow::Ok(worktree_db_ids)
784 })?;
785
786 let file_ids = database
787 .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
788 .await?;
789
790 let batch_n = cx.background().num_cpus();
791 let ids_len = file_ids.clone().len();
792 let minimum_batch_size = 50;
793
794 let batch_size = {
795 let size = ids_len / batch_n;
796 if size < minimum_batch_size {
797 minimum_batch_size
798 } else {
799 size
800 }
801 };
802
803 let mut batch_results = Vec::new();
804 for batch in file_ids.chunks(batch_size) {
805 let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
806 let limit = limit.clone();
807 let fs = fs.clone();
808 let db_path = db_path.clone();
809 let query = query.clone();
810 if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
811 .await
812 .log_err()
813 {
814 batch_results.push(async move {
815 db.top_k_search(&query, limit, batch.as_slice()).await
816 });
817 }
818 }
819
820 let batch_results = futures::future::join_all(batch_results).await;
821
822 let mut results = Vec::new();
823 for batch_result in batch_results {
824 if batch_result.is_ok() {
825 for (id, similarity) in batch_result.unwrap() {
826 let ix = match results
827 .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
828 {
829 Ok(ix) => ix,
830 Err(ix) => ix,
831 };
832
833 results.insert(ix, (id, similarity));
834 results.truncate(limit);
835 }
836 }
837 }
838
839 let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
840 let scores = results
841 .into_iter()
842 .map(|(_, score)| score)
843 .collect::<Vec<_>>();
844 let spans = database.spans_for_ids(ids.as_slice()).await?;
845
846 let mut tasks = Vec::new();
847 let mut ranges = Vec::new();
848 let weak_project = project.downgrade();
849 project.update(&mut cx, |project, cx| {
850 for (worktree_db_id, file_path, byte_range) in spans {
851 let project_state =
852 if let Some(state) = this.read(cx).projects.get(&weak_project) {
853 state
854 } else {
855 return Err(anyhow!("project not added"));
856 };
857 if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
858 tasks.push(project.open_buffer((worktree_id, file_path), cx));
859 ranges.push(byte_range);
860 }
861 }
862
863 Ok(())
864 })?;
865
866 let buffers = futures::future::join_all(tasks).await;
867 Ok(buffers
868 .into_iter()
869 .zip(ranges)
870 .zip(scores)
871 .filter_map(|((buffer, range), similarity)| {
872 let buffer = buffer.log_err()?;
873 let range = buffer.read_with(&cx, |buffer, _| {
874 let start = buffer.clip_offset(range.start, Bias::Left);
875 let end = buffer.clip_offset(range.end, Bias::Right);
876 buffer.anchor_before(start)..buffer.anchor_after(end)
877 });
878 Some(SearchResult {
879 buffer,
880 range,
881 similarity,
882 })
883 })
884 .collect())
885 })
886 }
887
888 fn search_modified_buffers(
889 &self,
890 project: &ModelHandle<Project>,
891 query: Embedding,
892 limit: usize,
893 includes: &[PathMatcher],
894 excludes: &[PathMatcher],
895 cx: &mut ModelContext<Self>,
896 ) -> Task<Result<Vec<SearchResult>>> {
897 let modified_buffers = project
898 .read(cx)
899 .opened_buffers(cx)
900 .into_iter()
901 .filter_map(|buffer_handle| {
902 let buffer = buffer_handle.read(cx);
903 let snapshot = buffer.snapshot();
904 let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
905 excludes.iter().any(|matcher| matcher.is_match(&path))
906 });
907
908 let included = if includes.len() == 0 {
909 true
910 } else {
911 snapshot.resolve_file_path(cx, false).map_or(false, |path| {
912 includes.iter().any(|matcher| matcher.is_match(&path))
913 })
914 };
915
916 if buffer.is_dirty() && !excluded && included {
917 Some((buffer_handle, snapshot))
918 } else {
919 None
920 }
921 })
922 .collect::<HashMap<_, _>>();
923
924 let embedding_provider = self.embedding_provider.clone();
925 let fs = self.fs.clone();
926 let db_path = self.db.path().clone();
927 let background = cx.background().clone();
928 cx.background().spawn(async move {
929 let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
930 let mut results = Vec::<SearchResult>::new();
931
932 let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
933 for (buffer, snapshot) in modified_buffers {
934 let language = snapshot
935 .language_at(0)
936 .cloned()
937 .unwrap_or_else(|| language::PLAIN_TEXT.clone());
938 let mut spans = retriever
939 .parse_file_with_template(None, &snapshot.text(), language)
940 .log_err()
941 .unwrap_or_default();
942 if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
943 .await
944 .log_err()
945 .is_some()
946 {
947 for span in spans {
948 let similarity = span.embedding.unwrap().similarity(&query);
949 let ix = match results
950 .binary_search_by_key(&Reverse(similarity), |result| {
951 Reverse(result.similarity)
952 }) {
953 Ok(ix) => ix,
954 Err(ix) => ix,
955 };
956
957 let range = {
958 let start = snapshot.clip_offset(span.range.start, Bias::Left);
959 let end = snapshot.clip_offset(span.range.end, Bias::Right);
960 snapshot.anchor_before(start)..snapshot.anchor_after(end)
961 };
962
963 results.insert(
964 ix,
965 SearchResult {
966 buffer: buffer.clone(),
967 range,
968 similarity,
969 },
970 );
971 results.truncate(limit);
972 }
973 }
974 }
975
976 Ok(results)
977 })
978 }
979
980 pub fn index_project(
981 &mut self,
982 project: ModelHandle<Project>,
983 cx: &mut ModelContext<Self>,
984 ) -> Task<Result<()>> {
985 if !self.embedding_provider.is_authenticated() {
986 return Task::ready(Err(anyhow!("user is not authenticated")));
987 }
988
989 if !self.projects.contains_key(&project.downgrade()) {
990 let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
991 project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
992 this.project_worktrees_changed(project.clone(), cx);
993 }
994 project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
995 this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
996 }
997 _ => {}
998 });
999 let project_state = ProjectState::new(subscription, cx);
1000 self.projects.insert(project.downgrade(), project_state);
1001 self.project_worktrees_changed(project.clone(), cx);
1002 }
1003 let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
1004 project_state.pending_index += 1;
1005 cx.notify();
1006
1007 let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
1008 let db = self.db.clone();
1009 let language_registry = self.language_registry.clone();
1010 let parsing_files_tx = self.parsing_files_tx.clone();
1011 let worktree_registration = self.wait_for_worktree_registration(&project, cx);
1012
1013 cx.spawn(|this, mut cx| async move {
1014 worktree_registration.await?;
1015
1016 let mut pending_files = Vec::new();
1017 let mut files_to_delete = Vec::new();
1018 this.update(&mut cx, |this, cx| {
1019 let project_state = this
1020 .projects
1021 .get_mut(&project.downgrade())
1022 .ok_or_else(|| anyhow!("project was dropped"))?;
1023 let pending_file_count_tx = &project_state.pending_file_count_tx;
1024
1025 project_state
1026 .worktrees
1027 .retain(|worktree_id, worktree_state| {
1028 let worktree = if let Some(worktree) =
1029 project.read(cx).worktree_for_id(*worktree_id, cx)
1030 {
1031 worktree
1032 } else {
1033 return false;
1034 };
1035 let worktree_state =
1036 if let WorktreeState::Registered(worktree_state) = worktree_state {
1037 worktree_state
1038 } else {
1039 return true;
1040 };
1041
1042 worktree_state.changed_paths.retain(|path, info| {
1043 if info.is_deleted {
1044 files_to_delete.push((worktree_state.db_id, path.clone()));
1045 } else {
1046 let absolute_path = worktree.read(cx).absolutize(path);
1047 let job_handle = JobHandle::new(pending_file_count_tx);
1048 pending_files.push(PendingFile {
1049 absolute_path,
1050 relative_path: path.clone(),
1051 language: None,
1052 job_handle,
1053 modified_time: info.mtime,
1054 worktree_db_id: worktree_state.db_id,
1055 });
1056 }
1057
1058 false
1059 });
1060 true
1061 });
1062
1063 anyhow::Ok(())
1064 })?;
1065
1066 cx.background()
1067 .spawn(async move {
1068 for (worktree_db_id, path) in files_to_delete {
1069 db.delete_file(worktree_db_id, path).await.log_err();
1070 }
1071
1072 let embeddings_for_digest = {
1073 let mut files = HashMap::default();
1074 for pending_file in &pending_files {
1075 files
1076 .entry(pending_file.worktree_db_id)
1077 .or_insert(Vec::new())
1078 .push(pending_file.relative_path.clone());
1079 }
1080 Arc::new(
1081 db.embeddings_for_files(files)
1082 .await
1083 .log_err()
1084 .unwrap_or_default(),
1085 )
1086 };
1087
1088 for mut pending_file in pending_files {
1089 if let Ok(language) = language_registry
1090 .language_for_file(&pending_file.relative_path, None)
1091 .await
1092 {
1093 if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
1094 && &language.name().as_ref() != &"Markdown"
1095 && language
1096 .grammar()
1097 .and_then(|grammar| grammar.embedding_config.as_ref())
1098 .is_none()
1099 {
1100 continue;
1101 }
1102 pending_file.language = Some(language);
1103 }
1104 parsing_files_tx
1105 .try_send((embeddings_for_digest.clone(), pending_file))
1106 .ok();
1107 }
1108
1109 // Wait until we're done indexing.
1110 while let Some(count) = pending_file_count_rx.next().await {
1111 if count == 0 {
1112 break;
1113 }
1114 }
1115 })
1116 .await;
1117
1118 this.update(&mut cx, |this, cx| {
1119 let project_state = this
1120 .projects
1121 .get_mut(&project.downgrade())
1122 .ok_or_else(|| anyhow!("project was dropped"))?;
1123 project_state.pending_index -= 1;
1124 cx.notify();
1125 anyhow::Ok(())
1126 })?;
1127
1128 Ok(())
1129 })
1130 }
1131
1132 fn wait_for_worktree_registration(
1133 &self,
1134 project: &ModelHandle<Project>,
1135 cx: &mut ModelContext<Self>,
1136 ) -> Task<Result<()>> {
1137 let project = project.downgrade();
1138 cx.spawn_weak(|this, cx| async move {
1139 loop {
1140 let mut pending_worktrees = Vec::new();
1141 this.upgrade(&cx)
1142 .ok_or_else(|| anyhow!("semantic index dropped"))?
1143 .read_with(&cx, |this, _| {
1144 if let Some(project) = this.projects.get(&project) {
1145 for worktree in project.worktrees.values() {
1146 if let WorktreeState::Registering(worktree) = worktree {
1147 pending_worktrees.push(worktree.done());
1148 }
1149 }
1150 }
1151 });
1152
1153 if pending_worktrees.is_empty() {
1154 break;
1155 } else {
1156 future::join_all(pending_worktrees).await;
1157 }
1158 }
1159 Ok(())
1160 })
1161 }
1162
1163 async fn embed_spans(
1164 spans: &mut [Span],
1165 embedding_provider: &dyn EmbeddingProvider,
1166 db: &VectorDatabase,
1167 ) -> Result<()> {
1168 let mut batch = Vec::new();
1169 let mut batch_tokens = 0;
1170 let mut embeddings = Vec::new();
1171
1172 let digests = spans
1173 .iter()
1174 .map(|span| span.digest.clone())
1175 .collect::<Vec<_>>();
1176 let embeddings_for_digests = db
1177 .embeddings_for_digests(digests)
1178 .await
1179 .log_err()
1180 .unwrap_or_default();
1181
1182 for span in &*spans {
1183 if embeddings_for_digests.contains_key(&span.digest) {
1184 continue;
1185 };
1186
1187 if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
1188 let batch_embeddings = embedding_provider
1189 .embed_batch(mem::take(&mut batch))
1190 .await?;
1191 embeddings.extend(batch_embeddings);
1192 batch_tokens = 0;
1193 }
1194
1195 batch_tokens += span.token_count;
1196 batch.push(span.content.clone());
1197 }
1198
1199 if !batch.is_empty() {
1200 let batch_embeddings = embedding_provider
1201 .embed_batch(mem::take(&mut batch))
1202 .await?;
1203
1204 embeddings.extend(batch_embeddings);
1205 }
1206
1207 let mut embeddings = embeddings.into_iter();
1208 for span in spans {
1209 let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
1210 Some(embedding.clone())
1211 } else {
1212 embeddings.next()
1213 };
1214 let embedding = embedding.ok_or_else(|| anyhow!("failed to embed spans"))?;
1215 span.embedding = Some(embedding);
1216 }
1217 Ok(())
1218 }
1219}
1220
1221impl Entity for SemanticIndex {
1222 type Event = ();
1223}
1224
1225impl Drop for JobHandle {
1226 fn drop(&mut self) {
1227 if let Some(inner) = Arc::get_mut(&mut self.tx) {
1228 // This is the last instance of the JobHandle (regardless of it's origin - whether it was cloned or not)
1229 if let Some(tx) = inner.upgrade() {
1230 let mut tx = tx.lock();
1231 *tx.borrow_mut() -= 1;
1232 }
1233 }
1234 }
1235}
1236
1237#[cfg(test)]
1238mod tests {
1239
1240 use super::*;
1241 #[test]
1242 fn test_job_handle() {
1243 let (job_count_tx, job_count_rx) = watch::channel_with(0);
1244 let tx = Arc::new(Mutex::new(job_count_tx));
1245 let job_handle = JobHandle::new(&tx);
1246
1247 assert_eq!(1, *job_count_rx.borrow());
1248 let new_job_handle = job_handle.clone();
1249 assert_eq!(1, *job_count_rx.borrow());
1250 drop(job_handle);
1251 assert_eq!(1, *job_count_rx.borrow());
1252 drop(new_job_handle);
1253 assert_eq!(0, *job_count_rx.borrow());
1254 }
1255}