Detailed changes
@@ -1389,15 +1389,6 @@ dependencies = [
"theme",
]
-[[package]]
-name = "conv"
-version = "0.3.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "78ff10625fd0ac447827aa30ea8b861fead473bb60aeb73af6c1c58caf0d1299"
-dependencies = [
- "custom_derive",
-]
-
[[package]]
name = "copilot"
version = "0.1.0"
@@ -1775,12 +1766,6 @@ dependencies = [
"winapi 0.3.9",
]
-[[package]]
-name = "custom_derive"
-version = "0.1.7"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ef8ae57c4978a2acd8b869ce6b9ca1dfe817bff704c220209fdef2c0b75a01b9"
-
[[package]]
name = "cxx"
version = "1.0.94"
@@ -2219,6 +2204,12 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
+[[package]]
+name = "fallible-streaming-iterator"
+version = "0.1.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
+
[[package]]
name = "fancy-regex"
version = "0.11.0"
@@ -2909,6 +2900,15 @@ dependencies = [
"ahash 0.8.3",
]
+[[package]]
+name = "hashlink"
+version = "0.7.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7249a3129cbc1ffccd74857f81464a323a152173cdb134e0fd81bc803b29facf"
+dependencies = [
+ "hashbrown 0.11.2",
+]
+
[[package]]
name = "hashlink"
version = "0.8.1"
@@ -5600,6 +5600,21 @@ dependencies = [
"zeroize",
]
+[[package]]
+name = "rusqlite"
+version = "0.27.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "85127183a999f7db96d1a976a309eebbfb6ea3b0b400ddd8340190129de6eb7a"
+dependencies = [
+ "bitflags",
+ "fallible-iterator",
+ "fallible-streaming-iterator",
+ "hashlink 0.7.0",
+ "libsqlite3-sys",
+ "memchr",
+ "smallvec",
+]
+
[[package]]
name = "rust-embed"
version = "6.6.1"
@@ -6531,7 +6546,7 @@ dependencies = [
"futures-executor",
"futures-intrusive",
"futures-util",
- "hashlink",
+ "hashlink 0.8.1",
"hex",
"hkdf",
"hmac 0.12.1",
@@ -7898,14 +7913,20 @@ version = "0.1.0"
dependencies = [
"anyhow",
"async-compat",
- "conv",
+ "async-trait",
"futures 0.3.28",
"gpui",
+ "isahc",
"language",
+ "lazy_static",
+ "log",
"project",
- "rand 0.8.5",
+ "rusqlite",
+ "serde",
+ "serde_json",
"smol",
"sqlx",
+ "tree-sitter",
"util",
"workspace",
]
@@ -476,12 +476,12 @@ pub struct Language {
pub struct Grammar {
id: usize,
- pub(crate) ts_language: tree_sitter::Language,
+ pub ts_language: tree_sitter::Language,
pub(crate) error_query: Query,
pub(crate) highlights_query: Option<Query>,
pub(crate) brackets_config: Option<BracketConfig>,
pub(crate) indents_config: Option<IndentConfig>,
- pub(crate) outline_config: Option<OutlineConfig>,
+ pub outline_config: Option<OutlineConfig>,
pub(crate) injection_config: Option<InjectionConfig>,
pub(crate) override_config: Option<OverrideConfig>,
pub(crate) highlight_map: Mutex<HighlightMap>,
@@ -495,12 +495,12 @@ struct IndentConfig {
outdent_capture_ix: Option<u32>,
}
-struct OutlineConfig {
- query: Query,
- item_capture_ix: u32,
- name_capture_ix: u32,
- context_capture_ix: Option<u32>,
- extra_context_capture_ix: Option<u32>,
+pub struct OutlineConfig {
+ pub query: Query,
+ pub item_capture_ix: u32,
+ pub name_capture_ix: u32,
+ pub context_capture_ix: Option<u32>,
+ pub extra_context_capture_ix: Option<u32>,
}
struct InjectionConfig {
@@ -19,8 +19,14 @@ futures.workspace = true
smol.workspace = true
sqlx = { version = "0.6", features = ["sqlite","runtime-tokio-rustls"] }
async-compat = "0.2.1"
-conv = "0.3.3"
-rand.workspace = true
+rusqlite = "0.27.0"
+isahc.workspace = true
+log.workspace = true
+tree-sitter.workspace = true
+lazy_static.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+async-trait.workspace = true
[dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] }
@@ -1,8 +1,6 @@
use anyhow::Result;
use async_compat::{Compat, CompatExt};
-use conv::ValueFrom;
-use sqlx::{migrate::MigrateDatabase, Pool, Sqlite, SqlitePool};
-use std::time::{Duration, Instant};
+use sqlx::{migrate::MigrateDatabase, Sqlite, SqlitePool};
use crate::IndexedFile;
@@ -0,0 +1,100 @@
+use anyhow::{anyhow, Result};
+use async_trait::async_trait;
+use futures::AsyncReadExt;
+use gpui::serde_json;
+use isahc::prelude::Configurable;
+use lazy_static::lazy_static;
+use serde::{Deserialize, Serialize};
+use std::env;
+use std::sync::Arc;
+use util::http::{HttpClient, Request};
+
+lazy_static! {
+ static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
+}
+
+pub struct OpenAIEmbeddings {
+ pub client: Arc<dyn HttpClient>,
+}
+
+#[derive(Serialize)]
+struct OpenAIEmbeddingRequest<'a> {
+ model: &'static str,
+ input: Vec<&'a str>,
+}
+
+#[derive(Deserialize)]
+struct OpenAIEmbeddingResponse {
+ data: Vec<OpenAIEmbedding>,
+ usage: OpenAIEmbeddingUsage,
+}
+
+#[derive(Debug, Deserialize)]
+struct OpenAIEmbedding {
+ embedding: Vec<f32>,
+ index: usize,
+ object: String,
+}
+
+#[derive(Deserialize)]
+struct OpenAIEmbeddingUsage {
+ prompt_tokens: usize,
+ total_tokens: usize,
+}
+
+#[async_trait]
+pub trait EmbeddingProvider: Sync {
+ async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>>;
+}
+
+#[async_trait]
+impl EmbeddingProvider for OpenAIEmbeddings {
+ async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
+ let api_key = OPENAI_API_KEY
+ .as_ref()
+ .ok_or_else(|| anyhow!("no api key"))?;
+
+ let request = Request::post("https://api.openai.com/v1/embeddings")
+ .redirect_policy(isahc::config::RedirectPolicy::Follow)
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_key))
+ .body(
+ serde_json::to_string(&OpenAIEmbeddingRequest {
+ input: spans,
+ model: "text-embedding-ada-002",
+ })
+ .unwrap()
+ .into(),
+ )?;
+
+ let mut response = self.client.send(request).await?;
+ if !response.status().is_success() {
+ return Err(anyhow!("openai embedding failed {}", response.status()));
+ }
+
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
+
+ log::info!(
+ "openai embedding completed. tokens: {:?}",
+ response.usage.total_tokens
+ );
+
+ // do we need to re-order these based on the `index` field?
+ eprintln!(
+ "indices: {:?}",
+ response
+ .data
+ .iter()
+ .map(|embedding| embedding.index)
+ .collect::<Vec<_>>()
+ );
+
+ Ok(response
+ .data
+ .into_iter()
+ .map(|embedding| embedding.embedding)
+ .collect())
+ }
+}
@@ -1,17 +1,25 @@
mod db;
-use anyhow::Result;
+mod embedding;
+
+use anyhow::{anyhow, Result};
use db::VectorDatabase;
+use embedding::{EmbeddingProvider, OpenAIEmbeddings};
use gpui::{AppContext, Entity, ModelContext, ModelHandle};
use language::LanguageRegistry;
use project::{Fs, Project};
-use rand::Rng;
use smol::channel;
use std::{path::PathBuf, sync::Arc, time::Instant};
-use util::ResultExt;
+use tree_sitter::{Parser, QueryCursor};
+use util::{http::HttpClient, ResultExt};
use workspace::WorkspaceCreated;
-pub fn init(fs: Arc<dyn Fs>, language_registry: Arc<LanguageRegistry>, cx: &mut AppContext) {
- let vector_store = cx.add_model(|cx| VectorStore::new(fs, language_registry));
+pub fn init(
+ fs: Arc<dyn Fs>,
+ http_client: Arc<dyn HttpClient>,
+ language_registry: Arc<LanguageRegistry>,
+ cx: &mut AppContext,
+) {
+ let vector_store = cx.add_model(|cx| VectorStore::new(fs, http_client, language_registry));
cx.subscribe_global::<WorkspaceCreated, _>({
let vector_store = vector_store.clone();
@@ -53,38 +61,86 @@ struct SearchResult {
struct VectorStore {
fs: Arc<dyn Fs>,
+ http_client: Arc<dyn HttpClient>,
language_registry: Arc<LanguageRegistry>,
}
impl VectorStore {
- fn new(fs: Arc<dyn Fs>, language_registry: Arc<LanguageRegistry>) -> Self {
+ fn new(
+ fs: Arc<dyn Fs>,
+ http_client: Arc<dyn HttpClient>,
+ language_registry: Arc<LanguageRegistry>,
+ ) -> Self {
Self {
fs,
+ http_client,
language_registry,
}
}
async fn index_file(
+ cursor: &mut QueryCursor,
+ parser: &mut Parser,
+ embedding_provider: &dyn EmbeddingProvider,
fs: &Arc<dyn Fs>,
language_registry: &Arc<LanguageRegistry>,
file_path: PathBuf,
) -> Result<IndexedFile> {
- // This is creating dummy documents to test the database writes.
- let mut documents = vec![];
- let mut rng = rand::thread_rng();
- let rand_num_of_documents: u8 = rng.gen_range(0..200);
- for _ in 0..rand_num_of_documents {
- let doc = Document {
- offset: 0,
- name: "test symbol".to_string(),
- embedding: vec![0.32 as f32; 768],
- };
- documents.push(doc);
+ let language = language_registry
+ .language_for_file(&file_path, None)
+ .await?;
+
+ if language.name().as_ref() != "Rust" {
+ Err(anyhow!("unsupported language"))?;
+ }
+
+ let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
+ let outline_config = grammar
+ .outline_config
+ .as_ref()
+ .ok_or_else(|| anyhow!("no outline query"))?;
+
+ let content = fs.load(&file_path).await?;
+ parser.set_language(grammar.ts_language).unwrap();
+ let tree = parser
+ .parse(&content, None)
+ .ok_or_else(|| anyhow!("parsing failed"))?;
+
+ let mut documents = Vec::new();
+ let mut context_spans = Vec::new();
+ for mat in cursor.matches(&outline_config.query, tree.root_node(), content.as_bytes()) {
+ let mut item_range = None;
+ let mut name_range = None;
+ for capture in mat.captures {
+ if capture.index == outline_config.item_capture_ix {
+ item_range = Some(capture.node.byte_range());
+ } else if capture.index == outline_config.name_capture_ix {
+ name_range = Some(capture.node.byte_range());
+ }
+ }
+
+ if let Some((item_range, name_range)) = item_range.zip(name_range) {
+ if let Some((item, name)) =
+ content.get(item_range.clone()).zip(content.get(name_range))
+ {
+ context_spans.push(item);
+ documents.push(Document {
+ name: name.to_string(),
+ offset: item_range.start,
+ embedding: Vec::new(),
+ });
+ }
+ }
+ }
+
+ let embeddings = embedding_provider.embed_batch(context_spans).await?;
+ for (document, embedding) in documents.iter_mut().zip(embeddings) {
+ document.embedding = embedding;
}
return Ok(IndexedFile {
path: file_path,
- sha1: "asdfasdfasdf".to_string(),
+ sha1: String::new(),
documents,
});
}
@@ -98,8 +154,9 @@ impl VectorStore {
let fs = self.fs.clone();
let language_registry = self.language_registry.clone();
+ let client = self.http_client.clone();
- cx.spawn(|this, cx| async move {
+ cx.spawn(|_, cx| async move {
futures::future::join_all(worktree_scans_complete).await;
let worktrees = project.read_with(&cx, |project, cx| {
@@ -131,15 +188,27 @@ impl VectorStore {
})
.detach();
+ let provider = OpenAIEmbeddings { client };
+
+ let t0 = Instant::now();
+
cx.background()
.scoped(|scope| {
for _ in 0..cx.background().num_cpus() {
scope.spawn(async {
+ let mut parser = Parser::new();
+ let mut cursor = QueryCursor::new();
while let Ok(file_path) = paths_rx.recv().await {
- if let Some(indexed_file) =
- Self::index_file(&fs, &language_registry, file_path)
- .await
- .log_err()
+ if let Some(indexed_file) = Self::index_file(
+ &mut cursor,
+ &mut parser,
+ &provider,
+ &fs,
+ &language_registry,
+ file_path,
+ )
+ .await
+ .log_err()
{
indexed_files_tx.try_send(indexed_file).unwrap();
}
@@ -148,6 +217,9 @@ impl VectorStore {
}
})
.await;
+
+ let duration = t0.elapsed();
+ log::info!("indexed project in {duration:?}");
})
.detach();
}
@@ -152,7 +152,7 @@ fn main() {
project_panel::init(cx);
diagnostics::init(cx);
search::init(cx);
- vector_store::init(fs.clone(), languages.clone(), cx);
+ vector_store::init(fs.clone(), http.clone(), languages.clone(), cx);
vim::init(cx);
terminal_view::init(cx);
theme_testbench::init(cx);