1mod db;
2mod embedding;
3mod modal;
4
5#[cfg(test)]
6mod vector_store_tests;
7
8use anyhow::{anyhow, Result};
9use db::{FileSha1, VectorDatabase};
10use embedding::{EmbeddingProvider, OpenAIEmbeddings};
11use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext};
12use language::{Language, LanguageRegistry};
13use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
14use project::{Fs, Project, WorktreeId};
15use smol::channel;
16use std::{
17 cmp::Ordering,
18 collections::{HashMap, HashSet},
19 path::{Path, PathBuf},
20 sync::Arc,
21};
22use tree_sitter::{Parser, QueryCursor};
23use util::{
24 channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt, TryFutureExt,
25};
26use workspace::{Workspace, WorkspaceCreated};
27
28#[derive(Debug)]
29pub struct Document {
30 pub offset: usize,
31 pub name: String,
32 pub embedding: Vec<f32>,
33}
34
35pub fn init(
36 fs: Arc<dyn Fs>,
37 http_client: Arc<dyn HttpClient>,
38 language_registry: Arc<LanguageRegistry>,
39 cx: &mut AppContext,
40) {
41 let db_file_path = EMBEDDINGS_DIR
42 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
43 .join("embeddings_db");
44
45 let vector_store = cx.add_model(|_| {
46 VectorStore::new(
47 fs,
48 db_file_path,
49 Arc::new(OpenAIEmbeddings {
50 client: http_client,
51 }),
52 language_registry,
53 )
54 });
55
56 cx.subscribe_global::<WorkspaceCreated, _>({
57 let vector_store = vector_store.clone();
58 move |event, cx| {
59 let workspace = &event.0;
60 if let Some(workspace) = workspace.upgrade(cx) {
61 let project = workspace.read(cx).project().clone();
62 if project.read(cx).is_local() {
63 vector_store.update(cx, |store, cx| {
64 store.add_project(project, cx).detach();
65 });
66 }
67 }
68 }
69 })
70 .detach();
71
72 cx.add_action({
73 move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
74 let vector_store = vector_store.clone();
75 workspace.toggle_modal(cx, |workspace, cx| {
76 let project = workspace.project().clone();
77 let workspace = cx.weak_handle();
78 cx.add_view(|cx| {
79 SemanticSearch::new(
80 SemanticSearchDelegate::new(workspace, project, vector_store),
81 cx,
82 )
83 })
84 })
85 }
86 });
87
88 SemanticSearch::init(cx);
89}
90
91#[derive(Debug)]
92pub struct IndexedFile {
93 path: PathBuf,
94 sha1: FileSha1,
95 documents: Vec<Document>,
96}
97
98pub struct VectorStore {
99 fs: Arc<dyn Fs>,
100 database_url: Arc<PathBuf>,
101 embedding_provider: Arc<dyn EmbeddingProvider>,
102 language_registry: Arc<LanguageRegistry>,
103 worktree_db_ids: Vec<(WorktreeId, i64)>,
104}
105
106#[derive(Debug, Clone)]
107pub struct SearchResult {
108 pub worktree_id: WorktreeId,
109 pub name: String,
110 pub offset: usize,
111 pub file_path: PathBuf,
112}
113
114impl VectorStore {
115 fn new(
116 fs: Arc<dyn Fs>,
117 database_url: PathBuf,
118 embedding_provider: Arc<dyn EmbeddingProvider>,
119 language_registry: Arc<LanguageRegistry>,
120 ) -> Self {
121 Self {
122 fs,
123 database_url: Arc::new(database_url),
124 embedding_provider,
125 language_registry,
126 worktree_db_ids: Vec::new(),
127 }
128 }
129
130 async fn index_file(
131 cursor: &mut QueryCursor,
132 parser: &mut Parser,
133 embedding_provider: &dyn EmbeddingProvider,
134 language: Arc<Language>,
135 file_path: PathBuf,
136 content: String,
137 ) -> Result<IndexedFile> {
138 let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
139 let embedding_config = grammar
140 .embedding_config
141 .as_ref()
142 .ok_or_else(|| anyhow!("no outline query"))?;
143
144 parser.set_language(grammar.ts_language).unwrap();
145 let tree = parser
146 .parse(&content, None)
147 .ok_or_else(|| anyhow!("parsing failed"))?;
148
149 let mut documents = Vec::new();
150 let mut context_spans = Vec::new();
151 for mat in cursor.matches(
152 &embedding_config.query,
153 tree.root_node(),
154 content.as_bytes(),
155 ) {
156 let mut item_range = None;
157 let mut name_range = None;
158 for capture in mat.captures {
159 if capture.index == embedding_config.item_capture_ix {
160 item_range = Some(capture.node.byte_range());
161 } else if capture.index == embedding_config.name_capture_ix {
162 name_range = Some(capture.node.byte_range());
163 }
164 }
165
166 if let Some((item_range, name_range)) = item_range.zip(name_range) {
167 if let Some((item, name)) =
168 content.get(item_range.clone()).zip(content.get(name_range))
169 {
170 context_spans.push(item);
171 documents.push(Document {
172 name: name.to_string(),
173 offset: item_range.start,
174 embedding: Vec::new(),
175 });
176 }
177 }
178 }
179
180 if !documents.is_empty() {
181 let embeddings = embedding_provider.embed_batch(context_spans).await?;
182 for (document, embedding) in documents.iter_mut().zip(embeddings) {
183 document.embedding = embedding;
184 }
185 }
186
187 let sha1 = FileSha1::from_str(content);
188
189 return Ok(IndexedFile {
190 path: file_path,
191 sha1,
192 documents,
193 });
194 }
195
196 fn add_project(
197 &mut self,
198 project: ModelHandle<Project>,
199 cx: &mut ModelContext<Self>,
200 ) -> Task<Result<()>> {
201 let worktree_scans_complete = project
202 .read(cx)
203 .worktrees(cx)
204 .map(|worktree| {
205 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
206 async move {
207 scan_complete.await;
208 }
209 })
210 .collect::<Vec<_>>();
211
212 let fs = self.fs.clone();
213 let language_registry = self.language_registry.clone();
214 let embedding_provider = self.embedding_provider.clone();
215 let database_url = self.database_url.clone();
216
217 cx.spawn(|this, mut cx| async move {
218 futures::future::join_all(worktree_scans_complete).await;
219
220 if let Some(db_directory) = database_url.parent() {
221 fs.create_dir(db_directory).await.log_err();
222 }
223 let db = VectorDatabase::new(database_url.to_string_lossy().into())?;
224
225 let worktrees = project.read_with(&cx, |project, cx| {
226 project
227 .worktrees(cx)
228 .map(|worktree| worktree.read(cx).snapshot())
229 .collect::<Vec<_>>()
230 });
231
232 // Here we query the worktree ids, and yet we dont have them elsewhere
233 // We likely want to clean up these datastructures
234 let (db, worktree_hashes, worktree_db_ids) = cx
235 .background()
236 .spawn({
237 let worktrees = worktrees.clone();
238 async move {
239 let mut worktree_db_ids: HashMap<WorktreeId, i64> = HashMap::new();
240 let mut hashes: HashMap<WorktreeId, HashMap<PathBuf, FileSha1>> =
241 HashMap::new();
242 for worktree in worktrees {
243 let worktree_db_id =
244 db.find_or_create_worktree(worktree.abs_path().as_ref())?;
245 worktree_db_ids.insert(worktree.id(), worktree_db_id);
246 hashes.insert(worktree.id(), db.get_file_hashes(worktree_db_id)?);
247 }
248 anyhow::Ok((db, hashes, worktree_db_ids))
249 }
250 })
251 .await?;
252
253 let (paths_tx, paths_rx) =
254 channel::unbounded::<(i64, PathBuf, String, Arc<Language>)>();
255 let (delete_paths_tx, delete_paths_rx) = channel::unbounded::<(i64, PathBuf)>();
256 let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>();
257 cx.background()
258 .spawn({
259 let fs = fs.clone();
260 let worktree_db_ids = worktree_db_ids.clone();
261 async move {
262 for worktree in worktrees.into_iter() {
263 let file_hashes = &worktree_hashes[&worktree.id()];
264 let mut files_included =
265 file_hashes.keys().collect::<HashSet<&PathBuf>>();
266 for file in worktree.files(false, 0) {
267 let absolute_path = worktree.absolutize(&file.path);
268
269 if let Ok(language) = language_registry
270 .language_for_file(&absolute_path, None)
271 .await
272 {
273 if language
274 .grammar()
275 .and_then(|grammar| grammar.embedding_config.as_ref())
276 .is_none()
277 {
278 continue;
279 }
280
281 if let Some(content) = fs.load(&absolute_path).await.log_err() {
282 let path_buf = file.path.to_path_buf();
283 let already_stored = file_hashes.get(&path_buf).map_or(
284 false,
285 |existing_hash| {
286 files_included.remove(&path_buf);
287 existing_hash.equals(&content)
288 },
289 );
290
291 if !already_stored {
292 paths_tx
293 .try_send((
294 worktree_db_ids[&worktree.id()],
295 path_buf,
296 content,
297 language,
298 ))
299 .unwrap();
300 }
301 }
302 }
303 }
304 for file in files_included {
305 delete_paths_tx
306 .try_send((worktree_db_ids[&worktree.id()], file.to_owned()))
307 .unwrap();
308 }
309 }
310 }
311 })
312 .detach();
313
314 let db_update_task = cx.background().spawn(
315 async move {
316 // Inserting all new files
317 while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await {
318 log::info!("Inserting File: {:?}", &indexed_file.path);
319 db.insert_file(worktree_id, indexed_file).log_err();
320 }
321
322 // Deleting all old files
323 while let Ok((worktree_id, delete_path)) = delete_paths_rx.recv().await {
324 log::info!("Deleting File: {:?}", &delete_path);
325 db.delete_file(worktree_id, delete_path).log_err();
326 }
327
328 anyhow::Ok(())
329 }
330 .log_err(),
331 );
332
333 cx.background()
334 .scoped(|scope| {
335 for _ in 0..cx.background().num_cpus() {
336 scope.spawn(async {
337 let mut parser = Parser::new();
338 let mut cursor = QueryCursor::new();
339 while let Ok((worktree_id, file_path, content, language)) =
340 paths_rx.recv().await
341 {
342 if let Some(indexed_file) = Self::index_file(
343 &mut cursor,
344 &mut parser,
345 embedding_provider.as_ref(),
346 language,
347 file_path,
348 content,
349 )
350 .await
351 .log_err()
352 {
353 indexed_files_tx
354 .try_send((worktree_id, indexed_file))
355 .unwrap();
356 }
357 }
358 });
359 }
360 })
361 .await;
362 drop(indexed_files_tx);
363
364 db_update_task.await;
365
366 this.update(&mut cx, |this, _| {
367 this.worktree_db_ids.extend(worktree_db_ids);
368 });
369
370 log::info!("Semantic Indexing Complete!");
371
372 anyhow::Ok(())
373 })
374 }
375
376 pub fn search(
377 &mut self,
378 project: &ModelHandle<Project>,
379 phrase: String,
380 limit: usize,
381 cx: &mut ModelContext<Self>,
382 ) -> Task<Result<Vec<SearchResult>>> {
383 let project = project.read(cx);
384 let worktree_db_ids = project
385 .worktrees(cx)
386 .filter_map(|worktree| {
387 let worktree_id = worktree.read(cx).id();
388 self.worktree_db_ids.iter().find_map(|(id, db_id)| {
389 if *id == worktree_id {
390 Some(*db_id)
391 } else {
392 None
393 }
394 })
395 })
396 .collect::<Vec<_>>();
397
398 let embedding_provider = self.embedding_provider.clone();
399 let database_url = self.database_url.clone();
400 cx.spawn(|this, cx| async move {
401 let documents = cx
402 .background()
403 .spawn(async move {
404 let database = VectorDatabase::new(database_url.to_string_lossy().into())?;
405
406 let phrase_embedding = embedding_provider
407 .embed_batch(vec![&phrase])
408 .await?
409 .into_iter()
410 .next()
411 .unwrap();
412
413 let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
414 database.for_each_document(&worktree_db_ids, |id, embedding| {
415 let similarity = dot(&embedding.0, &phrase_embedding);
416 let ix = match results.binary_search_by(|(_, s)| {
417 similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
418 }) {
419 Ok(ix) => ix,
420 Err(ix) => ix,
421 };
422 results.insert(ix, (id, similarity));
423 results.truncate(limit);
424 })?;
425
426 let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
427 database.get_documents_by_ids(&ids)
428 })
429 .await?;
430
431 let results = this.read_with(&cx, |this, _| {
432 documents
433 .into_iter()
434 .filter_map(|(worktree_db_id, file_path, offset, name)| {
435 let worktree_id = this.worktree_db_ids.iter().find_map(|(id, db_id)| {
436 if *db_id == worktree_db_id {
437 Some(*id)
438 } else {
439 None
440 }
441 })?;
442 Some(SearchResult {
443 worktree_id,
444 name,
445 offset,
446 file_path,
447 })
448 })
449 .collect()
450 });
451
452 anyhow::Ok(results)
453 })
454 }
455}
456
457impl Entity for VectorStore {
458 type Event = ();
459}
460
461fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
462 let len = vec_a.len();
463 assert_eq!(len, vec_b.len());
464
465 let mut result = 0.0;
466 unsafe {
467 matrixmultiply::sgemm(
468 1,
469 len,
470 1,
471 1.0,
472 vec_a.as_ptr(),
473 len as isize,
474 1,
475 vec_b.as_ptr(),
476 1,
477 len as isize,
478 0.0,
479 &mut result as *mut f32,
480 1,
481 1,
482 );
483 }
484 result
485}