1mod db;
2mod embedding_queue;
3mod parsing;
4pub mod semantic_index_settings;
5
6#[cfg(test)]
7mod semantic_index_tests;
8
9use crate::semantic_index_settings::SemanticIndexSettings;
10use ai::embedding::{Embedding, EmbeddingProvider};
11use ai::providers::open_ai::OpenAIEmbeddingProvider;
12use anyhow::{anyhow, Result};
13use collections::{BTreeMap, HashMap, HashSet};
14use db::VectorDatabase;
15use embedding_queue::{EmbeddingQueue, FileToEmbed};
16use futures::{future, FutureExt, StreamExt};
17use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
18use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
19use lazy_static::lazy_static;
20use ordered_float::OrderedFloat;
21use parking_lot::Mutex;
22use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
23use postage::watch;
24use project::{Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
25use smol::channel;
26use std::{
27 cmp::Reverse,
28 env,
29 future::Future,
30 mem,
31 ops::Range,
32 path::{Path, PathBuf},
33 sync::{Arc, Weak},
34 time::{Duration, Instant, SystemTime},
35};
36use util::paths::PathMatcher;
37use util::{channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
38use workspace::WorkspaceCreated;
39
40const SEMANTIC_INDEX_VERSION: usize = 11;
41const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
42const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
43
44lazy_static! {
45 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
46}
47
48pub fn init(
49 fs: Arc<dyn Fs>,
50 http_client: Arc<dyn HttpClient>,
51 language_registry: Arc<LanguageRegistry>,
52 cx: &mut AppContext,
53) {
54 settings::register::<SemanticIndexSettings>(cx);
55
56 let db_file_path = EMBEDDINGS_DIR
57 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
58 .join("embeddings_db");
59
60 cx.subscribe_global::<WorkspaceCreated, _>({
61 move |event, cx| {
62 let Some(semantic_index) = SemanticIndex::global(cx) else {
63 return;
64 };
65 let workspace = &event.0;
66 if let Some(workspace) = workspace.upgrade(cx) {
67 let project = workspace.read(cx).project().clone();
68 if project.read(cx).is_local() {
69 cx.spawn(|mut cx| async move {
70 let previously_indexed = semantic_index
71 .update(&mut cx, |index, cx| {
72 index.project_previously_indexed(&project, cx)
73 })
74 .await?;
75 if previously_indexed {
76 semantic_index
77 .update(&mut cx, |index, cx| index.index_project(project, cx))
78 .await?;
79 }
80 anyhow::Ok(())
81 })
82 .detach_and_log_err(cx);
83 }
84 }
85 }
86 })
87 .detach();
88
89 cx.spawn(move |mut cx| async move {
90 let semantic_index = SemanticIndex::new(
91 fs,
92 db_file_path,
93 Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
94 language_registry,
95 cx.clone(),
96 )
97 .await?;
98
99 cx.update(|cx| {
100 cx.set_global(semantic_index.clone());
101 });
102
103 anyhow::Ok(())
104 })
105 .detach();
106}
107
108#[derive(Copy, Clone, Debug)]
109pub enum SemanticIndexStatus {
110 NotAuthenticated,
111 NotIndexed,
112 Indexed,
113 Indexing {
114 remaining_files: usize,
115 rate_limit_expiry: Option<Instant>,
116 },
117}
118
119pub struct SemanticIndex {
120 fs: Arc<dyn Fs>,
121 db: VectorDatabase,
122 embedding_provider: Arc<dyn EmbeddingProvider>,
123 language_registry: Arc<LanguageRegistry>,
124 parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
125 _embedding_task: Task<()>,
126 _parsing_files_tasks: Vec<Task<()>>,
127 projects: HashMap<WeakModelHandle<Project>, ProjectState>,
128}
129
130struct ProjectState {
131 worktrees: HashMap<WorktreeId, WorktreeState>,
132 pending_file_count_rx: watch::Receiver<usize>,
133 pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
134 pending_index: usize,
135 _subscription: gpui::Subscription,
136 _observe_pending_file_count: Task<()>,
137}
138
139enum WorktreeState {
140 Registering(RegisteringWorktreeState),
141 Registered(RegisteredWorktreeState),
142}
143
144impl WorktreeState {
145 fn is_registered(&self) -> bool {
146 matches!(self, Self::Registered(_))
147 }
148
149 fn paths_changed(
150 &mut self,
151 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
152 worktree: &Worktree,
153 ) {
154 let changed_paths = match self {
155 Self::Registering(state) => &mut state.changed_paths,
156 Self::Registered(state) => &mut state.changed_paths,
157 };
158
159 for (path, entry_id, change) in changes.iter() {
160 let Some(entry) = worktree.entry_for_id(*entry_id) else {
161 continue;
162 };
163 if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
164 continue;
165 }
166 changed_paths.insert(
167 path.clone(),
168 ChangedPathInfo {
169 mtime: entry.mtime,
170 is_deleted: *change == PathChange::Removed,
171 },
172 );
173 }
174 }
175}
176
177struct RegisteringWorktreeState {
178 changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
179 done_rx: watch::Receiver<Option<()>>,
180 _registration: Task<()>,
181}
182
183impl RegisteringWorktreeState {
184 fn done(&self) -> impl Future<Output = ()> {
185 let mut done_rx = self.done_rx.clone();
186 async move {
187 while let Some(result) = done_rx.next().await {
188 if result.is_some() {
189 break;
190 }
191 }
192 }
193 }
194}
195
196struct RegisteredWorktreeState {
197 db_id: i64,
198 changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
199}
200
201struct ChangedPathInfo {
202 mtime: SystemTime,
203 is_deleted: bool,
204}
205
206#[derive(Clone)]
207pub struct JobHandle {
208 /// The outer Arc is here to count the clones of a JobHandle instance;
209 /// when the last handle to a given job is dropped, we decrement a counter (just once).
210 tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
211}
212
213impl JobHandle {
214 fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
215 *tx.lock().borrow_mut() += 1;
216 Self {
217 tx: Arc::new(Arc::downgrade(&tx)),
218 }
219 }
220}
221
222impl ProjectState {
223 fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
224 let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
225 let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
226 Self {
227 worktrees: Default::default(),
228 pending_file_count_rx: pending_file_count_rx.clone(),
229 pending_file_count_tx,
230 pending_index: 0,
231 _subscription: subscription,
232 _observe_pending_file_count: cx.spawn_weak({
233 let mut pending_file_count_rx = pending_file_count_rx.clone();
234 |this, mut cx| async move {
235 while let Some(_) = pending_file_count_rx.next().await {
236 if let Some(this) = this.upgrade(&cx) {
237 this.update(&mut cx, |_, cx| cx.notify());
238 }
239 }
240 }
241 }),
242 }
243 }
244
245 fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
246 self.worktrees
247 .iter()
248 .find_map(|(worktree_id, worktree_state)| match worktree_state {
249 WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
250 _ => None,
251 })
252 }
253}
254
255#[derive(Clone)]
256pub struct PendingFile {
257 worktree_db_id: i64,
258 relative_path: Arc<Path>,
259 absolute_path: PathBuf,
260 language: Option<Arc<Language>>,
261 modified_time: SystemTime,
262 job_handle: JobHandle,
263}
264
265#[derive(Clone)]
266pub struct SearchResult {
267 pub buffer: ModelHandle<Buffer>,
268 pub range: Range<Anchor>,
269 pub similarity: OrderedFloat<f32>,
270}
271
272impl SemanticIndex {
273 pub fn global(cx: &mut AppContext) -> Option<ModelHandle<SemanticIndex>> {
274 if cx.has_global::<ModelHandle<Self>>() {
275 Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
276 } else {
277 None
278 }
279 }
280
281 pub fn authenticate(&mut self, cx: &AppContext) -> bool {
282 if !self.embedding_provider.has_credentials() {
283 self.embedding_provider.retrieve_credentials(cx);
284 } else {
285 return true;
286 }
287
288 self.embedding_provider.has_credentials()
289 }
290
291 pub fn is_authenticated(&self) -> bool {
292 self.embedding_provider.has_credentials()
293 }
294
295 pub fn enabled(cx: &AppContext) -> bool {
296 settings::get::<SemanticIndexSettings>(cx).enabled
297 }
298
299 pub fn status(&self, project: &ModelHandle<Project>) -> SemanticIndexStatus {
300 if !self.is_authenticated() {
301 return SemanticIndexStatus::NotAuthenticated;
302 }
303
304 if let Some(project_state) = self.projects.get(&project.downgrade()) {
305 if project_state
306 .worktrees
307 .values()
308 .all(|worktree| worktree.is_registered())
309 && project_state.pending_index == 0
310 {
311 SemanticIndexStatus::Indexed
312 } else {
313 SemanticIndexStatus::Indexing {
314 remaining_files: project_state.pending_file_count_rx.borrow().clone(),
315 rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
316 }
317 }
318 } else {
319 SemanticIndexStatus::NotIndexed
320 }
321 }
322
323 pub async fn new(
324 fs: Arc<dyn Fs>,
325 database_path: PathBuf,
326 embedding_provider: Arc<dyn EmbeddingProvider>,
327 language_registry: Arc<LanguageRegistry>,
328 mut cx: AsyncAppContext,
329 ) -> Result<ModelHandle<Self>> {
330 let t0 = Instant::now();
331 let database_path = Arc::from(database_path);
332 let db = VectorDatabase::new(fs.clone(), database_path, cx.background()).await?;
333
334 log::trace!(
335 "db initialization took {:?} milliseconds",
336 t0.elapsed().as_millis()
337 );
338
339 Ok(cx.add_model(|cx| {
340 let t0 = Instant::now();
341 let embedding_queue =
342 EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
343 let _embedding_task = cx.background().spawn({
344 let embedded_files = embedding_queue.finished_files();
345 let db = db.clone();
346 async move {
347 while let Ok(file) = embedded_files.recv().await {
348 db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
349 .await
350 .log_err();
351 }
352 }
353 });
354
355 // Parse files into embeddable spans.
356 let (parsing_files_tx, parsing_files_rx) =
357 channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
358 let embedding_queue = Arc::new(Mutex::new(embedding_queue));
359 let mut _parsing_files_tasks = Vec::new();
360 for _ in 0..cx.background().num_cpus() {
361 let fs = fs.clone();
362 let mut parsing_files_rx = parsing_files_rx.clone();
363 let embedding_provider = embedding_provider.clone();
364 let embedding_queue = embedding_queue.clone();
365 let background = cx.background().clone();
366 _parsing_files_tasks.push(cx.background().spawn(async move {
367 let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
368 loop {
369 let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
370 let mut next_file_to_parse = parsing_files_rx.next().fuse();
371 futures::select_biased! {
372 next_file_to_parse = next_file_to_parse => {
373 if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
374 Self::parse_file(
375 &fs,
376 pending_file,
377 &mut retriever,
378 &embedding_queue,
379 &embeddings_for_digest,
380 )
381 .await
382 } else {
383 break;
384 }
385 },
386 _ = timer => {
387 embedding_queue.lock().flush();
388 }
389 }
390 }
391 }));
392 }
393
394 log::trace!(
395 "semantic index task initialization took {:?} milliseconds",
396 t0.elapsed().as_millis()
397 );
398 Self {
399 fs,
400 db,
401 embedding_provider,
402 language_registry,
403 parsing_files_tx,
404 _embedding_task,
405 _parsing_files_tasks,
406 projects: Default::default(),
407 }
408 }))
409 }
410
411 async fn parse_file(
412 fs: &Arc<dyn Fs>,
413 pending_file: PendingFile,
414 retriever: &mut CodeContextRetriever,
415 embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
416 embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
417 ) {
418 let Some(language) = pending_file.language else {
419 return;
420 };
421
422 if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
423 if let Some(mut spans) = retriever
424 .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
425 .log_err()
426 {
427 log::trace!(
428 "parsed path {:?}: {} spans",
429 pending_file.relative_path,
430 spans.len()
431 );
432
433 for span in &mut spans {
434 if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
435 span.embedding = Some(embedding.to_owned());
436 }
437 }
438
439 embedding_queue.lock().push(FileToEmbed {
440 worktree_id: pending_file.worktree_db_id,
441 path: pending_file.relative_path,
442 mtime: pending_file.modified_time,
443 job_handle: pending_file.job_handle,
444 spans,
445 });
446 }
447 }
448 }
449
450 pub fn project_previously_indexed(
451 &mut self,
452 project: &ModelHandle<Project>,
453 cx: &mut ModelContext<Self>,
454 ) -> Task<Result<bool>> {
455 let worktrees_indexed_previously = project
456 .read(cx)
457 .worktrees(cx)
458 .map(|worktree| {
459 self.db
460 .worktree_previously_indexed(&worktree.read(cx).abs_path())
461 })
462 .collect::<Vec<_>>();
463 cx.spawn(|_, _cx| async move {
464 let worktree_indexed_previously =
465 futures::future::join_all(worktrees_indexed_previously).await;
466
467 Ok(worktree_indexed_previously
468 .iter()
469 .filter(|worktree| worktree.is_ok())
470 .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
471 })
472 }
473
474 fn project_entries_changed(
475 &mut self,
476 project: ModelHandle<Project>,
477 worktree_id: WorktreeId,
478 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
479 cx: &mut ModelContext<Self>,
480 ) {
481 let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
482 return;
483 };
484 let project = project.downgrade();
485 let Some(project_state) = self.projects.get_mut(&project) else {
486 return;
487 };
488
489 let worktree = worktree.read(cx);
490 let worktree_state =
491 if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
492 worktree_state
493 } else {
494 return;
495 };
496 worktree_state.paths_changed(changes, worktree);
497 if let WorktreeState::Registered(_) = worktree_state {
498 cx.spawn_weak(|this, mut cx| async move {
499 cx.background().timer(BACKGROUND_INDEXING_DELAY).await;
500 if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) {
501 this.update(&mut cx, |this, cx| {
502 this.index_project(project, cx).detach_and_log_err(cx)
503 });
504 }
505 })
506 .detach();
507 }
508 }
509
510 fn register_worktree(
511 &mut self,
512 project: ModelHandle<Project>,
513 worktree: ModelHandle<Worktree>,
514 cx: &mut ModelContext<Self>,
515 ) {
516 let project = project.downgrade();
517 let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
518 project_state
519 } else {
520 return;
521 };
522 let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
523 worktree
524 } else {
525 return;
526 };
527 let worktree_abs_path = worktree.abs_path().clone();
528 let scan_complete = worktree.scan_complete();
529 let worktree_id = worktree.id();
530 let db = self.db.clone();
531 let language_registry = self.language_registry.clone();
532 let (mut done_tx, done_rx) = watch::channel();
533 let registration = cx.spawn(|this, mut cx| {
534 async move {
535 let register = async {
536 scan_complete.await;
537 let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
538 let mut file_mtimes = db.get_file_mtimes(db_id).await?;
539 let worktree = if let Some(project) = project.upgrade(&cx) {
540 project
541 .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
542 .ok_or_else(|| anyhow!("worktree not found"))?
543 } else {
544 return anyhow::Ok(());
545 };
546 let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot());
547 let mut changed_paths = cx
548 .background()
549 .spawn(async move {
550 let mut changed_paths = BTreeMap::new();
551 for file in worktree.files(false, 0) {
552 let absolute_path = worktree.absolutize(&file.path);
553
554 if file.is_external || file.is_ignored || file.is_symlink {
555 continue;
556 }
557
558 if let Ok(language) = language_registry
559 .language_for_file(&absolute_path, None)
560 .await
561 {
562 // Test if file is valid parseable file
563 if !PARSEABLE_ENTIRE_FILE_TYPES
564 .contains(&language.name().as_ref())
565 && &language.name().as_ref() != &"Markdown"
566 && language
567 .grammar()
568 .and_then(|grammar| grammar.embedding_config.as_ref())
569 .is_none()
570 {
571 continue;
572 }
573
574 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
575 let already_stored = stored_mtime
576 .map_or(false, |existing_mtime| {
577 existing_mtime == file.mtime
578 });
579
580 if !already_stored {
581 changed_paths.insert(
582 file.path.clone(),
583 ChangedPathInfo {
584 mtime: file.mtime,
585 is_deleted: false,
586 },
587 );
588 }
589 }
590 }
591
592 // Clean up entries from database that are no longer in the worktree.
593 for (path, mtime) in file_mtimes {
594 changed_paths.insert(
595 path.into(),
596 ChangedPathInfo {
597 mtime,
598 is_deleted: true,
599 },
600 );
601 }
602
603 anyhow::Ok(changed_paths)
604 })
605 .await?;
606 this.update(&mut cx, |this, cx| {
607 let project_state = this
608 .projects
609 .get_mut(&project)
610 .ok_or_else(|| anyhow!("project not registered"))?;
611 let project = project
612 .upgrade(cx)
613 .ok_or_else(|| anyhow!("project was dropped"))?;
614
615 if let Some(WorktreeState::Registering(state)) =
616 project_state.worktrees.remove(&worktree_id)
617 {
618 changed_paths.extend(state.changed_paths);
619 }
620 project_state.worktrees.insert(
621 worktree_id,
622 WorktreeState::Registered(RegisteredWorktreeState {
623 db_id,
624 changed_paths,
625 }),
626 );
627 this.index_project(project, cx).detach_and_log_err(cx);
628
629 anyhow::Ok(())
630 })?;
631
632 anyhow::Ok(())
633 };
634
635 if register.await.log_err().is_none() {
636 // Stop tracking this worktree if the registration failed.
637 this.update(&mut cx, |this, _| {
638 this.projects.get_mut(&project).map(|project_state| {
639 project_state.worktrees.remove(&worktree_id);
640 });
641 })
642 }
643
644 *done_tx.borrow_mut() = Some(());
645 }
646 });
647 project_state.worktrees.insert(
648 worktree_id,
649 WorktreeState::Registering(RegisteringWorktreeState {
650 changed_paths: Default::default(),
651 done_rx,
652 _registration: registration,
653 }),
654 );
655 }
656
657 fn project_worktrees_changed(
658 &mut self,
659 project: ModelHandle<Project>,
660 cx: &mut ModelContext<Self>,
661 ) {
662 let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
663 {
664 project_state
665 } else {
666 return;
667 };
668
669 let mut worktrees = project
670 .read(cx)
671 .worktrees(cx)
672 .filter(|worktree| worktree.read(cx).is_local())
673 .collect::<Vec<_>>();
674 let worktree_ids = worktrees
675 .iter()
676 .map(|worktree| worktree.read(cx).id())
677 .collect::<HashSet<_>>();
678
679 // Remove worktrees that are no longer present
680 project_state
681 .worktrees
682 .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
683
684 // Register new worktrees
685 worktrees.retain(|worktree| {
686 let worktree_id = worktree.read(cx).id();
687 !project_state.worktrees.contains_key(&worktree_id)
688 });
689 for worktree in worktrees {
690 self.register_worktree(project.clone(), worktree, cx);
691 }
692 }
693
694 pub fn pending_file_count(
695 &self,
696 project: &ModelHandle<Project>,
697 ) -> Option<watch::Receiver<usize>> {
698 Some(
699 self.projects
700 .get(&project.downgrade())?
701 .pending_file_count_rx
702 .clone(),
703 )
704 }
705
706 pub fn search_project(
707 &mut self,
708 project: ModelHandle<Project>,
709 query: String,
710 limit: usize,
711 includes: Vec<PathMatcher>,
712 excludes: Vec<PathMatcher>,
713 cx: &mut ModelContext<Self>,
714 ) -> Task<Result<Vec<SearchResult>>> {
715 if query.is_empty() {
716 return Task::ready(Ok(Vec::new()));
717 }
718
719 let index = self.index_project(project.clone(), cx);
720 let embedding_provider = self.embedding_provider.clone();
721
722 cx.spawn(|this, mut cx| async move {
723 index.await?;
724 let t0 = Instant::now();
725
726 let query = embedding_provider
727 .embed_batch(vec![query])
728 .await?
729 .pop()
730 .ok_or_else(|| anyhow!("could not embed query"))?;
731 log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
732
733 let search_start = Instant::now();
734 let modified_buffer_results = this.update(&mut cx, |this, cx| {
735 this.search_modified_buffers(
736 &project,
737 query.clone(),
738 limit,
739 &includes,
740 &excludes,
741 cx,
742 )
743 });
744 let file_results = this.update(&mut cx, |this, cx| {
745 this.search_files(project, query, limit, includes, excludes, cx)
746 });
747 let (modified_buffer_results, file_results) =
748 futures::join!(modified_buffer_results, file_results);
749
750 // Weave together the results from modified buffers and files.
751 let mut results = Vec::new();
752 let mut modified_buffers = HashSet::default();
753 for result in modified_buffer_results.log_err().unwrap_or_default() {
754 modified_buffers.insert(result.buffer.clone());
755 results.push(result);
756 }
757 for result in file_results.log_err().unwrap_or_default() {
758 if !modified_buffers.contains(&result.buffer) {
759 results.push(result);
760 }
761 }
762 results.sort_by_key(|result| Reverse(result.similarity));
763 results.truncate(limit);
764 log::trace!("Semantic search took {:?}", search_start.elapsed());
765 Ok(results)
766 })
767 }
768
769 pub fn search_files(
770 &mut self,
771 project: ModelHandle<Project>,
772 query: Embedding,
773 limit: usize,
774 includes: Vec<PathMatcher>,
775 excludes: Vec<PathMatcher>,
776 cx: &mut ModelContext<Self>,
777 ) -> Task<Result<Vec<SearchResult>>> {
778 let db_path = self.db.path().clone();
779 let fs = self.fs.clone();
780 cx.spawn(|this, mut cx| async move {
781 let database =
782 VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
783
784 let worktree_db_ids = this.read_with(&cx, |this, _| {
785 let project_state = this
786 .projects
787 .get(&project.downgrade())
788 .ok_or_else(|| anyhow!("project was not indexed"))?;
789 let worktree_db_ids = project_state
790 .worktrees
791 .values()
792 .filter_map(|worktree| {
793 if let WorktreeState::Registered(worktree) = worktree {
794 Some(worktree.db_id)
795 } else {
796 None
797 }
798 })
799 .collect::<Vec<i64>>();
800 anyhow::Ok(worktree_db_ids)
801 })?;
802
803 let file_ids = database
804 .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
805 .await?;
806
807 let batch_n = cx.background().num_cpus();
808 let ids_len = file_ids.clone().len();
809 let minimum_batch_size = 50;
810
811 let batch_size = {
812 let size = ids_len / batch_n;
813 if size < minimum_batch_size {
814 minimum_batch_size
815 } else {
816 size
817 }
818 };
819
820 let mut batch_results = Vec::new();
821 for batch in file_ids.chunks(batch_size) {
822 let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
823 let limit = limit.clone();
824 let fs = fs.clone();
825 let db_path = db_path.clone();
826 let query = query.clone();
827 if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
828 .await
829 .log_err()
830 {
831 batch_results.push(async move {
832 db.top_k_search(&query, limit, batch.as_slice()).await
833 });
834 }
835 }
836
837 let batch_results = futures::future::join_all(batch_results).await;
838
839 let mut results = Vec::new();
840 for batch_result in batch_results {
841 if batch_result.is_ok() {
842 for (id, similarity) in batch_result.unwrap() {
843 let ix = match results
844 .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
845 {
846 Ok(ix) => ix,
847 Err(ix) => ix,
848 };
849
850 results.insert(ix, (id, similarity));
851 results.truncate(limit);
852 }
853 }
854 }
855
856 let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
857 let scores = results
858 .into_iter()
859 .map(|(_, score)| score)
860 .collect::<Vec<_>>();
861 let spans = database.spans_for_ids(ids.as_slice()).await?;
862
863 let mut tasks = Vec::new();
864 let mut ranges = Vec::new();
865 let weak_project = project.downgrade();
866 project.update(&mut cx, |project, cx| {
867 for (worktree_db_id, file_path, byte_range) in spans {
868 let project_state =
869 if let Some(state) = this.read(cx).projects.get(&weak_project) {
870 state
871 } else {
872 return Err(anyhow!("project not added"));
873 };
874 if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
875 tasks.push(project.open_buffer((worktree_id, file_path), cx));
876 ranges.push(byte_range);
877 }
878 }
879
880 Ok(())
881 })?;
882
883 let buffers = futures::future::join_all(tasks).await;
884 Ok(buffers
885 .into_iter()
886 .zip(ranges)
887 .zip(scores)
888 .filter_map(|((buffer, range), similarity)| {
889 let buffer = buffer.log_err()?;
890 let range = buffer.read_with(&cx, |buffer, _| {
891 let start = buffer.clip_offset(range.start, Bias::Left);
892 let end = buffer.clip_offset(range.end, Bias::Right);
893 buffer.anchor_before(start)..buffer.anchor_after(end)
894 });
895 Some(SearchResult {
896 buffer,
897 range,
898 similarity,
899 })
900 })
901 .collect())
902 })
903 }
904
905 fn search_modified_buffers(
906 &self,
907 project: &ModelHandle<Project>,
908 query: Embedding,
909 limit: usize,
910 includes: &[PathMatcher],
911 excludes: &[PathMatcher],
912 cx: &mut ModelContext<Self>,
913 ) -> Task<Result<Vec<SearchResult>>> {
914 let modified_buffers = project
915 .read(cx)
916 .opened_buffers(cx)
917 .into_iter()
918 .filter_map(|buffer_handle| {
919 let buffer = buffer_handle.read(cx);
920 let snapshot = buffer.snapshot();
921 let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
922 excludes.iter().any(|matcher| matcher.is_match(&path))
923 });
924
925 let included = if includes.len() == 0 {
926 true
927 } else {
928 snapshot.resolve_file_path(cx, false).map_or(false, |path| {
929 includes.iter().any(|matcher| matcher.is_match(&path))
930 })
931 };
932
933 if buffer.is_dirty() && !excluded && included {
934 Some((buffer_handle, snapshot))
935 } else {
936 None
937 }
938 })
939 .collect::<HashMap<_, _>>();
940
941 let embedding_provider = self.embedding_provider.clone();
942 let fs = self.fs.clone();
943 let db_path = self.db.path().clone();
944 let background = cx.background().clone();
945 cx.background().spawn(async move {
946 let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
947 let mut results = Vec::<SearchResult>::new();
948
949 let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
950 for (buffer, snapshot) in modified_buffers {
951 let language = snapshot
952 .language_at(0)
953 .cloned()
954 .unwrap_or_else(|| language::PLAIN_TEXT.clone());
955 let mut spans = retriever
956 .parse_file_with_template(None, &snapshot.text(), language)
957 .log_err()
958 .unwrap_or_default();
959 if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
960 .await
961 .log_err()
962 .is_some()
963 {
964 for span in spans {
965 let similarity = span.embedding.unwrap().similarity(&query);
966 let ix = match results
967 .binary_search_by_key(&Reverse(similarity), |result| {
968 Reverse(result.similarity)
969 }) {
970 Ok(ix) => ix,
971 Err(ix) => ix,
972 };
973
974 let range = {
975 let start = snapshot.clip_offset(span.range.start, Bias::Left);
976 let end = snapshot.clip_offset(span.range.end, Bias::Right);
977 snapshot.anchor_before(start)..snapshot.anchor_after(end)
978 };
979
980 results.insert(
981 ix,
982 SearchResult {
983 buffer: buffer.clone(),
984 range,
985 similarity,
986 },
987 );
988 results.truncate(limit);
989 }
990 }
991 }
992
993 Ok(results)
994 })
995 }
996
997 pub fn index_project(
998 &mut self,
999 project: ModelHandle<Project>,
1000 cx: &mut ModelContext<Self>,
1001 ) -> Task<Result<()>> {
1002 if !self.is_authenticated() {
1003 if !self.authenticate(cx) {
1004 return Task::ready(Err(anyhow!("user is not authenticated")));
1005 }
1006 }
1007
1008 if !self.projects.contains_key(&project.downgrade()) {
1009 let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
1010 project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
1011 this.project_worktrees_changed(project.clone(), cx);
1012 }
1013 project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
1014 this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
1015 }
1016 _ => {}
1017 });
1018 let project_state = ProjectState::new(subscription, cx);
1019 self.projects.insert(project.downgrade(), project_state);
1020 self.project_worktrees_changed(project.clone(), cx);
1021 }
1022 let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
1023 project_state.pending_index += 1;
1024 cx.notify();
1025
1026 let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
1027 let db = self.db.clone();
1028 let language_registry = self.language_registry.clone();
1029 let parsing_files_tx = self.parsing_files_tx.clone();
1030 let worktree_registration = self.wait_for_worktree_registration(&project, cx);
1031
1032 cx.spawn(|this, mut cx| async move {
1033 worktree_registration.await?;
1034
1035 let mut pending_files = Vec::new();
1036 let mut files_to_delete = Vec::new();
1037 this.update(&mut cx, |this, cx| {
1038 let project_state = this
1039 .projects
1040 .get_mut(&project.downgrade())
1041 .ok_or_else(|| anyhow!("project was dropped"))?;
1042 let pending_file_count_tx = &project_state.pending_file_count_tx;
1043
1044 project_state
1045 .worktrees
1046 .retain(|worktree_id, worktree_state| {
1047 let worktree = if let Some(worktree) =
1048 project.read(cx).worktree_for_id(*worktree_id, cx)
1049 {
1050 worktree
1051 } else {
1052 return false;
1053 };
1054 let worktree_state =
1055 if let WorktreeState::Registered(worktree_state) = worktree_state {
1056 worktree_state
1057 } else {
1058 return true;
1059 };
1060
1061 worktree_state.changed_paths.retain(|path, info| {
1062 if info.is_deleted {
1063 files_to_delete.push((worktree_state.db_id, path.clone()));
1064 } else {
1065 let absolute_path = worktree.read(cx).absolutize(path);
1066 let job_handle = JobHandle::new(pending_file_count_tx);
1067 pending_files.push(PendingFile {
1068 absolute_path,
1069 relative_path: path.clone(),
1070 language: None,
1071 job_handle,
1072 modified_time: info.mtime,
1073 worktree_db_id: worktree_state.db_id,
1074 });
1075 }
1076
1077 false
1078 });
1079 true
1080 });
1081
1082 anyhow::Ok(())
1083 })?;
1084
1085 cx.background()
1086 .spawn(async move {
1087 for (worktree_db_id, path) in files_to_delete {
1088 db.delete_file(worktree_db_id, path).await.log_err();
1089 }
1090
1091 let embeddings_for_digest = {
1092 let mut files = HashMap::default();
1093 for pending_file in &pending_files {
1094 files
1095 .entry(pending_file.worktree_db_id)
1096 .or_insert(Vec::new())
1097 .push(pending_file.relative_path.clone());
1098 }
1099 Arc::new(
1100 db.embeddings_for_files(files)
1101 .await
1102 .log_err()
1103 .unwrap_or_default(),
1104 )
1105 };
1106
1107 for mut pending_file in pending_files {
1108 if let Ok(language) = language_registry
1109 .language_for_file(&pending_file.relative_path, None)
1110 .await
1111 {
1112 if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
1113 && &language.name().as_ref() != &"Markdown"
1114 && language
1115 .grammar()
1116 .and_then(|grammar| grammar.embedding_config.as_ref())
1117 .is_none()
1118 {
1119 continue;
1120 }
1121 pending_file.language = Some(language);
1122 }
1123 parsing_files_tx
1124 .try_send((embeddings_for_digest.clone(), pending_file))
1125 .ok();
1126 }
1127
1128 // Wait until we're done indexing.
1129 while let Some(count) = pending_file_count_rx.next().await {
1130 if count == 0 {
1131 break;
1132 }
1133 }
1134 })
1135 .await;
1136
1137 this.update(&mut cx, |this, cx| {
1138 let project_state = this
1139 .projects
1140 .get_mut(&project.downgrade())
1141 .ok_or_else(|| anyhow!("project was dropped"))?;
1142 project_state.pending_index -= 1;
1143 cx.notify();
1144 anyhow::Ok(())
1145 })?;
1146
1147 Ok(())
1148 })
1149 }
1150
1151 fn wait_for_worktree_registration(
1152 &self,
1153 project: &ModelHandle<Project>,
1154 cx: &mut ModelContext<Self>,
1155 ) -> Task<Result<()>> {
1156 let project = project.downgrade();
1157 cx.spawn_weak(|this, cx| async move {
1158 loop {
1159 let mut pending_worktrees = Vec::new();
1160 this.upgrade(&cx)
1161 .ok_or_else(|| anyhow!("semantic index dropped"))?
1162 .read_with(&cx, |this, _| {
1163 if let Some(project) = this.projects.get(&project) {
1164 for worktree in project.worktrees.values() {
1165 if let WorktreeState::Registering(worktree) = worktree {
1166 pending_worktrees.push(worktree.done());
1167 }
1168 }
1169 }
1170 });
1171
1172 if pending_worktrees.is_empty() {
1173 break;
1174 } else {
1175 future::join_all(pending_worktrees).await;
1176 }
1177 }
1178 Ok(())
1179 })
1180 }
1181
1182 async fn embed_spans(
1183 spans: &mut [Span],
1184 embedding_provider: &dyn EmbeddingProvider,
1185 db: &VectorDatabase,
1186 ) -> Result<()> {
1187 let mut batch = Vec::new();
1188 let mut batch_tokens = 0;
1189 let mut embeddings = Vec::new();
1190
1191 let digests = spans
1192 .iter()
1193 .map(|span| span.digest.clone())
1194 .collect::<Vec<_>>();
1195 let embeddings_for_digests = db
1196 .embeddings_for_digests(digests)
1197 .await
1198 .log_err()
1199 .unwrap_or_default();
1200
1201 for span in &*spans {
1202 if embeddings_for_digests.contains_key(&span.digest) {
1203 continue;
1204 };
1205
1206 if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
1207 let batch_embeddings = embedding_provider
1208 .embed_batch(mem::take(&mut batch))
1209 .await?;
1210 embeddings.extend(batch_embeddings);
1211 batch_tokens = 0;
1212 }
1213
1214 batch_tokens += span.token_count;
1215 batch.push(span.content.clone());
1216 }
1217
1218 if !batch.is_empty() {
1219 let batch_embeddings = embedding_provider
1220 .embed_batch(mem::take(&mut batch))
1221 .await?;
1222
1223 embeddings.extend(batch_embeddings);
1224 }
1225
1226 let mut embeddings = embeddings.into_iter();
1227 for span in spans {
1228 let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
1229 Some(embedding.clone())
1230 } else {
1231 embeddings.next()
1232 };
1233 let embedding = embedding.ok_or_else(|| anyhow!("failed to embed spans"))?;
1234 span.embedding = Some(embedding);
1235 }
1236 Ok(())
1237 }
1238}
1239
1240impl Entity for SemanticIndex {
1241 type Event = ();
1242}
1243
1244impl Drop for JobHandle {
1245 fn drop(&mut self) {
1246 if let Some(inner) = Arc::get_mut(&mut self.tx) {
1247 // This is the last instance of the JobHandle (regardless of it's origin - whether it was cloned or not)
1248 if let Some(tx) = inner.upgrade() {
1249 let mut tx = tx.lock();
1250 *tx.borrow_mut() -= 1;
1251 }
1252 }
1253 }
1254}
1255
1256#[cfg(test)]
1257mod tests {
1258
1259 use super::*;
1260 #[test]
1261 fn test_job_handle() {
1262 let (job_count_tx, job_count_rx) = watch::channel_with(0);
1263 let tx = Arc::new(Mutex::new(job_count_tx));
1264 let job_handle = JobHandle::new(&tx);
1265
1266 assert_eq!(1, *job_count_rx.borrow());
1267 let new_job_handle = job_handle.clone();
1268 assert_eq!(1, *job_count_rx.borrow());
1269 drop(job_handle);
1270 assert_eq!(1, *job_count_rx.borrow());
1271 drop(new_job_handle);
1272 assert_eq!(0, *job_count_rx.borrow());
1273 }
1274}