1mod chunking;
2mod embedding;
3mod embedding_index;
4mod indexing;
5mod project_index;
6mod project_index_debug_view;
7mod summary_backlog;
8mod summary_index;
9mod worktree_index;
10
11use anyhow::{Context as _, Result};
12use collections::HashMap;
13use fs::Fs;
14use gpui::{App, AppContext as _, AsyncApp, BorrowAppContext, Context, Entity, Global, WeakEntity};
15use language::LineEnding;
16use project::{Project, Worktree};
17use std::{
18 cmp::Ordering,
19 path::{Path, PathBuf},
20 sync::Arc,
21};
22use util::ResultExt as _;
23use workspace::Workspace;
24
25pub use embedding::*;
26pub use project_index::{LoadedSearchResult, ProjectIndex, SearchResult, Status};
27pub use project_index_debug_view::ProjectIndexDebugView;
28pub use summary_index::FileSummary;
29
30pub struct SemanticDb {
31 embedding_provider: Arc<dyn EmbeddingProvider>,
32 db_connection: Option<heed::Env>,
33 project_indices: HashMap<WeakEntity<Project>, Entity<ProjectIndex>>,
34}
35
36impl Global for SemanticDb {}
37
38impl SemanticDb {
39 pub async fn new(
40 db_path: PathBuf,
41 embedding_provider: Arc<dyn EmbeddingProvider>,
42 cx: &mut AsyncApp,
43 ) -> Result<Self> {
44 let db_connection = cx
45 .background_spawn(async move {
46 std::fs::create_dir_all(&db_path)?;
47 unsafe {
48 heed::EnvOpenOptions::new()
49 .map_size(1024 * 1024 * 1024)
50 .max_dbs(3000)
51 .open(db_path)
52 }
53 })
54 .await
55 .context("opening database connection")?;
56
57 cx.update(|cx| {
58 cx.observe_new(
59 |workspace: &mut Workspace, _window, cx: &mut Context<Workspace>| {
60 let project = workspace.project().clone();
61
62 if cx.has_global::<SemanticDb>() {
63 cx.update_global::<SemanticDb, _>(|this, cx| {
64 this.create_project_index(project, cx);
65 })
66 } else {
67 log::info!("No SemanticDb, skipping project index")
68 }
69 },
70 )
71 .detach();
72 })
73 .ok();
74
75 Ok(SemanticDb {
76 db_connection: Some(db_connection),
77 embedding_provider,
78 project_indices: HashMap::default(),
79 })
80 }
81
82 pub async fn load_results(
83 mut results: Vec<SearchResult>,
84 fs: &Arc<dyn Fs>,
85 cx: &AsyncApp,
86 ) -> Result<Vec<LoadedSearchResult>> {
87 let mut max_scores_by_path = HashMap::<_, (f32, usize)>::default();
88 for result in &results {
89 let (score, query_index) = max_scores_by_path
90 .entry((result.worktree.clone(), result.path.clone()))
91 .or_default();
92 if result.score > *score {
93 *score = result.score;
94 *query_index = result.query_index;
95 }
96 }
97
98 results.sort_by(|a, b| {
99 let max_score_a = max_scores_by_path[&(a.worktree.clone(), a.path.clone())].0;
100 let max_score_b = max_scores_by_path[&(b.worktree.clone(), b.path.clone())].0;
101 max_score_b
102 .partial_cmp(&max_score_a)
103 .unwrap_or(Ordering::Equal)
104 .then_with(|| a.worktree.entity_id().cmp(&b.worktree.entity_id()))
105 .then_with(|| a.path.cmp(&b.path))
106 .then_with(|| a.range.start.cmp(&b.range.start))
107 });
108
109 let mut last_loaded_file: Option<(Entity<Worktree>, Arc<Path>, PathBuf, String)> = None;
110 let mut loaded_results = Vec::<LoadedSearchResult>::new();
111 for result in results {
112 let full_path;
113 let file_content;
114 if let Some(last_loaded_file) =
115 last_loaded_file
116 .as_ref()
117 .filter(|(last_worktree, last_path, _, _)| {
118 last_worktree == &result.worktree && last_path == &result.path
119 })
120 {
121 full_path = last_loaded_file.2.clone();
122 file_content = &last_loaded_file.3;
123 } else {
124 let output = result.worktree.read_with(cx, |worktree, _cx| {
125 let entry_abs_path = worktree.abs_path().join(&result.path);
126 let mut entry_full_path = PathBuf::from(worktree.root_name());
127 entry_full_path.push(&result.path);
128 let file_content = async {
129 let entry_abs_path = entry_abs_path;
130 fs.load(&entry_abs_path).await
131 };
132 (entry_full_path, file_content)
133 })?;
134 full_path = output.0;
135 let Some(content) = output.1.await.log_err() else {
136 continue;
137 };
138 last_loaded_file = Some((
139 result.worktree.clone(),
140 result.path.clone(),
141 full_path.clone(),
142 content,
143 ));
144 file_content = &last_loaded_file.as_ref().unwrap().3;
145 };
146
147 let query_index = max_scores_by_path[&(result.worktree.clone(), result.path.clone())].1;
148
149 let mut range_start = result.range.start.min(file_content.len());
150 let mut range_end = result.range.end.min(file_content.len());
151 while !file_content.is_char_boundary(range_start) {
152 range_start += 1;
153 }
154 while !file_content.is_char_boundary(range_end) {
155 range_end += 1;
156 }
157
158 let start_row = file_content[0..range_start].matches('\n').count() as u32;
159 let mut end_row = file_content[0..range_end].matches('\n').count() as u32;
160 let start_line_byte_offset = file_content[0..range_start]
161 .rfind('\n')
162 .map(|pos| pos + 1)
163 .unwrap_or_default();
164 let mut end_line_byte_offset = range_end;
165 if file_content[..end_line_byte_offset].ends_with('\n') {
166 end_row -= 1;
167 } else {
168 end_line_byte_offset = file_content[range_end..]
169 .find('\n')
170 .map(|pos| range_end + pos + 1)
171 .unwrap_or_else(|| file_content.len());
172 }
173 let mut excerpt_content =
174 file_content[start_line_byte_offset..end_line_byte_offset].to_string();
175 LineEnding::normalize(&mut excerpt_content);
176
177 if let Some(prev_result) = loaded_results.last_mut() {
178 if prev_result.full_path == full_path {
179 if *prev_result.row_range.end() + 1 == start_row {
180 prev_result.row_range = *prev_result.row_range.start()..=end_row;
181 prev_result.excerpt_content.push_str(&excerpt_content);
182 continue;
183 }
184 }
185 }
186
187 loaded_results.push(LoadedSearchResult {
188 path: result.path,
189 full_path,
190 excerpt_content,
191 row_range: start_row..=end_row,
192 query_index,
193 });
194 }
195
196 for result in &mut loaded_results {
197 while result.excerpt_content.ends_with("\n\n") {
198 result.excerpt_content.pop();
199 result.row_range =
200 *result.row_range.start()..=result.row_range.end().saturating_sub(1)
201 }
202 }
203
204 Ok(loaded_results)
205 }
206
207 pub fn project_index(
208 &mut self,
209 project: Entity<Project>,
210 _cx: &mut App,
211 ) -> Option<Entity<ProjectIndex>> {
212 self.project_indices.get(&project.downgrade()).cloned()
213 }
214
215 pub fn remaining_summaries(
216 &self,
217 project: &WeakEntity<Project>,
218 cx: &mut App,
219 ) -> Option<usize> {
220 self.project_indices.get(project).map(|project_index| {
221 project_index.update(cx, |project_index, cx| {
222 project_index.remaining_summaries(cx)
223 })
224 })
225 }
226
227 pub fn create_project_index(
228 &mut self,
229 project: Entity<Project>,
230 cx: &mut App,
231 ) -> Entity<ProjectIndex> {
232 let project_index = cx.new(|cx| {
233 ProjectIndex::new(
234 project.clone(),
235 self.db_connection.clone().unwrap(),
236 self.embedding_provider.clone(),
237 cx,
238 )
239 });
240
241 let project_weak = project.downgrade();
242 self.project_indices
243 .insert(project_weak.clone(), project_index.clone());
244
245 cx.observe_release(&project, move |_, cx| {
246 if cx.has_global::<SemanticDb>() {
247 cx.update_global::<SemanticDb, _>(|this, _| {
248 this.project_indices.remove(&project_weak);
249 })
250 }
251 })
252 .detach();
253
254 project_index
255 }
256}
257
258impl Drop for SemanticDb {
259 fn drop(&mut self) {
260 self.db_connection.take().unwrap().prepare_for_closing();
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use anyhow::anyhow;
268 use chunking::Chunk;
269 use embedding_index::{ChunkedFile, EmbeddingIndex};
270 use feature_flags::FeatureFlagAppExt;
271 use fs::FakeFs;
272 use futures::{FutureExt, future::BoxFuture};
273 use gpui::TestAppContext;
274 use indexing::IndexingEntrySet;
275 use language::language_settings::AllLanguageSettings;
276 use project::{Project, ProjectEntryId};
277 use serde_json::json;
278 use settings::SettingsStore;
279 use smol::channel;
280 use std::{future, path::Path, sync::Arc};
281 use util::separator;
282
283 fn init_test(cx: &mut TestAppContext) {
284 env_logger::try_init().ok();
285
286 cx.update(|cx| {
287 let store = SettingsStore::test(cx);
288 cx.set_global(store);
289 language::init(cx);
290 cx.update_flags(false, vec![]);
291 Project::init_settings(cx);
292 SettingsStore::update(cx, |store, cx| {
293 store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
294 });
295 });
296 }
297
298 pub struct TestEmbeddingProvider {
299 batch_size: usize,
300 compute_embedding: Box<dyn Fn(&str) -> Result<Embedding> + Send + Sync>,
301 }
302
303 impl TestEmbeddingProvider {
304 pub fn new(
305 batch_size: usize,
306 compute_embedding: impl 'static + Fn(&str) -> Result<Embedding> + Send + Sync,
307 ) -> Self {
308 Self {
309 batch_size,
310 compute_embedding: Box::new(compute_embedding),
311 }
312 }
313 }
314
315 impl EmbeddingProvider for TestEmbeddingProvider {
316 fn embed<'a>(
317 &'a self,
318 texts: &'a [TextToEmbed<'a>],
319 ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
320 let embeddings = texts
321 .iter()
322 .map(|to_embed| (self.compute_embedding)(to_embed.text))
323 .collect();
324 future::ready(embeddings).boxed()
325 }
326
327 fn batch_size(&self) -> usize {
328 self.batch_size
329 }
330 }
331
332 #[gpui::test]
333 async fn test_search(cx: &mut TestAppContext) {
334 cx.executor().allow_parking();
335
336 init_test(cx);
337
338 cx.update(|cx| {
339 // This functionality is staff-flagged.
340 cx.update_flags(true, vec![]);
341 });
342
343 let temp_dir = tempfile::tempdir().unwrap();
344
345 let mut semantic_index = SemanticDb::new(
346 temp_dir.path().into(),
347 Arc::new(TestEmbeddingProvider::new(16, |text| {
348 let mut embedding = vec![0f32; 2];
349 // if the text contains garbage, give it a 1 in the first dimension
350 if text.contains("garbage in") {
351 embedding[0] = 0.9;
352 } else {
353 embedding[0] = -0.9;
354 }
355
356 if text.contains("garbage out") {
357 embedding[1] = 0.9;
358 } else {
359 embedding[1] = -0.9;
360 }
361
362 Ok(Embedding::new(embedding))
363 })),
364 &mut cx.to_async(),
365 )
366 .await
367 .unwrap();
368
369 let fs = FakeFs::new(cx.executor());
370 let project_path = Path::new("/fake_project");
371
372 fs.insert_tree(
373 project_path,
374 json!({
375 "fixture": {
376 "main.rs": include_str!("../fixture/main.rs"),
377 "needle.md": include_str!("../fixture/needle.md"),
378 }
379 }),
380 )
381 .await;
382
383 let project = Project::test(fs, [project_path], cx).await;
384
385 let project_index = cx.update(|cx| {
386 let language_registry = project.read(cx).languages().clone();
387 let node_runtime = project.read(cx).node_runtime().unwrap().clone();
388 languages::init(language_registry, node_runtime, cx);
389 semantic_index.create_project_index(project.clone(), cx)
390 });
391
392 cx.run_until_parked();
393 while cx
394 .update(|cx| semantic_index.remaining_summaries(&project.downgrade(), cx))
395 .unwrap()
396 > 0
397 {
398 cx.run_until_parked();
399 }
400
401 let results = cx
402 .update(|cx| {
403 let project_index = project_index.read(cx);
404 let query = "garbage in, garbage out";
405 project_index.search(vec![query.into()], 4, cx)
406 })
407 .await
408 .unwrap();
409
410 assert!(
411 results.len() > 1,
412 "should have found some results, but only found {:?}",
413 results
414 );
415
416 for result in &results {
417 println!("result: {:?}", result.path);
418 println!("score: {:?}", result.score);
419 }
420
421 // Find result that is greater than 0.5
422 let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
423
424 assert_eq!(
425 search_result.path.to_string_lossy(),
426 separator!("fixture/needle.md")
427 );
428
429 let content = cx
430 .update(|cx| {
431 let worktree = search_result.worktree.read(cx);
432 let entry_abs_path = worktree.abs_path().join(&search_result.path);
433 let fs = project.read(cx).fs().clone();
434 cx.background_spawn(async move { fs.load(&entry_abs_path).await.unwrap() })
435 })
436 .await;
437
438 let range = search_result.range.clone();
439 let content = content[range.clone()].to_owned();
440
441 assert!(content.contains("garbage in, garbage out"));
442 }
443
444 #[gpui::test]
445 async fn test_embed_files(cx: &mut TestAppContext) {
446 cx.executor().allow_parking();
447
448 let provider = Arc::new(TestEmbeddingProvider::new(3, |text| {
449 if text.contains('g') {
450 Err(anyhow!("cannot embed text containing a 'g' character"))
451 } else {
452 Ok(Embedding::new(
453 ('a'..='z')
454 .map(|char| text.chars().filter(|c| *c == char).count() as f32)
455 .collect(),
456 ))
457 }
458 }));
459
460 let (indexing_progress_tx, _) = channel::unbounded();
461 let indexing_entries = Arc::new(IndexingEntrySet::new(indexing_progress_tx));
462
463 let (chunked_files_tx, chunked_files_rx) = channel::unbounded::<ChunkedFile>();
464 chunked_files_tx
465 .send_blocking(ChunkedFile {
466 path: Path::new("test1.md").into(),
467 mtime: None,
468 handle: indexing_entries.insert(ProjectEntryId::from_proto(0)),
469 text: "abcdefghijklmnop".to_string(),
470 chunks: [0..4, 4..8, 8..12, 12..16]
471 .into_iter()
472 .map(|range| Chunk {
473 range,
474 digest: Default::default(),
475 })
476 .collect(),
477 })
478 .unwrap();
479 chunked_files_tx
480 .send_blocking(ChunkedFile {
481 path: Path::new("test2.md").into(),
482 mtime: None,
483 handle: indexing_entries.insert(ProjectEntryId::from_proto(1)),
484 text: "qrstuvwxyz".to_string(),
485 chunks: [0..4, 4..8, 8..10]
486 .into_iter()
487 .map(|range| Chunk {
488 range,
489 digest: Default::default(),
490 })
491 .collect(),
492 })
493 .unwrap();
494 chunked_files_tx.close();
495
496 let embed_files_task =
497 cx.update(|cx| EmbeddingIndex::embed_files(provider.clone(), chunked_files_rx, cx));
498 embed_files_task.task.await.unwrap();
499
500 let embedded_files_rx = embed_files_task.files;
501 let mut embedded_files = Vec::new();
502 while let Ok((embedded_file, _)) = embedded_files_rx.recv().await {
503 embedded_files.push(embedded_file);
504 }
505
506 assert_eq!(embedded_files.len(), 1);
507 assert_eq!(embedded_files[0].path.as_ref(), Path::new("test2.md"));
508 assert_eq!(
509 embedded_files[0]
510 .chunks
511 .iter()
512 .map(|embedded_chunk| { embedded_chunk.embedding.clone() })
513 .collect::<Vec<Embedding>>(),
514 vec![
515 (provider.compute_embedding)("qrst").unwrap(),
516 (provider.compute_embedding)("uvwx").unwrap(),
517 (provider.compute_embedding)("yz").unwrap(),
518 ],
519 );
520 }
521
522 #[gpui::test]
523 async fn test_load_search_results(cx: &mut TestAppContext) {
524 init_test(cx);
525
526 let fs = FakeFs::new(cx.executor());
527 let project_path = Path::new("/fake_project");
528
529 let file1_content = "one\ntwo\nthree\nfour\nfive\n";
530 let file2_content = "aaa\nbbb\nccc\nddd\neee\n";
531
532 fs.insert_tree(
533 project_path,
534 json!({
535 "file1.txt": file1_content,
536 "file2.txt": file2_content,
537 }),
538 )
539 .await;
540
541 let fs = fs as Arc<dyn Fs>;
542 let project = Project::test(fs.clone(), [project_path], cx).await;
543 let worktree = project.read_with(cx, |project, cx| project.worktrees(cx).next().unwrap());
544
545 // chunk that is already newline-aligned
546 let search_results = vec![SearchResult {
547 worktree: worktree.clone(),
548 path: Path::new("file1.txt").into(),
549 range: 0..file1_content.find("four").unwrap(),
550 score: 0.5,
551 query_index: 0,
552 }];
553 assert_eq!(
554 SemanticDb::load_results(search_results, &fs, &cx.to_async())
555 .await
556 .unwrap(),
557 &[LoadedSearchResult {
558 path: Path::new("file1.txt").into(),
559 full_path: "fake_project/file1.txt".into(),
560 excerpt_content: "one\ntwo\nthree\n".into(),
561 row_range: 0..=2,
562 query_index: 0,
563 }]
564 );
565
566 // chunk that is *not* newline-aligned
567 let search_results = vec![SearchResult {
568 worktree: worktree.clone(),
569 path: Path::new("file1.txt").into(),
570 range: file1_content.find("two").unwrap() + 1..file1_content.find("four").unwrap() + 2,
571 score: 0.5,
572 query_index: 0,
573 }];
574 assert_eq!(
575 SemanticDb::load_results(search_results, &fs, &cx.to_async())
576 .await
577 .unwrap(),
578 &[LoadedSearchResult {
579 path: Path::new("file1.txt").into(),
580 full_path: "fake_project/file1.txt".into(),
581 excerpt_content: "two\nthree\nfour\n".into(),
582 row_range: 1..=3,
583 query_index: 0,
584 }]
585 );
586
587 // chunks that are adjacent
588
589 let search_results = vec![
590 SearchResult {
591 worktree: worktree.clone(),
592 path: Path::new("file1.txt").into(),
593 range: file1_content.find("two").unwrap()..file1_content.len(),
594 score: 0.6,
595 query_index: 0,
596 },
597 SearchResult {
598 worktree: worktree.clone(),
599 path: Path::new("file1.txt").into(),
600 range: 0..file1_content.find("two").unwrap(),
601 score: 0.5,
602 query_index: 1,
603 },
604 SearchResult {
605 worktree: worktree.clone(),
606 path: Path::new("file2.txt").into(),
607 range: 0..file2_content.len(),
608 score: 0.8,
609 query_index: 1,
610 },
611 ];
612 assert_eq!(
613 SemanticDb::load_results(search_results, &fs, &cx.to_async())
614 .await
615 .unwrap(),
616 &[
617 LoadedSearchResult {
618 path: Path::new("file2.txt").into(),
619 full_path: "fake_project/file2.txt".into(),
620 excerpt_content: file2_content.into(),
621 row_range: 0..=4,
622 query_index: 1,
623 },
624 LoadedSearchResult {
625 path: Path::new("file1.txt").into(),
626 full_path: "fake_project/file1.txt".into(),
627 excerpt_content: file1_content.into(),
628 row_range: 0..=4,
629 query_index: 0,
630 }
631 ]
632 );
633 }
634}