1mod db;
2mod embedding_queue;
3mod parsing;
4pub mod semantic_index_settings;
5
6#[cfg(test)]
7mod semantic_index_tests;
8
9use crate::semantic_index_settings::SemanticIndexSettings;
10use ai::embedding::{Embedding, EmbeddingProvider};
11use ai::providers::open_ai::OpenAIEmbeddingProvider;
12use anyhow::{anyhow, Result};
13use collections::{BTreeMap, HashMap, HashSet};
14use db::VectorDatabase;
15use embedding_queue::{EmbeddingQueue, FileToEmbed};
16use futures::{future, FutureExt, StreamExt};
17use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
18use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
19use lazy_static::lazy_static;
20use ordered_float::OrderedFloat;
21use parking_lot::Mutex;
22use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
23use postage::watch;
24use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
25use smol::channel;
26use std::{
27 cmp::Reverse,
28 env,
29 future::Future,
30 mem,
31 ops::Range,
32 path::{Path, PathBuf},
33 sync::{Arc, Weak},
34 time::{Duration, Instant, SystemTime},
35};
36use util::{channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
37use workspace::WorkspaceCreated;
38
39const SEMANTIC_INDEX_VERSION: usize = 11;
40const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
41const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
42
43lazy_static! {
44 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
45}
46
47pub fn init(
48 fs: Arc<dyn Fs>,
49 http_client: Arc<dyn HttpClient>,
50 language_registry: Arc<LanguageRegistry>,
51 cx: &mut AppContext,
52) {
53 settings::register::<SemanticIndexSettings>(cx);
54
55 let db_file_path = EMBEDDINGS_DIR
56 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
57 .join("embeddings_db");
58
59 cx.subscribe_global::<WorkspaceCreated, _>({
60 move |event, cx| {
61 let Some(semantic_index) = SemanticIndex::global(cx) else {
62 return;
63 };
64 let workspace = &event.0;
65 if let Some(workspace) = workspace.upgrade(cx) {
66 let project = workspace.read(cx).project().clone();
67 if project.read(cx).is_local() {
68 cx.spawn(|mut cx| async move {
69 let previously_indexed = semantic_index
70 .update(&mut cx, |index, cx| {
71 index.project_previously_indexed(&project, cx)
72 })
73 .await?;
74 if previously_indexed {
75 semantic_index
76 .update(&mut cx, |index, cx| index.index_project(project, cx))
77 .await?;
78 }
79 anyhow::Ok(())
80 })
81 .detach_and_log_err(cx);
82 }
83 }
84 }
85 })
86 .detach();
87
88 cx.spawn(move |mut cx| async move {
89 let semantic_index = SemanticIndex::new(
90 fs,
91 db_file_path,
92 Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
93 language_registry,
94 cx.clone(),
95 )
96 .await?;
97
98 cx.update(|cx| {
99 cx.set_global(semantic_index.clone());
100 });
101
102 anyhow::Ok(())
103 })
104 .detach();
105}
106
107#[derive(Copy, Clone, Debug)]
108pub enum SemanticIndexStatus {
109 NotAuthenticated,
110 NotIndexed,
111 Indexed,
112 Indexing {
113 remaining_files: usize,
114 rate_limit_expiry: Option<Instant>,
115 },
116}
117
118pub struct SemanticIndex {
119 fs: Arc<dyn Fs>,
120 db: VectorDatabase,
121 embedding_provider: Arc<dyn EmbeddingProvider>,
122 language_registry: Arc<LanguageRegistry>,
123 parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
124 _embedding_task: Task<()>,
125 _parsing_files_tasks: Vec<Task<()>>,
126 projects: HashMap<WeakModelHandle<Project>, ProjectState>,
127}
128
129struct ProjectState {
130 worktrees: HashMap<WorktreeId, WorktreeState>,
131 pending_file_count_rx: watch::Receiver<usize>,
132 pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
133 pending_index: usize,
134 _subscription: gpui::Subscription,
135 _observe_pending_file_count: Task<()>,
136}
137
138enum WorktreeState {
139 Registering(RegisteringWorktreeState),
140 Registered(RegisteredWorktreeState),
141}
142
143impl WorktreeState {
144 fn is_registered(&self) -> bool {
145 matches!(self, Self::Registered(_))
146 }
147
148 fn paths_changed(
149 &mut self,
150 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
151 worktree: &Worktree,
152 ) {
153 let changed_paths = match self {
154 Self::Registering(state) => &mut state.changed_paths,
155 Self::Registered(state) => &mut state.changed_paths,
156 };
157
158 for (path, entry_id, change) in changes.iter() {
159 let Some(entry) = worktree.entry_for_id(*entry_id) else {
160 continue;
161 };
162 if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
163 continue;
164 }
165 changed_paths.insert(
166 path.clone(),
167 ChangedPathInfo {
168 mtime: entry.mtime,
169 is_deleted: *change == PathChange::Removed,
170 },
171 );
172 }
173 }
174}
175
176struct RegisteringWorktreeState {
177 changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
178 done_rx: watch::Receiver<Option<()>>,
179 _registration: Task<()>,
180}
181
182impl RegisteringWorktreeState {
183 fn done(&self) -> impl Future<Output = ()> {
184 let mut done_rx = self.done_rx.clone();
185 async move {
186 while let Some(result) = done_rx.next().await {
187 if result.is_some() {
188 break;
189 }
190 }
191 }
192 }
193}
194
195struct RegisteredWorktreeState {
196 db_id: i64,
197 changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
198}
199
200struct ChangedPathInfo {
201 mtime: SystemTime,
202 is_deleted: bool,
203}
204
205#[derive(Clone)]
206pub struct JobHandle {
207 /// The outer Arc is here to count the clones of a JobHandle instance;
208 /// when the last handle to a given job is dropped, we decrement a counter (just once).
209 tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
210}
211
212impl JobHandle {
213 fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
214 *tx.lock().borrow_mut() += 1;
215 Self {
216 tx: Arc::new(Arc::downgrade(&tx)),
217 }
218 }
219}
220
221impl ProjectState {
222 fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
223 let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
224 let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
225 Self {
226 worktrees: Default::default(),
227 pending_file_count_rx: pending_file_count_rx.clone(),
228 pending_file_count_tx,
229 pending_index: 0,
230 _subscription: subscription,
231 _observe_pending_file_count: cx.spawn_weak({
232 let mut pending_file_count_rx = pending_file_count_rx.clone();
233 |this, mut cx| async move {
234 while let Some(_) = pending_file_count_rx.next().await {
235 if let Some(this) = this.upgrade(&cx) {
236 this.update(&mut cx, |_, cx| cx.notify());
237 }
238 }
239 }
240 }),
241 }
242 }
243
244 fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
245 self.worktrees
246 .iter()
247 .find_map(|(worktree_id, worktree_state)| match worktree_state {
248 WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
249 _ => None,
250 })
251 }
252}
253
254#[derive(Clone)]
255pub struct PendingFile {
256 worktree_db_id: i64,
257 relative_path: Arc<Path>,
258 absolute_path: PathBuf,
259 language: Option<Arc<Language>>,
260 modified_time: SystemTime,
261 job_handle: JobHandle,
262}
263
264#[derive(Clone)]
265pub struct SearchResult {
266 pub buffer: ModelHandle<Buffer>,
267 pub range: Range<Anchor>,
268 pub similarity: OrderedFloat<f32>,
269}
270
271impl SemanticIndex {
272 pub fn global(cx: &mut AppContext) -> Option<ModelHandle<SemanticIndex>> {
273 if cx.has_global::<ModelHandle<Self>>() {
274 Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
275 } else {
276 None
277 }
278 }
279
280 pub fn authenticate(&mut self, cx: &AppContext) -> bool {
281 if !self.embedding_provider.has_credentials() {
282 self.embedding_provider.retrieve_credentials(cx);
283 } else {
284 return true;
285 }
286
287 self.embedding_provider.has_credentials()
288 }
289
290 pub fn is_authenticated(&self) -> bool {
291 self.embedding_provider.has_credentials()
292 }
293
294 pub fn enabled(cx: &AppContext) -> bool {
295 settings::get::<SemanticIndexSettings>(cx).enabled
296 }
297
298 pub fn status(&self, project: &ModelHandle<Project>) -> SemanticIndexStatus {
299 if !self.is_authenticated() {
300 return SemanticIndexStatus::NotAuthenticated;
301 }
302
303 if let Some(project_state) = self.projects.get(&project.downgrade()) {
304 if project_state
305 .worktrees
306 .values()
307 .all(|worktree| worktree.is_registered())
308 && project_state.pending_index == 0
309 {
310 SemanticIndexStatus::Indexed
311 } else {
312 SemanticIndexStatus::Indexing {
313 remaining_files: project_state.pending_file_count_rx.borrow().clone(),
314 rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
315 }
316 }
317 } else {
318 SemanticIndexStatus::NotIndexed
319 }
320 }
321
322 pub async fn new(
323 fs: Arc<dyn Fs>,
324 database_path: PathBuf,
325 embedding_provider: Arc<dyn EmbeddingProvider>,
326 language_registry: Arc<LanguageRegistry>,
327 mut cx: AsyncAppContext,
328 ) -> Result<ModelHandle<Self>> {
329 let t0 = Instant::now();
330 let database_path = Arc::from(database_path);
331 let db = VectorDatabase::new(fs.clone(), database_path, cx.background()).await?;
332
333 log::trace!(
334 "db initialization took {:?} milliseconds",
335 t0.elapsed().as_millis()
336 );
337
338 Ok(cx.add_model(|cx| {
339 let t0 = Instant::now();
340 let embedding_queue =
341 EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
342 let _embedding_task = cx.background().spawn({
343 let embedded_files = embedding_queue.finished_files();
344 let db = db.clone();
345 async move {
346 while let Ok(file) = embedded_files.recv().await {
347 db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
348 .await
349 .log_err();
350 }
351 }
352 });
353
354 // Parse files into embeddable spans.
355 let (parsing_files_tx, parsing_files_rx) =
356 channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
357 let embedding_queue = Arc::new(Mutex::new(embedding_queue));
358 let mut _parsing_files_tasks = Vec::new();
359 for _ in 0..cx.background().num_cpus() {
360 let fs = fs.clone();
361 let mut parsing_files_rx = parsing_files_rx.clone();
362 let embedding_provider = embedding_provider.clone();
363 let embedding_queue = embedding_queue.clone();
364 let background = cx.background().clone();
365 _parsing_files_tasks.push(cx.background().spawn(async move {
366 let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
367 loop {
368 let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
369 let mut next_file_to_parse = parsing_files_rx.next().fuse();
370 futures::select_biased! {
371 next_file_to_parse = next_file_to_parse => {
372 if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
373 Self::parse_file(
374 &fs,
375 pending_file,
376 &mut retriever,
377 &embedding_queue,
378 &embeddings_for_digest,
379 )
380 .await
381 } else {
382 break;
383 }
384 },
385 _ = timer => {
386 embedding_queue.lock().flush();
387 }
388 }
389 }
390 }));
391 }
392
393 log::trace!(
394 "semantic index task initialization took {:?} milliseconds",
395 t0.elapsed().as_millis()
396 );
397 Self {
398 fs,
399 db,
400 embedding_provider,
401 language_registry,
402 parsing_files_tx,
403 _embedding_task,
404 _parsing_files_tasks,
405 projects: Default::default(),
406 }
407 }))
408 }
409
410 async fn parse_file(
411 fs: &Arc<dyn Fs>,
412 pending_file: PendingFile,
413 retriever: &mut CodeContextRetriever,
414 embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
415 embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
416 ) {
417 let Some(language) = pending_file.language else {
418 return;
419 };
420
421 if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
422 if let Some(mut spans) = retriever
423 .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
424 .log_err()
425 {
426 log::trace!(
427 "parsed path {:?}: {} spans",
428 pending_file.relative_path,
429 spans.len()
430 );
431
432 for span in &mut spans {
433 if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
434 span.embedding = Some(embedding.to_owned());
435 }
436 }
437
438 embedding_queue.lock().push(FileToEmbed {
439 worktree_id: pending_file.worktree_db_id,
440 path: pending_file.relative_path,
441 mtime: pending_file.modified_time,
442 job_handle: pending_file.job_handle,
443 spans,
444 });
445 }
446 }
447 }
448
449 pub fn project_previously_indexed(
450 &mut self,
451 project: &ModelHandle<Project>,
452 cx: &mut ModelContext<Self>,
453 ) -> Task<Result<bool>> {
454 let worktrees_indexed_previously = project
455 .read(cx)
456 .worktrees(cx)
457 .map(|worktree| {
458 self.db
459 .worktree_previously_indexed(&worktree.read(cx).abs_path())
460 })
461 .collect::<Vec<_>>();
462 cx.spawn(|_, _cx| async move {
463 let worktree_indexed_previously =
464 futures::future::join_all(worktrees_indexed_previously).await;
465
466 Ok(worktree_indexed_previously
467 .iter()
468 .filter(|worktree| worktree.is_ok())
469 .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
470 })
471 }
472
473 fn project_entries_changed(
474 &mut self,
475 project: ModelHandle<Project>,
476 worktree_id: WorktreeId,
477 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
478 cx: &mut ModelContext<Self>,
479 ) {
480 let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
481 return;
482 };
483 let project = project.downgrade();
484 let Some(project_state) = self.projects.get_mut(&project) else {
485 return;
486 };
487
488 let worktree = worktree.read(cx);
489 let worktree_state =
490 if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
491 worktree_state
492 } else {
493 return;
494 };
495 worktree_state.paths_changed(changes, worktree);
496 if let WorktreeState::Registered(_) = worktree_state {
497 cx.spawn_weak(|this, mut cx| async move {
498 cx.background().timer(BACKGROUND_INDEXING_DELAY).await;
499 if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) {
500 this.update(&mut cx, |this, cx| {
501 this.index_project(project, cx).detach_and_log_err(cx)
502 });
503 }
504 })
505 .detach();
506 }
507 }
508
509 fn register_worktree(
510 &mut self,
511 project: ModelHandle<Project>,
512 worktree: ModelHandle<Worktree>,
513 cx: &mut ModelContext<Self>,
514 ) {
515 let project = project.downgrade();
516 let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
517 project_state
518 } else {
519 return;
520 };
521 let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
522 worktree
523 } else {
524 return;
525 };
526 let worktree_abs_path = worktree.abs_path().clone();
527 let scan_complete = worktree.scan_complete();
528 let worktree_id = worktree.id();
529 let db = self.db.clone();
530 let language_registry = self.language_registry.clone();
531 let (mut done_tx, done_rx) = watch::channel();
532 let registration = cx.spawn(|this, mut cx| {
533 async move {
534 let register = async {
535 scan_complete.await;
536 let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
537 let mut file_mtimes = db.get_file_mtimes(db_id).await?;
538 let worktree = if let Some(project) = project.upgrade(&cx) {
539 project
540 .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
541 .ok_or_else(|| anyhow!("worktree not found"))?
542 } else {
543 return anyhow::Ok(());
544 };
545 let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot());
546 let mut changed_paths = cx
547 .background()
548 .spawn(async move {
549 let mut changed_paths = BTreeMap::new();
550 for file in worktree.files(false, 0) {
551 let absolute_path = worktree.absolutize(&file.path);
552
553 if file.is_external || file.is_ignored || file.is_symlink {
554 continue;
555 }
556
557 if let Ok(language) = language_registry
558 .language_for_file(&absolute_path, None)
559 .await
560 {
561 // Test if file is valid parseable file
562 if !PARSEABLE_ENTIRE_FILE_TYPES
563 .contains(&language.name().as_ref())
564 && &language.name().as_ref() != &"Markdown"
565 && language
566 .grammar()
567 .and_then(|grammar| grammar.embedding_config.as_ref())
568 .is_none()
569 {
570 continue;
571 }
572
573 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
574 let already_stored = stored_mtime
575 .map_or(false, |existing_mtime| {
576 existing_mtime == file.mtime
577 });
578
579 if !already_stored {
580 changed_paths.insert(
581 file.path.clone(),
582 ChangedPathInfo {
583 mtime: file.mtime,
584 is_deleted: false,
585 },
586 );
587 }
588 }
589 }
590
591 // Clean up entries from database that are no longer in the worktree.
592 for (path, mtime) in file_mtimes {
593 changed_paths.insert(
594 path.into(),
595 ChangedPathInfo {
596 mtime,
597 is_deleted: true,
598 },
599 );
600 }
601
602 anyhow::Ok(changed_paths)
603 })
604 .await?;
605 this.update(&mut cx, |this, cx| {
606 let project_state = this
607 .projects
608 .get_mut(&project)
609 .ok_or_else(|| anyhow!("project not registered"))?;
610 let project = project
611 .upgrade(cx)
612 .ok_or_else(|| anyhow!("project was dropped"))?;
613
614 if let Some(WorktreeState::Registering(state)) =
615 project_state.worktrees.remove(&worktree_id)
616 {
617 changed_paths.extend(state.changed_paths);
618 }
619 project_state.worktrees.insert(
620 worktree_id,
621 WorktreeState::Registered(RegisteredWorktreeState {
622 db_id,
623 changed_paths,
624 }),
625 );
626 this.index_project(project, cx).detach_and_log_err(cx);
627
628 anyhow::Ok(())
629 })?;
630
631 anyhow::Ok(())
632 };
633
634 if register.await.log_err().is_none() {
635 // Stop tracking this worktree if the registration failed.
636 this.update(&mut cx, |this, _| {
637 this.projects.get_mut(&project).map(|project_state| {
638 project_state.worktrees.remove(&worktree_id);
639 });
640 })
641 }
642
643 *done_tx.borrow_mut() = Some(());
644 }
645 });
646 project_state.worktrees.insert(
647 worktree_id,
648 WorktreeState::Registering(RegisteringWorktreeState {
649 changed_paths: Default::default(),
650 done_rx,
651 _registration: registration,
652 }),
653 );
654 }
655
656 fn project_worktrees_changed(
657 &mut self,
658 project: ModelHandle<Project>,
659 cx: &mut ModelContext<Self>,
660 ) {
661 let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
662 {
663 project_state
664 } else {
665 return;
666 };
667
668 let mut worktrees = project
669 .read(cx)
670 .worktrees(cx)
671 .filter(|worktree| worktree.read(cx).is_local())
672 .collect::<Vec<_>>();
673 let worktree_ids = worktrees
674 .iter()
675 .map(|worktree| worktree.read(cx).id())
676 .collect::<HashSet<_>>();
677
678 // Remove worktrees that are no longer present
679 project_state
680 .worktrees
681 .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
682
683 // Register new worktrees
684 worktrees.retain(|worktree| {
685 let worktree_id = worktree.read(cx).id();
686 !project_state.worktrees.contains_key(&worktree_id)
687 });
688 for worktree in worktrees {
689 self.register_worktree(project.clone(), worktree, cx);
690 }
691 }
692
693 pub fn pending_file_count(
694 &self,
695 project: &ModelHandle<Project>,
696 ) -> Option<watch::Receiver<usize>> {
697 Some(
698 self.projects
699 .get(&project.downgrade())?
700 .pending_file_count_rx
701 .clone(),
702 )
703 }
704
705 pub fn search_project(
706 &mut self,
707 project: ModelHandle<Project>,
708 query: String,
709 limit: usize,
710 includes: Vec<PathMatcher>,
711 excludes: Vec<PathMatcher>,
712 cx: &mut ModelContext<Self>,
713 ) -> Task<Result<Vec<SearchResult>>> {
714 if query.is_empty() {
715 return Task::ready(Ok(Vec::new()));
716 }
717
718 let index = self.index_project(project.clone(), cx);
719 let embedding_provider = self.embedding_provider.clone();
720
721 cx.spawn(|this, mut cx| async move {
722 index.await?;
723 let t0 = Instant::now();
724
725 let query = embedding_provider
726 .embed_batch(vec![query])
727 .await?
728 .pop()
729 .ok_or_else(|| anyhow!("could not embed query"))?;
730 log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
731
732 let search_start = Instant::now();
733 let modified_buffer_results = this.update(&mut cx, |this, cx| {
734 this.search_modified_buffers(
735 &project,
736 query.clone(),
737 limit,
738 &includes,
739 &excludes,
740 cx,
741 )
742 });
743 let file_results = this.update(&mut cx, |this, cx| {
744 this.search_files(project, query, limit, includes, excludes, cx)
745 });
746 let (modified_buffer_results, file_results) =
747 futures::join!(modified_buffer_results, file_results);
748
749 // Weave together the results from modified buffers and files.
750 let mut results = Vec::new();
751 let mut modified_buffers = HashSet::default();
752 for result in modified_buffer_results.log_err().unwrap_or_default() {
753 modified_buffers.insert(result.buffer.clone());
754 results.push(result);
755 }
756 for result in file_results.log_err().unwrap_or_default() {
757 if !modified_buffers.contains(&result.buffer) {
758 results.push(result);
759 }
760 }
761 results.sort_by_key(|result| Reverse(result.similarity));
762 results.truncate(limit);
763 log::trace!("Semantic search took {:?}", search_start.elapsed());
764 Ok(results)
765 })
766 }
767
768 pub fn search_files(
769 &mut self,
770 project: ModelHandle<Project>,
771 query: Embedding,
772 limit: usize,
773 includes: Vec<PathMatcher>,
774 excludes: Vec<PathMatcher>,
775 cx: &mut ModelContext<Self>,
776 ) -> Task<Result<Vec<SearchResult>>> {
777 let db_path = self.db.path().clone();
778 let fs = self.fs.clone();
779 cx.spawn(|this, mut cx| async move {
780 let database =
781 VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
782
783 let worktree_db_ids = this.read_with(&cx, |this, _| {
784 let project_state = this
785 .projects
786 .get(&project.downgrade())
787 .ok_or_else(|| anyhow!("project was not indexed"))?;
788 let worktree_db_ids = project_state
789 .worktrees
790 .values()
791 .filter_map(|worktree| {
792 if let WorktreeState::Registered(worktree) = worktree {
793 Some(worktree.db_id)
794 } else {
795 None
796 }
797 })
798 .collect::<Vec<i64>>();
799 anyhow::Ok(worktree_db_ids)
800 })?;
801
802 let file_ids = database
803 .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
804 .await?;
805
806 let batch_n = cx.background().num_cpus();
807 let ids_len = file_ids.clone().len();
808 let minimum_batch_size = 50;
809
810 let batch_size = {
811 let size = ids_len / batch_n;
812 if size < minimum_batch_size {
813 minimum_batch_size
814 } else {
815 size
816 }
817 };
818
819 let mut batch_results = Vec::new();
820 for batch in file_ids.chunks(batch_size) {
821 let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
822 let limit = limit.clone();
823 let fs = fs.clone();
824 let db_path = db_path.clone();
825 let query = query.clone();
826 if let Some(db) = VectorDatabase::new(fs, db_path.clone(), cx.background())
827 .await
828 .log_err()
829 {
830 batch_results.push(async move {
831 db.top_k_search(&query, limit, batch.as_slice()).await
832 });
833 }
834 }
835
836 let batch_results = futures::future::join_all(batch_results).await;
837
838 let mut results = Vec::new();
839 for batch_result in batch_results {
840 if batch_result.is_ok() {
841 for (id, similarity) in batch_result.unwrap() {
842 let ix = match results
843 .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
844 {
845 Ok(ix) => ix,
846 Err(ix) => ix,
847 };
848
849 results.insert(ix, (id, similarity));
850 results.truncate(limit);
851 }
852 }
853 }
854
855 let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
856 let scores = results
857 .into_iter()
858 .map(|(_, score)| score)
859 .collect::<Vec<_>>();
860 let spans = database.spans_for_ids(ids.as_slice()).await?;
861
862 let mut tasks = Vec::new();
863 let mut ranges = Vec::new();
864 let weak_project = project.downgrade();
865 project.update(&mut cx, |project, cx| {
866 for (worktree_db_id, file_path, byte_range) in spans {
867 let project_state =
868 if let Some(state) = this.read(cx).projects.get(&weak_project) {
869 state
870 } else {
871 return Err(anyhow!("project not added"));
872 };
873 if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
874 tasks.push(project.open_buffer((worktree_id, file_path), cx));
875 ranges.push(byte_range);
876 }
877 }
878
879 Ok(())
880 })?;
881
882 let buffers = futures::future::join_all(tasks).await;
883 Ok(buffers
884 .into_iter()
885 .zip(ranges)
886 .zip(scores)
887 .filter_map(|((buffer, range), similarity)| {
888 let buffer = buffer.log_err()?;
889 let range = buffer.read_with(&cx, |buffer, _| {
890 let start = buffer.clip_offset(range.start, Bias::Left);
891 let end = buffer.clip_offset(range.end, Bias::Right);
892 buffer.anchor_before(start)..buffer.anchor_after(end)
893 });
894 Some(SearchResult {
895 buffer,
896 range,
897 similarity,
898 })
899 })
900 .collect())
901 })
902 }
903
904 fn search_modified_buffers(
905 &self,
906 project: &ModelHandle<Project>,
907 query: Embedding,
908 limit: usize,
909 includes: &[PathMatcher],
910 excludes: &[PathMatcher],
911 cx: &mut ModelContext<Self>,
912 ) -> Task<Result<Vec<SearchResult>>> {
913 let modified_buffers = project
914 .read(cx)
915 .opened_buffers(cx)
916 .into_iter()
917 .filter_map(|buffer_handle| {
918 let buffer = buffer_handle.read(cx);
919 let snapshot = buffer.snapshot();
920 let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
921 excludes.iter().any(|matcher| matcher.is_match(&path))
922 });
923
924 let included = if includes.len() == 0 {
925 true
926 } else {
927 snapshot.resolve_file_path(cx, false).map_or(false, |path| {
928 includes.iter().any(|matcher| matcher.is_match(&path))
929 })
930 };
931
932 if buffer.is_dirty() && !excluded && included {
933 Some((buffer_handle, snapshot))
934 } else {
935 None
936 }
937 })
938 .collect::<HashMap<_, _>>();
939
940 let embedding_provider = self.embedding_provider.clone();
941 let fs = self.fs.clone();
942 let db_path = self.db.path().clone();
943 let background = cx.background().clone();
944 cx.background().spawn(async move {
945 let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
946 let mut results = Vec::<SearchResult>::new();
947
948 let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
949 for (buffer, snapshot) in modified_buffers {
950 let language = snapshot
951 .language_at(0)
952 .cloned()
953 .unwrap_or_else(|| language::PLAIN_TEXT.clone());
954 let mut spans = retriever
955 .parse_file_with_template(None, &snapshot.text(), language)
956 .log_err()
957 .unwrap_or_default();
958 if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
959 .await
960 .log_err()
961 .is_some()
962 {
963 for span in spans {
964 let similarity = span.embedding.unwrap().similarity(&query);
965 let ix = match results
966 .binary_search_by_key(&Reverse(similarity), |result| {
967 Reverse(result.similarity)
968 }) {
969 Ok(ix) => ix,
970 Err(ix) => ix,
971 };
972
973 let range = {
974 let start = snapshot.clip_offset(span.range.start, Bias::Left);
975 let end = snapshot.clip_offset(span.range.end, Bias::Right);
976 snapshot.anchor_before(start)..snapshot.anchor_after(end)
977 };
978
979 results.insert(
980 ix,
981 SearchResult {
982 buffer: buffer.clone(),
983 range,
984 similarity,
985 },
986 );
987 results.truncate(limit);
988 }
989 }
990 }
991
992 Ok(results)
993 })
994 }
995
996 pub fn index_project(
997 &mut self,
998 project: ModelHandle<Project>,
999 cx: &mut ModelContext<Self>,
1000 ) -> Task<Result<()>> {
1001 if !self.is_authenticated() {
1002 if !self.authenticate(cx) {
1003 return Task::ready(Err(anyhow!("user is not authenticated")));
1004 }
1005 }
1006
1007 if !self.projects.contains_key(&project.downgrade()) {
1008 let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
1009 project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
1010 this.project_worktrees_changed(project.clone(), cx);
1011 }
1012 project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
1013 this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
1014 }
1015 _ => {}
1016 });
1017 let project_state = ProjectState::new(subscription, cx);
1018 self.projects.insert(project.downgrade(), project_state);
1019 self.project_worktrees_changed(project.clone(), cx);
1020 }
1021 let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
1022 project_state.pending_index += 1;
1023 cx.notify();
1024
1025 let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
1026 let db = self.db.clone();
1027 let language_registry = self.language_registry.clone();
1028 let parsing_files_tx = self.parsing_files_tx.clone();
1029 let worktree_registration = self.wait_for_worktree_registration(&project, cx);
1030
1031 cx.spawn(|this, mut cx| async move {
1032 worktree_registration.await?;
1033
1034 let mut pending_files = Vec::new();
1035 let mut files_to_delete = Vec::new();
1036 this.update(&mut cx, |this, cx| {
1037 let project_state = this
1038 .projects
1039 .get_mut(&project.downgrade())
1040 .ok_or_else(|| anyhow!("project was dropped"))?;
1041 let pending_file_count_tx = &project_state.pending_file_count_tx;
1042
1043 project_state
1044 .worktrees
1045 .retain(|worktree_id, worktree_state| {
1046 let worktree = if let Some(worktree) =
1047 project.read(cx).worktree_for_id(*worktree_id, cx)
1048 {
1049 worktree
1050 } else {
1051 return false;
1052 };
1053 let worktree_state =
1054 if let WorktreeState::Registered(worktree_state) = worktree_state {
1055 worktree_state
1056 } else {
1057 return true;
1058 };
1059
1060 worktree_state.changed_paths.retain(|path, info| {
1061 if info.is_deleted {
1062 files_to_delete.push((worktree_state.db_id, path.clone()));
1063 } else {
1064 let absolute_path = worktree.read(cx).absolutize(path);
1065 let job_handle = JobHandle::new(pending_file_count_tx);
1066 pending_files.push(PendingFile {
1067 absolute_path,
1068 relative_path: path.clone(),
1069 language: None,
1070 job_handle,
1071 modified_time: info.mtime,
1072 worktree_db_id: worktree_state.db_id,
1073 });
1074 }
1075
1076 false
1077 });
1078 true
1079 });
1080
1081 anyhow::Ok(())
1082 })?;
1083
1084 cx.background()
1085 .spawn(async move {
1086 for (worktree_db_id, path) in files_to_delete {
1087 db.delete_file(worktree_db_id, path).await.log_err();
1088 }
1089
1090 let embeddings_for_digest = {
1091 let mut files = HashMap::default();
1092 for pending_file in &pending_files {
1093 files
1094 .entry(pending_file.worktree_db_id)
1095 .or_insert(Vec::new())
1096 .push(pending_file.relative_path.clone());
1097 }
1098 Arc::new(
1099 db.embeddings_for_files(files)
1100 .await
1101 .log_err()
1102 .unwrap_or_default(),
1103 )
1104 };
1105
1106 for mut pending_file in pending_files {
1107 if let Ok(language) = language_registry
1108 .language_for_file(&pending_file.relative_path, None)
1109 .await
1110 {
1111 if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
1112 && &language.name().as_ref() != &"Markdown"
1113 && language
1114 .grammar()
1115 .and_then(|grammar| grammar.embedding_config.as_ref())
1116 .is_none()
1117 {
1118 continue;
1119 }
1120 pending_file.language = Some(language);
1121 }
1122 parsing_files_tx
1123 .try_send((embeddings_for_digest.clone(), pending_file))
1124 .ok();
1125 }
1126
1127 // Wait until we're done indexing.
1128 while let Some(count) = pending_file_count_rx.next().await {
1129 if count == 0 {
1130 break;
1131 }
1132 }
1133 })
1134 .await;
1135
1136 this.update(&mut cx, |this, cx| {
1137 let project_state = this
1138 .projects
1139 .get_mut(&project.downgrade())
1140 .ok_or_else(|| anyhow!("project was dropped"))?;
1141 project_state.pending_index -= 1;
1142 cx.notify();
1143 anyhow::Ok(())
1144 })?;
1145
1146 Ok(())
1147 })
1148 }
1149
1150 fn wait_for_worktree_registration(
1151 &self,
1152 project: &ModelHandle<Project>,
1153 cx: &mut ModelContext<Self>,
1154 ) -> Task<Result<()>> {
1155 let project = project.downgrade();
1156 cx.spawn_weak(|this, cx| async move {
1157 loop {
1158 let mut pending_worktrees = Vec::new();
1159 this.upgrade(&cx)
1160 .ok_or_else(|| anyhow!("semantic index dropped"))?
1161 .read_with(&cx, |this, _| {
1162 if let Some(project) = this.projects.get(&project) {
1163 for worktree in project.worktrees.values() {
1164 if let WorktreeState::Registering(worktree) = worktree {
1165 pending_worktrees.push(worktree.done());
1166 }
1167 }
1168 }
1169 });
1170
1171 if pending_worktrees.is_empty() {
1172 break;
1173 } else {
1174 future::join_all(pending_worktrees).await;
1175 }
1176 }
1177 Ok(())
1178 })
1179 }
1180
1181 async fn embed_spans(
1182 spans: &mut [Span],
1183 embedding_provider: &dyn EmbeddingProvider,
1184 db: &VectorDatabase,
1185 ) -> Result<()> {
1186 let mut batch = Vec::new();
1187 let mut batch_tokens = 0;
1188 let mut embeddings = Vec::new();
1189
1190 let digests = spans
1191 .iter()
1192 .map(|span| span.digest.clone())
1193 .collect::<Vec<_>>();
1194 let embeddings_for_digests = db
1195 .embeddings_for_digests(digests)
1196 .await
1197 .log_err()
1198 .unwrap_or_default();
1199
1200 for span in &*spans {
1201 if embeddings_for_digests.contains_key(&span.digest) {
1202 continue;
1203 };
1204
1205 if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
1206 let batch_embeddings = embedding_provider
1207 .embed_batch(mem::take(&mut batch))
1208 .await?;
1209 embeddings.extend(batch_embeddings);
1210 batch_tokens = 0;
1211 }
1212
1213 batch_tokens += span.token_count;
1214 batch.push(span.content.clone());
1215 }
1216
1217 if !batch.is_empty() {
1218 let batch_embeddings = embedding_provider
1219 .embed_batch(mem::take(&mut batch))
1220 .await?;
1221
1222 embeddings.extend(batch_embeddings);
1223 }
1224
1225 let mut embeddings = embeddings.into_iter();
1226 for span in spans {
1227 let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
1228 Some(embedding.clone())
1229 } else {
1230 embeddings.next()
1231 };
1232 let embedding = embedding.ok_or_else(|| anyhow!("failed to embed spans"))?;
1233 span.embedding = Some(embedding);
1234 }
1235 Ok(())
1236 }
1237}
1238
1239impl Entity for SemanticIndex {
1240 type Event = ();
1241}
1242
1243impl Drop for JobHandle {
1244 fn drop(&mut self) {
1245 if let Some(inner) = Arc::get_mut(&mut self.tx) {
1246 // This is the last instance of the JobHandle (regardless of it's origin - whether it was cloned or not)
1247 if let Some(tx) = inner.upgrade() {
1248 let mut tx = tx.lock();
1249 *tx.borrow_mut() -= 1;
1250 }
1251 }
1252 }
1253}
1254
1255#[cfg(test)]
1256mod tests {
1257
1258 use super::*;
1259 #[test]
1260 fn test_job_handle() {
1261 let (job_count_tx, job_count_rx) = watch::channel_with(0);
1262 let tx = Arc::new(Mutex::new(job_count_tx));
1263 let job_handle = JobHandle::new(&tx);
1264
1265 assert_eq!(1, *job_count_rx.borrow());
1266 let new_job_handle = job_handle.clone();
1267 assert_eq!(1, *job_count_rx.borrow());
1268 drop(job_handle);
1269 assert_eq!(1, *job_count_rx.borrow());
1270 drop(new_job_handle);
1271 assert_eq!(0, *job_count_rx.borrow());
1272 }
1273}