1mod db;
2mod embedding;
3mod modal;
4
5#[cfg(test)]
6mod vector_store_tests;
7
8use anyhow::{anyhow, Result};
9use db::{FileSha1, VectorDatabase};
10use embedding::{EmbeddingProvider, OpenAIEmbeddings};
11use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext};
12use language::{Language, LanguageRegistry};
13use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
14use project::{Fs, Project, WorktreeId};
15use smol::channel;
16use std::{
17 cmp::Ordering,
18 collections::{HashMap, HashSet},
19 path::{Path, PathBuf},
20 sync::Arc,
21};
22use tree_sitter::{Parser, QueryCursor};
23use util::{
24 channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt, TryFutureExt,
25};
26use workspace::{Workspace, WorkspaceCreated};
27
28#[derive(Debug)]
29pub struct Document {
30 pub offset: usize,
31 pub name: String,
32 pub embedding: Vec<f32>,
33}
34
35pub fn init(
36 fs: Arc<dyn Fs>,
37 http_client: Arc<dyn HttpClient>,
38 language_registry: Arc<LanguageRegistry>,
39 cx: &mut AppContext,
40) {
41 let db_file_path = EMBEDDINGS_DIR
42 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
43 .join("embeddings_db");
44
45 let vector_store = cx.add_model(|_| {
46 VectorStore::new(
47 fs,
48 db_file_path,
49 Arc::new(OpenAIEmbeddings {
50 client: http_client,
51 }),
52 language_registry,
53 )
54 });
55
56 cx.subscribe_global::<WorkspaceCreated, _>({
57 let vector_store = vector_store.clone();
58 move |event, cx| {
59 let workspace = &event.0;
60 if let Some(workspace) = workspace.upgrade(cx) {
61 let project = workspace.read(cx).project().clone();
62 if project.read(cx).is_local() {
63 vector_store.update(cx, |store, cx| {
64 store.add_project(project, cx).detach();
65 });
66 }
67 }
68 }
69 })
70 .detach();
71
72 cx.add_action({
73 move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
74 let vector_store = vector_store.clone();
75 workspace.toggle_modal(cx, |workspace, cx| {
76 let project = workspace.project().clone();
77 let workspace = cx.weak_handle();
78 cx.add_view(|cx| {
79 SemanticSearch::new(
80 SemanticSearchDelegate::new(workspace, project, vector_store),
81 cx,
82 )
83 })
84 })
85 }
86 });
87
88 SemanticSearch::init(cx);
89}
90
91#[derive(Debug)]
92pub struct IndexedFile {
93 path: PathBuf,
94 sha1: FileSha1,
95 documents: Vec<Document>,
96}
97
98pub struct VectorStore {
99 fs: Arc<dyn Fs>,
100 database_url: Arc<PathBuf>,
101 embedding_provider: Arc<dyn EmbeddingProvider>,
102 language_registry: Arc<LanguageRegistry>,
103 worktree_db_ids: Vec<(WorktreeId, i64)>,
104}
105
106#[derive(Debug, Clone)]
107pub struct SearchResult {
108 pub worktree_id: WorktreeId,
109 pub name: String,
110 pub offset: usize,
111 pub file_path: PathBuf,
112}
113
114impl VectorStore {
115 fn new(
116 fs: Arc<dyn Fs>,
117 database_url: PathBuf,
118 embedding_provider: Arc<dyn EmbeddingProvider>,
119 language_registry: Arc<LanguageRegistry>,
120 ) -> Self {
121 Self {
122 fs,
123 database_url: Arc::new(database_url),
124 embedding_provider,
125 language_registry,
126 worktree_db_ids: Vec::new(),
127 }
128 }
129
130 async fn index_file(
131 cursor: &mut QueryCursor,
132 parser: &mut Parser,
133 embedding_provider: &dyn EmbeddingProvider,
134 language: Arc<Language>,
135 file_path: PathBuf,
136 content: String,
137 ) -> Result<IndexedFile> {
138 let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
139 let outline_config = grammar
140 .outline_config
141 .as_ref()
142 .ok_or_else(|| anyhow!("no outline query"))?;
143
144 parser.set_language(grammar.ts_language).unwrap();
145 let tree = parser
146 .parse(&content, None)
147 .ok_or_else(|| anyhow!("parsing failed"))?;
148
149 let mut documents = Vec::new();
150 let mut context_spans = Vec::new();
151 for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
152 let mut item_range = None;
153 let mut name_range = None;
154 for capture in mat.captures {
155 if capture.index == outline_config.item_capture_ix {
156 item_range = Some(capture.node.byte_range());
157 } else if capture.index == outline_config.name_capture_ix {
158 name_range = Some(capture.node.byte_range());
159 }
160 }
161
162 if let Some((item_range, name_range)) = item_range.zip(name_range) {
163 if let Some((item, name)) =
164 content.get(item_range.clone()).zip(content.get(name_range))
165 {
166 context_spans.push(item);
167 documents.push(Document {
168 name: name.to_string(),
169 offset: item_range.start,
170 embedding: Vec::new(),
171 });
172 }
173 }
174 }
175
176 if !documents.is_empty() {
177 let embeddings = embedding_provider.embed_batch(context_spans).await?;
178 for (document, embedding) in documents.iter_mut().zip(embeddings) {
179 document.embedding = embedding;
180 }
181 }
182
183 let sha1 = FileSha1::from_str(content);
184
185 return Ok(IndexedFile {
186 path: file_path,
187 sha1,
188 documents,
189 });
190 }
191
192 fn add_project(
193 &mut self,
194 project: ModelHandle<Project>,
195 cx: &mut ModelContext<Self>,
196 ) -> Task<Result<()>> {
197 let worktree_scans_complete = project
198 .read(cx)
199 .worktrees(cx)
200 .map(|worktree| {
201 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
202 async move {
203 scan_complete.await;
204 }
205 })
206 .collect::<Vec<_>>();
207
208 let fs = self.fs.clone();
209 let language_registry = self.language_registry.clone();
210 let embedding_provider = self.embedding_provider.clone();
211 let database_url = self.database_url.clone();
212
213 cx.spawn(|this, mut cx| async move {
214 futures::future::join_all(worktree_scans_complete).await;
215
216 if let Some(db_directory) = database_url.parent() {
217 fs.create_dir(db_directory).await.log_err();
218 }
219 let db = VectorDatabase::new(database_url.to_string_lossy().into())?;
220
221 let worktrees = project.read_with(&cx, |project, cx| {
222 project
223 .worktrees(cx)
224 .map(|worktree| worktree.read(cx).snapshot())
225 .collect::<Vec<_>>()
226 });
227
228 // Here we query the worktree ids, and yet we dont have them elsewhere
229 // We likely want to clean up these datastructures
230 let (db, worktree_hashes, worktree_db_ids) = cx
231 .background()
232 .spawn({
233 let worktrees = worktrees.clone();
234 async move {
235 let mut worktree_db_ids: HashMap<WorktreeId, i64> = HashMap::new();
236 let mut hashes: HashMap<WorktreeId, HashMap<PathBuf, FileSha1>> =
237 HashMap::new();
238 for worktree in worktrees {
239 let worktree_db_id =
240 db.find_or_create_worktree(worktree.abs_path().as_ref())?;
241 worktree_db_ids.insert(worktree.id(), worktree_db_id);
242 hashes.insert(worktree.id(), db.get_file_hashes(worktree_db_id)?);
243 }
244 anyhow::Ok((db, hashes, worktree_db_ids))
245 }
246 })
247 .await?;
248
249 let (paths_tx, paths_rx) =
250 channel::unbounded::<(i64, PathBuf, String, Arc<Language>)>();
251 let (delete_paths_tx, delete_paths_rx) = channel::unbounded::<(i64, PathBuf)>();
252 let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>();
253 cx.background()
254 .spawn({
255 let fs = fs.clone();
256 let worktree_db_ids = worktree_db_ids.clone();
257 async move {
258 for worktree in worktrees.into_iter() {
259 let file_hashes = &worktree_hashes[&worktree.id()];
260 let mut files_included =
261 file_hashes.keys().collect::<HashSet<&PathBuf>>();
262 for file in worktree.files(false, 0) {
263 let absolute_path = worktree.absolutize(&file.path);
264
265 if let Ok(language) = language_registry
266 .language_for_file(&absolute_path, None)
267 .await
268 {
269 if language.name().as_ref() != "Rust" {
270 continue;
271 }
272
273 if let Some(content) = fs.load(&absolute_path).await.log_err() {
274 let path_buf = file.path.to_path_buf();
275 let already_stored = file_hashes.get(&path_buf).map_or(
276 false,
277 |existing_hash| {
278 files_included.remove(&path_buf);
279 existing_hash.equals(&content)
280 },
281 );
282
283 if !already_stored {
284 paths_tx
285 .try_send((
286 worktree_db_ids[&worktree.id()],
287 path_buf,
288 content,
289 language,
290 ))
291 .unwrap();
292 }
293 }
294 }
295 }
296 for file in files_included {
297 delete_paths_tx
298 .try_send((worktree_db_ids[&worktree.id()], file.to_owned()))
299 .unwrap();
300 }
301 }
302 }
303 })
304 .detach();
305
306 let db_update_task = cx.background().spawn(
307 async move {
308 // Inserting all new files
309 while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await {
310 log::info!("Inserting File: {:?}", &indexed_file.path);
311 db.insert_file(worktree_id, indexed_file).log_err();
312 }
313
314 // Deleting all old files
315 while let Ok((worktree_id, delete_path)) = delete_paths_rx.recv().await {
316 log::info!("Deleting File: {:?}", &delete_path);
317 db.delete_file(worktree_id, delete_path).log_err();
318 }
319
320 anyhow::Ok(())
321 }
322 .log_err(),
323 );
324
325 cx.background()
326 .scoped(|scope| {
327 for _ in 0..cx.background().num_cpus() {
328 scope.spawn(async {
329 let mut parser = Parser::new();
330 let mut cursor = QueryCursor::new();
331 while let Ok((worktree_id, file_path, content, language)) =
332 paths_rx.recv().await
333 {
334 if let Some(indexed_file) = Self::index_file(
335 &mut cursor,
336 &mut parser,
337 embedding_provider.as_ref(),
338 language,
339 file_path,
340 content,
341 )
342 .await
343 .log_err()
344 {
345 indexed_files_tx
346 .try_send((worktree_id, indexed_file))
347 .unwrap();
348 }
349 }
350 });
351 }
352 })
353 .await;
354 drop(indexed_files_tx);
355
356 db_update_task.await;
357
358 this.update(&mut cx, |this, _| {
359 this.worktree_db_ids.extend(worktree_db_ids);
360 });
361
362 anyhow::Ok(())
363 })
364 }
365
366 pub fn search(
367 &mut self,
368 project: &ModelHandle<Project>,
369 phrase: String,
370 limit: usize,
371 cx: &mut ModelContext<Self>,
372 ) -> Task<Result<Vec<SearchResult>>> {
373 let project = project.read(cx);
374 let worktree_db_ids = project
375 .worktrees(cx)
376 .filter_map(|worktree| {
377 let worktree_id = worktree.read(cx).id();
378 self.worktree_db_ids.iter().find_map(|(id, db_id)| {
379 if *id == worktree_id {
380 Some(*db_id)
381 } else {
382 None
383 }
384 })
385 })
386 .collect::<Vec<_>>();
387
388 let embedding_provider = self.embedding_provider.clone();
389 let database_url = self.database_url.clone();
390 cx.spawn(|this, cx| async move {
391 let documents = cx
392 .background()
393 .spawn(async move {
394 let database = VectorDatabase::new(database_url.to_string_lossy().into())?;
395
396 let phrase_embedding = embedding_provider
397 .embed_batch(vec![&phrase])
398 .await?
399 .into_iter()
400 .next()
401 .unwrap();
402
403 let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
404 database.for_each_document(&worktree_db_ids, |id, embedding| {
405 let similarity = dot(&embedding.0, &phrase_embedding);
406 let ix = match results.binary_search_by(|(_, s)| {
407 similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
408 }) {
409 Ok(ix) => ix,
410 Err(ix) => ix,
411 };
412 results.insert(ix, (id, similarity));
413 results.truncate(limit);
414 })?;
415
416 let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
417 database.get_documents_by_ids(&ids)
418 })
419 .await?;
420
421 let results = this.read_with(&cx, |this, _| {
422 documents
423 .into_iter()
424 .filter_map(|(worktree_db_id, file_path, offset, name)| {
425 let worktree_id = this.worktree_db_ids.iter().find_map(|(id, db_id)| {
426 if *db_id == worktree_db_id {
427 Some(*id)
428 } else {
429 None
430 }
431 })?;
432 Some(SearchResult {
433 worktree_id,
434 name,
435 offset,
436 file_path,
437 })
438 })
439 .collect()
440 });
441
442 anyhow::Ok(results)
443 })
444 }
445}
446
447impl Entity for VectorStore {
448 type Event = ();
449}
450
451fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
452 let len = vec_a.len();
453 assert_eq!(len, vec_b.len());
454
455 let mut result = 0.0;
456 unsafe {
457 matrixmultiply::sgemm(
458 1,
459 len,
460 1,
461 1.0,
462 vec_a.as_ptr(),
463 len as isize,
464 1,
465 vec_b.as_ptr(),
466 1,
467 len as isize,
468 0.0,
469 &mut result as *mut f32,
470 1,
471 1,
472 );
473 }
474 result
475}