semantic_index.rs

  1mod chunking;
  2mod embedding;
  3mod embedding_index;
  4mod indexing;
  5mod project_index;
  6mod project_index_debug_view;
  7mod summary_backlog;
  8mod summary_index;
  9mod worktree_index;
 10
 11use anyhow::{Context as _, Result};
 12use collections::HashMap;
 13use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Context, Global, Model, WeakModel};
 14use project::Project;
 15use project_index::ProjectIndex;
 16use std::{path::PathBuf, sync::Arc};
 17use ui::ViewContext;
 18use workspace::Workspace;
 19
 20pub use embedding::*;
 21pub use project_index_debug_view::ProjectIndexDebugView;
 22pub use summary_index::FileSummary;
 23
 24pub struct SemanticDb {
 25    embedding_provider: Arc<dyn EmbeddingProvider>,
 26    db_connection: heed::Env,
 27    project_indices: HashMap<WeakModel<Project>, Model<ProjectIndex>>,
 28}
 29
 30impl Global for SemanticDb {}
 31
 32impl SemanticDb {
 33    pub async fn new(
 34        db_path: PathBuf,
 35        embedding_provider: Arc<dyn EmbeddingProvider>,
 36        cx: &mut AsyncAppContext,
 37    ) -> Result<Self> {
 38        let db_connection = cx
 39            .background_executor()
 40            .spawn(async move {
 41                std::fs::create_dir_all(&db_path)?;
 42                unsafe {
 43                    heed::EnvOpenOptions::new()
 44                        .map_size(1024 * 1024 * 1024)
 45                        .max_dbs(3000)
 46                        .open(db_path)
 47                }
 48            })
 49            .await
 50            .context("opening database connection")?;
 51
 52        cx.update(|cx| {
 53            cx.observe_new_views(
 54                |workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
 55                    let project = workspace.project().clone();
 56
 57                    if cx.has_global::<SemanticDb>() {
 58                        cx.update_global::<SemanticDb, _>(|this, cx| {
 59                            let project_index = cx.new_model(|cx| {
 60                                ProjectIndex::new(
 61                                    project.clone(),
 62                                    this.db_connection.clone(),
 63                                    this.embedding_provider.clone(),
 64                                    cx,
 65                                )
 66                            });
 67
 68                            let project_weak = project.downgrade();
 69                            this.project_indices
 70                                .insert(project_weak.clone(), project_index);
 71
 72                            cx.on_release(move |_, _, cx| {
 73                                if cx.has_global::<SemanticDb>() {
 74                                    cx.update_global::<SemanticDb, _>(|this, _| {
 75                                        this.project_indices.remove(&project_weak);
 76                                    })
 77                                }
 78                            })
 79                            .detach();
 80                        })
 81                    } else {
 82                        log::info!("No SemanticDb, skipping project index")
 83                    }
 84                },
 85            )
 86            .detach();
 87        })
 88        .ok();
 89
 90        Ok(SemanticDb {
 91            db_connection,
 92            embedding_provider,
 93            project_indices: HashMap::default(),
 94        })
 95    }
 96
 97    pub fn project_index(
 98        &mut self,
 99        project: Model<Project>,
100        _cx: &mut AppContext,
101    ) -> Option<Model<ProjectIndex>> {
102        self.project_indices.get(&project.downgrade()).cloned()
103    }
104
105    pub fn remaining_summaries(
106        &self,
107        project: &WeakModel<Project>,
108        cx: &mut AppContext,
109    ) -> Option<usize> {
110        self.project_indices.get(project).map(|project_index| {
111            project_index.update(cx, |project_index, cx| {
112                project_index.remaining_summaries(cx)
113            })
114        })
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121    use anyhow::anyhow;
122    use chunking::Chunk;
123    use embedding_index::{ChunkedFile, EmbeddingIndex};
124    use feature_flags::FeatureFlagAppExt;
125    use fs::FakeFs;
126    use futures::{future::BoxFuture, FutureExt};
127    use gpui::TestAppContext;
128    use indexing::IndexingEntrySet;
129    use language::language_settings::AllLanguageSettings;
130    use project::{Project, ProjectEntryId};
131    use serde_json::json;
132    use settings::SettingsStore;
133    use smol::{channel, stream::StreamExt};
134    use std::{future, path::Path, sync::Arc};
135
136    fn init_test(cx: &mut TestAppContext) {
137        env_logger::try_init().ok();
138
139        cx.update(|cx| {
140            let store = SettingsStore::test(cx);
141            cx.set_global(store);
142            language::init(cx);
143            cx.update_flags(false, vec![]);
144            Project::init_settings(cx);
145            SettingsStore::update(cx, |store, cx| {
146                store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
147            });
148        });
149    }
150
151    pub struct TestEmbeddingProvider {
152        batch_size: usize,
153        compute_embedding: Box<dyn Fn(&str) -> Result<Embedding> + Send + Sync>,
154    }
155
156    impl TestEmbeddingProvider {
157        pub fn new(
158            batch_size: usize,
159            compute_embedding: impl 'static + Fn(&str) -> Result<Embedding> + Send + Sync,
160        ) -> Self {
161            Self {
162                batch_size,
163                compute_embedding: Box::new(compute_embedding),
164            }
165        }
166    }
167
168    impl EmbeddingProvider for TestEmbeddingProvider {
169        fn embed<'a>(
170            &'a self,
171            texts: &'a [TextToEmbed<'a>],
172        ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
173            let embeddings = texts
174                .iter()
175                .map(|to_embed| (self.compute_embedding)(to_embed.text))
176                .collect();
177            future::ready(embeddings).boxed()
178        }
179
180        fn batch_size(&self) -> usize {
181            self.batch_size
182        }
183    }
184
185    #[gpui::test]
186    async fn test_search(cx: &mut TestAppContext) {
187        cx.executor().allow_parking();
188
189        init_test(cx);
190
191        let temp_dir = tempfile::tempdir().unwrap();
192
193        let mut semantic_index = SemanticDb::new(
194            temp_dir.path().into(),
195            Arc::new(TestEmbeddingProvider::new(16, |text| {
196                let mut embedding = vec![0f32; 2];
197                // if the text contains garbage, give it a 1 in the first dimension
198                if text.contains("garbage in") {
199                    embedding[0] = 0.9;
200                } else {
201                    embedding[0] = -0.9;
202                }
203
204                if text.contains("garbage out") {
205                    embedding[1] = 0.9;
206                } else {
207                    embedding[1] = -0.9;
208                }
209
210                Ok(Embedding::new(embedding))
211            })),
212            &mut cx.to_async(),
213        )
214        .await
215        .unwrap();
216
217        let fs = FakeFs::new(cx.executor());
218        let project_path = Path::new("/fake_project");
219
220        fs.insert_tree(
221            project_path,
222            json!({
223                "fixture": {
224                    "main.rs": include_str!("../fixture/main.rs"),
225                    "needle.md": include_str!("../fixture/needle.md"),
226                }
227            }),
228        )
229        .await;
230
231        let project = Project::test(fs, [project_path], cx).await;
232
233        cx.update(|cx| {
234            let language_registry = project.read(cx).languages().clone();
235            let node_runtime = project.read(cx).node_runtime().unwrap().clone();
236            languages::init(language_registry, node_runtime, cx);
237
238            // Manually create and insert the ProjectIndex
239            let project_index = cx.new_model(|cx| {
240                ProjectIndex::new(
241                    project.clone(),
242                    semantic_index.db_connection.clone(),
243                    semantic_index.embedding_provider.clone(),
244                    cx,
245                )
246            });
247            semantic_index
248                .project_indices
249                .insert(project.downgrade(), project_index);
250        });
251
252        let project_index = cx
253            .update(|_cx| {
254                semantic_index
255                    .project_indices
256                    .get(&project.downgrade())
257                    .cloned()
258            })
259            .unwrap();
260
261        cx.run_until_parked();
262        while cx
263            .update(|cx| semantic_index.remaining_summaries(&project.downgrade(), cx))
264            .unwrap()
265            > 0
266        {
267            cx.run_until_parked();
268        }
269
270        let results = cx
271            .update(|cx| {
272                let project_index = project_index.read(cx);
273                let query = "garbage in, garbage out";
274                project_index.search(query.into(), 4, cx)
275            })
276            .await
277            .unwrap();
278
279        assert!(
280            results.len() > 1,
281            "should have found some results, but only found {:?}",
282            results
283        );
284
285        for result in &results {
286            println!("result: {:?}", result.path);
287            println!("score: {:?}", result.score);
288        }
289
290        // Find result that is greater than 0.5
291        let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
292
293        assert_eq!(search_result.path.to_string_lossy(), "fixture/needle.md");
294
295        let content = cx
296            .update(|cx| {
297                let worktree = search_result.worktree.read(cx);
298                let entry_abs_path = worktree.abs_path().join(&search_result.path);
299                let fs = project.read(cx).fs().clone();
300                cx.background_executor()
301                    .spawn(async move { fs.load(&entry_abs_path).await.unwrap() })
302            })
303            .await;
304
305        let range = search_result.range.clone();
306        let content = content[range.clone()].to_owned();
307
308        assert!(content.contains("garbage in, garbage out"));
309    }
310
311    #[gpui::test]
312    async fn test_embed_files(cx: &mut TestAppContext) {
313        cx.executor().allow_parking();
314
315        let provider = Arc::new(TestEmbeddingProvider::new(3, |text| {
316            if text.contains('g') {
317                Err(anyhow!("cannot embed text containing a 'g' character"))
318            } else {
319                Ok(Embedding::new(
320                    ('a'..='z')
321                        .map(|char| text.chars().filter(|c| *c == char).count() as f32)
322                        .collect(),
323                ))
324            }
325        }));
326
327        let (indexing_progress_tx, _) = channel::unbounded();
328        let indexing_entries = Arc::new(IndexingEntrySet::new(indexing_progress_tx));
329
330        let (chunked_files_tx, chunked_files_rx) = channel::unbounded::<ChunkedFile>();
331        chunked_files_tx
332            .send_blocking(ChunkedFile {
333                path: Path::new("test1.md").into(),
334                mtime: None,
335                handle: indexing_entries.insert(ProjectEntryId::from_proto(0)),
336                text: "abcdefghijklmnop".to_string(),
337                chunks: [0..4, 4..8, 8..12, 12..16]
338                    .into_iter()
339                    .map(|range| Chunk {
340                        range,
341                        digest: Default::default(),
342                    })
343                    .collect(),
344            })
345            .unwrap();
346        chunked_files_tx
347            .send_blocking(ChunkedFile {
348                path: Path::new("test2.md").into(),
349                mtime: None,
350                handle: indexing_entries.insert(ProjectEntryId::from_proto(1)),
351                text: "qrstuvwxyz".to_string(),
352                chunks: [0..4, 4..8, 8..10]
353                    .into_iter()
354                    .map(|range| Chunk {
355                        range,
356                        digest: Default::default(),
357                    })
358                    .collect(),
359            })
360            .unwrap();
361        chunked_files_tx.close();
362
363        let embed_files_task =
364            cx.update(|cx| EmbeddingIndex::embed_files(provider.clone(), chunked_files_rx, cx));
365        embed_files_task.task.await.unwrap();
366
367        let mut embedded_files_rx = embed_files_task.files;
368        let mut embedded_files = Vec::new();
369        while let Some((embedded_file, _)) = embedded_files_rx.next().await {
370            embedded_files.push(embedded_file);
371        }
372
373        assert_eq!(embedded_files.len(), 1);
374        assert_eq!(embedded_files[0].path.as_ref(), Path::new("test2.md"));
375        assert_eq!(
376            embedded_files[0]
377                .chunks
378                .iter()
379                .map(|embedded_chunk| { embedded_chunk.embedding.clone() })
380                .collect::<Vec<Embedding>>(),
381            vec![
382                (provider.compute_embedding)("qrst").unwrap(),
383                (provider.compute_embedding)("uvwx").unwrap(),
384                (provider.compute_embedding)("yz").unwrap(),
385            ],
386        );
387    }
388}