1mod chunking;
2mod embedding;
3mod embedding_index;
4mod indexing;
5mod project_index;
6mod project_index_debug_view;
7mod summary_backlog;
8mod summary_index;
9mod worktree_index;
10
11use anyhow::{Context as _, Result};
12use collections::HashMap;
13use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Context, Global, Model, WeakModel};
14use project::Project;
15use project_index::ProjectIndex;
16use std::{path::PathBuf, sync::Arc};
17use ui::ViewContext;
18use workspace::Workspace;
19
20pub use embedding::*;
21pub use project_index_debug_view::ProjectIndexDebugView;
22pub use summary_index::FileSummary;
23
24pub struct SemanticDb {
25 embedding_provider: Arc<dyn EmbeddingProvider>,
26 db_connection: heed::Env,
27 project_indices: HashMap<WeakModel<Project>, Model<ProjectIndex>>,
28}
29
30impl Global for SemanticDb {}
31
32impl SemanticDb {
33 pub async fn new(
34 db_path: PathBuf,
35 embedding_provider: Arc<dyn EmbeddingProvider>,
36 cx: &mut AsyncAppContext,
37 ) -> Result<Self> {
38 let db_connection = cx
39 .background_executor()
40 .spawn(async move {
41 std::fs::create_dir_all(&db_path)?;
42 unsafe {
43 heed::EnvOpenOptions::new()
44 .map_size(1024 * 1024 * 1024)
45 .max_dbs(3000)
46 .open(db_path)
47 }
48 })
49 .await
50 .context("opening database connection")?;
51
52 cx.update(|cx| {
53 cx.observe_new_views(
54 |workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
55 let project = workspace.project().clone();
56
57 if cx.has_global::<SemanticDb>() {
58 cx.update_global::<SemanticDb, _>(|this, cx| {
59 let project_index = cx.new_model(|cx| {
60 ProjectIndex::new(
61 project.clone(),
62 this.db_connection.clone(),
63 this.embedding_provider.clone(),
64 cx,
65 )
66 });
67
68 let project_weak = project.downgrade();
69 this.project_indices
70 .insert(project_weak.clone(), project_index);
71
72 cx.on_release(move |_, _, cx| {
73 if cx.has_global::<SemanticDb>() {
74 cx.update_global::<SemanticDb, _>(|this, _| {
75 this.project_indices.remove(&project_weak);
76 })
77 }
78 })
79 .detach();
80 })
81 } else {
82 log::info!("No SemanticDb, skipping project index")
83 }
84 },
85 )
86 .detach();
87 })
88 .ok();
89
90 Ok(SemanticDb {
91 db_connection,
92 embedding_provider,
93 project_indices: HashMap::default(),
94 })
95 }
96
97 pub fn project_index(
98 &mut self,
99 project: Model<Project>,
100 _cx: &mut AppContext,
101 ) -> Option<Model<ProjectIndex>> {
102 self.project_indices.get(&project.downgrade()).cloned()
103 }
104
105 pub fn remaining_summaries(
106 &self,
107 project: &WeakModel<Project>,
108 cx: &mut AppContext,
109 ) -> Option<usize> {
110 self.project_indices.get(project).map(|project_index| {
111 project_index.update(cx, |project_index, cx| {
112 project_index.remaining_summaries(cx)
113 })
114 })
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use anyhow::anyhow;
122 use chunking::Chunk;
123 use embedding_index::{ChunkedFile, EmbeddingIndex};
124 use feature_flags::FeatureFlagAppExt;
125 use fs::FakeFs;
126 use futures::{future::BoxFuture, FutureExt};
127 use gpui::TestAppContext;
128 use indexing::IndexingEntrySet;
129 use language::language_settings::AllLanguageSettings;
130 use project::{Project, ProjectEntryId};
131 use serde_json::json;
132 use settings::SettingsStore;
133 use smol::{channel, stream::StreamExt};
134 use std::{future, path::Path, sync::Arc};
135
136 fn init_test(cx: &mut TestAppContext) {
137 env_logger::try_init().ok();
138
139 cx.update(|cx| {
140 let store = SettingsStore::test(cx);
141 cx.set_global(store);
142 language::init(cx);
143 cx.update_flags(false, vec![]);
144 Project::init_settings(cx);
145 SettingsStore::update(cx, |store, cx| {
146 store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
147 });
148 });
149 }
150
151 pub struct TestEmbeddingProvider {
152 batch_size: usize,
153 compute_embedding: Box<dyn Fn(&str) -> Result<Embedding> + Send + Sync>,
154 }
155
156 impl TestEmbeddingProvider {
157 pub fn new(
158 batch_size: usize,
159 compute_embedding: impl 'static + Fn(&str) -> Result<Embedding> + Send + Sync,
160 ) -> Self {
161 Self {
162 batch_size,
163 compute_embedding: Box::new(compute_embedding),
164 }
165 }
166 }
167
168 impl EmbeddingProvider for TestEmbeddingProvider {
169 fn embed<'a>(
170 &'a self,
171 texts: &'a [TextToEmbed<'a>],
172 ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
173 let embeddings = texts
174 .iter()
175 .map(|to_embed| (self.compute_embedding)(to_embed.text))
176 .collect();
177 future::ready(embeddings).boxed()
178 }
179
180 fn batch_size(&self) -> usize {
181 self.batch_size
182 }
183 }
184
185 #[gpui::test]
186 async fn test_search(cx: &mut TestAppContext) {
187 cx.executor().allow_parking();
188
189 init_test(cx);
190
191 let temp_dir = tempfile::tempdir().unwrap();
192
193 let mut semantic_index = SemanticDb::new(
194 temp_dir.path().into(),
195 Arc::new(TestEmbeddingProvider::new(16, |text| {
196 let mut embedding = vec![0f32; 2];
197 // if the text contains garbage, give it a 1 in the first dimension
198 if text.contains("garbage in") {
199 embedding[0] = 0.9;
200 } else {
201 embedding[0] = -0.9;
202 }
203
204 if text.contains("garbage out") {
205 embedding[1] = 0.9;
206 } else {
207 embedding[1] = -0.9;
208 }
209
210 Ok(Embedding::new(embedding))
211 })),
212 &mut cx.to_async(),
213 )
214 .await
215 .unwrap();
216
217 let fs = FakeFs::new(cx.executor());
218 let project_path = Path::new("/fake_project");
219
220 fs.insert_tree(
221 project_path,
222 json!({
223 "fixture": {
224 "main.rs": include_str!("../fixture/main.rs"),
225 "needle.md": include_str!("../fixture/needle.md"),
226 }
227 }),
228 )
229 .await;
230
231 let project = Project::test(fs, [project_path], cx).await;
232
233 cx.update(|cx| {
234 let language_registry = project.read(cx).languages().clone();
235 let node_runtime = project.read(cx).node_runtime().unwrap().clone();
236 languages::init(language_registry, node_runtime, cx);
237
238 // Manually create and insert the ProjectIndex
239 let project_index = cx.new_model(|cx| {
240 ProjectIndex::new(
241 project.clone(),
242 semantic_index.db_connection.clone(),
243 semantic_index.embedding_provider.clone(),
244 cx,
245 )
246 });
247 semantic_index
248 .project_indices
249 .insert(project.downgrade(), project_index);
250 });
251
252 let project_index = cx
253 .update(|_cx| {
254 semantic_index
255 .project_indices
256 .get(&project.downgrade())
257 .cloned()
258 })
259 .unwrap();
260
261 cx.run_until_parked();
262 while cx
263 .update(|cx| semantic_index.remaining_summaries(&project.downgrade(), cx))
264 .unwrap()
265 > 0
266 {
267 cx.run_until_parked();
268 }
269
270 let results = cx
271 .update(|cx| {
272 let project_index = project_index.read(cx);
273 let query = "garbage in, garbage out";
274 project_index.search(query.into(), 4, cx)
275 })
276 .await
277 .unwrap();
278
279 assert!(
280 results.len() > 1,
281 "should have found some results, but only found {:?}",
282 results
283 );
284
285 for result in &results {
286 println!("result: {:?}", result.path);
287 println!("score: {:?}", result.score);
288 }
289
290 // Find result that is greater than 0.5
291 let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
292
293 assert_eq!(search_result.path.to_string_lossy(), "fixture/needle.md");
294
295 let content = cx
296 .update(|cx| {
297 let worktree = search_result.worktree.read(cx);
298 let entry_abs_path = worktree.abs_path().join(&search_result.path);
299 let fs = project.read(cx).fs().clone();
300 cx.background_executor()
301 .spawn(async move { fs.load(&entry_abs_path).await.unwrap() })
302 })
303 .await;
304
305 let range = search_result.range.clone();
306 let content = content[range.clone()].to_owned();
307
308 assert!(content.contains("garbage in, garbage out"));
309 }
310
311 #[gpui::test]
312 async fn test_embed_files(cx: &mut TestAppContext) {
313 cx.executor().allow_parking();
314
315 let provider = Arc::new(TestEmbeddingProvider::new(3, |text| {
316 if text.contains('g') {
317 Err(anyhow!("cannot embed text containing a 'g' character"))
318 } else {
319 Ok(Embedding::new(
320 ('a'..='z')
321 .map(|char| text.chars().filter(|c| *c == char).count() as f32)
322 .collect(),
323 ))
324 }
325 }));
326
327 let (indexing_progress_tx, _) = channel::unbounded();
328 let indexing_entries = Arc::new(IndexingEntrySet::new(indexing_progress_tx));
329
330 let (chunked_files_tx, chunked_files_rx) = channel::unbounded::<ChunkedFile>();
331 chunked_files_tx
332 .send_blocking(ChunkedFile {
333 path: Path::new("test1.md").into(),
334 mtime: None,
335 handle: indexing_entries.insert(ProjectEntryId::from_proto(0)),
336 text: "abcdefghijklmnop".to_string(),
337 chunks: [0..4, 4..8, 8..12, 12..16]
338 .into_iter()
339 .map(|range| Chunk {
340 range,
341 digest: Default::default(),
342 })
343 .collect(),
344 })
345 .unwrap();
346 chunked_files_tx
347 .send_blocking(ChunkedFile {
348 path: Path::new("test2.md").into(),
349 mtime: None,
350 handle: indexing_entries.insert(ProjectEntryId::from_proto(1)),
351 text: "qrstuvwxyz".to_string(),
352 chunks: [0..4, 4..8, 8..10]
353 .into_iter()
354 .map(|range| Chunk {
355 range,
356 digest: Default::default(),
357 })
358 .collect(),
359 })
360 .unwrap();
361 chunked_files_tx.close();
362
363 let embed_files_task =
364 cx.update(|cx| EmbeddingIndex::embed_files(provider.clone(), chunked_files_rx, cx));
365 embed_files_task.task.await.unwrap();
366
367 let mut embedded_files_rx = embed_files_task.files;
368 let mut embedded_files = Vec::new();
369 while let Some((embedded_file, _)) = embedded_files_rx.next().await {
370 embedded_files.push(embedded_file);
371 }
372
373 assert_eq!(embedded_files.len(), 1);
374 assert_eq!(embedded_files[0].path.as_ref(), Path::new("test2.md"));
375 assert_eq!(
376 embedded_files[0]
377 .chunks
378 .iter()
379 .map(|embedded_chunk| { embedded_chunk.embedding.clone() })
380 .collect::<Vec<Embedding>>(),
381 vec![
382 (provider.compute_embedding)("qrst").unwrap(),
383 (provider.compute_embedding)("uvwx").unwrap(),
384 (provider.compute_embedding)("yz").unwrap(),
385 ],
386 );
387 }
388}