1mod db;
2mod embedding_queue;
3mod parsing;
4pub mod semantic_index_settings;
5
6#[cfg(test)]
7mod semantic_index_tests;
8
9use crate::semantic_index_settings::SemanticIndexSettings;
10use ai::auth::ProviderCredential;
11use ai::embedding::{Embedding, EmbeddingProvider};
12use ai::providers::open_ai::OpenAIEmbeddingProvider;
13use anyhow::{anyhow, Result};
14use collections::{BTreeMap, HashMap, HashSet};
15use db::VectorDatabase;
16use embedding_queue::{EmbeddingQueue, FileToEmbed};
17use futures::{future, FutureExt, StreamExt};
18use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
19use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
20use lazy_static::lazy_static;
21use ordered_float::OrderedFloat;
22use parking_lot::Mutex;
23use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
24use postage::watch;
25use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
26use smol::channel;
27use std::{
28 cmp::Reverse,
29 env,
30 future::Future,
31 mem,
32 ops::Range,
33 path::{Path, PathBuf},
34 sync::{Arc, Weak},
35 time::{Duration, Instant, SystemTime},
36};
37use util::{channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
38use workspace::WorkspaceCreated;
39
40const SEMANTIC_INDEX_VERSION: usize = 11;
41const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
42const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
43
44lazy_static! {
45 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
46}
47
48pub fn init(
49 fs: Arc<dyn Fs>,
50 http_client: Arc<dyn HttpClient>,
51 language_registry: Arc<LanguageRegistry>,
52 cx: &mut AppContext,
53) {
54 settings::register::<SemanticIndexSettings>(cx);
55
56 let db_file_path = EMBEDDINGS_DIR
57 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
58 .join("embeddings_db");
59
60 cx.subscribe_global::<WorkspaceCreated, _>({
61 move |event, cx| {
62 let Some(semantic_index) = SemanticIndex::global(cx) else {
63 return;
64 };
65 let workspace = &event.0;
66 if let Some(workspace) = workspace.upgrade(cx) {
67 let project = workspace.read(cx).project().clone();
68 if project.read(cx).is_local() {
69 cx.spawn(|mut cx| async move {
70 let previously_indexed = semantic_index
71 .update(&mut cx, |index, cx| {
72 index.project_previously_indexed(&project, cx)
73 })
74 .await?;
75 if previously_indexed {
76 semantic_index
77 .update(&mut cx, |index, cx| index.index_project(project, cx))
78 .await?;
79 }
80 anyhow::Ok(())
81 })
82 .detach_and_log_err(cx);
83 }
84 }
85 }
86 })
87 .detach();
88
89 cx.spawn(move |mut cx| async move {
90 let semantic_index = SemanticIndex::new(
91 fs,
92 db_file_path,
93 Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
94 language_registry,
95 cx.clone(),
96 )
97 .await?;
98
99 cx.update(|cx| {
100 cx.set_global(semantic_index.clone());
101 });
102
103 anyhow::Ok(())
104 })
105 .detach();
106}
107
108#[derive(Copy, Clone, Debug)]
109pub enum SemanticIndexStatus {
110 NotAuthenticated,
111 NotIndexed,
112 Indexed,
113 Indexing {
114 remaining_files: usize,
115 rate_limit_expiry: Option<Instant>,
116 },
117}
118
119pub struct SemanticIndex {
120 fs: Arc<dyn Fs>,
121 db: VectorDatabase,
122 embedding_provider: Arc<dyn EmbeddingProvider>,
123 language_registry: Arc<LanguageRegistry>,
124 parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
125 _embedding_task: Task<()>,
126 _parsing_files_tasks: Vec<Task<()>>,
127 projects: HashMap<WeakModelHandle<Project>, ProjectState>,
128 provider_credential: ProviderCredential,
129 embedding_queue: Arc<Mutex<EmbeddingQueue>>,
130}
131
132struct ProjectState {
133 worktrees: HashMap<WorktreeId, WorktreeState>,
134 pending_file_count_rx: watch::Receiver<usize>,
135 pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
136 pending_index: usize,
137 _subscription: gpui::Subscription,
138 _observe_pending_file_count: Task<()>,
139}
140
141enum WorktreeState {
142 Registering(RegisteringWorktreeState),
143 Registered(RegisteredWorktreeState),
144}
145
146impl WorktreeState {
147 fn is_registered(&self) -> bool {
148 matches!(self, Self::Registered(_))
149 }
150
151 fn paths_changed(
152 &mut self,
153 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
154 worktree: &Worktree,
155 ) {
156 let changed_paths = match self {
157 Self::Registering(state) => &mut state.changed_paths,
158 Self::Registered(state) => &mut state.changed_paths,
159 };
160
161 for (path, entry_id, change) in changes.iter() {
162 let Some(entry) = worktree.entry_for_id(*entry_id) else {
163 continue;
164 };
165 if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
166 continue;
167 }
168 changed_paths.insert(
169 path.clone(),
170 ChangedPathInfo {
171 mtime: entry.mtime,
172 is_deleted: *change == PathChange::Removed,
173 },
174 );
175 }
176 }
177}
178
179struct RegisteringWorktreeState {
180 changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
181 done_rx: watch::Receiver<Option<()>>,
182 _registration: Task<()>,
183}
184
185impl RegisteringWorktreeState {
186 fn done(&self) -> impl Future<Output = ()> {
187 let mut done_rx = self.done_rx.clone();
188 async move {
189 while let Some(result) = done_rx.next().await {
190 if result.is_some() {
191 break;
192 }
193 }
194 }
195 }
196}
197
198struct RegisteredWorktreeState {
199 db_id: i64,
200 changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
201}
202
203struct ChangedPathInfo {
204 mtime: SystemTime,
205 is_deleted: bool,
206}
207
208#[derive(Clone)]
209pub struct JobHandle {
210 /// The outer Arc is here to count the clones of a JobHandle instance;
211 /// when the last handle to a given job is dropped, we decrement a counter (just once).
212 tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
213}
214
215impl JobHandle {
216 fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
217 *tx.lock().borrow_mut() += 1;
218 Self {
219 tx: Arc::new(Arc::downgrade(&tx)),
220 }
221 }
222}
223
224impl ProjectState {
225 fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
226 let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
227 let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
228 Self {
229 worktrees: Default::default(),
230 pending_file_count_rx: pending_file_count_rx.clone(),
231 pending_file_count_tx,
232 pending_index: 0,
233 _subscription: subscription,
234 _observe_pending_file_count: cx.spawn_weak({
235 let mut pending_file_count_rx = pending_file_count_rx.clone();
236 |this, mut cx| async move {
237 while let Some(_) = pending_file_count_rx.next().await {
238 if let Some(this) = this.upgrade(&cx) {
239 this.update(&mut cx, |_, cx| cx.notify());
240 }
241 }
242 }
243 }),
244 }
245 }
246
247 fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
248 self.worktrees
249 .iter()
250 .find_map(|(worktree_id, worktree_state)| match worktree_state {
251 WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
252 _ => None,
253 })
254 }
255}
256
257#[derive(Clone)]
258pub struct PendingFile {
259 worktree_db_id: i64,
260 relative_path: Arc<Path>,
261 absolute_path: PathBuf,
262 language: Option<Arc<Language>>,
263 modified_time: SystemTime,
264 job_handle: JobHandle,
265}
266
267#[derive(Clone)]
268pub struct SearchResult {
269 pub buffer: ModelHandle<Buffer>,
270 pub range: Range<Anchor>,
271 pub similarity: OrderedFloat<f32>,
272}
273
274impl SemanticIndex {
275 pub fn global(cx: &mut AppContext) -> Option<ModelHandle<SemanticIndex>> {
276 if cx.has_global::<ModelHandle<Self>>() {
277 Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
278 } else {
279 None
280 }
281 }
282
283 pub fn authenticate(&mut self, cx: &AppContext) -> bool {
284 let existing_credential = self.provider_credential.clone();
285 let credential = match existing_credential {
286 ProviderCredential::NoCredentials => self.embedding_provider.retrieve_credentials(cx),
287 _ => existing_credential,
288 };
289
290 self.provider_credential = credential.clone();
291 self.embedding_queue.lock().set_credential(credential);
292 self.is_authenticated()
293 }
294
295 pub fn is_authenticated(&self) -> bool {
296 let credential = &self.provider_credential;
297 match credential {
298 &ProviderCredential::Credentials { .. } => true,
299 &ProviderCredential::NotNeeded => true,
300 _ => false,
301 }
302 }
303
304 pub fn enabled(cx: &AppContext) -> bool {
305 settings::get::<SemanticIndexSettings>(cx).enabled
306 }
307
308 pub fn status(&self, project: &ModelHandle<Project>) -> SemanticIndexStatus {
309 if !self.is_authenticated() {
310 return SemanticIndexStatus::NotAuthenticated;
311 }
312
313 if let Some(project_state) = self.projects.get(&project.downgrade()) {
314 if project_state
315 .worktrees
316 .values()
317 .all(|worktree| worktree.is_registered())
318 && project_state.pending_index == 0
319 {
320 SemanticIndexStatus::Indexed
321 } else {
322 SemanticIndexStatus::Indexing {
323 remaining_files: project_state.pending_file_count_rx.borrow().clone(),
324 rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
325 }
326 }
327 } else {
328 SemanticIndexStatus::NotIndexed
329 }
330 }
331
332 pub async fn new(
333 fs: Arc<dyn Fs>,
334 database_path: PathBuf,
335 embedding_provider: Arc<dyn EmbeddingProvider>,
336 language_registry: Arc<LanguageRegistry>,
337 mut cx: AsyncAppContext,
338 ) -> Result<ModelHandle<Self>> {
339 let t0 = Instant::now();
340 let database_path = Arc::from(database_path);
341 let db = VectorDatabase::new(fs.clone(), database_path, cx.background()).await?;
342
343 log::trace!(
344 "db initialization took {:?} milliseconds",
345 t0.elapsed().as_millis()
346 );
347
348 Ok(cx.add_model(|cx| {
349 let t0 = Instant::now();
350 let embedding_queue =
351 EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), ProviderCredential::NoCredentials);
352 let _embedding_task = cx.background().spawn({
353 let embedded_files = embedding_queue.finished_files();
354 let db = db.clone();
355 async move {
356 while let Ok(file) = embedded_files.recv().await {
357 db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
358 .await
359 .log_err();
360 }
361 }
362 });
363
364 // Parse files into embeddable spans.
365 let (parsing_files_tx, parsing_files_rx) =
366 channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
367 let embedding_queue = Arc::new(Mutex::new(embedding_queue));
368 let mut _parsing_files_tasks = Vec::new();
369 for _ in 0..cx.background().num_cpus() {
370 let fs = fs.clone();
371 let mut parsing_files_rx = parsing_files_rx.clone();
372 let embedding_provider = embedding_provider.clone();
373 let embedding_queue = embedding_queue.clone();
374 let background = cx.background().clone();
375 _parsing_files_tasks.push(cx.background().spawn(async move {
376 let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
377 loop {
378 let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
379 let mut next_file_to_parse = parsing_files_rx.next().fuse();
380 futures::select_biased! {
381 next_file_to_parse = next_file_to_parse => {
382 if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
383 Self::parse_file(
384 &fs,
385 pending_file,
386 &mut retriever,
387 &embedding_queue,
388 &embeddings_for_digest,
389 )
390 .await
391 } else {
392 break;
393 }
394 },
395 _ = timer => {
396 embedding_queue.lock().flush();
397 }
398 }
399 }
400 }));
401 }
402
403 log::trace!(
404 "semantic index task initialization took {:?} milliseconds",
405 t0.elapsed().as_millis()
406 );
407 Self {
408 fs,
409 db,
410 embedding_provider,
411 language_registry,
412 parsing_files_tx,
413 _embedding_task,
414 _parsing_files_tasks,
415 projects: Default::default(),
416 provider_credential: ProviderCredential::NoCredentials,
417 embedding_queue
418 }
419 }))
420 }
421
422 async fn parse_file(
423 fs: &Arc<dyn Fs>,
424 pending_file: PendingFile,
425 retriever: &mut CodeContextRetriever,
426 embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
427 embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
428 ) {
429 let Some(language) = pending_file.language else {
430 return;
431 };
432
433 if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
434 if let Some(mut spans) = retriever
435 .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
436 .log_err()
437 {
438 log::trace!(
439 "parsed path {:?}: {} spans",
440 pending_file.relative_path,
441 spans.len()
442 );
443
444 for span in &mut spans {
445 if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
446 span.embedding = Some(embedding.to_owned());
447 }
448 }
449
450 embedding_queue.lock().push(FileToEmbed {
451 worktree_id: pending_file.worktree_db_id,
452 path: pending_file.relative_path,
453 mtime: pending_file.modified_time,
454 job_handle: pending_file.job_handle,
455 spans,
456 });
457 }
458 }
459 }
460
461 pub fn project_previously_indexed(
462 &mut self,
463 project: &ModelHandle<Project>,
464 cx: &mut ModelContext<Self>,
465 ) -> Task<Result<bool>> {
466 let worktrees_indexed_previously = project
467 .read(cx)
468 .worktrees(cx)
469 .map(|worktree| {
470 self.db
471 .worktree_previously_indexed(&worktree.read(cx).abs_path())
472 })
473 .collect::<Vec<_>>();
474 cx.spawn(|_, _cx| async move {
475 let worktree_indexed_previously =
476 futures::future::join_all(worktrees_indexed_previously).await;
477
478 Ok(worktree_indexed_previously
479 .iter()
480 .filter(|worktree| worktree.is_ok())
481 .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
482 })
483 }
484
485 fn project_entries_changed(
486 &mut self,
487 project: ModelHandle<Project>,
488 worktree_id: WorktreeId,
489 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
490 cx: &mut ModelContext<Self>,
491 ) {
492 let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
493 return;
494 };
495 let project = project.downgrade();
496 let Some(project_state) = self.projects.get_mut(&project) else {
497 return;
498 };
499
500 let worktree = worktree.read(cx);
501 let worktree_state =
502 if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
503 worktree_state
504 } else {
505 return;
506 };
507 worktree_state.paths_changed(changes, worktree);
508 if let WorktreeState::Registered(_) = worktree_state {
509 cx.spawn_weak(|this, mut cx| async move {
510 cx.background().timer(BACKGROUND_INDEXING_DELAY).await;
511 if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) {
512 this.update(&mut cx, |this, cx| {
513 this.index_project(project, cx).detach_and_log_err(cx)
514 });
515 }
516 })
517 .detach();
518 }
519 }
520
521 fn register_worktree(
522 &mut self,
523 project: ModelHandle<Project>,
524 worktree: ModelHandle<Worktree>,
525 cx: &mut ModelContext<Self>,
526 ) {
527 let project = project.downgrade();
528 let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
529 project_state
530 } else {
531 return;
532 };
533 let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
534 worktree
535 } else {
536 return;
537 };
538 let worktree_abs_path = worktree.abs_path().clone();
539 let scan_complete = worktree.scan_complete();
540 let worktree_id = worktree.id();
541 let db = self.db.clone();
542 let language_registry = self.language_registry.clone();
543 let (mut done_tx, done_rx) = watch::channel();
544 let registration = cx.spawn(|this, mut cx| {
545 async move {
546 let register = async {
547 scan_complete.await;
548 let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
549 let mut file_mtimes = db.get_file_mtimes(db_id).await?;
550 let worktree = if let Some(project) = project.upgrade(&cx) {
551 project
552 .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
553 .ok_or_else(|| anyhow!("worktree not found"))?
554 } else {
555 return anyhow::Ok(());
556 };
557 let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot());
558 let mut changed_paths = cx
559 .background()
560 .spawn(async move {
561 let mut changed_paths = BTreeMap::new();
562 for file in worktree.files(false, 0) {
563 let absolute_path = worktree.absolutize(&file.path);
564
565 if file.is_external || file.is_ignored || file.is_symlink {
566 continue;
567 }
568
569 if let Ok(language) = language_registry
570 .language_for_file(&absolute_path, None)
571 .await
572 {
573 // Test if file is valid parseable file
574 if !PARSEABLE_ENTIRE_FILE_TYPES
575 .contains(&language.name().as_ref())
576 && &language.name().as_ref() != &"Markdown"
577 && language
578 .grammar()
579 .and_then(|grammar| grammar.embedding_config.as_ref())
580 .is_none()
581 {
582 continue;
583 }
584
585 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
586 let already_stored = stored_mtime
587 .map_or(false, |existing_mtime| {
588 existing_mtime == file.mtime
589 });
590
591 if !already_stored {
592 changed_paths.insert(
593 file.path.clone(),
594 ChangedPathInfo {
595 mtime: file.mtime,
596 is_deleted: false,
597 },
598 );
599 }
600 }
601 }
602
603 // Clean up entries from database that are no longer in the worktree.
604 for (path, mtime) in file_mtimes {
605 changed_paths.insert(
606 path.into(),
607 ChangedPathInfo {
608 mtime,
609 is_deleted: true,
610 },
611 );
612 }
613
614 anyhow::Ok(changed_paths)
615 })
616 .await?;
617 this.update(&mut cx, |this, cx| {
618 let project_state = this
619 .projects
620 .get_mut(&project)
621 .ok_or_else(|| anyhow!("project not registered"))?;
622 let project = project
623 .upgrade(cx)
624 .ok_or_else(|| anyhow!("project was dropped"))?;
625
626 if let Some(WorktreeState::Registering(state)) =
627 project_state.worktrees.remove(&worktree_id)
628 {
629 changed_paths.extend(state.changed_paths);
630 }
631 project_state.worktrees.insert(
632 worktree_id,
633 WorktreeState::Registered(RegisteredWorktreeState {
634 db_id,
635 changed_paths,
636 }),
637 );
638 this.index_project(project, cx).detach_and_log_err(cx);
639
640 anyhow::Ok(())
641 })?;
642
643 anyhow::Ok(())
644 };
645
646 if register.await.log_err().is_none() {
647 // Stop tracking this worktree if the registration failed.
648 this.update(&mut cx, |this, _| {
649 this.projects.get_mut(&project).map(|project_state| {
650 project_state.worktrees.remove(&worktree_id);
651 });
652 })
653 }
654
655 *done_tx.borrow_mut() = Some(());
656 }
657 });
658 project_state.worktrees.insert(
659 worktree_id,
660 WorktreeState::Registering(RegisteringWorktreeState {
661 changed_paths: Default::default(),
662 done_rx,
663 _registration: registration,
664 }),
665 );
666 }
667
668 fn project_worktrees_changed(
669 &mut self,
670 project: ModelHandle<Project>,
671 cx: &mut ModelContext<Self>,
672 ) {
673 let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
674 {
675 project_state
676 } else {
677 return;
678 };
679
680 let mut worktrees = project
681 .read(cx)
682 .worktrees(cx)
683 .filter(|worktree| worktree.read(cx).is_local())
684 .collect::<Vec<_>>();
685 let worktree_ids = worktrees
686 .iter()
687 .map(|worktree| worktree.read(cx).id())
688 .collect::<HashSet<_>>();
689
690 // Remove worktrees that are no longer present
691 project_state
692 .worktrees
693 .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
694
695 // Register new worktrees
696 worktrees.retain(|worktree| {
697 let worktree_id = worktree.read(cx).id();
698 !project_state.worktrees.contains_key(&worktree_id)
699 });
700 for worktree in worktrees {
701 self.register_worktree(project.clone(), worktree, cx);
702 }
703 }
704
705 pub fn pending_file_count(
706 &self,
707 project: &ModelHandle<Project>,
708 ) -> Option<watch::Receiver<usize>> {
709 Some(
710 self.projects
711 .get(&project.downgrade())?
712 .pending_file_count_rx
713 .clone(),
714 )
715 }
716
717 pub fn search_project(
718 &mut self,
719 project: ModelHandle<Project>,
720 query: String,
721 limit: usize,
722 includes: Vec<PathMatcher>,
723 excludes: Vec<PathMatcher>,
724 cx: &mut ModelContext<Self>,
725 ) -> Task<Result<Vec<SearchResult>>> {
726 if query.is_empty() {
727 return Task::ready(Ok(Vec::new()));
728 }
729
730 let index = self.index_project(project.clone(), cx);
731 let embedding_provider = self.embedding_provider.clone();
732 let credential = self.provider_credential.clone();
733
734 cx.spawn(|this, mut cx| async move {
735 index.await?;
736 let t0 = Instant::now();
737
738 let query = embedding_provider
739 .embed_batch(vec![query], credential)
740 .await?
741 .pop()
742 .ok_or_else(|| anyhow!("could not embed query"))?;
743 log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
744
745 let search_start = Instant::now();
746 let modified_buffer_results = this.update(&mut cx, |this, cx| {
747 this.search_modified_buffers(
748 &project,
749 query.clone(),
750 limit,
751 &includes,
752 &excludes,
753 cx,
754 )
755 });
756 let file_results = this.update(&mut cx, |this, cx| {
757 this.search_files(project, query, limit, includes, excludes, cx)
758 });
759 let (modified_buffer_results, file_results) =
760 futures::join!(modified_buffer_results, file_results);
761
762 // Weave together the results from modified buffers and files.
763 let mut results = Vec::new();
764 let mut modified_buffers = HashSet::default();
765 for result in modified_buffer_results.log_err().unwrap_or_default() {
766 modified_buffers.insert(result.buffer.clone());
767 results.push(result);
768 }
769 for result in file_results.log_err().unwrap_or_default() {
770 if !modified_buffers.contains(&result.buffer) {
771 results.push(result);
772 }
773 }
774 results.sort_by_key(|result| Reverse(result.similarity));
775 results.truncate(limit);
776 log::trace!("Semantic search took {:?}", search_start.elapsed());
777 Ok(results)
778 })
779 }
780
781 pub fn search_files(
782 &mut self,
783 project: ModelHandle<Project>,
784 query: Embedding,
785 limit: usize,
786 includes: Vec<PathMatcher>,
787 excludes: Vec<PathMatcher>,
788 cx: &mut ModelContext<Self>,
789 ) -> Task<Result<Vec<SearchResult>>> {
790 let db_path = self.db.path().clone();
791 let fs = self.fs.clone();
792 cx.spawn(|this, mut cx| async move {
793 let database =
794 VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
795
796 let worktree_db_ids = this.read_with(&cx, |this, _| {
797 let project_state = this
798 .projects
799 .get(&project.downgrade())
800 .ok_or_else(|| anyhow!("project was not indexed"))?;
801 let worktree_db_ids = project_state
802 .worktrees
803 .values()
804 .filter_map(|worktree| {
805 if let WorktreeState::Registered(worktree) = worktree {
806 Some(worktree.db_id)
807 } else {
808 None
809 }
810 })
811 .collect::<Vec<i64>>();
812 anyhow::Ok(worktree_db_ids)
813 })?;
814
815 let file_ids = database
816 .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
817 .await?;
818
819 let batch_n = cx.background().num_cpus();
820 let ids_len = file_ids.clone().len();
821 let minimum_batch_size = 50;
822
823 let batch_size = {
824 let size = ids_len / batch_n;
825 if size < minimum_batch_size {
826 minimum_batch_size
827 } else {
828 size
829 }
830 };
831
832 let mut batch_results = Vec::new();
833 for batch in file_ids.chunks(batch_size) {
834 let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
835 let limit = limit.clone();
836 let fs = fs.clone();
837 let db_path = db_path.clone();
838 let query = query.clone();
839 if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
840 .await
841 .log_err()
842 {
843 batch_results.push(async move {
844 db.top_k_search(&query, limit, batch.as_slice()).await
845 });
846 }
847 }
848
849 let batch_results = futures::future::join_all(batch_results).await;
850
851 let mut results = Vec::new();
852 for batch_result in batch_results {
853 if batch_result.is_ok() {
854 for (id, similarity) in batch_result.unwrap() {
855 let ix = match results
856 .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
857 {
858 Ok(ix) => ix,
859 Err(ix) => ix,
860 };
861
862 results.insert(ix, (id, similarity));
863 results.truncate(limit);
864 }
865 }
866 }
867
868 let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
869 let scores = results
870 .into_iter()
871 .map(|(_, score)| score)
872 .collect::<Vec<_>>();
873 let spans = database.spans_for_ids(ids.as_slice()).await?;
874
875 let mut tasks = Vec::new();
876 let mut ranges = Vec::new();
877 let weak_project = project.downgrade();
878 project.update(&mut cx, |project, cx| {
879 for (worktree_db_id, file_path, byte_range) in spans {
880 let project_state =
881 if let Some(state) = this.read(cx).projects.get(&weak_project) {
882 state
883 } else {
884 return Err(anyhow!("project not added"));
885 };
886 if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
887 tasks.push(project.open_buffer((worktree_id, file_path), cx));
888 ranges.push(byte_range);
889 }
890 }
891
892 Ok(())
893 })?;
894
895 let buffers = futures::future::join_all(tasks).await;
896 Ok(buffers
897 .into_iter()
898 .zip(ranges)
899 .zip(scores)
900 .filter_map(|((buffer, range), similarity)| {
901 let buffer = buffer.log_err()?;
902 let range = buffer.read_with(&cx, |buffer, _| {
903 let start = buffer.clip_offset(range.start, Bias::Left);
904 let end = buffer.clip_offset(range.end, Bias::Right);
905 buffer.anchor_before(start)..buffer.anchor_after(end)
906 });
907 Some(SearchResult {
908 buffer,
909 range,
910 similarity,
911 })
912 })
913 .collect())
914 })
915 }
916
917 fn search_modified_buffers(
918 &self,
919 project: &ModelHandle<Project>,
920 query: Embedding,
921 limit: usize,
922 includes: &[PathMatcher],
923 excludes: &[PathMatcher],
924 cx: &mut ModelContext<Self>,
925 ) -> Task<Result<Vec<SearchResult>>> {
926 let modified_buffers = project
927 .read(cx)
928 .opened_buffers(cx)
929 .into_iter()
930 .filter_map(|buffer_handle| {
931 let buffer = buffer_handle.read(cx);
932 let snapshot = buffer.snapshot();
933 let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
934 excludes.iter().any(|matcher| matcher.is_match(&path))
935 });
936
937 let included = if includes.len() == 0 {
938 true
939 } else {
940 snapshot.resolve_file_path(cx, false).map_or(false, |path| {
941 includes.iter().any(|matcher| matcher.is_match(&path))
942 })
943 };
944
945 if buffer.is_dirty() && !excluded && included {
946 Some((buffer_handle, snapshot))
947 } else {
948 None
949 }
950 })
951 .collect::<HashMap<_, _>>();
952
953 let embedding_provider = self.embedding_provider.clone();
954 let fs = self.fs.clone();
955 let db_path = self.db.path().clone();
956 let background = cx.background().clone();
957 let credential = self.provider_credential.clone();
958 cx.background().spawn(async move {
959 let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
960 let mut results = Vec::<SearchResult>::new();
961
962 let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
963 for (buffer, snapshot) in modified_buffers {
964 let language = snapshot
965 .language_at(0)
966 .cloned()
967 .unwrap_or_else(|| language::PLAIN_TEXT.clone());
968 let mut spans = retriever
969 .parse_file_with_template(None, &snapshot.text(), language)
970 .log_err()
971 .unwrap_or_default();
972 if Self::embed_spans(
973 &mut spans,
974 embedding_provider.as_ref(),
975 &db,
976 credential.clone(),
977 )
978 .await
979 .log_err()
980 .is_some()
981 {
982 for span in spans {
983 let similarity = span.embedding.unwrap().similarity(&query);
984 let ix = match results
985 .binary_search_by_key(&Reverse(similarity), |result| {
986 Reverse(result.similarity)
987 }) {
988 Ok(ix) => ix,
989 Err(ix) => ix,
990 };
991
992 let range = {
993 let start = snapshot.clip_offset(span.range.start, Bias::Left);
994 let end = snapshot.clip_offset(span.range.end, Bias::Right);
995 snapshot.anchor_before(start)..snapshot.anchor_after(end)
996 };
997
998 results.insert(
999 ix,
1000 SearchResult {
1001 buffer: buffer.clone(),
1002 range,
1003 similarity,
1004 },
1005 );
1006 results.truncate(limit);
1007 }
1008 }
1009 }
1010
1011 Ok(results)
1012 })
1013 }
1014
1015 pub fn index_project(
1016 &mut self,
1017 project: ModelHandle<Project>,
1018 cx: &mut ModelContext<Self>,
1019 ) -> Task<Result<()>> {
1020 if !self.is_authenticated() {
1021 if !self.authenticate(cx) {
1022 return Task::ready(Err(anyhow!("user is not authenticated")));
1023 }
1024 }
1025
1026 if !self.projects.contains_key(&project.downgrade()) {
1027 let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
1028 project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
1029 this.project_worktrees_changed(project.clone(), cx);
1030 }
1031 project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
1032 this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
1033 }
1034 _ => {}
1035 });
1036 let project_state = ProjectState::new(subscription, cx);
1037 self.projects.insert(project.downgrade(), project_state);
1038 self.project_worktrees_changed(project.clone(), cx);
1039 }
1040 let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
1041 project_state.pending_index += 1;
1042 cx.notify();
1043
1044 let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
1045 let db = self.db.clone();
1046 let language_registry = self.language_registry.clone();
1047 let parsing_files_tx = self.parsing_files_tx.clone();
1048 let worktree_registration = self.wait_for_worktree_registration(&project, cx);
1049
1050 cx.spawn(|this, mut cx| async move {
1051 worktree_registration.await?;
1052
1053 let mut pending_files = Vec::new();
1054 let mut files_to_delete = Vec::new();
1055 this.update(&mut cx, |this, cx| {
1056 let project_state = this
1057 .projects
1058 .get_mut(&project.downgrade())
1059 .ok_or_else(|| anyhow!("project was dropped"))?;
1060 let pending_file_count_tx = &project_state.pending_file_count_tx;
1061
1062 project_state
1063 .worktrees
1064 .retain(|worktree_id, worktree_state| {
1065 let worktree = if let Some(worktree) =
1066 project.read(cx).worktree_for_id(*worktree_id, cx)
1067 {
1068 worktree
1069 } else {
1070 return false;
1071 };
1072 let worktree_state =
1073 if let WorktreeState::Registered(worktree_state) = worktree_state {
1074 worktree_state
1075 } else {
1076 return true;
1077 };
1078
1079 worktree_state.changed_paths.retain(|path, info| {
1080 if info.is_deleted {
1081 files_to_delete.push((worktree_state.db_id, path.clone()));
1082 } else {
1083 let absolute_path = worktree.read(cx).absolutize(path);
1084 let job_handle = JobHandle::new(pending_file_count_tx);
1085 pending_files.push(PendingFile {
1086 absolute_path,
1087 relative_path: path.clone(),
1088 language: None,
1089 job_handle,
1090 modified_time: info.mtime,
1091 worktree_db_id: worktree_state.db_id,
1092 });
1093 }
1094
1095 false
1096 });
1097 true
1098 });
1099
1100 anyhow::Ok(())
1101 })?;
1102
1103 cx.background()
1104 .spawn(async move {
1105 for (worktree_db_id, path) in files_to_delete {
1106 db.delete_file(worktree_db_id, path).await.log_err();
1107 }
1108
1109 let embeddings_for_digest = {
1110 let mut files = HashMap::default();
1111 for pending_file in &pending_files {
1112 files
1113 .entry(pending_file.worktree_db_id)
1114 .or_insert(Vec::new())
1115 .push(pending_file.relative_path.clone());
1116 }
1117 Arc::new(
1118 db.embeddings_for_files(files)
1119 .await
1120 .log_err()
1121 .unwrap_or_default(),
1122 )
1123 };
1124
1125 for mut pending_file in pending_files {
1126 if let Ok(language) = language_registry
1127 .language_for_file(&pending_file.relative_path, None)
1128 .await
1129 {
1130 if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
1131 && &language.name().as_ref() != &"Markdown"
1132 && language
1133 .grammar()
1134 .and_then(|grammar| grammar.embedding_config.as_ref())
1135 .is_none()
1136 {
1137 continue;
1138 }
1139 pending_file.language = Some(language);
1140 }
1141 parsing_files_tx
1142 .try_send((embeddings_for_digest.clone(), pending_file))
1143 .ok();
1144 }
1145
1146 // Wait until we're done indexing.
1147 while let Some(count) = pending_file_count_rx.next().await {
1148 if count == 0 {
1149 break;
1150 }
1151 }
1152 })
1153 .await;
1154
1155 this.update(&mut cx, |this, cx| {
1156 let project_state = this
1157 .projects
1158 .get_mut(&project.downgrade())
1159 .ok_or_else(|| anyhow!("project was dropped"))?;
1160 project_state.pending_index -= 1;
1161 cx.notify();
1162 anyhow::Ok(())
1163 })?;
1164
1165 Ok(())
1166 })
1167 }
1168
1169 fn wait_for_worktree_registration(
1170 &self,
1171 project: &ModelHandle<Project>,
1172 cx: &mut ModelContext<Self>,
1173 ) -> Task<Result<()>> {
1174 let project = project.downgrade();
1175 cx.spawn_weak(|this, cx| async move {
1176 loop {
1177 let mut pending_worktrees = Vec::new();
1178 this.upgrade(&cx)
1179 .ok_or_else(|| anyhow!("semantic index dropped"))?
1180 .read_with(&cx, |this, _| {
1181 if let Some(project) = this.projects.get(&project) {
1182 for worktree in project.worktrees.values() {
1183 if let WorktreeState::Registering(worktree) = worktree {
1184 pending_worktrees.push(worktree.done());
1185 }
1186 }
1187 }
1188 });
1189
1190 if pending_worktrees.is_empty() {
1191 break;
1192 } else {
1193 future::join_all(pending_worktrees).await;
1194 }
1195 }
1196 Ok(())
1197 })
1198 }
1199
1200 async fn embed_spans(
1201 spans: &mut [Span],
1202 embedding_provider: &dyn EmbeddingProvider,
1203 db: &VectorDatabase,
1204 credential: ProviderCredential,
1205 ) -> Result<()> {
1206 let mut batch = Vec::new();
1207 let mut batch_tokens = 0;
1208 let mut embeddings = Vec::new();
1209
1210 let digests = spans
1211 .iter()
1212 .map(|span| span.digest.clone())
1213 .collect::<Vec<_>>();
1214 let embeddings_for_digests = db
1215 .embeddings_for_digests(digests)
1216 .await
1217 .log_err()
1218 .unwrap_or_default();
1219
1220 for span in &*spans {
1221 if embeddings_for_digests.contains_key(&span.digest) {
1222 continue;
1223 };
1224
1225 if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
1226 let batch_embeddings = embedding_provider
1227 .embed_batch(mem::take(&mut batch), credential.clone())
1228 .await?;
1229 embeddings.extend(batch_embeddings);
1230 batch_tokens = 0;
1231 }
1232
1233 batch_tokens += span.token_count;
1234 batch.push(span.content.clone());
1235 }
1236
1237 if !batch.is_empty() {
1238 let batch_embeddings = embedding_provider
1239 .embed_batch(mem::take(&mut batch), credential)
1240 .await?;
1241
1242 embeddings.extend(batch_embeddings);
1243 }
1244
1245 let mut embeddings = embeddings.into_iter();
1246 for span in spans {
1247 let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
1248 Some(embedding.clone())
1249 } else {
1250 embeddings.next()
1251 };
1252 let embedding = embedding.ok_or_else(|| anyhow!("failed to embed spans"))?;
1253 span.embedding = Some(embedding);
1254 }
1255 Ok(())
1256 }
1257}
1258
1259impl Entity for SemanticIndex {
1260 type Event = ();
1261}
1262
1263impl Drop for JobHandle {
1264 fn drop(&mut self) {
1265 if let Some(inner) = Arc::get_mut(&mut self.tx) {
1266 // This is the last instance of the JobHandle (regardless of it's origin - whether it was cloned or not)
1267 if let Some(tx) = inner.upgrade() {
1268 let mut tx = tx.lock();
1269 *tx.borrow_mut() -= 1;
1270 }
1271 }
1272 }
1273}
1274
1275#[cfg(test)]
1276mod tests {
1277
1278 use super::*;
1279 #[test]
1280 fn test_job_handle() {
1281 let (job_count_tx, job_count_rx) = watch::channel_with(0);
1282 let tx = Arc::new(Mutex::new(job_count_tx));
1283 let job_handle = JobHandle::new(&tx);
1284
1285 assert_eq!(1, *job_count_rx.borrow());
1286 let new_job_handle = job_handle.clone();
1287 assert_eq!(1, *job_count_rx.borrow());
1288 drop(job_handle);
1289 assert_eq!(1, *job_count_rx.borrow());
1290 drop(new_job_handle);
1291 assert_eq!(0, *job_count_rx.borrow());
1292 }
1293}