semantic_index.rs

  1mod chunking;
  2mod embedding;
  3mod embedding_index;
  4mod indexing;
  5mod project_index;
  6mod project_index_debug_view;
  7mod summary_backlog;
  8mod summary_index;
  9mod worktree_index;
 10
 11use anyhow::{Context as _, Result};
 12use collections::HashMap;
 13use fs::Fs;
 14use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Context, Global, Model, WeakModel};
 15use project::Project;
 16use std::{path::PathBuf, sync::Arc};
 17use ui::ViewContext;
 18use util::ResultExt as _;
 19use workspace::Workspace;
 20
 21pub use embedding::*;
 22pub use project_index::{LoadedSearchResult, ProjectIndex, SearchResult, Status};
 23pub use project_index_debug_view::ProjectIndexDebugView;
 24pub use summary_index::FileSummary;
 25
 26pub struct SemanticDb {
 27    embedding_provider: Arc<dyn EmbeddingProvider>,
 28    db_connection: heed::Env,
 29    project_indices: HashMap<WeakModel<Project>, Model<ProjectIndex>>,
 30}
 31
 32impl Global for SemanticDb {}
 33
 34impl SemanticDb {
 35    pub async fn new(
 36        db_path: PathBuf,
 37        embedding_provider: Arc<dyn EmbeddingProvider>,
 38        cx: &mut AsyncAppContext,
 39    ) -> Result<Self> {
 40        let db_connection = cx
 41            .background_executor()
 42            .spawn(async move {
 43                std::fs::create_dir_all(&db_path)?;
 44                unsafe {
 45                    heed::EnvOpenOptions::new()
 46                        .map_size(1024 * 1024 * 1024)
 47                        .max_dbs(3000)
 48                        .open(db_path)
 49                }
 50            })
 51            .await
 52            .context("opening database connection")?;
 53
 54        cx.update(|cx| {
 55            cx.observe_new_views(
 56                |workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
 57                    let project = workspace.project().clone();
 58
 59                    if cx.has_global::<SemanticDb>() {
 60                        cx.update_global::<SemanticDb, _>(|this, cx| {
 61                            this.create_project_index(project, cx);
 62                        })
 63                    } else {
 64                        log::info!("No SemanticDb, skipping project index")
 65                    }
 66                },
 67            )
 68            .detach();
 69        })
 70        .ok();
 71
 72        Ok(SemanticDb {
 73            db_connection,
 74            embedding_provider,
 75            project_indices: HashMap::default(),
 76        })
 77    }
 78
 79    pub async fn load_results(
 80        results: Vec<SearchResult>,
 81        fs: &Arc<dyn Fs>,
 82        cx: &AsyncAppContext,
 83    ) -> Result<Vec<LoadedSearchResult>> {
 84        let mut loaded_results = Vec::new();
 85        for result in results {
 86            let (full_path, file_content) = result.worktree.read_with(cx, |worktree, _cx| {
 87                let entry_abs_path = worktree.abs_path().join(&result.path);
 88                let mut entry_full_path = PathBuf::from(worktree.root_name());
 89                entry_full_path.push(&result.path);
 90                let file_content = async {
 91                    let entry_abs_path = entry_abs_path;
 92                    fs.load(&entry_abs_path).await
 93                };
 94                (entry_full_path, file_content)
 95            })?;
 96            if let Some(file_content) = file_content.await.log_err() {
 97                let range_start = result.range.start.min(file_content.len());
 98                let range_end = result.range.end.min(file_content.len());
 99
100                let start_row = file_content[0..range_start].matches('\n').count() as u32;
101                let end_row = file_content[0..range_end].matches('\n').count() as u32;
102                let start_line_byte_offset = file_content[0..range_start]
103                    .rfind('\n')
104                    .map(|pos| pos + 1)
105                    .unwrap_or_default();
106                let end_line_byte_offset = file_content[range_end..]
107                    .find('\n')
108                    .map(|pos| range_end + pos)
109                    .unwrap_or_else(|| file_content.len());
110
111                loaded_results.push(LoadedSearchResult {
112                    path: result.path,
113                    range: start_line_byte_offset..end_line_byte_offset,
114                    full_path,
115                    file_content,
116                    row_range: start_row..=end_row,
117                });
118            }
119        }
120        Ok(loaded_results)
121    }
122
123    pub fn project_index(
124        &mut self,
125        project: Model<Project>,
126        _cx: &mut AppContext,
127    ) -> Option<Model<ProjectIndex>> {
128        self.project_indices.get(&project.downgrade()).cloned()
129    }
130
131    pub fn remaining_summaries(
132        &self,
133        project: &WeakModel<Project>,
134        cx: &mut AppContext,
135    ) -> Option<usize> {
136        self.project_indices.get(project).map(|project_index| {
137            project_index.update(cx, |project_index, cx| {
138                project_index.remaining_summaries(cx)
139            })
140        })
141    }
142
143    pub fn create_project_index(
144        &mut self,
145        project: Model<Project>,
146        cx: &mut AppContext,
147    ) -> Model<ProjectIndex> {
148        let project_index = cx.new_model(|cx| {
149            ProjectIndex::new(
150                project.clone(),
151                self.db_connection.clone(),
152                self.embedding_provider.clone(),
153                cx,
154            )
155        });
156
157        let project_weak = project.downgrade();
158        self.project_indices
159            .insert(project_weak.clone(), project_index.clone());
160
161        cx.observe_release(&project, move |_, cx| {
162            if cx.has_global::<SemanticDb>() {
163                cx.update_global::<SemanticDb, _>(|this, _| {
164                    this.project_indices.remove(&project_weak);
165                })
166            }
167        })
168        .detach();
169
170        project_index
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use anyhow::anyhow;
178    use chunking::Chunk;
179    use embedding_index::{ChunkedFile, EmbeddingIndex};
180    use feature_flags::FeatureFlagAppExt;
181    use fs::FakeFs;
182    use futures::{future::BoxFuture, FutureExt};
183    use gpui::TestAppContext;
184    use indexing::IndexingEntrySet;
185    use language::language_settings::AllLanguageSettings;
186    use project::{Project, ProjectEntryId};
187    use serde_json::json;
188    use settings::SettingsStore;
189    use smol::{channel, stream::StreamExt};
190    use std::{future, path::Path, sync::Arc};
191
192    fn init_test(cx: &mut TestAppContext) {
193        env_logger::try_init().ok();
194
195        cx.update(|cx| {
196            let store = SettingsStore::test(cx);
197            cx.set_global(store);
198            language::init(cx);
199            cx.update_flags(false, vec![]);
200            Project::init_settings(cx);
201            SettingsStore::update(cx, |store, cx| {
202                store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
203            });
204        });
205    }
206
207    pub struct TestEmbeddingProvider {
208        batch_size: usize,
209        compute_embedding: Box<dyn Fn(&str) -> Result<Embedding> + Send + Sync>,
210    }
211
212    impl TestEmbeddingProvider {
213        pub fn new(
214            batch_size: usize,
215            compute_embedding: impl 'static + Fn(&str) -> Result<Embedding> + Send + Sync,
216        ) -> Self {
217            Self {
218                batch_size,
219                compute_embedding: Box::new(compute_embedding),
220            }
221        }
222    }
223
224    impl EmbeddingProvider for TestEmbeddingProvider {
225        fn embed<'a>(
226            &'a self,
227            texts: &'a [TextToEmbed<'a>],
228        ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
229            let embeddings = texts
230                .iter()
231                .map(|to_embed| (self.compute_embedding)(to_embed.text))
232                .collect();
233            future::ready(embeddings).boxed()
234        }
235
236        fn batch_size(&self) -> usize {
237            self.batch_size
238        }
239    }
240
241    #[gpui::test]
242    async fn test_search(cx: &mut TestAppContext) {
243        cx.executor().allow_parking();
244
245        init_test(cx);
246
247        let temp_dir = tempfile::tempdir().unwrap();
248
249        let mut semantic_index = SemanticDb::new(
250            temp_dir.path().into(),
251            Arc::new(TestEmbeddingProvider::new(16, |text| {
252                let mut embedding = vec![0f32; 2];
253                // if the text contains garbage, give it a 1 in the first dimension
254                if text.contains("garbage in") {
255                    embedding[0] = 0.9;
256                } else {
257                    embedding[0] = -0.9;
258                }
259
260                if text.contains("garbage out") {
261                    embedding[1] = 0.9;
262                } else {
263                    embedding[1] = -0.9;
264                }
265
266                Ok(Embedding::new(embedding))
267            })),
268            &mut cx.to_async(),
269        )
270        .await
271        .unwrap();
272
273        let fs = FakeFs::new(cx.executor());
274        let project_path = Path::new("/fake_project");
275
276        fs.insert_tree(
277            project_path,
278            json!({
279                "fixture": {
280                    "main.rs": include_str!("../fixture/main.rs"),
281                    "needle.md": include_str!("../fixture/needle.md"),
282                }
283            }),
284        )
285        .await;
286
287        let project = Project::test(fs, [project_path], cx).await;
288
289        let project_index = cx.update(|cx| {
290            let language_registry = project.read(cx).languages().clone();
291            let node_runtime = project.read(cx).node_runtime().unwrap().clone();
292            languages::init(language_registry, node_runtime, cx);
293            semantic_index.create_project_index(project.clone(), cx)
294        });
295
296        cx.run_until_parked();
297        while cx
298            .update(|cx| semantic_index.remaining_summaries(&project.downgrade(), cx))
299            .unwrap()
300            > 0
301        {
302            cx.run_until_parked();
303        }
304
305        let results = cx
306            .update(|cx| {
307                let project_index = project_index.read(cx);
308                let query = "garbage in, garbage out";
309                project_index.search(query.into(), 4, cx)
310            })
311            .await
312            .unwrap();
313
314        assert!(
315            results.len() > 1,
316            "should have found some results, but only found {:?}",
317            results
318        );
319
320        for result in &results {
321            println!("result: {:?}", result.path);
322            println!("score: {:?}", result.score);
323        }
324
325        // Find result that is greater than 0.5
326        let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
327
328        assert_eq!(search_result.path.to_string_lossy(), "fixture/needle.md");
329
330        let content = cx
331            .update(|cx| {
332                let worktree = search_result.worktree.read(cx);
333                let entry_abs_path = worktree.abs_path().join(&search_result.path);
334                let fs = project.read(cx).fs().clone();
335                cx.background_executor()
336                    .spawn(async move { fs.load(&entry_abs_path).await.unwrap() })
337            })
338            .await;
339
340        let range = search_result.range.clone();
341        let content = content[range.clone()].to_owned();
342
343        assert!(content.contains("garbage in, garbage out"));
344    }
345
346    #[gpui::test]
347    async fn test_embed_files(cx: &mut TestAppContext) {
348        cx.executor().allow_parking();
349
350        let provider = Arc::new(TestEmbeddingProvider::new(3, |text| {
351            if text.contains('g') {
352                Err(anyhow!("cannot embed text containing a 'g' character"))
353            } else {
354                Ok(Embedding::new(
355                    ('a'..='z')
356                        .map(|char| text.chars().filter(|c| *c == char).count() as f32)
357                        .collect(),
358                ))
359            }
360        }));
361
362        let (indexing_progress_tx, _) = channel::unbounded();
363        let indexing_entries = Arc::new(IndexingEntrySet::new(indexing_progress_tx));
364
365        let (chunked_files_tx, chunked_files_rx) = channel::unbounded::<ChunkedFile>();
366        chunked_files_tx
367            .send_blocking(ChunkedFile {
368                path: Path::new("test1.md").into(),
369                mtime: None,
370                handle: indexing_entries.insert(ProjectEntryId::from_proto(0)),
371                text: "abcdefghijklmnop".to_string(),
372                chunks: [0..4, 4..8, 8..12, 12..16]
373                    .into_iter()
374                    .map(|range| Chunk {
375                        range,
376                        digest: Default::default(),
377                    })
378                    .collect(),
379            })
380            .unwrap();
381        chunked_files_tx
382            .send_blocking(ChunkedFile {
383                path: Path::new("test2.md").into(),
384                mtime: None,
385                handle: indexing_entries.insert(ProjectEntryId::from_proto(1)),
386                text: "qrstuvwxyz".to_string(),
387                chunks: [0..4, 4..8, 8..10]
388                    .into_iter()
389                    .map(|range| Chunk {
390                        range,
391                        digest: Default::default(),
392                    })
393                    .collect(),
394            })
395            .unwrap();
396        chunked_files_tx.close();
397
398        let embed_files_task =
399            cx.update(|cx| EmbeddingIndex::embed_files(provider.clone(), chunked_files_rx, cx));
400        embed_files_task.task.await.unwrap();
401
402        let mut embedded_files_rx = embed_files_task.files;
403        let mut embedded_files = Vec::new();
404        while let Some((embedded_file, _)) = embedded_files_rx.next().await {
405            embedded_files.push(embedded_file);
406        }
407
408        assert_eq!(embedded_files.len(), 1);
409        assert_eq!(embedded_files[0].path.as_ref(), Path::new("test2.md"));
410        assert_eq!(
411            embedded_files[0]
412                .chunks
413                .iter()
414                .map(|embedded_chunk| { embedded_chunk.embedding.clone() })
415                .collect::<Vec<Embedding>>(),
416            vec![
417                (provider.compute_embedding)("qrst").unwrap(),
418                (provider.compute_embedding)("uvwx").unwrap(),
419                (provider.compute_embedding)("yz").unwrap(),
420            ],
421        );
422    }
423}