open ai indexing on open for rust files

KCaverly created

Change summary

Cargo.lock                              |  57 ++++++++----
crates/language/src/language.rs         |  16 +-
crates/vector_store/Cargo.toml          |  10 +
crates/vector_store/src/db.rs           |   4 
crates/vector_store/src/embedding.rs    | 100 ++++++++++++++++++++++
crates/vector_store/src/vector_store.rs | 118 +++++++++++++++++++++-----
crates/zed/src/main.rs                  |   2 
7 files changed, 252 insertions(+), 55 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -1389,15 +1389,6 @@ dependencies = [
  "theme",
 ]
 
-[[package]]
-name = "conv"
-version = "0.3.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "78ff10625fd0ac447827aa30ea8b861fead473bb60aeb73af6c1c58caf0d1299"
-dependencies = [
- "custom_derive",
-]
-
 [[package]]
 name = "copilot"
 version = "0.1.0"
@@ -1775,12 +1766,6 @@ dependencies = [
  "winapi 0.3.9",
 ]
 
-[[package]]
-name = "custom_derive"
-version = "0.1.7"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ef8ae57c4978a2acd8b869ce6b9ca1dfe817bff704c220209fdef2c0b75a01b9"
-
 [[package]]
 name = "cxx"
 version = "1.0.94"
@@ -2219,6 +2204,12 @@ version = "0.2.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
 
+[[package]]
+name = "fallible-streaming-iterator"
+version = "0.1.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
+
 [[package]]
 name = "fancy-regex"
 version = "0.11.0"
@@ -2909,6 +2900,15 @@ dependencies = [
  "ahash 0.8.3",
 ]
 
+[[package]]
+name = "hashlink"
+version = "0.7.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7249a3129cbc1ffccd74857f81464a323a152173cdb134e0fd81bc803b29facf"
+dependencies = [
+ "hashbrown 0.11.2",
+]
+
 [[package]]
 name = "hashlink"
 version = "0.8.1"
@@ -5600,6 +5600,21 @@ dependencies = [
  "zeroize",
 ]
 
+[[package]]
+name = "rusqlite"
+version = "0.27.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "85127183a999f7db96d1a976a309eebbfb6ea3b0b400ddd8340190129de6eb7a"
+dependencies = [
+ "bitflags",
+ "fallible-iterator",
+ "fallible-streaming-iterator",
+ "hashlink 0.7.0",
+ "libsqlite3-sys",
+ "memchr",
+ "smallvec",
+]
+
 [[package]]
 name = "rust-embed"
 version = "6.6.1"
