vector_store.rs

  1mod db;
  2mod embedding;
  3
  4use anyhow::{anyhow, Result};
  5use db::VectorDatabase;
  6use embedding::{EmbeddingProvider, OpenAIEmbeddings};
  7use gpui::{AppContext, Entity, ModelContext, ModelHandle};
  8use language::LanguageRegistry;
  9use project::{Fs, Project};
 10use smol::channel;
 11use std::{path::PathBuf, sync::Arc, time::Instant};
 12use tree_sitter::{Parser, QueryCursor};
 13use util::{http::HttpClient, ResultExt};
 14use workspace::WorkspaceCreated;
 15
 16pub fn init(
 17    fs: Arc<dyn Fs>,
 18    http_client: Arc<dyn HttpClient>,
 19    language_registry: Arc<LanguageRegistry>,
 20    cx: &mut AppContext,
 21) {
 22    let vector_store = cx.add_model(|cx| VectorStore::new(fs, http_client, language_registry));
 23
 24    cx.subscribe_global::<WorkspaceCreated, _>({
 25        let vector_store = vector_store.clone();
 26        move |event, cx| {
 27            let workspace = &event.0;
 28            if let Some(workspace) = workspace.upgrade(cx) {
 29                let project = workspace.read(cx).project().clone();
 30                if project.read(cx).is_local() {
 31                    vector_store.update(cx, |store, cx| {
 32                        store.add_project(project, cx);
 33                    });
 34                }
 35            }
 36        }
 37    })
 38    .detach();
 39}
 40
 41#[derive(Debug, sqlx::FromRow)]
 42struct Document {
 43    offset: usize,
 44    name: String,
 45    embedding: Vec<f32>,
 46}
 47
 48#[derive(Debug, sqlx::FromRow)]
 49pub struct IndexedFile {
 50    path: PathBuf,
 51    sha1: String,
 52    documents: Vec<Document>,
 53}
 54
 55struct SearchResult {
 56    path: PathBuf,
 57    offset: usize,
 58    name: String,
 59    distance: f32,
 60}
 61
 62struct VectorStore {
 63    fs: Arc<dyn Fs>,
 64    http_client: Arc<dyn HttpClient>,
 65    language_registry: Arc<LanguageRegistry>,
 66}
 67
 68impl VectorStore {
 69    fn new(
 70        fs: Arc<dyn Fs>,
 71        http_client: Arc<dyn HttpClient>,
 72        language_registry: Arc<LanguageRegistry>,
 73    ) -> Self {
 74        Self {
 75            fs,
 76            http_client,
 77            language_registry,
 78        }
 79    }
 80
 81    async fn index_file(
 82        cursor: &mut QueryCursor,
 83        parser: &mut Parser,
 84        embedding_provider: &dyn EmbeddingProvider,
 85        fs: &Arc<dyn Fs>,
 86        language_registry: &Arc<LanguageRegistry>,
 87        file_path: PathBuf,
 88    ) -> Result<IndexedFile> {
 89        let language = language_registry
 90            .language_for_file(&file_path, None)
 91            .await?;
 92
 93        if language.name().as_ref() != "Rust" {
 94            Err(anyhow!("unsupported language"))?;
 95        }
 96
 97        let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
 98        let outline_config = grammar
 99            .outline_config
100            .as_ref()
101            .ok_or_else(|| anyhow!("no outline query"))?;
102
103        let content = fs.load(&file_path).await?;
104        parser.set_language(grammar.ts_language).unwrap();
105        let tree = parser
106            .parse(&content, None)
107            .ok_or_else(|| anyhow!("parsing failed"))?;
108
109        let mut documents = Vec::new();
110        let mut context_spans = Vec::new();
111        for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
112            let mut item_range = None;
113            let mut name_range = None;
114            for capture in mat.captures {
115                if capture.index == outline_config.item_capture_ix {
116                    item_range = Some(capture.node.byte_range());
117                } else if capture.index == outline_config.name_capture_ix {
118                    name_range = Some(capture.node.byte_range());
119                }
120            }
121
122            if let Some((item_range, name_range)) = item_range.zip(name_range) {
123                if let Some((item, name)) =
124                    content.get(item_range.clone()).zip(content.get(name_range))
125                {
126                    context_spans.push(item);
127                    documents.push(Document {
128                        name: name.to_string(),
129                        offset: item_range.start,
130                        embedding: Vec::new(),
131                    });
132                }
133            }
134        }
135
136        let embeddings = embedding_provider.embed_batch(context_spans).await?;
137        for (document, embedding) in documents.iter_mut().zip(embeddings) {
138            document.embedding = embedding;
139        }
140
141        return Ok(IndexedFile {
142            path: file_path,
143            sha1: String::new(),
144            documents,
145        });
146    }
147
148    fn add_project(&mut self, project: ModelHandle<Project>, cx: &mut ModelContext<Self>) {
149        let worktree_scans_complete = project
150            .read(cx)
151            .worktrees(cx)
152            .map(|worktree| worktree.read(cx).as_local().unwrap().scan_complete())
153            .collect::<Vec<_>>();
154
155        let fs = self.fs.clone();
156        let language_registry = self.language_registry.clone();
157        let client = self.http_client.clone();
158
159        cx.spawn(|_, cx| async move {
160            futures::future::join_all(worktree_scans_complete).await;
161
162            let worktrees = project.read_with(&cx, |project, cx| {
163                project
164                    .worktrees(cx)
165                    .map(|worktree| worktree.read(cx).snapshot())
166                    .collect::<Vec<_>>()
167            });
168
169            let (paths_tx, paths_rx) = channel::unbounded::<PathBuf>();
170            let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<IndexedFile>();
171            cx.background()
172                .spawn(async move {
173                    for worktree in worktrees {
174                        for file in worktree.files(false, 0) {
175                            paths_tx.try_send(worktree.absolutize(&file.path)).unwrap();
176                        }
177                    }
178                })
179                .detach();
180
181            cx.background()
182                .spawn(async move {
183                    // Initialize Database, creates database and tables if not exists
184                    VectorDatabase::initialize_database().await.log_err();
185                    while let Ok(indexed_file) = indexed_files_rx.recv().await {
186                        VectorDatabase::insert_file(indexed_file).await.log_err();
187                    }
188                })
189                .detach();
190
191            let provider = OpenAIEmbeddings { client };
192
193            let t0 = Instant::now();
194
195            cx.background()
196                .scoped(|scope| {
197                    for _ in 0..cx.background().num_cpus() {
198                        scope.spawn(async {
199                            let mut parser = Parser::new();
200                            let mut cursor = QueryCursor::new();
201                            while let Ok(file_path) = paths_rx.recv().await {
202                                if let Some(indexed_file) = Self::index_file(
203                                    &mut cursor,
204                                    &mut parser,
205                                    &provider,
206                                    &fs,
207                                    &language_registry,
208                                    file_path,
209                                )
210                                .await
211                                .log_err()
212                                {
213                                    indexed_files_tx.try_send(indexed_file).unwrap();
214                                }
215                            }
216                        });
217                    }
218                })
219                .await;
220
221            let duration = t0.elapsed();
222            log::info!("indexed project in {duration:?}");
223        })
224        .detach();
225    }
226}
227
228impl Entity for VectorStore {
229    type Event = ();
230}