1mod db;
2mod embedding_queue;
3mod parsing;
4pub mod semantic_index_settings;
5
6#[cfg(test)]
7mod semantic_index_tests;
8
9use crate::semantic_index_settings::SemanticIndexSettings;
10use ai::embedding::{Embedding, EmbeddingProvider};
11use ai::providers::open_ai::OpenAIEmbeddingProvider;
12use anyhow::{anyhow, Result};
13use collections::{BTreeMap, HashMap, HashSet};
14use db::VectorDatabase;
15use embedding_queue::{EmbeddingQueue, FileToEmbed};
16use futures::{future, FutureExt, StreamExt};
17use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
18use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
19use lazy_static::lazy_static;
20use ordered_float::OrderedFloat;
21use parking_lot::Mutex;
22use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
23use postage::watch;
24use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
25use smol::channel;
26use std::{
27 cmp::Reverse,
28 env,
29 future::Future,
30 mem,
31 ops::Range,
32 path::{Path, PathBuf},
33 sync::{Arc, Weak},
34 time::{Duration, Instant, SystemTime},
35};
36use util::{channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
37use workspace::WorkspaceCreated;
38
39const SEMANTIC_INDEX_VERSION: usize = 11;
40const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
41const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
42
43lazy_static! {
44 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
45}
46
47pub fn init(
48 fs: Arc<dyn Fs>,
49 http_client: Arc<dyn HttpClient>,
50 language_registry: Arc<LanguageRegistry>,
51 cx: &mut AppContext,
52) {
53 settings::register::<SemanticIndexSettings>(cx);
54
55 let db_file_path = EMBEDDINGS_DIR
56 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
57 .join("embeddings_db");
58
59 cx.subscribe_global::<WorkspaceCreated, _>({
60 move |event, cx| {
61 let Some(semantic_index) = SemanticIndex::global(cx) else {
62 return;
63 };
64 let workspace = &event.0;
65 if let Some(workspace) = workspace.upgrade(cx) {
66 let project = workspace.read(cx).project().clone();
67 if project.read(cx).is_local() {
68 cx.spawn(|mut cx| async move {
69 let previously_indexed = semantic_index
70 .update(&mut cx, |index, cx| {
71 index.project_previously_indexed(&project, cx)
72 })
73 .await?;
74 if previously_indexed {
75 semantic_index
76 .update(&mut cx, |index, cx| index.index_project(project, cx))
77 .await?;
78 }
79 anyhow::Ok(())
80 })
81 .detach_and_log_err(cx);
82 }
83 }
84 }
85 })
86 .detach();
87
88 cx.spawn(move |mut cx| async move {
89 let semantic_index = SemanticIndex::new(
90 fs,
91 db_file_path,
92 Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
93 language_registry,
94 cx.clone(),
95 )
96 .await?;
97
98 cx.update(|cx| {
99 cx.set_global(semantic_index.clone());
100 });
101
102 anyhow::Ok(())
103 })
104 .detach();
105}
106
107#[derive(Copy, Clone, Debug)]
108pub enum SemanticIndexStatus {
109 NotAuthenticated,
110 NotIndexed,
111 Indexed,
112 Indexing {
113 remaining_files: usize,
114 rate_limit_expiry: Option<Instant>,
115 },
116}
117
118pub struct SemanticIndex {
119 fs: Arc<dyn Fs>,
120 db: VectorDatabase,
121 embedding_provider: Arc<dyn EmbeddingProvider>,
122 language_registry: Arc<LanguageRegistry>,
123 parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
124 _embedding_task: Task<()>,
125 _parsing_files_tasks: Vec<Task<()>>,
126 projects: HashMap<WeakModelHandle<Project>, ProjectState>,
127}
128
129struct ProjectState {
130 worktrees: HashMap<WorktreeId, WorktreeState>,
131 pending_file_count_rx: watch::Receiver<usize>,
132 pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
133 pending_index: usize,
134 _subscription: gpui::Subscription,
135 _observe_pending_file_count: Task<()>,
136}
137
138enum WorktreeState {
139 Registering(RegisteringWorktreeState),
140 Registered(RegisteredWorktreeState),
141}
142
143impl WorktreeState {
144 fn is_registered(&self) -> bool {
145 matches!(self, Self::Registered(_))
146 }
147
148 fn paths_changed(
149 &mut self,
150 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
151 worktree: &Worktree,
152 ) {
153 let changed_paths = match self {
154 Self::Registering(state) => &mut state.changed_paths,
155 Self::Registered(state) => &mut state.changed_paths,
156 };
157
158 for (path, entry_id, change) in changes.iter() {
159 let Some(entry) = worktree.entry_for_id(*entry_id) else {
160 continue;
161 };
162 if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
163 continue;
164 }
165 changed_paths.insert(
166 path.clone(),
167 ChangedPathInfo {
168 mtime: entry.mtime,
169 is_deleted: *change == PathChange::Removed,
170 },
171 );
172 }
173 }
174}
175
176struct RegisteringWorktreeState {
177 changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
178 done_rx: watch::Receiver<Option<()>>,
179 _registration: Task<()>,
180}
181
182impl RegisteringWorktreeState {
183 fn done(&self) -> impl Future<Output = ()> {
184 let mut done_rx = self.done_rx.clone();
185 async move {
186 while let Some(result) = done_rx.next().await {
187 if result.is_some() {
188 break;
189 }
190 }
191 }
192 }
193}
194
195struct RegisteredWorktreeState {
196 db_id: i64,
197 changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
198}
199
200struct ChangedPathInfo {
201 mtime: SystemTime,
202 is_deleted: bool,
203}
204
205#[derive(Clone)]
206pub struct JobHandle {
207 /// The outer Arc is here to count the clones of a JobHandle instance;
208 /// when the last handle to a given job is dropped, we decrement a counter (just once).
209 tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
210}
211
212impl JobHandle {
213 fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
214 *tx.lock().borrow_mut() += 1;
215 Self {
216 tx: Arc::new(Arc::downgrade(&tx)),
217 }
218 }
219}
220
221impl ProjectState {
222 fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
223 let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
224 let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
225 Self {
226 worktrees: Default::default(),
227 pending_file_count_rx: pending_file_count_rx.clone(),
228 pending_file_count_tx,
229 pending_index: 0,
230 _subscription: subscription,
231 _observe_pending_file_count: cx.spawn_weak({
232 let mut pending_file_count_rx = pending_file_count_rx.clone();
233 |this, mut cx| async move {
234 while let Some(_) = pending_file_count_rx.next().await {
235 if let Some(this) = this.upgrade(&cx) {
236 this.update(&mut cx, |_, cx| cx.notify());
237 }
238 }
239 }
240 }),
241 }
242 }
243
244 fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
245 self.worktrees
246 .iter()
247 .find_map(|(worktree_id, worktree_state)| match worktree_state {
248 WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
249 _ => None,
250 })
251 }
252}
253
254#[derive(Clone)]
255pub struct PendingFile {
256 worktree_db_id: i64,
257 relative_path: Arc<Path>,
258 absolute_path: PathBuf,
259 language: Option<Arc<Language>>,
260 modified_time: SystemTime,
261 job_handle: JobHandle,
262}
263
264#[derive(Clone)]
265pub struct SearchResult {
266 pub buffer: ModelHandle<Buffer>,
267 pub range: Range<Anchor>,
268 pub similarity: OrderedFloat<f32>,
269}
270
271impl SemanticIndex {
272 pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
273 if cx.has_global::<ModelHandle<Self>>() {
274 Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
275 } else {
276 None
277 }
278 }
279
280 pub fn enabled(cx: &AppContext) -> bool {
281 settings::get::<SemanticIndexSettings>(cx).enabled
282 }
283
284 pub fn status(&self, project: &ModelHandle<Project>) -> SemanticIndexStatus {
285 if !self.embedding_provider.is_authenticated() {
286 return SemanticIndexStatus::NotAuthenticated;
287 }
288
289 if let Some(project_state) = self.projects.get(&project.downgrade()) {
290 if project_state
291 .worktrees
292 .values()
293 .all(|worktree| worktree.is_registered())
294 && project_state.pending_index == 0
295 {
296 SemanticIndexStatus::Indexed
297 } else {
298 SemanticIndexStatus::Indexing {
299 remaining_files: project_state.pending_file_count_rx.borrow().clone(),
300 rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
301 }
302 }
303 } else {
304 SemanticIndexStatus::NotIndexed
305 }
306 }
307
308 pub async fn new(
309 fs: Arc<dyn Fs>,
310 database_path: PathBuf,
311 embedding_provider: Arc<dyn EmbeddingProvider>,
312 language_registry: Arc<LanguageRegistry>,
313 mut cx: AsyncAppContext,
314 ) -> Result<ModelHandle<Self>> {
315 let t0 = Instant::now();
316 let database_path = Arc::from(database_path);
317 let db = VectorDatabase::new(fs.clone(), database_path, cx.background()).await?;
318
319 log::trace!(
320 "db initialization took {:?} milliseconds",
321 t0.elapsed().as_millis()
322 );
323
324 Ok(cx.add_model(|cx| {
325 let t0 = Instant::now();
326 let embedding_queue =
327 EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
328 let _embedding_task = cx.background().spawn({
329 let embedded_files = embedding_queue.finished_files();
330 let db = db.clone();
331 async move {
332 while let Ok(file) = embedded_files.recv().await {
333 db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
334 .await
335 .log_err();
336 }
337 }
338 });
339
340 // Parse files into embeddable spans.
341 let (parsing_files_tx, parsing_files_rx) =
342 channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
343 let embedding_queue = Arc::new(Mutex::new(embedding_queue));
344 let mut _parsing_files_tasks = Vec::new();
345 for _ in 0..cx.background().num_cpus() {
346 let fs = fs.clone();
347 let mut parsing_files_rx = parsing_files_rx.clone();
348 let embedding_provider = embedding_provider.clone();
349 let embedding_queue = embedding_queue.clone();
350 let background = cx.background().clone();
351 _parsing_files_tasks.push(cx.background().spawn(async move {
352 let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
353 loop {
354 let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
355 let mut next_file_to_parse = parsing_files_rx.next().fuse();
356 futures::select_biased! {
357 next_file_to_parse = next_file_to_parse => {
358 if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
359 Self::parse_file(
360 &fs,
361 pending_file,
362 &mut retriever,
363 &embedding_queue,
364 &embeddings_for_digest,
365 )
366 .await
367 } else {
368 break;
369 }
370 },
371 _ = timer => {
372 embedding_queue.lock().flush();
373 }
374 }
375 }
376 }));
377 }
378
379 log::trace!(
380 "semantic index task initialization took {:?} milliseconds",
381 t0.elapsed().as_millis()
382 );
383 Self {
384 fs,
385 db,
386 embedding_provider,
387 language_registry,
388 parsing_files_tx,
389 _embedding_task,
390 _parsing_files_tasks,
391 projects: Default::default(),
392 }
393 }))
394 }
395
396 async fn parse_file(
397 fs: &Arc<dyn Fs>,
398 pending_file: PendingFile,
399 retriever: &mut CodeContextRetriever,
400 embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
401 embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
402 ) {
403 let Some(language) = pending_file.language else {
404 return;
405 };
406
407 if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
408 if let Some(mut spans) = retriever
409 .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
410 .log_err()
411 {
412 log::trace!(
413 "parsed path {:?}: {} spans",
414 pending_file.relative_path,
415 spans.len()
416 );
417
418 for span in &mut spans {
419 if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
420 span.embedding = Some(embedding.to_owned());
421 }
422 }
423
424 embedding_queue.lock().push(FileToEmbed {
425 worktree_id: pending_file.worktree_db_id,
426 path: pending_file.relative_path,
427 mtime: pending_file.modified_time,
428 job_handle: pending_file.job_handle,
429 spans,
430 });
431 }
432 }
433 }
434
435 pub fn project_previously_indexed(
436 &mut self,
437 project: &ModelHandle<Project>,
438 cx: &mut ModelContext<Self>,
439 ) -> Task<Result<bool>> {
440 let worktrees_indexed_previously = project
441 .read(cx)
442 .worktrees(cx)
443 .map(|worktree| {
444 self.db
445 .worktree_previously_indexed(&worktree.read(cx).abs_path())
446 })
447 .collect::<Vec<_>>();
448 cx.spawn(|_, _cx| async move {
449 let worktree_indexed_previously =
450 futures::future::join_all(worktrees_indexed_previously).await;
451
452 Ok(worktree_indexed_previously
453 .iter()
454 .filter(|worktree| worktree.is_ok())
455 .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
456 })
457 }
458
459 fn project_entries_changed(
460 &mut self,
461 project: ModelHandle<Project>,
462 worktree_id: WorktreeId,
463 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
464 cx: &mut ModelContext<Self>,
465 ) {
466 let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
467 return;
468 };
469 let project = project.downgrade();
470 let Some(project_state) = self.projects.get_mut(&project) else {
471 return;
472 };
473
474 let worktree = worktree.read(cx);
475 let worktree_state =
476 if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
477 worktree_state
478 } else {
479 return;
480 };
481 worktree_state.paths_changed(changes, worktree);
482 if let WorktreeState::Registered(_) = worktree_state {
483 cx.spawn_weak(|this, mut cx| async move {
484 cx.background().timer(BACKGROUND_INDEXING_DELAY).await;
485 if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) {
486 this.update(&mut cx, |this, cx| {
487 this.index_project(project, cx).detach_and_log_err(cx)
488 });
489 }
490 })
491 .detach();
492 }
493 }
494
495 fn register_worktree(
496 &mut self,
497 project: ModelHandle<Project>,
498 worktree: ModelHandle<Worktree>,
499 cx: &mut ModelContext<Self>,
500 ) {
501 let project = project.downgrade();
502 let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
503 project_state
504 } else {
505 return;
506 };
507 let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
508 worktree
509 } else {
510 return;
511 };
512 let worktree_abs_path = worktree.abs_path().clone();
513 let scan_complete = worktree.scan_complete();
514 let worktree_id = worktree.id();
515 let db = self.db.clone();
516 let language_registry = self.language_registry.clone();
517 let (mut done_tx, done_rx) = watch::channel();
518 let registration = cx.spawn(|this, mut cx| {
519 async move {
520 let register = async {
521 scan_complete.await;
522 let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
523 let mut file_mtimes = db.get_file_mtimes(db_id).await?;
524 let worktree = if let Some(project) = project.upgrade(&cx) {
525 project
526 .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
527 .ok_or_else(|| anyhow!("worktree not found"))?
528 } else {
529 return anyhow::Ok(());
530 };
531 let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot());
532 let mut changed_paths = cx
533 .background()
534 .spawn(async move {
535 let mut changed_paths = BTreeMap::new();
536 for file in worktree.files(false, 0) {
537 let absolute_path = worktree.absolutize(&file.path);
538
539 if file.is_external || file.is_ignored || file.is_symlink {
540 continue;
541 }
542
543 if let Ok(language) = language_registry
544 .language_for_file(&absolute_path, None)
545 .await
546 {
547 // Test if file is valid parseable file
548 if !PARSEABLE_ENTIRE_FILE_TYPES
549 .contains(&language.name().as_ref())
550 && &language.name().as_ref() != &"Markdown"
551 && language
552 .grammar()
553 .and_then(|grammar| grammar.embedding_config.as_ref())
554 .is_none()
555 {
556 continue;
557 }
558
559 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
560 let already_stored = stored_mtime
561 .map_or(false, |existing_mtime| {
562 existing_mtime == file.mtime
563 });
564
565 if !already_stored {
566 changed_paths.insert(
567 file.path.clone(),
568 ChangedPathInfo {
569 mtime: file.mtime,
570 is_deleted: false,
571 },
572 );
573 }
574 }
575 }
576
577 // Clean up entries from database that are no longer in the worktree.
578 for (path, mtime) in file_mtimes {
579 changed_paths.insert(
580 path.into(),
581 ChangedPathInfo {
582 mtime,
583 is_deleted: true,
584 },
585 );
586 }
587
588 anyhow::Ok(changed_paths)
589 })
590 .await?;
591 this.update(&mut cx, |this, cx| {
592 let project_state = this
593 .projects
594 .get_mut(&project)
595 .ok_or_else(|| anyhow!("project not registered"))?;
596 let project = project
597 .upgrade(cx)
598 .ok_or_else(|| anyhow!("project was dropped"))?;
599
600 if let Some(WorktreeState::Registering(state)) =
601 project_state.worktrees.remove(&worktree_id)
602 {
603 changed_paths.extend(state.changed_paths);
604 }
605 project_state.worktrees.insert(
606 worktree_id,
607 WorktreeState::Registered(RegisteredWorktreeState {
608 db_id,
609 changed_paths,
610 }),
611 );
612 this.index_project(project, cx).detach_and_log_err(cx);
613
614 anyhow::Ok(())
615 })?;
616
617 anyhow::Ok(())
618 };
619
620 if register.await.log_err().is_none() {
621 // Stop tracking this worktree if the registration failed.
622 this.update(&mut cx, |this, _| {
623 this.projects.get_mut(&project).map(|project_state| {
624 project_state.worktrees.remove(&worktree_id);
625 });
626 })
627 }
628
629 *done_tx.borrow_mut() = Some(());
630 }
631 });
632 project_state.worktrees.insert(
633 worktree_id,
634 WorktreeState::Registering(RegisteringWorktreeState {
635 changed_paths: Default::default(),
636 done_rx,
637 _registration: registration,
638 }),
639 );
640 }
641
642 fn project_worktrees_changed(
643 &mut self,
644 project: ModelHandle<Project>,
645 cx: &mut ModelContext<Self>,
646 ) {
647 let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
648 {
649 project_state
650 } else {
651 return;
652 };
653
654 let mut worktrees = project
655 .read(cx)
656 .worktrees(cx)
657 .filter(|worktree| worktree.read(cx).is_local())
658 .collect::<Vec<_>>();
659 let worktree_ids = worktrees
660 .iter()
661 .map(|worktree| worktree.read(cx).id())
662 .collect::<HashSet<_>>();
663
664 // Remove worktrees that are no longer present
665 project_state
666 .worktrees
667 .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
668
669 // Register new worktrees
670 worktrees.retain(|worktree| {
671 let worktree_id = worktree.read(cx).id();
672 !project_state.worktrees.contains_key(&worktree_id)
673 });
674 for worktree in worktrees {
675 self.register_worktree(project.clone(), worktree, cx);
676 }
677 }
678
679 pub fn pending_file_count(
680 &self,
681 project: &ModelHandle<Project>,
682 ) -> Option<watch::Receiver<usize>> {
683 Some(
684 self.projects
685 .get(&project.downgrade())?
686 .pending_file_count_rx
687 .clone(),
688 )
689 }
690
691 pub fn search_project(
692 &mut self,
693 project: ModelHandle<Project>,
694 query: String,
695 limit: usize,
696 includes: Vec<PathMatcher>,
697 excludes: Vec<PathMatcher>,
698 cx: &mut ModelContext<Self>,
699 ) -> Task<Result<Vec<SearchResult>>> {
700 if query.is_empty() {
701 return Task::ready(Ok(Vec::new()));
702 }
703
704 let index = self.index_project(project.clone(), cx);
705 let embedding_provider = self.embedding_provider.clone();
706
707 cx.spawn(|this, mut cx| async move {
708 index.await?;
709 let t0 = Instant::now();
710 let query = embedding_provider
711 .embed_batch(vec![query])
712 .await?
713 .pop()
714 .ok_or_else(|| anyhow!("could not embed query"))?;
715 log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
716
717 let search_start = Instant::now();
718 let modified_buffer_results = this.update(&mut cx, |this, cx| {
719 this.search_modified_buffers(
720 &project,
721 query.clone(),
722 limit,
723 &includes,
724 &excludes,
725 cx,
726 )
727 });
728 let file_results = this.update(&mut cx, |this, cx| {
729 this.search_files(project, query, limit, includes, excludes, cx)
730 });
731 let (modified_buffer_results, file_results) =
732 futures::join!(modified_buffer_results, file_results);
733
734 // Weave together the results from modified buffers and files.
735 let mut results = Vec::new();
736 let mut modified_buffers = HashSet::default();
737 for result in modified_buffer_results.log_err().unwrap_or_default() {
738 modified_buffers.insert(result.buffer.clone());
739 results.push(result);
740 }
741 for result in file_results.log_err().unwrap_or_default() {
742 if !modified_buffers.contains(&result.buffer) {
743 results.push(result);
744 }
745 }
746 results.sort_by_key(|result| Reverse(result.similarity));
747 results.truncate(limit);
748 log::trace!("Semantic search took {:?}", search_start.elapsed());
749 Ok(results)
750 })
751 }
752
753 pub fn search_files(
754 &mut self,
755 project: ModelHandle<Project>,
756 query: Embedding,
757 limit: usize,
758 includes: Vec<PathMatcher>,
759 excludes: Vec<PathMatcher>,
760 cx: &mut ModelContext<Self>,
761 ) -> Task<Result<Vec<SearchResult>>> {
762 let db_path = self.db.path().clone();
763 let fs = self.fs.clone();
764 cx.spawn(|this, mut cx| async move {
765 let database =
766 VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
767
768 let worktree_db_ids = this.read_with(&cx, |this, _| {
769 let project_state = this
770 .projects
771 .get(&project.downgrade())
772 .ok_or_else(|| anyhow!("project was not indexed"))?;
773 let worktree_db_ids = project_state
774 .worktrees
775 .values()
776 .filter_map(|worktree| {
777 if let WorktreeState::Registered(worktree) = worktree {
778 Some(worktree.db_id)
779 } else {
780 None
781 }
782 })
783 .collect::<Vec<i64>>();
784 anyhow::Ok(worktree_db_ids)
785 })?;
786
787 let file_ids = database
788 .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
789 .await?;
790
791 let batch_n = cx.background().num_cpus();
792 let ids_len = file_ids.clone().len();
793 let minimum_batch_size = 50;
794
795 let batch_size = {
796 let size = ids_len / batch_n;
797 if size < minimum_batch_size {
798 minimum_batch_size
799 } else {
800 size
801 }
802 };
803
804 let mut batch_results = Vec::new();
805 for batch in file_ids.chunks(batch_size) {
806 let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
807 let limit = limit.clone();
808 let fs = fs.clone();
809 let db_path = db_path.clone();
810 let query = query.clone();
811 if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
812 .await
813 .log_err()
814 {
815 batch_results.push(async move {
816 db.top_k_search(&query, limit, batch.as_slice()).await
817 });
818 }
819 }
820
821 let batch_results = futures::future::join_all(batch_results).await;
822
823 let mut results = Vec::new();
824 for batch_result in batch_results {
825 if batch_result.is_ok() {
826 for (id, similarity) in batch_result.unwrap() {
827 let ix = match results
828 .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
829 {
830 Ok(ix) => ix,
831 Err(ix) => ix,
832 };
833
834 results.insert(ix, (id, similarity));
835 results.truncate(limit);
836 }
837 }
838 }
839
840 let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
841 let scores = results
842 .into_iter()
843 .map(|(_, score)| score)
844 .collect::<Vec<_>>();
845 let spans = database.spans_for_ids(ids.as_slice()).await?;
846
847 let mut tasks = Vec::new();
848 let mut ranges = Vec::new();
849 let weak_project = project.downgrade();
850 project.update(&mut cx, |project, cx| {
851 for (worktree_db_id, file_path, byte_range) in spans {
852 let project_state =
853 if let Some(state) = this.read(cx).projects.get(&weak_project) {
854 state
855 } else {
856 return Err(anyhow!("project not added"));
857 };
858 if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
859 tasks.push(project.open_buffer((worktree_id, file_path), cx));
860 ranges.push(byte_range);
861 }
862 }
863
864 Ok(())
865 })?;
866
867 let buffers = futures::future::join_all(tasks).await;
868 Ok(buffers
869 .into_iter()
870 .zip(ranges)
871 .zip(scores)
872 .filter_map(|((buffer, range), similarity)| {
873 let buffer = buffer.log_err()?;
874 let range = buffer.read_with(&cx, |buffer, _| {
875 let start = buffer.clip_offset(range.start, Bias::Left);
876 let end = buffer.clip_offset(range.end, Bias::Right);
877 buffer.anchor_before(start)..buffer.anchor_after(end)
878 });
879 Some(SearchResult {
880 buffer,
881 range,
882 similarity,
883 })
884 })
885 .collect())
886 })
887 }
888
889 fn search_modified_buffers(
890 &self,
891 project: &ModelHandle<Project>,
892 query: Embedding,
893 limit: usize,
894 includes: &[PathMatcher],
895 excludes: &[PathMatcher],
896 cx: &mut ModelContext<Self>,
897 ) -> Task<Result<Vec<SearchResult>>> {
898 let modified_buffers = project
899 .read(cx)
900 .opened_buffers(cx)
901 .into_iter()
902 .filter_map(|buffer_handle| {
903 let buffer = buffer_handle.read(cx);
904 let snapshot = buffer.snapshot();
905 let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
906 excludes.iter().any(|matcher| matcher.is_match(&path))
907 });
908
909 let included = if includes.len() == 0 {
910 true
911 } else {
912 snapshot.resolve_file_path(cx, false).map_or(false, |path| {
913 includes.iter().any(|matcher| matcher.is_match(&path))
914 })
915 };
916
917 if buffer.is_dirty() && !excluded && included {
918 Some((buffer_handle, snapshot))
919 } else {
920 None
921 }
922 })
923 .collect::<HashMap<_, _>>();
924
925 let embedding_provider = self.embedding_provider.clone();
926 let fs = self.fs.clone();
927 let db_path = self.db.path().clone();
928 let background = cx.background().clone();
929 cx.background().spawn(async move {
930 let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
931 let mut results = Vec::<SearchResult>::new();
932
933 let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
934 for (buffer, snapshot) in modified_buffers {
935 let language = snapshot
936 .language_at(0)
937 .cloned()
938 .unwrap_or_else(|| language::PLAIN_TEXT.clone());
939 let mut spans = retriever
940 .parse_file_with_template(None, &snapshot.text(), language)
941 .log_err()
942 .unwrap_or_default();
943 if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
944 .await
945 .log_err()
946 .is_some()
947 {
948 for span in spans {
949 let similarity = span.embedding.unwrap().similarity(&query);
950 let ix = match results
951 .binary_search_by_key(&Reverse(similarity), |result| {
952 Reverse(result.similarity)
953 }) {
954 Ok(ix) => ix,
955 Err(ix) => ix,
956 };
957
958 let range = {
959 let start = snapshot.clip_offset(span.range.start, Bias::Left);
960 let end = snapshot.clip_offset(span.range.end, Bias::Right);
961 snapshot.anchor_before(start)..snapshot.anchor_after(end)
962 };
963
964 results.insert(
965 ix,
966 SearchResult {
967 buffer: buffer.clone(),
968 range,
969 similarity,
970 },
971 );
972 results.truncate(limit);
973 }
974 }
975 }
976
977 Ok(results)
978 })
979 }
980
981 pub fn index_project(
982 &mut self,
983 project: ModelHandle<Project>,
984 cx: &mut ModelContext<Self>,
985 ) -> Task<Result<()>> {
986 if !self.embedding_provider.is_authenticated() {
987 return Task::ready(Err(anyhow!("user is not authenticated")));
988 }
989
990 if !self.projects.contains_key(&project.downgrade()) {
991 let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
992 project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
993 this.project_worktrees_changed(project.clone(), cx);
994 }
995 project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
996 this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
997 }
998 _ => {}
999 });
1000 let project_state = ProjectState::new(subscription, cx);
1001 self.projects.insert(project.downgrade(), project_state);
1002 self.project_worktrees_changed(project.clone(), cx);
1003 }
1004 let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
1005 project_state.pending_index += 1;
1006 cx.notify();
1007
1008 let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
1009 let db = self.db.clone();
1010 let language_registry = self.language_registry.clone();
1011 let parsing_files_tx = self.parsing_files_tx.clone();
1012 let worktree_registration = self.wait_for_worktree_registration(&project, cx);
1013
1014 cx.spawn(|this, mut cx| async move {
1015 worktree_registration.await?;
1016
1017 let mut pending_files = Vec::new();
1018 let mut files_to_delete = Vec::new();
1019 this.update(&mut cx, |this, cx| {
1020 let project_state = this
1021 .projects
1022 .get_mut(&project.downgrade())
1023 .ok_or_else(|| anyhow!("project was dropped"))?;
1024 let pending_file_count_tx = &project_state.pending_file_count_tx;
1025
1026 project_state
1027 .worktrees
1028 .retain(|worktree_id, worktree_state| {
1029 let worktree = if let Some(worktree) =
1030 project.read(cx).worktree_for_id(*worktree_id, cx)
1031 {
1032 worktree
1033 } else {
1034 return false;
1035 };
1036 let worktree_state =
1037 if let WorktreeState::Registered(worktree_state) = worktree_state {
1038 worktree_state
1039 } else {
1040 return true;
1041 };
1042
1043 worktree_state.changed_paths.retain(|path, info| {
1044 if info.is_deleted {
1045 files_to_delete.push((worktree_state.db_id, path.clone()));
1046 } else {
1047 let absolute_path = worktree.read(cx).absolutize(path);
1048 let job_handle = JobHandle::new(pending_file_count_tx);
1049 pending_files.push(PendingFile {
1050 absolute_path,
1051 relative_path: path.clone(),
1052 language: None,
1053 job_handle,
1054 modified_time: info.mtime,
1055 worktree_db_id: worktree_state.db_id,
1056 });
1057 }
1058
1059 false
1060 });
1061 true
1062 });
1063
1064 anyhow::Ok(())
1065 })?;
1066
1067 cx.background()
1068 .spawn(async move {
1069 for (worktree_db_id, path) in files_to_delete {
1070 db.delete_file(worktree_db_id, path).await.log_err();
1071 }
1072
1073 let embeddings_for_digest = {
1074 let mut files = HashMap::default();
1075 for pending_file in &pending_files {
1076 files
1077 .entry(pending_file.worktree_db_id)
1078 .or_insert(Vec::new())
1079 .push(pending_file.relative_path.clone());
1080 }
1081 Arc::new(
1082 db.embeddings_for_files(files)
1083 .await
1084 .log_err()
1085 .unwrap_or_default(),
1086 )
1087 };
1088
1089 for mut pending_file in pending_files {
1090 if let Ok(language) = language_registry
1091 .language_for_file(&pending_file.relative_path, None)
1092 .await
1093 {
1094 if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
1095 && &language.name().as_ref() != &"Markdown"
1096 && language
1097 .grammar()
1098 .and_then(|grammar| grammar.embedding_config.as_ref())
1099 .is_none()
1100 {
1101 continue;
1102 }
1103 pending_file.language = Some(language);
1104 }
1105 parsing_files_tx
1106 .try_send((embeddings_for_digest.clone(), pending_file))
1107 .ok();
1108 }
1109
1110 // Wait until we're done indexing.
1111 while let Some(count) = pending_file_count_rx.next().await {
1112 if count == 0 {
1113 break;
1114 }
1115 }
1116 })
1117 .await;
1118
1119 this.update(&mut cx, |this, cx| {
1120 let project_state = this
1121 .projects
1122 .get_mut(&project.downgrade())
1123 .ok_or_else(|| anyhow!("project was dropped"))?;
1124 project_state.pending_index -= 1;
1125 cx.notify();
1126 anyhow::Ok(())
1127 })?;
1128
1129 Ok(())
1130 })
1131 }
1132
1133 fn wait_for_worktree_registration(
1134 &self,
1135 project: &ModelHandle<Project>,
1136 cx: &mut ModelContext<Self>,
1137 ) -> Task<Result<()>> {
1138 let project = project.downgrade();
1139 cx.spawn_weak(|this, cx| async move {
1140 loop {
1141 let mut pending_worktrees = Vec::new();
1142 this.upgrade(&cx)
1143 .ok_or_else(|| anyhow!("semantic index dropped"))?
1144 .read_with(&cx, |this, _| {
1145 if let Some(project) = this.projects.get(&project) {
1146 for worktree in project.worktrees.values() {
1147 if let WorktreeState::Registering(worktree) = worktree {
1148 pending_worktrees.push(worktree.done());
1149 }
1150 }
1151 }
1152 });
1153
1154 if pending_worktrees.is_empty() {
1155 break;
1156 } else {
1157 future::join_all(pending_worktrees).await;
1158 }
1159 }
1160 Ok(())
1161 })
1162 }
1163
1164 async fn embed_spans(
1165 spans: &mut [Span],
1166 embedding_provider: &dyn EmbeddingProvider,
1167 db: &VectorDatabase,
1168 ) -> Result<()> {
1169 let mut batch = Vec::new();
1170 let mut batch_tokens = 0;
1171 let mut embeddings = Vec::new();
1172
1173 let digests = spans
1174 .iter()
1175 .map(|span| span.digest.clone())
1176 .collect::<Vec<_>>();
1177 let embeddings_for_digests = db
1178 .embeddings_for_digests(digests)
1179 .await
1180 .log_err()
1181 .unwrap_or_default();
1182
1183 for span in &*spans {
1184 if embeddings_for_digests.contains_key(&span.digest) {
1185 continue;
1186 };
1187
1188 if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
1189 let batch_embeddings = embedding_provider
1190 .embed_batch(mem::take(&mut batch))
1191 .await?;
1192 embeddings.extend(batch_embeddings);
1193 batch_tokens = 0;
1194 }
1195
1196 batch_tokens += span.token_count;
1197 batch.push(span.content.clone());
1198 }
1199
1200 if !batch.is_empty() {
1201 let batch_embeddings = embedding_provider
1202 .embed_batch(mem::take(&mut batch))
1203 .await?;
1204
1205 embeddings.extend(batch_embeddings);
1206 }
1207
1208 let mut embeddings = embeddings.into_iter();
1209 for span in spans {
1210 let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
1211 Some(embedding.clone())
1212 } else {
1213 embeddings.next()
1214 };
1215 let embedding = embedding.ok_or_else(|| anyhow!("failed to embed spans"))?;
1216 span.embedding = Some(embedding);
1217 }
1218 Ok(())
1219 }
1220}
1221
1222impl Entity for SemanticIndex {
1223 type Event = ();
1224}
1225
1226impl Drop for JobHandle {
1227 fn drop(&mut self) {
1228 if let Some(inner) = Arc::get_mut(&mut self.tx) {
1229 // This is the last instance of the JobHandle (regardless of it's origin - whether it was cloned or not)
1230 if let Some(tx) = inner.upgrade() {
1231 let mut tx = tx.lock();
1232 *tx.borrow_mut() -= 1;
1233 }
1234 }
1235 }
1236}
1237
1238#[cfg(test)]
1239mod tests {
1240
1241 use super::*;
1242 #[test]
1243 fn test_job_handle() {
1244 let (job_count_tx, job_count_rx) = watch::channel_with(0);
1245 let tx = Arc::new(Mutex::new(job_count_tx));
1246 let job_handle = JobHandle::new(&tx);
1247
1248 assert_eq!(1, *job_count_rx.borrow());
1249 let new_job_handle = job_handle.clone();
1250 assert_eq!(1, *job_count_rx.borrow());
1251 drop(job_handle);
1252 assert_eq!(1, *job_count_rx.borrow());
1253 drop(new_job_handle);
1254 assert_eq!(0, *job_count_rx.borrow());
1255 }
1256}