1mod db;
2mod embedding;
3mod modal;
4
5#[cfg(test)]
6mod vector_store_tests;
7
8use anyhow::{anyhow, Result};
9use db::{FileSha1, VectorDatabase, VECTOR_DB_URL};
10use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
11use gpui::{actions, AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext};
12use language::{Language, LanguageRegistry};
13use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
14use project::{Fs, Project, WorktreeId};
15use smol::channel;
16use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc};
17use tree_sitter::{Parser, QueryCursor};
18use util::{http::HttpClient, ResultExt, TryFutureExt};
19use workspace::{Workspace, WorkspaceCreated};
20
21#[derive(Debug)]
22pub struct Document {
23 pub offset: usize,
24 pub name: String,
25 pub embedding: Vec<f32>,
26}
27
28pub fn init(
29 fs: Arc<dyn Fs>,
30 http_client: Arc<dyn HttpClient>,
31 language_registry: Arc<LanguageRegistry>,
32 cx: &mut AppContext,
33) {
34 let vector_store = cx.add_model(|_| {
35 VectorStore::new(
36 fs,
37 VECTOR_DB_URL.to_string(),
38 // Arc::new(DummyEmbeddings {}),
39 Arc::new(OpenAIEmbeddings {
40 client: http_client,
41 }),
42 language_registry,
43 )
44 });
45
46 cx.subscribe_global::<WorkspaceCreated, _>({
47 let vector_store = vector_store.clone();
48 move |event, cx| {
49 let workspace = &event.0;
50 if let Some(workspace) = workspace.upgrade(cx) {
51 let project = workspace.read(cx).project().clone();
52 if project.read(cx).is_local() {
53 vector_store.update(cx, |store, cx| {
54 store.add_project(project, cx).detach();
55 });
56 }
57 }
58 }
59 })
60 .detach();
61
62 cx.add_action({
63 move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
64 let vector_store = vector_store.clone();
65 workspace.toggle_modal(cx, |workspace, cx| {
66 let project = workspace.project().clone();
67 let workspace = cx.weak_handle();
68 cx.add_view(|cx| {
69 SemanticSearch::new(
70 SemanticSearchDelegate::new(workspace, project, vector_store),
71 cx,
72 )
73 })
74 })
75 }
76 });
77
78 SemanticSearch::init(cx);
79}
80
81#[derive(Debug)]
82pub struct IndexedFile {
83 path: PathBuf,
84 sha1: FileSha1,
85 documents: Vec<Document>,
86}
87
88pub struct VectorStore {
89 fs: Arc<dyn Fs>,
90 database_url: Arc<str>,
91 embedding_provider: Arc<dyn EmbeddingProvider>,
92 language_registry: Arc<LanguageRegistry>,
93 worktree_db_ids: Vec<(WorktreeId, i64)>,
94}
95
96#[derive(Debug)]
97pub struct SearchResult {
98 pub worktree_id: WorktreeId,
99 pub name: String,
100 pub offset: usize,
101 pub file_path: PathBuf,
102}
103
104impl VectorStore {
105 fn new(
106 fs: Arc<dyn Fs>,
107 database_url: String,
108 embedding_provider: Arc<dyn EmbeddingProvider>,
109 language_registry: Arc<LanguageRegistry>,
110 ) -> Self {
111 Self {
112 fs,
113 database_url: database_url.into(),
114 embedding_provider,
115 language_registry,
116 worktree_db_ids: Vec::new(),
117 }
118 }
119
120 async fn index_file(
121 cursor: &mut QueryCursor,
122 parser: &mut Parser,
123 embedding_provider: &dyn EmbeddingProvider,
124 language: Arc<Language>,
125 file_path: PathBuf,
126 content: String,
127 ) -> Result<IndexedFile> {
128 let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
129 let outline_config = grammar
130 .outline_config
131 .as_ref()
132 .ok_or_else(|| anyhow!("no outline query"))?;
133
134 parser.set_language(grammar.ts_language).unwrap();
135 let tree = parser
136 .parse(&content, None)
137 .ok_or_else(|| anyhow!("parsing failed"))?;
138
139 let mut documents = Vec::new();
140 let mut context_spans = Vec::new();
141 for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
142 let mut item_range = None;
143 let mut name_range = None;
144 for capture in mat.captures {
145 if capture.index == outline_config.item_capture_ix {
146 item_range = Some(capture.node.byte_range());
147 } else if capture.index == outline_config.name_capture_ix {
148 name_range = Some(capture.node.byte_range());
149 }
150 }
151
152 if let Some((item_range, name_range)) = item_range.zip(name_range) {
153 if let Some((item, name)) =
154 content.get(item_range.clone()).zip(content.get(name_range))
155 {
156 context_spans.push(item);
157 documents.push(Document {
158 name: name.to_string(),
159 offset: item_range.start,
160 embedding: Vec::new(),
161 });
162 }
163 }
164 }
165
166 if !documents.is_empty() {
167 let embeddings = embedding_provider.embed_batch(context_spans).await?;
168 for (document, embedding) in documents.iter_mut().zip(embeddings) {
169 document.embedding = embedding;
170 }
171 }
172
173 let sha1 = FileSha1::from_str(content);
174
175 return Ok(IndexedFile {
176 path: file_path,
177 sha1,
178 documents,
179 });
180 }
181
182 fn add_project(
183 &mut self,
184 project: ModelHandle<Project>,
185 cx: &mut ModelContext<Self>,
186 ) -> Task<Result<()>> {
187 let worktree_scans_complete = project
188 .read(cx)
189 .worktrees(cx)
190 .map(|worktree| {
191 let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
192 async move {
193 scan_complete.await;
194 log::info!("worktree scan completed");
195 }
196 })
197 .collect::<Vec<_>>();
198
199 let fs = self.fs.clone();
200 let language_registry = self.language_registry.clone();
201 let embedding_provider = self.embedding_provider.clone();
202 let database_url = self.database_url.clone();
203
204 cx.spawn(|this, mut cx| async move {
205 futures::future::join_all(worktree_scans_complete).await;
206
207 // TODO: remove this after fixing the bug in scan_complete
208 cx.background()
209 .timer(std::time::Duration::from_secs(3))
210 .await;
211
212 let db = VectorDatabase::new(&database_url)?;
213
214 let worktrees = project.read_with(&cx, |project, cx| {
215 project
216 .worktrees(cx)
217 .map(|worktree| worktree.read(cx).snapshot())
218 .collect::<Vec<_>>()
219 });
220
221 // Here we query the worktree ids, and yet we dont have them elsewhere
222 // We likely want to clean up these datastructures
223 let (db, worktree_hashes, worktree_db_ids) = cx
224 .background()
225 .spawn({
226 let worktrees = worktrees.clone();
227 async move {
228 let mut worktree_db_ids: HashMap<WorktreeId, i64> = HashMap::new();
229 let mut hashes: HashMap<WorktreeId, HashMap<PathBuf, FileSha1>> =
230 HashMap::new();
231 for worktree in worktrees {
232 let worktree_db_id =
233 db.find_or_create_worktree(worktree.abs_path().as_ref())?;
234 worktree_db_ids.insert(worktree.id(), worktree_db_id);
235 hashes.insert(worktree.id(), db.get_file_hashes(worktree_db_id)?);
236 }
237 anyhow::Ok((db, hashes, worktree_db_ids))
238 }
239 })
240 .await?;
241
242 let (paths_tx, paths_rx) =
243 channel::unbounded::<(i64, PathBuf, String, Arc<Language>)>();
244 let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>();
245 cx.background()
246 .spawn({
247 let fs = fs.clone();
248 let worktree_db_ids = worktree_db_ids.clone();
249 async move {
250 for worktree in worktrees.into_iter() {
251 let file_hashes = &worktree_hashes[&worktree.id()];
252 for file in worktree.files(false, 0) {
253 let absolute_path = worktree.absolutize(&file.path);
254
255 if let Ok(language) = language_registry
256 .language_for_file(&absolute_path, None)
257 .await
258 {
259 if language.name().as_ref() != "Rust" {
260 continue;
261 }
262
263 if let Some(content) = fs.load(&absolute_path).await.log_err() {
264 log::info!("loaded file: {absolute_path:?}");
265
266 let path_buf = file.path.to_path_buf();
267 let already_stored = file_hashes
268 .get(&path_buf)
269 .map_or(false, |existing_hash| {
270 existing_hash.equals(&content)
271 });
272
273 if !already_stored {
274 log::info!(
275 "File Changed (Sending to Parse): {:?}",
276 &path_buf
277 );
278 paths_tx
279 .try_send((
280 worktree_db_ids[&worktree.id()],
281 path_buf,
282 content,
283 language,
284 ))
285 .unwrap();
286 }
287 }
288 }
289 }
290 }
291 }
292 })
293 .detach();
294
295 let db_write_task = cx.background().spawn(
296 async move {
297 while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await {
298 db.insert_file(worktree_id, indexed_file).log_err();
299 }
300
301 anyhow::Ok(())
302 }
303 .log_err(),
304 );
305
306 cx.background()
307 .scoped(|scope| {
308 for _ in 0..cx.background().num_cpus() {
309 scope.spawn(async {
310 let mut parser = Parser::new();
311 let mut cursor = QueryCursor::new();
312 while let Ok((worktree_id, file_path, content, language)) =
313 paths_rx.recv().await
314 {
315 if let Some(indexed_file) = Self::index_file(
316 &mut cursor,
317 &mut parser,
318 embedding_provider.as_ref(),
319 language,
320 file_path,
321 content,
322 )
323 .await
324 .log_err()
325 {
326 indexed_files_tx
327 .try_send((worktree_id, indexed_file))
328 .unwrap();
329 }
330 }
331 });
332 }
333 })
334 .await;
335 drop(indexed_files_tx);
336
337 db_write_task.await;
338
339 this.update(&mut cx, |this, _| {
340 this.worktree_db_ids.extend(worktree_db_ids);
341 });
342
343 anyhow::Ok(())
344 })
345 }
346
347 pub fn search(
348 &mut self,
349 project: &ModelHandle<Project>,
350 phrase: String,
351 limit: usize,
352 cx: &mut ModelContext<Self>,
353 ) -> Task<Result<Vec<SearchResult>>> {
354 let project = project.read(cx);
355 let worktree_db_ids = project
356 .worktrees(cx)
357 .filter_map(|worktree| {
358 let worktree_id = worktree.read(cx).id();
359 self.worktree_db_ids.iter().find_map(|(id, db_id)| {
360 if *id == worktree_id {
361 Some(*db_id)
362 } else {
363 None
364 }
365 })
366 })
367 .collect::<Vec<_>>();
368
369 let embedding_provider = self.embedding_provider.clone();
370 let database_url = self.database_url.clone();
371 cx.spawn(|this, cx| async move {
372 let documents = cx
373 .background()
374 .spawn(async move {
375 let database = VectorDatabase::new(database_url.as_ref())?;
376
377 let phrase_embedding = embedding_provider
378 .embed_batch(vec![&phrase])
379 .await?
380 .into_iter()
381 .next()
382 .unwrap();
383
384 let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
385 database.for_each_document(&worktree_db_ids, |id, embedding| {
386 let similarity = dot(&embedding.0, &phrase_embedding);
387 let ix = match results.binary_search_by(|(_, s)| {
388 similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
389 }) {
390 Ok(ix) => ix,
391 Err(ix) => ix,
392 };
393 results.insert(ix, (id, similarity));
394 results.truncate(limit);
395 })?;
396
397 let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
398 database.get_documents_by_ids(&ids)
399 })
400 .await?;
401
402 let results = this.read_with(&cx, |this, _| {
403 documents
404 .into_iter()
405 .filter_map(|(worktree_db_id, file_path, offset, name)| {
406 let worktree_id = this.worktree_db_ids.iter().find_map(|(id, db_id)| {
407 if *db_id == worktree_db_id {
408 Some(*id)
409 } else {
410 None
411 }
412 })?;
413 Some(SearchResult {
414 worktree_id,
415 name,
416 offset,
417 file_path,
418 })
419 })
420 .collect()
421 });
422
423 anyhow::Ok(results)
424 })
425 }
426}
427
428impl Entity for VectorStore {
429 type Event = ();
430}
431
432fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
433 let len = vec_a.len();
434 assert_eq!(len, vec_b.len());
435
436 let mut result = 0.0;
437 unsafe {
438 matrixmultiply::sgemm(
439 1,
440 len,
441 1,
442 1.0,
443 vec_a.as_ptr(),
444 len as isize,
445 1,
446 vec_b.as_ptr(),
447 1,
448 len as isize,
449 0.0,
450 &mut result as *mut f32,
451 1,
452 1,
453 );
454 }
455 result
456}