1mod db;
2mod embedding_queue;
3mod parsing;
4pub mod semantic_index_settings;
5
6#[cfg(test)]
7mod semantic_index_tests;
8
9use crate::semantic_index_settings::SemanticIndexSettings;
10use ai::embedding::{Embedding, EmbeddingProvider};
11use ai::providers::open_ai::OpenAIEmbeddingProvider;
12use anyhow::{anyhow, Context as _, Result};
13use collections::{BTreeMap, HashMap, HashSet};
14use db::VectorDatabase;
15use embedding_queue::{EmbeddingQueue, FileToEmbed};
16use futures::{future, FutureExt, StreamExt};
17use gpui::{
18 AppContext, AsyncAppContext, BorrowWindow, Context, Model, ModelContext, Task, ViewContext,
19 WeakModel,
20};
21use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
22use lazy_static::lazy_static;
23use ordered_float::OrderedFloat;
24use parking_lot::Mutex;
25use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
26use postage::watch;
27use project::{Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
28use settings::Settings;
29use smol::channel;
30use std::{
31 cmp::Reverse,
32 env,
33 future::Future,
34 mem,
35 ops::Range,
36 path::{Path, PathBuf},
37 sync::{Arc, Weak},
38 time::{Duration, Instant, SystemTime},
39};
40use util::paths::PathMatcher;
41use util::{channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
42use workspace::Workspace;
43
44const SEMANTIC_INDEX_VERSION: usize = 11;
45const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
46const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
47
48lazy_static! {
49 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
50}
51
52pub fn init(
53 fs: Arc<dyn Fs>,
54 http_client: Arc<dyn HttpClient>,
55 language_registry: Arc<LanguageRegistry>,
56 cx: &mut AppContext,
57) {
58 SemanticIndexSettings::register(cx);
59
60 let db_file_path = EMBEDDINGS_DIR
61 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
62 .join("embeddings_db");
63
64 cx.observe_new_views(
65 |workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
66 let Some(semantic_index) = SemanticIndex::global(cx) else {
67 return;
68 };
69 let project = workspace.project().clone();
70
71 if project.read(cx).is_local() {
72 cx.app_mut()
73 .spawn(|mut cx| async move {
74 let previously_indexed = semantic_index
75 .update(&mut cx, |index, cx| {
76 index.project_previously_indexed(&project, cx)
77 })?
78 .await?;
79 if previously_indexed {
80 semantic_index
81 .update(&mut cx, |index, cx| index.index_project(project, cx))?
82 .await?;
83 }
84 anyhow::Ok(())
85 })
86 .detach_and_log_err(cx);
87 }
88 },
89 )
90 .detach();
91
92 cx.spawn(move |cx| async move {
93 let semantic_index = SemanticIndex::new(
94 fs,
95 db_file_path,
96 Arc::new(OpenAIEmbeddingProvider::new(
97 http_client,
98 cx.background_executor().clone(),
99 )),
100 language_registry,
101 cx.clone(),
102 )
103 .await?;
104
105 cx.update(|cx| cx.set_global(semantic_index.clone()))?;
106
107 anyhow::Ok(())
108 })
109 .detach();
110}
111
112#[derive(Copy, Clone, Debug)]
113pub enum SemanticIndexStatus {
114 NotAuthenticated,
115 NotIndexed,
116 Indexed,
117 Indexing {
118 remaining_files: usize,
119 rate_limit_expiry: Option<Instant>,
120 },
121}
122
123pub struct SemanticIndex {
124 fs: Arc<dyn Fs>,
125 db: VectorDatabase,
126 embedding_provider: Arc<dyn EmbeddingProvider>,
127 language_registry: Arc<LanguageRegistry>,
128 parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
129 _embedding_task: Task<()>,
130 _parsing_files_tasks: Vec<Task<()>>,
131 projects: HashMap<WeakModel<Project>, ProjectState>,
132}
133
134struct ProjectState {
135 worktrees: HashMap<WorktreeId, WorktreeState>,
136 pending_file_count_rx: watch::Receiver<usize>,
137 pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
138 pending_index: usize,
139 _subscription: gpui::Subscription,
140 _observe_pending_file_count: Task<()>,
141}
142
143enum WorktreeState {
144 Registering(RegisteringWorktreeState),
145 Registered(RegisteredWorktreeState),
146}
147
148impl WorktreeState {
149 fn is_registered(&self) -> bool {
150 matches!(self, Self::Registered(_))
151 }
152
153 fn paths_changed(
154 &mut self,
155 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
156 worktree: &Worktree,
157 ) {
158 let changed_paths = match self {
159 Self::Registering(state) => &mut state.changed_paths,
160 Self::Registered(state) => &mut state.changed_paths,
161 };
162
163 for (path, entry_id, change) in changes.iter() {
164 let Some(entry) = worktree.entry_for_id(*entry_id) else {
165 continue;
166 };
167 if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
168 continue;
169 }
170 changed_paths.insert(
171 path.clone(),
172 ChangedPathInfo {
173 mtime: entry.mtime,
174 is_deleted: *change == PathChange::Removed,
175 },
176 );
177 }
178 }
179}
180
181struct RegisteringWorktreeState {
182 changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
183 done_rx: watch::Receiver<Option<()>>,
184 _registration: Task<()>,
185}
186
187impl RegisteringWorktreeState {
188 fn done(&self) -> impl Future<Output = ()> {
189 let mut done_rx = self.done_rx.clone();
190 async move {
191 while let Some(result) = done_rx.next().await {
192 if result.is_some() {
193 break;
194 }
195 }
196 }
197 }
198}
199
200struct RegisteredWorktreeState {
201 db_id: i64,
202 changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
203}
204
205struct ChangedPathInfo {
206 mtime: SystemTime,
207 is_deleted: bool,
208}
209
210#[derive(Clone)]
211pub struct JobHandle {
212 /// The outer Arc is here to count the clones of a JobHandle instance;
213 /// when the last handle to a given job is dropped, we decrement a counter (just once).
214 tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
215}
216
217impl JobHandle {
218 fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
219 *tx.lock().borrow_mut() += 1;
220 Self {
221 tx: Arc::new(Arc::downgrade(&tx)),
222 }
223 }
224}
225
226impl ProjectState {
227 fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
228 let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
229 let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
230 Self {
231 worktrees: Default::default(),
232 pending_file_count_rx: pending_file_count_rx.clone(),
233 pending_file_count_tx,
234 pending_index: 0,
235 _subscription: subscription,
236 _observe_pending_file_count: cx.spawn({
237 let mut pending_file_count_rx = pending_file_count_rx.clone();
238 |this, mut cx| async move {
239 while let Some(_) = pending_file_count_rx.next().await {
240 if this.update(&mut cx, |_, cx| cx.notify()).is_err() {
241 break;
242 }
243 }
244 }
245 }),
246 }
247 }
248
249 fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
250 self.worktrees
251 .iter()
252 .find_map(|(worktree_id, worktree_state)| match worktree_state {
253 WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
254 _ => None,
255 })
256 }
257}
258
259#[derive(Clone)]
260pub struct PendingFile {
261 worktree_db_id: i64,
262 relative_path: Arc<Path>,
263 absolute_path: PathBuf,
264 language: Option<Arc<Language>>,
265 modified_time: SystemTime,
266 job_handle: JobHandle,
267}
268
269#[derive(Clone)]
270pub struct SearchResult {
271 pub buffer: Model<Buffer>,
272 pub range: Range<Anchor>,
273 pub similarity: OrderedFloat<f32>,
274}
275
276impl SemanticIndex {
277 pub fn global(cx: &mut AppContext) -> Option<Model<SemanticIndex>> {
278 cx.try_global::<Model<Self>>()
279 .map(|semantic_index| semantic_index.clone())
280 }
281
282 pub fn authenticate(&mut self, cx: &mut AppContext) -> bool {
283 if !self.embedding_provider.has_credentials() {
284 self.embedding_provider.retrieve_credentials(cx);
285 } else {
286 return true;
287 }
288
289 self.embedding_provider.has_credentials()
290 }
291
292 pub fn is_authenticated(&self) -> bool {
293 self.embedding_provider.has_credentials()
294 }
295
296 pub fn enabled(cx: &AppContext) -> bool {
297 SemanticIndexSettings::get_global(cx).enabled
298 }
299
300 pub fn status(&self, project: &Model<Project>) -> SemanticIndexStatus {
301 if !self.is_authenticated() {
302 return SemanticIndexStatus::NotAuthenticated;
303 }
304
305 if let Some(project_state) = self.projects.get(&project.downgrade()) {
306 if project_state
307 .worktrees
308 .values()
309 .all(|worktree| worktree.is_registered())
310 && project_state.pending_index == 0
311 {
312 SemanticIndexStatus::Indexed
313 } else {
314 SemanticIndexStatus::Indexing {
315 remaining_files: project_state.pending_file_count_rx.borrow().clone(),
316 rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
317 }
318 }
319 } else {
320 SemanticIndexStatus::NotIndexed
321 }
322 }
323
324 pub async fn new(
325 fs: Arc<dyn Fs>,
326 database_path: PathBuf,
327 embedding_provider: Arc<dyn EmbeddingProvider>,
328 language_registry: Arc<LanguageRegistry>,
329 mut cx: AsyncAppContext,
330 ) -> Result<Model<Self>> {
331 let t0 = Instant::now();
332 let database_path = Arc::from(database_path);
333 let db = VectorDatabase::new(fs.clone(), database_path, cx.background_executor().clone())
334 .await?;
335
336 log::trace!(
337 "db initialization took {:?} milliseconds",
338 t0.elapsed().as_millis()
339 );
340
341 cx.new_model(|cx| {
342 let t0 = Instant::now();
343 let embedding_queue =
344 EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor().clone());
345 let _embedding_task = cx.background_executor().spawn({
346 let embedded_files = embedding_queue.finished_files();
347 let db = db.clone();
348 async move {
349 while let Ok(file) = embedded_files.recv().await {
350 db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
351 .await
352 .log_err();
353 }
354 }
355 });
356
357 // Parse files into embeddable spans.
358 let (parsing_files_tx, parsing_files_rx) =
359 channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
360 let embedding_queue = Arc::new(Mutex::new(embedding_queue));
361 let mut _parsing_files_tasks = Vec::new();
362 for _ in 0..cx.background_executor().num_cpus() {
363 let fs = fs.clone();
364 let mut parsing_files_rx = parsing_files_rx.clone();
365 let embedding_provider = embedding_provider.clone();
366 let embedding_queue = embedding_queue.clone();
367 let background = cx.background_executor().clone();
368 _parsing_files_tasks.push(cx.background_executor().spawn(async move {
369 let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
370 loop {
371 let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
372 let mut next_file_to_parse = parsing_files_rx.next().fuse();
373 futures::select_biased! {
374 next_file_to_parse = next_file_to_parse => {
375 if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
376 Self::parse_file(
377 &fs,
378 pending_file,
379 &mut retriever,
380 &embedding_queue,
381 &embeddings_for_digest,
382 )
383 .await
384 } else {
385 break;
386 }
387 },
388 _ = timer => {
389 embedding_queue.lock().flush();
390 }
391 }
392 }
393 }));
394 }
395
396 log::trace!(
397 "semantic index task initialization took {:?} milliseconds",
398 t0.elapsed().as_millis()
399 );
400 Self {
401 fs,
402 db,
403 embedding_provider,
404 language_registry,
405 parsing_files_tx,
406 _embedding_task,
407 _parsing_files_tasks,
408 projects: Default::default(),
409 }
410 })
411 }
412
413 async fn parse_file(
414 fs: &Arc<dyn Fs>,
415 pending_file: PendingFile,
416 retriever: &mut CodeContextRetriever,
417 embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
418 embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
419 ) {
420 let Some(language) = pending_file.language else {
421 return;
422 };
423
424 if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
425 if let Some(mut spans) = retriever
426 .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
427 .log_err()
428 {
429 log::trace!(
430 "parsed path {:?}: {} spans",
431 pending_file.relative_path,
432 spans.len()
433 );
434
435 for span in &mut spans {
436 if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
437 span.embedding = Some(embedding.to_owned());
438 }
439 }
440
441 embedding_queue.lock().push(FileToEmbed {
442 worktree_id: pending_file.worktree_db_id,
443 path: pending_file.relative_path,
444 mtime: pending_file.modified_time,
445 job_handle: pending_file.job_handle,
446 spans,
447 });
448 }
449 }
450 }
451
452 pub fn project_previously_indexed(
453 &mut self,
454 project: &Model<Project>,
455 cx: &mut ModelContext<Self>,
456 ) -> Task<Result<bool>> {
457 let worktrees_indexed_previously = project
458 .read(cx)
459 .worktrees()
460 .map(|worktree| {
461 self.db
462 .worktree_previously_indexed(&worktree.read(cx).abs_path())
463 })
464 .collect::<Vec<_>>();
465 cx.spawn(|_, _cx| async move {
466 let worktree_indexed_previously =
467 futures::future::join_all(worktrees_indexed_previously).await;
468
469 Ok(worktree_indexed_previously
470 .iter()
471 .filter(|worktree| worktree.is_ok())
472 .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
473 })
474 }
475
476 fn project_entries_changed(
477 &mut self,
478 project: Model<Project>,
479 worktree_id: WorktreeId,
480 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
481 cx: &mut ModelContext<Self>,
482 ) {
483 let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
484 return;
485 };
486 let project = project.downgrade();
487 let Some(project_state) = self.projects.get_mut(&project) else {
488 return;
489 };
490
491 let worktree = worktree.read(cx);
492 let worktree_state =
493 if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
494 worktree_state
495 } else {
496 return;
497 };
498 worktree_state.paths_changed(changes, worktree);
499 if let WorktreeState::Registered(_) = worktree_state {
500 cx.spawn(|this, mut cx| async move {
501 cx.background_executor()
502 .timer(BACKGROUND_INDEXING_DELAY)
503 .await;
504 if let Some((this, project)) = this.upgrade().zip(project.upgrade()) {
505 this.update(&mut cx, |this, cx| {
506 this.index_project(project, cx).detach_and_log_err(cx)
507 })?;
508 }
509 anyhow::Ok(())
510 })
511 .detach_and_log_err(cx);
512 }
513 }
514
515 fn register_worktree(
516 &mut self,
517 project: Model<Project>,
518 worktree: Model<Worktree>,
519 cx: &mut ModelContext<Self>,
520 ) {
521 let project = project.downgrade();
522 let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
523 project_state
524 } else {
525 return;
526 };
527 let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
528 worktree
529 } else {
530 return;
531 };
532 let worktree_abs_path = worktree.abs_path().clone();
533 let scan_complete = worktree.scan_complete();
534 let worktree_id = worktree.id();
535 let db = self.db.clone();
536 let language_registry = self.language_registry.clone();
537 let (mut done_tx, done_rx) = watch::channel();
538 let registration = cx.spawn(|this, mut cx| {
539 async move {
540 let register = async {
541 scan_complete.await;
542 let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
543 let mut file_mtimes = db.get_file_mtimes(db_id).await?;
544 let worktree = if let Some(project) = project.upgrade() {
545 project
546 .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
547 .ok()
548 .flatten()
549 .context("worktree not found")?
550 } else {
551 return anyhow::Ok(());
552 };
553 let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot())?;
554 let mut changed_paths = cx
555 .background_executor()
556 .spawn(async move {
557 let mut changed_paths = BTreeMap::new();
558 for file in worktree.files(false, 0) {
559 let absolute_path = worktree.absolutize(&file.path)?;
560
561 if file.is_external || file.is_ignored || file.is_symlink {
562 continue;
563 }
564
565 if let Ok(language) = language_registry
566 .language_for_file(&absolute_path, None)
567 .await
568 {
569 // Test if file is valid parseable file
570 if !PARSEABLE_ENTIRE_FILE_TYPES
571 .contains(&language.name().as_ref())
572 && &language.name().as_ref() != &"Markdown"
573 && language
574 .grammar()
575 .and_then(|grammar| grammar.embedding_config.as_ref())
576 .is_none()
577 {
578 continue;
579 }
580
581 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
582 let already_stored = stored_mtime
583 .map_or(false, |existing_mtime| {
584 existing_mtime == file.mtime
585 });
586
587 if !already_stored {
588 changed_paths.insert(
589 file.path.clone(),
590 ChangedPathInfo {
591 mtime: file.mtime,
592 is_deleted: false,
593 },
594 );
595 }
596 }
597 }
598
599 // Clean up entries from database that are no longer in the worktree.
600 for (path, mtime) in file_mtimes {
601 changed_paths.insert(
602 path.into(),
603 ChangedPathInfo {
604 mtime,
605 is_deleted: true,
606 },
607 );
608 }
609
610 anyhow::Ok(changed_paths)
611 })
612 .await?;
613 this.update(&mut cx, |this, cx| {
614 let project_state = this
615 .projects
616 .get_mut(&project)
617 .context("project not registered")?;
618 let project = project.upgrade().context("project was dropped")?;
619
620 if let Some(WorktreeState::Registering(state)) =
621 project_state.worktrees.remove(&worktree_id)
622 {
623 changed_paths.extend(state.changed_paths);
624 }
625 project_state.worktrees.insert(
626 worktree_id,
627 WorktreeState::Registered(RegisteredWorktreeState {
628 db_id,
629 changed_paths,
630 }),
631 );
632 this.index_project(project, cx).detach_and_log_err(cx);
633
634 anyhow::Ok(())
635 })??;
636
637 anyhow::Ok(())
638 };
639
640 if register.await.log_err().is_none() {
641 // Stop tracking this worktree if the registration failed.
642 this.update(&mut cx, |this, _| {
643 this.projects.get_mut(&project).map(|project_state| {
644 project_state.worktrees.remove(&worktree_id);
645 });
646 })
647 .ok();
648 }
649
650 *done_tx.borrow_mut() = Some(());
651 }
652 });
653 project_state.worktrees.insert(
654 worktree_id,
655 WorktreeState::Registering(RegisteringWorktreeState {
656 changed_paths: Default::default(),
657 done_rx,
658 _registration: registration,
659 }),
660 );
661 }
662
663 fn project_worktrees_changed(&mut self, project: Model<Project>, cx: &mut ModelContext<Self>) {
664 let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
665 {
666 project_state
667 } else {
668 return;
669 };
670
671 let mut worktrees = project
672 .read(cx)
673 .worktrees()
674 .filter(|worktree| worktree.read(cx).is_local())
675 .collect::<Vec<_>>();
676 let worktree_ids = worktrees
677 .iter()
678 .map(|worktree| worktree.read(cx).id())
679 .collect::<HashSet<_>>();
680
681 // Remove worktrees that are no longer present
682 project_state
683 .worktrees
684 .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
685
686 // Register new worktrees
687 worktrees.retain(|worktree| {
688 let worktree_id = worktree.read(cx).id();
689 !project_state.worktrees.contains_key(&worktree_id)
690 });
691 for worktree in worktrees {
692 self.register_worktree(project.clone(), worktree, cx);
693 }
694 }
695
696 pub fn pending_file_count(&self, project: &Model<Project>) -> Option<watch::Receiver<usize>> {
697 Some(
698 self.projects
699 .get(&project.downgrade())?
700 .pending_file_count_rx
701 .clone(),
702 )
703 }
704
705 pub fn search_project(
706 &mut self,
707 project: Model<Project>,
708 query: String,
709 limit: usize,
710 includes: Vec<PathMatcher>,
711 excludes: Vec<PathMatcher>,
712 cx: &mut ModelContext<Self>,
713 ) -> Task<Result<Vec<SearchResult>>> {
714 if query.is_empty() {
715 return Task::ready(Ok(Vec::new()));
716 }
717
718 let index = self.index_project(project.clone(), cx);
719 let embedding_provider = self.embedding_provider.clone();
720
721 cx.spawn(|this, mut cx| async move {
722 index.await?;
723 let t0 = Instant::now();
724
725 let query = embedding_provider
726 .embed_batch(vec![query])
727 .await?
728 .pop()
729 .context("could not embed query")?;
730 log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
731
732 let search_start = Instant::now();
733 let modified_buffer_results = this.update(&mut cx, |this, cx| {
734 this.search_modified_buffers(
735 &project,
736 query.clone(),
737 limit,
738 &includes,
739 &excludes,
740 cx,
741 )
742 })?;
743 let file_results = this.update(&mut cx, |this, cx| {
744 this.search_files(project, query, limit, includes, excludes, cx)
745 })?;
746 let (modified_buffer_results, file_results) =
747 futures::join!(modified_buffer_results, file_results);
748
749 // Weave together the results from modified buffers and files.
750 let mut results = Vec::new();
751 let mut modified_buffers = HashSet::default();
752 for result in modified_buffer_results.log_err().unwrap_or_default() {
753 modified_buffers.insert(result.buffer.clone());
754 results.push(result);
755 }
756 for result in file_results.log_err().unwrap_or_default() {
757 if !modified_buffers.contains(&result.buffer) {
758 results.push(result);
759 }
760 }
761 results.sort_by_key(|result| Reverse(result.similarity));
762 results.truncate(limit);
763 log::trace!("Semantic search took {:?}", search_start.elapsed());
764 Ok(results)
765 })
766 }
767
768 pub fn search_files(
769 &mut self,
770 project: Model<Project>,
771 query: Embedding,
772 limit: usize,
773 includes: Vec<PathMatcher>,
774 excludes: Vec<PathMatcher>,
775 cx: &mut ModelContext<Self>,
776 ) -> Task<Result<Vec<SearchResult>>> {
777 let db_path = self.db.path().clone();
778 let fs = self.fs.clone();
779 cx.spawn(|this, mut cx| async move {
780 let database = VectorDatabase::new(
781 fs.clone(),
782 db_path.clone(),
783 cx.background_executor().clone(),
784 )
785 .await?;
786
787 let worktree_db_ids = this.read_with(&cx, |this, _| {
788 let project_state = this
789 .projects
790 .get(&project.downgrade())
791 .context("project was not indexed")?;
792 let worktree_db_ids = project_state
793 .worktrees
794 .values()
795 .filter_map(|worktree| {
796 if let WorktreeState::Registered(worktree) = worktree {
797 Some(worktree.db_id)
798 } else {
799 None
800 }
801 })
802 .collect::<Vec<i64>>();
803 anyhow::Ok(worktree_db_ids)
804 })??;
805
806 let file_ids = database
807 .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
808 .await?;
809
810 let batch_n = cx.background_executor().num_cpus();
811 let ids_len = file_ids.clone().len();
812 let minimum_batch_size = 50;
813
814 let batch_size = {
815 let size = ids_len / batch_n;
816 if size < minimum_batch_size {
817 minimum_batch_size
818 } else {
819 size
820 }
821 };
822
823 let mut batch_results = Vec::new();
824 for batch in file_ids.chunks(batch_size) {
825 let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
826 let limit = limit.clone();
827 let fs = fs.clone();
828 let db_path = db_path.clone();
829 let query = query.clone();
830 if let Some(db) =
831 VectorDatabase::new(fs, db_path.clone(), cx.background_executor().clone())
832 .await
833 .log_err()
834 {
835 batch_results.push(async move {
836 db.top_k_search(&query, limit, batch.as_slice()).await
837 });
838 }
839 }
840
841 let batch_results = futures::future::join_all(batch_results).await;
842
843 let mut results = Vec::new();
844 for batch_result in batch_results {
845 if batch_result.is_ok() {
846 for (id, similarity) in batch_result.unwrap() {
847 let ix = match results
848 .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
849 {
850 Ok(ix) => ix,
851 Err(ix) => ix,
852 };
853
854 results.insert(ix, (id, similarity));
855 results.truncate(limit);
856 }
857 }
858 }
859
860 let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
861 let scores = results
862 .into_iter()
863 .map(|(_, score)| score)
864 .collect::<Vec<_>>();
865 let spans = database.spans_for_ids(ids.as_slice()).await?;
866
867 let mut tasks = Vec::new();
868 let mut ranges = Vec::new();
869 let weak_project = project.downgrade();
870 project.update(&mut cx, |project, cx| {
871 let this = this.upgrade().context("index was dropped")?;
872 for (worktree_db_id, file_path, byte_range) in spans {
873 let project_state =
874 if let Some(state) = this.read(cx).projects.get(&weak_project) {
875 state
876 } else {
877 return Err(anyhow!("project not added"));
878 };
879 if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
880 tasks.push(project.open_buffer((worktree_id, file_path), cx));
881 ranges.push(byte_range);
882 }
883 }
884
885 Ok(())
886 })??;
887
888 let buffers = futures::future::join_all(tasks).await;
889 Ok(buffers
890 .into_iter()
891 .zip(ranges)
892 .zip(scores)
893 .filter_map(|((buffer, range), similarity)| {
894 let buffer = buffer.log_err()?;
895 let range = buffer
896 .read_with(&cx, |buffer, _| {
897 let start = buffer.clip_offset(range.start, Bias::Left);
898 let end = buffer.clip_offset(range.end, Bias::Right);
899 buffer.anchor_before(start)..buffer.anchor_after(end)
900 })
901 .log_err()?;
902 Some(SearchResult {
903 buffer,
904 range,
905 similarity,
906 })
907 })
908 .collect())
909 })
910 }
911
912 fn search_modified_buffers(
913 &self,
914 project: &Model<Project>,
915 query: Embedding,
916 limit: usize,
917 includes: &[PathMatcher],
918 excludes: &[PathMatcher],
919 cx: &mut ModelContext<Self>,
920 ) -> Task<Result<Vec<SearchResult>>> {
921 let modified_buffers = project
922 .read(cx)
923 .opened_buffers()
924 .into_iter()
925 .filter_map(|buffer_handle| {
926 let buffer = buffer_handle.read(cx);
927 let snapshot = buffer.snapshot();
928 let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
929 excludes.iter().any(|matcher| matcher.is_match(&path))
930 });
931
932 let included = if includes.len() == 0 {
933 true
934 } else {
935 snapshot.resolve_file_path(cx, false).map_or(false, |path| {
936 includes.iter().any(|matcher| matcher.is_match(&path))
937 })
938 };
939
940 if buffer.is_dirty() && !excluded && included {
941 Some((buffer_handle, snapshot))
942 } else {
943 None
944 }
945 })
946 .collect::<HashMap<_, _>>();
947
948 let embedding_provider = self.embedding_provider.clone();
949 let fs = self.fs.clone();
950 let db_path = self.db.path().clone();
951 let background = cx.background_executor().clone();
952 cx.background_executor().spawn(async move {
953 let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
954 let mut results = Vec::<SearchResult>::new();
955
956 let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
957 for (buffer, snapshot) in modified_buffers {
958 let language = snapshot
959 .language_at(0)
960 .cloned()
961 .unwrap_or_else(|| language::PLAIN_TEXT.clone());
962 let mut spans = retriever
963 .parse_file_with_template(None, &snapshot.text(), language)
964 .log_err()
965 .unwrap_or_default();
966 if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
967 .await
968 .log_err()
969 .is_some()
970 {
971 for span in spans {
972 let similarity = span.embedding.unwrap().similarity(&query);
973 let ix = match results
974 .binary_search_by_key(&Reverse(similarity), |result| {
975 Reverse(result.similarity)
976 }) {
977 Ok(ix) => ix,
978 Err(ix) => ix,
979 };
980
981 let range = {
982 let start = snapshot.clip_offset(span.range.start, Bias::Left);
983 let end = snapshot.clip_offset(span.range.end, Bias::Right);
984 snapshot.anchor_before(start)..snapshot.anchor_after(end)
985 };
986
987 results.insert(
988 ix,
989 SearchResult {
990 buffer: buffer.clone(),
991 range,
992 similarity,
993 },
994 );
995 results.truncate(limit);
996 }
997 }
998 }
999
1000 Ok(results)
1001 })
1002 }
1003
1004 pub fn index_project(
1005 &mut self,
1006 project: Model<Project>,
1007 cx: &mut ModelContext<Self>,
1008 ) -> Task<Result<()>> {
1009 if !self.is_authenticated() {
1010 if !self.authenticate(cx) {
1011 return Task::ready(Err(anyhow!("user is not authenticated")));
1012 }
1013 }
1014
1015 if !self.projects.contains_key(&project.downgrade()) {
1016 let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
1017 project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
1018 this.project_worktrees_changed(project.clone(), cx);
1019 }
1020 project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
1021 this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
1022 }
1023 _ => {}
1024 });
1025 let project_state = ProjectState::new(subscription, cx);
1026 self.projects.insert(project.downgrade(), project_state);
1027 self.project_worktrees_changed(project.clone(), cx);
1028 }
1029 let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
1030 project_state.pending_index += 1;
1031 cx.notify();
1032
1033 let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
1034 let db = self.db.clone();
1035 let language_registry = self.language_registry.clone();
1036 let parsing_files_tx = self.parsing_files_tx.clone();
1037 let worktree_registration = self.wait_for_worktree_registration(&project, cx);
1038
1039 cx.spawn(|this, mut cx| async move {
1040 worktree_registration.await?;
1041
1042 let mut pending_files = Vec::new();
1043 let mut files_to_delete = Vec::new();
1044 this.update(&mut cx, |this, cx| {
1045 let project_state = this
1046 .projects
1047 .get_mut(&project.downgrade())
1048 .context("project was dropped")?;
1049 let pending_file_count_tx = &project_state.pending_file_count_tx;
1050
1051 project_state
1052 .worktrees
1053 .retain(|worktree_id, worktree_state| {
1054 let worktree = if let Some(worktree) =
1055 project.read(cx).worktree_for_id(*worktree_id, cx)
1056 {
1057 worktree
1058 } else {
1059 return false;
1060 };
1061 let worktree_state =
1062 if let WorktreeState::Registered(worktree_state) = worktree_state {
1063 worktree_state
1064 } else {
1065 return true;
1066 };
1067
1068 for (path, info) in &worktree_state.changed_paths {
1069 if info.is_deleted {
1070 files_to_delete.push((worktree_state.db_id, path.clone()));
1071 } else if let Ok(absolute_path) = worktree.read(cx).absolutize(path) {
1072 let job_handle = JobHandle::new(pending_file_count_tx);
1073 pending_files.push(PendingFile {
1074 absolute_path,
1075 relative_path: path.clone(),
1076 language: None,
1077 job_handle,
1078 modified_time: info.mtime,
1079 worktree_db_id: worktree_state.db_id,
1080 });
1081 }
1082 }
1083 worktree_state.changed_paths.clear();
1084 true
1085 });
1086
1087 anyhow::Ok(())
1088 })??;
1089
1090 cx.background_executor()
1091 .spawn(async move {
1092 for (worktree_db_id, path) in files_to_delete {
1093 db.delete_file(worktree_db_id, path).await.log_err();
1094 }
1095
1096 let embeddings_for_digest = {
1097 let mut files = HashMap::default();
1098 for pending_file in &pending_files {
1099 files
1100 .entry(pending_file.worktree_db_id)
1101 .or_insert(Vec::new())
1102 .push(pending_file.relative_path.clone());
1103 }
1104 Arc::new(
1105 db.embeddings_for_files(files)
1106 .await
1107 .log_err()
1108 .unwrap_or_default(),
1109 )
1110 };
1111
1112 for mut pending_file in pending_files {
1113 if let Ok(language) = language_registry
1114 .language_for_file(&pending_file.relative_path, None)
1115 .await
1116 {
1117 if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
1118 && &language.name().as_ref() != &"Markdown"
1119 && language
1120 .grammar()
1121 .and_then(|grammar| grammar.embedding_config.as_ref())
1122 .is_none()
1123 {
1124 continue;
1125 }
1126 pending_file.language = Some(language);
1127 }
1128 parsing_files_tx
1129 .try_send((embeddings_for_digest.clone(), pending_file))
1130 .ok();
1131 }
1132
1133 // Wait until we're done indexing.
1134 while let Some(count) = pending_file_count_rx.next().await {
1135 if count == 0 {
1136 break;
1137 }
1138 }
1139 })
1140 .await;
1141
1142 this.update(&mut cx, |this, cx| {
1143 let project_state = this
1144 .projects
1145 .get_mut(&project.downgrade())
1146 .context("project was dropped")?;
1147 project_state.pending_index -= 1;
1148 cx.notify();
1149 anyhow::Ok(())
1150 })??;
1151
1152 Ok(())
1153 })
1154 }
1155
1156 fn wait_for_worktree_registration(
1157 &self,
1158 project: &Model<Project>,
1159 cx: &mut ModelContext<Self>,
1160 ) -> Task<Result<()>> {
1161 let project = project.downgrade();
1162 cx.spawn(|this, cx| async move {
1163 loop {
1164 let mut pending_worktrees = Vec::new();
1165 this.upgrade()
1166 .context("semantic index dropped")?
1167 .read_with(&cx, |this, _| {
1168 if let Some(project) = this.projects.get(&project) {
1169 for worktree in project.worktrees.values() {
1170 if let WorktreeState::Registering(worktree) = worktree {
1171 pending_worktrees.push(worktree.done());
1172 }
1173 }
1174 }
1175 })?;
1176
1177 if pending_worktrees.is_empty() {
1178 break;
1179 } else {
1180 future::join_all(pending_worktrees).await;
1181 }
1182 }
1183 Ok(())
1184 })
1185 }
1186
1187 async fn embed_spans(
1188 spans: &mut [Span],
1189 embedding_provider: &dyn EmbeddingProvider,
1190 db: &VectorDatabase,
1191 ) -> Result<()> {
1192 let mut batch = Vec::new();
1193 let mut batch_tokens = 0;
1194 let mut embeddings = Vec::new();
1195
1196 let digests = spans
1197 .iter()
1198 .map(|span| span.digest.clone())
1199 .collect::<Vec<_>>();
1200 let embeddings_for_digests = db
1201 .embeddings_for_digests(digests)
1202 .await
1203 .log_err()
1204 .unwrap_or_default();
1205
1206 for span in &*spans {
1207 if embeddings_for_digests.contains_key(&span.digest) {
1208 continue;
1209 };
1210
1211 if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
1212 let batch_embeddings = embedding_provider
1213 .embed_batch(mem::take(&mut batch))
1214 .await?;
1215 embeddings.extend(batch_embeddings);
1216 batch_tokens = 0;
1217 }
1218
1219 batch_tokens += span.token_count;
1220 batch.push(span.content.clone());
1221 }
1222
1223 if !batch.is_empty() {
1224 let batch_embeddings = embedding_provider
1225 .embed_batch(mem::take(&mut batch))
1226 .await?;
1227
1228 embeddings.extend(batch_embeddings);
1229 }
1230
1231 let mut embeddings = embeddings.into_iter();
1232 for span in spans {
1233 let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
1234 Some(embedding.clone())
1235 } else {
1236 embeddings.next()
1237 };
1238 let embedding = embedding.context("failed to embed spans")?;
1239 span.embedding = Some(embedding);
1240 }
1241 Ok(())
1242 }
1243}
1244
1245impl Drop for JobHandle {
1246 fn drop(&mut self) {
1247 if let Some(inner) = Arc::get_mut(&mut self.tx) {
1248 // This is the last instance of the JobHandle (regardless of its origin - whether it was cloned or not)
1249 if let Some(tx) = inner.upgrade() {
1250 let mut tx = tx.lock();
1251 *tx.borrow_mut() -= 1;
1252 }
1253 }
1254 }
1255}
1256
1257#[cfg(test)]
1258mod tests {
1259
1260 use super::*;
1261 #[test]
1262 fn test_job_handle() {
1263 let (job_count_tx, job_count_rx) = watch::channel_with(0);
1264 let tx = Arc::new(Mutex::new(job_count_tx));
1265 let job_handle = JobHandle::new(&tx);
1266
1267 assert_eq!(1, *job_count_rx.borrow());
1268 let new_job_handle = job_handle.clone();
1269 assert_eq!(1, *job_count_rx.borrow());
1270 drop(job_handle);
1271 assert_eq!(1, *job_count_rx.borrow());
1272 drop(new_job_handle);
1273 assert_eq!(0, *job_count_rx.borrow());
1274 }
1275}