1mod db;
2mod embedding;
3mod parsing;
4mod semantic_index_settings;
5
6#[cfg(test)]
7mod semantic_index_tests;
8
9use crate::semantic_index_settings::SemanticIndexSettings;
10use anyhow::{anyhow, Result};
11use db::VectorDatabase;
12use embedding::{EmbeddingProvider, OpenAIEmbeddings};
13use futures::{channel::oneshot, Future};
14use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
15use language::{Language, LanguageRegistry};
16use parking_lot::Mutex;
17use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES};
18use postage::watch;
19use project::{Fs, Project, WorktreeId};
20use smol::channel;
21use std::{
22 collections::HashMap,
23 mem,
24 ops::Range,
25 path::{Path, PathBuf},
26 sync::{Arc, Weak},
27 time::SystemTime,
28};
29use util::{
30 channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
31 http::HttpClient,
32 paths::EMBEDDINGS_DIR,
33 ResultExt,
34};
35
36const SEMANTIC_INDEX_VERSION: usize = 3;
37const EMBEDDINGS_BATCH_SIZE: usize = 150;
38
39pub fn init(
40 fs: Arc<dyn Fs>,
41 http_client: Arc<dyn HttpClient>,
42 language_registry: Arc<LanguageRegistry>,
43 cx: &mut AppContext,
44) {
45 settings::register::<SemanticIndexSettings>(cx);
46
47 let db_file_path = EMBEDDINGS_DIR
48 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
49 .join("embeddings_db");
50
51 if *RELEASE_CHANNEL == ReleaseChannel::Stable
52 || !settings::get::<SemanticIndexSettings>(cx).enabled
53 {
54 return;
55 }
56
57 cx.spawn(move |mut cx| async move {
58 let semantic_index = SemanticIndex::new(
59 fs,
60 db_file_path,
61 Arc::new(OpenAIEmbeddings {
62 client: http_client,
63 executor: cx.background(),
64 }),
65 language_registry,
66 cx.clone(),
67 )
68 .await?;
69
70 cx.update(|cx| {
71 cx.set_global(semantic_index.clone());
72 });
73
74 anyhow::Ok(())
75 })
76 .detach();
77}
78
79pub struct SemanticIndex {
80 fs: Arc<dyn Fs>,
81 database_url: Arc<PathBuf>,
82 embedding_provider: Arc<dyn EmbeddingProvider>,
83 language_registry: Arc<LanguageRegistry>,
84 db_update_tx: channel::Sender<DbOperation>,
85 parsing_files_tx: channel::Sender<PendingFile>,
86 _db_update_task: Task<()>,
87 _embed_batch_task: Task<()>,
88 _batch_files_task: Task<()>,
89 _parsing_files_tasks: Vec<Task<()>>,
90 projects: HashMap<WeakModelHandle<Project>, ProjectState>,
91}
92
93struct ProjectState {
94 worktree_db_ids: Vec<(WorktreeId, i64)>,
95 outstanding_job_count_rx: watch::Receiver<usize>,
96 outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
97}
98
99struct JobHandle {
100 tx: Weak<Mutex<watch::Sender<usize>>>,
101}
102
103impl ProjectState {
104 fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
105 self.worktree_db_ids
106 .iter()
107 .find_map(|(worktree_id, db_id)| {
108 if *worktree_id == id {
109 Some(*db_id)
110 } else {
111 None
112 }
113 })
114 }
115
116 fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
117 self.worktree_db_ids
118 .iter()
119 .find_map(|(worktree_id, db_id)| {
120 if *db_id == id {
121 Some(*worktree_id)
122 } else {
123 None
124 }
125 })
126 }
127}
128
129pub struct PendingFile {
130 worktree_db_id: i64,
131 relative_path: PathBuf,
132 absolute_path: PathBuf,
133 language: Arc<Language>,
134 modified_time: SystemTime,
135 job_handle: JobHandle,
136}
137
138#[derive(Debug, Clone)]
139pub struct SearchResult {
140 pub worktree_id: WorktreeId,
141 pub name: String,
142 pub byte_range: Range<usize>,
143 pub file_path: PathBuf,
144}
145
146enum DbOperation {
147 InsertFile {
148 worktree_id: i64,
149 documents: Vec<Document>,
150 path: PathBuf,
151 mtime: SystemTime,
152 job_handle: JobHandle,
153 },
154 Delete {
155 worktree_id: i64,
156 path: PathBuf,
157 },
158 FindOrCreateWorktree {
159 path: PathBuf,
160 sender: oneshot::Sender<Result<i64>>,
161 },
162 FileMTimes {
163 worktree_id: i64,
164 sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
165 },
166}
167
168enum EmbeddingJob {
169 Enqueue {
170 worktree_id: i64,
171 path: PathBuf,
172 mtime: SystemTime,
173 documents: Vec<Document>,
174 job_handle: JobHandle,
175 },
176 Flush,
177}
178
179impl SemanticIndex {
180 pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
181 if cx.has_global::<ModelHandle<Self>>() {
182 Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
183 } else {
184 None
185 }
186 }
187
188 async fn new(
189 fs: Arc<dyn Fs>,
190 database_url: PathBuf,
191 embedding_provider: Arc<dyn EmbeddingProvider>,
192 language_registry: Arc<LanguageRegistry>,
193 mut cx: AsyncAppContext,
194 ) -> Result<ModelHandle<Self>> {
195 let database_url = Arc::new(database_url);
196
197 let db = cx
198 .background()
199 .spawn(VectorDatabase::new(fs.clone(), database_url.clone()))
200 .await?;
201
202 Ok(cx.add_model(|cx| {
203 // Perform database operations
204 let (db_update_tx, db_update_rx) = channel::unbounded();
205 let _db_update_task = cx.background().spawn({
206 async move {
207 while let Ok(job) = db_update_rx.recv().await {
208 Self::run_db_operation(&db, job)
209 }
210 }
211 });
212
213 // Group documents into batches and send them to the embedding provider.
214 let (embed_batch_tx, embed_batch_rx) =
215 channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
216 let _embed_batch_task = cx.background().spawn({
217 let db_update_tx = db_update_tx.clone();
218 let embedding_provider = embedding_provider.clone();
219 async move {
220 while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
221 Self::compute_embeddings_for_batch(
222 embeddings_queue,
223 &embedding_provider,
224 &db_update_tx,
225 )
226 .await;
227 }
228 }
229 });
230
231 // Group documents into batches and send them to the embedding provider.
232 let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
233 let _batch_files_task = cx.background().spawn(async move {
234 let mut queue_len = 0;
235 let mut embeddings_queue = vec![];
236 while let Ok(job) = batch_files_rx.recv().await {
237 Self::enqueue_documents_to_embed(
238 job,
239 &mut queue_len,
240 &mut embeddings_queue,
241 &embed_batch_tx,
242 );
243 }
244 });
245
246 // Parse files into embeddable documents.
247 let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
248 let mut _parsing_files_tasks = Vec::new();
249 for _ in 0..cx.background().num_cpus() {
250 let fs = fs.clone();
251 let parsing_files_rx = parsing_files_rx.clone();
252 let batch_files_tx = batch_files_tx.clone();
253 let db_update_tx = db_update_tx.clone();
254 _parsing_files_tasks.push(cx.background().spawn(async move {
255 let mut retriever = CodeContextRetriever::new();
256 while let Ok(pending_file) = parsing_files_rx.recv().await {
257 Self::parse_file(
258 &fs,
259 pending_file,
260 &mut retriever,
261 &batch_files_tx,
262 &parsing_files_rx,
263 &db_update_tx,
264 )
265 .await;
266 }
267 }));
268 }
269
270 Self {
271 fs,
272 database_url,
273 embedding_provider,
274 language_registry,
275 db_update_tx,
276 parsing_files_tx,
277 _db_update_task,
278 _embed_batch_task,
279 _batch_files_task,
280 _parsing_files_tasks,
281 projects: HashMap::new(),
282 }
283 }))
284 }
285
286 fn run_db_operation(db: &VectorDatabase, job: DbOperation) {
287 match job {
288 DbOperation::InsertFile {
289 worktree_id,
290 documents,
291 path,
292 mtime,
293 job_handle,
294 } => {
295 db.insert_file(worktree_id, path, mtime, documents)
296 .log_err();
297 drop(job_handle)
298 }
299 DbOperation::Delete { worktree_id, path } => {
300 db.delete_file(worktree_id, path).log_err();
301 }
302 DbOperation::FindOrCreateWorktree { path, sender } => {
303 let id = db.find_or_create_worktree(&path);
304 sender.send(id).ok();
305 }
306 DbOperation::FileMTimes {
307 worktree_id: worktree_db_id,
308 sender,
309 } => {
310 let file_mtimes = db.get_file_mtimes(worktree_db_id);
311 sender.send(file_mtimes).ok();
312 }
313 }
314 }
315
316 async fn compute_embeddings_for_batch(
317 mut embeddings_queue: Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
318 embedding_provider: &Arc<dyn EmbeddingProvider>,
319 db_update_tx: &channel::Sender<DbOperation>,
320 ) {
321 let mut batch_documents = vec![];
322 for (_, documents, _, _, _) in embeddings_queue.iter() {
323 batch_documents.extend(documents.iter().map(|document| document.content.as_str()));
324 }
325
326 if let Ok(embeddings) = embedding_provider.embed_batch(batch_documents).await {
327 log::trace!(
328 "created {} embeddings for {} files",
329 embeddings.len(),
330 embeddings_queue.len(),
331 );
332
333 let mut i = 0;
334 let mut j = 0;
335
336 for embedding in embeddings.iter() {
337 while embeddings_queue[i].1.len() == j {
338 i += 1;
339 j = 0;
340 }
341
342 embeddings_queue[i].1[j].embedding = embedding.to_owned();
343 j += 1;
344 }
345
346 for (worktree_id, documents, path, mtime, job_handle) in embeddings_queue.into_iter() {
347 // for document in documents.iter() {
348 // // TODO: Update this so it doesn't panic
349 // assert!(
350 // document.embedding.len() > 0,
351 // "Document Embedding Not Complete"
352 // );
353 // }
354
355 db_update_tx
356 .send(DbOperation::InsertFile {
357 worktree_id,
358 documents,
359 path,
360 mtime,
361 job_handle,
362 })
363 .await
364 .unwrap();
365 }
366 }
367 }
368
369 fn enqueue_documents_to_embed(
370 job: EmbeddingJob,
371 queue_len: &mut usize,
372 embeddings_queue: &mut Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
373 embed_batch_tx: &channel::Sender<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>,
374 ) {
375 let should_flush = match job {
376 EmbeddingJob::Enqueue {
377 documents,
378 worktree_id,
379 path,
380 mtime,
381 job_handle,
382 } => {
383 *queue_len += &documents.len();
384 embeddings_queue.push((worktree_id, documents, path, mtime, job_handle));
385 *queue_len >= EMBEDDINGS_BATCH_SIZE
386 }
387 EmbeddingJob::Flush => true,
388 };
389
390 if should_flush {
391 embed_batch_tx
392 .try_send(mem::take(embeddings_queue))
393 .unwrap();
394 *queue_len = 0;
395 }
396 }
397
398 async fn parse_file(
399 fs: &Arc<dyn Fs>,
400 pending_file: PendingFile,
401 retriever: &mut CodeContextRetriever,
402 batch_files_tx: &channel::Sender<EmbeddingJob>,
403 parsing_files_rx: &channel::Receiver<PendingFile>,
404 db_update_tx: &channel::Sender<DbOperation>,
405 ) {
406 if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
407 if let Some(documents) = retriever
408 .parse_file(&pending_file.relative_path, &content, pending_file.language)
409 .log_err()
410 {
411 log::trace!(
412 "parsed path {:?}: {} documents",
413 pending_file.relative_path,
414 documents.len()
415 );
416
417 if documents.len() == 0 {
418 db_update_tx
419 .send(DbOperation::InsertFile {
420 worktree_id: pending_file.worktree_db_id,
421 documents,
422 path: pending_file.relative_path,
423 mtime: pending_file.modified_time,
424 job_handle: pending_file.job_handle,
425 })
426 .await
427 .unwrap();
428 } else {
429 batch_files_tx
430 .try_send(EmbeddingJob::Enqueue {
431 worktree_id: pending_file.worktree_db_id,
432 path: pending_file.relative_path,
433 mtime: pending_file.modified_time,
434 job_handle: pending_file.job_handle,
435 documents,
436 })
437 .unwrap();
438 }
439 }
440 }
441
442 if parsing_files_rx.len() == 0 {
443 batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
444 }
445 }
446
447 fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
448 let (tx, rx) = oneshot::channel();
449 self.db_update_tx
450 .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
451 .unwrap();
452 async move { rx.await? }
453 }
454
455 fn get_file_mtimes(
456 &self,
457 worktree_id: i64,
458 ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
459 let (tx, rx) = oneshot::channel();
460 self.db_update_tx
461 .try_send(DbOperation::FileMTimes {
462 worktree_id,
463 sender: tx,
464 })
465 .unwrap();
466 async move { rx.await? }
467 }
468
469 pub fn index_project(
470 &mut self,
471 project: ModelHandle<Project>,
472 cx: &mut ModelContext<Self>,
473 ) -> Task<Result<(usize, watch::Receiver<usize>)>> {
474 let worktree_scans_complete = project
475 .read(cx)
476 .worktrees(cx)
477 .map(|worktree| {
478 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
479 async move {
480 scan_complete.await;
481 }
482 })
483 .collect::<Vec<_>>();
484 let worktree_db_ids = project
485 .read(cx)
486 .worktrees(cx)
487 .map(|worktree| {
488 self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
489 })
490 .collect::<Vec<_>>();
491
492 let language_registry = self.language_registry.clone();
493 let db_update_tx = self.db_update_tx.clone();
494 let parsing_files_tx = self.parsing_files_tx.clone();
495
496 cx.spawn(|this, mut cx| async move {
497 futures::future::join_all(worktree_scans_complete).await;
498
499 let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
500
501 let worktrees = project.read_with(&cx, |project, cx| {
502 project
503 .worktrees(cx)
504 .map(|worktree| worktree.read(cx).snapshot())
505 .collect::<Vec<_>>()
506 });
507
508 let mut worktree_file_mtimes = HashMap::new();
509 let mut db_ids_by_worktree_id = HashMap::new();
510 for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
511 let db_id = db_id?;
512 db_ids_by_worktree_id.insert(worktree.id(), db_id);
513 worktree_file_mtimes.insert(
514 worktree.id(),
515 this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
516 .await?,
517 );
518 }
519
520 let (job_count_tx, job_count_rx) = watch::channel_with(0);
521 let job_count_tx = Arc::new(Mutex::new(job_count_tx));
522 this.update(&mut cx, |this, _| {
523 this.projects.insert(
524 project.downgrade(),
525 ProjectState {
526 worktree_db_ids: db_ids_by_worktree_id
527 .iter()
528 .map(|(a, b)| (*a, *b))
529 .collect(),
530 outstanding_job_count_rx: job_count_rx.clone(),
531 outstanding_job_count_tx: job_count_tx.clone(),
532 },
533 );
534 });
535
536 cx.background()
537 .spawn(async move {
538 let mut count = 0;
539 for worktree in worktrees.into_iter() {
540 let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap();
541 for file in worktree.files(false, 0) {
542 let absolute_path = worktree.absolutize(&file.path);
543
544 if let Ok(language) = language_registry
545 .language_for_file(&absolute_path, None)
546 .await
547 {
548 if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
549 && language
550 .grammar()
551 .and_then(|grammar| grammar.embedding_config.as_ref())
552 .is_none()
553 {
554 continue;
555 }
556
557 let path_buf = file.path.to_path_buf();
558 let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
559 let already_stored = stored_mtime
560 .map_or(false, |existing_mtime| existing_mtime == file.mtime);
561
562 if !already_stored {
563 count += 1;
564 *job_count_tx.lock().borrow_mut() += 1;
565 let job_handle = JobHandle {
566 tx: Arc::downgrade(&job_count_tx),
567 };
568 parsing_files_tx
569 .try_send(PendingFile {
570 worktree_db_id: db_ids_by_worktree_id[&worktree.id()],
571 relative_path: path_buf,
572 absolute_path,
573 language,
574 job_handle,
575 modified_time: file.mtime,
576 })
577 .unwrap();
578 }
579 }
580 }
581 for file in file_mtimes.keys() {
582 db_update_tx
583 .try_send(DbOperation::Delete {
584 worktree_id: db_ids_by_worktree_id[&worktree.id()],
585 path: file.to_owned(),
586 })
587 .unwrap();
588 }
589 }
590
591 anyhow::Ok((count, job_count_rx))
592 })
593 .await
594 })
595 }
596
597 pub fn outstanding_job_count_rx(
598 &self,
599 project: &ModelHandle<Project>,
600 ) -> Option<watch::Receiver<usize>> {
601 Some(
602 self.projects
603 .get(&project.downgrade())?
604 .outstanding_job_count_rx
605 .clone(),
606 )
607 }
608
609 pub fn search_project(
610 &mut self,
611 project: ModelHandle<Project>,
612 phrase: String,
613 limit: usize,
614 cx: &mut ModelContext<Self>,
615 ) -> Task<Result<Vec<SearchResult>>> {
616 let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
617 state
618 } else {
619 return Task::ready(Err(anyhow!("project not added")));
620 };
621
622 let worktree_db_ids = project
623 .read(cx)
624 .worktrees(cx)
625 .filter_map(|worktree| {
626 let worktree_id = worktree.read(cx).id();
627 project_state.db_id_for_worktree_id(worktree_id)
628 })
629 .collect::<Vec<_>>();
630
631 let embedding_provider = self.embedding_provider.clone();
632 let database_url = self.database_url.clone();
633 let fs = self.fs.clone();
634 cx.spawn(|this, cx| async move {
635 let documents = cx
636 .background()
637 .spawn(async move {
638 let database = VectorDatabase::new(fs, database_url).await?;
639
640 let phrase_embedding = embedding_provider
641 .embed_batch(vec![&phrase])
642 .await?
643 .into_iter()
644 .next()
645 .unwrap();
646
647 database.top_k_search(&worktree_db_ids, &phrase_embedding, limit)
648 })
649 .await?;
650
651 this.read_with(&cx, |this, _| {
652 let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
653 state
654 } else {
655 return Err(anyhow!("project not added"));
656 };
657
658 Ok(documents
659 .into_iter()
660 .filter_map(|(worktree_db_id, file_path, byte_range, name)| {
661 let worktree_id = project_state.worktree_id_for_db_id(worktree_db_id)?;
662 Some(SearchResult {
663 worktree_id,
664 name,
665 byte_range,
666 file_path,
667 })
668 })
669 .collect())
670 })
671 })
672 }
673}
674
675impl Entity for SemanticIndex {
676 type Event = ();
677}
678
679impl Drop for JobHandle {
680 fn drop(&mut self) {
681 if let Some(tx) = self.tx.upgrade() {
682 let mut tx = tx.lock();
683 *tx.borrow_mut() -= 1;
684 }
685 }
686}