1mod chunking;
2mod embedding;
3mod embedding_index;
4mod indexing;
5mod project_index;
6mod project_index_debug_view;
7mod summary_backlog;
8mod summary_index;
9mod worktree_index;
10
11use anyhow::{Context as _, Result};
12use collections::HashMap;
13use fs::Fs;
14use gpui::{AppContext, AsyncAppContext, BorrowAppContext, Context, Global, Model, WeakModel};
15use project::Project;
16use std::{path::PathBuf, sync::Arc};
17use ui::ViewContext;
18use util::ResultExt as _;
19use workspace::Workspace;
20
21pub use embedding::*;
22pub use project_index::{LoadedSearchResult, ProjectIndex, SearchResult, Status};
23pub use project_index_debug_view::ProjectIndexDebugView;
24pub use summary_index::FileSummary;
25
26pub struct SemanticDb {
27 embedding_provider: Arc<dyn EmbeddingProvider>,
28 db_connection: Option<heed::Env>,
29 project_indices: HashMap<WeakModel<Project>, Model<ProjectIndex>>,
30}
31
32impl Global for SemanticDb {}
33
34impl SemanticDb {
35 pub async fn new(
36 db_path: PathBuf,
37 embedding_provider: Arc<dyn EmbeddingProvider>,
38 cx: &mut AsyncAppContext,
39 ) -> Result<Self> {
40 let db_connection = cx
41 .background_executor()
42 .spawn(async move {
43 std::fs::create_dir_all(&db_path)?;
44 unsafe {
45 heed::EnvOpenOptions::new()
46 .map_size(1024 * 1024 * 1024)
47 .max_dbs(3000)
48 .open(db_path)
49 }
50 })
51 .await
52 .context("opening database connection")?;
53
54 cx.update(|cx| {
55 cx.observe_new_views(
56 |workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
57 let project = workspace.project().clone();
58
59 if cx.has_global::<SemanticDb>() {
60 cx.update_global::<SemanticDb, _>(|this, cx| {
61 this.create_project_index(project, cx);
62 })
63 } else {
64 log::info!("No SemanticDb, skipping project index")
65 }
66 },
67 )
68 .detach();
69 })
70 .ok();
71
72 Ok(SemanticDb {
73 db_connection: Some(db_connection),
74 embedding_provider,
75 project_indices: HashMap::default(),
76 })
77 }
78
79 pub async fn load_results(
80 results: Vec<SearchResult>,
81 fs: &Arc<dyn Fs>,
82 cx: &AsyncAppContext,
83 ) -> Result<Vec<LoadedSearchResult>> {
84 let mut loaded_results = Vec::new();
85 for result in results {
86 let (full_path, file_content) = result.worktree.read_with(cx, |worktree, _cx| {
87 let entry_abs_path = worktree.abs_path().join(&result.path);
88 let mut entry_full_path = PathBuf::from(worktree.root_name());
89 entry_full_path.push(&result.path);
90 let file_content = async {
91 let entry_abs_path = entry_abs_path;
92 fs.load(&entry_abs_path).await
93 };
94 (entry_full_path, file_content)
95 })?;
96 if let Some(file_content) = file_content.await.log_err() {
97 let range_start = result.range.start.min(file_content.len());
98 let range_end = result.range.end.min(file_content.len());
99
100 let start_row = file_content[0..range_start].matches('\n').count() as u32;
101 let end_row = file_content[0..range_end].matches('\n').count() as u32;
102 let start_line_byte_offset = file_content[0..range_start]
103 .rfind('\n')
104 .map(|pos| pos + 1)
105 .unwrap_or_default();
106 let end_line_byte_offset = file_content[range_end..]
107 .find('\n')
108 .map(|pos| range_end + pos)
109 .unwrap_or_else(|| file_content.len());
110
111 loaded_results.push(LoadedSearchResult {
112 path: result.path,
113 range: start_line_byte_offset..end_line_byte_offset,
114 full_path,
115 file_content,
116 row_range: start_row..=end_row,
117 });
118 }
119 }
120 Ok(loaded_results)
121 }
122
123 pub fn project_index(
124 &mut self,
125 project: Model<Project>,
126 _cx: &mut AppContext,
127 ) -> Option<Model<ProjectIndex>> {
128 self.project_indices.get(&project.downgrade()).cloned()
129 }
130
131 pub fn remaining_summaries(
132 &self,
133 project: &WeakModel<Project>,
134 cx: &mut AppContext,
135 ) -> Option<usize> {
136 self.project_indices.get(project).map(|project_index| {
137 project_index.update(cx, |project_index, cx| {
138 project_index.remaining_summaries(cx)
139 })
140 })
141 }
142
143 pub fn create_project_index(
144 &mut self,
145 project: Model<Project>,
146 cx: &mut AppContext,
147 ) -> Model<ProjectIndex> {
148 let project_index = cx.new_model(|cx| {
149 ProjectIndex::new(
150 project.clone(),
151 self.db_connection.clone().unwrap(),
152 self.embedding_provider.clone(),
153 cx,
154 )
155 });
156
157 let project_weak = project.downgrade();
158 self.project_indices
159 .insert(project_weak.clone(), project_index.clone());
160
161 cx.observe_release(&project, move |_, cx| {
162 if cx.has_global::<SemanticDb>() {
163 cx.update_global::<SemanticDb, _>(|this, _| {
164 this.project_indices.remove(&project_weak);
165 })
166 }
167 })
168 .detach();
169
170 project_index
171 }
172}
173
174impl Drop for SemanticDb {
175 fn drop(&mut self) {
176 self.db_connection.take().unwrap().prepare_for_closing();
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use anyhow::anyhow;
184 use chunking::Chunk;
185 use embedding_index::{ChunkedFile, EmbeddingIndex};
186 use feature_flags::FeatureFlagAppExt;
187 use fs::FakeFs;
188 use futures::{future::BoxFuture, FutureExt};
189 use gpui::TestAppContext;
190 use indexing::IndexingEntrySet;
191 use language::language_settings::AllLanguageSettings;
192 use project::{Project, ProjectEntryId};
193 use serde_json::json;
194 use settings::SettingsStore;
195 use smol::{channel, stream::StreamExt};
196 use std::{future, path::Path, sync::Arc};
197
198 fn init_test(cx: &mut TestAppContext) {
199 env_logger::try_init().ok();
200
201 cx.update(|cx| {
202 let store = SettingsStore::test(cx);
203 cx.set_global(store);
204 language::init(cx);
205 cx.update_flags(false, vec![]);
206 Project::init_settings(cx);
207 SettingsStore::update(cx, |store, cx| {
208 store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
209 });
210 });
211 }
212
213 pub struct TestEmbeddingProvider {
214 batch_size: usize,
215 compute_embedding: Box<dyn Fn(&str) -> Result<Embedding> + Send + Sync>,
216 }
217
218 impl TestEmbeddingProvider {
219 pub fn new(
220 batch_size: usize,
221 compute_embedding: impl 'static + Fn(&str) -> Result<Embedding> + Send + Sync,
222 ) -> Self {
223 Self {
224 batch_size,
225 compute_embedding: Box::new(compute_embedding),
226 }
227 }
228 }
229
230 impl EmbeddingProvider for TestEmbeddingProvider {
231 fn embed<'a>(
232 &'a self,
233 texts: &'a [TextToEmbed<'a>],
234 ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
235 let embeddings = texts
236 .iter()
237 .map(|to_embed| (self.compute_embedding)(to_embed.text))
238 .collect();
239 future::ready(embeddings).boxed()
240 }
241
242 fn batch_size(&self) -> usize {
243 self.batch_size
244 }
245 }
246
247 #[gpui::test]
248 async fn test_search(cx: &mut TestAppContext) {
249 cx.executor().allow_parking();
250
251 init_test(cx);
252
253 let temp_dir = tempfile::tempdir().unwrap();
254
255 let mut semantic_index = SemanticDb::new(
256 temp_dir.path().into(),
257 Arc::new(TestEmbeddingProvider::new(16, |text| {
258 let mut embedding = vec![0f32; 2];
259 // if the text contains garbage, give it a 1 in the first dimension
260 if text.contains("garbage in") {
261 embedding[0] = 0.9;
262 } else {
263 embedding[0] = -0.9;
264 }
265
266 if text.contains("garbage out") {
267 embedding[1] = 0.9;
268 } else {
269 embedding[1] = -0.9;
270 }
271
272 Ok(Embedding::new(embedding))
273 })),
274 &mut cx.to_async(),
275 )
276 .await
277 .unwrap();
278
279 let fs = FakeFs::new(cx.executor());
280 let project_path = Path::new("/fake_project");
281
282 fs.insert_tree(
283 project_path,
284 json!({
285 "fixture": {
286 "main.rs": include_str!("../fixture/main.rs"),
287 "needle.md": include_str!("../fixture/needle.md"),
288 }
289 }),
290 )
291 .await;
292
293 let project = Project::test(fs, [project_path], cx).await;
294
295 let project_index = cx.update(|cx| {
296 let language_registry = project.read(cx).languages().clone();
297 let node_runtime = project.read(cx).node_runtime().unwrap().clone();
298 languages::init(language_registry, node_runtime, cx);
299 semantic_index.create_project_index(project.clone(), cx)
300 });
301
302 cx.run_until_parked();
303 while cx
304 .update(|cx| semantic_index.remaining_summaries(&project.downgrade(), cx))
305 .unwrap()
306 > 0
307 {
308 cx.run_until_parked();
309 }
310
311 let results = cx
312 .update(|cx| {
313 let project_index = project_index.read(cx);
314 let query = "garbage in, garbage out";
315 project_index.search(query.into(), 4, cx)
316 })
317 .await
318 .unwrap();
319
320 assert!(
321 results.len() > 1,
322 "should have found some results, but only found {:?}",
323 results
324 );
325
326 for result in &results {
327 println!("result: {:?}", result.path);
328 println!("score: {:?}", result.score);
329 }
330
331 // Find result that is greater than 0.5
332 let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
333
334 assert_eq!(search_result.path.to_string_lossy(), "fixture/needle.md");
335
336 let content = cx
337 .update(|cx| {
338 let worktree = search_result.worktree.read(cx);
339 let entry_abs_path = worktree.abs_path().join(&search_result.path);
340 let fs = project.read(cx).fs().clone();
341 cx.background_executor()
342 .spawn(async move { fs.load(&entry_abs_path).await.unwrap() })
343 })
344 .await;
345
346 let range = search_result.range.clone();
347 let content = content[range.clone()].to_owned();
348
349 assert!(content.contains("garbage in, garbage out"));
350 }
351
352 #[gpui::test]
353 async fn test_embed_files(cx: &mut TestAppContext) {
354 cx.executor().allow_parking();
355
356 let provider = Arc::new(TestEmbeddingProvider::new(3, |text| {
357 if text.contains('g') {
358 Err(anyhow!("cannot embed text containing a 'g' character"))
359 } else {
360 Ok(Embedding::new(
361 ('a'..='z')
362 .map(|char| text.chars().filter(|c| *c == char).count() as f32)
363 .collect(),
364 ))
365 }
366 }));
367
368 let (indexing_progress_tx, _) = channel::unbounded();
369 let indexing_entries = Arc::new(IndexingEntrySet::new(indexing_progress_tx));
370
371 let (chunked_files_tx, chunked_files_rx) = channel::unbounded::<ChunkedFile>();
372 chunked_files_tx
373 .send_blocking(ChunkedFile {
374 path: Path::new("test1.md").into(),
375 mtime: None,
376 handle: indexing_entries.insert(ProjectEntryId::from_proto(0)),
377 text: "abcdefghijklmnop".to_string(),
378 chunks: [0..4, 4..8, 8..12, 12..16]
379 .into_iter()
380 .map(|range| Chunk {
381 range,
382 digest: Default::default(),
383 })
384 .collect(),
385 })
386 .unwrap();
387 chunked_files_tx
388 .send_blocking(ChunkedFile {
389 path: Path::new("test2.md").into(),
390 mtime: None,
391 handle: indexing_entries.insert(ProjectEntryId::from_proto(1)),
392 text: "qrstuvwxyz".to_string(),
393 chunks: [0..4, 4..8, 8..10]
394 .into_iter()
395 .map(|range| Chunk {
396 range,
397 digest: Default::default(),
398 })
399 .collect(),
400 })
401 .unwrap();
402 chunked_files_tx.close();
403
404 let embed_files_task =
405 cx.update(|cx| EmbeddingIndex::embed_files(provider.clone(), chunked_files_rx, cx));
406 embed_files_task.task.await.unwrap();
407
408 let mut embedded_files_rx = embed_files_task.files;
409 let mut embedded_files = Vec::new();
410 while let Some((embedded_file, _)) = embedded_files_rx.next().await {
411 embedded_files.push(embedded_file);
412 }
413
414 assert_eq!(embedded_files.len(), 1);
415 assert_eq!(embedded_files[0].path.as_ref(), Path::new("test2.md"));
416 assert_eq!(
417 embedded_files[0]
418 .chunks
419 .iter()
420 .map(|embedded_chunk| { embedded_chunk.embedding.clone() })
421 .collect::<Vec<Embedding>>(),
422 vec![
423 (provider.compute_embedding)("qrst").unwrap(),
424 (provider.compute_embedding)("uvwx").unwrap(),
425 (provider.compute_embedding)("yz").unwrap(),
426 ],
427 );
428 }
429}