1mod db;
2mod embedding_queue;
3mod parsing;
4pub mod semantic_index_settings;
5
6#[cfg(test)]
7mod semantic_index_tests;
8
9use crate::semantic_index_settings::SemanticIndexSettings;
10use ai::embedding::{Embedding, EmbeddingProvider};
11use ai::providers::open_ai::{OpenAiEmbeddingProvider, OPEN_AI_API_URL};
12use anyhow::{anyhow, Context as _, Result};
13use collections::{BTreeMap, HashMap, HashSet};
14use db::VectorDatabase;
15use embedding_queue::{EmbeddingQueue, FileToEmbed};
16use futures::{future, FutureExt, StreamExt};
17use gpui::{
18 AppContext, AsyncAppContext, BorrowWindow, Context, Global, Model, ModelContext, Task,
19 ViewContext, WeakModel,
20};
21use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
22use lazy_static::lazy_static;
23use ordered_float::OrderedFloat;
24use parking_lot::Mutex;
25use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
26use postage::watch;
27use project::{Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
28use release_channel::ReleaseChannel;
29use settings::Settings;
30use smol::channel;
31use std::{
32 cmp::Reverse,
33 env,
34 future::Future,
35 mem,
36 ops::Range,
37 path::{Path, PathBuf},
38 sync::{Arc, Weak},
39 time::{Duration, Instant, SystemTime},
40};
41use util::paths::PathMatcher;
42use util::{http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
43use workspace::Workspace;
44
45const SEMANTIC_INDEX_VERSION: usize = 11;
46const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
47const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
48
49lazy_static! {
50 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
51}
52
53pub fn init(
54 fs: Arc<dyn Fs>,
55 http_client: Arc<dyn HttpClient>,
56 language_registry: Arc<LanguageRegistry>,
57 cx: &mut AppContext,
58) {
59 SemanticIndexSettings::register(cx);
60
61 let db_file_path = EMBEDDINGS_DIR
62 .join(Path::new(ReleaseChannel::global(cx).dev_name()))
63 .join("embeddings_db");
64
65 cx.observe_new_views(
66 |workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
67 let Some(semantic_index) = SemanticIndex::global(cx) else {
68 return;
69 };
70 let project = workspace.project().clone();
71
72 if project.read(cx).is_local() {
73 cx.app_mut()
74 .spawn(|mut cx| async move {
75 let previously_indexed = semantic_index
76 .update(&mut cx, |index, cx| {
77 index.project_previously_indexed(&project, cx)
78 })?
79 .await?;
80 if previously_indexed {
81 semantic_index
82 .update(&mut cx, |index, cx| index.index_project(project, cx))?
83 .await?;
84 }
85 anyhow::Ok(())
86 })
87 .detach_and_log_err(cx);
88 }
89 },
90 )
91 .detach();
92
93 cx.spawn(move |cx| async move {
94 let embedding_provider = OpenAiEmbeddingProvider::new(
95 // TODO: We should read it from config, but I'm not sure whether to reuse `openai_api_url` in assistant settings or not
96 OPEN_AI_API_URL.to_string(),
97 http_client,
98 cx.background_executor().clone(),
99 )
100 .await;
101 let semantic_index = SemanticIndex::new(
102 fs,
103 db_file_path,
104 Arc::new(embedding_provider),
105 language_registry,
106 cx.clone(),
107 )
108 .await?;
109
110 cx.update(|cx| cx.set_global(GlobalSemanticIndex(semantic_index.clone())))?;
111
112 anyhow::Ok(())
113 })
114 .detach();
115}
116
117#[derive(Copy, Clone, Debug)]
118pub enum SemanticIndexStatus {
119 NotAuthenticated,
120 NotIndexed,
121 Indexed,
122 Indexing {
123 remaining_files: usize,
124 rate_limit_expiry: Option<Instant>,
125 },
126}
127
128pub struct SemanticIndex {
129 fs: Arc<dyn Fs>,
130 db: VectorDatabase,
131 embedding_provider: Arc<dyn EmbeddingProvider>,
132 language_registry: Arc<LanguageRegistry>,
133 parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
134 _embedding_task: Task<()>,
135 _parsing_files_tasks: Vec<Task<()>>,
136 projects: HashMap<WeakModel<Project>, ProjectState>,
137}
138
139struct GlobalSemanticIndex(Model<SemanticIndex>);
140
141impl Global for GlobalSemanticIndex {}
142
143struct ProjectState {
144 worktrees: HashMap<WorktreeId, WorktreeState>,
145 pending_file_count_rx: watch::Receiver<usize>,
146 pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
147 pending_index: usize,
148 _subscription: gpui::Subscription,
149 _observe_pending_file_count: Task<()>,
150}
151
152enum WorktreeState {
153 Registering(RegisteringWorktreeState),
154 Registered(RegisteredWorktreeState),
155}
156
157impl WorktreeState {
158 fn is_registered(&self) -> bool {
159 matches!(self, Self::Registered(_))
160 }
161
162 fn paths_changed(
163 &mut self,
164 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
165 worktree: &Worktree,
166 ) {
167 let changed_paths = match self {
168 Self::Registering(state) => &mut state.changed_paths,
169 Self::Registered(state) => &mut state.changed_paths,
170 };
171
172 for (path, entry_id, change) in changes.iter() {
173 let Some(entry) = worktree.entry_for_id(*entry_id) else {
174 continue;
175 };
176 let Some(mtime) = entry.mtime else {
177 continue;
178 };
179 if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
180 continue;
181 }
182 changed_paths.insert(
183 path.clone(),
184 ChangedPathInfo {
185 mtime,
186 is_deleted: *change == PathChange::Removed,
187 },
188 );
189 }
190 }
191}
192
193struct RegisteringWorktreeState {
194 changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
195 done_rx: watch::Receiver<Option<()>>,
196 _registration: Task<()>,
197}
198
199impl RegisteringWorktreeState {
200 fn done(&self) -> impl Future<Output = ()> {
201 let mut done_rx = self.done_rx.clone();
202 async move {
203 while let Some(result) = done_rx.next().await {
204 if result.is_some() {
205 break;
206 }
207 }
208 }
209 }
210}
211
212struct RegisteredWorktreeState {
213 db_id: i64,
214 changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
215}
216
217struct ChangedPathInfo {
218 mtime: SystemTime,
219 is_deleted: bool,
220}
221
222#[derive(Clone)]
223pub struct JobHandle {
224 /// The outer Arc is here to count the clones of a JobHandle instance;
225 /// when the last handle to a given job is dropped, we decrement a counter (just once).
226 tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
227}
228
229impl JobHandle {
230 fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
231 *tx.lock().borrow_mut() += 1;
232 Self {
233 tx: Arc::new(Arc::downgrade(&tx)),
234 }
235 }
236}
237
238impl ProjectState {
239 fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
240 let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
241 let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
242 Self {
243 worktrees: Default::default(),
244 pending_file_count_rx: pending_file_count_rx.clone(),
245 pending_file_count_tx,
246 pending_index: 0,
247 _subscription: subscription,
248 _observe_pending_file_count: cx.spawn({
249 let mut pending_file_count_rx = pending_file_count_rx.clone();
250 |this, mut cx| async move {
251 while let Some(_) = pending_file_count_rx.next().await {
252 if this.update(&mut cx, |_, cx| cx.notify()).is_err() {
253 break;
254 }
255 }
256 }
257 }),
258 }
259 }
260
261 fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
262 self.worktrees
263 .iter()
264 .find_map(|(worktree_id, worktree_state)| match worktree_state {
265 WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
266 _ => None,
267 })
268 }
269}
270
271#[derive(Clone)]
272pub struct PendingFile {
273 worktree_db_id: i64,
274 relative_path: Arc<Path>,
275 absolute_path: PathBuf,
276 language: Option<Arc<Language>>,
277 modified_time: SystemTime,
278 job_handle: JobHandle,
279}
280
281#[derive(Clone)]
282pub struct SearchResult {
283 pub buffer: Model<Buffer>,
284 pub range: Range<Anchor>,
285 pub similarity: OrderedFloat<f32>,
286}
287
288impl SemanticIndex {
289 pub fn global(cx: &mut AppContext) -> Option<Model<SemanticIndex>> {
290 cx.try_global::<GlobalSemanticIndex>()
291 .map(|semantic_index| semantic_index.0.clone())
292 }
293
294 pub fn authenticate(&mut self, cx: &mut AppContext) -> Task<bool> {
295 if !self.embedding_provider.has_credentials() {
296 let embedding_provider = self.embedding_provider.clone();
297 cx.spawn(|cx| async move {
298 if let Some(retrieve_credentials) = cx
299 .update(|cx| embedding_provider.retrieve_credentials(cx))
300 .log_err()
301 {
302 retrieve_credentials.await;
303 }
304
305 embedding_provider.has_credentials()
306 })
307 } else {
308 Task::ready(true)
309 }
310 }
311
312 pub fn is_authenticated(&self) -> bool {
313 self.embedding_provider.has_credentials()
314 }
315
316 pub fn enabled(cx: &AppContext) -> bool {
317 SemanticIndexSettings::get_global(cx).enabled
318 }
319
320 pub fn status(&self, project: &Model<Project>) -> SemanticIndexStatus {
321 if !self.is_authenticated() {
322 return SemanticIndexStatus::NotAuthenticated;
323 }
324
325 if let Some(project_state) = self.projects.get(&project.downgrade()) {
326 if project_state
327 .worktrees
328 .values()
329 .all(|worktree| worktree.is_registered())
330 && project_state.pending_index == 0
331 {
332 SemanticIndexStatus::Indexed
333 } else {
334 SemanticIndexStatus::Indexing {
335 remaining_files: *project_state.pending_file_count_rx.borrow(),
336 rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
337 }
338 }
339 } else {
340 SemanticIndexStatus::NotIndexed
341 }
342 }
343
344 pub async fn new(
345 fs: Arc<dyn Fs>,
346 database_path: PathBuf,
347 embedding_provider: Arc<dyn EmbeddingProvider>,
348 language_registry: Arc<LanguageRegistry>,
349 mut cx: AsyncAppContext,
350 ) -> Result<Model<Self>> {
351 let t0 = Instant::now();
352 let database_path = Arc::from(database_path);
353 let db = VectorDatabase::new(fs.clone(), database_path, cx.background_executor().clone())
354 .await?;
355
356 log::trace!(
357 "db initialization took {:?} milliseconds",
358 t0.elapsed().as_millis()
359 );
360
361 cx.new_model(|cx| {
362 let t0 = Instant::now();
363 let embedding_queue =
364 EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor().clone());
365 let _embedding_task = cx.background_executor().spawn({
366 let embedded_files = embedding_queue.finished_files();
367 let db = db.clone();
368 async move {
369 while let Ok(file) = embedded_files.recv().await {
370 db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
371 .await
372 .log_err();
373 }
374 }
375 });
376
377 // Parse files into embeddable spans.
378 let (parsing_files_tx, parsing_files_rx) =
379 channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
380 let embedding_queue = Arc::new(Mutex::new(embedding_queue));
381 let mut _parsing_files_tasks = Vec::new();
382 for _ in 0..cx.background_executor().num_cpus() {
383 let fs = fs.clone();
384 let mut parsing_files_rx = parsing_files_rx.clone();
385 let embedding_provider = embedding_provider.clone();
386 let embedding_queue = embedding_queue.clone();
387 let background = cx.background_executor().clone();
388 _parsing_files_tasks.push(cx.background_executor().spawn(async move {
389 let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
390 loop {
391 let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
392 let mut next_file_to_parse = parsing_files_rx.next().fuse();
393 futures::select_biased! {
394 next_file_to_parse = next_file_to_parse => {
395 if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
396 Self::parse_file(
397 &fs,
398 pending_file,
399 &mut retriever,
400 &embedding_queue,
401 &embeddings_for_digest,
402 )
403 .await
404 } else {
405 break;
406 }
407 },
408 _ = timer => {
409 embedding_queue.lock().flush();
410 }
411 }
412 }
413 }));
414 }
415
416 log::trace!(
417 "semantic index task initialization took {:?} milliseconds",
418 t0.elapsed().as_millis()
419 );
420 Self {
421 fs,
422 db,
423 embedding_provider,
424 language_registry,
425 parsing_files_tx,
426 _embedding_task,
427 _parsing_files_tasks,
428 projects: Default::default(),
429 }
430 })
431 }
432
433 async fn parse_file(
434 fs: &Arc<dyn Fs>,
435 pending_file: PendingFile,
436 retriever: &mut CodeContextRetriever,
437 embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
438 embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
439 ) {
440 let Some(language) = pending_file.language else {
441 return;
442 };
443
444 if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
445 if let Some(mut spans) = retriever
446 .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
447 .log_err()
448 {
449 log::trace!(
450 "parsed path {:?}: {} spans",
451 pending_file.relative_path,
452 spans.len()
453 );
454
455 for span in &mut spans {
456 if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
457 span.embedding = Some(embedding.to_owned());
458 }
459 }
460
461 embedding_queue.lock().push(FileToEmbed {
462 worktree_id: pending_file.worktree_db_id,
463 path: pending_file.relative_path,
464 mtime: pending_file.modified_time,
465 job_handle: pending_file.job_handle,
466 spans,
467 });
468 }
469 }
470 }
471
472 pub fn project_previously_indexed(
473 &mut self,
474 project: &Model<Project>,
475 cx: &mut ModelContext<Self>,
476 ) -> Task<Result<bool>> {
477 let worktrees_indexed_previously = project
478 .read(cx)
479 .worktrees()
480 .map(|worktree| {
481 self.db
482 .worktree_previously_indexed(&worktree.read(cx).abs_path())
483 })
484 .collect::<Vec<_>>();
485 cx.spawn(|_, _cx| async move {
486 let worktree_indexed_previously =
487 futures::future::join_all(worktrees_indexed_previously).await;
488
489 Ok(worktree_indexed_previously
490 .iter()
491 .filter(|worktree| worktree.is_ok())
492 .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
493 })
494 }
495
496 fn project_entries_changed(
497 &mut self,
498 project: Model<Project>,
499 worktree_id: WorktreeId,
500 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
501 cx: &mut ModelContext<Self>,
502 ) {
503 let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) else {
504 return;
505 };
506 let project = project.downgrade();
507 let Some(project_state) = self.projects.get_mut(&project) else {
508 return;
509 };
510
511 let worktree = worktree.read(cx);
512 let worktree_state =
513 if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
514 worktree_state
515 } else {
516 return;
517 };
518 worktree_state.paths_changed(changes, worktree);
519 if let WorktreeState::Registered(_) = worktree_state {
520 cx.spawn(|this, mut cx| async move {
521 cx.background_executor()
522 .timer(BACKGROUND_INDEXING_DELAY)
523 .await;
524 if let Some((this, project)) = this.upgrade().zip(project.upgrade()) {
525 this.update(&mut cx, |this, cx| {
526 this.index_project(project, cx).detach_and_log_err(cx)
527 })?;
528 }
529 anyhow::Ok(())
530 })
531 .detach_and_log_err(cx);
532 }
533 }
534
535 fn register_worktree(
536 &mut self,
537 project: Model<Project>,
538 worktree: Model<Worktree>,
539 cx: &mut ModelContext<Self>,
540 ) {
541 let project = project.downgrade();
542 let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
543 project_state
544 } else {
545 return;
546 };
547 let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
548 worktree
549 } else {
550 return;
551 };
552 let worktree_abs_path = worktree.abs_path().clone();
553 let scan_complete = worktree.scan_complete();
554 let worktree_id = worktree.id();
555 let db = self.db.clone();
556 let language_registry = self.language_registry.clone();
557 let (mut done_tx, done_rx) = watch::channel();
558 let registration = cx.spawn(|this, mut cx| {
559 async move {
560 let register = async {
561 scan_complete.await;
562 let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
563 let mut file_mtimes = db.get_file_mtimes(db_id).await?;
564 let worktree = if let Some(project) = project.upgrade() {
565 project
566 .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
567 .ok()
568 .flatten()
569 .context("worktree not found")?
570 } else {
571 return anyhow::Ok(());
572 };
573 let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot())?;
574 let mut changed_paths = cx
575 .background_executor()
576 .spawn(async move {
577 let mut changed_paths = BTreeMap::new();
578 for file in worktree.files(false, 0) {
579 let absolute_path = worktree.absolutize(&file.path)?;
580
581 if file.is_external || file.is_ignored || file.is_symlink {
582 continue;
583 }
584
585 if let Ok(language) = language_registry
586 .language_for_file_path(&absolute_path)
587 .await
588 {
589 // Test if file is valid parseable file
590 if !PARSEABLE_ENTIRE_FILE_TYPES
591 .contains(&language.name().as_ref())
592 && &language.name().as_ref() != &"Markdown"
593 && language
594 .grammar()
595 .and_then(|grammar| grammar.embedding_config.as_ref())
596 .is_none()
597 {
598 continue;
599 }
600 let Some(new_mtime) = file.mtime else {
601 continue;
602 };
603
604 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
605 let already_stored = stored_mtime == Some(new_mtime);
606
607 if !already_stored {
608 changed_paths.insert(
609 file.path.clone(),
610 ChangedPathInfo {
611 mtime: new_mtime,
612 is_deleted: false,
613 },
614 );
615 }
616 }
617 }
618
619 // Clean up entries from database that are no longer in the worktree.
620 for (path, mtime) in file_mtimes {
621 changed_paths.insert(
622 path.into(),
623 ChangedPathInfo {
624 mtime,
625 is_deleted: true,
626 },
627 );
628 }
629
630 anyhow::Ok(changed_paths)
631 })
632 .await?;
633 this.update(&mut cx, |this, cx| {
634 let project_state = this
635 .projects
636 .get_mut(&project)
637 .context("project not registered")?;
638 let project = project.upgrade().context("project was dropped")?;
639
640 if let Some(WorktreeState::Registering(state)) =
641 project_state.worktrees.remove(&worktree_id)
642 {
643 changed_paths.extend(state.changed_paths);
644 }
645 project_state.worktrees.insert(
646 worktree_id,
647 WorktreeState::Registered(RegisteredWorktreeState {
648 db_id,
649 changed_paths,
650 }),
651 );
652 this.index_project(project, cx).detach_and_log_err(cx);
653
654 anyhow::Ok(())
655 })??;
656
657 anyhow::Ok(())
658 };
659
660 if register.await.log_err().is_none() {
661 // Stop tracking this worktree if the registration failed.
662 this.update(&mut cx, |this, _| {
663 if let Some(project_state) = this.projects.get_mut(&project) {
664 project_state.worktrees.remove(&worktree_id);
665 }
666 })
667 .ok();
668 }
669
670 *done_tx.borrow_mut() = Some(());
671 }
672 });
673 project_state.worktrees.insert(
674 worktree_id,
675 WorktreeState::Registering(RegisteringWorktreeState {
676 changed_paths: Default::default(),
677 done_rx,
678 _registration: registration,
679 }),
680 );
681 }
682
683 fn project_worktrees_changed(&mut self, project: Model<Project>, cx: &mut ModelContext<Self>) {
684 let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
685 {
686 project_state
687 } else {
688 return;
689 };
690
691 let mut worktrees = project
692 .read(cx)
693 .worktrees()
694 .filter(|worktree| worktree.read(cx).is_local())
695 .collect::<Vec<_>>();
696 let worktree_ids = worktrees
697 .iter()
698 .map(|worktree| worktree.read(cx).id())
699 .collect::<HashSet<_>>();
700
701 // Remove worktrees that are no longer present
702 project_state
703 .worktrees
704 .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
705
706 // Register new worktrees
707 worktrees.retain(|worktree| {
708 let worktree_id = worktree.read(cx).id();
709 !project_state.worktrees.contains_key(&worktree_id)
710 });
711 for worktree in worktrees {
712 self.register_worktree(project.clone(), worktree, cx);
713 }
714 }
715
716 pub fn pending_file_count(&self, project: &Model<Project>) -> Option<watch::Receiver<usize>> {
717 Some(
718 self.projects
719 .get(&project.downgrade())?
720 .pending_file_count_rx
721 .clone(),
722 )
723 }
724
725 pub fn search_project(
726 &mut self,
727 project: Model<Project>,
728 query: String,
729 limit: usize,
730 includes: Vec<PathMatcher>,
731 excludes: Vec<PathMatcher>,
732 cx: &mut ModelContext<Self>,
733 ) -> Task<Result<Vec<SearchResult>>> {
734 if query.is_empty() {
735 return Task::ready(Ok(Vec::new()));
736 }
737
738 let index = self.index_project(project.clone(), cx);
739 let embedding_provider = self.embedding_provider.clone();
740
741 cx.spawn(|this, mut cx| async move {
742 index.await?;
743 let t0 = Instant::now();
744
745 let query = embedding_provider
746 .embed_batch(vec![query])
747 .await?
748 .pop()
749 .context("could not embed query")?;
750 log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
751
752 let search_start = Instant::now();
753 let modified_buffer_results = this.update(&mut cx, |this, cx| {
754 this.search_modified_buffers(
755 &project,
756 query.clone(),
757 limit,
758 &includes,
759 &excludes,
760 cx,
761 )
762 })?;
763 let file_results = this.update(&mut cx, |this, cx| {
764 this.search_files(project, query, limit, includes, excludes, cx)
765 })?;
766 let (modified_buffer_results, file_results) =
767 futures::join!(modified_buffer_results, file_results);
768
769 // Weave together the results from modified buffers and files.
770 let mut results = Vec::new();
771 let mut modified_buffers = HashSet::default();
772 for result in modified_buffer_results.log_err().unwrap_or_default() {
773 modified_buffers.insert(result.buffer.clone());
774 results.push(result);
775 }
776 for result in file_results.log_err().unwrap_or_default() {
777 if !modified_buffers.contains(&result.buffer) {
778 results.push(result);
779 }
780 }
781 results.sort_by_key(|result| Reverse(result.similarity));
782 results.truncate(limit);
783 log::trace!("Semantic search took {:?}", search_start.elapsed());
784 Ok(results)
785 })
786 }
787
788 pub fn search_files(
789 &mut self,
790 project: Model<Project>,
791 query: Embedding,
792 limit: usize,
793 includes: Vec<PathMatcher>,
794 excludes: Vec<PathMatcher>,
795 cx: &mut ModelContext<Self>,
796 ) -> Task<Result<Vec<SearchResult>>> {
797 let db_path = self.db.path().clone();
798 let fs = self.fs.clone();
799 cx.spawn(|this, mut cx| async move {
800 let database = VectorDatabase::new(
801 fs.clone(),
802 db_path.clone(),
803 cx.background_executor().clone(),
804 )
805 .await?;
806
807 let worktree_db_ids = this.read_with(&cx, |this, _| {
808 let project_state = this
809 .projects
810 .get(&project.downgrade())
811 .context("project was not indexed")?;
812 let worktree_db_ids = project_state
813 .worktrees
814 .values()
815 .filter_map(|worktree| {
816 if let WorktreeState::Registered(worktree) = worktree {
817 Some(worktree.db_id)
818 } else {
819 None
820 }
821 })
822 .collect::<Vec<i64>>();
823 anyhow::Ok(worktree_db_ids)
824 })??;
825
826 let file_ids = database
827 .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
828 .await?;
829
830 let batch_n = cx.background_executor().num_cpus();
831 let ids_len = file_ids.clone().len();
832 let minimum_batch_size = 50;
833
834 let batch_size = {
835 let size = ids_len / batch_n;
836 if size < minimum_batch_size {
837 minimum_batch_size
838 } else {
839 size
840 }
841 };
842
843 let mut batch_results = Vec::new();
844 for batch in file_ids.chunks(batch_size) {
845 let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
846 let fs = fs.clone();
847 let db_path = db_path.clone();
848 let query = query.clone();
849 if let Some(db) =
850 VectorDatabase::new(fs, db_path.clone(), cx.background_executor().clone())
851 .await
852 .log_err()
853 {
854 batch_results.push(async move {
855 db.top_k_search(&query, limit, batch.as_slice()).await
856 });
857 }
858 }
859
860 let batch_results = futures::future::join_all(batch_results).await;
861
862 let mut results = Vec::new();
863 for batch_result in batch_results {
864 if batch_result.is_ok() {
865 for (id, similarity) in batch_result.unwrap() {
866 let ix = match results
867 .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
868 {
869 Ok(ix) => ix,
870 Err(ix) => ix,
871 };
872
873 results.insert(ix, (id, similarity));
874 results.truncate(limit);
875 }
876 }
877 }
878
879 let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
880 let scores = results
881 .into_iter()
882 .map(|(_, score)| score)
883 .collect::<Vec<_>>();
884 let spans = database.spans_for_ids(ids.as_slice()).await?;
885
886 let mut tasks = Vec::new();
887 let mut ranges = Vec::new();
888 let weak_project = project.downgrade();
889 project.update(&mut cx, |project, cx| {
890 let this = this.upgrade().context("index was dropped")?;
891 for (worktree_db_id, file_path, byte_range) in spans {
892 let project_state =
893 if let Some(state) = this.read(cx).projects.get(&weak_project) {
894 state
895 } else {
896 return Err(anyhow!("project not added"));
897 };
898 if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
899 tasks.push(project.open_buffer((worktree_id, file_path), cx));
900 ranges.push(byte_range);
901 }
902 }
903
904 Ok(())
905 })??;
906
907 let buffers = futures::future::join_all(tasks).await;
908 Ok(buffers
909 .into_iter()
910 .zip(ranges)
911 .zip(scores)
912 .filter_map(|((buffer, range), similarity)| {
913 let buffer = buffer.log_err()?;
914 let range = buffer
915 .read_with(&cx, |buffer, _| {
916 let start = buffer.clip_offset(range.start, Bias::Left);
917 let end = buffer.clip_offset(range.end, Bias::Right);
918 buffer.anchor_before(start)..buffer.anchor_after(end)
919 })
920 .log_err()?;
921 Some(SearchResult {
922 buffer,
923 range,
924 similarity,
925 })
926 })
927 .collect())
928 })
929 }
930
931 fn search_modified_buffers(
932 &self,
933 project: &Model<Project>,
934 query: Embedding,
935 limit: usize,
936 includes: &[PathMatcher],
937 excludes: &[PathMatcher],
938 cx: &mut ModelContext<Self>,
939 ) -> Task<Result<Vec<SearchResult>>> {
940 let modified_buffers = project
941 .read(cx)
942 .opened_buffers()
943 .into_iter()
944 .filter_map(|buffer_handle| {
945 let buffer = buffer_handle.read(cx);
946 let snapshot = buffer.snapshot();
947 let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
948 excludes.iter().any(|matcher| matcher.is_match(&path))
949 });
950
951 let included = if includes.len() == 0 {
952 true
953 } else {
954 snapshot.resolve_file_path(cx, false).map_or(false, |path| {
955 includes.iter().any(|matcher| matcher.is_match(&path))
956 })
957 };
958
959 if buffer.is_dirty() && !excluded && included {
960 Some((buffer_handle, snapshot))
961 } else {
962 None
963 }
964 })
965 .collect::<HashMap<_, _>>();
966
967 let embedding_provider = self.embedding_provider.clone();
968 let fs = self.fs.clone();
969 let db_path = self.db.path().clone();
970 let background = cx.background_executor().clone();
971 cx.background_executor().spawn(async move {
972 let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
973 let mut results = Vec::<SearchResult>::new();
974
975 let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
976 for (buffer, snapshot) in modified_buffers {
977 let language = snapshot
978 .language_at(0)
979 .cloned()
980 .unwrap_or_else(|| language::PLAIN_TEXT.clone());
981 let mut spans = retriever
982 .parse_file_with_template(None, &snapshot.text(), language)
983 .log_err()
984 .unwrap_or_default();
985 if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
986 .await
987 .log_err()
988 .is_some()
989 {
990 for span in spans {
991 let similarity = span.embedding.unwrap().similarity(&query);
992 let ix = match results
993 .binary_search_by_key(&Reverse(similarity), |result| {
994 Reverse(result.similarity)
995 }) {
996 Ok(ix) => ix,
997 Err(ix) => ix,
998 };
999
1000 let range = {
1001 let start = snapshot.clip_offset(span.range.start, Bias::Left);
1002 let end = snapshot.clip_offset(span.range.end, Bias::Right);
1003 snapshot.anchor_before(start)..snapshot.anchor_after(end)
1004 };
1005
1006 results.insert(
1007 ix,
1008 SearchResult {
1009 buffer: buffer.clone(),
1010 range,
1011 similarity,
1012 },
1013 );
1014 results.truncate(limit);
1015 }
1016 }
1017 }
1018
1019 Ok(results)
1020 })
1021 }
1022
1023 pub fn index_project(
1024 &mut self,
1025 project: Model<Project>,
1026 cx: &mut ModelContext<Self>,
1027 ) -> Task<Result<()>> {
1028 if self.is_authenticated() {
1029 self.index_project_internal(project, cx)
1030 } else {
1031 let authenticate = self.authenticate(cx);
1032 cx.spawn(|this, mut cx| async move {
1033 if authenticate.await {
1034 this.update(&mut cx, |this, cx| this.index_project_internal(project, cx))?
1035 .await
1036 } else {
1037 Err(anyhow!("user is not authenticated"))
1038 }
1039 })
1040 }
1041 }
1042
1043 fn index_project_internal(
1044 &mut self,
1045 project: Model<Project>,
1046 cx: &mut ModelContext<Self>,
1047 ) -> Task<Result<()>> {
1048 if !self.projects.contains_key(&project.downgrade()) {
1049 let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
1050 project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
1051 this.project_worktrees_changed(project.clone(), cx);
1052 }
1053 project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
1054 this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
1055 }
1056 _ => {}
1057 });
1058 let project_state = ProjectState::new(subscription, cx);
1059 self.projects.insert(project.downgrade(), project_state);
1060 self.project_worktrees_changed(project.clone(), cx);
1061 }
1062 let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
1063 project_state.pending_index += 1;
1064 cx.notify();
1065
1066 let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
1067 let db = self.db.clone();
1068 let language_registry = self.language_registry.clone();
1069 let parsing_files_tx = self.parsing_files_tx.clone();
1070 let worktree_registration = self.wait_for_worktree_registration(&project, cx);
1071
1072 cx.spawn(|this, mut cx| async move {
1073 worktree_registration.await?;
1074
1075 let mut pending_files = Vec::new();
1076 let mut files_to_delete = Vec::new();
1077 this.update(&mut cx, |this, cx| {
1078 let project_state = this
1079 .projects
1080 .get_mut(&project.downgrade())
1081 .context("project was dropped")?;
1082 let pending_file_count_tx = &project_state.pending_file_count_tx;
1083
1084 project_state
1085 .worktrees
1086 .retain(|worktree_id, worktree_state| {
1087 let worktree = if let Some(worktree) =
1088 project.read(cx).worktree_for_id(*worktree_id, cx)
1089 {
1090 worktree
1091 } else {
1092 return false;
1093 };
1094 let worktree_state =
1095 if let WorktreeState::Registered(worktree_state) = worktree_state {
1096 worktree_state
1097 } else {
1098 return true;
1099 };
1100
1101 for (path, info) in &worktree_state.changed_paths {
1102 if info.is_deleted {
1103 files_to_delete.push((worktree_state.db_id, path.clone()));
1104 } else if let Ok(absolute_path) = worktree.read(cx).absolutize(path) {
1105 let job_handle = JobHandle::new(pending_file_count_tx);
1106 pending_files.push(PendingFile {
1107 absolute_path,
1108 relative_path: path.clone(),
1109 language: None,
1110 job_handle,
1111 modified_time: info.mtime,
1112 worktree_db_id: worktree_state.db_id,
1113 });
1114 }
1115 }
1116 worktree_state.changed_paths.clear();
1117 true
1118 });
1119
1120 anyhow::Ok(())
1121 })??;
1122
1123 cx.background_executor()
1124 .spawn(async move {
1125 for (worktree_db_id, path) in files_to_delete {
1126 db.delete_file(worktree_db_id, path).await.log_err();
1127 }
1128
1129 let embeddings_for_digest = {
1130 let mut files = HashMap::default();
1131 for pending_file in &pending_files {
1132 files
1133 .entry(pending_file.worktree_db_id)
1134 .or_insert(Vec::new())
1135 .push(pending_file.relative_path.clone());
1136 }
1137 Arc::new(
1138 db.embeddings_for_files(files)
1139 .await
1140 .log_err()
1141 .unwrap_or_default(),
1142 )
1143 };
1144
1145 for mut pending_file in pending_files {
1146 if let Ok(language) = language_registry
1147 .language_for_file_path(&pending_file.relative_path)
1148 .await
1149 {
1150 if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
1151 && &language.name().as_ref() != &"Markdown"
1152 && language
1153 .grammar()
1154 .and_then(|grammar| grammar.embedding_config.as_ref())
1155 .is_none()
1156 {
1157 continue;
1158 }
1159 pending_file.language = Some(language);
1160 }
1161 parsing_files_tx
1162 .try_send((embeddings_for_digest.clone(), pending_file))
1163 .ok();
1164 }
1165
1166 // Wait until we're done indexing.
1167 while let Some(count) = pending_file_count_rx.next().await {
1168 if count == 0 {
1169 break;
1170 }
1171 }
1172 })
1173 .await;
1174
1175 this.update(&mut cx, |this, cx| {
1176 let project_state = this
1177 .projects
1178 .get_mut(&project.downgrade())
1179 .context("project was dropped")?;
1180 project_state.pending_index -= 1;
1181 cx.notify();
1182 anyhow::Ok(())
1183 })??;
1184
1185 Ok(())
1186 })
1187 }
1188
1189 fn wait_for_worktree_registration(
1190 &self,
1191 project: &Model<Project>,
1192 cx: &mut ModelContext<Self>,
1193 ) -> Task<Result<()>> {
1194 let project = project.downgrade();
1195 cx.spawn(|this, cx| async move {
1196 loop {
1197 let mut pending_worktrees = Vec::new();
1198 this.upgrade()
1199 .context("semantic index dropped")?
1200 .read_with(&cx, |this, _| {
1201 if let Some(project) = this.projects.get(&project) {
1202 for worktree in project.worktrees.values() {
1203 if let WorktreeState::Registering(worktree) = worktree {
1204 pending_worktrees.push(worktree.done());
1205 }
1206 }
1207 }
1208 })?;
1209
1210 if pending_worktrees.is_empty() {
1211 break;
1212 } else {
1213 future::join_all(pending_worktrees).await;
1214 }
1215 }
1216 Ok(())
1217 })
1218 }
1219
1220 async fn embed_spans(
1221 spans: &mut [Span],
1222 embedding_provider: &dyn EmbeddingProvider,
1223 db: &VectorDatabase,
1224 ) -> Result<()> {
1225 let mut batch = Vec::new();
1226 let mut batch_tokens = 0;
1227 let mut embeddings = Vec::new();
1228
1229 let digests = spans
1230 .iter()
1231 .map(|span| span.digest.clone())
1232 .collect::<Vec<_>>();
1233 let embeddings_for_digests = db
1234 .embeddings_for_digests(digests)
1235 .await
1236 .log_err()
1237 .unwrap_or_default();
1238
1239 for span in &*spans {
1240 if embeddings_for_digests.contains_key(&span.digest) {
1241 continue;
1242 };
1243
1244 if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
1245 let batch_embeddings = embedding_provider
1246 .embed_batch(mem::take(&mut batch))
1247 .await?;
1248 embeddings.extend(batch_embeddings);
1249 batch_tokens = 0;
1250 }
1251
1252 batch_tokens += span.token_count;
1253 batch.push(span.content.clone());
1254 }
1255
1256 if !batch.is_empty() {
1257 let batch_embeddings = embedding_provider
1258 .embed_batch(mem::take(&mut batch))
1259 .await?;
1260
1261 embeddings.extend(batch_embeddings);
1262 }
1263
1264 let mut embeddings = embeddings.into_iter();
1265 for span in spans {
1266 let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
1267 Some(embedding.clone())
1268 } else {
1269 embeddings.next()
1270 };
1271 let embedding = embedding.context("failed to embed spans")?;
1272 span.embedding = Some(embedding);
1273 }
1274 Ok(())
1275 }
1276}
1277
1278impl Drop for JobHandle {
1279 fn drop(&mut self) {
1280 if let Some(inner) = Arc::get_mut(&mut self.tx) {
1281 // This is the last instance of the JobHandle (regardless of its origin - whether it was cloned or not)
1282 if let Some(tx) = inner.upgrade() {
1283 let mut tx = tx.lock();
1284 *tx.borrow_mut() -= 1;
1285 }
1286 }
1287 }
1288}
1289
1290#[cfg(test)]
1291mod tests {
1292
1293 use super::*;
1294 #[test]
1295 fn test_job_handle() {
1296 let (job_count_tx, job_count_rx) = watch::channel_with(0);
1297 let tx = Arc::new(Mutex::new(job_count_tx));
1298 let job_handle = JobHandle::new(&tx);
1299
1300 assert_eq!(1, *job_count_rx.borrow());
1301 let new_job_handle = job_handle.clone();
1302 assert_eq!(1, *job_count_rx.borrow());
1303 drop(job_handle);
1304 assert_eq!(1, *job_count_rx.borrow());
1305 drop(new_job_handle);
1306 assert_eq!(0, *job_count_rx.borrow());
1307 }
1308}