1mod db;
2mod embedding;
3mod parsing;
4mod search;
5
6use anyhow::{anyhow, Result};
7use db::VectorDatabase;
8use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
9use gpui::{AppContext, Entity, ModelContext, ModelHandle};
10use language::LanguageRegistry;
11use parsing::Document;
12use project::{Fs, Project};
13use search::{BruteForceSearch, VectorSearch};
14use smol::channel;
15use std::{path::PathBuf, sync::Arc, time::Instant};
16use tree_sitter::{Parser, QueryCursor};
17use util::{http::HttpClient, ResultExt, TryFutureExt};
18use workspace::WorkspaceCreated;
19
20pub fn init(
21 fs: Arc<dyn Fs>,
22 http_client: Arc<dyn HttpClient>,
23 language_registry: Arc<LanguageRegistry>,
24 cx: &mut AppContext,
25) {
26 let vector_store = cx.add_model(|cx| VectorStore::new(fs, http_client, language_registry));
27
28 cx.subscribe_global::<WorkspaceCreated, _>({
29 let vector_store = vector_store.clone();
30 move |event, cx| {
31 let workspace = &event.0;
32 if let Some(workspace) = workspace.upgrade(cx) {
33 let project = workspace.read(cx).project().clone();
34 if project.read(cx).is_local() {
35 vector_store.update(cx, |store, cx| {
36 store.add_project(project, cx);
37 });
38 }
39 }
40 }
41 })
42 .detach();
43}
44
45#[derive(Debug)]
46pub struct IndexedFile {
47 path: PathBuf,
48 sha1: String,
49 documents: Vec<Document>,
50}
51
52struct SearchResult {
53 path: PathBuf,
54 offset: usize,
55 name: String,
56 distance: f32,
57}
58
59struct VectorStore {
60 fs: Arc<dyn Fs>,
61 http_client: Arc<dyn HttpClient>,
62 language_registry: Arc<LanguageRegistry>,
63}
64
65impl VectorStore {
66 fn new(
67 fs: Arc<dyn Fs>,
68 http_client: Arc<dyn HttpClient>,
69 language_registry: Arc<LanguageRegistry>,
70 ) -> Self {
71 Self {
72 fs,
73 http_client,
74 language_registry,
75 }
76 }
77
78 async fn index_file(
79 cursor: &mut QueryCursor,
80 parser: &mut Parser,
81 embedding_provider: &dyn EmbeddingProvider,
82 fs: &Arc<dyn Fs>,
83 language_registry: &Arc<LanguageRegistry>,
84 file_path: PathBuf,
85 ) -> Result<IndexedFile> {
86 let language = language_registry
87 .language_for_file(&file_path, None)
88 .await?;
89
90 if language.name().as_ref() != "Rust" {
91 Err(anyhow!("unsupported language"))?;
92 }
93
94 let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
95 let outline_config = grammar
96 .outline_config
97 .as_ref()
98 .ok_or_else(|| anyhow!("no outline query"))?;
99
100 let content = fs.load(&file_path).await?;
101 parser.set_language(grammar.ts_language).unwrap();
102 let tree = parser
103 .parse(&content, None)
104 .ok_or_else(|| anyhow!("parsing failed"))?;
105
106 let mut documents = Vec::new();
107 let mut context_spans = Vec::new();
108 for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
109 let mut item_range = None;
110 let mut name_range = None;
111 for capture in mat.captures {
112 if capture.index == outline_config.item_capture_ix {
113 item_range = Some(capture.node.byte_range());
114 } else if capture.index == outline_config.name_capture_ix {
115 name_range = Some(capture.node.byte_range());
116 }
117 }
118
119 if let Some((item_range, name_range)) = item_range.zip(name_range) {
120 if let Some((item, name)) =
121 content.get(item_range.clone()).zip(content.get(name_range))
122 {
123 context_spans.push(item);
124 documents.push(Document {
125 name: name.to_string(),
126 offset: item_range.start,
127 embedding: Vec::new(),
128 });
129 }
130 }
131 }
132
133 let embeddings = embedding_provider.embed_batch(context_spans).await?;
134 for (document, embedding) in documents.iter_mut().zip(embeddings) {
135 document.embedding = embedding;
136 }
137
138 return Ok(IndexedFile {
139 path: file_path,
140 sha1: String::new(),
141 documents,
142 });
143 }
144
145 fn add_project(&mut self, project: ModelHandle<Project>, cx: &mut ModelContext<Self>) {
146 let worktree_scans_complete = project
147 .read(cx)
148 .worktrees(cx)
149 .map(|worktree| worktree.read(cx).as_local().unwrap().scan_complete())
150 .collect::<Vec<_>>();
151
152 let fs = self.fs.clone();
153 let language_registry = self.language_registry.clone();
154 let client = self.http_client.clone();
155
156 cx.spawn(|_, cx| async move {
157 futures::future::join_all(worktree_scans_complete).await;
158
159 let worktrees = project.read_with(&cx, |project, cx| {
160 project
161 .worktrees(cx)
162 .map(|worktree| worktree.read(cx).snapshot())
163 .collect::<Vec<_>>()
164 });
165
166 let (paths_tx, paths_rx) = channel::unbounded::<PathBuf>();
167 let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<IndexedFile>();
168 cx.background()
169 .spawn(async move {
170 for worktree in worktrees {
171 for file in worktree.files(false, 0) {
172 paths_tx.try_send(worktree.absolutize(&file.path)).unwrap();
173 }
174 }
175 })
176 .detach();
177
178 cx.background()
179 .spawn({
180 let client = client.clone();
181 async move {
182 // Initialize Database, creates database and tables if not exists
183 let db = VectorDatabase::new()?;
184 while let Ok(indexed_file) = indexed_files_rx.recv().await {
185 db.insert_file(indexed_file).log_err();
186 }
187
188 // ALL OF THE BELOW IS FOR TESTING,
189 // This should be removed as we find and appropriate place for evaluate our search.
190
191 let embedding_provider = OpenAIEmbeddings{ client };
192 let queries = vec![
193 "compute embeddings for all of the symbols in the codebase, and write them to a database",
194 "compute an outline view of all of the symbols in a buffer",
195 "scan a directory on the file system and load all of its children into an in-memory snapshot",
196 ];
197 let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
198
199 let t2 = Instant::now();
200 let documents = db.get_documents().unwrap();
201 let files = db.get_files().unwrap();
202 println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
203
204 let t1 = Instant::now();
205 let mut bfs = BruteForceSearch::load(&db).unwrap();
206 println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
207 for (idx, embed) in embeddings.into_iter().enumerate() {
208 let t0 = Instant::now();
209 println!("\nQuery: {:?}", queries[idx]);
210 let results = bfs.top_k_search(&embed, 5).await;
211 println!("Search Elapsed: {}", t0.elapsed().as_millis());
212 for (id, distance) in results {
213 println!("");
214 println!(" distance: {:?}", distance);
215 println!(" document: {:?}", documents[&id].name);
216 println!(" path: {:?}", files[&documents[&id].file_id].path);
217 }
218
219 }
220
221 anyhow::Ok(())
222 }}.log_err())
223 .detach();
224
225 let provider = DummyEmbeddings {};
226 // let provider = OpenAIEmbeddings { client };
227
228 cx.background()
229 .scoped(|scope| {
230 for _ in 0..cx.background().num_cpus() {
231 scope.spawn(async {
232 let mut parser = Parser::new();
233 let mut cursor = QueryCursor::new();
234 while let Ok(file_path) = paths_rx.recv().await {
235 if let Some(indexed_file) = Self::index_file(
236 &mut cursor,
237 &mut parser,
238 &provider,
239 &fs,
240 &language_registry,
241 file_path,
242 )
243 .await
244 .log_err()
245 {
246 indexed_files_tx.try_send(indexed_file).unwrap();
247 }
248 }
249 });
250 }
251 })
252 .await;
253 })
254 .detach();
255 }
256}
257
258impl Entity for VectorStore {
259 type Event = ();
260}