semantic_index.rs

  1mod chunking;
  2mod embedding;
  3mod embedding_index;
  4mod indexing;
  5mod project_index;
  6mod project_index_debug_view;
  7mod summary_backlog;
  8mod summary_index;
  9mod worktree_index;
 10
 11use anyhow::{Context as _, Result};
 12use collections::HashMap;
 13use fs::Fs;
 14use gpui::{App, AppContext as _, AsyncApp, BorrowAppContext, Context, Entity, Global, WeakEntity};
 15use language::LineEnding;
 16use project::{Project, Worktree};
 17use std::{
 18    cmp::Ordering,
 19    path::{Path, PathBuf},
 20    sync::Arc,
 21};
 22use util::ResultExt as _;
 23use workspace::Workspace;
 24
 25pub use embedding::*;
 26pub use project_index::{LoadedSearchResult, ProjectIndex, SearchResult, Status};
 27pub use project_index_debug_view::ProjectIndexDebugView;
 28pub use summary_index::FileSummary;
 29
 30pub struct SemanticDb {
 31    embedding_provider: Arc<dyn EmbeddingProvider>,
 32    db_connection: Option<heed::Env>,
 33    project_indices: HashMap<WeakEntity<Project>, Entity<ProjectIndex>>,
 34}
 35
 36impl Global for SemanticDb {}
 37
 38impl SemanticDb {
 39    pub async fn new(
 40        db_path: PathBuf,
 41        embedding_provider: Arc<dyn EmbeddingProvider>,
 42        cx: &mut AsyncApp,
 43    ) -> Result<Self> {
 44        let db_connection = cx
 45            .background_executor()
 46            .spawn(async move {
 47                std::fs::create_dir_all(&db_path)?;
 48                unsafe {
 49                    heed::EnvOpenOptions::new()
 50                        .map_size(1024 * 1024 * 1024)
 51                        .max_dbs(3000)
 52                        .open(db_path)
 53                }
 54            })
 55            .await
 56            .context("opening database connection")?;
 57
 58        cx.update(|cx| {
 59            cx.observe_new(
 60                |workspace: &mut Workspace, _window, cx: &mut Context<Workspace>| {
 61                    let project = workspace.project().clone();
 62
 63                    if cx.has_global::<SemanticDb>() {
 64                        cx.update_global::<SemanticDb, _>(|this, cx| {
 65                            this.create_project_index(project, cx);
 66                        })
 67                    } else {
 68                        log::info!("No SemanticDb, skipping project index")
 69                    }
 70                },
 71            )
 72            .detach();
 73        })
 74        .ok();
 75
 76        Ok(SemanticDb {
 77            db_connection: Some(db_connection),
 78            embedding_provider,
 79            project_indices: HashMap::default(),
 80        })
 81    }
 82
 83    pub async fn load_results(
 84        mut results: Vec<SearchResult>,
 85        fs: &Arc<dyn Fs>,
 86        cx: &AsyncApp,
 87    ) -> Result<Vec<LoadedSearchResult>> {
 88        let mut max_scores_by_path = HashMap::<_, (f32, usize)>::default();
 89        for result in &results {
 90            let (score, query_index) = max_scores_by_path
 91                .entry((result.worktree.clone(), result.path.clone()))
 92                .or_default();
 93            if result.score > *score {
 94                *score = result.score;
 95                *query_index = result.query_index;
 96            }
 97        }
 98
 99        results.sort_by(|a, b| {
100            let max_score_a = max_scores_by_path[&(a.worktree.clone(), a.path.clone())].0;
101            let max_score_b = max_scores_by_path[&(b.worktree.clone(), b.path.clone())].0;
102            max_score_b
103                .partial_cmp(&max_score_a)
104                .unwrap_or(Ordering::Equal)
105                .then_with(|| a.worktree.entity_id().cmp(&b.worktree.entity_id()))
106                .then_with(|| a.path.cmp(&b.path))
107                .then_with(|| a.range.start.cmp(&b.range.start))
108        });
109
110        let mut last_loaded_file: Option<(Entity<Worktree>, Arc<Path>, PathBuf, String)> = None;
111        let mut loaded_results = Vec::<LoadedSearchResult>::new();
112        for result in results {
113            let full_path;
114            let file_content;
115            if let Some(last_loaded_file) =
116                last_loaded_file
117                    .as_ref()
118                    .filter(|(last_worktree, last_path, _, _)| {
119                        last_worktree == &result.worktree && last_path == &result.path
120                    })
121            {
122                full_path = last_loaded_file.2.clone();
123                file_content = &last_loaded_file.3;
124            } else {
125                let output = result.worktree.read_with(cx, |worktree, _cx| {
126                    let entry_abs_path = worktree.abs_path().join(&result.path);
127                    let mut entry_full_path = PathBuf::from(worktree.root_name());
128                    entry_full_path.push(&result.path);
129                    let file_content = async {
130                        let entry_abs_path = entry_abs_path;
131                        fs.load(&entry_abs_path).await
132                    };
133                    (entry_full_path, file_content)
134                })?;
135                full_path = output.0;
136                let Some(content) = output.1.await.log_err() else {
137                    continue;
138                };
139                last_loaded_file = Some((
140                    result.worktree.clone(),
141                    result.path.clone(),
142                    full_path.clone(),
143                    content,
144                ));
145                file_content = &last_loaded_file.as_ref().unwrap().3;
146            };
147
148            let query_index = max_scores_by_path[&(result.worktree.clone(), result.path.clone())].1;
149
150            let mut range_start = result.range.start.min(file_content.len());
151            let mut range_end = result.range.end.min(file_content.len());
152            while !file_content.is_char_boundary(range_start) {
153                range_start += 1;
154            }
155            while !file_content.is_char_boundary(range_end) {
156                range_end += 1;
157            }
158
159            let start_row = file_content[0..range_start].matches('\n').count() as u32;
160            let mut end_row = file_content[0..range_end].matches('\n').count() as u32;
161            let start_line_byte_offset = file_content[0..range_start]
162                .rfind('\n')
163                .map(|pos| pos + 1)
164                .unwrap_or_default();
165            let mut end_line_byte_offset = range_end;
166            if file_content[..end_line_byte_offset].ends_with('\n') {
167                end_row -= 1;
168            } else {
169                end_line_byte_offset = file_content[range_end..]
170                    .find('\n')
171                    .map(|pos| range_end + pos + 1)
172                    .unwrap_or_else(|| file_content.len());
173            }
174            let mut excerpt_content =
175                file_content[start_line_byte_offset..end_line_byte_offset].to_string();
176            LineEnding::normalize(&mut excerpt_content);
177
178            if let Some(prev_result) = loaded_results.last_mut() {
179                if prev_result.full_path == full_path {
180                    if *prev_result.row_range.end() + 1 == start_row {
181                        prev_result.row_range = *prev_result.row_range.start()..=end_row;
182                        prev_result.excerpt_content.push_str(&excerpt_content);
183                        continue;
184                    }
185                }
186            }
187
188            loaded_results.push(LoadedSearchResult {
189                path: result.path,
190                full_path,
191                excerpt_content,
192                row_range: start_row..=end_row,
193                query_index,
194            });
195        }
196
197        for result in &mut loaded_results {
198            while result.excerpt_content.ends_with("\n\n") {
199                result.excerpt_content.pop();
200                result.row_range =
201                    *result.row_range.start()..=result.row_range.end().saturating_sub(1)
202            }
203        }
204
205        Ok(loaded_results)
206    }
207
208    pub fn project_index(
209        &mut self,
210        project: Entity<Project>,
211        _cx: &mut App,
212    ) -> Option<Entity<ProjectIndex>> {
213        self.project_indices.get(&project.downgrade()).cloned()
214    }
215
216    pub fn remaining_summaries(
217        &self,
218        project: &WeakEntity<Project>,
219        cx: &mut App,
220    ) -> Option<usize> {
221        self.project_indices.get(project).map(|project_index| {
222            project_index.update(cx, |project_index, cx| {
223                project_index.remaining_summaries(cx)
224            })
225        })
226    }
227
228    pub fn create_project_index(
229        &mut self,
230        project: Entity<Project>,
231        cx: &mut App,
232    ) -> Entity<ProjectIndex> {
233        let project_index = cx.new(|cx| {
234            ProjectIndex::new(
235                project.clone(),
236                self.db_connection.clone().unwrap(),
237                self.embedding_provider.clone(),
238                cx,
239            )
240        });
241
242        let project_weak = project.downgrade();
243        self.project_indices
244            .insert(project_weak.clone(), project_index.clone());
245
246        cx.observe_release(&project, move |_, cx| {
247            if cx.has_global::<SemanticDb>() {
248                cx.update_global::<SemanticDb, _>(|this, _| {
249                    this.project_indices.remove(&project_weak);
250                })
251            }
252        })
253        .detach();
254
255        project_index
256    }
257}
258
259impl Drop for SemanticDb {
260    fn drop(&mut self) {
261        self.db_connection.take().unwrap().prepare_for_closing();
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use anyhow::anyhow;
269    use chunking::Chunk;
270    use embedding_index::{ChunkedFile, EmbeddingIndex};
271    use feature_flags::FeatureFlagAppExt;
272    use fs::FakeFs;
273    use futures::{future::BoxFuture, FutureExt};
274    use gpui::TestAppContext;
275    use indexing::IndexingEntrySet;
276    use language::language_settings::AllLanguageSettings;
277    use project::{Project, ProjectEntryId};
278    use serde_json::json;
279    use settings::SettingsStore;
280    use smol::channel;
281    use std::{future, path::Path, sync::Arc};
282    use util::separator;
283
284    fn init_test(cx: &mut TestAppContext) {
285        env_logger::try_init().ok();
286
287        cx.update(|cx| {
288            let store = SettingsStore::test(cx);
289            cx.set_global(store);
290            language::init(cx);
291            cx.update_flags(false, vec![]);
292            Project::init_settings(cx);
293            SettingsStore::update(cx, |store, cx| {
294                store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
295            });
296        });
297    }
298
299    pub struct TestEmbeddingProvider {
300        batch_size: usize,
301        compute_embedding: Box<dyn Fn(&str) -> Result<Embedding> + Send + Sync>,
302    }
303
304    impl TestEmbeddingProvider {
305        pub fn new(
306            batch_size: usize,
307            compute_embedding: impl 'static + Fn(&str) -> Result<Embedding> + Send + Sync,
308        ) -> Self {
309            Self {
310                batch_size,
311                compute_embedding: Box::new(compute_embedding),
312            }
313        }
314    }
315
316    impl EmbeddingProvider for TestEmbeddingProvider {
317        fn embed<'a>(
318            &'a self,
319            texts: &'a [TextToEmbed<'a>],
320        ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
321            let embeddings = texts
322                .iter()
323                .map(|to_embed| (self.compute_embedding)(to_embed.text))
324                .collect();
325            future::ready(embeddings).boxed()
326        }
327
328        fn batch_size(&self) -> usize {
329            self.batch_size
330        }
331    }
332
333    #[gpui::test]
334    async fn test_search(cx: &mut TestAppContext) {
335        cx.executor().allow_parking();
336
337        init_test(cx);
338
339        cx.update(|cx| {
340            // This functionality is staff-flagged.
341            cx.update_flags(true, vec![]);
342        });
343
344        let temp_dir = tempfile::tempdir().unwrap();
345
346        let mut semantic_index = SemanticDb::new(
347            temp_dir.path().into(),
348            Arc::new(TestEmbeddingProvider::new(16, |text| {
349                let mut embedding = vec![0f32; 2];
350                // if the text contains garbage, give it a 1 in the first dimension
351                if text.contains("garbage in") {
352                    embedding[0] = 0.9;
353                } else {
354                    embedding[0] = -0.9;
355                }
356
357                if text.contains("garbage out") {
358                    embedding[1] = 0.9;
359                } else {
360                    embedding[1] = -0.9;
361                }
362
363                Ok(Embedding::new(embedding))
364            })),
365            &mut cx.to_async(),
366        )
367        .await
368        .unwrap();
369
370        let fs = FakeFs::new(cx.executor());
371        let project_path = Path::new("/fake_project");
372
373        fs.insert_tree(
374            project_path,
375            json!({
376                "fixture": {
377                    "main.rs": include_str!("../fixture/main.rs"),
378                    "needle.md": include_str!("../fixture/needle.md"),
379                }
380            }),
381        )
382        .await;
383
384        let project = Project::test(fs, [project_path], cx).await;
385
386        let project_index = cx.update(|cx| {
387            let language_registry = project.read(cx).languages().clone();
388            let node_runtime = project.read(cx).node_runtime().unwrap().clone();
389            languages::init(language_registry, node_runtime, cx);
390            semantic_index.create_project_index(project.clone(), cx)
391        });
392
393        cx.run_until_parked();
394        while cx
395            .update(|cx| semantic_index.remaining_summaries(&project.downgrade(), cx))
396            .unwrap()
397            > 0
398        {
399            cx.run_until_parked();
400        }
401
402        let results = cx
403            .update(|cx| {
404                let project_index = project_index.read(cx);
405                let query = "garbage in, garbage out";
406                project_index.search(vec![query.into()], 4, cx)
407            })
408            .await
409            .unwrap();
410
411        assert!(
412            results.len() > 1,
413            "should have found some results, but only found {:?}",
414            results
415        );
416
417        for result in &results {
418            println!("result: {:?}", result.path);
419            println!("score: {:?}", result.score);
420        }
421
422        // Find result that is greater than 0.5
423        let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
424
425        assert_eq!(
426            search_result.path.to_string_lossy(),
427            separator!("fixture/needle.md")
428        );
429
430        let content = cx
431            .update(|cx| {
432                let worktree = search_result.worktree.read(cx);
433                let entry_abs_path = worktree.abs_path().join(&search_result.path);
434                let fs = project.read(cx).fs().clone();
435                cx.background_executor()
436                    .spawn(async move { fs.load(&entry_abs_path).await.unwrap() })
437            })
438            .await;
439
440        let range = search_result.range.clone();
441        let content = content[range.clone()].to_owned();
442
443        assert!(content.contains("garbage in, garbage out"));
444    }
445
446    #[gpui::test]
447    async fn test_embed_files(cx: &mut TestAppContext) {
448        cx.executor().allow_parking();
449
450        let provider = Arc::new(TestEmbeddingProvider::new(3, |text| {
451            if text.contains('g') {
452                Err(anyhow!("cannot embed text containing a 'g' character"))
453            } else {
454                Ok(Embedding::new(
455                    ('a'..='z')
456                        .map(|char| text.chars().filter(|c| *c == char).count() as f32)
457                        .collect(),
458                ))
459            }
460        }));
461
462        let (indexing_progress_tx, _) = channel::unbounded();
463        let indexing_entries = Arc::new(IndexingEntrySet::new(indexing_progress_tx));
464
465        let (chunked_files_tx, chunked_files_rx) = channel::unbounded::<ChunkedFile>();
466        chunked_files_tx
467            .send_blocking(ChunkedFile {
468                path: Path::new("test1.md").into(),
469                mtime: None,
470                handle: indexing_entries.insert(ProjectEntryId::from_proto(0)),
471                text: "abcdefghijklmnop".to_string(),
472                chunks: [0..4, 4..8, 8..12, 12..16]
473                    .into_iter()
474                    .map(|range| Chunk {
475                        range,
476                        digest: Default::default(),
477                    })
478                    .collect(),
479            })
480            .unwrap();
481        chunked_files_tx
482            .send_blocking(ChunkedFile {
483                path: Path::new("test2.md").into(),
484                mtime: None,
485                handle: indexing_entries.insert(ProjectEntryId::from_proto(1)),
486                text: "qrstuvwxyz".to_string(),
487                chunks: [0..4, 4..8, 8..10]
488                    .into_iter()
489                    .map(|range| Chunk {
490                        range,
491                        digest: Default::default(),
492                    })
493                    .collect(),
494            })
495            .unwrap();
496        chunked_files_tx.close();
497
498        let embed_files_task =
499            cx.update(|cx| EmbeddingIndex::embed_files(provider.clone(), chunked_files_rx, cx));
500        embed_files_task.task.await.unwrap();
501
502        let embedded_files_rx = embed_files_task.files;
503        let mut embedded_files = Vec::new();
504        while let Ok((embedded_file, _)) = embedded_files_rx.recv().await {
505            embedded_files.push(embedded_file);
506        }
507
508        assert_eq!(embedded_files.len(), 1);
509        assert_eq!(embedded_files[0].path.as_ref(), Path::new("test2.md"));
510        assert_eq!(
511            embedded_files[0]
512                .chunks
513                .iter()
514                .map(|embedded_chunk| { embedded_chunk.embedding.clone() })
515                .collect::<Vec<Embedding>>(),
516            vec![
517                (provider.compute_embedding)("qrst").unwrap(),
518                (provider.compute_embedding)("uvwx").unwrap(),
519                (provider.compute_embedding)("yz").unwrap(),
520            ],
521        );
522    }
523
524    #[gpui::test]
525    async fn test_load_search_results(cx: &mut TestAppContext) {
526        init_test(cx);
527
528        let fs = FakeFs::new(cx.executor());
529        let project_path = Path::new("/fake_project");
530
531        let file1_content = "one\ntwo\nthree\nfour\nfive\n";
532        let file2_content = "aaa\nbbb\nccc\nddd\neee\n";
533
534        fs.insert_tree(
535            project_path,
536            json!({
537                "file1.txt": file1_content,
538                "file2.txt": file2_content,
539            }),
540        )
541        .await;
542
543        let fs = fs as Arc<dyn Fs>;
544        let project = Project::test(fs.clone(), [project_path], cx).await;
545        let worktree = project.read_with(cx, |project, cx| project.worktrees(cx).next().unwrap());
546
547        // chunk that is already newline-aligned
548        let search_results = vec![SearchResult {
549            worktree: worktree.clone(),
550            path: Path::new("file1.txt").into(),
551            range: 0..file1_content.find("four").unwrap(),
552            score: 0.5,
553            query_index: 0,
554        }];
555        assert_eq!(
556            SemanticDb::load_results(search_results, &fs, &cx.to_async())
557                .await
558                .unwrap(),
559            &[LoadedSearchResult {
560                path: Path::new("file1.txt").into(),
561                full_path: "fake_project/file1.txt".into(),
562                excerpt_content: "one\ntwo\nthree\n".into(),
563                row_range: 0..=2,
564                query_index: 0,
565            }]
566        );
567
568        // chunk that is *not* newline-aligned
569        let search_results = vec![SearchResult {
570            worktree: worktree.clone(),
571            path: Path::new("file1.txt").into(),
572            range: file1_content.find("two").unwrap() + 1..file1_content.find("four").unwrap() + 2,
573            score: 0.5,
574            query_index: 0,
575        }];
576        assert_eq!(
577            SemanticDb::load_results(search_results, &fs, &cx.to_async())
578                .await
579                .unwrap(),
580            &[LoadedSearchResult {
581                path: Path::new("file1.txt").into(),
582                full_path: "fake_project/file1.txt".into(),
583                excerpt_content: "two\nthree\nfour\n".into(),
584                row_range: 1..=3,
585                query_index: 0,
586            }]
587        );
588
589        // chunks that are adjacent
590
591        let search_results = vec![
592            SearchResult {
593                worktree: worktree.clone(),
594                path: Path::new("file1.txt").into(),
595                range: file1_content.find("two").unwrap()..file1_content.len(),
596                score: 0.6,
597                query_index: 0,
598            },
599            SearchResult {
600                worktree: worktree.clone(),
601                path: Path::new("file1.txt").into(),
602                range: 0..file1_content.find("two").unwrap(),
603                score: 0.5,
604                query_index: 1,
605            },
606            SearchResult {
607                worktree: worktree.clone(),
608                path: Path::new("file2.txt").into(),
609                range: 0..file2_content.len(),
610                score: 0.8,
611                query_index: 1,
612            },
613        ];
614        assert_eq!(
615            SemanticDb::load_results(search_results, &fs, &cx.to_async())
616                .await
617                .unwrap(),
618            &[
619                LoadedSearchResult {
620                    path: Path::new("file2.txt").into(),
621                    full_path: "fake_project/file2.txt".into(),
622                    excerpt_content: file2_content.into(),
623                    row_range: 0..=4,
624                    query_index: 1,
625                },
626                LoadedSearchResult {
627                    path: Path::new("file1.txt").into(),
628                    full_path: "fake_project/file1.txt".into(),
629                    excerpt_content: file1_content.into(),
630                    row_range: 0..=4,
631                    query_index: 0,
632                }
633            ]
634        );
635    }
636}