semantic_index.rs

  1mod chunking;
  2mod embedding;
  3mod embedding_index;
  4mod indexing;
  5mod project_index;
  6mod project_index_debug_view;
  7mod summary_backlog;
  8mod summary_index;
  9mod worktree_index;
 10
 11use anyhow::{Context as _, Result};
 12use collections::HashMap;
 13use fs::Fs;
 14use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Context, Global, Model, WeakModel};
 15use project::Project;
 16use std::{path::PathBuf, sync::Arc};
 17use ui::ViewContext;
 18use util::ResultExt as _;
 19use workspace::Workspace;
 20
 21pub use embedding::*;
 22pub use project_index::{LoadedSearchResult, ProjectIndex, SearchResult, Status};
 23pub use project_index_debug_view::ProjectIndexDebugView;
 24pub use summary_index::FileSummary;
 25
 26pub struct SemanticDb {
 27    embedding_provider: Arc<dyn EmbeddingProvider>,
 28    db_connection: Option<heed::Env>,
 29    project_indices: HashMap<WeakModel<Project>, Model<ProjectIndex>>,
 30}
 31
 32impl Global for SemanticDb {}
 33
 34impl SemanticDb {
 35    pub async fn new(
 36        db_path: PathBuf,
 37        embedding_provider: Arc<dyn EmbeddingProvider>,
 38        cx: &mut AsyncAppContext,
 39    ) -> Result<Self> {
 40        let db_connection = cx
 41            .background_executor()
 42            .spawn(async move {
 43                std::fs::create_dir_all(&db_path)?;
 44                unsafe {
 45                    heed::EnvOpenOptions::new()
 46                        .map_size(1024 * 1024 * 1024)
 47                        .max_dbs(3000)
 48                        .open(db_path)
 49                }
 50            })
 51            .await
 52            .context("opening database connection")?;
 53
 54        cx.update(|cx| {
 55            cx.observe_new_views(
 56                |workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
 57                    let project = workspace.project().clone();
 58
 59                    if cx.has_global::<SemanticDb>() {
 60                        cx.update_global::<SemanticDb, _>(|this, cx| {
 61                            this.create_project_index(project, cx);
 62                        })
 63                    } else {
 64                        log::info!("No SemanticDb, skipping project index")
 65                    }
 66                },
 67            )
 68            .detach();
 69        })
 70        .ok();
 71
 72        Ok(SemanticDb {
 73            db_connection: Some(db_connection),
 74            embedding_provider,
 75            project_indices: HashMap::default(),
 76        })
 77    }
 78
 79    pub async fn load_results(
 80        results: Vec<SearchResult>,
 81        fs: &Arc<dyn Fs>,
 82        cx: &AsyncAppContext,
 83    ) -> Result<Vec<LoadedSearchResult>> {
 84        let mut loaded_results = Vec::new();
 85        for result in results {
 86            let (full_path, file_content) = result.worktree.read_with(cx, |worktree, _cx| {
 87                let entry_abs_path = worktree.abs_path().join(&result.path);
 88                let mut entry_full_path = PathBuf::from(worktree.root_name());
 89                entry_full_path.push(&result.path);
 90                let file_content = async {
 91                    let entry_abs_path = entry_abs_path;
 92                    fs.load(&entry_abs_path).await
 93                };
 94                (entry_full_path, file_content)
 95            })?;
 96            if let Some(file_content) = file_content.await.log_err() {
 97                let range_start = result.range.start.min(file_content.len());
 98                let range_end = result.range.end.min(file_content.len());
 99
100                let start_row = file_content[0..range_start].matches('\n').count() as u32;
101                let end_row = file_content[0..range_end].matches('\n').count() as u32;
102                let start_line_byte_offset = file_content[0..range_start]
103                    .rfind('\n')
104                    .map(|pos| pos + 1)
105                    .unwrap_or_default();
106                let end_line_byte_offset = file_content[range_end..]
107                    .find('\n')
108                    .map(|pos| range_end + pos)
109                    .unwrap_or_else(|| file_content.len());
110
111                loaded_results.push(LoadedSearchResult {
112                    path: result.path,
113                    range: start_line_byte_offset..end_line_byte_offset,
114                    full_path,
115                    file_content,
116                    row_range: start_row..=end_row,
117                });
118            }
119        }
120        Ok(loaded_results)
121    }
122
123    pub fn project_index(
124        &mut self,
125        project: Model<Project>,
126        _cx: &mut AppContext,
127    ) -> Option<Model<ProjectIndex>> {
128        self.project_indices.get(&project.downgrade()).cloned()
129    }
130
131    pub fn remaining_summaries(
132        &self,
133        project: &WeakModel<Project>,
134        cx: &mut AppContext,
135    ) -> Option<usize> {
136        self.project_indices.get(project).map(|project_index| {
137            project_index.update(cx, |project_index, cx| {
138                project_index.remaining_summaries(cx)
139            })
140        })
141    }
142
143    pub fn create_project_index(
144        &mut self,
145        project: Model<Project>,
146        cx: &mut AppContext,
147    ) -> Model<ProjectIndex> {
148        let project_index = cx.new_model(|cx| {
149            ProjectIndex::new(
150                project.clone(),
151                self.db_connection.clone().unwrap(),
152                self.embedding_provider.clone(),
153                cx,
154            )
155        });
156
157        let project_weak = project.downgrade();
158        self.project_indices
159            .insert(project_weak.clone(), project_index.clone());
160
161        cx.observe_release(&project, move |_, cx| {
162            if cx.has_global::<SemanticDb>() {
163                cx.update_global::<SemanticDb, _>(|this, _| {
164                    this.project_indices.remove(&project_weak);
165                })
166            }
167        })
168        .detach();
169
170        project_index
171    }
172}
173
174impl Drop for SemanticDb {
175    fn drop(&mut self) {
176        self.db_connection.take().unwrap().prepare_for_closing();
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use anyhow::anyhow;
184    use chunking::Chunk;
185    use embedding_index::{ChunkedFile, EmbeddingIndex};
186    use feature_flags::FeatureFlagAppExt;
187    use fs::FakeFs;
188    use futures::{future::BoxFuture, FutureExt};
189    use gpui::TestAppContext;
190    use indexing::IndexingEntrySet;
191    use language::language_settings::AllLanguageSettings;
192    use project::{Project, ProjectEntryId};
193    use serde_json::json;
194    use settings::SettingsStore;
195    use smol::{channel, stream::StreamExt};
196    use std::{future, path::Path, sync::Arc};
197
198    fn init_test(cx: &mut TestAppContext) {
199        env_logger::try_init().ok();
200
201        cx.update(|cx| {
202            let store = SettingsStore::test(cx);
203            cx.set_global(store);
204            language::init(cx);
205            cx.update_flags(false, vec![]);
206            Project::init_settings(cx);
207            SettingsStore::update(cx, |store, cx| {
208                store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
209            });
210        });
211    }
212
213    pub struct TestEmbeddingProvider {
214        batch_size: usize,
215        compute_embedding: Box<dyn Fn(&str) -> Result<Embedding> + Send + Sync>,
216    }
217
218    impl TestEmbeddingProvider {
219        pub fn new(
220            batch_size: usize,
221            compute_embedding: impl 'static + Fn(&str) -> Result<Embedding> + Send + Sync,
222        ) -> Self {
223            Self {
224                batch_size,
225                compute_embedding: Box::new(compute_embedding),
226            }
227        }
228    }
229
230    impl EmbeddingProvider for TestEmbeddingProvider {
231        fn embed<'a>(
232            &'a self,
233            texts: &'a [TextToEmbed<'a>],
234        ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
235            let embeddings = texts
236                .iter()
237                .map(|to_embed| (self.compute_embedding)(to_embed.text))
238                .collect();
239            future::ready(embeddings).boxed()
240        }
241
242        fn batch_size(&self) -> usize {
243            self.batch_size
244        }
245    }
246
247    #[gpui::test]
248    async fn test_search(cx: &mut TestAppContext) {
249        cx.executor().allow_parking();
250
251        init_test(cx);
252
253        let temp_dir = tempfile::tempdir().unwrap();
254
255        let mut semantic_index = SemanticDb::new(
256            temp_dir.path().into(),
257            Arc::new(TestEmbeddingProvider::new(16, |text| {
258                let mut embedding = vec![0f32; 2];
259                // if the text contains garbage, give it a 1 in the first dimension
260                if text.contains("garbage in") {
261                    embedding[0] = 0.9;
262                } else {
263                    embedding[0] = -0.9;
264                }
265
266                if text.contains("garbage out") {
267                    embedding[1] = 0.9;
268                } else {
269                    embedding[1] = -0.9;
270                }
271
272                Ok(Embedding::new(embedding))
273            })),
274            &mut cx.to_async(),
275        )
276        .await
277        .unwrap();
278
279        let fs = FakeFs::new(cx.executor());
280        let project_path = Path::new("/fake_project");
281
282        fs.insert_tree(
283            project_path,
284            json!({
285                "fixture": {
286                    "main.rs": include_str!("../fixture/main.rs"),
287                    "needle.md": include_str!("../fixture/needle.md"),
288                }
289            }),
290        )
291        .await;
292
293        let project = Project::test(fs, [project_path], cx).await;
294
295        let project_index = cx.update(|cx| {
296            let language_registry = project.read(cx).languages().clone();
297            let node_runtime = project.read(cx).node_runtime().unwrap().clone();
298            languages::init(language_registry, node_runtime, cx);
299            semantic_index.create_project_index(project.clone(), cx)
300        });
301
302        cx.run_until_parked();
303        while cx
304            .update(|cx| semantic_index.remaining_summaries(&project.downgrade(), cx))
305            .unwrap()
306            > 0
307        {
308            cx.run_until_parked();
309        }
310
311        let results = cx
312            .update(|cx| {
313                let project_index = project_index.read(cx);
314                let query = "garbage in, garbage out";
315                project_index.search(query.into(), 4, cx)
316            })
317            .await
318            .unwrap();
319
320        assert!(
321            results.len() > 1,
322            "should have found some results, but only found {:?}",
323            results
324        );
325
326        for result in &results {
327            println!("result: {:?}", result.path);
328            println!("score: {:?}", result.score);
329        }
330
331        // Find result that is greater than 0.5
332        let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
333
334        assert_eq!(search_result.path.to_string_lossy(), "fixture/needle.md");
335
336        let content = cx
337            .update(|cx| {
338                let worktree = search_result.worktree.read(cx);
339                let entry_abs_path = worktree.abs_path().join(&search_result.path);
340                let fs = project.read(cx).fs().clone();
341                cx.background_executor()
342                    .spawn(async move { fs.load(&entry_abs_path).await.unwrap() })
343            })
344            .await;
345
346        let range = search_result.range.clone();
347        let content = content[range.clone()].to_owned();
348
349        assert!(content.contains("garbage in, garbage out"));
350    }
351
352    #[gpui::test]
353    async fn test_embed_files(cx: &mut TestAppContext) {
354        cx.executor().allow_parking();
355
356        let provider = Arc::new(TestEmbeddingProvider::new(3, |text| {
357            if text.contains('g') {
358                Err(anyhow!("cannot embed text containing a 'g' character"))
359            } else {
360                Ok(Embedding::new(
361                    ('a'..='z')
362                        .map(|char| text.chars().filter(|c| *c == char).count() as f32)
363                        .collect(),
364                ))
365            }
366        }));
367
368        let (indexing_progress_tx, _) = channel::unbounded();
369        let indexing_entries = Arc::new(IndexingEntrySet::new(indexing_progress_tx));
370
371        let (chunked_files_tx, chunked_files_rx) = channel::unbounded::<ChunkedFile>();
372        chunked_files_tx
373            .send_blocking(ChunkedFile {
374                path: Path::new("test1.md").into(),
375                mtime: None,
376                handle: indexing_entries.insert(ProjectEntryId::from_proto(0)),
377                text: "abcdefghijklmnop".to_string(),
378                chunks: [0..4, 4..8, 8..12, 12..16]
379                    .into_iter()
380                    .map(|range| Chunk {
381                        range,
382                        digest: Default::default(),
383                    })
384                    .collect(),
385            })
386            .unwrap();
387        chunked_files_tx
388            .send_blocking(ChunkedFile {
389                path: Path::new("test2.md").into(),
390                mtime: None,
391                handle: indexing_entries.insert(ProjectEntryId::from_proto(1)),
392                text: "qrstuvwxyz".to_string(),
393                chunks: [0..4, 4..8, 8..10]
394                    .into_iter()
395                    .map(|range| Chunk {
396                        range,
397                        digest: Default::default(),
398                    })
399                    .collect(),
400            })
401            .unwrap();
402        chunked_files_tx.close();
403
404        let embed_files_task =
405            cx.update(|cx| EmbeddingIndex::embed_files(provider.clone(), chunked_files_rx, cx));
406        embed_files_task.task.await.unwrap();
407
408        let mut embedded_files_rx = embed_files_task.files;
409        let mut embedded_files = Vec::new();
410        while let Some((embedded_file, _)) = embedded_files_rx.next().await {
411            embedded_files.push(embedded_file);
412        }
413
414        assert_eq!(embedded_files.len(), 1);
415        assert_eq!(embedded_files[0].path.as_ref(), Path::new("test2.md"));
416        assert_eq!(
417            embedded_files[0]
418                .chunks
419                .iter()
420                .map(|embedded_chunk| { embedded_chunk.embedding.clone() })
421                .collect::<Vec<Embedding>>(),
422            vec![
423                (provider.compute_embedding)("qrst").unwrap(),
424                (provider.compute_embedding)("uvwx").unwrap(),
425                (provider.compute_embedding)("yz").unwrap(),
426            ],
427        );
428    }
429}