Detailed changes
@@ -3265,6 +3265,15 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ea835d29036a4087793836fa931b08837ad5e957da9e23886b29586fb9b6650"
+[[package]]
+name = "doxygen-rs"
+version = "0.4.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "415b6ec780d34dcf624666747194393603d0373b7141eef01d12ee58881507d9"
+dependencies = [
+ "phf",
+]
+
[[package]]
name = "dwrote"
version = "0.11.0"
@@ -4085,6 +4094,17 @@ dependencies = [
"futures-util",
]
+[[package]]
+name = "futures-batch"
+version = "0.6.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6f444c45a1cb86f2a7e301469fd50a82084a60dadc25d94529a8312276ecb71a"
+dependencies = [
+ "futures 0.3.28",
+ "futures-timer",
+ "pin-utils",
+]
+
[[package]]
name = "futures-channel"
version = "0.3.30"
@@ -4180,6 +4200,12 @@ version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004"
+[[package]]
+name = "futures-timer"
+version = "3.0.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24"
+
[[package]]
name = "futures-util"
version = "0.3.30"
@@ -4659,6 +4685,41 @@ dependencies = [
"unicode-segmentation",
]
+[[package]]
+name = "heed"
+version = "0.20.0-alpha.9"
+source = "git+https://github.com/meilisearch/heed?rev=036ac23f73a021894974b9adc815bc95b3e0482a#036ac23f73a021894974b9adc815bc95b3e0482a"
+dependencies = [
+ "bitflags 2.4.2",
+ "byteorder",
+ "heed-traits",
+ "heed-types",
+ "libc",
+ "lmdb-master-sys",
+ "once_cell",
+ "page_size",
+ "serde",
+ "synchronoise",
+ "url",
+]
+
+[[package]]
+name = "heed-traits"
+version = "0.20.0-alpha.9"
+source = "git+https://github.com/meilisearch/heed?rev=036ac23f73a021894974b9adc815bc95b3e0482a#036ac23f73a021894974b9adc815bc95b3e0482a"
+
+[[package]]
+name = "heed-types"
+version = "0.20.0-alpha.9"
+source = "git+https://github.com/meilisearch/heed?rev=036ac23f73a021894974b9adc815bc95b3e0482a#036ac23f73a021894974b9adc815bc95b3e0482a"
+dependencies = [
+ "bincode",
+ "byteorder",
+ "heed-traits",
+ "serde",
+ "serde_json",
+]
+
[[package]]
name = "hermit-abi"
version = "0.1.19"
@@ -5664,6 +5725,16 @@ dependencies = [
"sha2 0.10.7",
]
+[[package]]
+name = "lmdb-master-sys"
+version = "0.1.0"
+source = "git+https://github.com/meilisearch/heed?rev=036ac23f73a021894974b9adc815bc95b3e0482a#036ac23f73a021894974b9adc815bc95b3e0482a"
+dependencies = [
+ "cc",
+ "doxygen-rs",
+ "libc",
+]
+
[[package]]
name = "lock_api"
version = "0.4.10"
@@ -6683,6 +6754,16 @@ dependencies = [
"sha2 0.10.7",
]
+[[package]]
+name = "page_size"
+version = "0.6.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da"
+dependencies = [
+ "libc",
+ "winapi",
+]
+
[[package]]
name = "palette"
version = "0.7.5"
@@ -6856,7 +6937,31 @@ version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc"
dependencies = [
+ "phf_macros",
+ "phf_shared",
+]
+
+[[package]]
+name = "phf_generator"
+version = "0.11.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0"
+dependencies = [
+ "phf_shared",
+ "rand 0.8.5",
+]
+
+[[package]]
+name = "phf_macros"
+version = "0.11.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3444646e286606587e49f3bcf1679b8cef1dc2c5ecc29ddacaffc305180d464b"
+dependencies = [
+ "phf_generator",
"phf_shared",
+ "proc-macro2",
+ "quote",
+ "syn 2.0.48",
]
[[package]]
@@ -8473,6 +8578,35 @@ version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "58bf37232d3bb9a2c4e641ca2a11d83b5062066f88df7fed36c28772046d65ba"
+[[package]]
+name = "semantic_index"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "client",
+ "clock",
+ "collections",
+ "env_logger",
+ "fs",
+ "futures 0.3.28",
+ "futures-batch",
+ "gpui",
+ "heed",
+ "language",
+ "languages",
+ "log",
+ "open_ai",
+ "project",
+ "serde",
+ "serde_json",
+ "settings",
+ "sha2 0.10.7",
+ "smol",
+ "tempfile",
+ "util",
+ "worktree",
+]
+
[[package]]
name = "semantic_version"
version = "0.1.0"
@@ -9478,6 +9612,15 @@ version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160"
+[[package]]
+name = "synchronoise"
+version = "1.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3dbc01390fc626ce8d1cffe3376ded2b72a11bb70e1c75f404a210e4daa4def2"
+dependencies = [
+ "crossbeam-queue",
+]
+
[[package]]
name = "sys-locale"
version = "0.3.1"
@@ -73,6 +73,7 @@ members = [
"crates/task",
"crates/tasks_ui",
"crates/search",
+ "crates/semantic_index",
"crates/semantic_version",
"crates/settings",
"crates/snippet",
@@ -253,9 +254,11 @@ derive_more = "0.99.17"
emojis = "0.6.1"
env_logger = "0.9"
futures = "0.3"
+futures-batch = "0.6.1"
futures-lite = "1.13"
git2 = { version = "0.15", default-features = false }
globset = "0.4"
+heed = { git = "https://github.com/meilisearch/heed", rev = "036ac23f73a021894974b9adc815bc95b3e0482a", features = ["read-txn-no-tls"] }
hex = "0.4.3"
ignore = "0.4.22"
indoc = "1"
@@ -264,7 +264,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
);
assert_eq!(
- channel.next_event(cx),
+ channel.next_event(cx).await,
ChannelChatEvent::MessagesUpdated {
old_range: 2..2,
new_count: 1,
@@ -317,7 +317,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
);
assert_eq!(
- channel.next_event(cx),
+ channel.next_event(cx).await,
ChannelChatEvent::MessagesUpdated {
old_range: 0..0,
new_count: 2,
@@ -0,0 +1,9 @@
+CREATE TABLE IF NOT EXISTS "embeddings" (
+ "model" TEXT,
+ "digest" BYTEA,
+ "dimensions" FLOAT4[1536],
+ "retrieved_at" TIMESTAMP NOT NULL DEFAULT now(),
+ PRIMARY KEY ("model", "digest")
+);
+
+CREATE INDEX IF NOT EXISTS "idx_retrieved_at_on_embeddings" ON "embeddings" ("retrieved_at");
@@ -6,6 +6,7 @@ pub mod channels;
pub mod contacts;
pub mod contributors;
pub mod dev_servers;
+pub mod embeddings;
pub mod extensions;
pub mod hosted_projects;
pub mod messages;
@@ -0,0 +1,94 @@
+use super::*;
+use time::Duration;
+use time::OffsetDateTime;
+
+impl Database {
+ pub async fn get_embeddings(
+ &self,
+ model: &str,
+ digests: &[Vec<u8>],
+ ) -> Result<HashMap<Vec<u8>, Vec<f32>>> {
+ self.weak_transaction(|tx| async move {
+ let embeddings = {
+ let mut db_embeddings = embedding::Entity::find()
+ .filter(
+ embedding::Column::Model.eq(model).and(
+ embedding::Column::Digest
+ .is_in(digests.iter().map(|digest| digest.as_slice())),
+ ),
+ )
+ .stream(&*tx)
+ .await?;
+
+ let mut embeddings = HashMap::default();
+ while let Some(db_embedding) = db_embeddings.next().await {
+ let db_embedding = db_embedding?;
+ embeddings.insert(db_embedding.digest, db_embedding.dimensions);
+ }
+ embeddings
+ };
+
+ if !embeddings.is_empty() {
+ let now = OffsetDateTime::now_utc();
+ let retrieved_at = PrimitiveDateTime::new(now.date(), now.time());
+
+ embedding::Entity::update_many()
+ .filter(
+ embedding::Column::Digest
+ .is_in(embeddings.keys().map(|digest| digest.as_slice())),
+ )
+ .col_expr(embedding::Column::RetrievedAt, Expr::value(retrieved_at))
+ .exec(&*tx)
+ .await?;
+ }
+
+ Ok(embeddings)
+ })
+ .await
+ }
+
+ pub async fn save_embeddings(
+ &self,
+ model: &str,
+ embeddings: &HashMap<Vec<u8>, Vec<f32>>,
+ ) -> Result<()> {
+ self.weak_transaction(|tx| async move {
+ embedding::Entity::insert_many(embeddings.iter().map(|(digest, dimensions)| {
+ let now_offset_datetime = OffsetDateTime::now_utc();
+ let retrieved_at =
+ PrimitiveDateTime::new(now_offset_datetime.date(), now_offset_datetime.time());
+
+ embedding::ActiveModel {
+ model: ActiveValue::set(model.to_string()),
+ digest: ActiveValue::set(digest.clone()),
+ dimensions: ActiveValue::set(dimensions.clone()),
+ retrieved_at: ActiveValue::set(retrieved_at),
+ }
+ }))
+ .on_conflict(
+ OnConflict::columns([embedding::Column::Model, embedding::Column::Digest])
+ .do_nothing()
+ .to_owned(),
+ )
+ .exec_without_returning(&*tx)
+ .await?;
+ Ok(())
+ })
+ .await
+ }
+
+ pub async fn purge_old_embeddings(&self) -> Result<()> {
+ self.weak_transaction(|tx| async move {
+ embedding::Entity::delete_many()
+ .filter(
+ embedding::Column::RetrievedAt
+ .lte(OffsetDateTime::now_utc() - Duration::days(60)),
+ )
+ .exec(&*tx)
+ .await?;
+
+ Ok(())
+ })
+ .await
+ }
+}
@@ -11,6 +11,7 @@ pub mod channel_message_mention;
pub mod contact;
pub mod contributor;
pub mod dev_server;
+pub mod embedding;
pub mod extension;
pub mod extension_version;
pub mod feature_flag;
@@ -0,0 +1,18 @@
+use sea_orm::entity::prelude::*;
+use time::PrimitiveDateTime;
+
+#[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
+#[sea_orm(table_name = "embeddings")]
+pub struct Model {
+ #[sea_orm(primary_key)]
+ pub model: String,
+ #[sea_orm(primary_key)]
+ pub digest: Vec<u8>,
+ pub dimensions: Vec<f32>,
+ pub retrieved_at: PrimitiveDateTime,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {}
+
+impl ActiveModelBehavior for ActiveModel {}
@@ -2,6 +2,7 @@ mod buffer_tests;
mod channel_tests;
mod contributor_tests;
mod db_tests;
+mod embedding_tests;
mod extension_tests;
mod feature_flag_tests;
mod message_tests;
@@ -0,0 +1,84 @@
+use super::TestDb;
+use crate::db::embedding;
+use collections::HashMap;
+use sea_orm::{sea_query::Expr, ColumnTrait, EntityTrait, QueryFilter};
+use std::ops::Sub;
+use time::{Duration, OffsetDateTime, PrimitiveDateTime};
+
+// SQLite does not support array arguments, so we only test this against a real postgres instance
+#[gpui::test]
+async fn test_get_embeddings_postgres(cx: &mut gpui::TestAppContext) {
+ let test_db = TestDb::postgres(cx.executor().clone());
+ let db = test_db.db();
+
+ let provider = "test_model";
+ let digest1 = vec![1, 2, 3];
+ let digest2 = vec![4, 5, 6];
+ let embeddings = HashMap::from_iter([
+ (digest1.clone(), vec![0.1, 0.2, 0.3]),
+ (digest2.clone(), vec![0.4, 0.5, 0.6]),
+ ]);
+
+ // Save embeddings
+ db.save_embeddings(provider, &embeddings).await.unwrap();
+
+ // Retrieve embeddings
+ let retrieved_embeddings = db
+ .get_embeddings(provider, &[digest1.clone(), digest2.clone()])
+ .await
+ .unwrap();
+ assert_eq!(retrieved_embeddings.len(), 2);
+ assert!(retrieved_embeddings.contains_key(&digest1));
+ assert!(retrieved_embeddings.contains_key(&digest2));
+
+ // Check if the retrieved embeddings are correct
+ assert_eq!(retrieved_embeddings[&digest1], vec![0.1, 0.2, 0.3]);
+ assert_eq!(retrieved_embeddings[&digest2], vec![0.4, 0.5, 0.6]);
+}
+
+#[gpui::test]
+async fn test_purge_old_embeddings(cx: &mut gpui::TestAppContext) {
+ let test_db = TestDb::postgres(cx.executor().clone());
+ let db = test_db.db();
+
+ let model = "test_model";
+ let digest = vec![7, 8, 9];
+ let embeddings = HashMap::from_iter([(digest.clone(), vec![0.7, 0.8, 0.9])]);
+
+ // Save old embeddings
+ db.save_embeddings(model, &embeddings).await.unwrap();
+
+ // Reach into the DB and change the retrieved at to be > 60 days
+ db.weak_transaction(|tx| {
+ let digest = digest.clone();
+ async move {
+ let sixty_days_ago = OffsetDateTime::now_utc().sub(Duration::days(61));
+ let retrieved_at = PrimitiveDateTime::new(sixty_days_ago.date(), sixty_days_ago.time());
+
+ embedding::Entity::update_many()
+ .filter(
+ embedding::Column::Model
+ .eq(model)
+ .and(embedding::Column::Digest.eq(digest)),
+ )
+ .col_expr(embedding::Column::RetrievedAt, Expr::value(retrieved_at))
+ .exec(&*tx)
+ .await
+ .unwrap();
+
+ Ok(())
+ }
+ })
+ .await
+ .unwrap();
+
+ // Purge old embeddings
+ db.purge_old_embeddings().await.unwrap();
+
+ // Try to retrieve the purged embeddings
+ let retrieved_embeddings = db.get_embeddings(model, &[digest.clone()]).await.unwrap();
+ assert!(
+ retrieved_embeddings.is_empty(),
+ "Old embeddings should have been purged"
+ );
+}
@@ -6,8 +6,8 @@ use axum::{
Extension, Router,
};
use collab::{
- api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, AppState,
- Config, RateLimiter, Result,
+ api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor,
+ rpc::ResultExt, AppState, Config, RateLimiter, Result,
};
use db::Database;
use std::{
@@ -23,7 +23,7 @@ use tower_http::trace::TraceLayer;
use tracing_subscriber::{
filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt, Layer,
};
-use util::ResultExt;
+use util::ResultExt as _;
const VERSION: &str = env!("CARGO_PKG_VERSION");
const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
@@ -90,6 +90,7 @@ async fn main() -> Result<()> {
};
if is_collab {
+ state.db.purge_old_embeddings().await.trace_err();
RateLimiter::save_periodically(state.rate_limiter.clone(), state.executor.clone());
}
@@ -32,6 +32,8 @@ use axum::{
use collections::{HashMap, HashSet};
pub use connection_pool::{ConnectionPool, ZedVersion};
use core::fmt::{self, Debug, Formatter};
+use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL};
+use sha2::Digest;
use futures::{
channel::oneshot,
@@ -568,6 +570,22 @@ impl Server {
app_state.config.google_ai_api_key.clone(),
)
})
+ })
+ .add_request_handler({
+ user_handler(move |request, response, session| {
+ get_cached_embeddings(request, response, session)
+ })
+ })
+ .add_request_handler({
+ let app_state = app_state.clone();
+ user_handler(move |request, response, session| {
+ compute_embeddings(
+ request,
+ response,
+ session,
+ app_state.config.openai_api_key.clone(),
+ )
+ })
});
Arc::new(server)
@@ -4021,8 +4039,6 @@ async fn complete_with_open_ai(
session: UserSession,
api_key: Arc<str>,
) -> Result<()> {
- const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
-
let mut completion_stream = open_ai::stream_completion(
&session.http_client,
OPEN_AI_API_URL,
@@ -4276,6 +4292,128 @@ async fn count_tokens_with_language_model(
Ok(())
}
+struct ComputeEmbeddingsRateLimit;
+
+impl RateLimit for ComputeEmbeddingsRateLimit {
+ fn capacity() -> usize {
+ std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
+ .ok()
+ .and_then(|v| v.parse().ok())
+ .unwrap_or(120) // Picked arbitrarily
+ }
+
+ fn refill_duration() -> chrono::Duration {
+ chrono::Duration::hours(1)
+ }
+
+ fn db_name() -> &'static str {
+ "compute-embeddings"
+ }
+}
+
+async fn compute_embeddings(
+ request: proto::ComputeEmbeddings,
+ response: Response<proto::ComputeEmbeddings>,
+ session: UserSession,
+ api_key: Option<Arc<str>>,
+) -> Result<()> {
+ let api_key = api_key.context("no OpenAI API key configured on the server")?;
+ authorize_access_to_language_models(&session).await?;
+
+ session
+ .rate_limiter
+ .check::<ComputeEmbeddingsRateLimit>(session.user_id())
+ .await?;
+
+ let embeddings = match request.model.as_str() {
+ "openai/text-embedding-3-small" => {
+ open_ai::embed(
+ &session.http_client,
+ OPEN_AI_API_URL,
+ &api_key,
+ OpenAiEmbeddingModel::TextEmbedding3Small,
+ request.texts.iter().map(|text| text.as_str()),
+ )
+ .await?
+ }
+ provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
+ };
+
+ let embeddings = request
+ .texts
+ .iter()
+ .map(|text| {
+ let mut hasher = sha2::Sha256::new();
+ hasher.update(text.as_bytes());
+ let result = hasher.finalize();
+ result.to_vec()
+ })
+ .zip(
+ embeddings
+ .data
+ .into_iter()
+ .map(|embedding| embedding.embedding),
+ )
+ .collect::<HashMap<_, _>>();
+
+ let db = session.db().await;
+ db.save_embeddings(&request.model, &embeddings)
+ .await
+ .context("failed to save embeddings")
+ .trace_err();
+
+ response.send(proto::ComputeEmbeddingsResponse {
+ embeddings: embeddings
+ .into_iter()
+ .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
+ .collect(),
+ })?;
+ Ok(())
+}
+
+struct GetCachedEmbeddingsRateLimit;
+
+impl RateLimit for GetCachedEmbeddingsRateLimit {
+ fn capacity() -> usize {
+ std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
+ .ok()
+ .and_then(|v| v.parse().ok())
+ .unwrap_or(120) // Picked arbitrarily
+ }
+
+ fn refill_duration() -> chrono::Duration {
+ chrono::Duration::hours(1)
+ }
+
+ fn db_name() -> &'static str {
+ "get-cached-embeddings"
+ }
+}
+
+async fn get_cached_embeddings(
+ request: proto::GetCachedEmbeddings,
+ response: Response<proto::GetCachedEmbeddings>,
+ session: UserSession,
+) -> Result<()> {
+ authorize_access_to_language_models(&session).await?;
+
+ session
+ .rate_limiter
+ .check::<GetCachedEmbeddingsRateLimit>(session.user_id())
+ .await?;
+
+ let db = session.db().await;
+ let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
+
+ response.send(proto::GetCachedEmbeddingsResponse {
+ embeddings: embeddings
+ .into_iter()
+ .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
+ .collect(),
+ })?;
+ Ok(())
+}
+
async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
let db = session.db().await;
let flags = db.get_user_flags(session.user_id()).await?;
@@ -396,7 +396,7 @@ mod tests {
let blame = cx.new_model(|cx| GitBlame::new(buffer.clone(), project.clone(), cx));
- let event = project.next_event(cx);
+ let event = project.next_event(cx).await;
assert_eq!(
event,
project::Event::Notification(
@@ -7,7 +7,7 @@ use crate::{
TextSystem, View, ViewContext, VisualContext, WindowContext, WindowHandle, WindowOptions,
};
use anyhow::{anyhow, bail};
-use futures::{Stream, StreamExt};
+use futures::{channel::oneshot, Stream, StreamExt};
use std::{cell::RefCell, future::Future, ops::Deref, rc::Rc, sync::Arc, time::Duration};
/// A TestAppContext is provided to tests created with `#[gpui::test]`, it provides
@@ -479,31 +479,26 @@ impl TestAppContext {
impl<T: 'static> Model<T> {
/// Block until the next event is emitted by the model, then return it.
- pub fn next_event<Evt>(&self, cx: &mut TestAppContext) -> Evt
+ pub fn next_event<Event>(&self, cx: &mut TestAppContext) -> impl Future<Output = Event>
where
- Evt: Send + Clone + 'static,
- T: EventEmitter<Evt>,
+ Event: Send + Clone + 'static,
+ T: EventEmitter<Event>,
{
- let (tx, mut rx) = futures::channel::mpsc::unbounded();
- let _subscription = self.update(cx, |_, cx| {
+ let (tx, mut rx) = oneshot::channel();
+ let mut tx = Some(tx);
+ let subscription = self.update(cx, |_, cx| {
cx.subscribe(self, move |_, _, event, _| {
- tx.unbounded_send(event.clone()).ok();
+ if let Some(tx) = tx.take() {
+ _ = tx.send(event.clone());
+ }
})
});
- // Run other tasks until the event is emitted.
- loop {
- match rx.try_next() {
- Ok(Some(event)) => return event,
- Ok(None) => panic!("model was dropped"),
- Err(_) => {
- if !cx.executor().tick() {
- break;
- }
- }
- }
+ async move {
+ let event = rx.await.expect("no event emitted");
+ drop(subscription);
+ event
}
- panic!("no event received")
}
/// Returns a future that resolves when the model notifies.
@@ -372,7 +372,7 @@ impl BackgroundExecutor {
self.dispatcher.as_test().unwrap().rng()
}
- /// How many CPUs are available to the dispatcher
+ /// How many CPUs are available to the dispatcher.
pub fn num_cpus(&self) -> usize {
num_cpus::get()
}
@@ -440,6 +440,11 @@ impl<'a> Scope<'a> {
}
}
+ /// How many CPUs are available to the dispatcher.
+ pub fn num_cpus(&self) -> usize {
+ self.executor.num_cpus()
+ }
+
/// Spawn a future into this scope.
pub fn spawn<F>(&mut self, f: F)
where
@@ -72,7 +72,7 @@ pub use lsp::LanguageServerId;
pub use outline::{Outline, OutlineItem};
pub use syntax_map::{OwnedSyntaxLayer, SyntaxLayer};
pub use text::LineEnding;
-pub use tree_sitter::{Parser, Tree};
+pub use tree_sitter::{Node, Parser, Tree, TreeCursor};
use crate::language_settings::SoftWrap;
@@ -91,6 +91,16 @@ thread_local! {
};
}
+pub fn with_parser<F, R>(func: F) -> R
+where
+ F: FnOnce(&mut Parser) -> R,
+{
+ PARSER.with(|parser| {
+ let mut parser = parser.borrow_mut();
+ func(&mut parser)
+ })
+}
+
lazy_static! {
static ref NEXT_LANGUAGE_ID: AtomicUsize = Default::default();
static ref NEXT_GRAMMAR_ID: AtomicUsize = Default::default();
@@ -1,9 +1,11 @@
-use anyhow::{anyhow, Result};
+use anyhow::{anyhow, Context, Result};
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
use serde::{Deserialize, Serialize};
-use std::convert::TryFrom;
+use std::{convert::TryFrom, future::Future};
use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
+
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
@@ -188,3 +190,68 @@ pub async fn stream_completion(
}
}
}
+
+#[derive(Copy, Clone, Serialize, Deserialize)]
+pub enum OpenAiEmbeddingModel {
+ #[serde(rename = "text-embedding-3-small")]
+ TextEmbedding3Small,
+ #[serde(rename = "text-embedding-3-large")]
+ TextEmbedding3Large,
+}
+
+#[derive(Serialize)]
+struct OpenAiEmbeddingRequest<'a> {
+ model: OpenAiEmbeddingModel,
+ input: Vec<&'a str>,
+}
+
+#[derive(Deserialize)]
+pub struct OpenAiEmbeddingResponse {
+ pub data: Vec<OpenAiEmbedding>,
+}
+
+#[derive(Deserialize)]
+pub struct OpenAiEmbedding {
+ pub embedding: Vec<f32>,
+}
+
+pub fn embed<'a>(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ model: OpenAiEmbeddingModel,
+ texts: impl IntoIterator<Item = &'a str>,
+) -> impl 'static + Future<Output = Result<OpenAiEmbeddingResponse>> {
+ let uri = format!("{api_url}/embeddings");
+
+ let request = OpenAiEmbeddingRequest {
+ model,
+ input: texts.into_iter().collect(),
+ };
+ let body = AsyncBody::from(serde_json::to_string(&request).unwrap());
+ let request = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(uri)
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_key))
+ .body(body)
+ .map(|request| client.send(request));
+
+ async move {
+ let mut response = request?.await?;
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ if response.status().is_success() {
+ let response: OpenAiEmbeddingResponse =
+ serde_json::from_str(&body).context("failed to parse OpenAI embedding response")?;
+ Ok(response)
+ } else {
+ Err(anyhow!(
+ "error during embedding, status: {:?}, body: {:?}",
+ response.status(),
+ body
+ ))
+ }
+ }
+}
@@ -978,6 +978,50 @@ impl Project {
}
}
+ #[cfg(any(test, feature = "test-support"))]
+ pub async fn example(
+ root_paths: impl IntoIterator<Item = &Path>,
+ cx: &mut AsyncAppContext,
+ ) -> Model<Project> {
+ use clock::FakeSystemClock;
+
+ let fs = Arc::new(RealFs::default());
+ let languages = LanguageRegistry::test(cx.background_executor().clone());
+ let clock = Arc::new(FakeSystemClock::default());
+ let http_client = util::http::FakeHttpClient::with_404_response();
+ let client = cx
+ .update(|cx| client::Client::new(clock, http_client.clone(), cx))
+ .unwrap();
+ let user_store = cx
+ .new_model(|cx| UserStore::new(client.clone(), cx))
+ .unwrap();
+ let project = cx
+ .update(|cx| {
+ Project::local(
+ client,
+ node_runtime::FakeNodeRuntime::new(),
+ user_store,
+ Arc::new(languages),
+ fs,
+ cx,
+ )
+ })
+ .unwrap();
+ for path in root_paths {
+ let (tree, _) = project
+ .update(cx, |project, cx| {
+ project.find_or_create_local_worktree(path, true, cx)
+ })
+ .unwrap()
+ .await
+ .unwrap();
+ tree.update(cx, |tree, _| tree.as_local().unwrap().scan_complete())
+ .unwrap()
+ .await;
+ }
+ project
+ }
+
#[cfg(any(test, feature = "test-support"))]
pub async fn test(
fs: Arc<dyn Fs>,
@@ -1146,6 +1190,10 @@ impl Project {
self.user_store.clone()
}
+ pub fn node_runtime(&self) -> Option<&Arc<dyn NodeRuntime>> {
+ self.node.as_ref()
+ }
+
pub fn opened_buffers(&self) -> Vec<Model<Buffer>> {
self.opened_buffers
.values()
@@ -2661,7 +2661,7 @@ async fn test_file_changes_multiple_times_on_disk(cx: &mut gpui::TestAppContext)
)
.await
.unwrap();
- worktree.next_event(cx);
+ worktree.next_event(cx).await;
// Change the buffer's file again. Depending on the random seed, the
// previous file change may still be in progress.
@@ -2672,7 +2672,7 @@ async fn test_file_changes_multiple_times_on_disk(cx: &mut gpui::TestAppContext)
)
.await
.unwrap();
- worktree.next_event(cx);
+ worktree.next_event(cx).await;
cx.executor().run_until_parked();
let on_disk_text = fs.load(Path::new("/dir/file1")).await.unwrap();
@@ -2716,7 +2716,7 @@ async fn test_edit_buffer_while_it_reloads(cx: &mut gpui::TestAppContext) {
)
.await
.unwrap();
- worktree.next_event(cx);
+ worktree.next_event(cx).await;
cx.executor()
.spawn(cx.executor().simulate_random_delay())
@@ -204,6 +204,11 @@ message Envelope {
LanguageModelResponse language_model_response = 167;
CountTokensWithLanguageModel count_tokens_with_language_model = 168;
CountTokensResponse count_tokens_response = 169;
+ GetCachedEmbeddings get_cached_embeddings = 189;
+ GetCachedEmbeddingsResponse get_cached_embeddings_response = 190;
+ ComputeEmbeddings compute_embeddings = 191;
+ ComputeEmbeddingsResponse compute_embeddings_response = 192; // current max
+
UpdateChannelMessage update_channel_message = 170;
ChannelMessageUpdate channel_message_update = 171;
@@ -216,7 +221,7 @@ message Envelope {
MultiLspQueryResponse multi_lsp_query_response = 176;
CreateRemoteProject create_remote_project = 177;
- CreateRemoteProjectResponse create_remote_project_response = 188; // current max
+ CreateRemoteProjectResponse create_remote_project_response = 188;
CreateDevServer create_dev_server = 178;
CreateDevServerResponse create_dev_server_response = 179;
ShutdownDevServer shutdown_dev_server = 180;
@@ -1892,6 +1897,29 @@ message CountTokensResponse {
uint32 token_count = 1;
}
+message GetCachedEmbeddings {
+ string model = 1;
+ repeated bytes digests = 2;
+}
+
+message GetCachedEmbeddingsResponse {
+ repeated Embedding embeddings = 1;
+}
+
+message ComputeEmbeddings {
+ string model = 1;
+ repeated string texts = 2;
+}
+
+message ComputeEmbeddingsResponse {
+ repeated Embedding embeddings = 1;
+}
+
+message Embedding {
+ bytes digest = 1;
+ repeated float dimensions = 2;
+}
+
message BlameBuffer {
uint64 project_id = 1;
uint64 buffer_id = 2;
@@ -151,6 +151,8 @@ messages!(
(ChannelMessageSent, Foreground),
(ChannelMessageUpdate, Foreground),
(CompleteWithLanguageModel, Background),
+ (ComputeEmbeddings, Background),
+ (ComputeEmbeddingsResponse, Background),
(CopyProjectEntry, Foreground),
(CountTokensWithLanguageModel, Background),
(CountTokensResponse, Background),
@@ -174,6 +176,8 @@ messages!(
(FormatBuffers, Foreground),
(FormatBuffersResponse, Foreground),
(FuzzySearchUsers, Foreground),
+ (GetCachedEmbeddings, Background),
+ (GetCachedEmbeddingsResponse, Background),
(GetChannelMembers, Foreground),
(GetChannelMembersResponse, Foreground),
(GetChannelMessages, Background),
@@ -325,6 +329,7 @@ request_messages!(
(CancelCall, Ack),
(CopyProjectEntry, ProjectEntryResponse),
(CompleteWithLanguageModel, LanguageModelResponse),
+ (ComputeEmbeddings, ComputeEmbeddingsResponse),
(CountTokensWithLanguageModel, CountTokensResponse),
(CreateChannel, CreateChannelResponse),
(CreateProjectEntry, ProjectEntryResponse),
@@ -336,6 +341,7 @@ request_messages!(
(Follow, FollowResponse),
(FormatBuffers, FormatBuffersResponse),
(FuzzySearchUsers, UsersResponse),
+ (GetCachedEmbeddings, GetCachedEmbeddingsResponse),
(GetChannelMembers, GetChannelMembersResponse),
(GetChannelMessages, GetChannelMessagesResponse),
(GetChannelMessagesById, GetChannelMessagesResponse),
@@ -0,0 +1,48 @@
+[package]
+name = "semantic_index"
+description = "Process, chunk, and embed text as vectors for semantic search."
+version = "0.1.0"
+edition = "2021"
+publish = false
+license = "GPL-3.0-or-later"
+
+[lib]
+path = "src/semantic_index.rs"
+
+[dependencies]
+anyhow.workspace = true
+client.workspace = true
+clock.workspace = true
+collections.workspace = true
+fs.workspace = true
+futures.workspace = true
+futures-batch.workspace = true
+gpui.workspace = true
+language.workspace = true
+log.workspace = true
+heed.workspace = true
+open_ai.workspace = true
+project.workspace = true
+settings.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+sha2.workspace = true
+smol.workspace = true
+util. workspace = true
+worktree.workspace = true
+
+[dev-dependencies]
+env_logger.workspace = true
+client = { workspace = true, features = ["test-support"] }
+fs = { workspace = true, features = ["test-support"] }
+futures.workspace = true
+gpui = { workspace = true, features = ["test-support"] }
+language = { workspace = true, features = ["test-support"] }
+languages.workspace = true
+project = { workspace = true, features = ["test-support"] }
+tempfile.workspace = true
+util = { workspace = true, features = ["test-support"] }
+worktree = { workspace = true, features = ["test-support"] }
+
+[lints]
+workspace = true
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,140 @@
+use client::Client;
+use futures::channel::oneshot;
+use gpui::{App, Global, TestAppContext};
+use language::language_settings::AllLanguageSettings;
+use project::Project;
+use semantic_index::{OpenAiEmbeddingModel, OpenAiEmbeddingProvider, SemanticIndex};
+use settings::SettingsStore;
+use std::{path::Path, sync::Arc};
+use util::http::HttpClientWithUrl;
+
+pub fn init_test(cx: &mut TestAppContext) {
+ _ = cx.update(|cx| {
+ let store = SettingsStore::test(cx);
+ cx.set_global(store);
+ language::init(cx);
+ Project::init_settings(cx);
+ SettingsStore::update(cx, |store, cx| {
+ store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
+ });
+ });
+}
+
+fn main() {
+ env_logger::init();
+
+ use clock::FakeSystemClock;
+
+ App::new().run(|cx| {
+ let store = SettingsStore::test(cx);
+ cx.set_global(store);
+ language::init(cx);
+ Project::init_settings(cx);
+ SettingsStore::update(cx, |store, cx| {
+ store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
+ });
+
+ let clock = Arc::new(FakeSystemClock::default());
+ let http = Arc::new(HttpClientWithUrl::new("http://localhost:11434"));
+
+ let client = client::Client::new(clock, http.clone(), cx);
+ Client::set_global(client.clone(), cx);
+
+ let args: Vec<String> = std::env::args().collect();
+ if args.len() < 2 {
+ eprintln!("Usage: cargo run --example index -p semantic_index -- <project_path>");
+ cx.quit();
+ return;
+ }
+
+ // let embedding_provider = semantic_index::FakeEmbeddingProvider;
+
+ let api_key = std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
+ let embedding_provider = OpenAiEmbeddingProvider::new(
+ http.clone(),
+ OpenAiEmbeddingModel::TextEmbedding3Small,
+ open_ai::OPEN_AI_API_URL.to_string(),
+ api_key,
+ );
+
+ let semantic_index = SemanticIndex::new(
+ Path::new("/tmp/semantic-index-db.mdb"),
+ Arc::new(embedding_provider),
+ cx,
+ );
+
+ cx.spawn(|mut cx| async move {
+ let mut semantic_index = semantic_index.await.unwrap();
+
+ let project_path = Path::new(&args[1]);
+
+ let project = Project::example([project_path], &mut cx).await;
+
+ cx.update(|cx| {
+ let language_registry = project.read(cx).languages().clone();
+ let node_runtime = project.read(cx).node_runtime().unwrap().clone();
+ languages::init(language_registry, node_runtime, cx);
+ })
+ .unwrap();
+
+ let project_index = cx
+ .update(|cx| semantic_index.project_index(project.clone(), cx))
+ .unwrap();
+
+ let (tx, rx) = oneshot::channel();
+ let mut tx = Some(tx);
+ let subscription = cx.update(|cx| {
+ cx.subscribe(&project_index, move |_, event, _| {
+ if let Some(tx) = tx.take() {
+ _ = tx.send(*event);
+ }
+ })
+ });
+
+ let index_start = std::time::Instant::now();
+ rx.await.expect("no event emitted");
+ drop(subscription);
+ println!("Index time: {:?}", index_start.elapsed());
+
+ let results = cx
+ .update(|cx| {
+ let project_index = project_index.read(cx);
+ let query = "converting an anchor to a point";
+ project_index.search(query, 4, cx)
+ })
+ .unwrap()
+ .await;
+
+ for search_result in results {
+ let path = search_result.path.clone();
+
+ let content = cx
+ .update(|cx| {
+ let worktree = search_result.worktree.read(cx);
+ let entry_abs_path = worktree.abs_path().join(search_result.path.clone());
+ let fs = project.read(cx).fs().clone();
+ cx.spawn(|_| async move { fs.load(&entry_abs_path).await.unwrap() })
+ })
+ .unwrap()
+ .await;
+
+ let range = search_result.range.clone();
+ let content = content[search_result.range].to_owned();
+
+ println!(
+ "✄✄✄✄✄✄✄✄✄✄✄✄✄✄ {:?} @ {} ✄✄✄✄✄✄✄✄✄✄✄✄✄✄",
+ path, search_result.score
+ );
+ println!("{:?}:{:?}:{:?}", path, range.start, range.end);
+ println!("{}", content);
+ }
+
+ cx.background_executor()
+ .timer(std::time::Duration::from_secs(100000))
+ .await;
+
+ cx.update(|cx| cx.quit()).unwrap();
+ })
+ .detach();
+ });
+}
@@ -0,0 +1,3 @@
+fn main() {
+ println!("Hello Indexer!");
+}
@@ -0,0 +1,43 @@
+# Searching for a needle in a haystack
+
+When you have a large amount of text, it can be useful to search for a specific word or phrase. This is often referred to as "finding a needle in a haystack." In this markdown document, we're "hiding" a key phrase for our text search to find. Can you find it?
+
+## Instructions
+
+1. Use the search functionality in your text editor or markdown viewer to find the hidden phrase in this document.
+
+2. Once you've found the **phrase**, write it down and proceed to the next step.
+
+Honestly, I just want to fill up plenty of characters so that we chunk this markdown into several chunks.
+
+## Tips
+
+- Relax
+- Take a deep breath
+- Focus on the task at hand
+- Don't get distracted by other text
+- Use the search functionality to your advantage
+
+## Example code
+
+```python
+def search_for_needle(haystack, needle):
+ if needle in haystack:
+ return True
+ else:
+ return False
+```
+
+```javascript
+function searchForNeedle(haystack, needle) {
+ return haystack.includes(needle);
+}
+```
+
+## Background
+
+When creating an index for a book or searching for a specific term in a large document, the ability to quickly find a specific word or phrase is essential. This is where search functionality comes in handy. However, one should _remember_ that the search is only as good as the index that was built. As they say, garbage in, garbage out!
+
+## Conclusion
+
+Searching for a needle in a haystack can be a challenging task, but with the right tools and techniques, it becomes much easier. Whether you're looking for a specific word in a document or trying to find a key piece of information in a large dataset, the ability to search efficiently is a valuable skill to have.
@@ -0,0 +1,409 @@
+use language::{with_parser, Grammar, Tree};
+use serde::{Deserialize, Serialize};
+use sha2::{Digest, Sha256};
+use std::{cmp, ops::Range, sync::Arc};
+
+const CHUNK_THRESHOLD: usize = 1500;
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct Chunk {
+ pub range: Range<usize>,
+ pub digest: [u8; 32],
+}
+
+pub fn chunk_text(text: &str, grammar: Option<&Arc<Grammar>>) -> Vec<Chunk> {
+ if let Some(grammar) = grammar {
+ let tree = with_parser(|parser| {
+ parser
+ .set_language(&grammar.ts_language)
+ .expect("incompatible grammar");
+ parser.parse(&text, None).expect("invalid language")
+ });
+
+ chunk_parse_tree(tree, &text, CHUNK_THRESHOLD)
+ } else {
+ chunk_lines(&text)
+ }
+}
+
+fn chunk_parse_tree(tree: Tree, text: &str, chunk_threshold: usize) -> Vec<Chunk> {
+ let mut chunk_ranges = Vec::new();
+ let mut cursor = tree.walk();
+
+ let mut range = 0..0;
+ loop {
+ let node = cursor.node();
+
+ // If adding the node to the current chunk exceeds the threshold
+ if node.end_byte() - range.start > chunk_threshold {
+ // Try to descend into its first child. If we can't, flush the current
+ // range and try again.
+ if cursor.goto_first_child() {
+ continue;
+ } else if !range.is_empty() {
+ chunk_ranges.push(range.clone());
+ range.start = range.end;
+ continue;
+ }
+
+ // If we get here, the node itself has no children but is larger than the threshold.
+ // Break its text into arbitrary chunks.
+ split_text(text, range.clone(), node.end_byte(), &mut chunk_ranges);
+ }
+ range.end = node.end_byte();
+
+ // If we get here, we consumed the node. Advance to the next child, ascending if there isn't one.
+ while !cursor.goto_next_sibling() {
+ if !cursor.goto_parent() {
+ if !range.is_empty() {
+ chunk_ranges.push(range);
+ }
+
+ return chunk_ranges
+ .into_iter()
+ .map(|range| {
+ let digest = Sha256::digest(&text[range.clone()]).into();
+ Chunk { range, digest }
+ })
+ .collect();
+ }
+ }
+ }
+}
+
+fn chunk_lines(text: &str) -> Vec<Chunk> {
+ let mut chunk_ranges = Vec::new();
+ let mut range = 0..0;
+
+ let mut newlines = text.match_indices('\n').peekable();
+ while let Some((newline_ix, _)) = newlines.peek() {
+ let newline_ix = newline_ix + 1;
+ if newline_ix - range.start <= CHUNK_THRESHOLD {
+ range.end = newline_ix;
+ newlines.next();
+ } else {
+ if range.is_empty() {
+ split_text(text, range, newline_ix, &mut chunk_ranges);
+ range = newline_ix..newline_ix;
+ } else {
+ chunk_ranges.push(range.clone());
+ range.start = range.end;
+ }
+ }
+ }
+
+ if !range.is_empty() {
+ chunk_ranges.push(range);
+ }
+
+ chunk_ranges
+ .into_iter()
+ .map(|range| {
+ let mut hasher = Sha256::new();
+ hasher.update(&text[range.clone()]);
+ let mut digest = [0u8; 32];
+ digest.copy_from_slice(hasher.finalize().as_slice());
+ Chunk { range, digest }
+ })
+ .collect()
+}
+
+fn split_text(
+ text: &str,
+ mut range: Range<usize>,
+ max_end: usize,
+ chunk_ranges: &mut Vec<Range<usize>>,
+) {
+ while range.start < max_end {
+ range.end = cmp::min(range.start + CHUNK_THRESHOLD, max_end);
+ while !text.is_char_boundary(range.end) {
+ range.end -= 1;
+ }
+ chunk_ranges.push(range.clone());
+ range.start = range.end;
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use language::{tree_sitter_rust, Language, LanguageConfig, LanguageMatcher};
+
+ // This example comes from crates/gpui/examples/window_positioning.rs which
+ // has the property of being CHUNK_THRESHOLD < TEXT.len() < 2*CHUNK_THRESHOLD
+ static TEXT: &str = r#"
+ use gpui::*;
+
+ struct WindowContent {
+ text: SharedString,
+ }
+
+ impl Render for WindowContent {
+ fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
+ div()
+ .flex()
+ .bg(rgb(0x1e2025))
+ .size_full()
+ .justify_center()
+ .items_center()
+ .text_xl()
+ .text_color(rgb(0xffffff))
+ .child(self.text.clone())
+ }
+ }
+
+ fn main() {
+ App::new().run(|cx: &mut AppContext| {
+ // Create several new windows, positioned in the top right corner of each screen
+
+ for screen in cx.displays() {
+ let options = {
+ let popup_margin_width = DevicePixels::from(16);
+ let popup_margin_height = DevicePixels::from(-0) - DevicePixels::from(48);
+
+ let window_size = Size {
+ width: px(400.),
+ height: px(72.),
+ };
+
+ let screen_bounds = screen.bounds();
+ let size: Size<DevicePixels> = window_size.into();
+
+ let bounds = gpui::Bounds::<DevicePixels> {
+ origin: screen_bounds.upper_right()
+ - point(size.width + popup_margin_width, popup_margin_height),
+ size: window_size.into(),
+ };
+
+ WindowOptions {
+ // Set the bounds of the window in screen coordinates
+ bounds: Some(bounds),
+ // Specify the display_id to ensure the window is created on the correct screen
+ display_id: Some(screen.id()),
+
+ titlebar: None,
+ window_background: WindowBackgroundAppearance::default(),
+ focus: false,
+ show: true,
+ kind: WindowKind::PopUp,
+ is_movable: false,
+ fullscreen: false,
+ }
+ };
+
+ cx.open_window(options, |cx| {
+ cx.new_view(|_| WindowContent {
+ text: format!("{:?}", screen.id()).into(),
+ })
+ });
+ }
+ });
+ }"#;
+
+ fn setup_rust_language() -> Language {
+ Language::new(
+ LanguageConfig {
+ name: "Rust".into(),
+ matcher: LanguageMatcher {
+ path_suffixes: vec!["rs".to_string()],
+ ..Default::default()
+ },
+ ..Default::default()
+ },
+ Some(tree_sitter_rust::language()),
+ )
+ }
+
+ #[test]
+ fn test_chunk_text() {
+ let text = "a\n".repeat(1000);
+ let chunks = chunk_text(&text, None);
+ assert_eq!(
+ chunks.len(),
+ ((2000_f64) / (CHUNK_THRESHOLD as f64)).ceil() as usize
+ );
+ }
+
+ #[test]
+ fn test_chunk_text_grammar() {
+ // Let's set up a big text with some known segments
+ // We'll then chunk it and verify that the chunks are correct
+
+ let language = setup_rust_language();
+
+ let chunks = chunk_text(TEXT, language.grammar());
+ assert_eq!(chunks.len(), 2);
+
+ assert_eq!(chunks[0].range.start, 0);
+ assert_eq!(chunks[0].range.end, 1498);
+ // The break between chunks is right before the "Specify the display_id" comment
+
+ assert_eq!(chunks[1].range.start, 1498);
+ assert_eq!(chunks[1].range.end, 2396);
+ }
+
+ #[test]
+ fn test_chunk_parse_tree() {
+ let language = setup_rust_language();
+ let grammar = language.grammar().unwrap();
+
+ let tree = with_parser(|parser| {
+ parser
+ .set_language(&grammar.ts_language)
+ .expect("incompatible grammar");
+ parser.parse(TEXT, None).expect("invalid language")
+ });
+
+ let chunks = chunk_parse_tree(tree, TEXT, 250);
+ assert_eq!(chunks.len(), 11);
+ }
+
+ #[test]
+ fn test_chunk_unparsable() {
+ // Even if a chunk is unparsable, we should still be able to chunk it
+ let language = setup_rust_language();
+ let grammar = language.grammar().unwrap();
+
+ let text = r#"fn main() {"#;
+ let tree = with_parser(|parser| {
+ parser
+ .set_language(&grammar.ts_language)
+ .expect("incompatible grammar");
+ parser.parse(text, None).expect("invalid language")
+ });
+
+ let chunks = chunk_parse_tree(tree, text, 250);
+ assert_eq!(chunks.len(), 1);
+
+ assert_eq!(chunks[0].range.start, 0);
+ assert_eq!(chunks[0].range.end, 11);
+ }
+
+ #[test]
+ fn test_empty_text() {
+ let language = setup_rust_language();
+ let grammar = language.grammar().unwrap();
+
+ let tree = with_parser(|parser| {
+ parser
+ .set_language(&grammar.ts_language)
+ .expect("incompatible grammar");
+ parser.parse("", None).expect("invalid language")
+ });
+
+ let chunks = chunk_parse_tree(tree, "", CHUNK_THRESHOLD);
+ assert!(chunks.is_empty(), "Chunks should be empty for empty text");
+ }
+
+ #[test]
+ fn test_single_large_node() {
+ let large_text = "static ".to_owned() + "a".repeat(CHUNK_THRESHOLD - 1).as_str() + " = 2";
+
+ let language = setup_rust_language();
+ let grammar = language.grammar().unwrap();
+
+ let tree = with_parser(|parser| {
+ parser
+ .set_language(&grammar.ts_language)
+ .expect("incompatible grammar");
+ parser.parse(&large_text, None).expect("invalid language")
+ });
+
+ let chunks = chunk_parse_tree(tree, &large_text, CHUNK_THRESHOLD);
+
+ assert_eq!(
+ chunks.len(),
+ 3,
+ "Large chunks are broken up according to grammar as best as possible"
+ );
+
+ // Expect chunks to be static, aaaaaa..., and = 2
+ assert_eq!(chunks[0].range.start, 0);
+ assert_eq!(chunks[0].range.end, "static".len());
+
+ assert_eq!(chunks[1].range.start, "static".len());
+ assert_eq!(chunks[1].range.end, "static".len() + CHUNK_THRESHOLD);
+
+ assert_eq!(chunks[2].range.start, "static".len() + CHUNK_THRESHOLD);
+ assert_eq!(chunks[2].range.end, large_text.len());
+ }
+
+ #[test]
+ fn test_multiple_small_nodes() {
+ let small_text = "a b c d e f g h i j k l m n o p q r s t u v w x y z";
+ let language = setup_rust_language();
+ let grammar = language.grammar().unwrap();
+
+ let tree = with_parser(|parser| {
+ parser
+ .set_language(&grammar.ts_language)
+ .expect("incompatible grammar");
+ parser.parse(small_text, None).expect("invalid language")
+ });
+
+ let chunks = chunk_parse_tree(tree, small_text, 5);
+ assert!(
+ chunks.len() > 1,
+ "Should have multiple chunks for multiple small nodes"
+ );
+ }
+
+ #[test]
+ fn test_node_with_children() {
+ let nested_text = "fn main() { let a = 1; let b = 2; }";
+ let language = setup_rust_language();
+ let grammar = language.grammar().unwrap();
+
+ let tree = with_parser(|parser| {
+ parser
+ .set_language(&grammar.ts_language)
+ .expect("incompatible grammar");
+ parser.parse(nested_text, None).expect("invalid language")
+ });
+
+ let chunks = chunk_parse_tree(tree, nested_text, 10);
+ assert!(
+ chunks.len() > 1,
+ "Should have multiple chunks for a node with children"
+ );
+ }
+
+ #[test]
+ fn test_text_with_unparsable_sections() {
+ // This test uses purposefully hit-or-miss sizing of 11 characters per likely chunk
+ let mixed_text = "fn main() { let a = 1; let b = 2; } unparsable bits here";
+ let language = setup_rust_language();
+ let grammar = language.grammar().unwrap();
+
+ let tree = with_parser(|parser| {
+ parser
+ .set_language(&grammar.ts_language)
+ .expect("incompatible grammar");
+ parser.parse(mixed_text, None).expect("invalid language")
+ });
+
+ let chunks = chunk_parse_tree(tree, mixed_text, 11);
+ assert!(
+ chunks.len() > 1,
+ "Should handle both parsable and unparsable sections correctly"
+ );
+
+ let expected_chunks = [
+ "fn main() {",
+ " let a = 1;",
+ " let b = 2;",
+ " }",
+ " unparsable",
+ " bits here",
+ ];
+
+ for (i, chunk) in chunks.iter().enumerate() {
+ assert_eq!(
+ &mixed_text[chunk.range.clone()],
+ expected_chunks[i],
+ "Chunk {} should match",
+ i
+ );
+ }
+ }
+}
@@ -0,0 +1,125 @@
+mod cloud;
+mod ollama;
+mod open_ai;
+
+pub use cloud::*;
+pub use ollama::*;
+pub use open_ai::*;
+use sha2::{Digest, Sha256};
+
+use anyhow::Result;
+use futures::{future::BoxFuture, FutureExt};
+use serde::{Deserialize, Serialize};
+use std::{fmt, future};
+
+#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
+pub struct Embedding(Vec<f32>);
+
+impl Embedding {
+ pub fn new(mut embedding: Vec<f32>) -> Self {
+ let len = embedding.len();
+ let mut norm = 0f32;
+
+ for i in 0..len {
+ norm += embedding[i] * embedding[i];
+ }
+
+ norm = norm.sqrt();
+ for dimension in &mut embedding {
+ *dimension /= norm;
+ }
+
+ Self(embedding)
+ }
+
+ fn len(&self) -> usize {
+ self.0.len()
+ }
+
+ pub fn similarity(self, other: &Embedding) -> f32 {
+ debug_assert_eq!(self.0.len(), other.0.len());
+ self.0
+ .iter()
+ .copied()
+ .zip(other.0.iter().copied())
+ .map(|(a, b)| a * b)
+ .sum()
+ }
+}
+
+impl fmt::Display for Embedding {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ let digits_to_display = 3;
+
+ // Start the Embedding display format
+ write!(f, "Embedding(sized: {}; values: [", self.len())?;
+
+ for (index, value) in self.0.iter().enumerate().take(digits_to_display) {
+ // Lead with comma if not the first element
+ if index != 0 {
+ write!(f, ", ")?;
+ }
+ write!(f, "{:.3}", value)?;
+ }
+ if self.len() > digits_to_display {
+ write!(f, "...")?;
+ }
+ write!(f, "])")
+ }
+}
+
+/// Trait for embedding providers. Texts in, vectors out.
+pub trait EmbeddingProvider: Sync + Send {
+ fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>>;
+ fn batch_size(&self) -> usize;
+}
+
+#[derive(Debug)]
+pub struct TextToEmbed<'a> {
+ pub text: &'a str,
+ pub digest: [u8; 32],
+}
+
+impl<'a> TextToEmbed<'a> {
+ pub fn new(text: &'a str) -> Self {
+ let digest = Sha256::digest(text.as_bytes());
+ Self {
+ text,
+ digest: digest.into(),
+ }
+ }
+}
+
+pub struct FakeEmbeddingProvider;
+
+impl EmbeddingProvider for FakeEmbeddingProvider {
+ fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
+ let embeddings = texts
+ .iter()
+ .map(|_text| {
+ let mut embedding = vec![0f32; 1536];
+ for i in 0..embedding.len() {
+ embedding[i] = i as f32;
+ }
+ Embedding::new(embedding)
+ })
+ .collect();
+ future::ready(Ok(embeddings)).boxed()
+ }
+
+ fn batch_size(&self) -> usize {
+ 16
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[gpui::test]
+ fn test_normalize_embedding() {
+ let normalized = Embedding::new(vec![1.0, 1.0, 1.0]);
+ let value: f32 = 1.0 / 3.0_f32.sqrt();
+ assert_eq!(normalized, Embedding(vec![value; 3]));
+ }
+}
@@ -0,0 +1,88 @@
+use crate::{Embedding, EmbeddingProvider, TextToEmbed};
+use anyhow::{anyhow, Context, Result};
+use client::{proto, Client};
+use collections::HashMap;
+use futures::{future::BoxFuture, FutureExt};
+use std::sync::Arc;
+
+pub struct CloudEmbeddingProvider {
+ model: String,
+ client: Arc<Client>,
+}
+
+impl CloudEmbeddingProvider {
+ pub fn new(client: Arc<Client>) -> Self {
+ Self {
+ model: "openai/text-embedding-3-small".into(),
+ client,
+ }
+ }
+}
+
+impl EmbeddingProvider for CloudEmbeddingProvider {
+ fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
+ // First, fetch any embeddings that are cached based on the requested texts' digests
+ // Then compute any embeddings that are missing.
+ async move {
+ let cached_embeddings = self.client.request(proto::GetCachedEmbeddings {
+ model: self.model.clone(),
+ digests: texts
+ .iter()
+ .map(|to_embed| to_embed.digest.to_vec())
+ .collect(),
+ });
+ let mut embeddings = cached_embeddings
+ .await
+ .context("failed to fetch cached embeddings via cloud model")?
+ .embeddings
+ .into_iter()
+ .map(|embedding| {
+ let digest: [u8; 32] = embedding
+ .digest
+ .try_into()
+ .map_err(|_| anyhow!("invalid digest for cached embedding"))?;
+ Ok((digest, embedding.dimensions))
+ })
+ .collect::<Result<HashMap<_, _>>>()?;
+
+ let compute_embeddings_request = proto::ComputeEmbeddings {
+ model: self.model.clone(),
+ texts: texts
+ .iter()
+ .filter_map(|to_embed| {
+ if embeddings.contains_key(&to_embed.digest) {
+ None
+ } else {
+ Some(to_embed.text.to_string())
+ }
+ })
+ .collect(),
+ };
+ if !compute_embeddings_request.texts.is_empty() {
+ let missing_embeddings = self.client.request(compute_embeddings_request).await?;
+ for embedding in missing_embeddings.embeddings {
+ let digest: [u8; 32] = embedding
+ .digest
+ .try_into()
+ .map_err(|_| anyhow!("invalid digest for cached embedding"))?;
+ embeddings.insert(digest, embedding.dimensions);
+ }
+ }
+
+ texts
+ .iter()
+ .map(|to_embed| {
+ let dimensions = embeddings.remove(&to_embed.digest).with_context(|| {
+ format!("server did not return an embedding for {:?}", to_embed)
+ })?;
+ Ok(Embedding::new(dimensions))
+ })
+ .collect()
+ }
+ .boxed()
+ }
+
+ fn batch_size(&self) -> usize {
+ 2048
+ }
+}
@@ -0,0 +1,74 @@
+use anyhow::{Context as _, Result};
+use futures::{future::BoxFuture, AsyncReadExt, FutureExt};
+use serde::{Deserialize, Serialize};
+use std::sync::Arc;
+use util::http::HttpClient;
+
+use crate::{Embedding, EmbeddingProvider, TextToEmbed};
+
+pub enum OllamaEmbeddingModel {
+ NomicEmbedText,
+ MxbaiEmbedLarge,
+}
+
+pub struct OllamaEmbeddingProvider {
+ client: Arc<dyn HttpClient>,
+ model: OllamaEmbeddingModel,
+}
+
+#[derive(Serialize)]
+struct OllamaEmbeddingRequest {
+ model: String,
+ prompt: String,
+}
+
+#[derive(Deserialize)]
+struct OllamaEmbeddingResponse {
+ embedding: Vec<f32>,
+}
+
+impl OllamaEmbeddingProvider {
+ pub fn new(client: Arc<dyn HttpClient>, model: OllamaEmbeddingModel) -> Self {
+ Self { client, model }
+ }
+}
+
+impl EmbeddingProvider for OllamaEmbeddingProvider {
+ fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
+ //
+ let model = match self.model {
+ OllamaEmbeddingModel::NomicEmbedText => "nomic-embed-text",
+ OllamaEmbeddingModel::MxbaiEmbedLarge => "mxbai-embed-large",
+ };
+
+ futures::future::try_join_all(texts.into_iter().map(|to_embed| {
+ let request = OllamaEmbeddingRequest {
+ model: model.to_string(),
+ prompt: to_embed.text.to_string(),
+ };
+
+ let request = serde_json::to_string(&request).unwrap();
+
+ async {
+ let response = self
+ .client
+ .post_json("http://localhost:11434/api/embeddings", request.into())
+ .await?;
+
+ let mut body = String::new();
+ response.into_body().read_to_string(&mut body).await?;
+
+ let response: OllamaEmbeddingResponse =
+ serde_json::from_str(&body).context("Unable to pull response")?;
+
+ Ok(Embedding::new(response.embedding))
+ }
+ }))
+ .boxed()
+ }
+
+ fn batch_size(&self) -> usize {
+ // TODO: Figure out decent value
+ 10
+ }
+}
@@ -0,0 +1,55 @@
+use crate::{Embedding, EmbeddingProvider, TextToEmbed};
+use anyhow::Result;
+use futures::{future::BoxFuture, FutureExt};
+pub use open_ai::OpenAiEmbeddingModel;
+use std::sync::Arc;
+use util::http::HttpClient;
+
+pub struct OpenAiEmbeddingProvider {
+ client: Arc<dyn HttpClient>,
+ model: OpenAiEmbeddingModel,
+ api_url: String,
+ api_key: String,
+}
+
+impl OpenAiEmbeddingProvider {
+ pub fn new(
+ client: Arc<dyn HttpClient>,
+ model: OpenAiEmbeddingModel,
+ api_url: String,
+ api_key: String,
+ ) -> Self {
+ Self {
+ client,
+ model,
+ api_url,
+ api_key,
+ }
+ }
+}
+
+impl EmbeddingProvider for OpenAiEmbeddingProvider {
+ fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
+ let embed = open_ai::embed(
+ self.client.as_ref(),
+ &self.api_url,
+ &self.api_key,
+ self.model,
+ texts.iter().map(|to_embed| to_embed.text),
+ );
+ async move {
+ let response = embed.await?;
+ Ok(response
+ .data
+ .into_iter()
+ .map(|data| Embedding::new(data.embedding))
+ .collect())
+ }
+ .boxed()
+ }
+
+ fn batch_size(&self) -> usize {
+ // From https://platform.openai.com/docs/api-reference/embeddings/create
+ 2048
+ }
+}
@@ -0,0 +1,954 @@
+mod chunking;
+mod embedding;
+
+use anyhow::{anyhow, Context as _, Result};
+use chunking::{chunk_text, Chunk};
+use collections::{Bound, HashMap};
+pub use embedding::*;
+use fs::Fs;
+use futures::stream::StreamExt;
+use futures_batch::ChunksTimeoutStreamExt;
+use gpui::{
+ AppContext, AsyncAppContext, Context, EntityId, EventEmitter, Global, Model, ModelContext,
+ Subscription, Task, WeakModel,
+};
+use heed::types::{SerdeBincode, Str};
+use language::LanguageRegistry;
+use project::{Entry, Project, UpdatedEntriesSet, Worktree};
+use serde::{Deserialize, Serialize};
+use smol::channel;
+use std::{
+ cmp::Ordering,
+ future::Future,
+ ops::Range,
+ path::Path,
+ sync::Arc,
+ time::{Duration, SystemTime},
+};
+use util::ResultExt;
+use worktree::LocalSnapshot;
+
+pub struct SemanticIndex {
+ embedding_provider: Arc<dyn EmbeddingProvider>,
+ db_connection: heed::Env,
+ project_indices: HashMap<WeakModel<Project>, Model<ProjectIndex>>,
+}
+
+impl Global for SemanticIndex {}
+
+impl SemanticIndex {
+ pub fn new(
+ db_path: &Path,
+ embedding_provider: Arc<dyn EmbeddingProvider>,
+ cx: &mut AppContext,
+ ) -> Task<Result<Self>> {
+ let db_path = db_path.to_path_buf();
+ cx.spawn(|cx| async move {
+ let db_connection = cx
+ .background_executor()
+ .spawn(async move {
+ unsafe {
+ heed::EnvOpenOptions::new()
+ .map_size(1024 * 1024 * 1024)
+ .max_dbs(3000)
+ .open(db_path)
+ }
+ })
+ .await?;
+
+ Ok(SemanticIndex {
+ db_connection,
+ embedding_provider,
+ project_indices: HashMap::default(),
+ })
+ })
+ }
+
+ pub fn project_index(
+ &mut self,
+ project: Model<Project>,
+ cx: &mut AppContext,
+ ) -> Model<ProjectIndex> {
+ self.project_indices
+ .entry(project.downgrade())
+ .or_insert_with(|| {
+ cx.new_model(|cx| {
+ ProjectIndex::new(
+ project,
+ self.db_connection.clone(),
+ self.embedding_provider.clone(),
+ cx,
+ )
+ })
+ })
+ .clone()
+ }
+}
+
+pub struct ProjectIndex {
+ db_connection: heed::Env,
+ project: Model<Project>,
+ worktree_indices: HashMap<EntityId, WorktreeIndexHandle>,
+ language_registry: Arc<LanguageRegistry>,
+ fs: Arc<dyn Fs>,
+ last_status: Status,
+ embedding_provider: Arc<dyn EmbeddingProvider>,
+ _subscription: Subscription,
+}
+
+enum WorktreeIndexHandle {
+ Loading {
+ _task: Task<Result<()>>,
+ },
+ Loaded {
+ index: Model<WorktreeIndex>,
+ _subscription: Subscription,
+ },
+}
+
+impl ProjectIndex {
+ fn new(
+ project: Model<Project>,
+ db_connection: heed::Env,
+ embedding_provider: Arc<dyn EmbeddingProvider>,
+ cx: &mut ModelContext<Self>,
+ ) -> Self {
+ let language_registry = project.read(cx).languages().clone();
+ let fs = project.read(cx).fs().clone();
+ let mut this = ProjectIndex {
+ db_connection,
+ project: project.clone(),
+ worktree_indices: HashMap::default(),
+ language_registry,
+ fs,
+ last_status: Status::Idle,
+ embedding_provider,
+ _subscription: cx.subscribe(&project, Self::handle_project_event),
+ };
+ this.update_worktree_indices(cx);
+ this
+ }
+
+ fn handle_project_event(
+ &mut self,
+ _: Model<Project>,
+ event: &project::Event,
+ cx: &mut ModelContext<Self>,
+ ) {
+ match event {
+ project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
+ self.update_worktree_indices(cx);
+ }
+ _ => {}
+ }
+ }
+
+ fn update_worktree_indices(&mut self, cx: &mut ModelContext<Self>) {
+ let worktrees = self
+ .project
+ .read(cx)
+ .visible_worktrees(cx)
+ .filter_map(|worktree| {
+ if worktree.read(cx).is_local() {
+ Some((worktree.entity_id(), worktree))
+ } else {
+ None
+ }
+ })
+ .collect::<HashMap<_, _>>();
+
+ self.worktree_indices
+ .retain(|worktree_id, _| worktrees.contains_key(worktree_id));
+ for (worktree_id, worktree) in worktrees {
+ self.worktree_indices.entry(worktree_id).or_insert_with(|| {
+ let worktree_index = WorktreeIndex::load(
+ worktree.clone(),
+ self.db_connection.clone(),
+ self.language_registry.clone(),
+ self.fs.clone(),
+ self.embedding_provider.clone(),
+ cx,
+ );
+
+ let load_worktree = cx.spawn(|this, mut cx| async move {
+ if let Some(index) = worktree_index.await.log_err() {
+ this.update(&mut cx, |this, cx| {
+ this.worktree_indices.insert(
+ worktree_id,
+ WorktreeIndexHandle::Loaded {
+ _subscription: cx
+ .observe(&index, |this, _, cx| this.update_status(cx)),
+ index,
+ },
+ );
+ })?;
+ } else {
+ this.update(&mut cx, |this, _cx| {
+ this.worktree_indices.remove(&worktree_id)
+ })?;
+ }
+
+ this.update(&mut cx, |this, cx| this.update_status(cx))
+ });
+
+ WorktreeIndexHandle::Loading {
+ _task: load_worktree,
+ }
+ });
+ }
+
+ self.update_status(cx);
+ }
+
+ fn update_status(&mut self, cx: &mut ModelContext<Self>) {
+ let mut status = Status::Idle;
+ for index in self.worktree_indices.values() {
+ match index {
+ WorktreeIndexHandle::Loading { .. } => {
+ status = Status::Scanning;
+ break;
+ }
+ WorktreeIndexHandle::Loaded { index, .. } => {
+ if index.read(cx).status == Status::Scanning {
+ status = Status::Scanning;
+ break;
+ }
+ }
+ }
+ }
+
+ if status != self.last_status {
+ self.last_status = status;
+ cx.emit(status);
+ }
+ }
+
+ pub fn search(&self, query: &str, limit: usize, cx: &AppContext) -> Task<Vec<SearchResult>> {
+ let mut worktree_searches = Vec::new();
+ for worktree_index in self.worktree_indices.values() {
+ if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
+ worktree_searches
+ .push(index.read_with(cx, |index, cx| index.search(query, limit, cx)));
+ }
+ }
+
+ cx.spawn(|_| async move {
+ let mut results = Vec::new();
+ let worktree_searches = futures::future::join_all(worktree_searches).await;
+
+ for worktree_search_results in worktree_searches {
+ if let Some(worktree_search_results) = worktree_search_results.log_err() {
+ results.extend(worktree_search_results);
+ }
+ }
+
+ results
+ .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
+ results.truncate(limit);
+
+ results
+ })
+ }
+}
+
+pub struct SearchResult {
+ pub worktree: Model<Worktree>,
+ pub path: Arc<Path>,
+ pub range: Range<usize>,
+ pub score: f32,
+}
+
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+pub enum Status {
+ Idle,
+ Scanning,
+}
+
+impl EventEmitter<Status> for ProjectIndex {}
+
+struct WorktreeIndex {
+ worktree: Model<Worktree>,
+ db_connection: heed::Env,
+ db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
+ language_registry: Arc<LanguageRegistry>,
+ fs: Arc<dyn Fs>,
+ embedding_provider: Arc<dyn EmbeddingProvider>,
+ status: Status,
+ _index_entries: Task<Result<()>>,
+ _subscription: Subscription,
+}
+
+impl WorktreeIndex {
+ pub fn load(
+ worktree: Model<Worktree>,
+ db_connection: heed::Env,
+ language_registry: Arc<LanguageRegistry>,
+ fs: Arc<dyn Fs>,
+ embedding_provider: Arc<dyn EmbeddingProvider>,
+ cx: &mut AppContext,
+ ) -> Task<Result<Model<Self>>> {
+ let worktree_abs_path = worktree.read(cx).abs_path();
+ cx.spawn(|mut cx| async move {
+ let db = cx
+ .background_executor()
+ .spawn({
+ let db_connection = db_connection.clone();
+ async move {
+ let mut txn = db_connection.write_txn()?;
+ let db_name = worktree_abs_path.to_string_lossy();
+ let db = db_connection.create_database(&mut txn, Some(&db_name))?;
+ txn.commit()?;
+ anyhow::Ok(db)
+ }
+ })
+ .await?;
+ cx.new_model(|cx| {
+ Self::new(
+ worktree,
+ db_connection,
+ db,
+ language_registry,
+ fs,
+ embedding_provider,
+ cx,
+ )
+ })
+ })
+ }
+
+ fn new(
+ worktree: Model<Worktree>,
+ db_connection: heed::Env,
+ db: heed::Database<Str, SerdeBincode<EmbeddedFile>>,
+ language_registry: Arc<LanguageRegistry>,
+ fs: Arc<dyn Fs>,
+ embedding_provider: Arc<dyn EmbeddingProvider>,
+ cx: &mut ModelContext<Self>,
+ ) -> Self {
+ let (updated_entries_tx, updated_entries_rx) = channel::unbounded();
+ let _subscription = cx.subscribe(&worktree, move |_this, _worktree, event, _cx| {
+ if let worktree::Event::UpdatedEntries(update) = event {
+ _ = updated_entries_tx.try_send(update.clone());
+ }
+ });
+
+ Self {
+ db_connection,
+ db,
+ worktree,
+ language_registry,
+ fs,
+ embedding_provider,
+ status: Status::Idle,
+ _index_entries: cx.spawn(|this, cx| Self::index_entries(this, updated_entries_rx, cx)),
+ _subscription,
+ }
+ }
+
+ async fn index_entries(
+ this: WeakModel<Self>,
+ updated_entries: channel::Receiver<UpdatedEntriesSet>,
+ mut cx: AsyncAppContext,
+ ) -> Result<()> {
+ let index = this.update(&mut cx, |this, cx| {
+ cx.notify();
+ this.status = Status::Scanning;
+ this.index_entries_changed_on_disk(cx)
+ })?;
+ index.await.log_err();
+ this.update(&mut cx, |this, cx| {
+ this.status = Status::Idle;
+ cx.notify();
+ })?;
+
+ while let Ok(updated_entries) = updated_entries.recv().await {
+ let index = this.update(&mut cx, |this, cx| {
+ cx.notify();
+ this.status = Status::Scanning;
+ this.index_updated_entries(updated_entries, cx)
+ })?;
+ index.await.log_err();
+ this.update(&mut cx, |this, cx| {
+ this.status = Status::Idle;
+ cx.notify();
+ })?;
+ }
+
+ Ok(())
+ }
+
+ fn index_entries_changed_on_disk(&self, cx: &AppContext) -> impl Future<Output = Result<()>> {
+ let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
+ let worktree_abs_path = worktree.abs_path().clone();
+ let scan = self.scan_entries(worktree.clone(), cx);
+ let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
+ let embed = self.embed_files(chunk.files, cx);
+ let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
+ async move {
+ futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
+ Ok(())
+ }
+ }
+
+ fn index_updated_entries(
+ &self,
+ updated_entries: UpdatedEntriesSet,
+ cx: &AppContext,
+ ) -> impl Future<Output = Result<()>> {
+ let worktree = self.worktree.read(cx).as_local().unwrap().snapshot();
+ let worktree_abs_path = worktree.abs_path().clone();
+ let scan = self.scan_updated_entries(worktree, updated_entries, cx);
+ let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
+ let embed = self.embed_files(chunk.files, cx);
+ let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
+ async move {
+ futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
+ Ok(())
+ }
+ }
+
+ fn scan_entries(&self, worktree: LocalSnapshot, cx: &AppContext) -> ScanEntries {
+ let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
+ let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
+ let db_connection = self.db_connection.clone();
+ let db = self.db;
+ let task = cx.background_executor().spawn(async move {
+ let txn = db_connection
+ .read_txn()
+ .context("failed to create read transaction")?;
+ let mut db_entries = db
+ .iter(&txn)
+ .context("failed to create iterator")?
+ .move_between_keys()
+ .peekable();
+
+ let mut deletion_range: Option<(Bound<&str>, Bound<&str>)> = None;
+ for entry in worktree.files(false, 0) {
+ let entry_db_key = db_key_for_path(&entry.path);
+
+ let mut saved_mtime = None;
+ while let Some(db_entry) = db_entries.peek() {
+ match db_entry {
+ Ok((db_path, db_embedded_file)) => match (*db_path).cmp(&entry_db_key) {
+ Ordering::Less => {
+ if let Some(deletion_range) = deletion_range.as_mut() {
+ deletion_range.1 = Bound::Included(db_path);
+ } else {
+ deletion_range =
+ Some((Bound::Included(db_path), Bound::Included(db_path)));
+ }
+
+ db_entries.next();
+ }
+ Ordering::Equal => {
+ if let Some(deletion_range) = deletion_range.take() {
+ deleted_entry_ranges_tx
+ .send((
+ deletion_range.0.map(ToString::to_string),
+ deletion_range.1.map(ToString::to_string),
+ ))
+ .await?;
+ }
+ saved_mtime = db_embedded_file.mtime;
+ db_entries.next();
+ break;
+ }
+ Ordering::Greater => {
+ break;
+ }
+ },
+ Err(_) => return Err(db_entries.next().unwrap().unwrap_err())?,
+ }
+ }
+
+ if entry.mtime != saved_mtime {
+ updated_entries_tx.send(entry.clone()).await?;
+ }
+ }
+
+ if let Some(db_entry) = db_entries.next() {
+ let (db_path, _) = db_entry?;
+ deleted_entry_ranges_tx
+ .send((Bound::Included(db_path.to_string()), Bound::Unbounded))
+ .await?;
+ }
+
+ Ok(())
+ });
+
+ ScanEntries {
+ updated_entries: updated_entries_rx,
+ deleted_entry_ranges: deleted_entry_ranges_rx,
+ task,
+ }
+ }
+
+ fn scan_updated_entries(
+ &self,
+ worktree: LocalSnapshot,
+ updated_entries: UpdatedEntriesSet,
+ cx: &AppContext,
+ ) -> ScanEntries {
+ let (updated_entries_tx, updated_entries_rx) = channel::bounded(512);
+ let (deleted_entry_ranges_tx, deleted_entry_ranges_rx) = channel::bounded(128);
+ let task = cx.background_executor().spawn(async move {
+ for (path, entry_id, status) in updated_entries.iter() {
+ match status {
+ project::PathChange::Added
+ | project::PathChange::Updated
+ | project::PathChange::AddedOrUpdated => {
+ if let Some(entry) = worktree.entry_for_id(*entry_id) {
+ updated_entries_tx.send(entry.clone()).await?;
+ }
+ }
+ project::PathChange::Removed => {
+ let db_path = db_key_for_path(path);
+ deleted_entry_ranges_tx
+ .send((Bound::Included(db_path.clone()), Bound::Included(db_path)))
+ .await?;
+ }
+ project::PathChange::Loaded => {
+ // Do nothing.
+ }
+ }
+ }
+
+ Ok(())
+ });
+
+ ScanEntries {
+ updated_entries: updated_entries_rx,
+ deleted_entry_ranges: deleted_entry_ranges_rx,
+ task,
+ }
+ }
+
+ fn chunk_files(
+ &self,
+ worktree_abs_path: Arc<Path>,
+ entries: channel::Receiver<Entry>,
+ cx: &AppContext,
+ ) -> ChunkFiles {
+ let language_registry = self.language_registry.clone();
+ let fs = self.fs.clone();
+ let (chunked_files_tx, chunked_files_rx) = channel::bounded(2048);
+ let task = cx.spawn(|cx| async move {
+ cx.background_executor()
+ .scoped(|cx| {
+ for _ in 0..cx.num_cpus() {
+ cx.spawn(async {
+ while let Ok(entry) = entries.recv().await {
+ let entry_abs_path = worktree_abs_path.join(&entry.path);
+ let Some(text) = fs.load(&entry_abs_path).await.log_err() else {
+ continue;
+ };
+ let language = language_registry
+ .language_for_file_path(&entry.path)
+ .await
+ .ok();
+ let grammar =
+ language.as_ref().and_then(|language| language.grammar());
+ let chunked_file = ChunkedFile {
+ worktree_root: worktree_abs_path.clone(),
+ chunks: chunk_text(&text, grammar),
+ entry,
+ text,
+ };
+
+ if chunked_files_tx.send(chunked_file).await.is_err() {
+ return;
+ }
+ }
+ });
+ }
+ })
+ .await;
+ Ok(())
+ });
+
+ ChunkFiles {
+ files: chunked_files_rx,
+ task,
+ }
+ }
+
+ fn embed_files(
+ &self,
+ chunked_files: channel::Receiver<ChunkedFile>,
+ cx: &AppContext,
+ ) -> EmbedFiles {
+ let embedding_provider = self.embedding_provider.clone();
+ let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
+ let task = cx.background_executor().spawn(async move {
+ let mut chunked_file_batches =
+ chunked_files.chunks_timeout(512, Duration::from_secs(2));
+ while let Some(chunked_files) = chunked_file_batches.next().await {
+ // View the batch of files as a vec of chunks
+ // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
+ // Once those are done, reassemble it back into which files they belong to
+
+ let chunks = chunked_files
+ .iter()
+ .flat_map(|file| {
+ file.chunks.iter().map(|chunk| TextToEmbed {
+ text: &file.text[chunk.range.clone()],
+ digest: chunk.digest,
+ })
+ })
+ .collect::<Vec<_>>();
+
+ let mut embeddings = Vec::new();
+ for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
+ // todo!("add a retry facility")
+ embeddings.extend(embedding_provider.embed(embedding_batch).await?);
+ }
+
+ let mut embeddings = embeddings.into_iter();
+ for chunked_file in chunked_files {
+ let chunk_embeddings = embeddings
+ .by_ref()
+ .take(chunked_file.chunks.len())
+ .collect::<Vec<_>>();
+ let embedded_chunks = chunked_file
+ .chunks
+ .into_iter()
+ .zip(chunk_embeddings)
+ .map(|(chunk, embedding)| EmbeddedChunk { chunk, embedding })
+ .collect();
+ let embedded_file = EmbeddedFile {
+ path: chunked_file.entry.path.clone(),
+ mtime: chunked_file.entry.mtime,
+ chunks: embedded_chunks,
+ };
+
+ embedded_files_tx.send(embedded_file).await?;
+ }
+ }
+ Ok(())
+ });
+
+ EmbedFiles {
+ files: embedded_files_rx,
+ task,
+ }
+ }
+
+ fn persist_embeddings(
+ &self,
+ mut deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
+ embedded_files: channel::Receiver<EmbeddedFile>,
+ cx: &AppContext,
+ ) -> Task<Result<()>> {
+ let db_connection = self.db_connection.clone();
+ let db = self.db;
+ cx.background_executor().spawn(async move {
+ while let Some(deletion_range) = deleted_entry_ranges.next().await {
+ let mut txn = db_connection.write_txn()?;
+ let start = deletion_range.0.as_ref().map(|start| start.as_str());
+ let end = deletion_range.1.as_ref().map(|end| end.as_str());
+ log::debug!("deleting embeddings in range {:?}", &(start, end));
+ db.delete_range(&mut txn, &(start, end))?;
+ txn.commit()?;
+ }
+
+ let mut embedded_files = embedded_files.chunks_timeout(4096, Duration::from_secs(2));
+ while let Some(embedded_files) = embedded_files.next().await {
+ let mut txn = db_connection.write_txn()?;
+ for file in embedded_files {
+ log::debug!("saving embedding for file {:?}", file.path);
+ let key = db_key_for_path(&file.path);
+ db.put(&mut txn, &key, &file)?;
+ }
+ txn.commit()?;
+ log::debug!("committed");
+ }
+
+ Ok(())
+ })
+ }
+
+ fn search(
+ &self,
+ query: &str,
+ limit: usize,
+ cx: &AppContext,
+ ) -> Task<Result<Vec<SearchResult>>> {
+ let (chunks_tx, chunks_rx) = channel::bounded(1024);
+
+ let db_connection = self.db_connection.clone();
+ let db = self.db;
+ let scan_chunks = cx.background_executor().spawn({
+ async move {
+ let txn = db_connection
+ .read_txn()
+ .context("failed to create read transaction")?;
+ let db_entries = db.iter(&txn).context("failed to iterate database")?;
+ for db_entry in db_entries {
+ let (_, db_embedded_file) = db_entry?;
+ for chunk in db_embedded_file.chunks {
+ chunks_tx
+ .send((db_embedded_file.path.clone(), chunk))
+ .await?;
+ }
+ }
+ anyhow::Ok(())
+ }
+ });
+
+ let query = query.to_string();
+ let embedding_provider = self.embedding_provider.clone();
+ let worktree = self.worktree.clone();
+ cx.spawn(|cx| async move {
+ #[cfg(debug_assertions)]
+ let embedding_query_start = std::time::Instant::now();
+
+ let mut query_embeddings = embedding_provider
+ .embed(&[TextToEmbed::new(&query)])
+ .await?;
+ let query_embedding = query_embeddings
+ .pop()
+ .ok_or_else(|| anyhow!("no embedding for query"))?;
+ let mut workers = Vec::new();
+ for _ in 0..cx.background_executor().num_cpus() {
+ workers.push(Vec::<SearchResult>::new());
+ }
+
+ #[cfg(debug_assertions)]
+ let search_start = std::time::Instant::now();
+
+ cx.background_executor()
+ .scoped(|cx| {
+ for worker_results in workers.iter_mut() {
+ cx.spawn(async {
+ while let Ok((path, embedded_chunk)) = chunks_rx.recv().await {
+ let score = embedded_chunk.embedding.similarity(&query_embedding);
+ let ix = match worker_results.binary_search_by(|probe| {
+ score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
+ }) {
+ Ok(ix) | Err(ix) => ix,
+ };
+ worker_results.insert(
+ ix,
+ SearchResult {
+ worktree: worktree.clone(),
+ path: path.clone(),
+ range: embedded_chunk.chunk.range.clone(),
+ score,
+ },
+ );
+ worker_results.truncate(limit);
+ }
+ });
+ }
+ })
+ .await;
+ scan_chunks.await?;
+
+ let mut search_results = Vec::with_capacity(workers.len() * limit);
+ for worker_results in workers {
+ search_results.extend(worker_results);
+ }
+ search_results
+ .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
+ search_results.truncate(limit);
+ #[cfg(debug_assertions)]
+ {
+ let search_elapsed = search_start.elapsed();
+ log::debug!(
+ "searched {} entries in {:?}",
+ search_results.len(),
+ search_elapsed
+ );
+ let embedding_query_elapsed = embedding_query_start.elapsed();
+ log::debug!("embedding query took {:?}", embedding_query_elapsed);
+ }
+
+ Ok(search_results)
+ })
+ }
+}
+
+struct ScanEntries {
+ updated_entries: channel::Receiver<Entry>,
+ deleted_entry_ranges: channel::Receiver<(Bound<String>, Bound<String>)>,
+ task: Task<Result<()>>,
+}
+
+struct ChunkFiles {
+ files: channel::Receiver<ChunkedFile>,
+ task: Task<Result<()>>,
+}
+
+struct ChunkedFile {
+ #[allow(dead_code)]
+ pub worktree_root: Arc<Path>,
+ pub entry: Entry,
+ pub text: String,
+ pub chunks: Vec<Chunk>,
+}
+
+struct EmbedFiles {
+ files: channel::Receiver<EmbeddedFile>,
+ task: Task<Result<()>>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct EmbeddedFile {
+ path: Arc<Path>,
+ mtime: Option<SystemTime>,
+ chunks: Vec<EmbeddedChunk>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct EmbeddedChunk {
+ chunk: Chunk,
+ embedding: Embedding,
+}
+
+fn db_key_for_path(path: &Arc<Path>) -> String {
+ path.to_string_lossy().replace('/', "\0")
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use futures::channel::oneshot;
+ use futures::{future::BoxFuture, FutureExt};
+
+ use gpui::{Global, TestAppContext};
+ use language::language_settings::AllLanguageSettings;
+ use project::Project;
+ use settings::SettingsStore;
+ use std::{future, path::Path, sync::Arc};
+
+ fn init_test(cx: &mut TestAppContext) {
+ _ = cx.update(|cx| {
+ let store = SettingsStore::test(cx);
+ cx.set_global(store);
+ language::init(cx);
+ Project::init_settings(cx);
+ SettingsStore::update(cx, |store, cx| {
+ store.update_user_settings::<AllLanguageSettings>(cx, |_| {});
+ });
+ });
+ }
+
+ pub struct TestEmbeddingProvider;
+
+ impl EmbeddingProvider for TestEmbeddingProvider {
+ fn embed<'a>(
+ &'a self,
+ texts: &'a [TextToEmbed<'a>],
+ ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
+ let embeddings = texts
+ .iter()
+ .map(|text| {
+ let mut embedding = vec![0f32; 2];
+ // if the text contains garbage, give it a 1 in the first dimension
+ if text.text.contains("garbage in") {
+ embedding[0] = 0.9;
+ } else {
+ embedding[0] = -0.9;
+ }
+
+ if text.text.contains("garbage out") {
+ embedding[1] = 0.9;
+ } else {
+ embedding[1] = -0.9;
+ }
+
+ Embedding::new(embedding)
+ })
+ .collect();
+ future::ready(Ok(embeddings)).boxed()
+ }
+
+ fn batch_size(&self) -> usize {
+ 16
+ }
+ }
+
+ #[gpui::test]
+ async fn test_search(cx: &mut TestAppContext) {
+ cx.executor().allow_parking();
+
+ init_test(cx);
+
+ let temp_dir = tempfile::tempdir().unwrap();
+
+ let mut semantic_index = cx
+ .update(|cx| {
+ let semantic_index = SemanticIndex::new(
+ Path::new(temp_dir.path()),
+ Arc::new(TestEmbeddingProvider),
+ cx,
+ );
+ semantic_index
+ })
+ .await
+ .unwrap();
+
+ // todo!(): use a fixture
+ let project_path = Path::new("./fixture");
+
+ let project = cx
+ .spawn(|mut cx| async move { Project::example([project_path], &mut cx).await })
+ .await;
+
+ cx.update(|cx| {
+ let language_registry = project.read(cx).languages().clone();
+ let node_runtime = project.read(cx).node_runtime().unwrap().clone();
+ languages::init(language_registry, node_runtime, cx);
+ });
+
+ let project_index = cx.update(|cx| semantic_index.project_index(project.clone(), cx));
+
+ let (tx, rx) = oneshot::channel();
+ let mut tx = Some(tx);
+ let subscription = cx.update(|cx| {
+ cx.subscribe(&project_index, move |_, event, _| {
+ if let Some(tx) = tx.take() {
+ _ = tx.send(*event);
+ }
+ })
+ });
+
+ rx.await.expect("no event emitted");
+ drop(subscription);
+
+ let results = cx
+ .update(|cx| {
+ let project_index = project_index.read(cx);
+ let query = "garbage in, garbage out";
+ project_index.search(query, 4, cx)
+ })
+ .await;
+
+ assert!(results.len() > 1, "should have found some results");
+
+ for result in &results {
+ println!("result: {:?}", result.path);
+ println!("score: {:?}", result.score);
+ }
+
+ // Find result that is greater than 0.5
+ let search_result = results.iter().find(|result| result.score > 0.9).unwrap();
+
+ assert_eq!(search_result.path.to_string_lossy(), "needle.md");
+
+ let content = cx
+ .update(|cx| {
+ let worktree = search_result.worktree.read(cx);
+ let entry_abs_path = worktree.abs_path().join(search_result.path.clone());
+ let fs = project.read(cx).fs().clone();
+ cx.spawn(|_| async move { fs.load(&entry_abs_path).await.unwrap() })
+ })
+ .await;
+
+ let range = search_result.range.clone();
+ let content = content[range.clone()].to_owned();
+
+ assert!(content.contains("garbage in, garbage out"));
+ }
+}
@@ -71,19 +71,28 @@ impl HttpClientWithUrl {
}
impl HttpClient for Arc<HttpClientWithUrl> {
- fn send(&self, req: Request<AsyncBody>) -> BoxFuture<Result<Response<AsyncBody>, Error>> {
+ fn send(
+ &self,
+ req: Request<AsyncBody>,
+ ) -> BoxFuture<'static, Result<Response<AsyncBody>, Error>> {
self.client.send(req)
}
}
impl HttpClient for HttpClientWithUrl {
- fn send(&self, req: Request<AsyncBody>) -> BoxFuture<Result<Response<AsyncBody>, Error>> {
+ fn send(
+ &self,
+ req: Request<AsyncBody>,
+ ) -> BoxFuture<'static, Result<Response<AsyncBody>, Error>> {
self.client.send(req)
}
}
pub trait HttpClient: Send + Sync {
- fn send(&self, req: Request<AsyncBody>) -> BoxFuture<Result<Response<AsyncBody>, Error>>;
+ fn send(
+ &self,
+ req: Request<AsyncBody>,
+ ) -> BoxFuture<'static, Result<Response<AsyncBody>, Error>>;
fn get<'a>(
&'a self,
@@ -135,8 +144,12 @@ pub fn client() -> Arc<dyn HttpClient> {
}
impl HttpClient for isahc::HttpClient {
- fn send(&self, req: Request<AsyncBody>) -> BoxFuture<Result<Response<AsyncBody>, Error>> {
- Box::pin(async move { self.send_async(req).await })
+ fn send(
+ &self,
+ req: Request<AsyncBody>,
+ ) -> BoxFuture<'static, Result<Response<AsyncBody>, Error>> {
+ let client = self.clone();
+ Box::pin(async move { client.send_async(req).await })
}
}
@@ -196,7 +209,10 @@ impl fmt::Debug for FakeHttpClient {
#[cfg(feature = "test-support")]
impl HttpClient for FakeHttpClient {
- fn send(&self, req: Request<AsyncBody>) -> BoxFuture<Result<Response<AsyncBody>, Error>> {
+ fn send(
+ &self,
+ req: Request<AsyncBody>,
+ ) -> BoxFuture<'static, Result<Response<AsyncBody>, Error>> {
let future = (self.handler)(req);
Box::pin(async move { future.await.map(Into::into) })
}