semantic_index.rs

  1mod chunking;
  2mod embedding;
  3mod embedding_index;
  4mod indexing;
  5mod project_index;
  6mod project_index_debug_view;
  7mod summary_backlog;
  8mod summary_index;
  9mod worktree_index;
 10
 11use anyhow::{Context as _, Result};
 12use collections::HashMap;
 13use fs::Fs;
 14use gpui::{App, AppContext as _, AsyncApp, BorrowAppContext, Context, Entity, Global, WeakEntity};
 15use language::LineEnding;
 16use project::{Project, Worktree};
 17use std::{
 18    cmp::Ordering,
 19    path::{Path, PathBuf},
 20    sync::Arc,
 21};
 22use util::ResultExt as _;
 23use workspace::Workspace;
 24
 25pub use embedding::*;
 26pub use project_index::{LoadedSearchResult, ProjectIndex, SearchResult, Status};
 27pub use project_index_debug_view::ProjectIndexDebugView;
 28pub use summary_index::FileSummary;
 29
 30pub struct SemanticDb {
 31    embedding_provider: Arc<dyn EmbeddingProvider>,
 32    db_connection: Option<heed::Env>,
 33    project_indices: HashMap<WeakEntity<Project>, Entity<ProjectIndex>>,
 34}
 35
 36impl Global for SemanticDb {}
 37
 38impl SemanticDb {
 39    pub async fn new(
 40        db_path: PathBuf,
 41        embedding_provider: Arc<dyn EmbeddingProvider>,
 42        cx: &mut AsyncApp,
 43    ) -> Result<Self> {
 44        let db_connection = cx
 45            .background_spawn(async move {
 46                std::fs::create_dir_all(&db_path)?;
 47                unsafe {
 48                    heed::EnvOpenOptions::new()
 49                        .map_size(1024 * 1024 * 1024)
 50                        .max_dbs(3000)
 51                        .open(db_path)
 52                }
 53            })
 54            .await
 55            .context("opening database connection")?;
 56
 57        cx.update(|cx| {
 58            cx.observe_new(
 59                |workspace: &mut Workspace, _window, cx: &mut Context<Workspace>| {
 60                    let project = workspace.project().clone();
 61
 62                    if cx.has_global::<SemanticDb>() {
 63                        cx.update_global::<SemanticDb, _>(|this, cx| {
 64                            this.create_project_index(project, cx);
 65                        })
 66                    } else {
 67                        log::info!("No SemanticDb, skipping project index")
 68                    }
 69                },
 70            )
 71            .detach();
 72        })
 73        .ok();
 74
 75        Ok(SemanticDb {
 76            db_connection: Some(db_connection),
 77            embedding_provider,
 78            project_indices: HashMap::default(),
 79        })
 80    }
 81
 82    pub async fn load_results(
 83        mut results: Vec<SearchResult>,
 84        fs: &Arc<dyn Fs>,
 85        cx: &AsyncApp,
 86    ) -> Result<Vec<LoadedSearchResult>> {
 87        let mut max_scores_by_path = HashMap::<_, (f32, usize)>::default();
 88        for result in &results {
 89            let (score, query_index) = max_scores_by_path
 90                .entry((result.worktree.clone(), result.path.clone()))
 91                .or_default();
 92            if result.score > *score {
 93                *score = result.score;
 94                *query_index = result.query_index;
 95            }
 96        }
 97
 98        results.sort_by(|a, b| {
 99            let max_score_a = max_scores_by_path[&(a.worktree.clone(), a.path.clone())].0;
100            let max_score_b = max_scores_by_path[&(b.worktree.clone(), b.path.clone())].0;
101            max_score_b
102                .partial_cmp(&max_score_a)
103                .unwrap_or(Ordering::Equal)
104                .then_with(|| a.worktree.entity_id().cmp(&b.worktree.entity_id()))
105                .then_with(|| a.path.cmp(&b.path))
106                .then_with(|| a.range.start.cmp(&b.range.start))
107        });
108
109        let mut last_loaded_file: Option<(Entity<Worktree>, Arc<Path>, PathBuf, String)> = None;
110        let mut loaded_results = Vec::<LoadedSearchResult>::new();
111        for result in results {
112            let full_path;
113            let file_content;
114            if let Some(last_loaded_file) =
115                last_loaded_file
116                    .as_ref()
117                    .filter(|(last_worktree, last_path, _, _)| {
118                        last_worktree == &result.worktree && last_path == &result.path
119                    })
120            {
121                full_path = last_loaded_file.2.clone();
122                file_content = &last_loaded_file.3;
123            } else {
124                let output = result.worktree.read_with(cx, |worktree, _cx| {
125                    let entry_abs_path = worktree.abs_path().join(&result.path);
126                    let mut entry_full_path = PathBuf::from(worktree.root_name());
127                    entry_full_path.push(&result.path);
128                    let file_content = async {
129                        let entry_abs_path = entry_abs_path;
130                        fs.load(&entry_abs_path).await
131                    };
132                    (entry_full_path, file_content)
133                })?;
134                full_path = output.0;
135                let Some(content) = output.1.await.log_err() else {
136                    continue;
137                };
138                last_loaded_file = Some((
139                    result.worktree.clone(),
140                    result.path.clone(),
141                    full_path.clone(),
142                    content,
143                ));
144                file_content = &last_loaded_file.as_ref().unwrap().3;
145            };
146
147            let query_index = max_scores_by_path[&(result.worktree.clone(), result.path.clone())].1;
148
149            let mut range_start = result.range.start.min(file_content.len());
150            let mut range_end = result.range.end.min(file_content.len());
151            while !file_content.is_char_boundary(range_start) {
152                range_start += 1;
153            }
154            while !file_content.is_char_boundary(range_end) {
155                range_end += 1;
156            }
157
158            let start_row = file_content[0..range_start].matches('\n').count() as u32;
159            let mut end_row = file_content[0..range_end].matches('\n').count() as u32;
160            let start_line_byte_offset = file_content[0..range_start]
161                .rfind('\n')
162                .map(|pos| pos + 1)
163                .unwrap_or_default();
164            let mut end_line_byte_offset = range_end;
165            if file_content[..end_line_byte_offset].ends_with('\n') {
166                end_row -= 1;
167            } else {
168                end_line_byte_offset = file_content[range_end..]
169                    .find('\n')
170                    .map(|pos| range_end + pos + 1)
171                    .unwrap_or_else(|| file_content.len());
172            }
173            let mut excerpt_content =
174                file_content[start_line_byte_offset..end_line_byte_offset].to_string();
175            LineEnding::normalize(&mut excerpt_content);
176
177            if let Some(prev_result) = loaded_results.last_mut()
178                && prev_result.full_path == full_path
179                    && *prev_result.row_range.end() + 1 == start_row {
180                        prev_result.row_range = *prev_result.row_range.start()..=end_row;
181                        prev_result.excerpt_content.push_str(&excerpt_content);
182                        continue;
183                    }
184
185            loaded_results.push(LoadedSearchResult {
186                path: result.path,
187                full_path,
188                excerpt_content,
189                row_range: start_row..=end_row,
190                query_index,
191            });
192        }
193
194        for result in &mut loaded_results {
195            while result.excerpt_content.ends_with("\n\n") {
196                result.excerpt_content.pop();
197                result.row_range =
198                    *result.row_range.start()..=result.row_range.end().saturating_sub(1)
199            }
200        }
201
202        Ok(loaded_results)
203    }
204
205    pub fn project_index(
206        &mut self,
207        project: Entity<Project>,
208        _cx: &mut App,
209    ) -> Option<Entity<ProjectIndex>> {
210        self.project_indices.get(&project.downgrade()).cloned()
211    }
212
213    pub fn remaining_summaries(
214        &self,
215        project: &WeakEntity<Project>,
216        cx: &mut App,
217    ) -> Option<usize> {
218        self.project_indices.get(project).map(|project_index| {
219            project_index.update(cx, |project_index, cx| {
220                project_index.remaining_summaries(cx)
221            })
222        })
223    }
224
225    pub fn create_project_index(
226        &mut self,
227        project: Entity<Project>,
228        cx: &mut App,
229    ) -> Entity<ProjectIndex> {
230        let project_index = cx.new(|cx| {
231            ProjectIndex::new(
232                project.clone(),
233                self.db_connection.clone().unwrap(),
234                self.embedding_provider.clone(),
235                cx,
236            )
237        });
238
239        let project_weak = project.downgrade();
240        self.project_indices
241            .insert(project_weak.clone(), project_index.clone());
242
243        cx.observe_release(&project, move |_, cx| {
244            if cx.has_global::<SemanticDb>() {
245                cx.update_global::<SemanticDb, _>(|this, _| {
246                    this.project_indices.remove(&project_weak);
247                })
248            }
249        })
250        .detach();
251
252        project_index
253    }
254}
255
256impl Drop for SemanticDb {
257    fn drop(&mut self) {
258        self.db_connection.take().unwrap().prepare_for_closing();
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use chunking::Chunk;
266    use embedding_index::{ChunkedFile, EmbeddingIndex};
267    use feature_flags::FeatureFlagAppExt;
268    use fs::FakeFs;
269    use futures::{FutureExt, future::BoxFuture};
270    use gpui::TestAppContext;
271    use indexing::IndexingEntrySet;
272    use language::language_settings::AllLanguageSettings;
273    use project::{Project, ProjectEntryId};
274    use serde_json::json;
275    use settings::SettingsStore;
276    use smol::channel;
277    use std::{future, path::Path, sync::Arc};
278    use util::path;
279
280    fn init_test(cx: &mut TestAppContext) {
281        zlog::init_test();
282
283        cx.update(|cx| {
284            let store = SettingsStore::test(cx);
285            cx.set_global(store);
286            language::init(cx);
287            cx.update_flags(false, vec![]);
288            Project::init_settings(cx);
289            SettingsStore::update(cx, |store, cx| {
290                store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
291            });
292        });
293    }
294
295    pub struct TestEmbeddingProvider {
296        batch_size: usize,
297        compute_embedding: Box<dyn Fn(&str) -> Result<Embedding> + Send + Sync>,
298    }
299
300    impl TestEmbeddingProvider {
301        pub fn new(
302            batch_size: usize,
303            compute_embedding: impl 'static + Fn(&str) -> Result<Embedding> + Send + Sync,
304        ) -> Self {
305            Self {
306                batch_size,
307                compute_embedding: Box::new(compute_embedding),
308            }
309        }
310    }
311
312    impl EmbeddingProvider for TestEmbeddingProvider {
313        fn embed<'a>(
314            &'a self,
315            texts: &'a [TextToEmbed<'a>],
316        ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
317            let embeddings = texts
318                .iter()
319                .map(|to_embed| (self.compute_embedding)(to_embed.text))
320                .collect();
321            future::ready(embeddings).boxed()
322        }
323
324        fn batch_size(&self) -> usize {
325            self.batch_size
326        }
327    }
328
329    #[gpui::test]
330    async fn test_search(cx: &mut TestAppContext) {
331        cx.executor().allow_parking();
332
333        init_test(cx);
334
335        cx.update(|cx| {
336            // This functionality is staff-flagged.
337            cx.update_flags(true, vec![]);
338        });
339
340        let temp_dir = tempfile::tempdir().unwrap();
341
342        let mut semantic_index = SemanticDb::new(
343            temp_dir.path().into(),
344            Arc::new(TestEmbeddingProvider::new(16, |text| {
345                let mut embedding = vec![0f32; 2];
346                // if the text contains garbage, give it a 1 in the first dimension
347                if text.contains("garbage in") {
348                    embedding[0] = 0.9;
349                } else {
350                    embedding[0] = -0.9;
351                }
352
353                if text.contains("garbage out") {
354                    embedding[1] = 0.9;
355                } else {
356                    embedding[1] = -0.9;
357                }
358
359                Ok(Embedding::new(embedding))
360            })),
361            &mut cx.to_async(),
362        )
363        .await
364        .unwrap();
365
366        let fs = FakeFs::new(cx.executor());
367        let project_path = Path::new("/fake_project");
368
369        fs.insert_tree(
370            project_path,
371            json!({
372                "fixture": {
373                    "main.rs": include_str!("../fixture/main.rs"),
374                    "needle.md": include_str!("../fixture/needle.md"),
375                }
376            }),
377        )
378        .await;
379
380        let project = Project::test(fs, [project_path], cx).await;
381
382        let project_index = cx.update(|cx| {
383            let language_registry = project.read(cx).languages().clone();
384            let node_runtime = project.read(cx).node_runtime().unwrap().clone();
385            languages::init(language_registry, node_runtime, cx);
386            semantic_index.create_project_index(project.clone(), cx)
387        });
388
389        cx.run_until_parked();
390        while cx
391            .update(|cx| semantic_index.remaining_summaries(&project.downgrade(), cx))
392            .unwrap()
393            > 0
394        {
395            cx.run_until_parked();
396        }
397
398        let results = cx
399            .update(|cx| {
400                let project_index = project_index.read(cx);
401                let query = "garbage in, garbage out";
402                project_index.search(vec![query.into()], 4, cx)
403            })
404            .await
405            .unwrap();
406
407        assert!(
408            results.len() > 1,
409            "should have found some results, but only found {:?}",
410            results
411        );
412
413        for result in &results {
414            println!("result: {:?}", result.path);
415            println!("score: {:?}", result.score);
416        }
417
418        // Find result that is greater than 0.5
419        let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
420
421        assert_eq!(
422            search_result.path.to_string_lossy(),
423            path!("fixture/needle.md")
424        );
425
426        let content = cx
427            .update(|cx| {
428                let worktree = search_result.worktree.read(cx);
429                let entry_abs_path = worktree.abs_path().join(&search_result.path);
430                let fs = project.read(cx).fs().clone();
431                cx.background_spawn(async move { fs.load(&entry_abs_path).await.unwrap() })
432            })
433            .await;
434
435        let range = search_result.range.clone();
436        let content = content[range.clone()].to_owned();
437
438        assert!(content.contains("garbage in, garbage out"));
439    }
440
441    #[gpui::test]
442    async fn test_embed_files(cx: &mut TestAppContext) {
443        cx.executor().allow_parking();
444
445        let provider = Arc::new(TestEmbeddingProvider::new(3, |text| {
446            anyhow::ensure!(
447                !text.contains('g'),
448                "cannot embed text containing a 'g' character"
449            );
450            Ok(Embedding::new(
451                ('a'..='z')
452                    .map(|char| text.chars().filter(|c| *c == char).count() as f32)
453                    .collect(),
454            ))
455        }));
456
457        let (indexing_progress_tx, _) = channel::unbounded();
458        let indexing_entries = Arc::new(IndexingEntrySet::new(indexing_progress_tx));
459
460        let (chunked_files_tx, chunked_files_rx) = channel::unbounded::<ChunkedFile>();
461        chunked_files_tx
462            .send_blocking(ChunkedFile {
463                path: Path::new("test1.md").into(),
464                mtime: None,
465                handle: indexing_entries.insert(ProjectEntryId::from_proto(0)),
466                text: "abcdefghijklmnop".to_string(),
467                chunks: [0..4, 4..8, 8..12, 12..16]
468                    .into_iter()
469                    .map(|range| Chunk {
470                        range,
471                        digest: Default::default(),
472                    })
473                    .collect(),
474            })
475            .unwrap();
476        chunked_files_tx
477            .send_blocking(ChunkedFile {
478                path: Path::new("test2.md").into(),
479                mtime: None,
480                handle: indexing_entries.insert(ProjectEntryId::from_proto(1)),
481                text: "qrstuvwxyz".to_string(),
482                chunks: [0..4, 4..8, 8..10]
483                    .into_iter()
484                    .map(|range| Chunk {
485                        range,
486                        digest: Default::default(),
487                    })
488                    .collect(),
489            })
490            .unwrap();
491        chunked_files_tx.close();
492
493        let embed_files_task =
494            cx.update(|cx| EmbeddingIndex::embed_files(provider.clone(), chunked_files_rx, cx));
495        embed_files_task.task.await.unwrap();
496
497        let embedded_files_rx = embed_files_task.files;
498        let mut embedded_files = Vec::new();
499        while let Ok((embedded_file, _)) = embedded_files_rx.recv().await {
500            embedded_files.push(embedded_file);
501        }
502
503        assert_eq!(embedded_files.len(), 1);
504        assert_eq!(embedded_files[0].path.as_ref(), Path::new("test2.md"));
505        assert_eq!(
506            embedded_files[0]
507                .chunks
508                .iter()
509                .map(|embedded_chunk| { embedded_chunk.embedding.clone() })
510                .collect::<Vec<Embedding>>(),
511            vec![
512                (provider.compute_embedding)("qrst").unwrap(),
513                (provider.compute_embedding)("uvwx").unwrap(),
514                (provider.compute_embedding)("yz").unwrap(),
515            ],
516        );
517    }
518
519    #[gpui::test]
520    async fn test_load_search_results(cx: &mut TestAppContext) {
521        init_test(cx);
522
523        let fs = FakeFs::new(cx.executor());
524        let project_path = Path::new("/fake_project");
525
526        let file1_content = "one\ntwo\nthree\nfour\nfive\n";
527        let file2_content = "aaa\nbbb\nccc\nddd\neee\n";
528
529        fs.insert_tree(
530            project_path,
531            json!({
532                "file1.txt": file1_content,
533                "file2.txt": file2_content,
534            }),
535        )
536        .await;
537
538        let fs = fs as Arc<dyn Fs>;
539        let project = Project::test(fs.clone(), [project_path], cx).await;
540        let worktree = project.read_with(cx, |project, cx| project.worktrees(cx).next().unwrap());
541
542        // chunk that is already newline-aligned
543        let search_results = vec![SearchResult {
544            worktree: worktree.clone(),
545            path: Path::new("file1.txt").into(),
546            range: 0..file1_content.find("four").unwrap(),
547            score: 0.5,
548            query_index: 0,
549        }];
550        assert_eq!(
551            SemanticDb::load_results(search_results, &fs, &cx.to_async())
552                .await
553                .unwrap(),
554            &[LoadedSearchResult {
555                path: Path::new("file1.txt").into(),
556                full_path: "fake_project/file1.txt".into(),
557                excerpt_content: "one\ntwo\nthree\n".into(),
558                row_range: 0..=2,
559                query_index: 0,
560            }]
561        );
562
563        // chunk that is *not* newline-aligned
564        let search_results = vec![SearchResult {
565            worktree: worktree.clone(),
566            path: Path::new("file1.txt").into(),
567            range: file1_content.find("two").unwrap() + 1..file1_content.find("four").unwrap() + 2,
568            score: 0.5,
569            query_index: 0,
570        }];
571        assert_eq!(
572            SemanticDb::load_results(search_results, &fs, &cx.to_async())
573                .await
574                .unwrap(),
575            &[LoadedSearchResult {
576                path: Path::new("file1.txt").into(),
577                full_path: "fake_project/file1.txt".into(),
578                excerpt_content: "two\nthree\nfour\n".into(),
579                row_range: 1..=3,
580                query_index: 0,
581            }]
582        );
583
584        // chunks that are adjacent
585
586        let search_results = vec![
587            SearchResult {
588                worktree: worktree.clone(),
589                path: Path::new("file1.txt").into(),
590                range: file1_content.find("two").unwrap()..file1_content.len(),
591                score: 0.6,
592                query_index: 0,
593            },
594            SearchResult {
595                worktree: worktree.clone(),
596                path: Path::new("file1.txt").into(),
597                range: 0..file1_content.find("two").unwrap(),
598                score: 0.5,
599                query_index: 1,
600            },
601            SearchResult {
602                worktree: worktree.clone(),
603                path: Path::new("file2.txt").into(),
604                range: 0..file2_content.len(),
605                score: 0.8,
606                query_index: 1,
607            },
608        ];
609        assert_eq!(
610            SemanticDb::load_results(search_results, &fs, &cx.to_async())
611                .await
612                .unwrap(),
613            &[
614                LoadedSearchResult {
615                    path: Path::new("file2.txt").into(),
616                    full_path: "fake_project/file2.txt".into(),
617                    excerpt_content: file2_content.into(),
618                    row_range: 0..=4,
619                    query_index: 1,
620                },
621                LoadedSearchResult {
622                    path: Path::new("file1.txt").into(),
623                    full_path: "fake_project/file1.txt".into(),
624                    excerpt_content: file1_content.into(),
625                    row_range: 0..=4,
626                    query_index: 0,
627                }
628            ]
629        );
630    }
631}