From dd309070eb03dd51041d412ecce553ab43450342 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 22 Jun 2023 16:50:07 -0400 Subject: [PATCH] open ai indexing on open for rust files --- Cargo.lock | 57 ++++++++---- crates/language/src/language.rs | 16 ++-- crates/vector_store/Cargo.toml | 10 +- crates/vector_store/src/db.rs | 4 +- crates/vector_store/src/embedding.rs | 100 ++++++++++++++++++++ crates/vector_store/src/vector_store.rs | 118 +++++++++++++++++++----- crates/zed/src/main.rs | 2 +- 7 files changed, 252 insertions(+), 55 deletions(-) create mode 100644 crates/vector_store/src/embedding.rs diff --git a/Cargo.lock b/Cargo.lock index beb84e04bd838062ddad097a2dcb2fc3a678136f..5a93ce77af4801ccb9a5fd0f0a8b4a97ad6c8bc1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1389,15 +1389,6 @@ dependencies = [ "theme", ] -[[package]] -name = "conv" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ff10625fd0ac447827aa30ea8b861fead473bb60aeb73af6c1c58caf0d1299" -dependencies = [ - "custom_derive", -] - [[package]] name = "copilot" version = "0.1.0" @@ -1775,12 +1766,6 @@ dependencies = [ "winapi 0.3.9", ] -[[package]] -name = "custom_derive" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef8ae57c4978a2acd8b869ce6b9ca1dfe817bff704c220209fdef2c0b75a01b9" - [[package]] name = "cxx" version = "1.0.94" @@ -2219,6 +2204,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "fancy-regex" version = "0.11.0" @@ -2909,6 +2900,15 @@ dependencies = [ "ahash 0.8.3", ] +[[package]] +name = "hashlink" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7249a3129cbc1ffccd74857f81464a323a152173cdb134e0fd81bc803b29facf" +dependencies = [ + "hashbrown 0.11.2", +] + [[package]] name = "hashlink" version = "0.8.1" @@ -5600,6 +5600,21 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rusqlite" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85127183a999f7db96d1a976a309eebbfb6ea3b0b400ddd8340190129de6eb7a" +dependencies = [ + "bitflags", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink 0.7.0", + "libsqlite3-sys", + "memchr", + "smallvec", +] + [[package]] name = "rust-embed" version = "6.6.1" @@ -6531,7 +6546,7 @@ dependencies = [ "futures-executor", "futures-intrusive", "futures-util", - "hashlink", + "hashlink 0.8.1", "hex", "hkdf", "hmac 0.12.1", @@ -7898,14 +7913,20 @@ version = "0.1.0" dependencies = [ "anyhow", "async-compat", - "conv", + "async-trait", "futures 0.3.28", "gpui", + "isahc", "language", + "lazy_static", + "log", "project", - "rand 0.8.5", + "rusqlite", + "serde", + "serde_json", "smol", "sqlx", + "tree-sitter", "util", "workspace", ] diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index 5a4d604ce349e107c282b8ab2e981b07e94261b9..4c6f709f38fb2b44eba6b6562e2dd7d1e10c545a 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -476,12 +476,12 @@ pub struct Language { pub struct Grammar { id: usize, - pub(crate) ts_language: tree_sitter::Language, + pub ts_language: tree_sitter::Language, pub(crate) error_query: Query, pub(crate) highlights_query: Option, pub(crate) brackets_config: Option, pub(crate) indents_config: Option, - pub(crate) outline_config: Option, + pub outline_config: Option, pub(crate) injection_config: Option, pub(crate) override_config: Option, pub(crate) highlight_map: Mutex, @@ -495,12 +495,12 @@ struct IndentConfig { outdent_capture_ix: Option, } -struct OutlineConfig { - query: Query, - item_capture_ix: u32, - name_capture_ix: u32, - context_capture_ix: Option, - extra_context_capture_ix: Option, +pub struct OutlineConfig { + pub query: Query, + pub item_capture_ix: u32, + pub name_capture_ix: u32, + pub context_capture_ix: Option, + pub extra_context_capture_ix: Option, } struct InjectionConfig { diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml index 74ad23740e43dc7d9dabb4c475dbe1a612a1df49..2db672ed255180eab6f9187ec6cd685d3592a770 100644 --- a/crates/vector_store/Cargo.toml +++ b/crates/vector_store/Cargo.toml @@ -19,8 +19,14 @@ futures.workspace = true smol.workspace = true sqlx = { version = "0.6", features = ["sqlite","runtime-tokio-rustls"] } async-compat = "0.2.1" -conv = "0.3.3" -rand.workspace = true +rusqlite = "0.27.0" +isahc.workspace = true +log.workspace = true +tree-sitter.workspace = true +lazy_static.workspace = true +serde.workspace = true +serde_json.workspace = true +async-trait.workspace = true [dev-dependencies] gpui = { path = "../gpui", features = ["test-support"] } diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index dfa85044d69e0a807e226adc91df33f52ad1a32b..d335d327b8637ffb0e349d506787a322622879c7 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -1,8 +1,6 @@ use anyhow::Result; use async_compat::{Compat, CompatExt}; -use conv::ValueFrom; -use sqlx::{migrate::MigrateDatabase, Pool, Sqlite, SqlitePool}; -use std::time::{Duration, Instant}; +use sqlx::{migrate::MigrateDatabase, Sqlite, SqlitePool}; use crate::IndexedFile; diff --git a/crates/vector_store/src/embedding.rs b/crates/vector_store/src/embedding.rs new file mode 100644 index 0000000000000000000000000000000000000000..f1ae5479ee23ebe50a4233e5fb3cf3bb8248f55e --- /dev/null +++ b/crates/vector_store/src/embedding.rs @@ -0,0 +1,100 @@ +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use futures::AsyncReadExt; +use gpui::serde_json; +use isahc::prelude::Configurable; +use lazy_static::lazy_static; +use serde::{Deserialize, Serialize}; +use std::env; +use std::sync::Arc; +use util::http::{HttpClient, Request}; + +lazy_static! { + static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); +} + +pub struct OpenAIEmbeddings { + pub client: Arc, +} + +#[derive(Serialize)] +struct OpenAIEmbeddingRequest<'a> { + model: &'static str, + input: Vec<&'a str>, +} + +#[derive(Deserialize)] +struct OpenAIEmbeddingResponse { + data: Vec, + usage: OpenAIEmbeddingUsage, +} + +#[derive(Debug, Deserialize)] +struct OpenAIEmbedding { + embedding: Vec, + index: usize, + object: String, +} + +#[derive(Deserialize)] +struct OpenAIEmbeddingUsage { + prompt_tokens: usize, + total_tokens: usize, +} + +#[async_trait] +pub trait EmbeddingProvider: Sync { + async fn embed_batch(&self, spans: Vec<&str>) -> Result>>; +} + +#[async_trait] +impl EmbeddingProvider for OpenAIEmbeddings { + async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { + let api_key = OPENAI_API_KEY + .as_ref() + .ok_or_else(|| anyhow!("no api key"))?; + + let request = Request::post("https://api.openai.com/v1/embeddings") + .redirect_policy(isahc::config::RedirectPolicy::Follow) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body( + serde_json::to_string(&OpenAIEmbeddingRequest { + input: spans, + model: "text-embedding-ada-002", + }) + .unwrap() + .into(), + )?; + + let mut response = self.client.send(request).await?; + if !response.status().is_success() { + return Err(anyhow!("openai embedding failed {}", response.status())); + } + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; + + log::info!( + "openai embedding completed. tokens: {:?}", + response.usage.total_tokens + ); + + // do we need to re-order these based on the `index` field? + eprintln!( + "indices: {:?}", + response + .data + .iter() + .map(|embedding| embedding.index) + .collect::>() + ); + + Ok(response + .data + .into_iter() + .map(|embedding| embedding.embedding) + .collect()) + } +} diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 93f9fbe06dd0e79768739d40090a071c85928563..f4d5baca8095c4076212ab2d9624881db651ef69 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -1,17 +1,25 @@ mod db; -use anyhow::Result; +mod embedding; + +use anyhow::{anyhow, Result}; use db::VectorDatabase; +use embedding::{EmbeddingProvider, OpenAIEmbeddings}; use gpui::{AppContext, Entity, ModelContext, ModelHandle}; use language::LanguageRegistry; use project::{Fs, Project}; -use rand::Rng; use smol::channel; use std::{path::PathBuf, sync::Arc, time::Instant}; -use util::ResultExt; +use tree_sitter::{Parser, QueryCursor}; +use util::{http::HttpClient, ResultExt}; use workspace::WorkspaceCreated; -pub fn init(fs: Arc, language_registry: Arc, cx: &mut AppContext) { - let vector_store = cx.add_model(|cx| VectorStore::new(fs, language_registry)); +pub fn init( + fs: Arc, + http_client: Arc, + language_registry: Arc, + cx: &mut AppContext, +) { + let vector_store = cx.add_model(|cx| VectorStore::new(fs, http_client, language_registry)); cx.subscribe_global::({ let vector_store = vector_store.clone(); @@ -53,38 +61,86 @@ struct SearchResult { struct VectorStore { fs: Arc, + http_client: Arc, language_registry: Arc, } impl VectorStore { - fn new(fs: Arc, language_registry: Arc) -> Self { + fn new( + fs: Arc, + http_client: Arc, + language_registry: Arc, + ) -> Self { Self { fs, + http_client, language_registry, } } async fn index_file( + cursor: &mut QueryCursor, + parser: &mut Parser, + embedding_provider: &dyn EmbeddingProvider, fs: &Arc, language_registry: &Arc, file_path: PathBuf, ) -> Result { - // This is creating dummy documents to test the database writes. - let mut documents = vec![]; - let mut rng = rand::thread_rng(); - let rand_num_of_documents: u8 = rng.gen_range(0..200); - for _ in 0..rand_num_of_documents { - let doc = Document { - offset: 0, - name: "test symbol".to_string(), - embedding: vec![0.32 as f32; 768], - }; - documents.push(doc); + let language = language_registry + .language_for_file(&file_path, None) + .await?; + + if language.name().as_ref() != "Rust" { + Err(anyhow!("unsupported language"))?; + } + + let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?; + let outline_config = grammar + .outline_config + .as_ref() + .ok_or_else(|| anyhow!("no outline query"))?; + + let content = fs.load(&file_path).await?; + parser.set_language(grammar.ts_language).unwrap(); + let tree = parser + .parse(&content, None) + .ok_or_else(|| anyhow!("parsing failed"))?; + + let mut documents = Vec::new(); + let mut context_spans = Vec::new(); + for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) { + let mut item_range = None; + let mut name_range = None; + for capture in mat.captures { + if capture.index == outline_config.item_capture_ix { + item_range = Some(capture.node.byte_range()); + } else if capture.index == outline_config.name_capture_ix { + name_range = Some(capture.node.byte_range()); + } + } + + if let Some((item_range, name_range)) = item_range.zip(name_range) { + if let Some((item, name)) = + content.get(item_range.clone()).zip(content.get(name_range)) + { + context_spans.push(item); + documents.push(Document { + name: name.to_string(), + offset: item_range.start, + embedding: Vec::new(), + }); + } + } + } + + let embeddings = embedding_provider.embed_batch(context_spans).await?; + for (document, embedding) in documents.iter_mut().zip(embeddings) { + document.embedding = embedding; } return Ok(IndexedFile { path: file_path, - sha1: "asdfasdfasdf".to_string(), + sha1: String::new(), documents, }); } @@ -98,8 +154,9 @@ impl VectorStore { let fs = self.fs.clone(); let language_registry = self.language_registry.clone(); + let client = self.http_client.clone(); - cx.spawn(|this, cx| async move { + cx.spawn(|_, cx| async move { futures::future::join_all(worktree_scans_complete).await; let worktrees = project.read_with(&cx, |project, cx| { @@ -131,15 +188,27 @@ impl VectorStore { }) .detach(); + let provider = OpenAIEmbeddings { client }; + + let t0 = Instant::now(); + cx.background() .scoped(|scope| { for _ in 0..cx.background().num_cpus() { scope.spawn(async { + let mut parser = Parser::new(); + let mut cursor = QueryCursor::new(); while let Ok(file_path) = paths_rx.recv().await { - if let Some(indexed_file) = - Self::index_file(&fs, &language_registry, file_path) - .await - .log_err() + if let Some(indexed_file) = Self::index_file( + &mut cursor, + &mut parser, + &provider, + &fs, + &language_registry, + file_path, + ) + .await + .log_err() { indexed_files_tx.try_send(indexed_file).unwrap(); } @@ -148,6 +217,9 @@ impl VectorStore { } }) .await; + + let duration = t0.elapsed(); + log::info!("indexed project in {duration:?}"); }) .detach(); } diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 76d02307f6b1e6e355ca65f201f4927f77221ae7..8a59bbde41af19e41ba31e65bed3350343be6899 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -152,7 +152,7 @@ fn main() { project_panel::init(cx); diagnostics::init(cx); search::init(cx); - vector_store::init(fs.clone(), languages.clone(), cx); + vector_store::init(fs.clone(), http.clone(), languages.clone(), cx); vim::init(cx); terminal_view::init(cx); theme_testbench::init(cx);