semantic_index.rs

  1mod chunking;
  2mod embedding;
  3mod embedding_index;
  4mod indexing;
  5mod project_index;
  6mod project_index_debug_view;
  7mod summary_backlog;
  8mod summary_index;
  9mod worktree_index;
 10
 11use anyhow::{Context as _, Result};
 12use collections::HashMap;
 13use fs::Fs;
 14use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Context, Global, Model, WeakModel};
 15use language::LineEnding;
 16use project::{Project, Worktree};
 17use std::{
 18    cmp::Ordering,
 19    path::{Path, PathBuf},
 20    sync::Arc,
 21};
 22use ui::ViewContext;
 23use util::ResultExt as _;
 24use workspace::Workspace;
 25
 26pub use embedding::*;
 27pub use project_index::{LoadedSearchResult, ProjectIndex, SearchResult, Status};
 28pub use project_index_debug_view::ProjectIndexDebugView;
 29pub use summary_index::FileSummary;
 30
 31pub struct SemanticDb {
 32    embedding_provider: Arc<dyn EmbeddingProvider>,
 33    db_connection: Option<heed::Env>,
 34    project_indices: HashMap<WeakModel<Project>, Model<ProjectIndex>>,
 35}
 36
 37impl Global for SemanticDb {}
 38
 39impl SemanticDb {
 40    pub async fn new(
 41        db_path: PathBuf,
 42        embedding_provider: Arc<dyn EmbeddingProvider>,
 43        cx: &mut AsyncAppContext,
 44    ) -> Result<Self> {
 45        let db_connection = cx
 46            .background_executor()
 47            .spawn(async move {
 48                std::fs::create_dir_all(&db_path)?;
 49                unsafe {
 50                    heed::EnvOpenOptions::new()
 51                        .map_size(1024 * 1024 * 1024)
 52                        .max_dbs(3000)
 53                        .open(db_path)
 54                }
 55            })
 56            .await
 57            .context("opening database connection")?;
 58
 59        cx.update(|cx| {
 60            cx.observe_new_views(
 61                |workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
 62                    let project = workspace.project().clone();
 63
 64                    if cx.has_global::<SemanticDb>() {
 65                        cx.update_global::<SemanticDb, _>(|this, cx| {
 66                            this.create_project_index(project, cx);
 67                        })
 68                    } else {
 69                        log::info!("No SemanticDb, skipping project index")
 70                    }
 71                },
 72            )
 73            .detach();
 74        })
 75        .ok();
 76
 77        Ok(SemanticDb {
 78            db_connection: Some(db_connection),
 79            embedding_provider,
 80            project_indices: HashMap::default(),
 81        })
 82    }
 83
 84    pub async fn load_results(
 85        mut results: Vec<SearchResult>,
 86        fs: &Arc<dyn Fs>,
 87        cx: &AsyncAppContext,
 88    ) -> Result<Vec<LoadedSearchResult>> {
 89        let mut max_scores_by_path = HashMap::<_, (f32, usize)>::default();
 90        for result in &results {
 91            let (score, query_index) = max_scores_by_path
 92                .entry((result.worktree.clone(), result.path.clone()))
 93                .or_default();
 94            if result.score > *score {
 95                *score = result.score;
 96                *query_index = result.query_index;
 97            }
 98        }
 99
100        results.sort_by(|a, b| {
101            let max_score_a = max_scores_by_path[&(a.worktree.clone(), a.path.clone())].0;
102            let max_score_b = max_scores_by_path[&(b.worktree.clone(), b.path.clone())].0;
103            max_score_b
104                .partial_cmp(&max_score_a)
105                .unwrap_or(Ordering::Equal)
106                .then_with(|| a.worktree.entity_id().cmp(&b.worktree.entity_id()))
107                .then_with(|| a.path.cmp(&b.path))
108                .then_with(|| a.range.start.cmp(&b.range.start))
109        });
110
111        let mut last_loaded_file: Option<(Model<Worktree>, Arc<Path>, PathBuf, String)> = None;
112        let mut loaded_results = Vec::<LoadedSearchResult>::new();
113        for result in results {
114            let full_path;
115            let file_content;
116            if let Some(last_loaded_file) =
117                last_loaded_file
118                    .as_ref()
119                    .filter(|(last_worktree, last_path, _, _)| {
120                        last_worktree == &result.worktree && last_path == &result.path
121                    })
122            {
123                full_path = last_loaded_file.2.clone();
124                file_content = &last_loaded_file.3;
125            } else {
126                let output = result.worktree.read_with(cx, |worktree, _cx| {
127                    let entry_abs_path = worktree.abs_path().join(&result.path);
128                    let mut entry_full_path = PathBuf::from(worktree.root_name());
129                    entry_full_path.push(&result.path);
130                    let file_content = async {
131                        let entry_abs_path = entry_abs_path;
132                        fs.load(&entry_abs_path).await
133                    };
134                    (entry_full_path, file_content)
135                })?;
136                full_path = output.0;
137                let Some(content) = output.1.await.log_err() else {
138                    continue;
139                };
140                last_loaded_file = Some((
141                    result.worktree.clone(),
142                    result.path.clone(),
143                    full_path.clone(),
144                    content,
145                ));
146                file_content = &last_loaded_file.as_ref().unwrap().3;
147            };
148
149            let query_index = max_scores_by_path[&(result.worktree.clone(), result.path.clone())].1;
150
151            let mut range_start = result.range.start.min(file_content.len());
152            let mut range_end = result.range.end.min(file_content.len());
153            while !file_content.is_char_boundary(range_start) {
154                range_start += 1;
155            }
156            while !file_content.is_char_boundary(range_end) {
157                range_end += 1;
158            }
159
160            let start_row = file_content[0..range_start].matches('\n').count() as u32;
161            let mut end_row = file_content[0..range_end].matches('\n').count() as u32;
162            let start_line_byte_offset = file_content[0..range_start]
163                .rfind('\n')
164                .map(|pos| pos + 1)
165                .unwrap_or_default();
166            let mut end_line_byte_offset = range_end;
167            if file_content[..end_line_byte_offset].ends_with('\n') {
168                end_row -= 1;
169            } else {
170                end_line_byte_offset = file_content[range_end..]
171                    .find('\n')
172                    .map(|pos| range_end + pos + 1)
173                    .unwrap_or_else(|| file_content.len());
174            }
175            let mut excerpt_content =
176                file_content[start_line_byte_offset..end_line_byte_offset].to_string();
177            LineEnding::normalize(&mut excerpt_content);
178
179            if let Some(prev_result) = loaded_results.last_mut() {
180                if prev_result.full_path == full_path {
181                    if *prev_result.row_range.end() + 1 == start_row {
182                        prev_result.row_range = *prev_result.row_range.start()..=end_row;
183                        prev_result.excerpt_content.push_str(&excerpt_content);
184                        continue;
185                    }
186                }
187            }
188
189            loaded_results.push(LoadedSearchResult {
190                path: result.path,
191                full_path,
192                excerpt_content,
193                row_range: start_row..=end_row,
194                query_index,
195            });
196        }
197
198        for result in &mut loaded_results {
199            while result.excerpt_content.ends_with("\n\n") {
200                result.excerpt_content.pop();
201                result.row_range =
202                    *result.row_range.start()..=result.row_range.end().saturating_sub(1)
203            }
204        }
205
206        Ok(loaded_results)
207    }
208
209    pub fn project_index(
210        &mut self,
211        project: Model<Project>,
212        _cx: &mut AppContext,
213    ) -> Option<Model<ProjectIndex>> {
214        self.project_indices.get(&project.downgrade()).cloned()
215    }
216
217    pub fn remaining_summaries(
218        &self,
219        project: &WeakModel<Project>,
220        cx: &mut AppContext,
221    ) -> Option<usize> {
222        self.project_indices.get(project).map(|project_index| {
223            project_index.update(cx, |project_index, cx| {
224                project_index.remaining_summaries(cx)
225            })
226        })
227    }
228
229    pub fn create_project_index(
230        &mut self,
231        project: Model<Project>,
232        cx: &mut AppContext,
233    ) -> Model<ProjectIndex> {
234        let project_index = cx.new_model(|cx| {
235            ProjectIndex::new(
236                project.clone(),
237                self.db_connection.clone().unwrap(),
238                self.embedding_provider.clone(),
239                cx,
240            )
241        });
242
243        let project_weak = project.downgrade();
244        self.project_indices
245            .insert(project_weak.clone(), project_index.clone());
246
247        cx.observe_release(&project, move |_, cx| {
248            if cx.has_global::<SemanticDb>() {
249                cx.update_global::<SemanticDb, _>(|this, _| {
250                    this.project_indices.remove(&project_weak);
251                })
252            }
253        })
254        .detach();
255
256        project_index
257    }
258}
259
260impl Drop for SemanticDb {
261    fn drop(&mut self) {
262        self.db_connection.take().unwrap().prepare_for_closing();
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use anyhow::anyhow;
270    use chunking::Chunk;
271    use embedding_index::{ChunkedFile, EmbeddingIndex};
272    use feature_flags::FeatureFlagAppExt;
273    use fs::FakeFs;
274    use futures::{future::BoxFuture, FutureExt};
275    use gpui::TestAppContext;
276    use indexing::IndexingEntrySet;
277    use language::language_settings::AllLanguageSettings;
278    use project::{Project, ProjectEntryId};
279    use serde_json::json;
280    use settings::SettingsStore;
281    use smol::{channel, stream::StreamExt};
282    use std::{future, path::Path, sync::Arc};
283
284    fn init_test(cx: &mut TestAppContext) {
285        env_logger::try_init().ok();
286
287        cx.update(|cx| {
288            let store = SettingsStore::test(cx);
289            cx.set_global(store);
290            language::init(cx);
291            cx.update_flags(false, vec![]);
292            Project::init_settings(cx);
293            SettingsStore::update(cx, |store, cx| {
294                store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
295            });
296        });
297    }
298
299    pub struct TestEmbeddingProvider {
300        batch_size: usize,
301        compute_embedding: Box<dyn Fn(&str) -> Result<Embedding> + Send + Sync>,
302    }
303
304    impl TestEmbeddingProvider {
305        pub fn new(
306            batch_size: usize,
307            compute_embedding: impl 'static + Fn(&str) -> Result<Embedding> + Send + Sync,
308        ) -> Self {
309            Self {
310                batch_size,
311                compute_embedding: Box::new(compute_embedding),
312            }
313        }
314    }
315
316    impl EmbeddingProvider for TestEmbeddingProvider {
317        fn embed<'a>(
318            &'a self,
319            texts: &'a [TextToEmbed<'a>],
320        ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
321            let embeddings = texts
322                .iter()
323                .map(|to_embed| (self.compute_embedding)(to_embed.text))
324                .collect();
325            future::ready(embeddings).boxed()
326        }
327
328        fn batch_size(&self) -> usize {
329            self.batch_size
330        }
331    }
332
333    #[gpui::test]
334    async fn test_search(cx: &mut TestAppContext) {
335        cx.executor().allow_parking();
336
337        init_test(cx);
338
339        let temp_dir = tempfile::tempdir().unwrap();
340
341        let mut semantic_index = SemanticDb::new(
342            temp_dir.path().into(),
343            Arc::new(TestEmbeddingProvider::new(16, |text| {
344                let mut embedding = vec![0f32; 2];
345                // if the text contains garbage, give it a 1 in the first dimension
346                if text.contains("garbage in") {
347                    embedding[0] = 0.9;
348                } else {
349                    embedding[0] = -0.9;
350                }
351
352                if text.contains("garbage out") {
353                    embedding[1] = 0.9;
354                } else {
355                    embedding[1] = -0.9;
356                }
357
358                Ok(Embedding::new(embedding))
359            })),
360            &mut cx.to_async(),
361        )
362        .await
363        .unwrap();
364
365        let fs = FakeFs::new(cx.executor());
366        let project_path = Path::new("/fake_project");
367
368        fs.insert_tree(
369            project_path,
370            json!({
371                "fixture": {
372                    "main.rs": include_str!("../fixture/main.rs"),
373                    "needle.md": include_str!("../fixture/needle.md"),
374                }
375            }),
376        )
377        .await;
378
379        let project = Project::test(fs, [project_path], cx).await;
380
381        let project_index = cx.update(|cx| {
382            let language_registry = project.read(cx).languages().clone();
383            let node_runtime = project.read(cx).node_runtime().unwrap().clone();
384            languages::init(language_registry, node_runtime, cx);
385            semantic_index.create_project_index(project.clone(), cx)
386        });
387
388        cx.run_until_parked();
389        while cx
390            .update(|cx| semantic_index.remaining_summaries(&project.downgrade(), cx))
391            .unwrap()
392            > 0
393        {
394            cx.run_until_parked();
395        }
396
397        let results = cx
398            .update(|cx| {
399                let project_index = project_index.read(cx);
400                let query = "garbage in, garbage out";
401                project_index.search(vec![query.into()], 4, cx)
402            })
403            .await
404            .unwrap();
405
406        assert!(
407            results.len() > 1,
408            "should have found some results, but only found {:?}",
409            results
410        );
411
412        for result in &results {
413            println!("result: {:?}", result.path);
414            println!("score: {:?}", result.score);
415        }
416
417        // Find result that is greater than 0.5
418        let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
419
420        assert_eq!(search_result.path.to_string_lossy(), "fixture/needle.md");
421
422        let content = cx
423            .update(|cx| {
424                let worktree = search_result.worktree.read(cx);
425                let entry_abs_path = worktree.abs_path().join(&search_result.path);
426                let fs = project.read(cx).fs().clone();
427                cx.background_executor()
428                    .spawn(async move { fs.load(&entry_abs_path).await.unwrap() })
429            })
430            .await;
431
432        let range = search_result.range.clone();
433        let content = content[range.clone()].to_owned();
434
435        assert!(content.contains("garbage in, garbage out"));
436    }
437
438    #[gpui::test]
439    async fn test_embed_files(cx: &mut TestAppContext) {
440        cx.executor().allow_parking();
441
442        let provider = Arc::new(TestEmbeddingProvider::new(3, |text| {
443            if text.contains('g') {
444                Err(anyhow!("cannot embed text containing a 'g' character"))
445            } else {
446                Ok(Embedding::new(
447                    ('a'..='z')
448                        .map(|char| text.chars().filter(|c| *c == char).count() as f32)
449                        .collect(),
450                ))
451            }
452        }));
453
454        let (indexing_progress_tx, _) = channel::unbounded();
455        let indexing_entries = Arc::new(IndexingEntrySet::new(indexing_progress_tx));
456
457        let (chunked_files_tx, chunked_files_rx) = channel::unbounded::<ChunkedFile>();
458        chunked_files_tx
459            .send_blocking(ChunkedFile {
460                path: Path::new("test1.md").into(),
461                mtime: None,
462                handle: indexing_entries.insert(ProjectEntryId::from_proto(0)),
463                text: "abcdefghijklmnop".to_string(),
464                chunks: [0..4, 4..8, 8..12, 12..16]
465                    .into_iter()
466                    .map(|range| Chunk {
467                        range,
468                        digest: Default::default(),
469                    })
470                    .collect(),
471            })
472            .unwrap();
473        chunked_files_tx
474            .send_blocking(ChunkedFile {
475                path: Path::new("test2.md").into(),
476                mtime: None,
477                handle: indexing_entries.insert(ProjectEntryId::from_proto(1)),
478                text: "qrstuvwxyz".to_string(),
479                chunks: [0..4, 4..8, 8..10]
480                    .into_iter()
481                    .map(|range| Chunk {
482                        range,
483                        digest: Default::default(),
484                    })
485                    .collect(),
486            })
487            .unwrap();
488        chunked_files_tx.close();
489
490        let embed_files_task =
491            cx.update(|cx| EmbeddingIndex::embed_files(provider.clone(), chunked_files_rx, cx));
492        embed_files_task.task.await.unwrap();
493
494        let mut embedded_files_rx = embed_files_task.files;
495        let mut embedded_files = Vec::new();
496        while let Some((embedded_file, _)) = embedded_files_rx.next().await {
497            embedded_files.push(embedded_file);
498        }
499
500        assert_eq!(embedded_files.len(), 1);
501        assert_eq!(embedded_files[0].path.as_ref(), Path::new("test2.md"));
502        assert_eq!(
503            embedded_files[0]
504                .chunks
505                .iter()
506                .map(|embedded_chunk| { embedded_chunk.embedding.clone() })
507                .collect::<Vec<Embedding>>(),
508            vec![
509                (provider.compute_embedding)("qrst").unwrap(),
510                (provider.compute_embedding)("uvwx").unwrap(),
511                (provider.compute_embedding)("yz").unwrap(),
512            ],
513        );
514    }
515
516    #[gpui::test]
517    async fn test_load_search_results(cx: &mut TestAppContext) {
518        init_test(cx);
519
520        let fs = FakeFs::new(cx.executor());
521        let project_path = Path::new("/fake_project");
522
523        let file1_content = "one\ntwo\nthree\nfour\nfive\n";
524        let file2_content = "aaa\nbbb\nccc\nddd\neee\n";
525
526        fs.insert_tree(
527            project_path,
528            json!({
529                "file1.txt": file1_content,
530                "file2.txt": file2_content,
531            }),
532        )
533        .await;
534
535        let fs = fs as Arc<dyn Fs>;
536        let project = Project::test(fs.clone(), [project_path], cx).await;
537        let worktree = project.read_with(cx, |project, cx| project.worktrees(cx).next().unwrap());
538
539        // chunk that is already newline-aligned
540        let search_results = vec![SearchResult {
541            worktree: worktree.clone(),
542            path: Path::new("file1.txt").into(),
543            range: 0..file1_content.find("four").unwrap(),
544            score: 0.5,
545            query_index: 0,
546        }];
547        assert_eq!(
548            SemanticDb::load_results(search_results, &fs, &cx.to_async())
549                .await
550                .unwrap(),
551            &[LoadedSearchResult {
552                path: Path::new("file1.txt").into(),
553                full_path: "fake_project/file1.txt".into(),
554                excerpt_content: "one\ntwo\nthree\n".into(),
555                row_range: 0..=2,
556                query_index: 0,
557            }]
558        );
559
560        // chunk that is *not* newline-aligned
561        let search_results = vec![SearchResult {
562            worktree: worktree.clone(),
563            path: Path::new("file1.txt").into(),
564            range: file1_content.find("two").unwrap() + 1..file1_content.find("four").unwrap() + 2,
565            score: 0.5,
566            query_index: 0,
567        }];
568        assert_eq!(
569            SemanticDb::load_results(search_results, &fs, &cx.to_async())
570                .await
571                .unwrap(),
572            &[LoadedSearchResult {
573                path: Path::new("file1.txt").into(),
574                full_path: "fake_project/file1.txt".into(),
575                excerpt_content: "two\nthree\nfour\n".into(),
576                row_range: 1..=3,
577                query_index: 0,
578            }]
579        );
580
581        // chunks that are adjacent
582
583        let search_results = vec![
584            SearchResult {
585                worktree: worktree.clone(),
586                path: Path::new("file1.txt").into(),
587                range: file1_content.find("two").unwrap()..file1_content.len(),
588                score: 0.6,
589                query_index: 0,
590            },
591            SearchResult {
592                worktree: worktree.clone(),
593                path: Path::new("file1.txt").into(),
594                range: 0..file1_content.find("two").unwrap(),
595                score: 0.5,
596                query_index: 1,
597            },
598            SearchResult {
599                worktree: worktree.clone(),
600                path: Path::new("file2.txt").into(),
601                range: 0..file2_content.len(),
602                score: 0.8,
603                query_index: 1,
604            },
605        ];
606        assert_eq!(
607            SemanticDb::load_results(search_results, &fs, &cx.to_async())
608                .await
609                .unwrap(),
610            &[
611                LoadedSearchResult {
612                    path: Path::new("file2.txt").into(),
613                    full_path: "fake_project/file2.txt".into(),
614                    excerpt_content: file2_content.into(),
615                    row_range: 0..=4,
616                    query_index: 1,
617                },
618                LoadedSearchResult {
619                    path: Path::new("file1.txt").into(),
620                    full_path: "fake_project/file1.txt".into(),
621                    excerpt_content: file1_content.into(),
622                    row_range: 0..=4,
623                    query_index: 0,
624                }
625            ]
626        );
627    }
628}