vector_store.rs

  1mod db;
  2mod embedding;
  3mod parsing;
  4mod search;
  5
  6use anyhow::{anyhow, Result};
  7use db::VectorDatabase;
  8use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
  9use gpui::{AppContext, Entity, ModelContext, ModelHandle};
 10use language::LanguageRegistry;
 11use parsing::Document;
 12use project::{Fs, Project};
 13use search::{BruteForceSearch, VectorSearch};
 14use smol::channel;
 15use std::{path::PathBuf, sync::Arc, time::Instant};
 16use tree_sitter::{Parser, QueryCursor};
 17use util::{http::HttpClient, ResultExt, TryFutureExt};
 18use workspace::WorkspaceCreated;
 19
 20pub fn init(
 21    fs: Arc<dyn Fs>,
 22    http_client: Arc<dyn HttpClient>,
 23    language_registry: Arc<LanguageRegistry>,
 24    cx: &mut AppContext,
 25) {
 26    let vector_store = cx.add_model(|cx| VectorStore::new(fs, http_client, language_registry));
 27
 28    cx.subscribe_global::<WorkspaceCreated, _>({
 29        let vector_store = vector_store.clone();
 30        move |event, cx| {
 31            let workspace = &event.0;
 32            if let Some(workspace) = workspace.upgrade(cx) {
 33                let project = workspace.read(cx).project().clone();
 34                if project.read(cx).is_local() {
 35                    vector_store.update(cx, |store, cx| {
 36                        store.add_project(project, cx);
 37                    });
 38                }
 39            }
 40        }
 41    })
 42    .detach();
 43}
 44
 45#[derive(Debug)]
 46pub struct IndexedFile {
 47    path: PathBuf,
 48    sha1: String,
 49    documents: Vec<Document>,
 50}
 51
 52struct SearchResult {
 53    path: PathBuf,
 54    offset: usize,
 55    name: String,
 56    distance: f32,
 57}
 58
 59struct VectorStore {
 60    fs: Arc<dyn Fs>,
 61    http_client: Arc<dyn HttpClient>,
 62    language_registry: Arc<LanguageRegistry>,
 63}
 64
 65impl VectorStore {
 66    fn new(
 67        fs: Arc<dyn Fs>,
 68        http_client: Arc<dyn HttpClient>,
 69        language_registry: Arc<LanguageRegistry>,
 70    ) -> Self {
 71        Self {
 72            fs,
 73            http_client,
 74            language_registry,
 75        }
 76    }
 77
 78    async fn index_file(
 79        cursor: &mut QueryCursor,
 80        parser: &mut Parser,
 81        embedding_provider: &dyn EmbeddingProvider,
 82        fs: &Arc<dyn Fs>,
 83        language_registry: &Arc<LanguageRegistry>,
 84        file_path: PathBuf,
 85    ) -> Result<IndexedFile> {
 86        let language = language_registry
 87            .language_for_file(&file_path, None)
 88            .await?;
 89
 90        if language.name().as_ref() != "Rust" {
 91            Err(anyhow!("unsupported language"))?;
 92        }
 93
 94        let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
 95        let outline_config = grammar
 96            .outline_config
 97            .as_ref()
 98            .ok_or_else(|| anyhow!("no outline query"))?;
 99
100        let content = fs.load(&file_path).await?;
101        parser.set_language(grammar.ts_language).unwrap();
102        let tree = parser
103            .parse(&content, None)
104            .ok_or_else(|| anyhow!("parsing failed"))?;
105
106        let mut documents = Vec::new();
107        let mut context_spans = Vec::new();
108        for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
109            let mut item_range = None;
110            let mut name_range = None;
111            for capture in mat.captures {
112                if capture.index == outline_config.item_capture_ix {
113                    item_range = Some(capture.node.byte_range());
114                } else if capture.index == outline_config.name_capture_ix {
115                    name_range = Some(capture.node.byte_range());
116                }
117            }
118
119            if let Some((item_range, name_range)) = item_range.zip(name_range) {
120                if let Some((item, name)) =
121                    content.get(item_range.clone()).zip(content.get(name_range))
122                {
123                    context_spans.push(item);
124                    documents.push(Document {
125                        name: name.to_string(),
126                        offset: item_range.start,
127                        embedding: Vec::new(),
128                    });
129                }
130            }
131        }
132
133        let embeddings = embedding_provider.embed_batch(context_spans).await?;
134        for (document, embedding) in documents.iter_mut().zip(embeddings) {
135            document.embedding = embedding;
136        }
137
138        return Ok(IndexedFile {
139            path: file_path,
140            sha1: String::new(),
141            documents,
142        });
143    }
144
145    fn add_project(&mut self, project: ModelHandle<Project>, cx: &mut ModelContext<Self>) {
146        let worktree_scans_complete = project
147            .read(cx)
148            .worktrees(cx)
149            .map(|worktree| worktree.read(cx).as_local().unwrap().scan_complete())
150            .collect::<Vec<_>>();
151
152        let fs = self.fs.clone();
153        let language_registry = self.language_registry.clone();
154        let client = self.http_client.clone();
155
156        cx.spawn(|_, cx| async move {
157            futures::future::join_all(worktree_scans_complete).await;
158
159            let worktrees = project.read_with(&cx, |project, cx| {
160                project
161                    .worktrees(cx)
162                    .map(|worktree| worktree.read(cx).snapshot())
163                    .collect::<Vec<_>>()
164            });
165
166            let (paths_tx, paths_rx) = channel::unbounded::<PathBuf>();
167            let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<IndexedFile>();
168            cx.background()
169                .spawn(async move {
170                    for worktree in worktrees {
171                        for file in worktree.files(false, 0) {
172                            paths_tx.try_send(worktree.absolutize(&file.path)).unwrap();
173                        }
174                    }
175                })
176                .detach();
177
178            cx.background()
179                .spawn({
180                    let client = client.clone();
181                    async move {
182                    // Initialize Database, creates database and tables if not exists
183                    let db = VectorDatabase::new()?;
184                    while let Ok(indexed_file) = indexed_files_rx.recv().await {
185                        db.insert_file(indexed_file).log_err();
186                    }
187
188                    // ALL OF THE BELOW IS FOR TESTING,
189                    // This should be removed as we find and appropriate place for evaluate our search.
190
191                    let embedding_provider = OpenAIEmbeddings{ client };
192                    let queries = vec![
193                        "compute embeddings for all of the symbols in the codebase, and write them to a database",
194                            "compute an outline view of all of the symbols in a buffer",
195                            "scan a directory on the file system and load all of its children into an in-memory snapshot",
196                    ];
197                    let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
198
199                    let t2 = Instant::now();
200                    let documents = db.get_documents().unwrap();
201                    let files = db.get_files().unwrap();
202                    println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
203
204                    let t1 = Instant::now();
205                    let mut bfs = BruteForceSearch::load(&db).unwrap();
206                    println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
207                    for (idx, embed) in embeddings.into_iter().enumerate() {
208                        let t0 = Instant::now();
209                        println!("\nQuery: {:?}", queries[idx]);
210                        let results = bfs.top_k_search(&embed, 5).await;
211                        println!("Search Elapsed: {}", t0.elapsed().as_millis());
212                        for (id, distance) in results {
213                            println!("");
214                            println!("   distance: {:?}", distance);
215                            println!("   document: {:?}", documents[&id].name);
216                            println!("   path:     {:?}", files[&documents[&id].file_id].path);
217                        }
218
219                    }
220
221                    anyhow::Ok(())
222                }}.log_err())
223                .detach();
224
225            let provider = DummyEmbeddings {};
226            // let provider = OpenAIEmbeddings { client };
227
228            cx.background()
229                .scoped(|scope| {
230                    for _ in 0..cx.background().num_cpus() {
231                        scope.spawn(async {
232                            let mut parser = Parser::new();
233                            let mut cursor = QueryCursor::new();
234                            while let Ok(file_path) = paths_rx.recv().await {
235                                if let Some(indexed_file) = Self::index_file(
236                                    &mut cursor,
237                                    &mut parser,
238                                    &provider,
239                                    &fs,
240                                    &language_registry,
241                                    file_path,
242                                )
243                                .await
244                                .log_err()
245                                {
246                                    indexed_files_tx.try_send(indexed_file).unwrap();
247                                }
248                            }
249                        });
250                    }
251                })
252                .await;
253        })
254        .detach();
255    }
256}
257
258impl Entity for VectorStore {
259    type Event = ();
260}