1mod db;
2mod embedding_queue;
3mod parsing;
4pub mod semantic_index_settings;
5
6#[cfg(test)]
7mod semantic_index_tests;
8
9use crate::semantic_index_settings::SemanticIndexSettings;
10use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
11use anyhow::{anyhow, Result};
12use collections::{BTreeMap, HashMap, HashSet};
13use db::VectorDatabase;
14use embedding_queue::{EmbeddingQueue, FileToEmbed};
15use futures::{future, FutureExt, StreamExt};
16use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
17use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
18use lazy_static::lazy_static;
19use ordered_float::OrderedFloat;
20use parking_lot::Mutex;
21use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
22use postage::watch;
23use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
24use smol::channel;
25use std::{
26 cmp::Reverse,
27 env,
28 future::Future,
29 mem,
30 ops::Range,
31 path::{Path, PathBuf},
32 sync::{Arc, Weak},
33 time::{Duration, Instant, SystemTime},
34};
35use util::{channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
36use workspace::WorkspaceCreated;
37
38const SEMANTIC_INDEX_VERSION: usize = 11;
39const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
40const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
41
42lazy_static! {
43 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
44}
45
46pub fn init(
47 fs: Arc<dyn Fs>,
48 http_client: Arc<dyn HttpClient>,
49 language_registry: Arc<LanguageRegistry>,
50 cx: &mut AppContext,
51) {
52 settings::register::<SemanticIndexSettings>(cx);
53
54 let db_file_path = EMBEDDINGS_DIR
55 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
56 .join("embeddings_db");
57
58 cx.subscribe_global::<WorkspaceCreated, _>({
59 move |event, cx| {
60 let Some(semantic_index) = SemanticIndex::global(cx) else {
61 return;
62 };
63 let workspace = &event.0;
64 if let Some(workspace) = workspace.upgrade(cx) {
65 let project = workspace.read(cx).project().clone();
66 if project.read(cx).is_local() {
67 cx.spawn(|mut cx| async move {
68 let previously_indexed = semantic_index
69 .update(&mut cx, |index, cx| {
70 index.project_previously_indexed(&project, cx)
71 })
72 .await?;
73 if previously_indexed {
74 semantic_index
75 .update(&mut cx, |index, cx| index.index_project(project, cx))
76 .await?;
77 }
78 anyhow::Ok(())
79 })
80 .detach_and_log_err(cx);
81 }
82 }
83 }
84 })
85 .detach();
86
87 cx.spawn(move |mut cx| async move {
88 let semantic_index = SemanticIndex::new(
89 fs,
90 db_file_path,
91 Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
92 language_registry,
93 cx.clone(),
94 )
95 .await?;
96
97 cx.update(|cx| {
98 cx.set_global(semantic_index.clone());
99 });
100
101 anyhow::Ok(())
102 })
103 .detach();
104}
105
106#[derive(Copy, Clone, Debug)]
107pub enum SemanticIndexStatus {
108 NotAuthenticated,
109 NotIndexed,
110 Indexed,
111 Indexing {
112 remaining_files: usize,
113 rate_limit_expiry: Option<Instant>,
114 },
115}
116
117pub struct SemanticIndex {
118 fs: Arc<dyn Fs>,
119 db: VectorDatabase,
120 embedding_provider: Arc<dyn EmbeddingProvider>,
121 language_registry: Arc<LanguageRegistry>,
122 parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
123 _embedding_task: Task<()>,
124 _parsing_files_tasks: Vec<Task<()>>,
125 projects: HashMap<WeakModelHandle<Project>, ProjectState>,
126 api_key: Option<String>,
127 embedding_queue: Arc<Mutex<EmbeddingQueue>>,
128}
129
130struct ProjectState {
131 worktrees: HashMap<WorktreeId, WorktreeState>,
132 pending_file_count_rx: watch::Receiver<usize>,
133 pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
134 pending_index: usize,
135 _subscription: gpui::Subscription,
136 _observe_pending_file_count: Task<()>,
137}
138
139enum WorktreeState {
140 Registering(RegisteringWorktreeState),
141 Registered(RegisteredWorktreeState),
142}
143
144impl WorktreeState {
145 fn is_registered(&self) -> bool {
146 matches!(self, Self::Registered(_))
147 }
148
149 fn paths_changed(
150 &mut self,
151 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
152 worktree: &Worktree,
153 ) {
154 let changed_paths = match self {
155 Self::Registering(state) => &mut state.changed_paths,
156 Self::Registered(state) => &mut state.changed_paths,
157 };
158
159 for (path, entry_id, change) in changes.iter() {
160 let Some(entry) = worktree.entry_for_id(*entry_id) else {
161 continue;
162 };
163 if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
164 continue;
165 }
166 changed_paths.insert(
167 path.clone(),
168 ChangedPathInfo {
169 mtime: entry.mtime,
170 is_deleted: *change == PathChange::Removed,
171 },
172 );
173 }
174 }
175}
176
177struct RegisteringWorktreeState {
178 changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
179 done_rx: watch::Receiver<Option<()>>,
180 _registration: Task<()>,
181}
182
183impl RegisteringWorktreeState {
184 fn done(&self) -> impl Future<Output = ()> {
185 let mut done_rx = self.done_rx.clone();
186 async move {
187 while let Some(result) = done_rx.next().await {
188 if result.is_some() {
189 break;
190 }
191 }
192 }
193 }
194}
195
196struct RegisteredWorktreeState {
197 db_id: i64,
198 changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
199}
200
201struct ChangedPathInfo {
202 mtime: SystemTime,
203 is_deleted: bool,
204}
205
206#[derive(Clone)]
207pub struct JobHandle {
208 /// The outer Arc is here to count the clones of a JobHandle instance;
209 /// when the last handle to a given job is dropped, we decrement a counter (just once).
210 tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
211}
212
213impl JobHandle {
214 fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
215 *tx.lock().borrow_mut() += 1;
216 Self {
217 tx: Arc::new(Arc::downgrade(&tx)),
218 }
219 }
220}
221
222impl ProjectState {
223 fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
224 let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
225 let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
226 Self {
227 worktrees: Default::default(),
228 pending_file_count_rx: pending_file_count_rx.clone(),
229 pending_file_count_tx,
230 pending_index: 0,
231 _subscription: subscription,
232 _observe_pending_file_count: cx.spawn_weak({
233 let mut pending_file_count_rx = pending_file_count_rx.clone();
234 |this, mut cx| async move {
235 while let Some(_) = pending_file_count_rx.next().await {
236 if let Some(this) = this.upgrade(&cx) {
237 this.update(&mut cx, |_, cx| cx.notify());
238 }
239 }
240 }
241 }),
242 }
243 }
244
245 fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
246 self.worktrees
247 .iter()
248 .find_map(|(worktree_id, worktree_state)| match worktree_state {
249 WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
250 _ => None,
251 })
252 }
253}
254
255#[derive(Clone)]
256pub struct PendingFile {
257 worktree_db_id: i64,
258 relative_path: Arc<Path>,
259 absolute_path: PathBuf,
260 language: Option<Arc<Language>>,
261 modified_time: SystemTime,
262 job_handle: JobHandle,
263}
264
265#[derive(Clone)]
266pub struct SearchResult {
267 pub buffer: ModelHandle<Buffer>,
268 pub range: Range<Anchor>,
269 pub similarity: OrderedFloat<f32>,
270}
271
272impl SemanticIndex {
273 pub fn global(cx: &mut AppContext) -> Option<ModelHandle<SemanticIndex>> {
274 if cx.has_global::<ModelHandle<Self>>() {
275 Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
276 } else {
277 None
278 }
279 }
280
281 pub fn authenticate(&mut self, cx: &AppContext) {
282 if self.api_key.is_none() {
283 self.api_key = self.embedding_provider.retrieve_credentials(cx);
284
285 self.embedding_queue
286 .lock()
287 .set_api_key(self.api_key.clone());
288 }
289 }
290
291 pub fn is_authenticated(&self) -> bool {
292 self.api_key.is_some()
293 }
294
295 pub fn enabled(cx: &AppContext) -> bool {
296 settings::get::<SemanticIndexSettings>(cx).enabled
297 }
298
299 pub fn status(&self, project: &ModelHandle<Project>) -> SemanticIndexStatus {
300 if !self.is_authenticated() {
301 return SemanticIndexStatus::NotAuthenticated;
302 }
303
304 if let Some(project_state) = self.projects.get(&project.downgrade()) {
305 if project_state
306 .worktrees
307 .values()
308 .all(|worktree| worktree.is_registered())
309 && project_state.pending_index == 0
310 {
311 SemanticIndexStatus::Indexed
312 } else {
313 SemanticIndexStatus::Indexing {
314 remaining_files: project_state.pending_file_count_rx.borrow().clone(),
315 rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
316 }
317 }
318 } else {
319 SemanticIndexStatus::NotIndexed
320 }
321 }
322
323 pub async fn new(
324 fs: Arc<dyn Fs>,
325 database_path: PathBuf,
326 embedding_provider: Arc<dyn EmbeddingProvider>,
327 language_registry: Arc<LanguageRegistry>,
328 mut cx: AsyncAppContext,
329 ) -> Result<ModelHandle<Self>> {
330 let t0 = Instant::now();
331 let database_path = Arc::from(database_path);
332 let db = VectorDatabase::new(fs.clone(), database_path, cx.background()).await?;
333
334 log::trace!(
335 "db initialization took {:?} milliseconds",
336 t0.elapsed().as_millis()
337 );
338
339 Ok(cx.add_model(|cx| {
340 let t0 = Instant::now();
341 let embedding_queue =
342 EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None);
343 let _embedding_task = cx.background().spawn({
344 let embedded_files = embedding_queue.finished_files();
345 let db = db.clone();
346 async move {
347 while let Ok(file) = embedded_files.recv().await {
348 db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
349 .await
350 .log_err();
351 }
352 }
353 });
354
355 // Parse files into embeddable spans.
356 let (parsing_files_tx, parsing_files_rx) =
357 channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
358 let embedding_queue = Arc::new(Mutex::new(embedding_queue));
359 let mut _parsing_files_tasks = Vec::new();
360 for _ in 0..cx.background().num_cpus() {
361 let fs = fs.clone();
362 let mut parsing_files_rx = parsing_files_rx.clone();
363 let embedding_provider = embedding_provider.clone();
364 let embedding_queue = embedding_queue.clone();
365 let background = cx.background().clone();
366 _parsing_files_tasks.push(cx.background().spawn(async move {
367 let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
368 loop {
369 let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
370 let mut next_file_to_parse = parsing_files_rx.next().fuse();
371 futures::select_biased! {
372 next_file_to_parse = next_file_to_parse => {
373 if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
374 Self::parse_file(
375 &fs,
376 pending_file,
377 &mut retriever,
378 &embedding_queue,
379 &embeddings_for_digest,
380 )
381 .await
382 } else {
383 break;
384 }
385 },
386 _ = timer => {
387 embedding_queue.lock().flush();
388 }
389 }
390 }
391 }));
392 }
393
394 log::trace!(
395 "semantic index task initialization took {:?} milliseconds",
396 t0.elapsed().as_millis()
397 );
398 Self {
399 fs,
400 db,
401 embedding_provider,
402 language_registry,
403 parsing_files_tx,
404 _embedding_task,
405 _parsing_files_tasks,
406 projects: Default::default(),
407 api_key: None,
408 embedding_queue
409 }
410 }))
411 }
412
413 async fn parse_file(
414 fs: &Arc<dyn Fs>,
415 pending_file: PendingFile,
416 retriever: &mut CodeContextRetriever,
417 embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
418 embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
419 ) {
420 let Some(language) = pending_file.language else {
421 return;
422 };
423
424 if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
425 if let Some(mut spans) = retriever
426 .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
427 .log_err()
428 {
429 log::trace!(
430 "parsed path {:?}: {} spans",
431 pending_file.relative_path,
432 spans.len()
433 );
434
435 for span in &mut spans {
436 if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
437 span.embedding = Some(embedding.to_owned());
438 }
439 }
440
441 embedding_queue.lock().push(FileToEmbed {
442 worktree_id: pending_file.worktree_db_id,
443 path: pending_file.relative_path,
444 mtime: pending_file.modified_time,
445 job_handle: pending_file.job_handle,
446 spans,
447 });
448 }
449 }
450 }
451
452 pub fn project_previously_indexed(
453 &mut self,
454 project: &ModelHandle<Project>,
455 cx: &mut ModelContext<Self>,
456 ) -> Task<Result<bool>> {
457 let worktrees_indexed_previously = project
458 .read(cx)
459 .worktrees(cx)
460 .map(|worktree| {
461 self.db
462 .worktree_previously_indexed(&worktree.read(cx).abs_path())
463 })
464 .collect::<Vec<_>>();
465 cx.spawn(|_, _cx| async move {
466 let worktree_indexed_previously =
467 futures::future::join_all(worktrees_indexed_previously).await;
468
469 Ok(worktree_indexed_previously
470 .iter()
471 .filter(|worktree| worktree.is_ok())
472 .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
473 })
474 }
475
476 fn project_entries_changed(
477 &mut self,
478 project: ModelHandle<Project>,
479 worktree_id: WorktreeId,
480 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
481 cx: &mut ModelContext<Self>,
482 ) {
483 let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
484 return;
485 };
486 let project = project.downgrade();
487 let Some(project_state) = self.projects.get_mut(&project) else {
488 return;
489 };
490
491 let worktree = worktree.read(cx);
492 let worktree_state =
493 if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
494 worktree_state
495 } else {
496 return;
497 };
498 worktree_state.paths_changed(changes, worktree);
499 if let WorktreeState::Registered(_) = worktree_state {
500 cx.spawn_weak(|this, mut cx| async move {
501 cx.background().timer(BACKGROUND_INDEXING_DELAY).await;
502 if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) {
503 this.update(&mut cx, |this, cx| {
504 this.index_project(project, cx).detach_and_log_err(cx)
505 });
506 }
507 })
508 .detach();
509 }
510 }
511
512 fn register_worktree(
513 &mut self,
514 project: ModelHandle<Project>,
515 worktree: ModelHandle<Worktree>,
516 cx: &mut ModelContext<Self>,
517 ) {
518 let project = project.downgrade();
519 let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
520 project_state
521 } else {
522 return;
523 };
524 let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
525 worktree
526 } else {
527 return;
528 };
529 let worktree_abs_path = worktree.abs_path().clone();
530 let scan_complete = worktree.scan_complete();
531 let worktree_id = worktree.id();
532 let db = self.db.clone();
533 let language_registry = self.language_registry.clone();
534 let (mut done_tx, done_rx) = watch::channel();
535 let registration = cx.spawn(|this, mut cx| {
536 async move {
537 let register = async {
538 scan_complete.await;
539 let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
540 let mut file_mtimes = db.get_file_mtimes(db_id).await?;
541 let worktree = if let Some(project) = project.upgrade(&cx) {
542 project
543 .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
544 .ok_or_else(|| anyhow!("worktree not found"))?
545 } else {
546 return anyhow::Ok(());
547 };
548 let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot());
549 let mut changed_paths = cx
550 .background()
551 .spawn(async move {
552 let mut changed_paths = BTreeMap::new();
553 for file in worktree.files(false, 0) {
554 let absolute_path = worktree.absolutize(&file.path);
555
556 if file.is_external || file.is_ignored || file.is_symlink {
557 continue;
558 }
559
560 if let Ok(language) = language_registry
561 .language_for_file(&absolute_path, None)
562 .await
563 {
564 // Test if file is valid parseable file
565 if !PARSEABLE_ENTIRE_FILE_TYPES
566 .contains(&language.name().as_ref())
567 && &language.name().as_ref() != &"Markdown"
568 && language
569 .grammar()
570 .and_then(|grammar| grammar.embedding_config.as_ref())
571 .is_none()
572 {
573 continue;
574 }
575
576 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
577 let already_stored = stored_mtime
578 .map_or(false, |existing_mtime| {
579 existing_mtime == file.mtime
580 });
581
582 if !already_stored {
583 changed_paths.insert(
584 file.path.clone(),
585 ChangedPathInfo {
586 mtime: file.mtime,
587 is_deleted: false,
588 },
589 );
590 }
591 }
592 }
593
594 // Clean up entries from database that are no longer in the worktree.
595 for (path, mtime) in file_mtimes {
596 changed_paths.insert(
597 path.into(),
598 ChangedPathInfo {
599 mtime,
600 is_deleted: true,
601 },
602 );
603 }
604
605 anyhow::Ok(changed_paths)
606 })
607 .await?;
608 this.update(&mut cx, |this, cx| {
609 let project_state = this
610 .projects
611 .get_mut(&project)
612 .ok_or_else(|| anyhow!("project not registered"))?;
613 let project = project
614 .upgrade(cx)
615 .ok_or_else(|| anyhow!("project was dropped"))?;
616
617 if let Some(WorktreeState::Registering(state)) =
618 project_state.worktrees.remove(&worktree_id)
619 {
620 changed_paths.extend(state.changed_paths);
621 }
622 project_state.worktrees.insert(
623 worktree_id,
624 WorktreeState::Registered(RegisteredWorktreeState {
625 db_id,
626 changed_paths,
627 }),
628 );
629 this.index_project(project, cx).detach_and_log_err(cx);
630
631 anyhow::Ok(())
632 })?;
633
634 anyhow::Ok(())
635 };
636
637 if register.await.log_err().is_none() {
638 // Stop tracking this worktree if the registration failed.
639 this.update(&mut cx, |this, _| {
640 this.projects.get_mut(&project).map(|project_state| {
641 project_state.worktrees.remove(&worktree_id);
642 });
643 })
644 }
645
646 *done_tx.borrow_mut() = Some(());
647 }
648 });
649 project_state.worktrees.insert(
650 worktree_id,
651 WorktreeState::Registering(RegisteringWorktreeState {
652 changed_paths: Default::default(),
653 done_rx,
654 _registration: registration,
655 }),
656 );
657 }
658
659 fn project_worktrees_changed(
660 &mut self,
661 project: ModelHandle<Project>,
662 cx: &mut ModelContext<Self>,
663 ) {
664 let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
665 {
666 project_state
667 } else {
668 return;
669 };
670
671 let mut worktrees = project
672 .read(cx)
673 .worktrees(cx)
674 .filter(|worktree| worktree.read(cx).is_local())
675 .collect::<Vec<_>>();
676 let worktree_ids = worktrees
677 .iter()
678 .map(|worktree| worktree.read(cx).id())
679 .collect::<HashSet<_>>();
680
681 // Remove worktrees that are no longer present
682 project_state
683 .worktrees
684 .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
685
686 // Register new worktrees
687 worktrees.retain(|worktree| {
688 let worktree_id = worktree.read(cx).id();
689 !project_state.worktrees.contains_key(&worktree_id)
690 });
691 for worktree in worktrees {
692 self.register_worktree(project.clone(), worktree, cx);
693 }
694 }
695
696 pub fn pending_file_count(
697 &self,
698 project: &ModelHandle<Project>,
699 ) -> Option<watch::Receiver<usize>> {
700 Some(
701 self.projects
702 .get(&project.downgrade())?
703 .pending_file_count_rx
704 .clone(),
705 )
706 }
707
708 pub fn search_project(
709 &mut self,
710 project: ModelHandle<Project>,
711 query: String,
712 limit: usize,
713 includes: Vec<PathMatcher>,
714 excludes: Vec<PathMatcher>,
715 cx: &mut ModelContext<Self>,
716 ) -> Task<Result<Vec<SearchResult>>> {
717 if query.is_empty() {
718 return Task::ready(Ok(Vec::new()));
719 }
720
721 let index = self.index_project(project.clone(), cx);
722 let embedding_provider = self.embedding_provider.clone();
723 let api_key = self.api_key.clone();
724
725 cx.spawn(|this, mut cx| async move {
726 index.await?;
727 let t0 = Instant::now();
728 let query = embedding_provider
729 .embed_batch(vec![query], api_key)
730 .await?
731 .pop()
732 .ok_or_else(|| anyhow!("could not embed query"))?;
733 log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
734
735 let search_start = Instant::now();
736 let modified_buffer_results = this.update(&mut cx, |this, cx| {
737 this.search_modified_buffers(
738 &project,
739 query.clone(),
740 limit,
741 &includes,
742 &excludes,
743 cx,
744 )
745 });
746 let file_results = this.update(&mut cx, |this, cx| {
747 this.search_files(project, query, limit, includes, excludes, cx)
748 });
749 let (modified_buffer_results, file_results) =
750 futures::join!(modified_buffer_results, file_results);
751
752 // Weave together the results from modified buffers and files.
753 let mut results = Vec::new();
754 let mut modified_buffers = HashSet::default();
755 for result in modified_buffer_results.log_err().unwrap_or_default() {
756 modified_buffers.insert(result.buffer.clone());
757 results.push(result);
758 }
759 for result in file_results.log_err().unwrap_or_default() {
760 if !modified_buffers.contains(&result.buffer) {
761 results.push(result);
762 }
763 }
764 results.sort_by_key(|result| Reverse(result.similarity));
765 results.truncate(limit);
766 log::trace!("Semantic search took {:?}", search_start.elapsed());
767 Ok(results)
768 })
769 }
770
771 pub fn search_files(
772 &mut self,
773 project: ModelHandle<Project>,
774 query: Embedding,
775 limit: usize,
776 includes: Vec<PathMatcher>,
777 excludes: Vec<PathMatcher>,
778 cx: &mut ModelContext<Self>,
779 ) -> Task<Result<Vec<SearchResult>>> {
780 let db_path = self.db.path().clone();
781 let fs = self.fs.clone();
782 cx.spawn(|this, mut cx| async move {
783 let database =
784 VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
785
786 let worktree_db_ids = this.read_with(&cx, |this, _| {
787 let project_state = this
788 .projects
789 .get(&project.downgrade())
790 .ok_or_else(|| anyhow!("project was not indexed"))?;
791 let worktree_db_ids = project_state
792 .worktrees
793 .values()
794 .filter_map(|worktree| {
795 if let WorktreeState::Registered(worktree) = worktree {
796 Some(worktree.db_id)
797 } else {
798 None
799 }
800 })
801 .collect::<Vec<i64>>();
802 anyhow::Ok(worktree_db_ids)
803 })?;
804
805 let file_ids = database
806 .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
807 .await?;
808
809 let batch_n = cx.background().num_cpus();
810 let ids_len = file_ids.clone().len();
811 let minimum_batch_size = 50;
812
813 let batch_size = {
814 let size = ids_len / batch_n;
815 if size < minimum_batch_size {
816 minimum_batch_size
817 } else {
818 size
819 }
820 };
821
822 let mut batch_results = Vec::new();
823 for batch in file_ids.chunks(batch_size) {
824 let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
825 let limit = limit.clone();
826 let fs = fs.clone();
827 let db_path = db_path.clone();
828 let query = query.clone();
829 if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
830 .await
831 .log_err()
832 {
833 batch_results.push(async move {
834 db.top_k_search(&query, limit, batch.as_slice()).await
835 });
836 }
837 }
838
839 let batch_results = futures::future::join_all(batch_results).await;
840
841 let mut results = Vec::new();
842 for batch_result in batch_results {
843 if batch_result.is_ok() {
844 for (id, similarity) in batch_result.unwrap() {
845 let ix = match results
846 .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
847 {
848 Ok(ix) => ix,
849 Err(ix) => ix,
850 };
851
852 results.insert(ix, (id, similarity));
853 results.truncate(limit);
854 }
855 }
856 }
857
858 let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
859 let scores = results
860 .into_iter()
861 .map(|(_, score)| score)
862 .collect::<Vec<_>>();
863 let spans = database.spans_for_ids(ids.as_slice()).await?;
864
865 let mut tasks = Vec::new();
866 let mut ranges = Vec::new();
867 let weak_project = project.downgrade();
868 project.update(&mut cx, |project, cx| {
869 for (worktree_db_id, file_path, byte_range) in spans {
870 let project_state =
871 if let Some(state) = this.read(cx).projects.get(&weak_project) {
872 state
873 } else {
874 return Err(anyhow!("project not added"));
875 };
876 if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
877 tasks.push(project.open_buffer((worktree_id, file_path), cx));
878 ranges.push(byte_range);
879 }
880 }
881
882 Ok(())
883 })?;
884
885 let buffers = futures::future::join_all(tasks).await;
886 Ok(buffers
887 .into_iter()
888 .zip(ranges)
889 .zip(scores)
890 .filter_map(|((buffer, range), similarity)| {
891 let buffer = buffer.log_err()?;
892 let range = buffer.read_with(&cx, |buffer, _| {
893 let start = buffer.clip_offset(range.start, Bias::Left);
894 let end = buffer.clip_offset(range.end, Bias::Right);
895 buffer.anchor_before(start)..buffer.anchor_after(end)
896 });
897 Some(SearchResult {
898 buffer,
899 range,
900 similarity,
901 })
902 })
903 .collect())
904 })
905 }
906
907 fn search_modified_buffers(
908 &self,
909 project: &ModelHandle<Project>,
910 query: Embedding,
911 limit: usize,
912 includes: &[PathMatcher],
913 excludes: &[PathMatcher],
914 cx: &mut ModelContext<Self>,
915 ) -> Task<Result<Vec<SearchResult>>> {
916 let modified_buffers = project
917 .read(cx)
918 .opened_buffers(cx)
919 .into_iter()
920 .filter_map(|buffer_handle| {
921 let buffer = buffer_handle.read(cx);
922 let snapshot = buffer.snapshot();
923 let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
924 excludes.iter().any(|matcher| matcher.is_match(&path))
925 });
926
927 let included = if includes.len() == 0 {
928 true
929 } else {
930 snapshot.resolve_file_path(cx, false).map_or(false, |path| {
931 includes.iter().any(|matcher| matcher.is_match(&path))
932 })
933 };
934
935 if buffer.is_dirty() && !excluded && included {
936 Some((buffer_handle, snapshot))
937 } else {
938 None
939 }
940 })
941 .collect::<HashMap<_, _>>();
942
943 let embedding_provider = self.embedding_provider.clone();
944 let fs = self.fs.clone();
945 let db_path = self.db.path().clone();
946 let background = cx.background().clone();
947 let api_key = self.api_key.clone();
948 cx.background().spawn(async move {
949 let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
950 let mut results = Vec::<SearchResult>::new();
951
952 let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
953 for (buffer, snapshot) in modified_buffers {
954 let language = snapshot
955 .language_at(0)
956 .cloned()
957 .unwrap_or_else(|| language::PLAIN_TEXT.clone());
958 let mut spans = retriever
959 .parse_file_with_template(None, &snapshot.text(), language)
960 .log_err()
961 .unwrap_or_default();
962 if Self::embed_spans(
963 &mut spans,
964 embedding_provider.as_ref(),
965 &db,
966 api_key.clone(),
967 )
968 .await
969 .log_err()
970 .is_some()
971 {
972 for span in spans {
973 let similarity = span.embedding.unwrap().similarity(&query);
974 let ix = match results
975 .binary_search_by_key(&Reverse(similarity), |result| {
976 Reverse(result.similarity)
977 }) {
978 Ok(ix) => ix,
979 Err(ix) => ix,
980 };
981
982 let range = {
983 let start = snapshot.clip_offset(span.range.start, Bias::Left);
984 let end = snapshot.clip_offset(span.range.end, Bias::Right);
985 snapshot.anchor_before(start)..snapshot.anchor_after(end)
986 };
987
988 results.insert(
989 ix,
990 SearchResult {
991 buffer: buffer.clone(),
992 range,
993 similarity,
994 },
995 );
996 results.truncate(limit);
997 }
998 }
999 }
1000
1001 Ok(results)
1002 })
1003 }
1004
1005 pub fn index_project(
1006 &mut self,
1007 project: ModelHandle<Project>,
1008 cx: &mut ModelContext<Self>,
1009 ) -> Task<Result<()>> {
1010 if self.api_key.is_none() {
1011 self.authenticate(cx);
1012 if self.api_key.is_none() {
1013 return Task::ready(Err(anyhow!("user is not authenticated")));
1014 }
1015 }
1016
1017 if !self.projects.contains_key(&project.downgrade()) {
1018 let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
1019 project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
1020 this.project_worktrees_changed(project.clone(), cx);
1021 }
1022 project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
1023 this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
1024 }
1025 _ => {}
1026 });
1027 let project_state = ProjectState::new(subscription, cx);
1028 self.projects.insert(project.downgrade(), project_state);
1029 self.project_worktrees_changed(project.clone(), cx);
1030 }
1031 let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
1032 project_state.pending_index += 1;
1033 cx.notify();
1034
1035 let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
1036 let db = self.db.clone();
1037 let language_registry = self.language_registry.clone();
1038 let parsing_files_tx = self.parsing_files_tx.clone();
1039 let worktree_registration = self.wait_for_worktree_registration(&project, cx);
1040
1041 cx.spawn(|this, mut cx| async move {
1042 worktree_registration.await?;
1043
1044 let mut pending_files = Vec::new();
1045 let mut files_to_delete = Vec::new();
1046 this.update(&mut cx, |this, cx| {
1047 let project_state = this
1048 .projects
1049 .get_mut(&project.downgrade())
1050 .ok_or_else(|| anyhow!("project was dropped"))?;
1051 let pending_file_count_tx = &project_state.pending_file_count_tx;
1052
1053 project_state
1054 .worktrees
1055 .retain(|worktree_id, worktree_state| {
1056 let worktree = if let Some(worktree) =
1057 project.read(cx).worktree_for_id(*worktree_id, cx)
1058 {
1059 worktree
1060 } else {
1061 return false;
1062 };
1063 let worktree_state =
1064 if let WorktreeState::Registered(worktree_state) = worktree_state {
1065 worktree_state
1066 } else {
1067 return true;
1068 };
1069
1070 worktree_state.changed_paths.retain(|path, info| {
1071 if info.is_deleted {
1072 files_to_delete.push((worktree_state.db_id, path.clone()));
1073 } else {
1074 let absolute_path = worktree.read(cx).absolutize(path);
1075 let job_handle = JobHandle::new(pending_file_count_tx);
1076 pending_files.push(PendingFile {
1077 absolute_path,
1078 relative_path: path.clone(),
1079 language: None,
1080 job_handle,
1081 modified_time: info.mtime,
1082 worktree_db_id: worktree_state.db_id,
1083 });
1084 }
1085
1086 false
1087 });
1088 true
1089 });
1090
1091 anyhow::Ok(())
1092 })?;
1093
1094 cx.background()
1095 .spawn(async move {
1096 for (worktree_db_id, path) in files_to_delete {
1097 db.delete_file(worktree_db_id, path).await.log_err();
1098 }
1099
1100 let embeddings_for_digest = {
1101 let mut files = HashMap::default();
1102 for pending_file in &pending_files {
1103 files
1104 .entry(pending_file.worktree_db_id)
1105 .or_insert(Vec::new())
1106 .push(pending_file.relative_path.clone());
1107 }
1108 Arc::new(
1109 db.embeddings_for_files(files)
1110 .await
1111 .log_err()
1112 .unwrap_or_default(),
1113 )
1114 };
1115
1116 for mut pending_file in pending_files {
1117 if let Ok(language) = language_registry
1118 .language_for_file(&pending_file.relative_path, None)
1119 .await
1120 {
1121 if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
1122 && &language.name().as_ref() != &"Markdown"
1123 && language
1124 .grammar()
1125 .and_then(|grammar| grammar.embedding_config.as_ref())
1126 .is_none()
1127 {
1128 continue;
1129 }
1130 pending_file.language = Some(language);
1131 }
1132 parsing_files_tx
1133 .try_send((embeddings_for_digest.clone(), pending_file))
1134 .ok();
1135 }
1136
1137 // Wait until we're done indexing.
1138 while let Some(count) = pending_file_count_rx.next().await {
1139 if count == 0 {
1140 break;
1141 }
1142 }
1143 })
1144 .await;
1145
1146 this.update(&mut cx, |this, cx| {
1147 let project_state = this
1148 .projects
1149 .get_mut(&project.downgrade())
1150 .ok_or_else(|| anyhow!("project was dropped"))?;
1151 project_state.pending_index -= 1;
1152 cx.notify();
1153 anyhow::Ok(())
1154 })?;
1155
1156 Ok(())
1157 })
1158 }
1159
1160 fn wait_for_worktree_registration(
1161 &self,
1162 project: &ModelHandle<Project>,
1163 cx: &mut ModelContext<Self>,
1164 ) -> Task<Result<()>> {
1165 let project = project.downgrade();
1166 cx.spawn_weak(|this, cx| async move {
1167 loop {
1168 let mut pending_worktrees = Vec::new();
1169 this.upgrade(&cx)
1170 .ok_or_else(|| anyhow!("semantic index dropped"))?
1171 .read_with(&cx, |this, _| {
1172 if let Some(project) = this.projects.get(&project) {
1173 for worktree in project.worktrees.values() {
1174 if let WorktreeState::Registering(worktree) = worktree {
1175 pending_worktrees.push(worktree.done());
1176 }
1177 }
1178 }
1179 });
1180
1181 if pending_worktrees.is_empty() {
1182 break;
1183 } else {
1184 future::join_all(pending_worktrees).await;
1185 }
1186 }
1187 Ok(())
1188 })
1189 }
1190
1191 async fn embed_spans(
1192 spans: &mut [Span],
1193 embedding_provider: &dyn EmbeddingProvider,
1194 db: &VectorDatabase,
1195 api_key: Option<String>,
1196 ) -> Result<()> {
1197 let mut batch = Vec::new();
1198 let mut batch_tokens = 0;
1199 let mut embeddings = Vec::new();
1200
1201 let digests = spans
1202 .iter()
1203 .map(|span| span.digest.clone())
1204 .collect::<Vec<_>>();
1205 let embeddings_for_digests = db
1206 .embeddings_for_digests(digests)
1207 .await
1208 .log_err()
1209 .unwrap_or_default();
1210
1211 for span in &*spans {
1212 if embeddings_for_digests.contains_key(&span.digest) {
1213 continue;
1214 };
1215
1216 if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
1217 let batch_embeddings = embedding_provider
1218 .embed_batch(mem::take(&mut batch), api_key.clone())
1219 .await?;
1220 embeddings.extend(batch_embeddings);
1221 batch_tokens = 0;
1222 }
1223
1224 batch_tokens += span.token_count;
1225 batch.push(span.content.clone());
1226 }
1227
1228 if !batch.is_empty() {
1229 let batch_embeddings = embedding_provider
1230 .embed_batch(mem::take(&mut batch), api_key)
1231 .await?;
1232
1233 embeddings.extend(batch_embeddings);
1234 }
1235
1236 let mut embeddings = embeddings.into_iter();
1237 for span in spans {
1238 let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
1239 Some(embedding.clone())
1240 } else {
1241 embeddings.next()
1242 };
1243 let embedding = embedding.ok_or_else(|| anyhow!("failed to embed spans"))?;
1244 span.embedding = Some(embedding);
1245 }
1246 Ok(())
1247 }
1248}
1249
1250impl Entity for SemanticIndex {
1251 type Event = ();
1252}
1253
1254impl Drop for JobHandle {
1255 fn drop(&mut self) {
1256 if let Some(inner) = Arc::get_mut(&mut self.tx) {
1257 // This is the last instance of the JobHandle (regardless of it's origin - whether it was cloned or not)
1258 if let Some(tx) = inner.upgrade() {
1259 let mut tx = tx.lock();
1260 *tx.borrow_mut() -= 1;
1261 }
1262 }
1263 }
1264}
1265
1266#[cfg(test)]
1267mod tests {
1268
1269 use super::*;
1270 #[test]
1271 fn test_job_handle() {
1272 let (job_count_tx, job_count_rx) = watch::channel_with(0);
1273 let tx = Arc::new(Mutex::new(job_count_tx));
1274 let job_handle = JobHandle::new(&tx);
1275
1276 assert_eq!(1, *job_count_rx.borrow());
1277 let new_job_handle = job_handle.clone();
1278 assert_eq!(1, *job_count_rx.borrow());
1279 drop(job_handle);
1280 assert_eq!(1, *job_count_rx.borrow());
1281 drop(new_job_handle);
1282 assert_eq!(0, *job_count_rx.borrow());
1283 }
1284}