1mod db;
2mod embedding_queue;
3mod parsing;
4pub mod semantic_index_settings;
5
6#[cfg(test)]
7mod semantic_index_tests;
8
9use crate::semantic_index_settings::SemanticIndexSettings;
10use ai::embedding::{Embedding, EmbeddingProvider};
11use ai::providers::open_ai::OpenAIEmbeddingProvider;
12use anyhow::{anyhow, Context as _, Result};
13use collections::{BTreeMap, HashMap, HashSet};
14use db::VectorDatabase;
15use embedding_queue::{EmbeddingQueue, FileToEmbed};
16use futures::{future, FutureExt, StreamExt};
17use gpui::{
18 AppContext, AsyncAppContext, BorrowWindow, Context, Model, ModelContext, Task, ViewContext,
19 WeakModel,
20};
21use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
22use lazy_static::lazy_static;
23use ordered_float::OrderedFloat;
24use parking_lot::Mutex;
25use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
26use postage::watch;
27use project::{Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
28use settings::Settings;
29use smol::channel;
30use std::{
31 cmp::Reverse,
32 env,
33 future::Future,
34 mem,
35 ops::Range,
36 path::{Path, PathBuf},
37 sync::{Arc, Weak},
38 time::{Duration, Instant, SystemTime},
39};
40use util::paths::PathMatcher;
41use util::{channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
42use workspace::Workspace;
43
44const SEMANTIC_INDEX_VERSION: usize = 11;
45const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
46const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
47
48lazy_static! {
49 static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
50}
51
52pub fn init(
53 fs: Arc<dyn Fs>,
54 http_client: Arc<dyn HttpClient>,
55 language_registry: Arc<LanguageRegistry>,
56 cx: &mut AppContext,
57) {
58 SemanticIndexSettings::register(cx);
59
60 let db_file_path = EMBEDDINGS_DIR
61 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
62 .join("embeddings_db");
63
64 cx.observe_new_views(
65 |workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
66 let Some(semantic_index) = SemanticIndex::global(cx) else {
67 return;
68 };
69 let project = workspace.project().clone();
70
71 if project.read(cx).is_local() {
72 cx.app_mut()
73 .spawn(|mut cx| async move {
74 let previously_indexed = semantic_index
75 .update(&mut cx, |index, cx| {
76 index.project_previously_indexed(&project, cx)
77 })?
78 .await?;
79 if previously_indexed {
80 semantic_index
81 .update(&mut cx, |index, cx| index.index_project(project, cx))?
82 .await?;
83 }
84 anyhow::Ok(())
85 })
86 .detach_and_log_err(cx);
87 }
88 },
89 )
90 .detach();
91
92 cx.spawn(move |cx| async move {
93 let embedding_provider =
94 OpenAIEmbeddingProvider::new(http_client, cx.background_executor().clone()).await;
95 let semantic_index = SemanticIndex::new(
96 fs,
97 db_file_path,
98 Arc::new(embedding_provider),
99 language_registry,
100 cx.clone(),
101 )
102 .await?;
103
104 cx.update(|cx| cx.set_global(semantic_index.clone()))?;
105
106 anyhow::Ok(())
107 })
108 .detach();
109}
110
111#[derive(Copy, Clone, Debug)]
112pub enum SemanticIndexStatus {
113 NotAuthenticated,
114 NotIndexed,
115 Indexed,
116 Indexing {
117 remaining_files: usize,
118 rate_limit_expiry: Option<Instant>,
119 },
120}
121
122pub struct SemanticIndex {
123 fs: Arc<dyn Fs>,
124 db: VectorDatabase,
125 embedding_provider: Arc<dyn EmbeddingProvider>,
126 language_registry: Arc<LanguageRegistry>,
127 parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
128 _embedding_task: Task<()>,
129 _parsing_files_tasks: Vec<Task<()>>,
130 projects: HashMap<WeakModel<Project>, ProjectState>,
131}
132
133struct ProjectState {
134 worktrees: HashMap<WorktreeId, WorktreeState>,
135 pending_file_count_rx: watch::Receiver<usize>,
136 pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
137 pending_index: usize,
138 _subscription: gpui::Subscription,
139 _observe_pending_file_count: Task<()>,
140}
141
142enum WorktreeState {
143 Registering(RegisteringWorktreeState),
144 Registered(RegisteredWorktreeState),
145}
146
147impl WorktreeState {
148 fn is_registered(&self) -> bool {
149 matches!(self, Self::Registered(_))
150 }
151
152 fn paths_changed(
153 &mut self,
154 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
155 worktree: &Worktree,
156 ) {
157 let changed_paths = match self {
158 Self::Registering(state) => &mut state.changed_paths,
159 Self::Registered(state) => &mut state.changed_paths,
160 };
161
162 for (path, entry_id, change) in changes.iter() {
163 let Some(entry) = worktree.entry_for_id(*entry_id) else {
164 continue;
165 };
166 if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
167 continue;
168 }
169 changed_paths.insert(
170 path.clone(),
171 ChangedPathInfo {
172 mtime: entry.mtime,
173 is_deleted: *change == PathChange::Removed,
174 },
175 );
176 }
177 }
178}
179
180struct RegisteringWorktreeState {
181 changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
182 done_rx: watch::Receiver<Option<()>>,
183 _registration: Task<()>,
184}
185
186impl RegisteringWorktreeState {
187 fn done(&self) -> impl Future<Output = ()> {
188 let mut done_rx = self.done_rx.clone();
189 async move {
190 while let Some(result) = done_rx.next().await {
191 if result.is_some() {
192 break;
193 }
194 }
195 }
196 }
197}
198
199struct RegisteredWorktreeState {
200 db_id: i64,
201 changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
202}
203
204struct ChangedPathInfo {
205 mtime: SystemTime,
206 is_deleted: bool,
207}
208
209#[derive(Clone)]
210pub struct JobHandle {
211 /// The outer Arc is here to count the clones of a JobHandle instance;
212 /// when the last handle to a given job is dropped, we decrement a counter (just once).
213 tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
214}
215
216impl JobHandle {
217 fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
218 *tx.lock().borrow_mut() += 1;
219 Self {
220 tx: Arc::new(Arc::downgrade(&tx)),
221 }
222 }
223}
224
225impl ProjectState {
226 fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
227 let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
228 let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
229 Self {
230 worktrees: Default::default(),
231 pending_file_count_rx: pending_file_count_rx.clone(),
232 pending_file_count_tx,
233 pending_index: 0,
234 _subscription: subscription,
235 _observe_pending_file_count: cx.spawn({
236 let mut pending_file_count_rx = pending_file_count_rx.clone();
237 |this, mut cx| async move {
238 while let Some(_) = pending_file_count_rx.next().await {
239 if this.update(&mut cx, |_, cx| cx.notify()).is_err() {
240 break;
241 }
242 }
243 }
244 }),
245 }
246 }
247
248 fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
249 self.worktrees
250 .iter()
251 .find_map(|(worktree_id, worktree_state)| match worktree_state {
252 WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
253 _ => None,
254 })
255 }
256}
257
258#[derive(Clone)]
259pub struct PendingFile {
260 worktree_db_id: i64,
261 relative_path: Arc<Path>,
262 absolute_path: PathBuf,
263 language: Option<Arc<Language>>,
264 modified_time: SystemTime,
265 job_handle: JobHandle,
266}
267
268#[derive(Clone)]
269pub struct SearchResult {
270 pub buffer: Model<Buffer>,
271 pub range: Range<Anchor>,
272 pub similarity: OrderedFloat<f32>,
273}
274
275impl SemanticIndex {
276 pub fn global(cx: &mut AppContext) -> Option<Model<SemanticIndex>> {
277 cx.try_global::<Model<Self>>()
278 .map(|semantic_index| semantic_index.clone())
279 }
280
281 pub fn authenticate(&mut self, cx: &mut AppContext) -> Task<bool> {
282 if !self.embedding_provider.has_credentials() {
283 let embedding_provider = self.embedding_provider.clone();
284 cx.spawn(|cx| async move {
285 if let Some(retrieve_credentials) = cx
286 .update(|cx| embedding_provider.retrieve_credentials(cx))
287 .log_err()
288 {
289 retrieve_credentials.await;
290 }
291
292 embedding_provider.has_credentials()
293 })
294 } else {
295 Task::ready(true)
296 }
297 }
298
299 pub fn is_authenticated(&self) -> bool {
300 self.embedding_provider.has_credentials()
301 }
302
303 pub fn enabled(cx: &AppContext) -> bool {
304 SemanticIndexSettings::get_global(cx).enabled
305 }
306
307 pub fn status(&self, project: &Model<Project>) -> SemanticIndexStatus {
308 if !self.is_authenticated() {
309 return SemanticIndexStatus::NotAuthenticated;
310 }
311
312 if let Some(project_state) = self.projects.get(&project.downgrade()) {
313 if project_state
314 .worktrees
315 .values()
316 .all(|worktree| worktree.is_registered())
317 && project_state.pending_index == 0
318 {
319 SemanticIndexStatus::Indexed
320 } else {
321 SemanticIndexStatus::Indexing {
322 remaining_files: project_state.pending_file_count_rx.borrow().clone(),
323 rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
324 }
325 }
326 } else {
327 SemanticIndexStatus::NotIndexed
328 }
329 }
330
331 pub async fn new(
332 fs: Arc<dyn Fs>,
333 database_path: PathBuf,
334 embedding_provider: Arc<dyn EmbeddingProvider>,
335 language_registry: Arc<LanguageRegistry>,
336 mut cx: AsyncAppContext,
337 ) -> Result<Model<Self>> {
338 let t0 = Instant::now();
339 let database_path = Arc::from(database_path);
340 let db = VectorDatabase::new(fs.clone(), database_path, cx.background_executor().clone())
341 .await?;
342
343 log::trace!(
344 "db initialization took {:?} milliseconds",
345 t0.elapsed().as_millis()
346 );
347
348 cx.new_model(|cx| {
349 let t0 = Instant::now();
350 let embedding_queue =
351 EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor().clone());
352 let _embedding_task = cx.background_executor().spawn({
353 let embedded_files = embedding_queue.finished_files();
354 let db = db.clone();
355 async move {
356 while let Ok(file) = embedded_files.recv().await {
357 db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
358 .await
359 .log_err();
360 }
361 }
362 });
363
364 // Parse files into embeddable spans.
365 let (parsing_files_tx, parsing_files_rx) =
366 channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
367 let embedding_queue = Arc::new(Mutex::new(embedding_queue));
368 let mut _parsing_files_tasks = Vec::new();
369 for _ in 0..cx.background_executor().num_cpus() {
370 let fs = fs.clone();
371 let mut parsing_files_rx = parsing_files_rx.clone();
372 let embedding_provider = embedding_provider.clone();
373 let embedding_queue = embedding_queue.clone();
374 let background = cx.background_executor().clone();
375 _parsing_files_tasks.push(cx.background_executor().spawn(async move {
376 let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
377 loop {
378 let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
379 let mut next_file_to_parse = parsing_files_rx.next().fuse();
380 futures::select_biased! {
381 next_file_to_parse = next_file_to_parse => {
382 if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
383 Self::parse_file(
384 &fs,
385 pending_file,
386 &mut retriever,
387 &embedding_queue,
388 &embeddings_for_digest,
389 )
390 .await
391 } else {
392 break;
393 }
394 },
395 _ = timer => {
396 embedding_queue.lock().flush();
397 }
398 }
399 }
400 }));
401 }
402
403 log::trace!(
404 "semantic index task initialization took {:?} milliseconds",
405 t0.elapsed().as_millis()
406 );
407 Self {
408 fs,
409 db,
410 embedding_provider,
411 language_registry,
412 parsing_files_tx,
413 _embedding_task,
414 _parsing_files_tasks,
415 projects: Default::default(),
416 }
417 })
418 }
419
420 async fn parse_file(
421 fs: &Arc<dyn Fs>,
422 pending_file: PendingFile,
423 retriever: &mut CodeContextRetriever,
424 embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
425 embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
426 ) {
427 let Some(language) = pending_file.language else {
428 return;
429 };
430
431 if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
432 if let Some(mut spans) = retriever
433 .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
434 .log_err()
435 {
436 log::trace!(
437 "parsed path {:?}: {} spans",
438 pending_file.relative_path,
439 spans.len()
440 );
441
442 for span in &mut spans {
443 if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
444 span.embedding = Some(embedding.to_owned());
445 }
446 }
447
448 embedding_queue.lock().push(FileToEmbed {
449 worktree_id: pending_file.worktree_db_id,
450 path: pending_file.relative_path,
451 mtime: pending_file.modified_time,
452 job_handle: pending_file.job_handle,
453 spans,
454 });
455 }
456 }
457 }
458
459 pub fn project_previously_indexed(
460 &mut self,
461 project: &Model<Project>,
462 cx: &mut ModelContext<Self>,
463 ) -> Task<Result<bool>> {
464 let worktrees_indexed_previously = project
465 .read(cx)
466 .worktrees()
467 .map(|worktree| {
468 self.db
469 .worktree_previously_indexed(&worktree.read(cx).abs_path())
470 })
471 .collect::<Vec<_>>();
472 cx.spawn(|_, _cx| async move {
473 let worktree_indexed_previously =
474 futures::future::join_all(worktrees_indexed_previously).await;
475
476 Ok(worktree_indexed_previously
477 .iter()
478 .filter(|worktree| worktree.is_ok())
479 .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
480 })
481 }
482
483 fn project_entries_changed(
484 &mut self,
485 project: Model<Project>,
486 worktree_id: WorktreeId,
487 changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
488 cx: &mut ModelContext<Self>,
489 ) {
490 let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
491 return;
492 };
493 let project = project.downgrade();
494 let Some(project_state) = self.projects.get_mut(&project) else {
495 return;
496 };
497
498 let worktree = worktree.read(cx);
499 let worktree_state =
500 if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
501 worktree_state
502 } else {
503 return;
504 };
505 worktree_state.paths_changed(changes, worktree);
506 if let WorktreeState::Registered(_) = worktree_state {
507 cx.spawn(|this, mut cx| async move {
508 cx.background_executor()
509 .timer(BACKGROUND_INDEXING_DELAY)
510 .await;
511 if let Some((this, project)) = this.upgrade().zip(project.upgrade()) {
512 this.update(&mut cx, |this, cx| {
513 this.index_project(project, cx).detach_and_log_err(cx)
514 })?;
515 }
516 anyhow::Ok(())
517 })
518 .detach_and_log_err(cx);
519 }
520 }
521
522 fn register_worktree(
523 &mut self,
524 project: Model<Project>,
525 worktree: Model<Worktree>,
526 cx: &mut ModelContext<Self>,
527 ) {
528 let project = project.downgrade();
529 let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
530 project_state
531 } else {
532 return;
533 };
534 let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
535 worktree
536 } else {
537 return;
538 };
539 let worktree_abs_path = worktree.abs_path().clone();
540 let scan_complete = worktree.scan_complete();
541 let worktree_id = worktree.id();
542 let db = self.db.clone();
543 let language_registry = self.language_registry.clone();
544 let (mut done_tx, done_rx) = watch::channel();
545 let registration = cx.spawn(|this, mut cx| {
546 async move {
547 let register = async {
548 scan_complete.await;
549 let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
550 let mut file_mtimes = db.get_file_mtimes(db_id).await?;
551 let worktree = if let Some(project) = project.upgrade() {
552 project
553 .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
554 .ok()
555 .flatten()
556 .context("worktree not found")?
557 } else {
558 return anyhow::Ok(());
559 };
560 let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot())?;
561 let mut changed_paths = cx
562 .background_executor()
563 .spawn(async move {
564 let mut changed_paths = BTreeMap::new();
565 for file in worktree.files(false, 0) {
566 let absolute_path = worktree.absolutize(&file.path)?;
567
568 if file.is_external || file.is_ignored || file.is_symlink {
569 continue;
570 }
571
572 if let Ok(language) = language_registry
573 .language_for_file(&absolute_path, None)
574 .await
575 {
576 // Test if file is valid parseable file
577 if !PARSEABLE_ENTIRE_FILE_TYPES
578 .contains(&language.name().as_ref())
579 && &language.name().as_ref() != &"Markdown"
580 && language
581 .grammar()
582 .and_then(|grammar| grammar.embedding_config.as_ref())
583 .is_none()
584 {
585 continue;
586 }
587
588 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
589 let already_stored = stored_mtime
590 .map_or(false, |existing_mtime| {
591 existing_mtime == file.mtime
592 });
593
594 if !already_stored {
595 changed_paths.insert(
596 file.path.clone(),
597 ChangedPathInfo {
598 mtime: file.mtime,
599 is_deleted: false,
600 },
601 );
602 }
603 }
604 }
605
606 // Clean up entries from database that are no longer in the worktree.
607 for (path, mtime) in file_mtimes {
608 changed_paths.insert(
609 path.into(),
610 ChangedPathInfo {
611 mtime,
612 is_deleted: true,
613 },
614 );
615 }
616
617 anyhow::Ok(changed_paths)
618 })
619 .await?;
620 this.update(&mut cx, |this, cx| {
621 let project_state = this
622 .projects
623 .get_mut(&project)
624 .context("project not registered")?;
625 let project = project.upgrade().context("project was dropped")?;
626
627 if let Some(WorktreeState::Registering(state)) =
628 project_state.worktrees.remove(&worktree_id)
629 {
630 changed_paths.extend(state.changed_paths);
631 }
632 project_state.worktrees.insert(
633 worktree_id,
634 WorktreeState::Registered(RegisteredWorktreeState {
635 db_id,
636 changed_paths,
637 }),
638 );
639 this.index_project(project, cx).detach_and_log_err(cx);
640
641 anyhow::Ok(())
642 })??;
643
644 anyhow::Ok(())
645 };
646
647 if register.await.log_err().is_none() {
648 // Stop tracking this worktree if the registration failed.
649 this.update(&mut cx, |this, _| {
650 this.projects.get_mut(&project).map(|project_state| {
651 project_state.worktrees.remove(&worktree_id);
652 });
653 })
654 .ok();
655 }
656
657 *done_tx.borrow_mut() = Some(());
658 }
659 });
660 project_state.worktrees.insert(
661 worktree_id,
662 WorktreeState::Registering(RegisteringWorktreeState {
663 changed_paths: Default::default(),
664 done_rx,
665 _registration: registration,
666 }),
667 );
668 }
669
670 fn project_worktrees_changed(&mut self, project: Model<Project>, cx: &mut ModelContext<Self>) {
671 let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
672 {
673 project_state
674 } else {
675 return;
676 };
677
678 let mut worktrees = project
679 .read(cx)
680 .worktrees()
681 .filter(|worktree| worktree.read(cx).is_local())
682 .collect::<Vec<_>>();
683 let worktree_ids = worktrees
684 .iter()
685 .map(|worktree| worktree.read(cx).id())
686 .collect::<HashSet<_>>();
687
688 // Remove worktrees that are no longer present
689 project_state
690 .worktrees
691 .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
692
693 // Register new worktrees
694 worktrees.retain(|worktree| {
695 let worktree_id = worktree.read(cx).id();
696 !project_state.worktrees.contains_key(&worktree_id)
697 });
698 for worktree in worktrees {
699 self.register_worktree(project.clone(), worktree, cx);
700 }
701 }
702
703 pub fn pending_file_count(&self, project: &Model<Project>) -> Option<watch::Receiver<usize>> {
704 Some(
705 self.projects
706 .get(&project.downgrade())?
707 .pending_file_count_rx
708 .clone(),
709 )
710 }
711
712 pub fn search_project(
713 &mut self,
714 project: Model<Project>,
715 query: String,
716 limit: usize,
717 includes: Vec<PathMatcher>,
718 excludes: Vec<PathMatcher>,
719 cx: &mut ModelContext<Self>,
720 ) -> Task<Result<Vec<SearchResult>>> {
721 if query.is_empty() {
722 return Task::ready(Ok(Vec::new()));
723 }
724
725 let index = self.index_project(project.clone(), cx);
726 let embedding_provider = self.embedding_provider.clone();
727
728 cx.spawn(|this, mut cx| async move {
729 index.await?;
730 let t0 = Instant::now();
731
732 let query = embedding_provider
733 .embed_batch(vec![query])
734 .await?
735 .pop()
736 .context("could not embed query")?;
737 log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
738
739 let search_start = Instant::now();
740 let modified_buffer_results = this.update(&mut cx, |this, cx| {
741 this.search_modified_buffers(
742 &project,
743 query.clone(),
744 limit,
745 &includes,
746 &excludes,
747 cx,
748 )
749 })?;
750 let file_results = this.update(&mut cx, |this, cx| {
751 this.search_files(project, query, limit, includes, excludes, cx)
752 })?;
753 let (modified_buffer_results, file_results) =
754 futures::join!(modified_buffer_results, file_results);
755
756 // Weave together the results from modified buffers and files.
757 let mut results = Vec::new();
758 let mut modified_buffers = HashSet::default();
759 for result in modified_buffer_results.log_err().unwrap_or_default() {
760 modified_buffers.insert(result.buffer.clone());
761 results.push(result);
762 }
763 for result in file_results.log_err().unwrap_or_default() {
764 if !modified_buffers.contains(&result.buffer) {
765 results.push(result);
766 }
767 }
768 results.sort_by_key(|result| Reverse(result.similarity));
769 results.truncate(limit);
770 log::trace!("Semantic search took {:?}", search_start.elapsed());
771 Ok(results)
772 })
773 }
774
775 pub fn search_files(
776 &mut self,
777 project: Model<Project>,
778 query: Embedding,
779 limit: usize,
780 includes: Vec<PathMatcher>,
781 excludes: Vec<PathMatcher>,
782 cx: &mut ModelContext<Self>,
783 ) -> Task<Result<Vec<SearchResult>>> {
784 let db_path = self.db.path().clone();
785 let fs = self.fs.clone();
786 cx.spawn(|this, mut cx| async move {
787 let database = VectorDatabase::new(
788 fs.clone(),
789 db_path.clone(),
790 cx.background_executor().clone(),
791 )
792 .await?;
793
794 let worktree_db_ids = this.read_with(&cx, |this, _| {
795 let project_state = this
796 .projects
797 .get(&project.downgrade())
798 .context("project was not indexed")?;
799 let worktree_db_ids = project_state
800 .worktrees
801 .values()
802 .filter_map(|worktree| {
803 if let WorktreeState::Registered(worktree) = worktree {
804 Some(worktree.db_id)
805 } else {
806 None
807 }
808 })
809 .collect::<Vec<i64>>();
810 anyhow::Ok(worktree_db_ids)
811 })??;
812
813 let file_ids = database
814 .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
815 .await?;
816
817 let batch_n = cx.background_executor().num_cpus();
818 let ids_len = file_ids.clone().len();
819 let minimum_batch_size = 50;
820
821 let batch_size = {
822 let size = ids_len / batch_n;
823 if size < minimum_batch_size {
824 minimum_batch_size
825 } else {
826 size
827 }
828 };
829
830 let mut batch_results = Vec::new();
831 for batch in file_ids.chunks(batch_size) {
832 let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
833 let limit = limit.clone();
834 let fs = fs.clone();
835 let db_path = db_path.clone();
836 let query = query.clone();
837 if let Some(db) =
838 VectorDatabase::new(fs, db_path.clone(), cx.background_executor().clone())
839 .await
840 .log_err()
841 {
842 batch_results.push(async move {
843 db.top_k_search(&query, limit, batch.as_slice()).await
844 });
845 }
846 }
847
848 let batch_results = futures::future::join_all(batch_results).await;
849
850 let mut results = Vec::new();
851 for batch_result in batch_results {
852 if batch_result.is_ok() {
853 for (id, similarity) in batch_result.unwrap() {
854 let ix = match results
855 .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
856 {
857 Ok(ix) => ix,
858 Err(ix) => ix,
859 };
860
861 results.insert(ix, (id, similarity));
862 results.truncate(limit);
863 }
864 }
865 }
866
867 let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
868 let scores = results
869 .into_iter()
870 .map(|(_, score)| score)
871 .collect::<Vec<_>>();
872 let spans = database.spans_for_ids(ids.as_slice()).await?;
873
874 let mut tasks = Vec::new();
875 let mut ranges = Vec::new();
876 let weak_project = project.downgrade();
877 project.update(&mut cx, |project, cx| {
878 let this = this.upgrade().context("index was dropped")?;
879 for (worktree_db_id, file_path, byte_range) in spans {
880 let project_state =
881 if let Some(state) = this.read(cx).projects.get(&weak_project) {
882 state
883 } else {
884 return Err(anyhow!("project not added"));
885 };
886 if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
887 tasks.push(project.open_buffer((worktree_id, file_path), cx));
888 ranges.push(byte_range);
889 }
890 }
891
892 Ok(())
893 })??;
894
895 let buffers = futures::future::join_all(tasks).await;
896 Ok(buffers
897 .into_iter()
898 .zip(ranges)
899 .zip(scores)
900 .filter_map(|((buffer, range), similarity)| {
901 let buffer = buffer.log_err()?;
902 let range = buffer
903 .read_with(&cx, |buffer, _| {
904 let start = buffer.clip_offset(range.start, Bias::Left);
905 let end = buffer.clip_offset(range.end, Bias::Right);
906 buffer.anchor_before(start)..buffer.anchor_after(end)
907 })
908 .log_err()?;
909 Some(SearchResult {
910 buffer,
911 range,
912 similarity,
913 })
914 })
915 .collect())
916 })
917 }
918
919 fn search_modified_buffers(
920 &self,
921 project: &Model<Project>,
922 query: Embedding,
923 limit: usize,
924 includes: &[PathMatcher],
925 excludes: &[PathMatcher],
926 cx: &mut ModelContext<Self>,
927 ) -> Task<Result<Vec<SearchResult>>> {
928 let modified_buffers = project
929 .read(cx)
930 .opened_buffers()
931 .into_iter()
932 .filter_map(|buffer_handle| {
933 let buffer = buffer_handle.read(cx);
934 let snapshot = buffer.snapshot();
935 let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
936 excludes.iter().any(|matcher| matcher.is_match(&path))
937 });
938
939 let included = if includes.len() == 0 {
940 true
941 } else {
942 snapshot.resolve_file_path(cx, false).map_or(false, |path| {
943 includes.iter().any(|matcher| matcher.is_match(&path))
944 })
945 };
946
947 if buffer.is_dirty() && !excluded && included {
948 Some((buffer_handle, snapshot))
949 } else {
950 None
951 }
952 })
953 .collect::<HashMap<_, _>>();
954
955 let embedding_provider = self.embedding_provider.clone();
956 let fs = self.fs.clone();
957 let db_path = self.db.path().clone();
958 let background = cx.background_executor().clone();
959 cx.background_executor().spawn(async move {
960 let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
961 let mut results = Vec::<SearchResult>::new();
962
963 let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
964 for (buffer, snapshot) in modified_buffers {
965 let language = snapshot
966 .language_at(0)
967 .cloned()
968 .unwrap_or_else(|| language::PLAIN_TEXT.clone());
969 let mut spans = retriever
970 .parse_file_with_template(None, &snapshot.text(), language)
971 .log_err()
972 .unwrap_or_default();
973 if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
974 .await
975 .log_err()
976 .is_some()
977 {
978 for span in spans {
979 let similarity = span.embedding.unwrap().similarity(&query);
980 let ix = match results
981 .binary_search_by_key(&Reverse(similarity), |result| {
982 Reverse(result.similarity)
983 }) {
984 Ok(ix) => ix,
985 Err(ix) => ix,
986 };
987
988 let range = {
989 let start = snapshot.clip_offset(span.range.start, Bias::Left);
990 let end = snapshot.clip_offset(span.range.end, Bias::Right);
991 snapshot.anchor_before(start)..snapshot.anchor_after(end)
992 };
993
994 results.insert(
995 ix,
996 SearchResult {
997 buffer: buffer.clone(),
998 range,
999 similarity,
1000 },
1001 );
1002 results.truncate(limit);
1003 }
1004 }
1005 }
1006
1007 Ok(results)
1008 })
1009 }
1010
1011 pub fn index_project(
1012 &mut self,
1013 project: Model<Project>,
1014 cx: &mut ModelContext<Self>,
1015 ) -> Task<Result<()>> {
1016 if self.is_authenticated() {
1017 self.index_project_internal(project, cx)
1018 } else {
1019 let authenticate = self.authenticate(cx);
1020 cx.spawn(|this, mut cx| async move {
1021 if authenticate.await {
1022 this.update(&mut cx, |this, cx| this.index_project_internal(project, cx))?
1023 .await
1024 } else {
1025 Err(anyhow!("user is not authenticated"))
1026 }
1027 })
1028 }
1029 }
1030
1031 fn index_project_internal(
1032 &mut self,
1033 project: Model<Project>,
1034 cx: &mut ModelContext<Self>,
1035 ) -> Task<Result<()>> {
1036 if !self.projects.contains_key(&project.downgrade()) {
1037 let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
1038 project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
1039 this.project_worktrees_changed(project.clone(), cx);
1040 }
1041 project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
1042 this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
1043 }
1044 _ => {}
1045 });
1046 let project_state = ProjectState::new(subscription, cx);
1047 self.projects.insert(project.downgrade(), project_state);
1048 self.project_worktrees_changed(project.clone(), cx);
1049 }
1050 let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
1051 project_state.pending_index += 1;
1052 cx.notify();
1053
1054 let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
1055 let db = self.db.clone();
1056 let language_registry = self.language_registry.clone();
1057 let parsing_files_tx = self.parsing_files_tx.clone();
1058 let worktree_registration = self.wait_for_worktree_registration(&project, cx);
1059
1060 cx.spawn(|this, mut cx| async move {
1061 worktree_registration.await?;
1062
1063 let mut pending_files = Vec::new();
1064 let mut files_to_delete = Vec::new();
1065 this.update(&mut cx, |this, cx| {
1066 let project_state = this
1067 .projects
1068 .get_mut(&project.downgrade())
1069 .context("project was dropped")?;
1070 let pending_file_count_tx = &project_state.pending_file_count_tx;
1071
1072 project_state
1073 .worktrees
1074 .retain(|worktree_id, worktree_state| {
1075 let worktree = if let Some(worktree) =
1076 project.read(cx).worktree_for_id(*worktree_id, cx)
1077 {
1078 worktree
1079 } else {
1080 return false;
1081 };
1082 let worktree_state =
1083 if let WorktreeState::Registered(worktree_state) = worktree_state {
1084 worktree_state
1085 } else {
1086 return true;
1087 };
1088
1089 for (path, info) in &worktree_state.changed_paths {
1090 if info.is_deleted {
1091 files_to_delete.push((worktree_state.db_id, path.clone()));
1092 } else if let Ok(absolute_path) = worktree.read(cx).absolutize(path) {
1093 let job_handle = JobHandle::new(pending_file_count_tx);
1094 pending_files.push(PendingFile {
1095 absolute_path,
1096 relative_path: path.clone(),
1097 language: None,
1098 job_handle,
1099 modified_time: info.mtime,
1100 worktree_db_id: worktree_state.db_id,
1101 });
1102 }
1103 }
1104 worktree_state.changed_paths.clear();
1105 true
1106 });
1107
1108 anyhow::Ok(())
1109 })??;
1110
1111 cx.background_executor()
1112 .spawn(async move {
1113 for (worktree_db_id, path) in files_to_delete {
1114 db.delete_file(worktree_db_id, path).await.log_err();
1115 }
1116
1117 let embeddings_for_digest = {
1118 let mut files = HashMap::default();
1119 for pending_file in &pending_files {
1120 files
1121 .entry(pending_file.worktree_db_id)
1122 .or_insert(Vec::new())
1123 .push(pending_file.relative_path.clone());
1124 }
1125 Arc::new(
1126 db.embeddings_for_files(files)
1127 .await
1128 .log_err()
1129 .unwrap_or_default(),
1130 )
1131 };
1132
1133 for mut pending_file in pending_files {
1134 if let Ok(language) = language_registry
1135 .language_for_file(&pending_file.relative_path, None)
1136 .await
1137 {
1138 if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
1139 && &language.name().as_ref() != &"Markdown"
1140 && language
1141 .grammar()
1142 .and_then(|grammar| grammar.embedding_config.as_ref())
1143 .is_none()
1144 {
1145 continue;
1146 }
1147 pending_file.language = Some(language);
1148 }
1149 parsing_files_tx
1150 .try_send((embeddings_for_digest.clone(), pending_file))
1151 .ok();
1152 }
1153
1154 // Wait until we're done indexing.
1155 while let Some(count) = pending_file_count_rx.next().await {
1156 if count == 0 {
1157 break;
1158 }
1159 }
1160 })
1161 .await;
1162
1163 this.update(&mut cx, |this, cx| {
1164 let project_state = this
1165 .projects
1166 .get_mut(&project.downgrade())
1167 .context("project was dropped")?;
1168 project_state.pending_index -= 1;
1169 cx.notify();
1170 anyhow::Ok(())
1171 })??;
1172
1173 Ok(())
1174 })
1175 }
1176
1177 fn wait_for_worktree_registration(
1178 &self,
1179 project: &Model<Project>,
1180 cx: &mut ModelContext<Self>,
1181 ) -> Task<Result<()>> {
1182 let project = project.downgrade();
1183 cx.spawn(|this, cx| async move {
1184 loop {
1185 let mut pending_worktrees = Vec::new();
1186 this.upgrade()
1187 .context("semantic index dropped")?
1188 .read_with(&cx, |this, _| {
1189 if let Some(project) = this.projects.get(&project) {
1190 for worktree in project.worktrees.values() {
1191 if let WorktreeState::Registering(worktree) = worktree {
1192 pending_worktrees.push(worktree.done());
1193 }
1194 }
1195 }
1196 })?;
1197
1198 if pending_worktrees.is_empty() {
1199 break;
1200 } else {
1201 future::join_all(pending_worktrees).await;
1202 }
1203 }
1204 Ok(())
1205 })
1206 }
1207
1208 async fn embed_spans(
1209 spans: &mut [Span],
1210 embedding_provider: &dyn EmbeddingProvider,
1211 db: &VectorDatabase,
1212 ) -> Result<()> {
1213 let mut batch = Vec::new();
1214 let mut batch_tokens = 0;
1215 let mut embeddings = Vec::new();
1216
1217 let digests = spans
1218 .iter()
1219 .map(|span| span.digest.clone())
1220 .collect::<Vec<_>>();
1221 let embeddings_for_digests = db
1222 .embeddings_for_digests(digests)
1223 .await
1224 .log_err()
1225 .unwrap_or_default();
1226
1227 for span in &*spans {
1228 if embeddings_for_digests.contains_key(&span.digest) {
1229 continue;
1230 };
1231
1232 if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
1233 let batch_embeddings = embedding_provider
1234 .embed_batch(mem::take(&mut batch))
1235 .await?;
1236 embeddings.extend(batch_embeddings);
1237 batch_tokens = 0;
1238 }
1239
1240 batch_tokens += span.token_count;
1241 batch.push(span.content.clone());
1242 }
1243
1244 if !batch.is_empty() {
1245 let batch_embeddings = embedding_provider
1246 .embed_batch(mem::take(&mut batch))
1247 .await?;
1248
1249 embeddings.extend(batch_embeddings);
1250 }
1251
1252 let mut embeddings = embeddings.into_iter();
1253 for span in spans {
1254 let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
1255 Some(embedding.clone())
1256 } else {
1257 embeddings.next()
1258 };
1259 let embedding = embedding.context("failed to embed spans")?;
1260 span.embedding = Some(embedding);
1261 }
1262 Ok(())
1263 }
1264}
1265
1266impl Drop for JobHandle {
1267 fn drop(&mut self) {
1268 if let Some(inner) = Arc::get_mut(&mut self.tx) {
1269 // This is the last instance of the JobHandle (regardless of its origin - whether it was cloned or not)
1270 if let Some(tx) = inner.upgrade() {
1271 let mut tx = tx.lock();
1272 *tx.borrow_mut() -= 1;
1273 }
1274 }
1275 }
1276}
1277
1278#[cfg(test)]
1279mod tests {
1280
1281 use super::*;
1282 #[test]
1283 fn test_job_handle() {
1284 let (job_count_tx, job_count_rx) = watch::channel_with(0);
1285 let tx = Arc::new(Mutex::new(job_count_tx));
1286 let job_handle = JobHandle::new(&tx);
1287
1288 assert_eq!(1, *job_count_rx.borrow());
1289 let new_job_handle = job_handle.clone();
1290 assert_eq!(1, *job_count_rx.borrow());
1291 drop(job_handle);
1292 assert_eq!(1, *job_count_rx.borrow());
1293 drop(new_job_handle);
1294 assert_eq!(0, *job_count_rx.borrow());
1295 }
1296}