vector_store.rs

  1mod db;
  2mod embedding;
  3mod search;
  4
  5use anyhow::{anyhow, Result};
  6use db::VectorDatabase;
  7use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
  8use gpui::{AppContext, Entity, ModelContext, ModelHandle};
  9use language::LanguageRegistry;
 10use project::{Fs, Project};
 11use smol::channel;
 12use std::{path::PathBuf, sync::Arc, time::Instant};
 13use tree_sitter::{Parser, QueryCursor};
 14use util::{http::HttpClient, ResultExt};
 15use workspace::WorkspaceCreated;
 16
 17pub fn init(
 18    fs: Arc<dyn Fs>,
 19    http_client: Arc<dyn HttpClient>,
 20    language_registry: Arc<LanguageRegistry>,
 21    cx: &mut AppContext,
 22) {
 23    let vector_store = cx.add_model(|cx| VectorStore::new(fs, http_client, language_registry));
 24
 25    cx.subscribe_global::<WorkspaceCreated, _>({
 26        let vector_store = vector_store.clone();
 27        move |event, cx| {
 28            let workspace = &event.0;
 29            if let Some(workspace) = workspace.upgrade(cx) {
 30                let project = workspace.read(cx).project().clone();
 31                if project.read(cx).is_local() {
 32                    vector_store.update(cx, |store, cx| {
 33                        store.add_project(project, cx);
 34                    });
 35                }
 36            }
 37        }
 38    })
 39    .detach();
 40}
 41
 42#[derive(Debug)]
 43pub struct Document {
 44    pub offset: usize,
 45    pub name: String,
 46    pub embedding: Vec<f32>,
 47}
 48
 49#[derive(Debug)]
 50pub struct IndexedFile {
 51    path: PathBuf,
 52    sha1: String,
 53    documents: Vec<Document>,
 54}
 55
 56struct SearchResult {
 57    path: PathBuf,
 58    offset: usize,
 59    name: String,
 60    distance: f32,
 61}
 62
 63struct VectorStore {
 64    fs: Arc<dyn Fs>,
 65    http_client: Arc<dyn HttpClient>,
 66    language_registry: Arc<LanguageRegistry>,
 67}
 68
 69impl VectorStore {
 70    fn new(
 71        fs: Arc<dyn Fs>,
 72        http_client: Arc<dyn HttpClient>,
 73        language_registry: Arc<LanguageRegistry>,
 74    ) -> Self {
 75        Self {
 76            fs,
 77            http_client,
 78            language_registry,
 79        }
 80    }
 81
 82    async fn index_file(
 83        cursor: &mut QueryCursor,
 84        parser: &mut Parser,
 85        embedding_provider: &dyn EmbeddingProvider,
 86        fs: &Arc<dyn Fs>,
 87        language_registry: &Arc<LanguageRegistry>,
 88        file_path: PathBuf,
 89    ) -> Result<IndexedFile> {
 90        let language = language_registry
 91            .language_for_file(&file_path, None)
 92            .await?;
 93
 94        if language.name().as_ref() != "Rust" {
 95            Err(anyhow!("unsupported language"))?;
 96        }
 97
 98        let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
 99        let outline_config = grammar
100            .outline_config
101            .as_ref()
102            .ok_or_else(|| anyhow!("no outline query"))?;
103
104        let content = fs.load(&file_path).await?;
105        parser.set_language(grammar.ts_language).unwrap();
106        let tree = parser
107            .parse(&content, None)
108            .ok_or_else(|| anyhow!("parsing failed"))?;
109
110        let mut documents = Vec::new();
111        let mut context_spans = Vec::new();
112        for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
113            let mut item_range = None;
114            let mut name_range = None;
115            for capture in mat.captures {
116                if capture.index == outline_config.item_capture_ix {
117                    item_range = Some(capture.node.byte_range());
118                } else if capture.index == outline_config.name_capture_ix {
119                    name_range = Some(capture.node.byte_range());
120                }
121            }
122
123            if let Some((item_range, name_range)) = item_range.zip(name_range) {
124                if let Some((item, name)) =
125                    content.get(item_range.clone()).zip(content.get(name_range))
126                {
127                    context_spans.push(item);
128                    documents.push(Document {
129                        name: name.to_string(),
130                        offset: item_range.start,
131                        embedding: Vec::new(),
132                    });
133                }
134            }
135        }
136
137        let embeddings = embedding_provider.embed_batch(context_spans).await?;
138        for (document, embedding) in documents.iter_mut().zip(embeddings) {
139            document.embedding = embedding;
140        }
141
142        return Ok(IndexedFile {
143            path: file_path,
144            sha1: String::new(),
145            documents,
146        });
147    }
148
149    fn add_project(&mut self, project: ModelHandle<Project>, cx: &mut ModelContext<Self>) {
150        let worktree_scans_complete = project
151            .read(cx)
152            .worktrees(cx)
153            .map(|worktree| worktree.read(cx).as_local().unwrap().scan_complete())
154            .collect::<Vec<_>>();
155
156        let fs = self.fs.clone();
157        let language_registry = self.language_registry.clone();
158        let client = self.http_client.clone();
159
160        cx.spawn(|_, cx| async move {
161            futures::future::join_all(worktree_scans_complete).await;
162
163            let worktrees = project.read_with(&cx, |project, cx| {
164                project
165                    .worktrees(cx)
166                    .map(|worktree| worktree.read(cx).snapshot())
167                    .collect::<Vec<_>>()
168            });
169
170            let (paths_tx, paths_rx) = channel::unbounded::<PathBuf>();
171            let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<IndexedFile>();
172            cx.background()
173                .spawn(async move {
174                    for worktree in worktrees {
175                        for file in worktree.files(false, 0) {
176                            paths_tx.try_send(worktree.absolutize(&file.path)).unwrap();
177                        }
178                    }
179                })
180                .detach();
181
182            cx.background()
183                .spawn(async move {
184                    // Initialize Database, creates database and tables if not exists
185                    VectorDatabase::initialize_database().await.log_err();
186                    while let Ok(indexed_file) = indexed_files_rx.recv().await {
187                        VectorDatabase::insert_file(indexed_file).await.log_err();
188                    }
189
190                    anyhow::Ok(())
191                })
192                .detach();
193
194            let provider = DummyEmbeddings {};
195
196            cx.background()
197                .scoped(|scope| {
198                    for _ in 0..cx.background().num_cpus() {
199                        scope.spawn(async {
200                            let mut parser = Parser::new();
201                            let mut cursor = QueryCursor::new();
202                            while let Ok(file_path) = paths_rx.recv().await {
203                                if let Some(indexed_file) = Self::index_file(
204                                    &mut cursor,
205                                    &mut parser,
206                                    &provider,
207                                    &fs,
208                                    &language_registry,
209                                    file_path,
210                                )
211                                .await
212                                .log_err()
213                                {
214                                    indexed_files_tx.try_send(indexed_file).unwrap();
215                                }
216                            }
217                        });
218                    }
219                })
220                .await;
221        })
222        .detach();
223    }
224}
225
226impl Entity for VectorStore {
227    type Event = ();
228}