@@ -6531,7 +6546,7 @@ dependencies = [
  "futures-executor",
  "futures-intrusive",
  "futures-util",
- "hashlink",
+ "hashlink 0.8.1",
  "hex",
  "hkdf",
  "hmac 0.12.1",
@@ -7898,14 +7913,20 @@ version = "0.1.0"
 dependencies = [
  "anyhow",
  "async-compat",
- "conv",
+ "async-trait",
  "futures 0.3.28",
  "gpui",
+ "isahc",
  "language",
+ "lazy_static",
+ "log",
  "project",
- "rand 0.8.5",
+ "rusqlite",
+ "serde",
+ "serde_json",
  "smol",
  "sqlx",
+ "tree-sitter",
  "util",
  "workspace",
 ]

crates/language/src/language.rs 🔗

@@ -476,12 +476,12 @@ pub struct Language {
 
 pub struct Grammar {
     id: usize,
-    pub(crate) ts_language: tree_sitter::Language,
+    pub ts_language: tree_sitter::Language,
     pub(crate) error_query: Query,
     pub(crate) highlights_query: Option<Query>,
     pub(crate) brackets_config: Option<BracketConfig>,
     pub(crate) indents_config: Option<IndentConfig>,
-    pub(crate) outline_config: Option<OutlineConfig>,
+    pub outline_config: Option<OutlineConfig>,
     pub(crate) injection_config: Option<InjectionConfig>,
     pub(crate) override_config: Option<OverrideConfig>,
     pub(crate) highlight_map: Mutex<HighlightMap>,
@@ -495,12 +495,12 @@ struct IndentConfig {
     outdent_capture_ix: Option<u32>,
 }
 
-struct OutlineConfig {
-    query: Query,
-    item_capture_ix: u32,
-    name_capture_ix: u32,
-    context_capture_ix: Option<u32>,
-    extra_context_capture_ix: Option<u32>,
+pub struct OutlineConfig {
+    pub query: Query,
+    pub item_capture_ix: u32,
+    pub name_capture_ix: u32,
+    pub context_capture_ix: Option<u32>,
+    pub extra_context_capture_ix: Option<u32>,
 }
 
 struct InjectionConfig {

crates/vector_store/Cargo.toml 🔗

@@ -19,8 +19,14 @@ futures.workspace = true
 smol.workspace = true
 sqlx = { version = "0.6", features = ["sqlite","runtime-tokio-rustls"] }
 async-compat = "0.2.1"
-conv = "0.3.3"
-rand.workspace = true
+rusqlite = "0.27.0"
+isahc.workspace = true
+log.workspace = true
+tree-sitter.workspace = true
+lazy_static.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+async-trait.workspace = true
 
 [dev-dependencies]
 gpui = { path = "../gpui", features = ["test-support"] }

crates/vector_store/src/db.rs 🔗

@@ -1,8 +1,6 @@
 use anyhow::Result;
 use async_compat::{Compat, CompatExt};
-use conv::ValueFrom;
-use sqlx::{migrate::MigrateDatabase, Pool, Sqlite, SqlitePool};
-use std::time::{Duration, Instant};
+use sqlx::{migrate::MigrateDatabase, Sqlite, SqlitePool};
 
 use crate::IndexedFile;
 

crates/vector_store/src/embedding.rs 🔗

@@ -0,0 +1,100 @@
+use anyhow::{anyhow, Result};
+use async_trait::async_trait;
+use futures::AsyncReadExt;
+use gpui::serde_json;
+use isahc::prelude::Configurable;
+use lazy_static::lazy_static;
+use serde::{Deserialize, Serialize};
+use std::env;
+use std::sync::Arc;
+use util::http::{HttpClient, Request};
+
+lazy_static! {
+    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
+}
+
+pub struct OpenAIEmbeddings {
+    pub client: Arc<dyn HttpClient>,
+}
+
+#[derive(Serialize)]
+struct OpenAIEmbeddingRequest<'a> {
+    model: &'static str,
+    input: Vec<&'a str>,
+}
+
+#[derive(Deserialize)]
+struct OpenAIEmbeddingResponse {
+    data: Vec<OpenAIEmbedding>,
+    usage: OpenAIEmbeddingUsage,
+}
+
+#[derive(Debug, Deserialize)]
+struct OpenAIEmbedding {
+    embedding: Vec<f32>,
+    index: usize,
+    object: String,
+}
+
+#[derive(Deserialize)]
+struct OpenAIEmbeddingUsage {
+    prompt_tokens: usize,
+    total_tokens: usize,
+}
+
+#[async_trait]
+pub trait EmbeddingProvider: Sync {
+    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
+}
+
+#[async_trait]
+impl EmbeddingProvider for OpenAIEmbeddings {
+    async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
+        let api_key = OPENAI_API_KEY
+            .as_ref()
+            .ok_or_else(|| anyhow!("no api key"))?;
+
+        let request = Request::post("https://api.openai.com/v1/embeddings")
+            .redirect_policy(isahc::config::RedirectPolicy::Follow)
+            .header("Content-Type", "application/json")
+            .header("Authorization", format!("Bearer {}", api_key))
+            .body(
+                serde_json::to_string(&OpenAIEmbeddingRequest {
+                    input: spans,
+                    model: "text-embedding-ada-002",
+                })
+                .unwrap()
+                .into(),
+            )?;
+
+        let mut response = self.client.send(request).await?;
+        if !response.status().is_success() {
+            return Err(anyhow!("openai embedding failed {}", response.status()));
+        }
+
+        let mut body = String::new();
+        response.body_mut().read_to_string(&mut body).await?;
+        let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
+
+        log::info!(
+            "openai embedding completed. tokens: {:?}",
+            response.usage.total_tokens
+        );
+
+        // do we need to re-order these based on the `index` field?
+        eprintln!(
+            "indices: {:?}",
+            response
+                .data
+                .iter()
+                .map(|embedding| embedding.index)
+                .collect::<Vec<_>>()
+        );
+
+        Ok(response
+            .data
+            .into_iter()
+            .map(|embedding| embedding.embedding)
+            .collect())
+    }
+}

crates/vector_store/src/vector_store.rs 🔗

@@ -1,17 +1,25 @@
 mod db;
-use anyhow::Result;
+mod embedding;
+
+use anyhow::{anyhow, Result};
 use db::VectorDatabase;
+use embedding::{EmbeddingProvider, OpenAIEmbeddings};
 use gpui::{AppContext, Entity, ModelContext, ModelHandle};
 use language::LanguageRegistry;
 use project::{Fs, Project};
-use rand::Rng;
 use smol::channel;
 use std::{path::PathBuf, sync::Arc, time::Instant};
-use util::ResultExt;
+use tree_sitter::{Parser, QueryCursor};
+use util::{http::HttpClient, ResultExt};
 use workspace::WorkspaceCreated;
 
-pub fn init(fs: Arc<dyn Fs>, language_registry: Arc<LanguageRegistry>, cx: &mut AppContext) {
-    let vector_store = cx.add_model(|cx| VectorStore::new(fs, language_registry));
+pub fn init(
+    fs: Arc<dyn Fs>,
+    http_client: Arc<dyn HttpClient>,
+    language_registry: Arc<LanguageRegistry>,
+    cx: &mut AppContext,
+) {
+    let vector_store = cx.add_model(|cx| VectorStore::new(fs, http_client, language_registry));
 
     cx.subscribe_global::<WorkspaceCreated, _>({
         let vector_store = vector_store.clone();
@@ -53,38 +61,86 @@ struct SearchResult {
 
 struct VectorStore {
     fs: Arc<dyn Fs>,
+    http_client: Arc<dyn HttpClient>,
     language_registry: Arc<LanguageRegistry>,
 }
 
 impl VectorStore {
-    fn new(fs: Arc<dyn Fs>, language_registry: Arc<LanguageRegistry>) -> Self {
+    fn new(
+        fs: Arc<dyn Fs>,
+        http_client: Arc<dyn HttpClient>,
+        language_registry: Arc<LanguageRegistry>,
+    ) -> Self {
         Self {
             fs,
+            http_client,
             language_registry,
         }
     }
 
     async fn index_file(
+        cursor: &mut QueryCursor,
+        parser: &mut Parser,
+        embedding_provider: &dyn EmbeddingProvider,
         fs: &Arc<dyn Fs>,
         language_registry: &Arc<LanguageRegistry>,
         file_path: PathBuf,
     ) -> Result<IndexedFile> {
-        // This is creating dummy documents to test the database writes.
-        let mut documents = vec![];
-        let mut rng = rand::thread_rng();
-        let rand_num_of_documents: u8 = rng.gen_range(0..200);
-        for _ in 0..rand_num_of_documents {
-            let doc = Document {
-                offset: 0,
-                name: "test symbol".to_string(),
-                embedding: vec![0.32 as f32; 768],
-            };
-            documents.push(doc);
+        let language = language_registry
+            .language_for_file(&file_path, None)
+            .await?;
+
+        if language.name().as_ref() != "Rust" {
+            Err(anyhow!("unsupported language"))?;
+        }
+
+        let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
+        let outline_config = grammar
+            .outline_config
+            .as_ref()
+            .ok_or_else(|| anyhow!("no outline query"))?;
+
+        let content = fs.load(&file_path).await?;
+        parser.set_language(grammar.ts_language).unwrap();
+        let tree = parser
+            .parse(&content, None)
+            .ok_or_else(|| anyhow!("parsing failed"))?;
+
+        let mut documents = Vec::new();
+        let mut context_spans = Vec::new();
+        for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
+            let mut item_range = None;
+            let mut name_range = None;
+            for capture in mat.captures {
+                if capture.index == outline_config.item_capture_ix {
+                    item_range = Some(capture.node.byte_range());
+                } else if capture.index == outline_config.name_capture_ix {
+                    name_range = Some(capture.node.byte_range());
+                }
+            }
+
+            if let Some((item_range, name_range)) = item_range.zip(name_range) {
+                if let Some((item, name)) =
+                    content.get(item_range.clone()).zip(content.get(name_range))
+                {
+                    context_spans.push(item);
+                    documents.push(Document {
+                        name: name.to_string(),
+                        offset: item_range.start,
+                        embedding: Vec::new(),
+                    });
+                }
+            }
+        }
+
+        let embeddings = embedding_provider.embed_batch(context_spans).await?;
+        for (document, embedding) in documents.iter_mut().zip(embeddings) {
+            document.embedding = embedding;
         }
 
         return Ok(IndexedFile {
             path: file_path,
-            sha1: "asdfasdfasdf".to_string(),
+            sha1: String::new(),
             documents,
         });
     }
@@ -98,8 +154,9 @@ impl VectorStore {
 
         let fs = self.fs.clone();
         let language_registry = self.language_registry.clone();
+        let client = self.http_client.clone();
 
-        cx.spawn(|this, cx| async move {
+        cx.spawn(|_, cx| async move {
             futures::future::join_all(worktree_scans_complete).await;
 
             let worktrees = project.read_with(&cx, |project, cx| {
@@ -131,15 +188,27 @@ impl VectorStore {
                 })
                 .detach();
 
+            let provider = OpenAIEmbeddings { client };
+
+            let t0 = Instant::now();
+
             cx.background()
                 .scoped(|scope| {
                     for _ in 0..cx.background().num_cpus() {
                         scope.spawn(async {
+                            let mut parser = Parser::new();
+                            let mut cursor = QueryCursor::new();
                             while let Ok(file_path) = paths_rx.recv().await {
-                                if let Some(indexed_file) =
-                                    Self::index_file(&fs, &language_registry, file_path)
-                                        .await
-                                        .log_err()
+                                if let Some(indexed_file) = Self::index_file(
+                                    &mut cursor,
+                                    &mut parser,
+                                    &provider,
+                                    &fs,
+                                    &language_registry,
+                                    file_path,
+                                )
+                                .await
+                                .log_err()
                                 {
                                     indexed_files_tx.try_send(indexed_file).unwrap();
                                 }
@@ -148,6 +217,9 @@ impl VectorStore {
                     }
                 })
                 .await;
+
+            let duration = t0.elapsed();
+            log::info!("indexed project in {duration:?}");
         })
         .detach();
     }

crates/zed/src/main.rs 🔗

@@ -152,7 +152,7 @@ fn main() {
         project_panel::init(cx);
         diagnostics::init(cx);
         search::init(cx);
-        vector_store::init(fs.clone(), languages.clone(), cx);
+        vector_store::init(fs.clone(), http.clone(), languages.clone(), cx);
         vim::init(cx);
         terminal_view::init(cx);
         theme_testbench::init(cx);