1mod db;
2mod embedding;
3mod modal;
4
5#[cfg(test)]
6mod vector_store_tests;
7
8use anyhow::{anyhow, Result};
9use db::{FileSha1, VectorDatabase};
10use embedding::{EmbeddingProvider, OpenAIEmbeddings};
11use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext};
12use language::{Language, LanguageRegistry};
13use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
14use project::{Fs, Project, WorktreeId};
15use smol::channel;
16use std::{
17 cmp::Ordering,
18 collections::HashMap,
19 path::{Path, PathBuf},
20 sync::Arc,
21};
22use tree_sitter::{Parser, QueryCursor};
23use util::{
24 channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt, TryFutureExt,
25};
26use workspace::{Workspace, WorkspaceCreated};
27
28#[derive(Debug)]
29pub struct Document {
30 pub offset: usize,
31 pub name: String,
32 pub embedding: Vec<f32>,
33}
34
35pub fn init(
36 fs: Arc<dyn Fs>,
37 http_client: Arc<dyn HttpClient>,
38 language_registry: Arc<LanguageRegistry>,
39 cx: &mut AppContext,
40) {
41 let db_file_path = EMBEDDINGS_DIR
42 .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
43 .join("embeddings_db");
44
45 let vector_store = cx.add_model(|_| {
46 VectorStore::new(
47 fs,
48 db_file_path,
49 Arc::new(OpenAIEmbeddings {
50 client: http_client,
51 }),
52 language_registry,
53 )
54 });
55
56 cx.subscribe_global::<WorkspaceCreated, _>({
57 let vector_store = vector_store.clone();
58 move |event, cx| {
59 let workspace = &event.0;
60 if let Some(workspace) = workspace.upgrade(cx) {
61 let project = workspace.read(cx).project().clone();
62 if project.read(cx).is_local() {
63 vector_store.update(cx, |store, cx| {
64 store.add_project(project, cx).detach();
65 });
66 }
67 }
68 }
69 })
70 .detach();
71
72 cx.add_action({
73 move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
74 let vector_store = vector_store.clone();
75 workspace.toggle_modal(cx, |workspace, cx| {
76 let project = workspace.project().clone();
77 let workspace = cx.weak_handle();
78 cx.add_view(|cx| {
79 SemanticSearch::new(
80 SemanticSearchDelegate::new(workspace, project, vector_store),
81 cx,
82 )
83 })
84 })
85 }
86 });
87
88 SemanticSearch::init(cx);
89}
90
91#[derive(Debug)]
92pub struct IndexedFile {
93 path: PathBuf,
94 sha1: FileSha1,
95 documents: Vec<Document>,
96}
97
98pub struct VectorStore {
99 fs: Arc<dyn Fs>,
100 database_url: Arc<PathBuf>,
101 embedding_provider: Arc<dyn EmbeddingProvider>,
102 language_registry: Arc<LanguageRegistry>,
103 worktree_db_ids: Vec<(WorktreeId, i64)>,
104}
105
106#[derive(Debug, Clone)]
107pub struct SearchResult {
108 pub worktree_id: WorktreeId,
109 pub name: String,
110 pub offset: usize,
111 pub file_path: PathBuf,
112}
113
114impl VectorStore {
115 fn new(
116 fs: Arc<dyn Fs>,
117 database_url: PathBuf,
118 embedding_provider: Arc<dyn EmbeddingProvider>,
119 language_registry: Arc<LanguageRegistry>,
120 ) -> Self {
121 Self {
122 fs,
123 database_url: Arc::new(database_url),
124 embedding_provider,
125 language_registry,
126 worktree_db_ids: Vec::new(),
127 }
128 }
129
130 async fn index_file(
131 cursor: &mut QueryCursor,
132 parser: &mut Parser,
133 embedding_provider: &dyn EmbeddingProvider,
134 language: Arc<Language>,
135 file_path: PathBuf,
136 content: String,
137 ) -> Result<IndexedFile> {
138 let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
139 let outline_config = grammar
140 .outline_config
141 .as_ref()
142 .ok_or_else(|| anyhow!("no outline query"))?;
143
144 parser.set_language(grammar.ts_language).unwrap();
145 let tree = parser
146 .parse(&content, None)
147 .ok_or_else(|| anyhow!("parsing failed"))?;
148
149 let mut documents = Vec::new();
150 let mut context_spans = Vec::new();
151 for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
152 let mut item_range = None;
153 let mut name_range = None;
154 for capture in mat.captures {
155 if capture.index == outline_config.item_capture_ix {
156 item_range = Some(capture.node.byte_range());
157 } else if capture.index == outline_config.name_capture_ix {
158 name_range = Some(capture.node.byte_range());
159 }
160 }
161
162 if let Some((item_range, name_range)) = item_range.zip(name_range) {
163 if let Some((item, name)) =
164 content.get(item_range.clone()).zip(content.get(name_range))
165 {
166 context_spans.push(item);
167 documents.push(Document {
168 name: name.to_string(),
169 offset: item_range.start,
170 embedding: Vec::new(),
171 });
172 }
173 }
174 }
175
176 if !documents.is_empty() {
177 let embeddings = embedding_provider.embed_batch(context_spans).await?;
178 for (document, embedding) in documents.iter_mut().zip(embeddings) {
179 document.embedding = embedding;
180 }
181 }
182
183 let sha1 = FileSha1::from_str(content);
184
185 return Ok(IndexedFile {
186 path: file_path,
187 sha1,
188 documents,
189 });
190 }
191
192 fn add_project(
193 &mut self,
194 project: ModelHandle<Project>,
195 cx: &mut ModelContext<Self>,
196 ) -> Task<Result<()>> {
197 let worktree_scans_complete = project
198 .read(cx)
199 .worktrees(cx)
200 .map(|worktree| {
201 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
202 async move {
203 scan_complete.await;
204 log::info!("worktree scan completed");
205 }
206 })
207 .collect::<Vec<_>>();
208
209 let fs = self.fs.clone();
210 let language_registry = self.language_registry.clone();
211 let embedding_provider = self.embedding_provider.clone();
212 let database_url = self.database_url.clone();
213
214 cx.spawn(|this, mut cx| async move {
215 futures::future::join_all(worktree_scans_complete).await;
216
217 // TODO: remove this after fixing the bug in scan_complete
218 cx.background()
219 .timer(std::time::Duration::from_secs(3))
220 .await;
221
222 if let Some(db_directory) = database_url.parent() {
223 fs.create_dir(db_directory).await.log_err();
224 }
225 let db = VectorDatabase::new(database_url.to_string_lossy().into())?;
226
227 let worktrees = project.read_with(&cx, |project, cx| {
228 project
229 .worktrees(cx)
230 .map(|worktree| worktree.read(cx).snapshot())
231 .collect::<Vec<_>>()
232 });
233
234 // Here we query the worktree ids, and yet we dont have them elsewhere
235 // We likely want to clean up these datastructures
236 let (db, worktree_hashes, worktree_db_ids) = cx
237 .background()
238 .spawn({
239 let worktrees = worktrees.clone();
240 async move {
241 let mut worktree_db_ids: HashMap<WorktreeId, i64> = HashMap::new();
242 let mut hashes: HashMap<WorktreeId, HashMap<PathBuf, FileSha1>> =
243 HashMap::new();
244 for worktree in worktrees {
245 let worktree_db_id =
246 db.find_or_create_worktree(worktree.abs_path().as_ref())?;
247 worktree_db_ids.insert(worktree.id(), worktree_db_id);
248 hashes.insert(worktree.id(), db.get_file_hashes(worktree_db_id)?);
249 }
250 anyhow::Ok((db, hashes, worktree_db_ids))
251 }
252 })
253 .await?;
254
255 let (paths_tx, paths_rx) =
256 channel::unbounded::<(i64, PathBuf, String, Arc<Language>)>();
257 let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>();
258 cx.background()
259 .spawn({
260 let fs = fs.clone();
261 let worktree_db_ids = worktree_db_ids.clone();
262 async move {
263 for worktree in worktrees.into_iter() {
264 let file_hashes = &worktree_hashes[&worktree.id()];
265 for file in worktree.files(false, 0) {
266 let absolute_path = worktree.absolutize(&file.path);
267
268 if let Ok(language) = language_registry
269 .language_for_file(&absolute_path, None)
270 .await
271 {
272 if language.name().as_ref() != "Rust" {
273 continue;
274 }
275
276 if let Some(content) = fs.load(&absolute_path).await.log_err() {
277 log::info!("loaded file: {absolute_path:?}");
278
279 let path_buf = file.path.to_path_buf();
280 let already_stored = file_hashes
281 .get(&path_buf)
282 .map_or(false, |existing_hash| {
283 existing_hash.equals(&content)
284 });
285
286 if !already_stored {
287 log::info!(
288 "File Changed (Sending to Parse): {:?}",
289 &path_buf
290 );
291 paths_tx
292 .try_send((
293 worktree_db_ids[&worktree.id()],
294 path_buf,
295 content,
296 language,
297 ))
298 .unwrap();
299 }
300 }
301 }
302 }
303 }
304 }
305 })
306 .detach();
307
308 let db_write_task = cx.background().spawn(
309 async move {
310 while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await {
311 db.insert_file(worktree_id, indexed_file).log_err();
312 }
313
314 anyhow::Ok(())
315 }
316 .log_err(),
317 );
318
319 cx.background()
320 .scoped(|scope| {
321 for _ in 0..cx.background().num_cpus() {
322 scope.spawn(async {
323 let mut parser = Parser::new();
324 let mut cursor = QueryCursor::new();
325 while let Ok((worktree_id, file_path, content, language)) =
326 paths_rx.recv().await
327 {
328 if let Some(indexed_file) = Self::index_file(
329 &mut cursor,
330 &mut parser,
331 embedding_provider.as_ref(),
332 language,
333 file_path,
334 content,
335 )
336 .await
337 .log_err()
338 {
339 indexed_files_tx
340 .try_send((worktree_id, indexed_file))
341 .unwrap();
342 }
343 }
344 });
345 }
346 })
347 .await;
348 drop(indexed_files_tx);
349
350 db_write_task.await;
351
352 this.update(&mut cx, |this, _| {
353 this.worktree_db_ids.extend(worktree_db_ids);
354 });
355
356 anyhow::Ok(())
357 })
358 }
359
360 pub fn search(
361 &mut self,
362 project: &ModelHandle<Project>,
363 phrase: String,
364 limit: usize,
365 cx: &mut ModelContext<Self>,
366 ) -> Task<Result<Vec<SearchResult>>> {
367 let project = project.read(cx);
368 let worktree_db_ids = project
369 .worktrees(cx)
370 .filter_map(|worktree| {
371 let worktree_id = worktree.read(cx).id();
372 self.worktree_db_ids.iter().find_map(|(id, db_id)| {
373 if *id == worktree_id {
374 Some(*db_id)
375 } else {
376 None
377 }
378 })
379 })
380 .collect::<Vec<_>>();
381
382 let embedding_provider = self.embedding_provider.clone();
383 let database_url = self.database_url.clone();
384 cx.spawn(|this, cx| async move {
385 let documents = cx
386 .background()
387 .spawn(async move {
388 let database = VectorDatabase::new(database_url.to_string_lossy().into())?;
389
390 let phrase_embedding = embedding_provider
391 .embed_batch(vec![&phrase])
392 .await?
393 .into_iter()
394 .next()
395 .unwrap();
396
397 let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
398 database.for_each_document(&worktree_db_ids, |id, embedding| {
399 let similarity = dot(&embedding.0, &phrase_embedding);
400 let ix = match results.binary_search_by(|(_, s)| {
401 similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
402 }) {
403 Ok(ix) => ix,
404 Err(ix) => ix,
405 };
406 results.insert(ix, (id, similarity));
407 results.truncate(limit);
408 })?;
409
410 let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
411 database.get_documents_by_ids(&ids)
412 })
413 .await?;
414
415 let results = this.read_with(&cx, |this, _| {
416 documents
417 .into_iter()
418 .filter_map(|(worktree_db_id, file_path, offset, name)| {
419 let worktree_id = this.worktree_db_ids.iter().find_map(|(id, db_id)| {
420 if *db_id == worktree_db_id {
421 Some(*id)
422 } else {
423 None
424 }
425 })?;
426 Some(SearchResult {
427 worktree_id,
428 name,
429 offset,
430 file_path,
431 })
432 })
433 .collect()
434 });
435
436 anyhow::Ok(results)
437 })
438 }
439}
440
441impl Entity for VectorStore {
442 type Event = ();
443}
444
445fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
446 let len = vec_a.len();
447 assert_eq!(len, vec_b.len());
448
449 let mut result = 0.0;
450 unsafe {
451 matrixmultiply::sgemm(
452 1,
453 len,
454 1,
455 1.0,
456 vec_a.as_ptr(),
457 len as isize,
458 1,
459 vec_b.as_ptr(),
460 1,
461 len as isize,
462 0.0,
463 &mut result as *mut f32,
464 1,
465 1,
466 );
467 }
468 result
469}