1mod chunking;
2mod embedding;
3mod embedding_index;
4mod indexing;
5mod project_index;
6mod project_index_debug_view;
7mod summary_backlog;
8mod summary_index;
9mod worktree_index;
10
11use anyhow::{Context as _, Result};
12use collections::HashMap;
13use fs::Fs;
14use gpui::{App, AppContext as _, AsyncApp, BorrowAppContext, Context, Entity, Global, WeakEntity};
15use language::LineEnding;
16use project::{Project, Worktree};
17use std::{
18 cmp::Ordering,
19 path::{Path, PathBuf},
20 sync::Arc,
21};
22use util::ResultExt as _;
23use workspace::Workspace;
24
25pub use embedding::*;
26pub use project_index::{LoadedSearchResult, ProjectIndex, SearchResult, Status};
27pub use project_index_debug_view::ProjectIndexDebugView;
28pub use summary_index::FileSummary;
29
30pub struct SemanticDb {
31 embedding_provider: Arc<dyn EmbeddingProvider>,
32 db_connection: Option<heed::Env>,
33 project_indices: HashMap<WeakEntity<Project>, Entity<ProjectIndex>>,
34}
35
36impl Global for SemanticDb {}
37
38impl SemanticDb {
39 pub async fn new(
40 db_path: PathBuf,
41 embedding_provider: Arc<dyn EmbeddingProvider>,
42 cx: &mut AsyncApp,
43 ) -> Result<Self> {
44 let db_connection = cx
45 .background_executor()
46 .spawn(async move {
47 std::fs::create_dir_all(&db_path)?;
48 unsafe {
49 heed::EnvOpenOptions::new()
50 .map_size(1024 * 1024 * 1024)
51 .max_dbs(3000)
52 .open(db_path)
53 }
54 })
55 .await
56 .context("opening database connection")?;
57
58 cx.update(|cx| {
59 cx.observe_new(
60 |workspace: &mut Workspace, _window, cx: &mut Context<Workspace>| {
61 let project = workspace.project().clone();
62
63 if cx.has_global::<SemanticDb>() {
64 cx.update_global::<SemanticDb, _>(|this, cx| {
65 this.create_project_index(project, cx);
66 })
67 } else {
68 log::info!("No SemanticDb, skipping project index")
69 }
70 },
71 )
72 .detach();
73 })
74 .ok();
75
76 Ok(SemanticDb {
77 db_connection: Some(db_connection),
78 embedding_provider,
79 project_indices: HashMap::default(),
80 })
81 }
82
83 pub async fn load_results(
84 mut results: Vec<SearchResult>,
85 fs: &Arc<dyn Fs>,
86 cx: &AsyncApp,
87 ) -> Result<Vec<LoadedSearchResult>> {
88 let mut max_scores_by_path = HashMap::<_, (f32, usize)>::default();
89 for result in &results {
90 let (score, query_index) = max_scores_by_path
91 .entry((result.worktree.clone(), result.path.clone()))
92 .or_default();
93 if result.score > *score {
94 *score = result.score;
95 *query_index = result.query_index;
96 }
97 }
98
99 results.sort_by(|a, b| {
100 let max_score_a = max_scores_by_path[&(a.worktree.clone(), a.path.clone())].0;
101 let max_score_b = max_scores_by_path[&(b.worktree.clone(), b.path.clone())].0;
102 max_score_b
103 .partial_cmp(&max_score_a)
104 .unwrap_or(Ordering::Equal)
105 .then_with(|| a.worktree.entity_id().cmp(&b.worktree.entity_id()))
106 .then_with(|| a.path.cmp(&b.path))
107 .then_with(|| a.range.start.cmp(&b.range.start))
108 });
109
110 let mut last_loaded_file: Option<(Entity<Worktree>, Arc<Path>, PathBuf, String)> = None;
111 let mut loaded_results = Vec::<LoadedSearchResult>::new();
112 for result in results {
113 let full_path;
114 let file_content;
115 if let Some(last_loaded_file) =
116 last_loaded_file
117 .as_ref()
118 .filter(|(last_worktree, last_path, _, _)| {
119 last_worktree == &result.worktree && last_path == &result.path
120 })
121 {
122 full_path = last_loaded_file.2.clone();
123 file_content = &last_loaded_file.3;
124 } else {
125 let output = result.worktree.read_with(cx, |worktree, _cx| {
126 let entry_abs_path = worktree.abs_path().join(&result.path);
127 let mut entry_full_path = PathBuf::from(worktree.root_name());
128 entry_full_path.push(&result.path);
129 let file_content = async {
130 let entry_abs_path = entry_abs_path;
131 fs.load(&entry_abs_path).await
132 };
133 (entry_full_path, file_content)
134 })?;
135 full_path = output.0;
136 let Some(content) = output.1.await.log_err() else {
137 continue;
138 };
139 last_loaded_file = Some((
140 result.worktree.clone(),
141 result.path.clone(),
142 full_path.clone(),
143 content,
144 ));
145 file_content = &last_loaded_file.as_ref().unwrap().3;
146 };
147
148 let query_index = max_scores_by_path[&(result.worktree.clone(), result.path.clone())].1;
149
150 let mut range_start = result.range.start.min(file_content.len());
151 let mut range_end = result.range.end.min(file_content.len());
152 while !file_content.is_char_boundary(range_start) {
153 range_start += 1;
154 }
155 while !file_content.is_char_boundary(range_end) {
156 range_end += 1;
157 }
158
159 let start_row = file_content[0..range_start].matches('\n').count() as u32;
160 let mut end_row = file_content[0..range_end].matches('\n').count() as u32;
161 let start_line_byte_offset = file_content[0..range_start]
162 .rfind('\n')
163 .map(|pos| pos + 1)
164 .unwrap_or_default();
165 let mut end_line_byte_offset = range_end;
166 if file_content[..end_line_byte_offset].ends_with('\n') {
167 end_row -= 1;
168 } else {
169 end_line_byte_offset = file_content[range_end..]
170 .find('\n')
171 .map(|pos| range_end + pos + 1)
172 .unwrap_or_else(|| file_content.len());
173 }
174 let mut excerpt_content =
175 file_content[start_line_byte_offset..end_line_byte_offset].to_string();
176 LineEnding::normalize(&mut excerpt_content);
177
178 if let Some(prev_result) = loaded_results.last_mut() {
179 if prev_result.full_path == full_path {
180 if *prev_result.row_range.end() + 1 == start_row {
181 prev_result.row_range = *prev_result.row_range.start()..=end_row;
182 prev_result.excerpt_content.push_str(&excerpt_content);
183 continue;
184 }
185 }
186 }
187
188 loaded_results.push(LoadedSearchResult {
189 path: result.path,
190 full_path,
191 excerpt_content,
192 row_range: start_row..=end_row,
193 query_index,
194 });
195 }
196
197 for result in &mut loaded_results {
198 while result.excerpt_content.ends_with("\n\n") {
199 result.excerpt_content.pop();
200 result.row_range =
201 *result.row_range.start()..=result.row_range.end().saturating_sub(1)
202 }
203 }
204
205 Ok(loaded_results)
206 }
207
208 pub fn project_index(
209 &mut self,
210 project: Entity<Project>,
211 _cx: &mut App,
212 ) -> Option<Entity<ProjectIndex>> {
213 self.project_indices.get(&project.downgrade()).cloned()
214 }
215
216 pub fn remaining_summaries(
217 &self,
218 project: &WeakEntity<Project>,
219 cx: &mut App,
220 ) -> Option<usize> {
221 self.project_indices.get(project).map(|project_index| {
222 project_index.update(cx, |project_index, cx| {
223 project_index.remaining_summaries(cx)
224 })
225 })
226 }
227
228 pub fn create_project_index(
229 &mut self,
230 project: Entity<Project>,
231 cx: &mut App,
232 ) -> Entity<ProjectIndex> {
233 let project_index = cx.new(|cx| {
234 ProjectIndex::new(
235 project.clone(),
236 self.db_connection.clone().unwrap(),
237 self.embedding_provider.clone(),
238 cx,
239 )
240 });
241
242 let project_weak = project.downgrade();
243 self.project_indices
244 .insert(project_weak.clone(), project_index.clone());
245
246 cx.observe_release(&project, move |_, cx| {
247 if cx.has_global::<SemanticDb>() {
248 cx.update_global::<SemanticDb, _>(|this, _| {
249 this.project_indices.remove(&project_weak);
250 })
251 }
252 })
253 .detach();
254
255 project_index
256 }
257}
258
259impl Drop for SemanticDb {
260 fn drop(&mut self) {
261 self.db_connection.take().unwrap().prepare_for_closing();
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use anyhow::anyhow;
269 use chunking::Chunk;
270 use embedding_index::{ChunkedFile, EmbeddingIndex};
271 use feature_flags::FeatureFlagAppExt;
272 use fs::FakeFs;
273 use futures::{future::BoxFuture, FutureExt};
274 use gpui::TestAppContext;
275 use indexing::IndexingEntrySet;
276 use language::language_settings::AllLanguageSettings;
277 use project::{Project, ProjectEntryId};
278 use serde_json::json;
279 use settings::SettingsStore;
280 use smol::channel;
281 use std::{future, path::Path, sync::Arc};
282 use util::separator;
283
284 fn init_test(cx: &mut TestAppContext) {
285 env_logger::try_init().ok();
286
287 cx.update(|cx| {
288 let store = SettingsStore::test(cx);
289 cx.set_global(store);
290 language::init(cx);
291 cx.update_flags(false, vec![]);
292 Project::init_settings(cx);
293 SettingsStore::update(cx, |store, cx| {
294 store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
295 });
296 });
297 }
298
299 pub struct TestEmbeddingProvider {
300 batch_size: usize,
301 compute_embedding: Box<dyn Fn(&str) -> Result<Embedding> + Send + Sync>,
302 }
303
304 impl TestEmbeddingProvider {
305 pub fn new(
306 batch_size: usize,
307 compute_embedding: impl 'static + Fn(&str) -> Result<Embedding> + Send + Sync,
308 ) -> Self {
309 Self {
310 batch_size,
311 compute_embedding: Box::new(compute_embedding),
312 }
313 }
314 }
315
316 impl EmbeddingProvider for TestEmbeddingProvider {
317 fn embed<'a>(
318 &'a self,
319 texts: &'a [TextToEmbed<'a>],
320 ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
321 let embeddings = texts
322 .iter()
323 .map(|to_embed| (self.compute_embedding)(to_embed.text))
324 .collect();
325 future::ready(embeddings).boxed()
326 }
327
328 fn batch_size(&self) -> usize {
329 self.batch_size
330 }
331 }
332
333 #[gpui::test]
334 async fn test_search(cx: &mut TestAppContext) {
335 cx.executor().allow_parking();
336
337 init_test(cx);
338
339 cx.update(|cx| {
340 // This functionality is staff-flagged.
341 cx.update_flags(true, vec![]);
342 });
343
344 let temp_dir = tempfile::tempdir().unwrap();
345
346 let mut semantic_index = SemanticDb::new(
347 temp_dir.path().into(),
348 Arc::new(TestEmbeddingProvider::new(16, |text| {
349 let mut embedding = vec![0f32; 2];
350 // if the text contains garbage, give it a 1 in the first dimension
351 if text.contains("garbage in") {
352 embedding[0] = 0.9;
353 } else {
354 embedding[0] = -0.9;
355 }
356
357 if text.contains("garbage out") {
358 embedding[1] = 0.9;
359 } else {
360 embedding[1] = -0.9;
361 }
362
363 Ok(Embedding::new(embedding))
364 })),
365 &mut cx.to_async(),
366 )
367 .await
368 .unwrap();
369
370 let fs = FakeFs::new(cx.executor());
371 let project_path = Path::new("/fake_project");
372
373 fs.insert_tree(
374 project_path,
375 json!({
376 "fixture": {
377 "main.rs": include_str!("../fixture/main.rs"),
378 "needle.md": include_str!("../fixture/needle.md"),
379 }
380 }),
381 )
382 .await;
383
384 let project = Project::test(fs, [project_path], cx).await;
385
386 let project_index = cx.update(|cx| {
387 let language_registry = project.read(cx).languages().clone();
388 let node_runtime = project.read(cx).node_runtime().unwrap().clone();
389 languages::init(language_registry, node_runtime, cx);
390 semantic_index.create_project_index(project.clone(), cx)
391 });
392
393 cx.run_until_parked();
394 while cx
395 .update(|cx| semantic_index.remaining_summaries(&project.downgrade(), cx))
396 .unwrap()
397 > 0
398 {
399 cx.run_until_parked();
400 }
401
402 let results = cx
403 .update(|cx| {
404 let project_index = project_index.read(cx);
405 let query = "garbage in, garbage out";
406 project_index.search(vec![query.into()], 4, cx)
407 })
408 .await
409 .unwrap();
410
411 assert!(
412 results.len() > 1,
413 "should have found some results, but only found {:?}",
414 results
415 );
416
417 for result in &results {
418 println!("result: {:?}", result.path);
419 println!("score: {:?}", result.score);
420 }
421
422 // Find result that is greater than 0.5
423 let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
424
425 assert_eq!(
426 search_result.path.to_string_lossy(),
427 separator!("fixture/needle.md")
428 );
429
430 let content = cx
431 .update(|cx| {
432 let worktree = search_result.worktree.read(cx);
433 let entry_abs_path = worktree.abs_path().join(&search_result.path);
434 let fs = project.read(cx).fs().clone();
435 cx.background_executor()
436 .spawn(async move { fs.load(&entry_abs_path).await.unwrap() })
437 })
438 .await;
439
440 let range = search_result.range.clone();
441 let content = content[range.clone()].to_owned();
442
443 assert!(content.contains("garbage in, garbage out"));
444 }
445
446 #[gpui::test]
447 async fn test_embed_files(cx: &mut TestAppContext) {
448 cx.executor().allow_parking();
449
450 let provider = Arc::new(TestEmbeddingProvider::new(3, |text| {
451 if text.contains('g') {
452 Err(anyhow!("cannot embed text containing a 'g' character"))
453 } else {
454 Ok(Embedding::new(
455 ('a'..='z')
456 .map(|char| text.chars().filter(|c| *c == char).count() as f32)
457 .collect(),
458 ))
459 }
460 }));
461
462 let (indexing_progress_tx, _) = channel::unbounded();
463 let indexing_entries = Arc::new(IndexingEntrySet::new(indexing_progress_tx));
464
465 let (chunked_files_tx, chunked_files_rx) = channel::unbounded::<ChunkedFile>();
466 chunked_files_tx
467 .send_blocking(ChunkedFile {
468 path: Path::new("test1.md").into(),
469 mtime: None,
470 handle: indexing_entries.insert(ProjectEntryId::from_proto(0)),
471 text: "abcdefghijklmnop".to_string(),
472 chunks: [0..4, 4..8, 8..12, 12..16]
473 .into_iter()
474 .map(|range| Chunk {
475 range,
476 digest: Default::default(),
477 })
478 .collect(),
479 })
480 .unwrap();
481 chunked_files_tx
482 .send_blocking(ChunkedFile {
483 path: Path::new("test2.md").into(),
484 mtime: None,
485 handle: indexing_entries.insert(ProjectEntryId::from_proto(1)),
486 text: "qrstuvwxyz".to_string(),
487 chunks: [0..4, 4..8, 8..10]
488 .into_iter()
489 .map(|range| Chunk {
490 range,
491 digest: Default::default(),
492 })
493 .collect(),
494 })
495 .unwrap();
496 chunked_files_tx.close();
497
498 let embed_files_task =
499 cx.update(|cx| EmbeddingIndex::embed_files(provider.clone(), chunked_files_rx, cx));
500 embed_files_task.task.await.unwrap();
501
502 let embedded_files_rx = embed_files_task.files;
503 let mut embedded_files = Vec::new();
504 while let Ok((embedded_file, _)) = embedded_files_rx.recv().await {
505 embedded_files.push(embedded_file);
506 }
507
508 assert_eq!(embedded_files.len(), 1);
509 assert_eq!(embedded_files[0].path.as_ref(), Path::new("test2.md"));
510 assert_eq!(
511 embedded_files[0]
512 .chunks
513 .iter()
514 .map(|embedded_chunk| { embedded_chunk.embedding.clone() })
515 .collect::<Vec<Embedding>>(),
516 vec![
517 (provider.compute_embedding)("qrst").unwrap(),
518 (provider.compute_embedding)("uvwx").unwrap(),
519 (provider.compute_embedding)("yz").unwrap(),
520 ],
521 );
522 }
523
524 #[gpui::test]
525 async fn test_load_search_results(cx: &mut TestAppContext) {
526 init_test(cx);
527
528 let fs = FakeFs::new(cx.executor());
529 let project_path = Path::new("/fake_project");
530
531 let file1_content = "one\ntwo\nthree\nfour\nfive\n";
532 let file2_content = "aaa\nbbb\nccc\nddd\neee\n";
533
534 fs.insert_tree(
535 project_path,
536 json!({
537 "file1.txt": file1_content,
538 "file2.txt": file2_content,
539 }),
540 )
541 .await;
542
543 let fs = fs as Arc<dyn Fs>;
544 let project = Project::test(fs.clone(), [project_path], cx).await;
545 let worktree = project.read_with(cx, |project, cx| project.worktrees(cx).next().unwrap());
546
547 // chunk that is already newline-aligned
548 let search_results = vec![SearchResult {
549 worktree: worktree.clone(),
550 path: Path::new("file1.txt").into(),
551 range: 0..file1_content.find("four").unwrap(),
552 score: 0.5,
553 query_index: 0,
554 }];
555 assert_eq!(
556 SemanticDb::load_results(search_results, &fs, &cx.to_async())
557 .await
558 .unwrap(),
559 &[LoadedSearchResult {
560 path: Path::new("file1.txt").into(),
561 full_path: "fake_project/file1.txt".into(),
562 excerpt_content: "one\ntwo\nthree\n".into(),
563 row_range: 0..=2,
564 query_index: 0,
565 }]
566 );
567
568 // chunk that is *not* newline-aligned
569 let search_results = vec![SearchResult {
570 worktree: worktree.clone(),
571 path: Path::new("file1.txt").into(),
572 range: file1_content.find("two").unwrap() + 1..file1_content.find("four").unwrap() + 2,
573 score: 0.5,
574 query_index: 0,
575 }];
576 assert_eq!(
577 SemanticDb::load_results(search_results, &fs, &cx.to_async())
578 .await
579 .unwrap(),
580 &[LoadedSearchResult {
581 path: Path::new("file1.txt").into(),
582 full_path: "fake_project/file1.txt".into(),
583 excerpt_content: "two\nthree\nfour\n".into(),
584 row_range: 1..=3,
585 query_index: 0,
586 }]
587 );
588
589 // chunks that are adjacent
590
591 let search_results = vec![
592 SearchResult {
593 worktree: worktree.clone(),
594 path: Path::new("file1.txt").into(),
595 range: file1_content.find("two").unwrap()..file1_content.len(),
596 score: 0.6,
597 query_index: 0,
598 },
599 SearchResult {
600 worktree: worktree.clone(),
601 path: Path::new("file1.txt").into(),
602 range: 0..file1_content.find("two").unwrap(),
603 score: 0.5,
604 query_index: 1,
605 },
606 SearchResult {
607 worktree: worktree.clone(),
608 path: Path::new("file2.txt").into(),
609 range: 0..file2_content.len(),
610 score: 0.8,
611 query_index: 1,
612 },
613 ];
614 assert_eq!(
615 SemanticDb::load_results(search_results, &fs, &cx.to_async())
616 .await
617 .unwrap(),
618 &[
619 LoadedSearchResult {
620 path: Path::new("file2.txt").into(),
621 full_path: "fake_project/file2.txt".into(),
622 excerpt_content: file2_content.into(),
623 row_range: 0..=4,
624 query_index: 1,
625 },
626 LoadedSearchResult {
627 path: Path::new("file1.txt").into(),
628 full_path: "fake_project/file1.txt".into(),
629 excerpt_content: file1_content.into(),
630 row_range: 0..=4,
631 query_index: 0,
632 }
633 ]
634 );
635 }
636}