Detailed changes
@@ -85,32 +85,6 @@ dependencies = [
"memchr",
]
-[[package]]
-name = "ai"
-version = "0.1.0"
-dependencies = [
- "anyhow",
- "async-trait",
- "bincode",
- "futures 0.3.28",
- "gpui",
- "isahc",
- "language",
- "log",
- "matrixmultiply",
- "ordered-float 2.10.0",
- "parking_lot",
- "parse_duration",
- "postage",
- "rand 0.8.5",
- "rusqlite",
- "schemars",
- "serde",
- "serde_json",
- "tiktoken-rs",
- "util",
-]
-
[[package]]
name = "alacritty_terminal"
version = "0.22.1-dev"
@@ -339,9 +313,9 @@ dependencies = [
name = "assistant"
version = "0.1.0"
dependencies = [
- "ai",
"anyhow",
"chrono",
+ "client",
"collections",
"ctor",
"editor",
@@ -354,13 +328,14 @@ dependencies = [
"log",
"menu",
"multi_buffer",
+ "open_ai",
"ordered-float 2.10.0",
+ "parking_lot",
"project",
"rand 0.8.5",
"regex",
"schemars",
"search",
- "semantic_index",
"serde",
"serde_json",
"settings",
@@ -1339,7 +1314,7 @@ version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6773ddc0eafc0e509fb60e48dff7f450f8e674a0686ae8605e8d9901bd5eefa"
dependencies = [
- "num-bigint 0.4.4",
+ "num-bigint",
"num-integer",
"num-traits",
]
@@ -2209,11 +2184,11 @@ dependencies = [
"fs",
"futures 0.3.28",
"git",
+ "google_ai",
"gpui",
"hex",
"indoc",
"language",
- "lazy_static",
"live_kit_client",
"live_kit_server",
"log",
@@ -2222,6 +2197,7 @@ dependencies = [
"nanoid",
"node_runtime",
"notifications",
+ "open_ai",
"parking_lot",
"pretty_assertions",
"project",
@@ -3554,24 +3530,12 @@ dependencies = [
"workspace",
]
-[[package]]
-name = "fallible-iterator"
-version = "0.2.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
-
[[package]]
name = "fallible-iterator"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649"
-[[package]]
-name = "fallible-streaming-iterator"
-version = "0.1.9"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
-
[[package]]
name = "fancy-regex"
version = "0.11.0"
@@ -4183,7 +4147,7 @@ version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0"
dependencies = [
- "fallible-iterator 0.3.0",
+ "fallible-iterator",
"indexmap 2.0.0",
"stable_deref_trait",
]
@@ -4279,6 +4243,17 @@ dependencies = [
"workspace",
]
+[[package]]
+name = "google_ai"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "futures 0.3.28",
+ "serde",
+ "serde_json",
+ "util",
+]
+
[[package]]
name = "gpu-alloc"
version = "0.6.0"
@@ -5667,16 +5642,6 @@ version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
-[[package]]
-name = "matrixmultiply"
-version = "0.3.8"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2"
-dependencies = [
- "autocfg",
- "rawpointer",
-]
-
[[package]]
name = "maybe-owned"
version = "0.3.4"
@@ -5946,19 +5911,6 @@ dependencies = [
"tempfile",
]
-[[package]]
-name = "ndarray"
-version = "0.15.6"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
-dependencies = [
- "matrixmultiply",
- "num-complex 0.4.4",
- "num-integer",
- "num-traits",
- "rawpointer",
-]
-
[[package]]
name = "ndk"
version = "0.7.0"
@@ -6111,45 +6063,20 @@ dependencies = [
"winapi",
]
-[[package]]
-name = "num"
-version = "0.2.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36"
-dependencies = [
- "num-bigint 0.2.6",
- "num-complex 0.2.4",
- "num-integer",
- "num-iter",
- "num-rational 0.2.4",
- "num-traits",
-]
-
[[package]]
name = "num"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af"
dependencies = [
- "num-bigint 0.4.4",
- "num-complex 0.4.4",
+ "num-bigint",
+ "num-complex",
"num-integer",
"num-iter",
"num-rational 0.4.1",
"num-traits",
]
-[[package]]
-name = "num-bigint"
-version = "0.2.6"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "090c7f9998ee0ff65aa5b723e4009f7b217707f1fb5ea551329cc4d6231fb304"
-dependencies = [
- "autocfg",
- "num-integer",
- "num-traits",
-]
-
[[package]]
name = "num-bigint"
version = "0.4.4"
@@ -6196,16 +6123,6 @@ dependencies = [
"zeroize",
]
-[[package]]
-name = "num-complex"
-version = "0.2.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95"
-dependencies = [
- "autocfg",
- "num-traits",
-]
-
[[package]]
name = "num-complex"
version = "0.4.4"
@@ -6247,18 +6164,6 @@ dependencies = [
"num-traits",
]
-[[package]]
-name = "num-rational"
-version = "0.2.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5c000134b5dbf44adc5cb772486d335293351644b801551abe8f75c84cfa4aef"
-dependencies = [
- "autocfg",
- "num-bigint 0.2.6",
- "num-integer",
- "num-traits",
-]
-
[[package]]
name = "num-rational"
version = "0.3.2"
@@ -6277,7 +6182,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0"
dependencies = [
"autocfg",
- "num-bigint 0.4.4",
+ "num-bigint",
"num-integer",
"num-traits",
]
@@ -6436,7 +6341,7 @@ dependencies = [
"futures-util",
"hkdf",
"hmac 0.12.1",
- "num 0.4.1",
+ "num",
"num-bigint-dig 0.8.4",
"pbkdf2 0.12.2",
"rand 0.8.5",
@@ -6464,6 +6369,18 @@ dependencies = [
"pathdiff",
]
+[[package]]
+name = "open_ai"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "futures 0.3.28",
+ "schemars",
+ "serde",
+ "serde_json",
+ "util",
+]
+
[[package]]
name = "openssl"
version = "0.10.57"
@@ -6679,17 +6596,6 @@ dependencies = [
"windows-targets 0.48.5",
]
-[[package]]
-name = "parse_duration"
-version = "2.1.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7037e5e93e0172a5a96874380bf73bc6ecef022e26fa25f2be26864d6b3ba95d"
-dependencies = [
- "lazy_static",
- "num 0.2.1",
- "regex",
-]
-
[[package]]
name = "password-hash"
version = "0.2.1"
@@ -7471,12 +7377,6 @@ dependencies = [
"raw-window-handle 0.5.2",
]
-[[package]]
-name = "rawpointer"
-version = "0.2.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
-
[[package]]
name = "rayon"
version = "1.8.0"
@@ -7935,20 +7835,6 @@ dependencies = [
"zeroize",
]
-[[package]]
-name = "rusqlite"
-version = "0.29.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "549b9d036d571d42e6e85d1c1425e2ac83491075078ca9a15be021c56b1641f2"
-dependencies = [
- "bitflags 2.4.2",
- "fallible-iterator 0.2.0",
- "fallible-streaming-iterator",
- "hashlink",
- "libsqlite3-sys",
- "smallvec",
-]
-
[[package]]
name = "rust-embed"
version = "8.2.0"
@@ -8378,7 +8264,6 @@ dependencies = [
"language",
"menu",
"project",
- "semantic_index",
"serde",
"serde_json",
"settings",
@@ -8434,52 +8319,6 @@ version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "58bf37232d3bb9a2c4e641ca2a11d83b5062066f88df7fed36c28772046d65ba"
-[[package]]
-name = "semantic_index"
-version = "0.1.0"
-dependencies = [
- "ai",
- "anyhow",
- "collections",
- "ctor",
- "env_logger",
- "futures 0.3.28",
- "gpui",
- "language",
- "lazy_static",
- "log",
- "ndarray",
- "ordered-float 2.10.0",
- "parking_lot",
- "postage",
- "pretty_assertions",
- "project",
- "rand 0.8.5",
- "release_channel",
- "rpc",
- "rusqlite",
- "schemars",
- "serde",
- "serde_json",
- "settings",
- "sha1",
- "smol",
- "tempfile",
- "tree-sitter",
- "tree-sitter-cpp",
- "tree-sitter-elixir",
- "tree-sitter-json 0.20.0",
- "tree-sitter-lua",
- "tree-sitter-php",
- "tree-sitter-ruby",
- "tree-sitter-rust",
- "tree-sitter-toml",
- "tree-sitter-typescript",
- "unindent",
- "util",
- "workspace",
-]
-
[[package]]
name = "semver"
version = "1.0.18"
@@ -8766,7 +8605,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8eb4ea60fb301dc81dfc113df680571045d375ab7345d171c5dc7d7e13107a80"
dependencies = [
"chrono",
- "num-bigint 0.4.4",
+ "num-bigint",
"num-traits",
"thiserror",
]
@@ -9197,7 +9036,7 @@ dependencies = [
"log",
"md-5",
"memchr",
- "num-bigint 0.4.4",
+ "num-bigint",
"once_cell",
"rand 0.8.5",
"rust_decimal",
@@ -12729,7 +12568,6 @@ dependencies = [
"release_channel",
"rope",
"search",
- "semantic_index",
"serde",
"serde_json",
"settings",
@@ -1,7 +1,6 @@
[workspace]
members = [
"crates/activity_indicator",
- "crates/ai",
"crates/assets",
"crates/assistant",
"crates/audio",
@@ -34,6 +33,7 @@ members = [
"crates/fuzzy",
"crates/git",
"crates/go_to_line",
+ "crates/google_ai",
"crates/gpui",
"crates/gpui_macros",
"crates/image_viewer",
@@ -52,6 +52,7 @@ members = [
"crates/multi_buffer",
"crates/node_runtime",
"crates/notifications",
+ "crates/open_ai",
"crates/outline",
"crates/picker",
"crates/prettier",
@@ -69,7 +70,6 @@ members = [
"crates/task",
"crates/tasks_ui",
"crates/search",
- "crates/semantic_index",
"crates/settings",
"crates/snippet",
"crates/sqlez",
@@ -138,6 +138,7 @@ fsevent = { path = "crates/fsevent" }
fuzzy = { path = "crates/fuzzy" }
git = { path = "crates/git" }
go_to_line = { path = "crates/go_to_line" }
+google_ai = { path = "crates/google_ai" }
gpui = { path = "crates/gpui" }
gpui_macros = { path = "crates/gpui_macros" }
install_cli = { path = "crates/install_cli" }
@@ -156,6 +157,7 @@ menu = { path = "crates/menu" }
multi_buffer = { path = "crates/multi_buffer" }
node_runtime = { path = "crates/node_runtime" }
notifications = { path = "crates/notifications" }
+open_ai = { path = "crates/open_ai" }
outline = { path = "crates/outline" }
picker = { path = "crates/picker" }
plugin = { path = "crates/plugin" }
@@ -174,7 +176,6 @@ rpc = { path = "crates/rpc" }
task = { path = "crates/task" }
tasks_ui = { path = "crates/tasks_ui" }
search = { path = "crates/search" }
-semantic_index = { path = "crates/semantic_index" }
settings = { path = "crates/settings" }
snippet = { path = "crates/snippet" }
sqlez = { path = "crates/sqlez" }
@@ -251,7 +251,6 @@
"alt-tab": "search::CycleMode",
"cmd-shift-h": "search::ToggleReplace",
"alt-cmd-g": "search::ActivateRegexMode",
- "alt-cmd-s": "search::ActivateSemanticMode",
"alt-cmd-x": "search::ActivateTextMode"
}
},
@@ -276,7 +275,6 @@
"alt-tab": "search::CycleMode",
"cmd-shift-h": "search::ToggleReplace",
"alt-cmd-g": "search::ActivateRegexMode",
- "alt-cmd-s": "search::ActivateSemanticMode",
"alt-cmd-x": "search::ActivateTextMode"
}
},
@@ -302,7 +300,6 @@
"alt-tab": "search::CycleMode",
"alt-cmd-f": "project_search::ToggleFilters",
"alt-cmd-g": "search::ActivateRegexMode",
- "alt-cmd-s": "search::ActivateSemanticMode",
"alt-cmd-x": "search::ActivateTextMode"
}
},
@@ -237,6 +237,8 @@
"default_width": 380
},
"assistant": {
+ // Version of this setting.
+ "version": "1",
// Whether to show the assistant panel button in the status bar.
"button": true,
// Where to dock the assistant panel. Can be 'left', 'right' or 'bottom'.
@@ -245,28 +247,16 @@
"default_width": 640,
// Default height when the assistant is docked to the bottom.
"default_height": 320,
- // Deprecated: Please use `provider.api_url` instead.
- // The default OpenAI API endpoint to use when starting new conversations.
- "openai_api_url": "https://api.openai.com/v1",
- // Deprecated: Please use `provider.default_model` instead.
- // The default OpenAI model to use when starting new conversations. This
- // setting can take three values:
- //
- // 1. "gpt-3.5-turbo-0613""
- // 2. "gpt-4-0613""
- // 3. "gpt-4-1106-preview"
- "default_open_ai_model": "gpt-4-1106-preview",
+ // AI provider.
"provider": {
- "type": "openai",
- // The default OpenAI API endpoint to use when starting new conversations.
- "api_url": "https://api.openai.com/v1",
- // The default OpenAI model to use when starting new conversations. This
+ "name": "openai",
+ // The default model to use when starting new conversations. This
// setting can take three values:
//
- // 1. "gpt-3.5-turbo-0613""
- // 2. "gpt-4-0613""
- // 3. "gpt-4-1106-preview"
- "default_model": "gpt-4-1106-preview"
+ // 1. "gpt-3.5-turbo"
+ // 2. "gpt-4"
+ // 3. "gpt-4-turbo-preview"
+ "default_model": "gpt-4-turbo-preview"
}
},
// Whether the screen sharing icon is shown in the os status bar.
@@ -505,10 +495,6 @@
// Existing terminals will not pick up this change until they are recreated.
// "max_scroll_history_lines": 10000,
},
- // Difference settings for semantic_index
- "semantic_index": {
- "enabled": true
- },
// Settings specific to our elixir integration
"elixir": {
// Change the LSP zed uses for elixir.
@@ -1,41 +0,0 @@
-[package]
-name = "ai"
-version = "0.1.0"
-edition = "2021"
-publish = false
-license = "GPL-3.0-or-later"
-
-[lints]
-workspace = true
-
-[lib]
-path = "src/ai.rs"
-doctest = false
-
-[features]
-test-support = []
-
-[dependencies]
-anyhow.workspace = true
-async-trait.workspace = true
-bincode = "1.3.3"
-futures.workspace = true
-gpui.workspace = true
-isahc.workspace = true
-language.workspace = true
-log.workspace = true
-matrixmultiply = "0.3.7"
-ordered-float.workspace = true
-parking_lot.workspace = true
-parse_duration = "2.1.1"
-postage.workspace = true
-rand.workspace = true
-rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
-schemars.workspace = true
-serde.workspace = true
-serde_json.workspace = true
-tiktoken-rs.workspace = true
-util.workspace = true
-
-[dev-dependencies]
-gpui = { workspace = true, features = ["test-support"] }
@@ -1 +0,0 @@
-../../LICENSE-GPL
@@ -1,8 +0,0 @@
-pub mod auth;
-pub mod completion;
-pub mod embedding;
-pub mod models;
-pub mod prompts;
-pub mod providers;
-#[cfg(any(test, feature = "test-support"))]
-pub mod test;
@@ -1,23 +0,0 @@
-use futures::future::BoxFuture;
-use gpui::AppContext;
-
-#[derive(Clone, Debug)]
-pub enum ProviderCredential {
- Credentials { api_key: String },
- NoCredentials,
- NotNeeded,
-}
-
-pub trait CredentialProvider: Send + Sync {
- fn has_credentials(&self) -> bool;
- #[must_use]
- fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential>;
- #[must_use]
- fn save_credentials(
- &self,
- cx: &mut AppContext,
- credential: ProviderCredential,
- ) -> BoxFuture<()>;
- #[must_use]
- fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()>;
-}
@@ -1,23 +0,0 @@
-use anyhow::Result;
-use futures::{future::BoxFuture, stream::BoxStream};
-
-use crate::{auth::CredentialProvider, models::LanguageModel};
-
-pub trait CompletionRequest: Send + Sync {
- fn data(&self) -> serde_json::Result<String>;
-}
-
-pub trait CompletionProvider: CredentialProvider {
- fn base_model(&self) -> Box<dyn LanguageModel>;
- fn complete(
- &self,
- prompt: Box<dyn CompletionRequest>,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
- fn box_clone(&self) -> Box<dyn CompletionProvider>;
-}
-
-impl Clone for Box<dyn CompletionProvider> {
- fn clone(&self) -> Box<dyn CompletionProvider> {
- self.box_clone()
- }
-}
@@ -1,121 +0,0 @@
-use std::time::Instant;
-
-use anyhow::Result;
-use async_trait::async_trait;
-use ordered_float::OrderedFloat;
-use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
-use rusqlite::ToSql;
-
-use crate::auth::CredentialProvider;
-use crate::models::LanguageModel;
-
-#[derive(Debug, PartialEq, Clone)]
-pub struct Embedding(pub Vec<f32>);
-
-// This is needed for semantic index functionality
-// Unfortunately it has to live wherever the "Embedding" struct is created.
-// Keeping this in here though, introduces a 'rusqlite' dependency into AI
-// which is less than ideal
-impl FromSql for Embedding {
- fn column_result(value: ValueRef) -> FromSqlResult<Self> {
- let bytes = value.as_blob()?;
- let embedding =
- bincode::deserialize(bytes).map_err(|err| rusqlite::types::FromSqlError::Other(err))?;
- Ok(Embedding(embedding))
- }
-}
-
-impl ToSql for Embedding {
- fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
- let bytes = bincode::serialize(&self.0)
- .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
- Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
- }
-}
-impl From<Vec<f32>> for Embedding {
- fn from(value: Vec<f32>) -> Self {
- Embedding(value)
- }
-}
-
-impl Embedding {
- pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
- let len = self.0.len();
- assert_eq!(len, other.0.len());
-
- let mut result = 0.0;
- unsafe {
- matrixmultiply::sgemm(
- 1,
- len,
- 1,
- 1.0,
- self.0.as_ptr(),
- len as isize,
- 1,
- other.0.as_ptr(),
- 1,
- len as isize,
- 0.0,
- &mut result as *mut f32,
- 1,
- 1,
- );
- }
- OrderedFloat(result)
- }
-}
-
-#[async_trait]
-pub trait EmbeddingProvider: CredentialProvider {
- fn base_model(&self) -> Box<dyn LanguageModel>;
- async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
- fn max_tokens_per_batch(&self) -> usize;
- fn rate_limit_expiration(&self) -> Option<Instant>;
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use rand::prelude::*;
-
- #[gpui::test]
- fn test_similarity(mut rng: StdRng) {
- assert_eq!(
- Embedding::from(vec![1., 0., 0., 0., 0.])
- .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
- 0.
- );
- assert_eq!(
- Embedding::from(vec![2., 0., 0., 0., 0.])
- .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
- 6.
- );
-
- for _ in 0..100 {
- let size = 1536;
- let mut a = vec![0.; size];
- let mut b = vec![0.; size];
- for (a, b) in a.iter_mut().zip(b.iter_mut()) {
- *a = rng.gen();
- *b = rng.gen();
- }
- let a = Embedding::from(a);
- let b = Embedding::from(b);
-
- assert_eq!(
- round_to_decimals(a.similarity(&b), 1),
- round_to_decimals(reference_dot(&a.0, &b.0), 1)
- );
- }
-
- fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
- let factor = 10.0_f32.powi(decimal_places);
- (n * factor).round() / factor
- }
-
- fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
- OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
- }
- }
-}
@@ -1,16 +0,0 @@
-pub enum TruncationDirection {
- Start,
- End,
-}
-
-pub trait LanguageModel {
- fn name(&self) -> String;
- fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
- fn truncate(
- &self,
- content: &str,
- length: usize,
- direction: TruncationDirection,
- ) -> anyhow::Result<String>;
- fn capacity(&self) -> anyhow::Result<usize>;
-}
@@ -1,337 +0,0 @@
-use std::cmp::Reverse;
-use std::ops::Range;
-use std::sync::Arc;
-
-use language::BufferSnapshot;
-use util::ResultExt;
-
-use crate::models::LanguageModel;
-use crate::prompts::repository_context::PromptCodeSnippet;
-
-pub(crate) enum PromptFileType {
- Text,
- Code,
-}
-
-// TODO: Set this up to manage for defaults well
-pub struct PromptArguments {
- pub model: Arc<dyn LanguageModel>,
- pub user_prompt: Option<String>,
- pub language_name: Option<String>,
- pub project_name: Option<String>,
- pub snippets: Vec<PromptCodeSnippet>,
- pub reserved_tokens: usize,
- pub buffer: Option<BufferSnapshot>,
- pub selected_range: Option<Range<usize>>,
-}
-
-impl PromptArguments {
- pub(crate) fn get_file_type(&self) -> PromptFileType {
- if self
- .language_name
- .as_ref()
- .map(|name| !["Markdown", "Plain Text"].contains(&name.as_str()))
- .unwrap_or(true)
- {
- PromptFileType::Code
- } else {
- PromptFileType::Text
- }
- }
-}
-
-pub trait PromptTemplate {
- fn generate(
- &self,
- args: &PromptArguments,
- max_token_length: Option<usize>,
- ) -> anyhow::Result<(String, usize)>;
-}
-
-#[repr(i8)]
-#[derive(PartialEq, Eq)]
-pub enum PromptPriority {
- /// Ignores truncation.
- Mandatory,
- /// Truncates based on priority.
- Ordered { order: usize },
-}
-
-impl PartialOrd for PromptPriority {
- fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
- Some(self.cmp(other))
- }
-}
-
-impl Ord for PromptPriority {
- fn cmp(&self, other: &Self) -> std::cmp::Ordering {
- match (self, other) {
- (Self::Mandatory, Self::Mandatory) => std::cmp::Ordering::Equal,
- (Self::Mandatory, Self::Ordered { .. }) => std::cmp::Ordering::Greater,
- (Self::Ordered { .. }, Self::Mandatory) => std::cmp::Ordering::Less,
- (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.cmp(a),
- }
- }
-}
-
-pub struct PromptChain {
- args: PromptArguments,
- templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
-}
-
-impl PromptChain {
- pub fn new(
- args: PromptArguments,
- templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
- ) -> Self {
- PromptChain { args, templates }
- }
-
- pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
- // Argsort based on Prompt Priority
- let separator = "\n";
- let separator_tokens = self.args.model.count_tokens(separator)?;
- let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
- sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
-
- let mut tokens_outstanding = if truncate {
- Some(self.args.model.capacity()? - self.args.reserved_tokens)
- } else {
- None
- };
-
- let mut prompts = vec!["".to_string(); sorted_indices.len()];
- for idx in sorted_indices {
- let (_, template) = &self.templates[idx];
-
- if let Some((template_prompt, prompt_token_count)) =
- template.generate(&self.args, tokens_outstanding).log_err()
- {
- if template_prompt != "" {
- prompts[idx] = template_prompt;
-
- if let Some(remaining_tokens) = tokens_outstanding {
- let new_tokens = prompt_token_count + separator_tokens;
- tokens_outstanding = if remaining_tokens > new_tokens {
- Some(remaining_tokens - new_tokens)
- } else {
- Some(0)
- };
- }
- }
- }
- }
-
- prompts.retain(|x| x != "");
-
- let full_prompt = prompts.join(separator);
- let total_token_count = self.args.model.count_tokens(&full_prompt)?;
- anyhow::Ok((prompts.join(separator), total_token_count))
- }
-}
-
-#[cfg(test)]
-pub(crate) mod tests {
- use crate::models::TruncationDirection;
- use crate::test::FakeLanguageModel;
-
- use super::*;
-
- #[test]
- pub fn test_prompt_chain() {
- struct TestPromptTemplate {}
- impl PromptTemplate for TestPromptTemplate {
- fn generate(
- &self,
- args: &PromptArguments,
- max_token_length: Option<usize>,
- ) -> anyhow::Result<(String, usize)> {
- let mut content = "This is a test prompt template".to_string();
-
- let mut token_count = args.model.count_tokens(&content)?;
- if let Some(max_token_length) = max_token_length {
- if token_count > max_token_length {
- content = args.model.truncate(
- &content,
- max_token_length,
- TruncationDirection::End,
- )?;
- token_count = max_token_length;
- }
- }
-
- anyhow::Ok((content, token_count))
- }
- }
-
- struct TestLowPriorityTemplate {}
- impl PromptTemplate for TestLowPriorityTemplate {
- fn generate(
- &self,
- args: &PromptArguments,
- max_token_length: Option<usize>,
- ) -> anyhow::Result<(String, usize)> {
- let mut content = "This is a low priority test prompt template".to_string();
-
- let mut token_count = args.model.count_tokens(&content)?;
- if let Some(max_token_length) = max_token_length {
- if token_count > max_token_length {
- content = args.model.truncate(
- &content,
- max_token_length,
- TruncationDirection::End,
- )?;
- token_count = max_token_length;
- }
- }
-
- anyhow::Ok((content, token_count))
- }
- }
-
- let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
- let args = PromptArguments {
- model: model.clone(),
- language_name: None,
- project_name: None,
- snippets: Vec::new(),
- reserved_tokens: 0,
- buffer: None,
- selected_range: None,
- user_prompt: None,
- };
-
- let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
- (
- PromptPriority::Ordered { order: 0 },
- Box::new(TestPromptTemplate {}),
- ),
- (
- PromptPriority::Ordered { order: 1 },
- Box::new(TestLowPriorityTemplate {}),
- ),
- ];
- let chain = PromptChain::new(args, templates);
-
- let (prompt, token_count) = chain.generate(false).unwrap();
-
- assert_eq!(
- prompt,
- "This is a test prompt template\nThis is a low priority test prompt template"
- .to_string()
- );
-
- assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
-
- // Testing with Truncation Off
- // Should ignore capacity and return all prompts
- let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
- let args = PromptArguments {
- model: model.clone(),
- language_name: None,
- project_name: None,
- snippets: Vec::new(),
- reserved_tokens: 0,
- buffer: None,
- selected_range: None,
- user_prompt: None,
- };
-
- let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
- (
- PromptPriority::Ordered { order: 0 },
- Box::new(TestPromptTemplate {}),
- ),
- (
- PromptPriority::Ordered { order: 1 },
- Box::new(TestLowPriorityTemplate {}),
- ),
- ];
- let chain = PromptChain::new(args, templates);
-
- let (prompt, token_count) = chain.generate(false).unwrap();
-
- assert_eq!(
- prompt,
- "This is a test prompt template\nThis is a low priority test prompt template"
- .to_string()
- );
-
- assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
-
- // Testing with Truncation Off
- // Should ignore capacity and return all prompts
- let capacity = 20;
- let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
- let args = PromptArguments {
- model: model.clone(),
- language_name: None,
- project_name: None,
- snippets: Vec::new(),
- reserved_tokens: 0,
- buffer: None,
- selected_range: None,
- user_prompt: None,
- };
-
- let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
- (
- PromptPriority::Ordered { order: 0 },
- Box::new(TestPromptTemplate {}),
- ),
- (
- PromptPriority::Ordered { order: 1 },
- Box::new(TestLowPriorityTemplate {}),
- ),
- (
- PromptPriority::Ordered { order: 2 },
- Box::new(TestLowPriorityTemplate {}),
- ),
- ];
- let chain = PromptChain::new(args, templates);
-
- let (prompt, token_count) = chain.generate(true).unwrap();
-
- assert_eq!(prompt, "This is a test promp".to_string());
- assert_eq!(token_count, capacity);
-
- // Change Ordering of Prompts Based on Priority
- let capacity = 120;
- let reserved_tokens = 10;
- let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
- let args = PromptArguments {
- model: model.clone(),
- language_name: None,
- project_name: None,
- snippets: Vec::new(),
- reserved_tokens,
- buffer: None,
- selected_range: None,
- user_prompt: None,
- };
- let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
- (
- PromptPriority::Mandatory,
- Box::new(TestLowPriorityTemplate {}),
- ),
- (
- PromptPriority::Ordered { order: 0 },
- Box::new(TestPromptTemplate {}),
- ),
- (
- PromptPriority::Ordered { order: 1 },
- Box::new(TestLowPriorityTemplate {}),
- ),
- ];
- let chain = PromptChain::new(args, templates);
-
- let (prompt, token_count) = chain.generate(true).unwrap();
-
- assert_eq!(
- prompt,
- "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
- .to_string()
- );
- assert_eq!(token_count, capacity - reserved_tokens);
- }
-}
@@ -1,164 +0,0 @@
-use anyhow::anyhow;
-use language::BufferSnapshot;
-use language::ToOffset;
-
-use crate::models::LanguageModel;
-use crate::models::TruncationDirection;
-use crate::prompts::base::PromptArguments;
-use crate::prompts::base::PromptTemplate;
-use std::fmt::Write;
-use std::ops::Range;
-use std::sync::Arc;
-
-fn retrieve_context(
- buffer: &BufferSnapshot,
- selected_range: &Option<Range<usize>>,
- model: Arc<dyn LanguageModel>,
- max_token_count: Option<usize>,
-) -> anyhow::Result<(String, usize, bool)> {
- let mut prompt = String::new();
- let mut truncated = false;
- if let Some(selected_range) = selected_range {
- let start = selected_range.start.to_offset(buffer);
- let end = selected_range.end.to_offset(buffer);
-
- let start_window = buffer.text_for_range(0..start).collect::<String>();
-
- let mut selected_window = String::new();
- if start == end {
- write!(selected_window, "<|START|>").unwrap();
- } else {
- write!(selected_window, "<|START|").unwrap();
- }
-
- write!(
- selected_window,
- "{}",
- buffer.text_for_range(start..end).collect::<String>()
- )
- .unwrap();
-
- if start != end {
- write!(selected_window, "|END|>").unwrap();
- }
-
- let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
-
- if let Some(max_token_count) = max_token_count {
- let selected_tokens = model.count_tokens(&selected_window)?;
- if selected_tokens > max_token_count {
- return Err(anyhow!(
- "selected range is greater than model context window, truncation not possible"
- ));
- };
-
- let mut remaining_tokens = max_token_count - selected_tokens;
- let start_window_tokens = model.count_tokens(&start_window)?;
- let end_window_tokens = model.count_tokens(&end_window)?;
- let outside_tokens = start_window_tokens + end_window_tokens;
- if outside_tokens > remaining_tokens {
- let (start_goal_tokens, end_goal_tokens) =
- if start_window_tokens < end_window_tokens {
- let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
- remaining_tokens -= start_goal_tokens;
- let end_goal_tokens = remaining_tokens.min(end_window_tokens);
- (start_goal_tokens, end_goal_tokens)
- } else {
- let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
- remaining_tokens -= end_goal_tokens;
- let start_goal_tokens = remaining_tokens.min(start_window_tokens);
- (start_goal_tokens, end_goal_tokens)
- };
-
- let truncated_start_window =
- model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
- let truncated_end_window =
- model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
- writeln!(
- prompt,
- "{truncated_start_window}{selected_window}{truncated_end_window}"
- )
- .unwrap();
- truncated = true;
- } else {
- writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
- }
- } else {
- // If we dont have a selected range, include entire file.
- writeln!(prompt, "{}", &buffer.text()).unwrap();
-
- // Dumb truncation strategy
- if let Some(max_token_count) = max_token_count {
- if model.count_tokens(&prompt)? > max_token_count {
- truncated = true;
- prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
- }
- }
- }
- }
-
- let token_count = model.count_tokens(&prompt)?;
- anyhow::Ok((prompt, token_count, truncated))
-}
-
-pub struct FileContext {}
-
-impl PromptTemplate for FileContext {
- fn generate(
- &self,
- args: &PromptArguments,
- max_token_length: Option<usize>,
- ) -> anyhow::Result<(String, usize)> {
- if let Some(buffer) = &args.buffer {
- let mut prompt = String::new();
- // Add Initial Preamble
- // TODO: Do we want to add the path in here?
- writeln!(
- prompt,
- "The file you are currently working on has the following content:"
- )
- .unwrap();
-
- let language_name = args
- .language_name
- .clone()
- .unwrap_or("".to_string())
- .to_lowercase();
-
- let (context, _, truncated) = retrieve_context(
- buffer,
- &args.selected_range,
- args.model.clone(),
- max_token_length,
- )?;
- writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
-
- if truncated {
- writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
- }
-
- if let Some(selected_range) = &args.selected_range {
- let start = selected_range.start.to_offset(buffer);
- let end = selected_range.end.to_offset(buffer);
-
- if start == end {
- writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
- } else {
- writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
- }
- }
-
- // Really dumb truncation strategy
- if let Some(max_tokens) = max_token_length {
- prompt = args
- .model
- .truncate(&prompt, max_tokens, TruncationDirection::End)?;
- }
-
- let token_count = args.model.count_tokens(&prompt)?;
- anyhow::Ok((prompt, token_count))
- } else {
- Err(anyhow!("no buffer provided to retrieve file context from"))
- }
- }
-}
@@ -1,99 +0,0 @@
-use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
-use anyhow::anyhow;
-use std::fmt::Write;
-
-pub fn capitalize(s: &str) -> String {
- let mut c = s.chars();
- match c.next() {
- None => String::new(),
- Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
- }
-}
-
-pub struct GenerateInlineContent {}
-
-impl PromptTemplate for GenerateInlineContent {
- fn generate(
- &self,
- args: &PromptArguments,
- max_token_length: Option<usize>,
- ) -> anyhow::Result<(String, usize)> {
- let Some(user_prompt) = &args.user_prompt else {
- return Err(anyhow!("user prompt not provided"));
- };
-
- let file_type = args.get_file_type();
- let content_type = match &file_type {
- PromptFileType::Code => "code",
- PromptFileType::Text => "text",
- };
-
- let mut prompt = String::new();
-
- if let Some(selected_range) = &args.selected_range {
- if selected_range.start == selected_range.end {
- writeln!(
- prompt,
- "Assume the cursor is located where the `<|START|>` span is."
- )
- .unwrap();
- writeln!(
- prompt,
- "{} can't be replaced, so assume your answer will be inserted at the cursor.",
- capitalize(content_type)
- )
- .unwrap();
- writeln!(
- prompt,
- "Generate {content_type} based on the users prompt: {user_prompt}",
- )
- .unwrap();
- } else {
- writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
- writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
- writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap();
- }
- } else {
- writeln!(
- prompt,
- "Generate {content_type} based on the users prompt: {user_prompt}"
- )
- .unwrap();
- }
-
- if let Some(language_name) = &args.language_name {
- writeln!(
- prompt,
- "Your answer MUST always and only be valid {}.",
- language_name
- )
- .unwrap();
- }
- writeln!(prompt, "Never make remarks about the output.").unwrap();
- writeln!(
- prompt,
- "Do not return anything else, except the generated {content_type}."
- )
- .unwrap();
-
- match file_type {
- PromptFileType::Code => {
- // writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap();
- }
- _ => {}
- }
-
- // Really dumb truncation strategy
- if let Some(max_tokens) = max_token_length {
- prompt = args.model.truncate(
- &prompt,
- max_tokens,
- crate::models::TruncationDirection::End,
- )?;
- }
-
- let token_count = args.model.count_tokens(&prompt)?;
-
- anyhow::Ok((prompt, token_count))
- }
-}
@@ -1,5 +0,0 @@
-pub mod base;
-pub mod file_context;
-pub mod generate;
-pub mod preamble;
-pub mod repository_context;
@@ -1,52 +0,0 @@
-use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
-use std::fmt::Write;
-
-pub struct EngineerPreamble {}
-
-impl PromptTemplate for EngineerPreamble {
- fn generate(
- &self,
- args: &PromptArguments,
- max_token_length: Option<usize>,
- ) -> anyhow::Result<(String, usize)> {
- let mut prompts = Vec::new();
-
- match args.get_file_type() {
- PromptFileType::Code => {
- prompts.push(format!(
- "You are an expert {}engineer.",
- args.language_name.clone().unwrap_or("".to_string()) + " "
- ));
- }
- PromptFileType::Text => {
- prompts.push("You are an expert engineer.".to_string());
- }
- }
-
- if let Some(project_name) = args.project_name.clone() {
- prompts.push(format!(
- "You are currently working inside the '{project_name}' project in code editor Zed."
- ));
- }
-
- if let Some(mut remaining_tokens) = max_token_length {
- let mut prompt = String::new();
- let mut total_count = 0;
- for prompt_piece in prompts {
- let prompt_token_count =
- args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?;
- if remaining_tokens > prompt_token_count {
- writeln!(prompt, "{prompt_piece}").unwrap();
- remaining_tokens -= prompt_token_count;
- total_count += prompt_token_count;
- }
- }
-
- anyhow::Ok((prompt, total_count))
- } else {
- let prompt = prompts.join("\n");
- let token_count = args.model.count_tokens(&prompt)?;
- anyhow::Ok((prompt, token_count))
- }
- }
-}
@@ -1,96 +0,0 @@
-use crate::prompts::base::{PromptArguments, PromptTemplate};
-use std::fmt::Write;
-use std::{ops::Range, path::PathBuf};
-
-use gpui::{AsyncAppContext, Model};
-use language::{Anchor, Buffer};
-
-#[derive(Clone)]
-pub struct PromptCodeSnippet {
- path: Option<PathBuf>,
- language_name: Option<String>,
- content: String,
-}
-
-impl PromptCodeSnippet {
- pub fn new(
- buffer: Model<Buffer>,
- range: Range<Anchor>,
- cx: &mut AsyncAppContext,
- ) -> anyhow::Result<Self> {
- let (content, language_name, file_path) = buffer.update(cx, |buffer, _| {
- let snapshot = buffer.snapshot();
- let content = snapshot.text_for_range(range.clone()).collect::<String>();
-
- let language_name = buffer
- .language()
- .map(|language| language.name().to_string().to_lowercase());
-
- let file_path = buffer.file().map(|file| file.path().to_path_buf());
-
- (content, language_name, file_path)
- })?;
-
- anyhow::Ok(PromptCodeSnippet {
- path: file_path,
- language_name,
- content,
- })
- }
-}
-
-impl ToString for PromptCodeSnippet {
- fn to_string(&self) -> String {
- let path = self
- .path
- .as_ref()
- .map(|path| path.to_string_lossy().to_string())
- .unwrap_or("".to_string());
- let language_name = self.language_name.clone().unwrap_or("".to_string());
- let content = self.content.clone();
-
- format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
- }
-}
-
-pub struct RepositoryContext {}
-
-impl PromptTemplate for RepositoryContext {
- fn generate(
- &self,
- args: &PromptArguments,
- max_token_length: Option<usize>,
- ) -> anyhow::Result<(String, usize)> {
- const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
- let template = "You are working inside a large repository, here are a few code snippets that may be useful.";
- let mut prompt = String::new();
-
- let mut remaining_tokens = max_token_length;
- let separator_token_length = args.model.count_tokens("\n")?;
- for snippet in &args.snippets {
- let mut snippet_prompt = template.to_string();
- let content = snippet.to_string();
- writeln!(snippet_prompt, "{content}").unwrap();
-
- let token_count = args.model.count_tokens(&snippet_prompt)?;
- if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
- if let Some(tokens_left) = remaining_tokens {
- if tokens_left >= token_count {
- writeln!(prompt, "{snippet_prompt}").unwrap();
- remaining_tokens = if tokens_left >= (token_count + separator_token_length)
- {
- Some(tokens_left - token_count - separator_token_length)
- } else {
- Some(0)
- };
- }
- } else {
- writeln!(prompt, "{snippet_prompt}").unwrap();
- }
- }
- }
-
- let total_token_count = args.model.count_tokens(&prompt)?;
- anyhow::Ok((prompt, total_token_count))
- }
-}
@@ -1 +0,0 @@
-pub mod open_ai;
@@ -1,9 +0,0 @@
-pub mod completion;
-pub mod embedding;
-pub mod model;
-
-pub use completion::*;
-pub use embedding::*;
-pub use model::OpenAiLanguageModel;
-
-pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
@@ -1,421 +0,0 @@
-use std::{
- env,
- fmt::{self, Display},
- io,
- sync::Arc,
-};
-
-use anyhow::{anyhow, Result};
-use futures::{
- future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
- Stream, StreamExt,
-};
-use gpui::{AppContext, BackgroundExecutor};
-use isahc::{http::StatusCode, Request, RequestExt};
-use parking_lot::RwLock;
-use schemars::JsonSchema;
-use serde::{Deserialize, Serialize};
-use util::ResultExt;
-
-use crate::providers::open_ai::{OpenAiLanguageModel, OPEN_AI_API_URL};
-use crate::{
- auth::{CredentialProvider, ProviderCredential},
- completion::{CompletionProvider, CompletionRequest},
- models::LanguageModel,
-};
-
-#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
-#[serde(rename_all = "lowercase")]
-pub enum Role {
- User,
- Assistant,
- System,
-}
-
-impl Role {
- pub fn cycle(&mut self) {
- *self = match self {
- Role::User => Role::Assistant,
- Role::Assistant => Role::System,
- Role::System => Role::User,
- }
- }
-}
-
-impl Display for Role {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
- match self {
- Role::User => write!(f, "User"),
- Role::Assistant => write!(f, "Assistant"),
- Role::System => write!(f, "System"),
- }
- }
-}
-
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct RequestMessage {
- pub role: Role,
- pub content: String,
-}
-
-#[derive(Debug, Default, Serialize)]
-pub struct OpenAiRequest {
- pub model: String,
- pub messages: Vec<RequestMessage>,
- pub stream: bool,
- pub stop: Vec<String>,
- pub temperature: f32,
-}
-
-impl CompletionRequest for OpenAiRequest {
- fn data(&self) -> serde_json::Result<String> {
- serde_json::to_string(self)
- }
-}
-
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct ResponseMessage {
- pub role: Option<Role>,
- pub content: Option<String>,
-}
-
-#[derive(Deserialize, Debug)]
-pub struct OpenAiUsage {
- pub prompt_tokens: u32,
- pub completion_tokens: u32,
- pub total_tokens: u32,
-}
-
-#[derive(Deserialize, Debug)]
-pub struct ChatChoiceDelta {
- pub index: u32,
- pub delta: ResponseMessage,
- pub finish_reason: Option<String>,
-}
-
-#[derive(Deserialize, Debug)]
-pub struct OpenAiResponseStreamEvent {
- pub id: Option<String>,
- pub object: String,
- pub created: u32,
- pub model: String,
- pub choices: Vec<ChatChoiceDelta>,
- pub usage: Option<OpenAiUsage>,
-}
-
-async fn stream_completion(
- api_url: String,
- kind: OpenAiCompletionProviderKind,
- credential: ProviderCredential,
- executor: BackgroundExecutor,
- request: Box<dyn CompletionRequest>,
-) -> Result<impl Stream<Item = Result<OpenAiResponseStreamEvent>>> {
- let api_key = match credential {
- ProviderCredential::Credentials { api_key } => api_key,
- _ => {
- return Err(anyhow!("no credentials provider for completion"));
- }
- };
-
- let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>();
-
- let (auth_header_name, auth_header_value) = kind.auth_header(api_key);
- let json_data = request.data()?;
- let mut response = Request::post(kind.completions_endpoint_url(&api_url))
- .header("Content-Type", "application/json")
- .header(auth_header_name, auth_header_value)
- .body(json_data)?
- .send_async()
- .await?;
-
- let status = response.status();
- if status == StatusCode::OK {
- executor
- .spawn(async move {
- let mut lines = BufReader::new(response.body_mut()).lines();
-
- fn parse_line(
- line: Result<String, io::Error>,
- ) -> Result<Option<OpenAiResponseStreamEvent>> {
- if let Some(data) = line?.strip_prefix("data: ") {
- let event = serde_json::from_str(data)?;
- Ok(Some(event))
- } else {
- Ok(None)
- }
- }
-
- while let Some(line) = lines.next().await {
- if let Some(event) = parse_line(line).transpose() {
- let done = event.as_ref().map_or(false, |event| {
- event
- .choices
- .last()
- .map_or(false, |choice| choice.finish_reason.is_some())
- });
- if tx.unbounded_send(event).is_err() {
- break;
- }
-
- if done {
- break;
- }
- }
- }
-
- anyhow::Ok(())
- })
- .detach();
-
- Ok(rx)
- } else {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
-
- #[derive(Deserialize)]
- struct OpenAiResponse {
- error: OpenAiError,
- }
-
- #[derive(Deserialize)]
- struct OpenAiError {
- message: String,
- }
-
- match serde_json::from_str::<OpenAiResponse>(&body) {
- Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
- "Failed to connect to OpenAI API: {}",
- response.error.message,
- )),
-
- _ => Err(anyhow!(
- "Failed to connect to OpenAI API: {} {}",
- response.status(),
- body,
- )),
- }
- }
-}
-
-#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
-pub enum AzureOpenAiApiVersion {
- /// Retiring April 2, 2024.
- #[serde(rename = "2023-03-15-preview")]
- V2023_03_15Preview,
- #[serde(rename = "2023-05-15")]
- V2023_05_15,
- /// Retiring April 2, 2024.
- #[serde(rename = "2023-06-01-preview")]
- V2023_06_01Preview,
- /// Retiring April 2, 2024.
- #[serde(rename = "2023-07-01-preview")]
- V2023_07_01Preview,
- /// Retiring April 2, 2024.
- #[serde(rename = "2023-08-01-preview")]
- V2023_08_01Preview,
- /// Retiring April 2, 2024.
- #[serde(rename = "2023-09-01-preview")]
- V2023_09_01Preview,
- #[serde(rename = "2023-12-01-preview")]
- V2023_12_01Preview,
- #[serde(rename = "2024-02-15-preview")]
- V2024_02_15Preview,
-}
-
-impl fmt::Display for AzureOpenAiApiVersion {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(
- f,
- "{}",
- match self {
- Self::V2023_03_15Preview => "2023-03-15-preview",
- Self::V2023_05_15 => "2023-05-15",
- Self::V2023_06_01Preview => "2023-06-01-preview",
- Self::V2023_07_01Preview => "2023-07-01-preview",
- Self::V2023_08_01Preview => "2023-08-01-preview",
- Self::V2023_09_01Preview => "2023-09-01-preview",
- Self::V2023_12_01Preview => "2023-12-01-preview",
- Self::V2024_02_15Preview => "2024-02-15-preview",
- }
- )
- }
-}
-
-#[derive(Clone)]
-pub enum OpenAiCompletionProviderKind {
- OpenAi,
- AzureOpenAi {
- deployment_id: String,
- api_version: AzureOpenAiApiVersion,
- },
-}
-
-impl OpenAiCompletionProviderKind {
- /// Returns the chat completion endpoint URL for this [`OpenAiCompletionProviderKind`].
- fn completions_endpoint_url(&self, api_url: &str) -> String {
- match self {
- Self::OpenAi => {
- // https://platform.openai.com/docs/api-reference/chat/create
- format!("{api_url}/chat/completions")
- }
- Self::AzureOpenAi {
- deployment_id,
- api_version,
- } => {
- // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
- format!("{api_url}/openai/deployments/{deployment_id}/chat/completions?api-version={api_version}")
- }
- }
- }
-
- /// Returns the authentication header for this [`OpenAiCompletionProviderKind`].
- fn auth_header(&self, api_key: String) -> (&'static str, String) {
- match self {
- Self::OpenAi => ("Authorization", format!("Bearer {api_key}")),
- Self::AzureOpenAi { .. } => ("Api-Key", api_key),
- }
- }
-}
-
-#[derive(Clone)]
-pub struct OpenAiCompletionProvider {
- api_url: String,
- kind: OpenAiCompletionProviderKind,
- model: OpenAiLanguageModel,
- credential: Arc<RwLock<ProviderCredential>>,
- executor: BackgroundExecutor,
-}
-
-impl OpenAiCompletionProvider {
- pub async fn new(
- api_url: String,
- kind: OpenAiCompletionProviderKind,
- model_name: String,
- executor: BackgroundExecutor,
- ) -> Self {
- let model = executor
- .spawn(async move { OpenAiLanguageModel::load(&model_name) })
- .await;
- let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
- Self {
- api_url,
- kind,
- model,
- credential,
- executor,
- }
- }
-}
-
-impl CredentialProvider for OpenAiCompletionProvider {
- fn has_credentials(&self) -> bool {
- match *self.credential.read() {
- ProviderCredential::Credentials { .. } => true,
- _ => false,
- }
- }
-
- fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
- let existing_credential = self.credential.read().clone();
- let retrieved_credential = match existing_credential {
- ProviderCredential::Credentials { .. } => {
- return async move { existing_credential }.boxed()
- }
- _ => {
- if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
- async move { ProviderCredential::Credentials { api_key } }.boxed()
- } else {
- let credentials = cx.read_credentials(OPEN_AI_API_URL);
- async move {
- if let Some(Some((_, api_key))) = credentials.await.log_err() {
- if let Some(api_key) = String::from_utf8(api_key).log_err() {
- ProviderCredential::Credentials { api_key }
- } else {
- ProviderCredential::NoCredentials
- }
- } else {
- ProviderCredential::NoCredentials
- }
- }
- .boxed()
- }
- }
- };
-
- async move {
- let retrieved_credential = retrieved_credential.await;
- *self.credential.write() = retrieved_credential.clone();
- retrieved_credential
- }
- .boxed()
- }
-
- fn save_credentials(
- &self,
- cx: &mut AppContext,
- credential: ProviderCredential,
- ) -> BoxFuture<()> {
- *self.credential.write() = credential.clone();
- let credential = credential.clone();
- let write_credentials = match credential {
- ProviderCredential::Credentials { api_key } => {
- Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
- }
- _ => None,
- };
-
- async move {
- if let Some(write_credentials) = write_credentials {
- write_credentials.await.log_err();
- }
- }
- .boxed()
- }
-
- fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
- *self.credential.write() = ProviderCredential::NoCredentials;
- let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
- async move {
- delete_credentials.await.log_err();
- }
- .boxed()
- }
-}
-
-impl CompletionProvider for OpenAiCompletionProvider {
- fn base_model(&self) -> Box<dyn LanguageModel> {
- let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
- model
- }
-
- fn complete(
- &self,
- prompt: Box<dyn CompletionRequest>,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
- // Currently the CompletionRequest for OpenAI, includes a 'model' parameter
- // This means that the model is determined by the CompletionRequest and not the CompletionProvider,
- // which is currently model based, due to the language model.
- // At some point in the future we should rectify this.
- let credential = self.credential.read().clone();
- let api_url = self.api_url.clone();
- let kind = self.kind.clone();
- let request = stream_completion(api_url, kind, credential, self.executor.clone(), prompt);
- async move {
- let response = request.await?;
- let stream = response
- .filter_map(|response| async move {
- match response {
- Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
- Err(error) => Some(Err(error)),
- }
- })
- .boxed();
- Ok(stream)
- }
- .boxed()
- }
-
- fn box_clone(&self) -> Box<dyn CompletionProvider> {
- Box::new((*self).clone())
- }
-}
@@ -1,345 +0,0 @@
-use anyhow::{anyhow, Result};
-use async_trait::async_trait;
-use futures::future::BoxFuture;
-use futures::AsyncReadExt;
-use futures::FutureExt;
-use gpui::AppContext;
-use gpui::BackgroundExecutor;
-use isahc::http::StatusCode;
-use isahc::prelude::Configurable;
-use isahc::{AsyncBody, Response};
-use parking_lot::{Mutex, RwLock};
-use parse_duration::parse;
-use postage::watch;
-use serde::{Deserialize, Serialize};
-use serde_json;
-use std::env;
-use std::ops::Add;
-use std::sync::{Arc, OnceLock};
-use std::time::{Duration, Instant};
-use tiktoken_rs::{cl100k_base, CoreBPE};
-use util::http::{HttpClient, Request};
-use util::ResultExt;
-
-use crate::auth::{CredentialProvider, ProviderCredential};
-use crate::embedding::{Embedding, EmbeddingProvider};
-use crate::models::LanguageModel;
-use crate::providers::open_ai::OpenAiLanguageModel;
-
-use crate::providers::open_ai::OPEN_AI_API_URL;
-
-pub(crate) fn open_ai_bpe_tokenizer() -> &'static CoreBPE {
- static OPEN_AI_BPE_TOKENIZER: OnceLock<CoreBPE> = OnceLock::new();
- OPEN_AI_BPE_TOKENIZER.get_or_init(|| cl100k_base().unwrap())
-}
-
-#[derive(Clone)]
-pub struct OpenAiEmbeddingProvider {
- api_url: String,
- model: OpenAiLanguageModel,
- credential: Arc<RwLock<ProviderCredential>>,
- pub client: Arc<dyn HttpClient>,
- pub executor: BackgroundExecutor,
- rate_limit_count_rx: watch::Receiver<Option<Instant>>,
- rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
-}
-
-#[derive(Serialize)]
-struct OpenAiEmbeddingRequest<'a> {
- model: &'static str,
- input: Vec<&'a str>,
-}
-
-#[derive(Deserialize)]
-struct OpenAiEmbeddingResponse {
- data: Vec<OpenAiEmbedding>,
- usage: OpenAiEmbeddingUsage,
-}
-
-#[derive(Debug, Deserialize)]
-struct OpenAiEmbedding {
- embedding: Vec<f32>,
- index: usize,
- object: String,
-}
-
-#[derive(Deserialize)]
-struct OpenAiEmbeddingUsage {
- prompt_tokens: usize,
- total_tokens: usize,
-}
-
-impl OpenAiEmbeddingProvider {
- pub async fn new(
- api_url: String,
- client: Arc<dyn HttpClient>,
- executor: BackgroundExecutor,
- ) -> Self {
- let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
- let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
-
- // Loading the model is expensive, so ensure this runs off the main thread.
- let model = executor
- .spawn(async move { OpenAiLanguageModel::load("text-embedding-ada-002") })
- .await;
- let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
-
- OpenAiEmbeddingProvider {
- api_url,
- model,
- credential,
- client,
- executor,
- rate_limit_count_rx,
- rate_limit_count_tx,
- }
- }
-
- fn get_api_key(&self) -> Result<String> {
- match self.credential.read().clone() {
- ProviderCredential::Credentials { api_key } => Ok(api_key),
- _ => Err(anyhow!("api credentials not provided")),
- }
- }
-
- fn resolve_rate_limit(&self) {
- let reset_time = *self.rate_limit_count_tx.lock().borrow();
-
- if let Some(reset_time) = reset_time {
- if Instant::now() >= reset_time {
- *self.rate_limit_count_tx.lock().borrow_mut() = None
- }
- }
-
- log::trace!(
- "resolving reset time: {:?}",
- *self.rate_limit_count_tx.lock().borrow()
- );
- }
-
- fn update_reset_time(&self, reset_time: Instant) {
- let original_time = *self.rate_limit_count_tx.lock().borrow();
-
- let updated_time = if let Some(original_time) = original_time {
- if reset_time < original_time {
- Some(reset_time)
- } else {
- Some(original_time)
- }
- } else {
- Some(reset_time)
- };
-
- log::trace!("updating rate limit time: {:?}", updated_time);
-
- *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
- }
- async fn send_request(
- &self,
- api_url: &str,
- api_key: &str,
- spans: Vec<&str>,
- request_timeout: u64,
- ) -> Result<Response<AsyncBody>> {
- let request = Request::post(format!("{api_url}/embeddings"))
- .redirect_policy(isahc::config::RedirectPolicy::Follow)
- .timeout(Duration::from_secs(request_timeout))
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", api_key))
- .body(
- serde_json::to_string(&OpenAiEmbeddingRequest {
- input: spans.clone(),
- model: "text-embedding-ada-002",
- })
- .unwrap()
- .into(),
- )?;
-
- Ok(self.client.send(request).await?)
- }
-}
-
-impl CredentialProvider for OpenAiEmbeddingProvider {
- fn has_credentials(&self) -> bool {
- match *self.credential.read() {
- ProviderCredential::Credentials { .. } => true,
- _ => false,
- }
- }
-
- fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
- let existing_credential = self.credential.read().clone();
- let retrieved_credential = match existing_credential {
- ProviderCredential::Credentials { .. } => {
- return async move { existing_credential }.boxed()
- }
- _ => {
- if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
- async move { ProviderCredential::Credentials { api_key } }.boxed()
- } else {
- let credentials = cx.read_credentials(OPEN_AI_API_URL);
- async move {
- if let Some(Some((_, api_key))) = credentials.await.log_err() {
- if let Some(api_key) = String::from_utf8(api_key).log_err() {
- ProviderCredential::Credentials { api_key }
- } else {
- ProviderCredential::NoCredentials
- }
- } else {
- ProviderCredential::NoCredentials
- }
- }
- .boxed()
- }
- }
- };
-
- async move {
- let retrieved_credential = retrieved_credential.await;
- *self.credential.write() = retrieved_credential.clone();
- retrieved_credential
- }
- .boxed()
- }
-
- fn save_credentials(
- &self,
- cx: &mut AppContext,
- credential: ProviderCredential,
- ) -> BoxFuture<()> {
- *self.credential.write() = credential.clone();
- let credential = credential.clone();
- let write_credentials = match credential {
- ProviderCredential::Credentials { api_key } => {
- Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes()))
- }
- _ => None,
- };
-
- async move {
- if let Some(write_credentials) = write_credentials {
- write_credentials.await.log_err();
- }
- }
- .boxed()
- }
-
- fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> {
- *self.credential.write() = ProviderCredential::NoCredentials;
- let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL);
- async move {
- delete_credentials.await.log_err();
- }
- .boxed()
- }
-}
-
-#[async_trait]
-impl EmbeddingProvider for OpenAiEmbeddingProvider {
- fn base_model(&self) -> Box<dyn LanguageModel> {
- let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
- model
- }
-
- fn max_tokens_per_batch(&self) -> usize {
- 50000
- }
-
- fn rate_limit_expiration(&self) -> Option<Instant> {
- *self.rate_limit_count_rx.borrow()
- }
-
- async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
- const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
- const MAX_RETRIES: usize = 4;
-
- let api_url = self.api_url.as_str();
- let api_key = self.get_api_key()?;
-
- let mut request_number = 0;
- let mut rate_limiting = false;
- let mut request_timeout: u64 = 15;
- let mut response: Response<AsyncBody>;
- while request_number < MAX_RETRIES {
- response = self
- .send_request(
- &api_url,
- &api_key,
- spans.iter().map(|x| &**x).collect(),
- request_timeout,
- )
- .await?;
-
- request_number += 1;
-
- match response.status() {
- StatusCode::REQUEST_TIMEOUT => {
- request_timeout += 5;
- }
- StatusCode::OK => {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- let response: OpenAiEmbeddingResponse = serde_json::from_str(&body)?;
-
- log::trace!(
- "openai embedding completed. tokens: {:?}",
- response.usage.total_tokens
- );
-
- // If we complete a request successfully that was previously rate_limited
- // resolve the rate limit
- if rate_limiting {
- self.resolve_rate_limit()
- }
-
- return Ok(response
- .data
- .into_iter()
- .map(|embedding| Embedding::from(embedding.embedding))
- .collect());
- }
- StatusCode::TOO_MANY_REQUESTS => {
- rate_limiting = true;
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
-
- let delay_duration = {
- let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
- if let Some(time_to_reset) =
- response.headers().get("x-ratelimit-reset-tokens")
- {
- if let Ok(time_str) = time_to_reset.to_str() {
- parse(time_str).unwrap_or(delay)
- } else {
- delay
- }
- } else {
- delay
- }
- };
-
- // If we've previously rate limited, increment the duration but not the count
- let reset_time = Instant::now().add(delay_duration);
- self.update_reset_time(reset_time);
-
- log::trace!(
- "openai rate limiting: waiting {:?} until lifted",
- &delay_duration
- );
-
- self.executor.timer(delay_duration).await;
- }
- _ => {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- return Err(anyhow!(
- "open ai bad request: {:?} {:?}",
- &response.status(),
- body
- ));
- }
- }
- }
- Err(anyhow!("openai max retries"))
- }
-}
@@ -1,59 +0,0 @@
-use anyhow::anyhow;
-use tiktoken_rs::CoreBPE;
-
-use crate::models::{LanguageModel, TruncationDirection};
-
-use super::open_ai_bpe_tokenizer;
-
-#[derive(Clone)]
-pub struct OpenAiLanguageModel {
- name: String,
- bpe: Option<CoreBPE>,
-}
-
-impl OpenAiLanguageModel {
- pub fn load(model_name: &str) -> Self {
- let bpe = tiktoken_rs::get_bpe_from_model(model_name)
- .unwrap_or(open_ai_bpe_tokenizer().to_owned());
- OpenAiLanguageModel {
- name: model_name.to_string(),
- bpe: Some(bpe),
- }
- }
-}
-
-impl LanguageModel for OpenAiLanguageModel {
- fn name(&self) -> String {
- self.name.clone()
- }
- fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
- if let Some(bpe) = &self.bpe {
- anyhow::Ok(bpe.encode_with_special_tokens(content).len())
- } else {
- Err(anyhow!("bpe for open ai model was not retrieved"))
- }
- }
- fn truncate(
- &self,
- content: &str,
- length: usize,
- direction: TruncationDirection,
- ) -> anyhow::Result<String> {
- if let Some(bpe) = &self.bpe {
- let tokens = bpe.encode_with_special_tokens(content);
- if tokens.len() > length {
- match direction {
- TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
- TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
- }
- } else {
- bpe.decode(tokens)
- }
- } else {
- Err(anyhow!("bpe for open ai model was not retrieved"))
- }
- }
- fn capacity(&self) -> anyhow::Result<usize> {
- anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
- }
-}
@@ -1,206 +0,0 @@
-use std::{
- sync::atomic::{self, AtomicUsize, Ordering},
- time::Instant,
-};
-
-use async_trait::async_trait;
-use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
-use gpui::AppContext;
-use parking_lot::Mutex;
-
-use crate::{
- auth::{CredentialProvider, ProviderCredential},
- completion::{CompletionProvider, CompletionRequest},
- embedding::{Embedding, EmbeddingProvider},
- models::{LanguageModel, TruncationDirection},
-};
-
-#[derive(Clone)]
-pub struct FakeLanguageModel {
- pub capacity: usize,
-}
-
-impl LanguageModel for FakeLanguageModel {
- fn name(&self) -> String {
- "dummy".to_string()
- }
- fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
- anyhow::Ok(content.chars().collect::<Vec<char>>().len())
- }
- fn truncate(
- &self,
- content: &str,
- length: usize,
- direction: TruncationDirection,
- ) -> anyhow::Result<String> {
- println!("TRYING TO TRUNCATE: {:?}", length.clone());
-
- if length > self.count_tokens(content)? {
- println!("NOT TRUNCATING");
- return anyhow::Ok(content.to_string());
- }
-
- anyhow::Ok(match direction {
- TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
- .into_iter()
- .collect::<String>(),
- TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
- .into_iter()
- .collect::<String>(),
- })
- }
- fn capacity(&self) -> anyhow::Result<usize> {
- anyhow::Ok(self.capacity)
- }
-}
-
-#[derive(Default)]
-pub struct FakeEmbeddingProvider {
- pub embedding_count: AtomicUsize,
-}
-
-impl Clone for FakeEmbeddingProvider {
- fn clone(&self) -> Self {
- FakeEmbeddingProvider {
- embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
- }
- }
-}
-
-impl FakeEmbeddingProvider {
- pub fn embedding_count(&self) -> usize {
- self.embedding_count.load(atomic::Ordering::SeqCst)
- }
-
- pub fn embed_sync(&self, span: &str) -> Embedding {
- let mut result = vec![1.0; 26];
- for letter in span.chars() {
- let letter = letter.to_ascii_lowercase();
- if letter as u32 >= 'a' as u32 {
- let ix = (letter as u32) - ('a' as u32);
- if ix < 26 {
- result[ix as usize] += 1.0;
- }
- }
- }
-
- let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
- for x in &mut result {
- *x /= norm;
- }
-
- result.into()
- }
-}
-
-impl CredentialProvider for FakeEmbeddingProvider {
- fn has_credentials(&self) -> bool {
- true
- }
-
- fn retrieve_credentials(&self, _cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
- async { ProviderCredential::NotNeeded }.boxed()
- }
-
- fn save_credentials(
- &self,
- _cx: &mut AppContext,
- _credential: ProviderCredential,
- ) -> BoxFuture<()> {
- async {}.boxed()
- }
-
- fn delete_credentials(&self, _cx: &mut AppContext) -> BoxFuture<()> {
- async {}.boxed()
- }
-}
-
-#[async_trait]
-impl EmbeddingProvider for FakeEmbeddingProvider {
- fn base_model(&self) -> Box<dyn LanguageModel> {
- Box::new(FakeLanguageModel { capacity: 1000 })
- }
- fn max_tokens_per_batch(&self) -> usize {
- 1000
- }
-
- fn rate_limit_expiration(&self) -> Option<Instant> {
- None
- }
-
- async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
- self.embedding_count
- .fetch_add(spans.len(), atomic::Ordering::SeqCst);
-
- anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
- }
-}
-
-pub struct FakeCompletionProvider {
- last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
-}
-
-impl Clone for FakeCompletionProvider {
- fn clone(&self) -> Self {
- Self {
- last_completion_tx: Mutex::new(None),
- }
- }
-}
-
-impl FakeCompletionProvider {
- pub fn new() -> Self {
- Self {
- last_completion_tx: Mutex::new(None),
- }
- }
-
- pub fn send_completion(&self, completion: impl Into<String>) {
- let mut tx = self.last_completion_tx.lock();
- tx.as_mut().unwrap().try_send(completion.into()).unwrap();
- }
-
- pub fn finish_completion(&self) {
- self.last_completion_tx.lock().take().unwrap();
- }
-}
-
-impl CredentialProvider for FakeCompletionProvider {
- fn has_credentials(&self) -> bool {
- true
- }
-
- fn retrieve_credentials(&self, _cx: &mut AppContext) -> BoxFuture<ProviderCredential> {
- async { ProviderCredential::NotNeeded }.boxed()
- }
-
- fn save_credentials(
- &self,
- _cx: &mut AppContext,
- _credential: ProviderCredential,
- ) -> BoxFuture<()> {
- async {}.boxed()
- }
-
- fn delete_credentials(&self, _cx: &mut AppContext) -> BoxFuture<()> {
- async {}.boxed()
- }
-}
-
-impl CompletionProvider for FakeCompletionProvider {
- fn base_model(&self) -> Box<dyn LanguageModel> {
- let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
- model
- }
- fn complete(
- &self,
- _prompt: Box<dyn CompletionRequest>,
- ) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
- let (tx, rx) = mpsc::channel(1);
- *self.last_completion_tx.lock() = Some(tx);
- async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
- }
- fn box_clone(&self) -> Box<dyn CompletionProvider> {
- Box::new((*self).clone())
- }
-}
@@ -5,17 +5,14 @@ edition = "2021"
publish = false
license = "GPL-3.0-or-later"
-[lints]
-workspace = true
-
[lib]
path = "src/assistant.rs"
doctest = false
[dependencies]
-ai.workspace = true
anyhow.workspace = true
chrono.workspace = true
+client.workspace = true
collections.workspace = true
editor.workspace = true
fs.workspace = true
@@ -26,12 +23,13 @@ language.workspace = true
log.workspace = true
menu.workspace = true
multi_buffer.workspace = true
+open_ai = { workspace = true, features = ["schemars"] }
ordered-float.workspace = true
+parking_lot.workspace = true
project.workspace = true
regex.workspace = true
schemars.workspace = true
search.workspace = true
-semantic_index.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
@@ -45,7 +43,6 @@ uuid.workspace = true
workspace.workspace = true
[dev-dependencies]
-ai = { workspace = true, features = ["test-support"] }
ctor.workspace = true
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true
@@ -1,22 +1,24 @@
pub mod assistant_panel;
pub mod assistant_settings;
mod codegen;
+mod completion_provider;
mod prompts;
+mod saved_conversation;
mod streaming_diff;
-use ai::providers::open_ai::Role;
-use anyhow::Result;
pub use assistant_panel::AssistantPanel;
-use assistant_settings::OpenAiModel;
+use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel};
use chrono::{DateTime, Local};
-use collections::HashMap;
-use fs::Fs;
-use futures::StreamExt;
+use client::{proto, Client};
+pub(crate) use completion_provider::*;
use gpui::{actions, AppContext, SharedString};
-use regex::Regex;
+pub(crate) use saved_conversation::*;
use serde::{Deserialize, Serialize};
-use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc};
-use util::paths::CONVERSATIONS_DIR;
+use settings::Settings;
+use std::{
+ fmt::{self, Display},
+ sync::Arc,
+};
actions!(
assistant,
@@ -30,7 +32,6 @@ actions!(
ResetKey,
InlineAssist,
ToggleIncludeConversation,
- ToggleRetrieveContext,
]
);
@@ -39,6 +40,139 @@ actions!(
)]
struct MessageId(usize);
+#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(rename_all = "lowercase")]
+pub enum Role {
+ User,
+ Assistant,
+ System,
+}
+
+impl Role {
+ pub fn cycle(&mut self) {
+ *self = match self {
+ Role::User => Role::Assistant,
+ Role::Assistant => Role::System,
+ Role::System => Role::User,
+ }
+ }
+}
+
+impl Display for Role {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Role::User => write!(f, "user"),
+ Role::Assistant => write!(f, "assistant"),
+ Role::System => write!(f, "system"),
+ }
+ }
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
+pub enum LanguageModel {
+ ZedDotDev(ZedDotDevModel),
+ OpenAi(OpenAiModel),
+}
+
+impl Default for LanguageModel {
+ fn default() -> Self {
+ LanguageModel::ZedDotDev(ZedDotDevModel::default())
+ }
+}
+
+impl LanguageModel {
+ pub fn telemetry_id(&self) -> String {
+ match self {
+ LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
+ LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
+ }
+ }
+
+ pub fn display_name(&self) -> String {
+ match self {
+ LanguageModel::OpenAi(model) => format!("openai/{}", model.display_name()),
+ LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.display_name()),
+ }
+ }
+
+ pub fn max_token_count(&self) -> usize {
+ match self {
+ LanguageModel::OpenAi(model) => tiktoken_rs::model::get_context_size(model.id()),
+ LanguageModel::ZedDotDev(model) => match model {
+ ZedDotDevModel::GptThreePointFiveTurbo
+ | ZedDotDevModel::GptFour
+ | ZedDotDevModel::GptFourTurbo => tiktoken_rs::model::get_context_size(model.id()),
+ ZedDotDevModel::Custom(_) => 30720, // TODO: Base this on the selected model.
+ },
+ }
+ }
+
+ pub fn id(&self) -> &str {
+ match self {
+ LanguageModel::OpenAi(model) => model.id(),
+ LanguageModel::ZedDotDev(model) => model.id(),
+ }
+ }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct LanguageModelRequestMessage {
+ pub role: Role,
+ pub content: String,
+}
+
+impl LanguageModelRequestMessage {
+ pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
+ proto::LanguageModelRequestMessage {
+ role: match self.role {
+ Role::User => proto::LanguageModelRole::LanguageModelUser,
+ Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
+ Role::System => proto::LanguageModelRole::LanguageModelSystem,
+ } as i32,
+ content: self.content.clone(),
+ }
+ }
+}
+
+#[derive(Debug, Default, Serialize)]
+pub struct LanguageModelRequest {
+ pub model: LanguageModel,
+ pub messages: Vec<LanguageModelRequestMessage>,
+ pub stop: Vec<String>,
+ pub temperature: f32,
+}
+
+impl LanguageModelRequest {
+ pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
+ proto::CompleteWithLanguageModel {
+ model: self.model.id().to_string(),
+ messages: self.messages.iter().map(|m| m.to_proto()).collect(),
+ stop: self.stop.clone(),
+ temperature: self.temperature,
+ }
+ }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct LanguageModelResponseMessage {
+ pub role: Option<Role>,
+ pub content: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct LanguageModelUsage {
+ pub prompt_tokens: u32,
+ pub completion_tokens: u32,
+ pub total_tokens: u32,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct LanguageModelChoiceDelta {
+ pub index: u32,
+ pub delta: LanguageModelResponseMessage,
+ pub finish_reason: Option<String>,
+}
+
#[derive(Clone, Debug, Serialize, Deserialize)]
struct MessageMetadata {
role: Role,
@@ -53,71 +187,9 @@ enum MessageStatus {
Error(SharedString),
}
-#[derive(Serialize, Deserialize)]
-struct SavedMessage {
- id: MessageId,
- start: usize,
-}
-
-#[derive(Serialize, Deserialize)]
-struct SavedConversation {
- id: Option<String>,
- zed: String,
- version: String,
- text: String,
- messages: Vec<SavedMessage>,
- message_metadata: HashMap<MessageId, MessageMetadata>,
- summary: String,
- api_url: Option<String>,
- model: OpenAiModel,
-}
-
-impl SavedConversation {
- const VERSION: &'static str = "0.1.0";
-}
-
-struct SavedConversationMetadata {
- title: String,
- path: PathBuf,
- mtime: chrono::DateTime<chrono::Local>,
-}
-
-impl SavedConversationMetadata {
- pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
- fs.create_dir(&CONVERSATIONS_DIR).await?;
-
- let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
- let mut conversations = Vec::<SavedConversationMetadata>::new();
- while let Some(path) = paths.next().await {
- let path = path?;
- if path.extension() != Some(OsStr::new("json")) {
- continue;
- }
-
- let pattern = r" - \d+.zed.json$";
- let re = Regex::new(pattern).unwrap();
-
- let metadata = fs.metadata(&path).await?;
- if let Some((file_name, metadata)) = path
- .file_name()
- .and_then(|name| name.to_str())
- .zip(metadata)
- {
- let title = re.replace(file_name, "");
- conversations.push(Self {
- title: title.into_owned(),
- path,
- mtime: metadata.mtime.into(),
- });
- }
- }
- conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
-
- Ok(conversations)
- }
-}
-
-pub fn init(cx: &mut AppContext) {
+pub fn init(client: Arc<Client>, cx: &mut AppContext) {
+ AssistantSettings::register(cx);
+ completion_provider::init(client, cx);
assistant_panel::init(cx);
}
@@ -1,21 +1,13 @@
use crate::{
- assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAiModel},
+ assistant_settings::{AssistantDockPosition, AssistantSettings, ZedDotDevModel},
codegen::{self, Codegen, CodegenKind},
prompts::generate_content_prompt,
- Assist, CycleMessageRole, InlineAssist, MessageId, MessageMetadata, MessageStatus,
+ Assist, CompletionProvider, CycleMessageRole, InlineAssist, LanguageModel,
+ LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata, MessageStatus,
NewConversation, QuoteSelection, ResetKey, Role, SavedConversation, SavedConversationMetadata,
- SavedMessage, Split, ToggleFocus, ToggleIncludeConversation, ToggleRetrieveContext,
+ SavedMessage, Split, ToggleFocus, ToggleIncludeConversation,
};
-use ai::prompts::repository_context::PromptCodeSnippet;
-use ai::{
- auth::ProviderCredential,
- completion::{CompletionProvider, CompletionRequest},
- providers::open_ai::{
- OpenAiCompletionProvider, OpenAiCompletionProviderKind, OpenAiRequest, RequestMessage,
- OPEN_AI_API_URL,
- },
-};
-use anyhow::{anyhow, Result};
+use anyhow::Result;
use chrono::{DateTime, Local};
use collections::{hash_map, HashMap, HashSet, VecDeque};
use editor::{
@@ -24,35 +16,25 @@ use editor::{
BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint,
},
scroll::{Autoscroll, AutoscrollStrategy},
- Anchor, Editor, EditorElement, EditorEvent, EditorStyle, MultiBufferSnapshot, ToOffset,
+ Anchor, Editor, EditorElement, EditorEvent, EditorStyle, MultiBufferSnapshot, ToOffset as _,
ToPoint,
};
use fs::Fs;
use futures::StreamExt;
use gpui::{
- canvas, div, point, relative, rems, uniform_list, Action, AnyElement, AppContext,
- AsyncAppContext, AsyncWindowContext, ClipboardItem, Context, EventEmitter, FocusHandle,
- FocusableView, FontStyle, FontWeight, HighlightStyle, InteractiveElement, IntoElement, Model,
- ModelContext, ParentElement, Pixels, PromptLevel, Render, SharedString,
+ canvas, div, point, relative, rems, uniform_list, Action, AnyElement, AnyView, AppContext,
+ AsyncAppContext, AsyncWindowContext, AvailableSpace, ClipboardItem, Context, EventEmitter,
+ FocusHandle, FocusableView, FontStyle, FontWeight, HighlightStyle, InteractiveElement,
+ IntoElement, Model, ModelContext, ParentElement, Pixels, Render, SharedString,
StatefulInteractiveElement, Styled, Subscription, Task, TextStyle, UniformListScrollHandle,
View, ViewContext, VisualContext, WeakModel, WeakView, WhiteSpace, WindowContext,
};
use language::{language_settings::SoftWrap, Buffer, BufferId, LanguageRegistry, ToOffset as _};
+use parking_lot::Mutex;
use project::Project;
use search::{buffer_search::DivRegistrar, BufferSearchBar};
-use semantic_index::{SemanticIndex, SemanticIndexStatus};
use settings::Settings;
-use std::{
- cell::Cell,
- cmp,
- fmt::Write,
- iter,
- ops::Range,
- path::{Path, PathBuf},
- rc::Rc,
- sync::Arc,
- time::{Duration, Instant},
-};
+use std::{cmp, fmt::Write, iter, ops::Range, path::PathBuf, sync::Arc, time::Duration};
use telemetry_events::AssistantKind;
use theme::ThemeSettings;
use ui::{
@@ -69,7 +51,6 @@ use workspace::{
};
pub fn init(cx: &mut AppContext) {
- AssistantSettings::register(cx);
cx.observe_new_views(
|workspace: &mut Workspace, _cx: &mut ViewContext<Workspace>| {
workspace
@@ -88,27 +69,29 @@ pub struct AssistantPanel {
workspace: WeakView<Workspace>,
width: Option<Pixels>,
height: Option<Pixels>,
- active_editor_index: Option<usize>,
- prev_active_editor_index: Option<usize>,
- editors: Vec<View<ConversationEditor>>,
+ active_conversation_editor: Option<ActiveConversationEditor>,
+ show_saved_conversations: bool,
saved_conversations: Vec<SavedConversationMetadata>,
saved_conversations_scroll_handle: UniformListScrollHandle,
zoomed: bool,
focus_handle: FocusHandle,
toolbar: View<Toolbar>,
- completion_provider: Arc<dyn CompletionProvider>,
- api_key_editor: Option<View<Editor>>,
languages: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
- subscriptions: Vec<Subscription>,
+ _subscriptions: Vec<Subscription>,
next_inline_assist_id: usize,
pending_inline_assists: HashMap<usize, PendingInlineAssist>,
pending_inline_assist_ids_by_editor: HashMap<WeakView<Editor>, Vec<usize>>,
include_conversation_in_next_inline_assist: bool,
inline_prompt_history: VecDeque<String>,
_watch_saved_conversations: Task<Result<()>>,
- semantic_index: Option<Model<SemanticIndex>>,
- retrieve_context_in_next_inline_assist: bool,
+ model: LanguageModel,
+ authentication_prompt: Option<AnyView>,
+}
+
+struct ActiveConversationEditor {
+ editor: View<ConversationEditor>,
+ _subscriptions: Vec<Subscription>,
}
impl AssistantPanel {
@@ -124,22 +107,6 @@ impl AssistantPanel {
.await
.log_err()
.unwrap_or_default();
- let (provider_kind, api_url, model_name) = cx.update(|cx| {
- let settings = AssistantSettings::get_global(cx);
- anyhow::Ok((
- settings.provider_kind()?,
- settings.provider_api_url()?,
- settings.provider_model_name()?,
- ))
- })??;
-
- let completion_provider = OpenAiCompletionProvider::new(
- api_url,
- provider_kind,
- model_name,
- cx.background_executor().clone(),
- )
- .await;
// TODO: deserialize state.
let workspace_handle = workspace.clone();
@@ -168,41 +135,48 @@ impl AssistantPanel {
let toolbar = cx.new_view(|cx| {
let mut toolbar = Toolbar::new();
toolbar.set_can_navigate(false, cx);
- toolbar.add_item(cx.new_view(|cx| BufferSearchBar::new(cx)), cx);
+ toolbar.add_item(cx.new_view(BufferSearchBar::new), cx);
toolbar
});
- let semantic_index = SemanticIndex::global(cx);
-
let focus_handle = cx.focus_handle();
- cx.on_focus_in(&focus_handle, Self::focus_in).detach();
- cx.on_focus_out(&focus_handle, Self::focus_out).detach();
+ let subscriptions = vec![
+ cx.on_focus_in(&focus_handle, Self::focus_in),
+ cx.on_focus_out(&focus_handle, Self::focus_out),
+ cx.observe_global::<CompletionProvider>({
+ let mut prev_settings_version =
+ CompletionProvider::global(cx).settings_version();
+ move |this, cx| {
+ this.completion_provider_changed(prev_settings_version, cx);
+ prev_settings_version =
+ CompletionProvider::global(cx).settings_version();
+ }
+ }),
+ ];
+ let model = CompletionProvider::global(cx).default_model();
Self {
workspace: workspace_handle,
- active_editor_index: Default::default(),
- prev_active_editor_index: Default::default(),
- editors: Default::default(),
+ active_conversation_editor: None,
+ show_saved_conversations: false,
saved_conversations,
saved_conversations_scroll_handle: Default::default(),
zoomed: false,
focus_handle,
toolbar,
- completion_provider: Arc::new(completion_provider),
- api_key_editor: None,
languages: workspace.app_state().languages.clone(),
fs: workspace.app_state().fs.clone(),
width: None,
height: None,
- subscriptions: Default::default(),
+ _subscriptions: subscriptions,
next_inline_assist_id: 0,
pending_inline_assists: Default::default(),
pending_inline_assist_ids_by_editor: Default::default(),
include_conversation_in_next_inline_assist: false,
inline_prompt_history: Default::default(),
_watch_saved_conversations,
- semantic_index,
- retrieve_context_in_next_inline_assist: false,
+ model,
+ authentication_prompt: None,
}
})
})
@@ -214,14 +188,8 @@ impl AssistantPanel {
.update(cx, |toolbar, cx| toolbar.focus_changed(true, cx));
cx.notify();
if self.focus_handle.is_focused(cx) {
- if self.has_credentials() {
- if let Some(editor) = self.active_editor() {
- cx.focus_view(editor);
- }
- }
-
- if let Some(api_key_editor) = self.api_key_editor.as_ref() {
- cx.focus_view(api_key_editor);
+ if let Some(editor) = self.active_conversation_editor() {
+ cx.focus_view(editor);
}
}
}
@@ -232,6 +200,30 @@ impl AssistantPanel {
cx.notify();
}
+ fn completion_provider_changed(
+ &mut self,
+ prev_settings_version: usize,
+ cx: &mut ViewContext<Self>,
+ ) {
+ if self.is_authenticated(cx) {
+ self.authentication_prompt = None;
+
+ let model = CompletionProvider::global(cx).default_model();
+ self.set_model(model, cx);
+
+ if self.active_conversation_editor().is_none() {
+ self.new_conversation(cx);
+ }
+ } else if self.authentication_prompt.is_none()
+ || prev_settings_version != CompletionProvider::global(cx).settings_version()
+ {
+ self.authentication_prompt =
+ Some(cx.update_global::<CompletionProvider, _>(|provider, cx| {
+ provider.authentication_prompt(cx)
+ }));
+ }
+ }
+
pub fn inline_assist(
workspace: &mut Workspace,
_: &InlineAssist,
@@ -250,7 +242,7 @@ impl AssistantPanel {
};
let project = workspace.project().clone();
- if assistant.update(cx, |assistant, _| assistant.has_credentials()) {
+ if assistant.update(cx, |assistant, cx| assistant.is_authenticated(cx)) {
assistant.update(cx, |assistant, cx| {
assistant.new_inline_assist(&active_editor, cx, &project)
});
@@ -258,9 +250,9 @@ impl AssistantPanel {
let assistant = assistant.downgrade();
cx.spawn(|workspace, mut cx| async move {
assistant
- .update(&mut cx, |assistant, cx| assistant.load_credentials(cx))?
- .await;
- if assistant.update(&mut cx, |assistant, _| assistant.has_credentials())? {
+ .update(&mut cx, |assistant, cx| assistant.authenticate(cx))?
+ .await?;
+ if assistant.update(&mut cx, |assistant, cx| assistant.is_authenticated(cx))? {
assistant.update(&mut cx, |assistant, cx| {
assistant.new_inline_assist(&active_editor, cx, &project)
})?;
@@ -311,34 +303,11 @@ impl AssistantPanel {
};
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
- let provider = self.completion_provider.clone();
-
- let codegen = cx.new_model(|cx| {
- Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
- });
- if let Some(semantic_index) = self.semantic_index.clone() {
- let project = project.clone();
- cx.spawn(|_, mut cx| async move {
- let previously_indexed = semantic_index
- .update(&mut cx, |index, cx| {
- index.project_previously_indexed(&project, cx)
- })?
- .await
- .unwrap_or(false);
- if previously_indexed {
- let _ = semantic_index
- .update(&mut cx, |index, cx| {
- index.index_project(project.clone(), cx)
- })?
- .await;
- }
- anyhow::Ok(())
- })
- .detach_and_log_err(cx);
- }
+ let codegen =
+ cx.new_model(|cx| Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, cx));
- let measurements = Rc::new(Cell::new(BlockMeasurements::default()));
+ let measurements = Arc::new(Mutex::new(BlockMeasurements::default()));
let inline_assistant = cx.new_view(|cx| {
InlineAssistant::new(
inline_assist_id,
@@ -348,9 +317,6 @@ impl AssistantPanel {
codegen.clone(),
self.workspace.clone(),
cx,
- self.retrieve_context_in_next_inline_assist,
- self.semantic_index.clone(),
- project.clone(),
)
});
let block_id = editor.update(cx, |editor, cx| {
@@ -365,10 +331,10 @@ impl AssistantPanel {
render: Arc::new({
let inline_assistant = inline_assistant.clone();
move |cx: &mut BlockContext| {
- measurements.set(BlockMeasurements {
+ *measurements.lock() = BlockMeasurements {
anchor_x: cx.anchor_x,
gutter_width: cx.gutter_dimensions.width,
- });
+ };
inline_assistant.clone().into_any_element()
}
}),
@@ -456,7 +422,7 @@ impl AssistantPanel {
.entry(editor.downgrade())
.or_default()
.push(inline_assist_id);
- self.update_highlights_for_editor(&editor, cx);
+ self.update_highlights_for_editor(editor, cx);
}
fn handle_inline_assistant_event(
@@ -470,15 +436,8 @@ impl AssistantPanel {
InlineAssistantEvent::Confirmed {
prompt,
include_conversation,
- retrieve_context,
} => {
- self.confirm_inline_assist(
- assist_id,
- prompt,
- *include_conversation,
- cx,
- *retrieve_context,
- );
+ self.confirm_inline_assist(assist_id, prompt, *include_conversation, cx);
}
InlineAssistantEvent::Canceled => {
self.finish_inline_assist(assist_id, true, cx);
@@ -491,9 +450,6 @@ impl AssistantPanel {
} => {
self.include_conversation_in_next_inline_assist = *include_conversation;
}
- InlineAssistantEvent::RetrieveContextToggled { retrieve_context } => {
- self.retrieve_context_in_next_inline_assist = *retrieve_context
- }
}
}
@@ -575,10 +531,9 @@ impl AssistantPanel {
user_prompt: &str,
include_conversation: bool,
cx: &mut ViewContext<Self>,
- retrieve_context: bool,
) {
let conversation = if include_conversation {
- self.active_editor()
+ self.active_conversation_editor()
.map(|editor| editor.read(cx).conversation.clone())
} else {
None
@@ -599,17 +554,13 @@ impl AssistantPanel {
let project = pending_assist.project.clone();
- let project_name = if let Some(project) = project.upgrade() {
- Some(
- project
- .read(cx)
- .worktree_root_names(cx)
- .collect::<Vec<&str>>()
- .join("/"),
- )
- } else {
- None
- };
+ let project_name = project.upgrade().map(|project| {
+ project
+ .read(cx)
+ .worktree_root_names(cx)
+ .collect::<Vec<&str>>()
+ .join("/")
+ });
self.inline_prompt_history
.retain(|prompt| prompt != user_prompt);
@@ -652,7 +603,7 @@ impl AssistantPanel {
// If Markdown or No Language is Known, increase the randomness for more creative output
// If Code, decrease temperature to get more deterministic outputs
let temperature = if let Some(language) = language_name.clone() {
- if *language != *"Markdown" {
+ if language.as_ref() != "Markdown" {
0.5
} else {
1.0
@@ -663,61 +614,9 @@ impl AssistantPanel {
let user_prompt = user_prompt.to_string();
- let snippets = if retrieve_context {
- let Some(project) = project.upgrade() else {
- return;
- };
-
- let search_results = if let Some(semantic_index) = self.semantic_index.clone() {
- let search_results = semantic_index.update(cx, |this, cx| {
- this.search_project(project, user_prompt.to_string(), 10, vec![], vec![], cx)
- });
-
- cx.background_executor()
- .spawn(async move { search_results.await.unwrap_or_default() })
- } else {
- Task::ready(Vec::new())
- };
-
- let snippets = cx.spawn(|_, mut cx| async move {
- let mut snippets = Vec::new();
- for result in search_results.await {
- snippets.push(PromptCodeSnippet::new(
- result.buffer,
- result.range,
- &mut cx,
- )?);
- }
- anyhow::Ok(snippets)
- });
- snippets
- } else {
- Task::ready(Ok(Vec::new()))
- };
-
- let Some(mut model_name) = AssistantSettings::get_global(cx)
- .provider_model_name()
- .log_err()
- else {
- return;
- };
-
- let prompt = cx.background_executor().spawn({
- let model_name = model_name.clone();
- async move {
- let snippets = snippets.await?;
-
- let language_name = language_name.as_deref();
- generate_content_prompt(
- user_prompt,
- language_name,
- buffer,
- range,
- snippets,
- &model_name,
- project_name,
- )
- }
+ let prompt = cx.background_executor().spawn(async move {
+ let language_name = language_name.as_deref();
+ generate_content_prompt(user_prompt, language_name, buffer, range, project_name)
});
let mut messages = Vec::new();
@@ -729,25 +628,24 @@ impl AssistantPanel {
.messages(cx)
.map(|message| message.to_open_ai_message(buffer)),
);
- model_name = conversation.model.full_name().to_string();
}
+ let model = self.model.clone();
cx.spawn(|_, mut cx| async move {
// I Don't know if we want to return a ? here.
let prompt = prompt.await?;
- messages.push(RequestMessage {
+ messages.push(LanguageModelRequestMessage {
role: Role::User,
content: prompt,
});
- let request = Box::new(OpenAiRequest {
- model: model_name,
+ let request = LanguageModelRequest {
+ model,
messages,
- stream: true,
stop: vec!["|END|>".to_string()],
temperature,
- });
+ };
codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?;
anyhow::Ok(())
@@ -781,7 +679,7 @@ impl AssistantPanel {
} else {
editor.highlight_background::<PendingInlineAssist>(
background_ranges,
- |theme| theme.editor_active_line_background, // todo("use the appropriate color")
+ |theme| theme.editor_active_line_background, // todo!("use the appropriate color")
cx,
);
}
@@ -801,54 +699,82 @@ impl AssistantPanel {
});
}
- fn build_api_key_editor(&mut self, cx: &mut WindowContext<'_>) {
- self.api_key_editor = Some(build_api_key_editor(cx));
- }
-
fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> View<ConversationEditor> {
let editor = cx.new_view(|cx| {
ConversationEditor::new(
- self.completion_provider.clone(),
+ self.model.clone(),
self.languages.clone(),
self.fs.clone(),
self.workspace.clone(),
cx,
)
});
- self.add_conversation(editor.clone(), cx);
+ self.show_conversation(editor.clone(), cx);
editor
}
- fn add_conversation(&mut self, editor: View<ConversationEditor>, cx: &mut ViewContext<Self>) {
- self.subscriptions
- .push(cx.subscribe(&editor, Self::handle_conversation_editor_event));
+ fn show_conversation(
+ &mut self,
+ conversation_editor: View<ConversationEditor>,
+ cx: &mut ViewContext<Self>,
+ ) {
+ let mut subscriptions = Vec::new();
+ subscriptions
+ .push(cx.subscribe(&conversation_editor, Self::handle_conversation_editor_event));
- let conversation = editor.read(cx).conversation.clone();
- self.subscriptions
- .push(cx.observe(&conversation, |_, _, cx| cx.notify()));
+ let conversation = conversation_editor.read(cx).conversation.clone();
+ subscriptions.push(cx.observe(&conversation, |_, _, cx| cx.notify()));
+
+ let editor = conversation_editor.read(cx).editor.clone();
+ self.toolbar.update(cx, |toolbar, cx| {
+ toolbar.set_active_item(Some(&editor), cx);
+ });
+ if self.focus_handle.contains_focused(cx) {
+ cx.focus_view(&editor);
+ }
+ self.active_conversation_editor = Some(ActiveConversationEditor {
+ editor: conversation_editor,
+ _subscriptions: subscriptions,
+ });
+ self.show_saved_conversations = false;
- let index = self.editors.len();
- self.editors.push(editor);
- self.set_active_editor_index(Some(index), cx);
+ cx.notify();
}
- fn set_active_editor_index(&mut self, index: Option<usize>, cx: &mut ViewContext<Self>) {
- self.prev_active_editor_index = self.active_editor_index;
- self.active_editor_index = index;
- if let Some(editor) = self.active_editor() {
- let editor = editor.read(cx).editor.clone();
- self.toolbar.update(cx, |toolbar, cx| {
- toolbar.set_active_item(Some(&editor), cx);
- });
- if self.focus_handle.contains_focused(cx) {
- cx.focus_view(&editor);
- }
- } else {
- self.toolbar.update(cx, |toolbar, cx| {
- toolbar.set_active_item(None, cx);
- });
- }
+ fn cycle_model(&mut self, cx: &mut ViewContext<Self>) {
+ let next_model = match &self.model {
+ LanguageModel::OpenAi(model) => LanguageModel::OpenAi(match &model {
+ open_ai::Model::ThreePointFiveTurbo => open_ai::Model::Four,
+ open_ai::Model::Four => open_ai::Model::FourTurbo,
+ open_ai::Model::FourTurbo => open_ai::Model::ThreePointFiveTurbo,
+ }),
+ LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model {
+ ZedDotDevModel::GptThreePointFiveTurbo => ZedDotDevModel::GptFour,
+ ZedDotDevModel::GptFour => ZedDotDevModel::GptFourTurbo,
+ ZedDotDevModel::GptFourTurbo => {
+ match CompletionProvider::global(cx).default_model() {
+ LanguageModel::ZedDotDev(custom) => custom,
+ _ => ZedDotDevModel::GptThreePointFiveTurbo,
+ }
+ }
+ ZedDotDevModel::Custom(_) => ZedDotDevModel::GptThreePointFiveTurbo,
+ }),
+ };
+
+ self.set_model(next_model, cx);
+ }
+ fn set_model(&mut self, model: LanguageModel, cx: &mut ViewContext<Self>) {
+ self.model = model.clone();
+ if let Some(editor) = self.active_conversation_editor() {
+ editor.update(cx, |active_conversation, cx| {
+ active_conversation
+ .conversation
+ .update(cx, |conversation, cx| {
+ conversation.set_model(model, cx);
+ })
+ })
+ }
cx.notify();
}
@@ -863,49 +789,6 @@ impl AssistantPanel {
}
}
- fn save_credentials(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
- if let Some(api_key) = self
- .api_key_editor
- .as_ref()
- .map(|editor| editor.read(cx).text(cx))
- {
- if !api_key.is_empty() {
- let credential = ProviderCredential::Credentials {
- api_key: api_key.clone(),
- };
-
- let completion_provider = self.completion_provider.clone();
- cx.spawn(|this, mut cx| async move {
- cx.update(|cx| completion_provider.save_credentials(cx, credential))?
- .await;
-
- this.update(&mut cx, |this, cx| {
- this.api_key_editor.take();
- this.focus_handle.focus(cx);
- cx.notify();
- })
- })
- .detach_and_log_err(cx);
- }
- } else {
- cx.propagate();
- }
- }
-
- fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
- let completion_provider = self.completion_provider.clone();
- cx.spawn(|this, mut cx| async move {
- cx.update(|cx| completion_provider.delete_credentials(cx))?
- .await;
- this.update(&mut cx, |this, cx| {
- this.build_api_key_editor(cx);
- this.focus_handle.focus(cx);
- cx.notify();
- })
- })
- .detach_and_log_err(cx);
- }
-
fn toggle_zoom(&mut self, _: &workspace::ToggleZoom, cx: &mut ViewContext<Self>) {
if self.zoomed {
cx.emit(PanelEvent::ZoomOut)
@@ -958,58 +841,27 @@ impl AssistantPanel {
}
}
- fn active_editor(&self) -> Option<&View<ConversationEditor>> {
- self.editors.get(self.active_editor_index?)
+ fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
+ CompletionProvider::global(cx)
+ .reset_credentials(cx)
+ .detach_and_log_err(cx);
}
- fn render_api_key_editor(
- &self,
- editor: &View<Editor>,
- cx: &mut ViewContext<Self>,
- ) -> impl IntoElement {
- let settings = ThemeSettings::get_global(cx);
- let text_style = TextStyle {
- color: if editor.read(cx).read_only(cx) {
- cx.theme().colors().text_disabled
- } else {
- cx.theme().colors().text
- },
- font_family: settings.ui_font.family.clone(),
- font_features: settings.ui_font.features,
- font_size: rems(0.875).into(),
- font_weight: FontWeight::NORMAL,
- font_style: FontStyle::Normal,
- line_height: relative(1.3),
- background_color: None,
- underline: None,
- strikethrough: None,
- white_space: WhiteSpace::Normal,
- };
- EditorElement::new(
- &editor,
- EditorStyle {
- background: cx.theme().colors().editor_background,
- local_player: cx.theme().players().local(),
- text: text_style,
- ..Default::default()
- },
- )
+ fn active_conversation_editor(&self) -> Option<&View<ConversationEditor>> {
+ Some(&self.active_conversation_editor.as_ref()?.editor)
}
fn render_hamburger_button(cx: &mut ViewContext<Self>) -> impl IntoElement {
IconButton::new("hamburger_button", IconName::Menu)
.on_click(cx.listener(|this, _event, cx| {
- if this.active_editor().is_some() {
- this.set_active_editor_index(None, cx);
- } else {
- this.set_active_editor_index(this.prev_active_editor_index, cx);
- }
+ this.show_saved_conversations = !this.show_saved_conversations;
+ cx.notify();
}))
.tooltip(|cx| Tooltip::text("Conversation History", cx))
}
fn render_editor_tools(&self, cx: &mut ViewContext<Self>) -> Vec<AnyElement> {
- if self.active_editor().is_some() {
+ if self.active_conversation_editor().is_some() {
vec![
Self::render_split_button(cx).into_any_element(),
Self::render_quote_button(cx).into_any_element(),
@@ -1023,7 +875,7 @@ impl AssistantPanel {
fn render_split_button(cx: &mut ViewContext<Self>) -> impl IntoElement {
IconButton::new("split_button", IconName::Snip)
.on_click(cx.listener(|this, _event, cx| {
- if let Some(active_editor) = this.active_editor() {
+ if let Some(active_editor) = this.active_conversation_editor() {
active_editor.update(cx, |editor, cx| editor.split(&Default::default(), cx));
}
}))
@@ -1034,7 +886,7 @@ impl AssistantPanel {
fn render_assist_button(cx: &mut ViewContext<Self>) -> impl IntoElement {
IconButton::new("assist_button", IconName::MagicWand)
.on_click(cx.listener(|this, _event, cx| {
- if let Some(active_editor) = this.active_editor() {
+ if let Some(active_editor) = this.active_conversation_editor() {
active_editor.update(cx, |editor, cx| editor.assist(&Default::default(), cx));
}
}))
@@ -1111,202 +963,185 @@ impl AssistantPanel {
fn open_conversation(&mut self, path: PathBuf, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
cx.focus(&self.focus_handle);
- if let Some(ix) = self.editor_index_for_path(&path, cx) {
- self.set_active_editor_index(Some(ix), cx);
- return Task::ready(Ok(()));
- }
-
let fs = self.fs.clone();
let workspace = self.workspace.clone();
let languages = self.languages.clone();
cx.spawn(|this, mut cx| async move {
- let saved_conversation = fs.load(&path).await?;
- let saved_conversation = serde_json::from_str(&saved_conversation)?;
- let conversation =
- Conversation::deserialize(saved_conversation, path.clone(), languages, &mut cx)
- .await?;
+ let saved_conversation = SavedConversation::load(&path, fs.as_ref()).await?;
+ let model = this.update(&mut cx, |this, _| this.model.clone())?;
+ let conversation = Conversation::deserialize(
+ saved_conversation,
+ model,
+ path.clone(),
+ languages,
+ &mut cx,
+ )
+ .await?;
this.update(&mut cx, |this, cx| {
- // If, by the time we've loaded the conversation, the user has already opened
- // the same conversation, we don't want to open it again.
- if let Some(ix) = this.editor_index_for_path(&path, cx) {
- this.set_active_editor_index(Some(ix), cx);
- } else {
- let editor = cx.new_view(|cx| {
- ConversationEditor::for_conversation(conversation, fs, workspace, cx)
- });
- this.add_conversation(editor, cx);
- }
+ let editor = cx.new_view(|cx| {
+ ConversationEditor::for_conversation(conversation, fs, workspace, cx)
+ });
+ this.show_conversation(editor, cx);
})?;
Ok(())
})
}
- fn editor_index_for_path(&self, path: &Path, cx: &AppContext) -> Option<usize> {
- self.editors
- .iter()
- .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
- }
-
- fn has_credentials(&mut self) -> bool {
- self.completion_provider.has_credentials()
+ fn is_authenticated(&mut self, cx: &mut ViewContext<Self>) -> bool {
+ CompletionProvider::global(cx).is_authenticated()
}
- fn load_credentials(&mut self, cx: &mut ViewContext<Self>) -> Task<()> {
- let completion_provider = self.completion_provider.clone();
- cx.spawn(|_, mut cx| async move {
- if let Some(retrieve_credentials) = cx
- .update(|cx| completion_provider.retrieve_credentials(cx))
- .log_err()
- {
- retrieve_credentials.await;
- }
- })
+ fn authenticate(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
+ cx.update_global::<CompletionProvider, _>(|provider, cx| provider.authenticate(cx))
}
-}
-fn build_api_key_editor(cx: &mut WindowContext) -> View<Editor> {
- cx.new_view(|cx| {
- let mut editor = Editor::single_line(cx);
- editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx);
- editor
- })
-}
-
-impl Render for AssistantPanel {
- fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
- if let Some(api_key_editor) = self.api_key_editor.clone() {
- const INSTRUCTIONS: [&'static str; 6] = [
- "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
- " - You can create an API key at: platform.openai.com/api-keys",
- " - Make sure your OpenAI account has credits",
- " - Having a subscription for another service like GitHub Copilot won't work.",
- " ",
- "Paste your OpenAI API key and press Enter to use the assistant:"
- ];
-
- v_flex()
- .p_4()
- .size_full()
- .on_action(cx.listener(AssistantPanel::save_credentials))
- .track_focus(&self.focus_handle)
- .children(
- INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
- )
- .child(
- h_flex()
- .w_full()
- .my_2()
- .px_2()
- .py_1()
- .bg(cx.theme().colors().editor_background)
- .rounded_md()
- .child(self.render_api_key_editor(&api_key_editor, cx)),
- )
- .child(
+ fn render_signed_in(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+ let header = TabBar::new("assistant_header")
+ .start_child(
+ h_flex().gap_1().child(Self::render_hamburger_button(cx)), // .children(title),
+ )
+ .children(self.active_conversation_editor().map(|editor| {
+ h_flex()
+ .h(rems(Tab::CONTAINER_HEIGHT_IN_REMS))
+ .flex_1()
+ .px_2()
+ .child(Label::new(editor.read(cx).title(cx)).into_element())
+ }))
+ .when(self.focus_handle.contains_focused(cx), |this| {
+ this.end_child(
h_flex()
.gap_2()
- .child(Label::new("Click on").size(LabelSize::Small))
- .child(Icon::new(IconName::Ai).size(IconSize::XSmall))
+ .when(self.active_conversation_editor().is_some(), |this| {
+ this.child(h_flex().gap_1().children(self.render_editor_tools(cx)))
+ .child(
+ ui::Divider::vertical()
+ .inset()
+ .color(ui::DividerColor::Border),
+ )
+ })
.child(
- Label::new("in the status bar to close this panel.")
- .size(LabelSize::Small),
+ h_flex()
+ .gap_1()
+ .child(Self::render_plus_button(cx))
+ .child(self.render_zoom_button(cx)),
),
)
- } else {
- let header = TabBar::new("assistant_header")
- .start_child(
- h_flex().gap_1().child(Self::render_hamburger_button(cx)), // .children(title),
- )
- .children(self.active_editor().map(|editor| {
- h_flex()
- .h(rems(Tab::CONTAINER_HEIGHT_IN_REMS))
- .flex_1()
- .px_2()
- .child(Label::new(editor.read(cx).title(cx)).into_element())
- }))
- .when(self.focus_handle.contains_focused(cx), |this| {
- this.end_child(
- h_flex()
- .gap_2()
- .when(self.active_editor().is_some(), |this| {
- this.child(h_flex().gap_1().children(self.render_editor_tools(cx)))
- .child(
- ui::Divider::vertical()
- .inset()
- .color(ui::DividerColor::Border),
- )
- })
- .child(
- h_flex()
- .gap_1()
- .child(Self::render_plus_button(cx))
- .child(self.render_zoom_button(cx)),
- ),
- )
- });
+ });
- let contents = if self.active_editor().is_some() {
- let mut registrar = DivRegistrar::new(
- |panel, cx| panel.toolbar.read(cx).item_of_type::<BufferSearchBar>(),
- cx,
- );
- BufferSearchBar::register(&mut registrar);
- registrar.into_div()
+ let contents = if self.active_conversation_editor().is_some() {
+ let mut registrar = DivRegistrar::new(
+ |panel, cx| panel.toolbar.read(cx).item_of_type::<BufferSearchBar>(),
+ cx,
+ );
+ BufferSearchBar::register(&mut registrar);
+ registrar.into_div()
+ } else {
+ div()
+ };
+ v_flex()
+ .key_context("AssistantPanel")
+ .size_full()
+ .on_action(cx.listener(|this, _: &workspace::NewFile, cx| {
+ this.new_conversation(cx);
+ }))
+ .on_action(cx.listener(AssistantPanel::toggle_zoom))
+ .on_action(cx.listener(AssistantPanel::deploy))
+ .on_action(cx.listener(AssistantPanel::select_next_match))
+ .on_action(cx.listener(AssistantPanel::select_prev_match))
+ .on_action(cx.listener(AssistantPanel::handle_editor_cancel))
+ .on_action(cx.listener(AssistantPanel::reset_credentials))
+ .track_focus(&self.focus_handle)
+ .child(header)
+ .children(if self.toolbar.read(cx).hidden() {
+ None
} else {
- div()
- };
- v_flex()
- .key_context("AssistantPanel")
- .size_full()
- .on_action(cx.listener(|this, _: &workspace::NewFile, cx| {
- this.new_conversation(cx);
- }))
- .on_action(cx.listener(AssistantPanel::reset_credentials))
- .on_action(cx.listener(AssistantPanel::toggle_zoom))
- .on_action(cx.listener(AssistantPanel::deploy))
- .on_action(cx.listener(AssistantPanel::select_next_match))
- .on_action(cx.listener(AssistantPanel::select_prev_match))
- .on_action(cx.listener(AssistantPanel::handle_editor_cancel))
- .track_focus(&self.focus_handle)
- .child(header)
- .children(if self.toolbar.read(cx).hidden() {
- None
- } else {
- Some(self.toolbar.clone())
- })
- .child(
- contents
- .flex_1()
- .child(if let Some(editor) = self.active_editor() {
- editor.clone().into_any_element()
- } else {
- let view = cx.view().clone();
- let scroll_handle = self.saved_conversations_scroll_handle.clone();
- let conversation_count = self.saved_conversations.len();
- canvas(
- move |bounds, cx| {
- let mut list = uniform_list(
- view,
- "saved_conversations",
- conversation_count,
- |this, range, cx| {
- range
- .map(|ix| this.render_saved_conversation(ix, cx))
- .collect()
- },
- )
- .track_scroll(scroll_handle)
- .into_any_element();
- list.layout(bounds.origin, bounds.size.into(), cx);
- list
+ Some(self.toolbar.clone())
+ })
+ .child(contents.flex_1().child(
+ if self.show_saved_conversations || self.active_conversation_editor().is_none() {
+ let view = cx.view().clone();
+ let scroll_handle = self.saved_conversations_scroll_handle.clone();
+ let conversation_count = self.saved_conversations.len();
+ canvas(
+ move |bounds, cx| {
+ let mut saved_conversations = uniform_list(
+ view,
+ "saved_conversations",
+ conversation_count,
+ |this, range, cx| {
+ range
+ .map(|ix| this.render_saved_conversation(ix, cx))
+ .collect()
},
- |_bounds, mut list, cx| list.paint(cx),
)
- .size_full()
- .into_any_element()
- }),
- )
+ .track_scroll(scroll_handle)
+ .into_any_element();
+ saved_conversations.layout(
+ bounds.origin,
+ bounds.size.map(AvailableSpace::Definite),
+ cx,
+ );
+ saved_conversations
+ },
+ |_bounds, mut saved_conversations, cx| saved_conversations.paint(cx),
+ )
+ .size_full()
+ .into_any_element()
+ } else {
+ let editor = self.active_conversation_editor().unwrap();
+ let conversation = editor.read(cx).conversation.clone();
+ div()
+ .size_full()
+ .child(editor.clone())
+ .child(
+ h_flex()
+ .absolute()
+ .gap_1()
+ .top_3()
+ .right_5()
+ .child(self.render_model(&conversation, cx))
+ .children(self.render_remaining_tokens(&conversation, cx)),
+ )
+ .into_any_element()
+ },
+ ))
+ }
+
+ fn render_model(
+ &self,
+ conversation: &Model<Conversation>,
+ cx: &mut ViewContext<Self>,
+ ) -> impl IntoElement {
+ Button::new("current_model", conversation.read(cx).model.display_name())
+ .style(ButtonStyle::Filled)
+ .tooltip(move |cx| Tooltip::text("Change Model", cx))
+ .on_click(cx.listener(|this, _, cx| this.cycle_model(cx)))
+ }
+
+ fn render_remaining_tokens(
+ &self,
+ conversation: &Model<Conversation>,
+ cx: &mut ViewContext<Self>,
+ ) -> Option<impl IntoElement> {
+ let remaining_tokens = conversation.read(cx).remaining_tokens()?;
+ let remaining_tokens_color = if remaining_tokens <= 0 {
+ Color::Error
+ } else if remaining_tokens <= 500 {
+ Color::Warning
+ } else {
+ Color::Default
+ };
+ Some(Label::new(remaining_tokens.to_string()).color(remaining_tokens_color))
+ }
+}
+
+impl Render for AssistantPanel {
+ fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+ if let Some(authentication_prompt) = self.authentication_prompt.as_ref() {
+ authentication_prompt.clone().into_any()
+ } else {
+ self.render_signed_in(cx).into_any_element()
}
}
}
@@ -1,169 +1,296 @@
-use ai::providers::open_ai::{
- AzureOpenAiApiVersion, OpenAiCompletionProviderKind, OPEN_AI_API_URL,
-};
-use anyhow::anyhow;
+use std::fmt;
+
use gpui::Pixels;
-use schemars::JsonSchema;
-use serde::{Deserialize, Serialize};
+pub use open_ai::Model as OpenAiModel;
+use schemars::{
+ schema::{InstanceType, Metadata, Schema, SchemaObject},
+ JsonSchema,
+};
+use serde::{
+ de::{self, Visitor},
+ Deserialize, Deserializer, Serialize, Serializer,
+};
use settings::Settings;
-#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
-#[serde(rename_all = "snake_case")]
-pub enum OpenAiModel {
- #[serde(rename = "gpt-3.5-turbo-0613")]
- ThreePointFiveTurbo,
- #[serde(rename = "gpt-4-0613")]
- Four,
- #[serde(rename = "gpt-4-1106-preview")]
- FourTurbo,
+#[derive(Clone, Debug, Default, PartialEq)]
+pub enum ZedDotDevModel {
+ GptThreePointFiveTurbo,
+ GptFour,
+ #[default]
+ GptFourTurbo,
+ Custom(String),
}
-impl OpenAiModel {
- pub fn full_name(&self) -> &'static str {
- match self {
- Self::ThreePointFiveTurbo => "gpt-3.5-turbo-0613",
- Self::Four => "gpt-4-0613",
- Self::FourTurbo => "gpt-4-1106-preview",
+impl Serialize for ZedDotDevModel {
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+ where
+ S: Serializer,
+ {
+ serializer.serialize_str(self.id())
+ }
+}
+
+impl<'de> Deserialize<'de> for ZedDotDevModel {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ struct ZedDotDevModelVisitor;
+
+ impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
+ type Value = ZedDotDevModel;
+
+ fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
+ }
+
+ fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
+ where
+ E: de::Error,
+ {
+ match value {
+ "gpt-3.5-turbo" => Ok(ZedDotDevModel::GptThreePointFiveTurbo),
+ "gpt-4" => Ok(ZedDotDevModel::GptFour),
+ "gpt-4-turbo-preview" => Ok(ZedDotDevModel::GptFourTurbo),
+ _ => Ok(ZedDotDevModel::Custom(value.to_owned())),
+ }
+ }
}
+
+ deserializer.deserialize_str(ZedDotDevModelVisitor)
}
+}
+
+impl JsonSchema for ZedDotDevModel {
+ fn schema_name() -> String {
+ "ZedDotDevModel".to_owned()
+ }
+
+ fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
+ let variants = vec![
+ "gpt-3.5-turbo".to_owned(),
+ "gpt-4".to_owned(),
+ "gpt-4-turbo-preview".to_owned(),
+ ];
+ Schema::Object(SchemaObject {
+ instance_type: Some(InstanceType::String.into()),
+ enum_values: Some(variants.into_iter().map(|s| s.into()).collect()),
+ metadata: Some(Box::new(Metadata {
+ title: Some("ZedDotDevModel".to_owned()),
+ default: Some(serde_json::json!("gpt-4-turbo-preview")),
+ examples: vec![
+ serde_json::json!("gpt-3.5-turbo"),
+ serde_json::json!("gpt-4"),
+ serde_json::json!("gpt-4-turbo-preview"),
+ serde_json::json!("custom-model-name"),
+ ],
+ ..Default::default()
+ })),
+ ..Default::default()
+ })
+ }
+}
- pub fn short_name(&self) -> &'static str {
+impl ZedDotDevModel {
+ pub fn id(&self) -> &str {
match self {
- Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
- Self::Four => "gpt-4",
- Self::FourTurbo => "gpt-4-turbo",
+ Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
+ Self::GptFour => "gpt-4",
+ Self::GptFourTurbo => "gpt-4-turbo-preview",
+ Self::Custom(id) => id,
}
}
- pub fn cycle(&self) -> Self {
+ pub fn display_name(&self) -> &str {
match self {
- Self::ThreePointFiveTurbo => Self::Four,
- Self::Four => Self::FourTurbo,
- Self::FourTurbo => Self::ThreePointFiveTurbo,
+ Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
+ Self::GptFour => "gpt-4",
+ Self::GptFourTurbo => "gpt-4-turbo",
+ Self::Custom(id) => id.as_str(),
}
}
}
-#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
+#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum AssistantDockPosition {
Left,
+ #[default]
Right,
Bottom,
}
-#[derive(Debug, Deserialize)]
+#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
+#[serde(tag = "name", rename_all = "snake_case")]
+pub enum AssistantProvider {
+ #[serde(rename = "zed.dev")]
+ ZedDotDev {
+ #[serde(default)]
+ default_model: ZedDotDevModel,
+ },
+ #[serde(rename = "openai")]
+ OpenAi {
+ #[serde(default)]
+ default_model: OpenAiModel,
+ #[serde(default = "open_ai_url")]
+ api_url: String,
+ },
+}
+
+impl Default for AssistantProvider {
+ fn default() -> Self {
+ Self::ZedDotDev {
+ default_model: ZedDotDevModel::default(),
+ }
+ }
+}
+
+fn open_ai_url() -> String {
+ "https://api.openai.com/v1".into()
+}
+
+#[derive(Default, Debug, Deserialize, Serialize)]
pub struct AssistantSettings {
- /// Whether to show the assistant panel button in the status bar.
pub button: bool,
- /// Where to dock the assistant.
pub dock: AssistantDockPosition,
- /// Default width in pixels when the assistant is docked to the left or right.
pub default_width: Pixels,
- /// Default height in pixels when the assistant is docked to the bottom.
pub default_height: Pixels,
- /// The default OpenAI model to use when starting new conversations.
- #[deprecated = "Please use `provider.default_model` instead."]
- pub default_open_ai_model: OpenAiModel,
- /// OpenAI API base URL to use when starting new conversations.
- #[deprecated = "Please use `provider.api_url` instead."]
- pub openai_api_url: String,
- /// The settings for the AI provider.
- pub provider: AiProviderSettings,
+ pub provider: AssistantProvider,
}
-impl AssistantSettings {
- pub fn provider_kind(&self) -> anyhow::Result<OpenAiCompletionProviderKind> {
- match &self.provider {
- AiProviderSettings::OpenAi(_) => Ok(OpenAiCompletionProviderKind::OpenAi),
- AiProviderSettings::AzureOpenAi(settings) => {
- let deployment_id = settings
- .deployment_id
- .clone()
- .ok_or_else(|| anyhow!("no Azure OpenAI deployment ID"))?;
- let api_version = settings
- .api_version
- .ok_or_else(|| anyhow!("no Azure OpenAI API version"))?;
-
- Ok(OpenAiCompletionProviderKind::AzureOpenAi {
- deployment_id,
- api_version,
- })
- }
- }
+/// Assistant panel settings
+#[derive(Clone, Serialize, Deserialize, Debug)]
+#[serde(untagged)]
+pub enum AssistantSettingsContent {
+ Versioned(VersionedAssistantSettingsContent),
+ Legacy(LegacyAssistantSettingsContent),
+}
+
+impl JsonSchema for AssistantSettingsContent {
+ fn schema_name() -> String {
+ VersionedAssistantSettingsContent::schema_name()
}
- pub fn provider_api_url(&self) -> anyhow::Result<String> {
- match &self.provider {
- AiProviderSettings::OpenAi(settings) => Ok(settings
- .api_url
- .clone()
- .unwrap_or_else(|| OPEN_AI_API_URL.to_string())),
- AiProviderSettings::AzureOpenAi(settings) => settings
- .api_url
- .clone()
- .ok_or_else(|| anyhow!("no Azure OpenAI API URL")),
- }
+ fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
+ VersionedAssistantSettingsContent::json_schema(gen)
}
- pub fn provider_model(&self) -> anyhow::Result<OpenAiModel> {
- match &self.provider {
- AiProviderSettings::OpenAi(settings) => {
- Ok(settings.default_model.unwrap_or(OpenAiModel::FourTurbo))
- }
- AiProviderSettings::AzureOpenAi(settings) => {
- let deployment_id = settings
- .deployment_id
- .as_deref()
- .ok_or_else(|| anyhow!("no Azure OpenAI deployment ID"))?;
-
- match deployment_id {
- // https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-preview
- "gpt-4" | "gpt-4-32k" => Ok(OpenAiModel::Four),
- // https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-35
- "gpt-35-turbo" | "gpt-35-turbo-16k" | "gpt-35-turbo-instruct" => {
- Ok(OpenAiModel::ThreePointFiveTurbo)
+ fn is_referenceable() -> bool {
+ VersionedAssistantSettingsContent::is_referenceable()
+ }
+}
+
+impl Default for AssistantSettingsContent {
+ fn default() -> Self {
+ Self::Versioned(VersionedAssistantSettingsContent::default())
+ }
+}
+
+impl AssistantSettingsContent {
+ fn upgrade(&self) -> AssistantSettingsContentV1 {
+ match self {
+ AssistantSettingsContent::Versioned(settings) => match settings {
+ VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
+ },
+ AssistantSettingsContent::Legacy(settings) => {
+ if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
+ AssistantSettingsContentV1 {
+ button: settings.button,
+ dock: settings.dock,
+ default_width: settings.default_width,
+ default_height: settings.default_height,
+ provider: Some(AssistantProvider::OpenAi {
+ default_model: settings
+ .default_open_ai_model
+ .clone()
+ .unwrap_or_default(),
+ api_url: open_ai_api_url.clone(),
+ }),
+ }
+ } else if let Some(open_ai_model) = settings.default_open_ai_model.clone() {
+ AssistantSettingsContentV1 {
+ button: settings.button,
+ dock: settings.dock,
+ default_width: settings.default_width,
+ default_height: settings.default_height,
+ provider: Some(AssistantProvider::OpenAi {
+ default_model: open_ai_model,
+ api_url: open_ai_url(),
+ }),
+ }
+ } else {
+ AssistantSettingsContentV1 {
+ button: settings.button,
+ dock: settings.dock,
+ default_width: settings.default_width,
+ default_height: settings.default_height,
+ provider: None,
}
- _ => Err(anyhow!(
- "no matching OpenAI model found for deployment ID: '{deployment_id}'"
- )),
}
}
}
}
- pub fn provider_model_name(&self) -> anyhow::Result<String> {
- match &self.provider {
- AiProviderSettings::OpenAi(settings) => Ok(settings
- .default_model
- .unwrap_or(OpenAiModel::FourTurbo)
- .full_name()
- .to_string()),
- AiProviderSettings::AzureOpenAi(settings) => settings
- .deployment_id
- .clone()
- .ok_or_else(|| anyhow!("no Azure OpenAI deployment ID")),
+ pub fn set_dock(&mut self, dock: AssistantDockPosition) {
+ match self {
+ AssistantSettingsContent::Versioned(settings) => match settings {
+ VersionedAssistantSettingsContent::V1(settings) => {
+ settings.dock = Some(dock);
+ }
+ },
+ AssistantSettingsContent::Legacy(settings) => {
+ settings.dock = Some(dock);
+ }
}
}
}
-impl Settings for AssistantSettings {
- const KEY: Option<&'static str> = Some("assistant");
-
- type FileContent = AssistantSettingsContent;
+#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
+#[serde(tag = "version")]
+pub enum VersionedAssistantSettingsContent {
+ #[serde(rename = "1")]
+ V1(AssistantSettingsContentV1),
+}
- fn load(
- default_value: &Self::FileContent,
- user_values: &[&Self::FileContent],
- _: &mut gpui::AppContext,
- ) -> anyhow::Result<Self> {
- Self::load_via_json_merge(default_value, user_values)
+impl Default for VersionedAssistantSettingsContent {
+ fn default() -> Self {
+ Self::V1(AssistantSettingsContentV1 {
+ button: None,
+ dock: None,
+ default_width: None,
+ default_height: None,
+ provider: None,
+ })
}
}
-/// Assistant panel settings
-#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
-pub struct AssistantSettingsContent {
+#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
+pub struct AssistantSettingsContentV1 {
+ /// Whether to show the assistant panel button in the status bar.
+ ///
+ /// Default: true
+ button: Option<bool>,
+ /// Where to dock the assistant.
+ ///
+ /// Default: right
+ dock: Option<AssistantDockPosition>,
+ /// Default width in pixels when the assistant is docked to the left or right.
+ ///
+ /// Default: 640
+ default_width: Option<f32>,
+ /// Default height in pixels when the assistant is docked to the bottom.
+ ///
+ /// Default: 320
+ default_height: Option<f32>,
+ /// The provider of the assistant service.
+ ///
+ /// This can either be the internal `zed.dev` service or an external `openai` service,
+ /// each with their respective default models and configurations.
+ provider: Option<AssistantProvider>,
+}
+
+#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
+pub struct LegacyAssistantSettingsContent {
/// Whether to show the assistant panel button in the status bar.
///
/// Default: true
@@ -180,88 +307,164 @@ pub struct AssistantSettingsContent {
///
/// Default: 320
pub default_height: Option<f32>,
- /// Deprecated: Please use `provider.default_model` instead.
/// The default OpenAI model to use when starting new conversations.
///
/// Default: gpt-4-1106-preview
- #[deprecated = "Please use `provider.default_model` instead."]
pub default_open_ai_model: Option<OpenAiModel>,
- /// Deprecated: Please use `provider.api_url` instead.
/// OpenAI API base URL to use when starting new conversations.
///
/// Default: https://api.openai.com/v1
- #[deprecated = "Please use `provider.api_url` instead."]
pub openai_api_url: Option<String>,
- /// The settings for the AI provider.
- #[serde(default)]
- pub provider: AiProviderSettingsContent,
}
-#[derive(Debug, Clone, Deserialize)]
-#[serde(tag = "type", rename_all = "snake_case")]
-pub enum AiProviderSettings {
- /// The settings for the OpenAI provider.
- #[serde(rename = "openai")]
- OpenAi(OpenAiProviderSettings),
- /// The settings for the Azure OpenAI provider.
- #[serde(rename = "azure_openai")]
- AzureOpenAi(AzureOpenAiProviderSettings),
-}
+impl Settings for AssistantSettings {
+ const KEY: Option<&'static str> = Some("assistant");
-/// The settings for the AI provider used by the Zed Assistant.
-#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
-#[serde(tag = "type", rename_all = "snake_case")]
-pub enum AiProviderSettingsContent {
- /// The settings for the OpenAI provider.
- #[serde(rename = "openai")]
- OpenAi(OpenAiProviderSettingsContent),
- /// The settings for the Azure OpenAI provider.
- #[serde(rename = "azure_openai")]
- AzureOpenAi(AzureOpenAiProviderSettingsContent),
-}
+ type FileContent = AssistantSettingsContent;
-impl Default for AiProviderSettingsContent {
- fn default() -> Self {
- Self::OpenAi(OpenAiProviderSettingsContent::default())
+ fn load(
+ default_value: &Self::FileContent,
+ user_values: &[&Self::FileContent],
+ _: &mut gpui::AppContext,
+ ) -> anyhow::Result<Self> {
+ let mut settings = AssistantSettings::default();
+
+ for value in [default_value].iter().chain(user_values) {
+ let value = value.upgrade();
+ merge(&mut settings.button, value.button);
+ merge(&mut settings.dock, value.dock);
+ merge(
+ &mut settings.default_width,
+ value.default_width.map(Into::into),
+ );
+ merge(
+ &mut settings.default_height,
+ value.default_height.map(Into::into),
+ );
+ if let Some(provider) = value.provider.clone() {
+ match (&mut settings.provider, provider) {
+ (
+ AssistantProvider::ZedDotDev { default_model },
+ AssistantProvider::ZedDotDev {
+ default_model: default_model_override,
+ },
+ ) => {
+ *default_model = default_model_override;
+ }
+ (
+ AssistantProvider::OpenAi {
+ default_model,
+ api_url,
+ },
+ AssistantProvider::OpenAi {
+ default_model: default_model_override,
+ api_url: api_url_override,
+ },
+ ) => {
+ *default_model = default_model_override;
+ *api_url = api_url_override;
+ }
+ (merged, provider_override) => {
+ *merged = provider_override;
+ }
+ }
+ }
+ }
+
+ Ok(settings)
}
}
-#[derive(Debug, Clone, Deserialize)]
-pub struct OpenAiProviderSettings {
- /// The OpenAI API base URL to use when starting new conversations.
- pub api_url: Option<String>,
- /// The default OpenAI model to use when starting new conversations.
- pub default_model: Option<OpenAiModel>,
+fn merge<T: Copy>(target: &mut T, value: Option<T>) {
+ if let Some(value) = value {
+ *target = value;
+ }
}
-#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
-pub struct OpenAiProviderSettingsContent {
- /// The OpenAI API base URL to use when starting new conversations.
- ///
- /// Default: https://api.openai.com/v1
- pub api_url: Option<String>,
- /// The default OpenAI model to use when starting new conversations.
- ///
- /// Default: gpt-4-1106-preview
- pub default_model: Option<OpenAiModel>,
-}
+#[cfg(test)]
+mod tests {
+ use gpui::AppContext;
+ use settings::SettingsStore;
-#[derive(Debug, Clone, Deserialize)]
-pub struct AzureOpenAiProviderSettings {
- /// The Azure OpenAI API base URL to use when starting new conversations.
- pub api_url: Option<String>,
- /// The Azure OpenAI API version.
- pub api_version: Option<AzureOpenAiApiVersion>,
- /// The Azure OpenAI API deployment ID.
- pub deployment_id: Option<String>,
-}
+ use super::*;
+
+ #[gpui::test]
+ fn test_deserialize_assistant_settings(cx: &mut AppContext) {
+ let store = settings::SettingsStore::test(cx);
+ cx.set_global(store);
-#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
-pub struct AzureOpenAiProviderSettingsContent {
- /// The Azure OpenAI API base URL to use when starting new conversations.
- pub api_url: Option<String>,
- /// The Azure OpenAI API version.
- pub api_version: Option<AzureOpenAiApiVersion>,
- /// The Azure OpenAI deployment ID.
- pub deployment_id: Option<String>,
+ // Settings default to gpt-4-turbo.
+ AssistantSettings::register(cx);
+ assert_eq!(
+ AssistantSettings::get_global(cx).provider,
+ AssistantProvider::OpenAi {
+ default_model: OpenAiModel::FourTurbo,
+ api_url: open_ai_url()
+ }
+ );
+
+ // Ensure backward-compatibility.
+ cx.update_global::<SettingsStore, _>(|store, cx| {
+ store
+ .set_user_settings(
+ r#"{
+ "assistant": {
+ "openai_api_url": "test-url",
+ }
+ }"#,
+ cx,
+ )
+ .unwrap();
+ });
+ assert_eq!(
+ AssistantSettings::get_global(cx).provider,
+ AssistantProvider::OpenAi {
+ default_model: OpenAiModel::FourTurbo,
+ api_url: "test-url".into()
+ }
+ );
+ cx.update_global::<SettingsStore, _>(|store, cx| {
+ store
+ .set_user_settings(
+ r#"{
+ "assistant": {
+ "default_open_ai_model": "gpt-4-0613"
+ }
+ }"#,
+ cx,
+ )
+ .unwrap();
+ });
+ assert_eq!(
+ AssistantSettings::get_global(cx).provider,
+ AssistantProvider::OpenAi {
+ default_model: OpenAiModel::Four,
+ api_url: open_ai_url()
+ }
+ );
+
+ // The new version supports setting a custom model when using zed.dev.
+ cx.update_global::<SettingsStore, _>(|store, cx| {
+ store
+ .set_user_settings(
+ r#"{
+ "assistant": {
+ "version": "1",
+ "provider": {
+ "name": "zed.dev",
+ "default_model": "custom"
+ }
+ }
+ }"#,
+ cx,
+ )
+ .unwrap();
+ });
+ assert_eq!(
+ AssistantSettings::get_global(cx).provider,
+ AssistantProvider::ZedDotDev {
+ default_model: ZedDotDevModel::Custom("custom".into())
+ }
+ );
+ }
}
@@ -1,12 +1,13 @@
-use crate::streaming_diff::{Hunk, StreamingDiff};
-use ai::completion::{CompletionProvider, CompletionRequest};
+use crate::{
+ streaming_diff::{Hunk, StreamingDiff},
+ CompletionProvider, LanguageModelRequest,
+};
use anyhow::Result;
use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
use gpui::{EventEmitter, Model, ModelContext, Task};
use language::{Rope, TransactionId};
-use multi_buffer;
-use std::{cmp, future, ops::Range, sync::Arc};
+use std::{cmp, future, ops::Range};
pub enum Event {
Finished,
@@ -20,7 +21,6 @@ pub enum CodegenKind {
}
pub struct Codegen {
- provider: Arc<dyn CompletionProvider>,
buffer: Model<MultiBuffer>,
snapshot: MultiBufferSnapshot,
kind: CodegenKind,
@@ -35,15 +35,9 @@ pub struct Codegen {
impl EventEmitter<Event> for Codegen {}
impl Codegen {
- pub fn new(
- buffer: Model<MultiBuffer>,
- kind: CodegenKind,
- provider: Arc<dyn CompletionProvider>,
- cx: &mut ModelContext<Self>,
- ) -> Self {
+ pub fn new(buffer: Model<MultiBuffer>, kind: CodegenKind, cx: &mut ModelContext<Self>) -> Self {
let snapshot = buffer.read(cx).snapshot(cx);
Self {
- provider,
buffer: buffer.clone(),
snapshot,
kind,
@@ -94,7 +88,7 @@ impl Codegen {
self.error.as_ref()
}
- pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
+ pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
let range = self.range();
let snapshot = self.snapshot.clone();
let selected_text = snapshot
@@ -108,7 +102,7 @@ impl Codegen {
.next()
.unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
- let response = self.provider.complete(prompt);
+ let response = CompletionProvider::global(cx).complete(prompt);
self.generation = cx.spawn(|this, mut cx| {
async move {
let generate = async {
@@ -305,7 +299,7 @@ fn strip_invalid_spans_from_codeblock(
}
if first_line {
- if buffer == "" || buffer == "`" || buffer == "``" {
+ if buffer.is_empty() || buffer == "`" || buffer == "``" {
return future::ready(None);
} else if buffer.starts_with("```") {
starts_with_markdown_codeblock = true;
@@ -360,8 +354,9 @@ fn strip_invalid_spans_from_codeblock(
mod tests {
use std::sync::Arc;
+ use crate::FakeCompletionProvider;
+
use super::*;
- use ai::test::FakeCompletionProvider;
use futures::stream::{self};
use gpui::{Context, TestAppContext};
use indoc::indoc;
@@ -378,15 +373,11 @@ mod tests {
pub name: String,
}
- impl CompletionRequest for DummyCompletionRequest {
- fn data(&self) -> serde_json::Result<String> {
- serde_json::to_string(self)
- }
- }
-
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
+ let provider = FakeCompletionProvider::default();
cx.set_global(cx.update(SettingsStore::test));
+ cx.set_global(CompletionProvider::Fake(provider.clone()));
cx.update(language_settings::init);
let text = indoc! {"
@@ -405,19 +396,10 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
});
- let provider = Arc::new(FakeCompletionProvider::new());
- let codegen = cx.new_model(|cx| {
- Codegen::new(
- buffer.clone(),
- CodegenKind::Transform { range },
- provider.clone(),
- cx,
- )
- });
+ let codegen =
+ cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Transform { range }, cx));
- let request = Box::new(DummyCompletionRequest {
- name: "test".to_string(),
- });
+ let request = LanguageModelRequest::default();
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
@@ -430,8 +412,7 @@ mod tests {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
- println!("CHUNK: {:?}", &chunk);
- provider.send_completion(chunk);
+ provider.send_completion(chunk.into());
new_text = suffix;
cx.background_executor.run_until_parked();
}
@@ -456,6 +437,8 @@ mod tests {
cx: &mut TestAppContext,
mut rng: StdRng,
) {
+ let provider = FakeCompletionProvider::default();
+ cx.set_global(CompletionProvider::Fake(provider.clone()));
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
@@ -472,19 +455,10 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6))
});
- let provider = Arc::new(FakeCompletionProvider::new());
- let codegen = cx.new_model(|cx| {
- Codegen::new(
- buffer.clone(),
- CodegenKind::Generate { position },
- provider.clone(),
- cx,
- )
- });
+ let codegen =
+ cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
- let request = Box::new(DummyCompletionRequest {
- name: "test".to_string(),
- });
+ let request = LanguageModelRequest::default();
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
@@ -497,7 +471,7 @@ mod tests {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
- provider.send_completion(chunk);
+ provider.send_completion(chunk.into());
new_text = suffix;
cx.background_executor.run_until_parked();
}
@@ -522,6 +496,8 @@ mod tests {
cx: &mut TestAppContext,
mut rng: StdRng,
) {
+ let provider = FakeCompletionProvider::default();
+ cx.set_global(CompletionProvider::Fake(provider.clone()));
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
@@ -538,19 +514,10 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2))
});
- let provider = Arc::new(FakeCompletionProvider::new());
- let codegen = cx.new_model(|cx| {
- Codegen::new(
- buffer.clone(),
- CodegenKind::Generate { position },
- provider.clone(),
- cx,
- )
- });
+ let codegen =
+ cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
- let request = Box::new(DummyCompletionRequest {
- name: "test".to_string(),
- });
+ let request = LanguageModelRequest::default();
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
@@ -563,8 +530,7 @@ mod tests {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
- println!("{:?}", &chunk);
- provider.send_completion(chunk);
+ provider.send_completion(chunk.into());
new_text = suffix;
cx.background_executor.run_until_parked();
}
@@ -0,0 +1,188 @@
+#[cfg(test)]
+mod fake;
+mod open_ai;
+mod zed;
+
+#[cfg(test)]
+pub use fake::*;
+pub use open_ai::*;
+pub use zed::*;
+
+use crate::{
+ assistant_settings::{AssistantProvider, AssistantSettings},
+ LanguageModel, LanguageModelRequest,
+};
+use anyhow::Result;
+use client::Client;
+use futures::{future::BoxFuture, stream::BoxStream};
+use gpui::{AnyView, AppContext, Task, WindowContext};
+use settings::{Settings, SettingsStore};
+use std::sync::Arc;
+
+pub fn init(client: Arc<Client>, cx: &mut AppContext) {
+ let mut settings_version = 0;
+ let provider = match &AssistantSettings::get_global(cx).provider {
+ AssistantProvider::ZedDotDev { default_model } => {
+ CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
+ default_model.clone(),
+ client.clone(),
+ settings_version,
+ cx,
+ ))
+ }
+ AssistantProvider::OpenAi {
+ default_model,
+ api_url,
+ } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
+ default_model.clone(),
+ api_url.clone(),
+ client.http_client(),
+ settings_version,
+ )),
+ };
+ cx.set_global(provider);
+
+ cx.observe_global::<SettingsStore>(move |cx| {
+ settings_version += 1;
+ cx.update_global::<CompletionProvider, _>(|provider, cx| {
+ match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
+ (
+ CompletionProvider::OpenAi(provider),
+ AssistantProvider::OpenAi {
+ default_model,
+ api_url,
+ },
+ ) => {
+ provider.update(default_model.clone(), api_url.clone(), settings_version);
+ }
+ (
+ CompletionProvider::ZedDotDev(provider),
+ AssistantProvider::ZedDotDev { default_model },
+ ) => {
+ provider.update(default_model.clone(), settings_version);
+ }
+ (CompletionProvider::OpenAi(_), AssistantProvider::ZedDotDev { default_model }) => {
+ *provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new(
+ default_model.clone(),
+ client.clone(),
+ settings_version,
+ cx,
+ ));
+ }
+ (
+ CompletionProvider::ZedDotDev(_),
+ AssistantProvider::OpenAi {
+ default_model,
+ api_url,
+ },
+ ) => {
+ *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
+ default_model.clone(),
+ api_url.clone(),
+ client.http_client(),
+ settings_version,
+ ));
+ }
+ #[cfg(test)]
+ (CompletionProvider::Fake(_), _) => unimplemented!(),
+ }
+ })
+ })
+ .detach();
+}
+
+pub enum CompletionProvider {
+ OpenAi(OpenAiCompletionProvider),
+ ZedDotDev(ZedDotDevCompletionProvider),
+ #[cfg(test)]
+ Fake(FakeCompletionProvider),
+}
+
+impl gpui::Global for CompletionProvider {}
+
+impl CompletionProvider {
+ pub fn global(cx: &AppContext) -> &Self {
+ cx.global::<Self>()
+ }
+
+ pub fn settings_version(&self) -> usize {
+ match self {
+ CompletionProvider::OpenAi(provider) => provider.settings_version(),
+ CompletionProvider::ZedDotDev(provider) => provider.settings_version(),
+ #[cfg(test)]
+ CompletionProvider::Fake(_) => unimplemented!(),
+ }
+ }
+
+ pub fn is_authenticated(&self) -> bool {
+ match self {
+ CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
+ CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(),
+ #[cfg(test)]
+ CompletionProvider::Fake(_) => true,
+ }
+ }
+
+ pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
+ match self {
+ CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
+ CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx),
+ #[cfg(test)]
+ CompletionProvider::Fake(_) => Task::ready(Ok(())),
+ }
+ }
+
+ pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
+ match self {
+ CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
+ CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx),
+ #[cfg(test)]
+ CompletionProvider::Fake(_) => unimplemented!(),
+ }
+ }
+
+ pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
+ match self {
+ CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
+ CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())),
+ #[cfg(test)]
+ CompletionProvider::Fake(_) => Task::ready(Ok(())),
+ }
+ }
+
+ pub fn default_model(&self) -> LanguageModel {
+ match self {
+ CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()),
+ CompletionProvider::ZedDotDev(provider) => {
+ LanguageModel::ZedDotDev(provider.default_model())
+ }
+ #[cfg(test)]
+ CompletionProvider::Fake(_) => unimplemented!(),
+ }
+ }
+
+ pub fn count_tokens(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AppContext,
+ ) -> BoxFuture<'static, Result<usize>> {
+ match self {
+ CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
+ CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx),
+ #[cfg(test)]
+ CompletionProvider::Fake(_) => unimplemented!(),
+ }
+ }
+
+ pub fn complete(
+ &self,
+ request: LanguageModelRequest,
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ match self {
+ CompletionProvider::OpenAi(provider) => provider.complete(request),
+ CompletionProvider::ZedDotDev(provider) => provider.complete(request),
+ #[cfg(test)]
+ CompletionProvider::Fake(provider) => provider.complete(),
+ }
+ }
+}
@@ -0,0 +1,29 @@
+use anyhow::Result;
+use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
+use std::sync::Arc;
+
+#[derive(Clone, Default)]
+pub struct FakeCompletionProvider {
+ current_completion_tx: Arc<parking_lot::Mutex<Option<mpsc::UnboundedSender<String>>>>,
+}
+
+impl FakeCompletionProvider {
+ pub fn complete(&self) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ let (tx, rx) = mpsc::unbounded();
+ *self.current_completion_tx.lock() = Some(tx);
+ async move { Ok(rx.map(Ok).boxed()) }.boxed()
+ }
+
+ pub fn send_completion(&self, chunk: String) {
+ self.current_completion_tx
+ .lock()
+ .as_ref()
+ .unwrap()
+ .unbounded_send(chunk)
+ .unwrap();
+ }
+
+ pub fn finish_completion(&self) {
+ self.current_completion_tx.lock().take();
+ }
+}
@@ -0,0 +1,301 @@
+use crate::{
+ assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
+};
+use anyhow::{anyhow, Result};
+use editor::{Editor, EditorElement, EditorStyle};
+use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
+use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace};
+use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
+use settings::Settings;
+use std::{env, sync::Arc};
+use theme::ThemeSettings;
+use ui::prelude::*;
+use util::{http::HttpClient, ResultExt};
+
+pub struct OpenAiCompletionProvider {
+ api_key: Option<String>,
+ api_url: String,
+ default_model: OpenAiModel,
+ http_client: Arc<dyn HttpClient>,
+ settings_version: usize,
+}
+
+impl OpenAiCompletionProvider {
+ pub fn new(
+ default_model: OpenAiModel,
+ api_url: String,
+ http_client: Arc<dyn HttpClient>,
+ settings_version: usize,
+ ) -> Self {
+ Self {
+ api_key: None,
+ api_url,
+ default_model,
+ http_client,
+ settings_version,
+ }
+ }
+
+ pub fn update(&mut self, default_model: OpenAiModel, api_url: String, settings_version: usize) {
+ self.default_model = default_model;
+ self.api_url = api_url;
+ self.settings_version = settings_version;
+ }
+
+ pub fn settings_version(&self) -> usize {
+ self.settings_version
+ }
+
+ pub fn is_authenticated(&self) -> bool {
+ self.api_key.is_some()
+ }
+
+ pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
+ if self.is_authenticated() {
+ Task::ready(Ok(()))
+ } else {
+ let api_url = self.api_url.clone();
+ cx.spawn(|mut cx| async move {
+ let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+ api_key
+ } else {
+ let (_, api_key) = cx
+ .update(|cx| cx.read_credentials(&api_url))?
+ .await?
+ .ok_or_else(|| anyhow!("credentials not found"))?;
+ String::from_utf8(api_key)?
+ };
+ cx.update_global::<CompletionProvider, _>(|provider, _cx| {
+ if let CompletionProvider::OpenAi(provider) = provider {
+ provider.api_key = Some(api_key);
+ }
+ })
+ })
+ }
+ }
+
+ pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
+ let delete_credentials = cx.delete_credentials(&self.api_url);
+ cx.spawn(|mut cx| async move {
+ delete_credentials.await.log_err();
+ cx.update_global::<CompletionProvider, _>(|provider, _cx| {
+ if let CompletionProvider::OpenAi(provider) = provider {
+ provider.api_key = None;
+ }
+ })
+ })
+ }
+
+ pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
+ cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
+ .into()
+ }
+
+ pub fn default_model(&self) -> OpenAiModel {
+ self.default_model.clone()
+ }
+
+ pub fn count_tokens(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AppContext,
+ ) -> BoxFuture<'static, Result<usize>> {
+ count_open_ai_tokens(request, cx.background_executor())
+ }
+
+ pub fn complete(
+ &self,
+ request: LanguageModelRequest,
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ let request = self.to_open_ai_request(request);
+
+ let http_client = self.http_client.clone();
+ let api_key = self.api_key.clone();
+ let api_url = self.api_url.clone();
+ async move {
+ let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
+ let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
+ let response = request.await?;
+ let stream = response
+ .filter_map(|response| async move {
+ match response {
+ Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
+ Err(error) => Some(Err(error)),
+ }
+ })
+ .boxed();
+ Ok(stream)
+ }
+ .boxed()
+ }
+
+ fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
+ let model = match request.model {
+ LanguageModel::ZedDotDev(_) => self.default_model(),
+ LanguageModel::OpenAi(model) => model,
+ };
+
+ Request {
+ model,
+ messages: request
+ .messages
+ .into_iter()
+ .map(|msg| RequestMessage {
+ role: msg.role.into(),
+ content: msg.content,
+ })
+ .collect(),
+ stream: true,
+ stop: request.stop,
+ temperature: request.temperature,
+ }
+ }
+}
+
+pub fn count_open_ai_tokens(
+ request: LanguageModelRequest,
+ background_executor: &gpui::BackgroundExecutor,
+) -> BoxFuture<'static, Result<usize>> {
+ background_executor
+ .spawn(async move {
+ let messages = request
+ .messages
+ .into_iter()
+ .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
+ role: match message.role {
+ Role::User => "user".into(),
+ Role::Assistant => "assistant".into(),
+ Role::System => "system".into(),
+ },
+ content: Some(message.content),
+ name: None,
+ function_call: None,
+ })
+ .collect::<Vec<_>>();
+
+ tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages)
+ })
+ .boxed()
+}
+
+impl From<Role> for open_ai::Role {
+ fn from(val: Role) -> Self {
+ match val {
+ Role::User => OpenAiRole::User,
+ Role::Assistant => OpenAiRole::Assistant,
+ Role::System => OpenAiRole::System,
+ }
+ }
+}
+
+struct AuthenticationPrompt {
+ api_key: View<Editor>,
+ api_url: String,
+}
+
+impl AuthenticationPrompt {
+ fn new(api_url: String, cx: &mut WindowContext) -> Self {
+ Self {
+ api_key: cx.new_view(|cx| {
+ let mut editor = Editor::single_line(cx);
+ editor.set_placeholder_text(
+ "sk-000000000000000000000000000000000000000000000000",
+ cx,
+ );
+ editor
+ }),
+ api_url,
+ }
+ }
+
+ fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
+ let api_key = self.api_key.read(cx).text(cx);
+ if api_key.is_empty() {
+ return;
+ }
+
+ let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes());
+ cx.spawn(|_, mut cx| async move {
+ write_credentials.await?;
+ cx.update_global::<CompletionProvider, _>(|provider, _cx| {
+ if let CompletionProvider::OpenAi(provider) = provider {
+ provider.api_key = Some(api_key);
+ }
+ })
+ })
+ .detach_and_log_err(cx);
+ }
+
+ fn render_api_key_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+ let settings = ThemeSettings::get_global(cx);
+ let text_style = TextStyle {
+ color: cx.theme().colors().text,
+ font_family: settings.ui_font.family.clone(),
+ font_features: settings.ui_font.features,
+ font_size: rems(0.875).into(),
+ font_weight: FontWeight::NORMAL,
+ font_style: FontStyle::Normal,
+ line_height: relative(1.3),
+ background_color: None,
+ underline: None,
+ strikethrough: None,
+ white_space: WhiteSpace::Normal,
+ };
+ EditorElement::new(
+ &self.api_key,
+ EditorStyle {
+ background: cx.theme().colors().editor_background,
+ local_player: cx.theme().players().local(),
+ text: text_style,
+ ..Default::default()
+ },
+ )
+ }
+}
+
+impl Render for AuthenticationPrompt {
+ fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+ const INSTRUCTIONS: [&str; 6] = [
+ "To use the assistant panel or inline assistant, you need to add your OpenAI API key.",
+ " - You can create an API key at: platform.openai.com/api-keys",
+ " - Make sure your OpenAI account has credits",
+ " - Having a subscription for another service like GitHub Copilot won't work.",
+ "",
+ "Paste your OpenAI API key below and hit enter to use the assistant:",
+ ];
+
+ v_flex()
+ .p_4()
+ .size_full()
+ .on_action(cx.listener(Self::save_api_key))
+ .children(
+ INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)),
+ )
+ .child(
+ h_flex()
+ .w_full()
+ .my_2()
+ .px_2()
+ .py_1()
+ .bg(cx.theme().colors().editor_background)
+ .rounded_md()
+ .child(self.render_api_key_editor(cx)),
+ )
+ .child(
+ Label::new(
+ "You can also assign the OPENAI_API_KEY environment variable and restart Zed.",
+ )
+ .size(LabelSize::Small),
+ )
+ .child(
+ h_flex()
+ .gap_2()
+ .child(Label::new("Click on").size(LabelSize::Small))
+ .child(Icon::new(IconName::Ai).size(IconSize::XSmall))
+ .child(
+ Label::new("in the status bar to close this panel.").size(LabelSize::Small),
+ ),
+ )
+ .into_any()
+ }
+}
@@ -0,0 +1,167 @@
+use crate::{
+ assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider,
+ LanguageModelRequest,
+};
+use anyhow::{anyhow, Result};
+use client::{proto, Client};
+use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
+use gpui::{AnyView, AppContext, Task};
+use std::{future, sync::Arc};
+use ui::prelude::*;
+
+pub struct ZedDotDevCompletionProvider {
+ client: Arc<Client>,
+ default_model: ZedDotDevModel,
+ settings_version: usize,
+ status: client::Status,
+ _maintain_client_status: Task<()>,
+}
+
+impl ZedDotDevCompletionProvider {
+ pub fn new(
+ default_model: ZedDotDevModel,
+ client: Arc<Client>,
+ settings_version: usize,
+ cx: &mut AppContext,
+ ) -> Self {
+ let mut status_rx = client.status();
+ let status = *status_rx.borrow();
+ let maintain_client_status = cx.spawn(|mut cx| async move {
+ while let Some(status) = status_rx.next().await {
+ let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
+ if let CompletionProvider::ZedDotDev(provider) = provider {
+ provider.status = status;
+ } else {
+ unreachable!()
+ }
+ });
+ }
+ });
+ Self {
+ client,
+ default_model,
+ settings_version,
+ status,
+ _maintain_client_status: maintain_client_status,
+ }
+ }
+
+ pub fn update(&mut self, default_model: ZedDotDevModel, settings_version: usize) {
+ self.default_model = default_model;
+ self.settings_version = settings_version;
+ }
+
+ pub fn settings_version(&self) -> usize {
+ self.settings_version
+ }
+
+ pub fn default_model(&self) -> ZedDotDevModel {
+ self.default_model.clone()
+ }
+
+ pub fn is_authenticated(&self) -> bool {
+ self.status.is_connected()
+ }
+
+ pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
+ let client = self.client.clone();
+ cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
+ }
+
+ pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
+ cx.new_view(|_cx| AuthenticationPrompt).into()
+ }
+
+ pub fn count_tokens(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AppContext,
+ ) -> BoxFuture<'static, Result<usize>> {
+ match request.model {
+ crate::LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
+ crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFour)
+ | crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFourTurbo)
+ | crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptThreePointFiveTurbo) => {
+ count_open_ai_tokens(request, cx.background_executor())
+ }
+ crate::LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => {
+ let request = self.client.request(proto::CountTokensWithLanguageModel {
+ model,
+ messages: request
+ .messages
+ .iter()
+ .map(|message| message.to_proto())
+ .collect(),
+ });
+ async move {
+ let response = request.await?;
+ Ok(response.token_count as usize)
+ }
+ .boxed()
+ }
+ }
+ }
+
+ pub fn complete(
+ &self,
+ request: LanguageModelRequest,
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ let request = proto::CompleteWithLanguageModel {
+ model: request.model.id().to_string(),
+ messages: request
+ .messages
+ .iter()
+ .map(|message| message.to_proto())
+ .collect(),
+ stop: request.stop,
+ temperature: request.temperature,
+ };
+
+ self.client
+ .request_stream(request)
+ .map_ok(|stream| {
+ stream
+ .filter_map(|response| async move {
+ match response {
+ Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)),
+ Err(error) => Some(Err(error)),
+ }
+ })
+ .boxed()
+ })
+ .boxed()
+ }
+}
+
+struct AuthenticationPrompt;
+
+impl Render for AuthenticationPrompt {
+ fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
+ const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline.";
+
+ v_flex().gap_6().p_4().child(Label::new(LABEL)).child(
+ v_flex()
+ .gap_2()
+ .child(
+ Button::new("sign_in", "Sign in")
+ .icon_color(Color::Muted)
+ .icon(IconName::Github)
+ .icon_position(IconPosition::Start)
+ .style(ButtonStyle::Filled)
+ .full_width()
+ .on_click(|_, cx| {
+ CompletionProvider::global(cx)
+ .authenticate(cx)
+ .detach_and_log_err(cx);
+ }),
+ )
+ .child(
+ div().flex().w_full().items_center().child(
+ Label::new("Sign in to enable collaboration.")
+ .color(Color::Muted)
+ .size(LabelSize::Small),
+ ),
+ ),
+ )
+ }
+}
@@ -1,394 +1,95 @@
-use ai::models::LanguageModel;
-use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
-use ai::prompts::file_context::FileContext;
-use ai::prompts::generate::GenerateInlineContent;
-use ai::prompts::preamble::EngineerPreamble;
-use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
-use ai::providers::open_ai::OpenAiLanguageModel;
-use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
-use std::cmp::{self, Reverse};
-use std::ops::Range;
-use std::sync::Arc;
-
-#[allow(dead_code)]
-fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
- #[derive(Debug)]
- struct Match {
- collapse: Range<usize>,
- keep: Vec<Range<usize>>,
- }
-
- let selected_range = selected_range.to_offset(buffer);
- let mut ts_matches = buffer.matches(0..buffer.len(), |grammar| {
- Some(&grammar.embedding_config.as_ref()?.query)
- });
- let configs = ts_matches
- .grammars()
- .iter()
- .map(|g| g.embedding_config.as_ref().unwrap())
- .collect::<Vec<_>>();
- let mut matches = Vec::new();
- while let Some(mat) = ts_matches.peek() {
- let config = &configs[mat.grammar_index];
- if let Some(collapse) = mat.captures.iter().find_map(|cap| {
- if Some(cap.index) == config.collapse_capture_ix {
- Some(cap.node.byte_range())
- } else {
- None
- }
- }) {
- let mut keep = Vec::new();
- for capture in mat.captures.iter() {
- if Some(capture.index) == config.keep_capture_ix {
- keep.push(capture.node.byte_range());
- } else {
- continue;
- }
- }
- ts_matches.advance();
- matches.push(Match { collapse, keep });
- } else {
- ts_matches.advance();
- }
- }
- matches.sort_unstable_by_key(|mat| (mat.collapse.start, Reverse(mat.collapse.end)));
- let mut matches = matches.into_iter().peekable();
-
- let mut summary = String::new();
- let mut offset = 0;
- let mut flushed_selection = false;
- while let Some(mat) = matches.next() {
- // Keep extending the collapsed range if the next match surrounds
- // the current one.
- while let Some(next_mat) = matches.peek() {
- if mat.collapse.start <= next_mat.collapse.start
- && mat.collapse.end >= next_mat.collapse.end
- {
- matches.next().unwrap();
- } else {
- break;
- }
- }
-
- if offset > mat.collapse.start {
- // Skip collapsed nodes that have already been summarized.
- offset = cmp::max(offset, mat.collapse.end);
- continue;
- }
-
- if offset <= selected_range.start && selected_range.start <= mat.collapse.end {
- if !flushed_selection {
- // The collapsed node ends after the selection starts, so we'll flush the selection first.
- summary.extend(buffer.text_for_range(offset..selected_range.start));
- summary.push_str("<|S|");
- if selected_range.end == selected_range.start {
- summary.push_str(">");
- } else {
- summary.extend(buffer.text_for_range(selected_range.clone()));
- summary.push_str("|E|>");
- }
- offset = selected_range.end;
- flushed_selection = true;
- }
-
- // If the selection intersects the collapsed node, we won't collapse it.
- if selected_range.end >= mat.collapse.start {
- continue;
- }
- }
-
- summary.extend(buffer.text_for_range(offset..mat.collapse.start));
- for keep in mat.keep {
- summary.extend(buffer.text_for_range(keep));
- }
- offset = mat.collapse.end;
- }
-
- // Flush selection if we haven't already done so.
- if !flushed_selection && offset <= selected_range.start {
- summary.extend(buffer.text_for_range(offset..selected_range.start));
- summary.push_str("<|S|");
- if selected_range.end == selected_range.start {
- summary.push_str(">");
- } else {
- summary.extend(buffer.text_for_range(selected_range.clone()));
- summary.push_str("|E|>");
- }
- offset = selected_range.end;
- }
-
- summary.extend(buffer.text_for_range(offset..buffer.len()));
- summary
-}
+use language::BufferSnapshot;
+use std::{fmt::Write, ops::Range};
pub fn generate_content_prompt(
user_prompt: String,
language_name: Option<&str>,
buffer: BufferSnapshot,
range: Range<usize>,
- search_results: Vec<PromptCodeSnippet>,
- model: &str,
project_name: Option<String>,
) -> anyhow::Result<String> {
- // Using new Prompt Templates
- let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAiLanguageModel::load(model));
- let lang_name = if let Some(language_name) = language_name {
- Some(language_name.to_string())
- } else {
- None
- };
+ let mut prompt = String::new();
- let args = PromptArguments {
- model: openai_model,
- language_name: lang_name.clone(),
- project_name,
- snippets: search_results.clone(),
- reserved_tokens: 1000,
- buffer: Some(buffer),
- selected_range: Some(range),
- user_prompt: Some(user_prompt.clone()),
+ let content_type = match language_name {
+ None | Some("Markdown" | "Plain Text") => {
+ writeln!(prompt, "You are an expert engineer.")?;
+ "Text"
+ }
+ Some(language_name) => {
+ writeln!(prompt, "You are an expert {language_name} engineer.")?;
+ writeln!(
+ prompt,
+ "Your answer MUST always and only be valid {}.",
+ language_name
+ )?;
+ "Code"
+ }
};
- let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
- (PromptPriority::Mandatory, Box::new(EngineerPreamble {})),
- (
- PromptPriority::Ordered { order: 1 },
- Box::new(RepositoryContext {}),
- ),
- (
- PromptPriority::Ordered { order: 0 },
- Box::new(FileContext {}),
- ),
- (
- PromptPriority::Mandatory,
- Box::new(GenerateInlineContent {}),
- ),
- ];
- let chain = PromptChain::new(args, templates);
- let (prompt, _) = chain.generate(true)?;
-
- anyhow::Ok(prompt)
-}
+ if let Some(project_name) = project_name {
+ writeln!(
+ prompt,
+ "You are currently working inside the '{project_name}' project in code editor Zed."
+ )?;
+ }
-#[cfg(test)]
-pub(crate) mod tests {
- use super::*;
- use gpui::{AppContext, Context};
- use indoc::indoc;
- use language::{
- language_settings, tree_sitter_rust, Buffer, BufferId, Language, LanguageConfig,
- LanguageMatcher, Point,
- };
- use settings::SettingsStore;
- use std::sync::Arc;
+ // Include file content.
+ for chunk in buffer.text_for_range(0..range.start) {
+ prompt.push_str(chunk);
+ }
- pub(crate) fn rust_lang() -> Language {
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".to_string()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_rust::language()),
- )
- .with_embedding_query(
- r#"
- (
- [(line_comment) (attribute_item)]* @context
- .
- [
- (struct_item
- name: (_) @name)
+ if range.is_empty() {
+ prompt.push_str("<|START|>");
+ } else {
+ prompt.push_str("<|START|");
+ }
- (enum_item
- name: (_) @name)
+ for chunk in buffer.text_for_range(range.clone()) {
+ prompt.push_str(chunk);
+ }
- (impl_item
- trait: (_)? @name
- "for"? @name
- type: (_) @name)
+ if !range.is_empty() {
+ prompt.push_str("|END|>");
+ }
- (trait_item
- name: (_) @name)
+ for chunk in buffer.text_for_range(range.end..buffer.len()) {
+ prompt.push_str(chunk);
+ }
- (function_item
- name: (_) @name
- body: (block
- "{" @keep
- "}" @keep) @collapse)
+ prompt.push('\n');
- (macro_definition
- name: (_) @name)
- ] @item
- )
- "#,
+ if range.is_empty() {
+ writeln!(
+ prompt,
+ "Assume the cursor is located where the `<|START|>` span is."
+ )
+ .unwrap();
+ writeln!(
+ prompt,
+ "{content_type} can't be replaced, so assume your answer will be inserted at the cursor.",
)
- .unwrap()
+ .unwrap();
+ writeln!(
+ prompt,
+ "Generate {content_type} based on the users prompt: {user_prompt}",
+ )
+ .unwrap();
+ } else {
+ writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
+ writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
+ writeln!(
+ prompt,
+ "Double check that you only return code and not the '<|START|' and '|END|'> spans"
+ )
+ .unwrap();
}
- #[gpui::test]
- fn test_outline_for_prompt(cx: &mut AppContext) {
- let settings_store = SettingsStore::test(cx);
- cx.set_global(settings_store);
- language_settings::init(cx);
- let text = indoc! {"
- struct X {
- a: usize,
- b: usize,
- }
-
- impl X {
-
- fn new() -> Self {
- let a = 1;
- let b = 2;
- Self { a, b }
- }
-
- pub fn a(&self, param: bool) -> usize {
- self.a
- }
-
- pub fn b(&self) -> usize {
- self.b
- }
- }
- "};
- let buffer = cx.new_model(|cx| {
- Buffer::new(0, BufferId::new(1).unwrap(), text).with_language(Arc::new(rust_lang()), cx)
- });
- let snapshot = buffer.read(cx).snapshot();
-
- assert_eq!(
- summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)),
- indoc! {"
- struct X {
- <|S|>a: usize,
- b: usize,
- }
-
- impl X {
-
- fn new() -> Self {}
-
- pub fn a(&self, param: bool) -> usize {}
-
- pub fn b(&self) -> usize {}
- }
- "}
- );
-
- assert_eq!(
- summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)),
- indoc! {"
- struct X {
- a: usize,
- b: usize,
- }
-
- impl X {
-
- fn new() -> Self {
- let <|S|a |E|>= 1;
- let b = 2;
- Self { a, b }
- }
-
- pub fn a(&self, param: bool) -> usize {}
-
- pub fn b(&self) -> usize {}
- }
- "}
- );
-
- assert_eq!(
- summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)),
- indoc! {"
- struct X {
- a: usize,
- b: usize,
- }
+ writeln!(prompt, "Never make remarks about the output.").unwrap();
+ writeln!(
+ prompt,
+ "Do not return anything else, except the generated {content_type}."
+ )
+ .unwrap();
- impl X {
- <|S|>
- fn new() -> Self {}
-
- pub fn a(&self, param: bool) -> usize {}
-
- pub fn b(&self) -> usize {}
- }
- "}
- );
-
- assert_eq!(
- summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)),
- indoc! {"
- struct X {
- a: usize,
- b: usize,
- }
-
- impl X {
-
- fn new() -> Self {}
-
- pub fn a(&self, param: bool) -> usize {}
-
- pub fn b(&self) -> usize {}
- }
- <|S|>"}
- );
-
- // Ensure nested functions get collapsed properly.
- let text = indoc! {"
- struct X {
- a: usize,
- b: usize,
- }
-
- impl X {
-
- fn new() -> Self {
- let a = 1;
- let b = 2;
- Self { a, b }
- }
-
- pub fn a(&self, param: bool) -> usize {
- let a = 30;
- fn nested() -> usize {
- 3
- }
- self.a + nested()
- }
-
- pub fn b(&self) -> usize {
- self.b
- }
- }
- "};
- buffer.update(cx, |buffer, cx| buffer.set_text(text, cx));
- let snapshot = buffer.read(cx).snapshot();
- assert_eq!(
- summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)),
- indoc! {"
- <|S|>struct X {
- a: usize,
- b: usize,
- }
-
- impl X {
-
- fn new() -> Self {}
-
- pub fn a(&self, param: bool) -> usize {}
-
- pub fn b(&self) -> usize {}
- }
- "}
- );
- }
+ Ok(prompt)
}
@@ -0,0 +1,121 @@
+use crate::{assistant_settings::OpenAiModel, MessageId, MessageMetadata};
+use anyhow::{anyhow, Result};
+use collections::HashMap;
+use fs::Fs;
+use futures::StreamExt;
+use regex::Regex;
+use serde::{Deserialize, Serialize};
+use std::{
+ cmp::Reverse,
+ ffi::OsStr,
+ path::{Path, PathBuf},
+ sync::Arc,
+};
+use util::paths::CONVERSATIONS_DIR;
+
+#[derive(Serialize, Deserialize)]
+pub struct SavedMessage {
+ pub id: MessageId,
+ pub start: usize,
+}
+
+#[derive(Serialize, Deserialize)]
+pub struct SavedConversation {
+ pub id: Option<String>,
+ pub zed: String,
+ pub version: String,
+ pub text: String,
+ pub messages: Vec<SavedMessage>,
+ pub message_metadata: HashMap<MessageId, MessageMetadata>,
+ pub summary: String,
+}
+
+impl SavedConversation {
+ pub const VERSION: &'static str = "0.2.0";
+
+ pub async fn load(path: &Path, fs: &dyn Fs) -> Result<Self> {
+ let saved_conversation = fs.load(path).await?;
+ let saved_conversation_json =
+ serde_json::from_str::<serde_json::Value>(&saved_conversation)?;
+ match saved_conversation_json
+ .get("version")
+ .ok_or_else(|| anyhow!("version not found"))?
+ {
+ serde_json::Value::String(version) => match version.as_str() {
+ Self::VERSION => Ok(serde_json::from_value::<Self>(saved_conversation_json)?),
+ "0.1.0" => {
+ let saved_conversation =
+ serde_json::from_value::<SavedConversationV0_1_0>(saved_conversation_json)?;
+ Ok(Self {
+ id: saved_conversation.id,
+ zed: saved_conversation.zed,
+ version: saved_conversation.version,
+ text: saved_conversation.text,
+ messages: saved_conversation.messages,
+ message_metadata: saved_conversation.message_metadata,
+ summary: saved_conversation.summary,
+ })
+ }
+ _ => Err(anyhow!(
+ "unrecognized saved conversation version: {}",
+ version
+ )),
+ },
+ _ => Err(anyhow!("version not found on saved conversation")),
+ }
+ }
+}
+
+#[derive(Serialize, Deserialize)]
+struct SavedConversationV0_1_0 {
+ id: Option<String>,
+ zed: String,
+ version: String,
+ text: String,
+ messages: Vec<SavedMessage>,
+ message_metadata: HashMap<MessageId, MessageMetadata>,
+ summary: String,
+ api_url: Option<String>,
+ model: OpenAiModel,
+}
+
+pub struct SavedConversationMetadata {
+ pub title: String,
+ pub path: PathBuf,
+ pub mtime: chrono::DateTime<chrono::Local>,
+}
+
+impl SavedConversationMetadata {
+ pub async fn list(fs: Arc<dyn Fs>) -> Result<Vec<Self>> {
+ fs.create_dir(&CONVERSATIONS_DIR).await?;
+
+ let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?;
+ let mut conversations = Vec::<SavedConversationMetadata>::new();
+ while let Some(path) = paths.next().await {
+ let path = path?;
+ if path.extension() != Some(OsStr::new("json")) {
+ continue;
+ }
+
+ let pattern = r" - \d+.zed.json$";
+ let re = Regex::new(pattern).unwrap();
+
+ let metadata = fs.metadata(&path).await?;
+ if let Some((file_name, metadata)) = path
+ .file_name()
+ .and_then(|name| name.to_str())
+ .zip(metadata)
+ {
+ let title = re.replace(file_name, "");
+ conversations.push(Self {
+ title: title.into_owned(),
+ path,
+ mtime: metadata.mtime.into(),
+ });
+ }
+ }
+ conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime));
+
+ Ok(conversations)
+ }
+}
@@ -197,12 +197,10 @@ impl StreamingDiff {
} else {
hunks.push(Hunk::Remove { len: char_len })
}
+ } else if let Some(Hunk::Keep { len }) = hunks.last_mut() {
+ *len += char_len;
} else {
- if let Some(Hunk::Keep { len }) = hunks.last_mut() {
- *len += char_len;
- } else {
- hunks.push(Hunk::Keep { len: char_len })
- }
+ hunks.push(Hunk::Keep { len: char_len })
}
}
@@ -13,7 +13,7 @@ use async_tungstenite::tungstenite::{
use clock::SystemClock;
use collections::HashMap;
use futures::{
- channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt,
+ channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt,
TryFutureExt as _, TryStreamExt,
};
use gpui::{
@@ -36,7 +36,10 @@ use std::{
future::Future,
marker::PhantomData,
path::PathBuf,
- sync::{atomic::AtomicU64, Arc, Weak},
+ sync::{
+ atomic::{AtomicU64, Ordering},
+ Arc, Weak,
+ },
time::{Duration, Instant},
};
use telemetry::Telemetry;
@@ -442,7 +445,7 @@ impl Client {
}
pub fn id(&self) -> u64 {
- self.id.load(std::sync::atomic::Ordering::SeqCst)
+ self.id.load(Ordering::SeqCst)
}
pub fn http_client(&self) -> Arc<HttpClientWithUrl> {
@@ -450,7 +453,7 @@ impl Client {
}
pub fn set_id(&self, id: u64) -> &Self {
- self.id.store(id, std::sync::atomic::Ordering::SeqCst);
+ self.id.store(id, Ordering::SeqCst);
self
}
@@ -1260,6 +1263,30 @@ impl Client {
.map_ok(|envelope| envelope.payload)
}
+ pub fn request_stream<T: RequestMessage>(
+ &self,
+ request: T,
+ ) -> impl Future<Output = Result<impl Stream<Item = Result<T::Response>>>> {
+ let client_id = self.id.load(Ordering::SeqCst);
+ log::debug!(
+ "rpc request start. client_id:{}. name:{}",
+ client_id,
+ T::NAME
+ );
+ let response = self
+ .connection_id()
+ .map(|conn_id| self.peer.request_stream(conn_id, request));
+ async move {
+ let response = response?.await;
+ log::debug!(
+ "rpc request finish. client_id:{}. name:{}",
+ client_id,
+ T::NAME
+ );
+ response
+ }
+ }
+
pub fn request_envelope<T: RequestMessage>(
&self,
request: T,
@@ -261,7 +261,7 @@ impl Telemetry {
self: &Arc<Self>,
conversation_id: Option<String>,
kind: AssistantKind,
- model: &str,
+ model: String,
) {
let event = Event::Assistant(AssistantEvent {
conversation_id,
@@ -31,10 +31,12 @@ collections.workspace = true
dashmap = "5.4"
envy = "0.4.2"
futures.workspace = true
+google_ai.workspace = true
hex.workspace = true
live_kit_server.workspace = true
log.workspace = true
nanoid = "0.4"
+open_ai.workspace = true
parking_lot.workspace = true
prometheus = "0.13"
prost.workspace = true
@@ -80,7 +82,6 @@ git = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
indoc.workspace = true
language = { workspace = true, features = ["test-support"] }
-lazy_static.workspace = true
live_kit_client = { workspace = true, features = ["test-support"] }
lsp = { workspace = true, features = ["test-support"] }
menu.workspace = true
@@ -379,6 +379,16 @@ CREATE TABLE extension_versions (
CREATE UNIQUE INDEX "index_extensions_external_id" ON "extensions" ("external_id");
CREATE INDEX "index_extensions_total_download_count" ON "extensions" ("total_download_count");
+CREATE TABLE rate_buckets (
+ user_id INT NOT NULL,
+ rate_limit_name VARCHAR(255) NOT NULL,
+ token_count INT NOT NULL,
+ last_refill TIMESTAMP WITHOUT TIME ZONE NOT NULL,
+ PRIMARY KEY (user_id, rate_limit_name),
+ FOREIGN KEY (user_id) REFERENCES users(id)
+);
+CREATE INDEX idx_user_id_rate_limit ON rate_buckets (user_id, rate_limit_name);
+
CREATE TABLE hosted_projects (
id INTEGER PRIMARY KEY AUTOINCREMENT,
channel_id INTEGER NOT NULL REFERENCES channels(id),
@@ -0,0 +1,11 @@
+CREATE TABLE IF NOT EXISTS rate_buckets (
+ user_id INT NOT NULL,
+ rate_limit_name VARCHAR(255) NOT NULL,
+ token_count INT NOT NULL,
+ last_refill TIMESTAMP WITHOUT TIME ZONE NOT NULL,
+ PRIMARY KEY (user_id, rate_limit_name),
+ CONSTRAINT fk_user
+ FOREIGN KEY (user_id) REFERENCES users(id)
+);
+
+CREATE INDEX idx_user_id_rate_limit ON rate_buckets (user_id, rate_limit_name);
@@ -0,0 +1,75 @@
+use anyhow::{anyhow, Result};
+use rpc::proto;
+
+pub fn language_model_request_to_open_ai(
+ request: proto::CompleteWithLanguageModel,
+) -> Result<open_ai::Request> {
+ Ok(open_ai::Request {
+ model: open_ai::Model::from_id(&request.model).unwrap_or(open_ai::Model::FourTurbo),
+ messages: request
+ .messages
+ .into_iter()
+ .map(|message| {
+ let role = proto::LanguageModelRole::from_i32(message.role)
+ .ok_or_else(|| anyhow!("invalid role {}", message.role))?;
+ Ok(open_ai::RequestMessage {
+ role: match role {
+ proto::LanguageModelRole::LanguageModelUser => open_ai::Role::User,
+ proto::LanguageModelRole::LanguageModelAssistant => {
+ open_ai::Role::Assistant
+ }
+ proto::LanguageModelRole::LanguageModelSystem => open_ai::Role::System,
+ },
+ content: message.content,
+ })
+ })
+ .collect::<Result<Vec<open_ai::RequestMessage>>>()?,
+ stream: true,
+ stop: request.stop,
+ temperature: request.temperature,
+ })
+}
+
+pub fn language_model_request_to_google_ai(
+ request: proto::CompleteWithLanguageModel,
+) -> Result<google_ai::GenerateContentRequest> {
+ Ok(google_ai::GenerateContentRequest {
+ contents: request
+ .messages
+ .into_iter()
+ .map(language_model_request_message_to_google_ai)
+ .collect::<Result<Vec<_>>>()?,
+ generation_config: None,
+ safety_settings: None,
+ })
+}
+
+pub fn language_model_request_message_to_google_ai(
+ message: proto::LanguageModelRequestMessage,
+) -> Result<google_ai::Content> {
+ let role = proto::LanguageModelRole::from_i32(message.role)
+ .ok_or_else(|| anyhow!("invalid role {}", message.role))?;
+
+ Ok(google_ai::Content {
+ parts: vec![google_ai::Part::TextPart(google_ai::TextPart {
+ text: message.content,
+ })],
+ role: match role {
+ proto::LanguageModelRole::LanguageModelUser => google_ai::Role::User,
+ proto::LanguageModelRole::LanguageModelAssistant => google_ai::Role::Model,
+ proto::LanguageModelRole::LanguageModelSystem => google_ai::Role::User,
+ },
+ })
+}
+
+pub fn count_tokens_request_to_google_ai(
+ request: proto::CountTokensWithLanguageModel,
+) -> Result<google_ai::CountTokensRequest> {
+ Ok(google_ai::CountTokensRequest {
+ contents: request
+ .messages
+ .into_iter()
+ .map(language_model_request_message_to_google_ai)
+ .collect::<Result<Vec<_>>>()?,
+ })
+}
@@ -1,6 +1,5 @@
use crate::{
db::{ExtensionMetadata, NewExtensionVersion},
- executor::Executor,
AppState, Error, Result,
};
use anyhow::{anyhow, Context as _};
@@ -136,7 +135,7 @@ async fn download_extension(
const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60);
const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60);
-pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>, executor: Executor) {
+pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>) {
let Some(blob_store_client) = app_state.blob_store_client.clone() else {
log::info!("no blob store client");
return;
@@ -146,6 +145,7 @@ pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc<AppState>, e
return;
};
+ let executor = app_state.executor.clone();
executor.spawn_detached({
let executor = executor.clone();
async move {
@@ -10,6 +10,7 @@ pub mod hosted_projects;
pub mod messages;
pub mod notifications;
pub mod projects;
+pub mod rate_buckets;
pub mod rooms;
pub mod servers;
pub mod users;
@@ -0,0 +1,58 @@
+use super::*;
+use crate::db::tables::rate_buckets;
+use sea_orm::{ColumnTrait, EntityTrait, QueryFilter};
+
+impl Database {
+ /// Saves the rate limit for the given user and rate limit name if the last_refill is later
+ /// than the currently saved timestamp.
+ pub async fn save_rate_buckets(&self, buckets: &[rate_buckets::Model]) -> Result<()> {
+ if buckets.is_empty() {
+ return Ok(());
+ }
+
+ self.transaction(|tx| async move {
+ rate_buckets::Entity::insert_many(buckets.iter().map(|bucket| {
+ rate_buckets::ActiveModel {
+ user_id: ActiveValue::Set(bucket.user_id),
+ rate_limit_name: ActiveValue::Set(bucket.rate_limit_name.clone()),
+ token_count: ActiveValue::Set(bucket.token_count),
+ last_refill: ActiveValue::Set(bucket.last_refill),
+ }
+ }))
+ .on_conflict(
+ OnConflict::columns([
+ rate_buckets::Column::UserId,
+ rate_buckets::Column::RateLimitName,
+ ])
+ .update_columns([
+ rate_buckets::Column::TokenCount,
+ rate_buckets::Column::LastRefill,
+ ])
+ .to_owned(),
+ )
+ .exec(&*tx)
+ .await?;
+
+ Ok(())
+ })
+ .await
+ }
+
+ /// Retrieves the rate limit for the given user and rate limit name.
+ pub async fn get_rate_bucket(
+ &self,
+ user_id: UserId,
+ rate_limit_name: &str,
+ ) -> Result<Option<rate_buckets::Model>> {
+ self.transaction(|tx| async move {
+ let rate_limit = rate_buckets::Entity::find()
+ .filter(rate_buckets::Column::UserId.eq(user_id))
+ .filter(rate_buckets::Column::RateLimitName.eq(rate_limit_name))
+ .one(&*tx)
+ .await?;
+
+ Ok(rate_limit)
+ })
+ .await
+ }
+}
@@ -22,6 +22,7 @@ pub mod observed_buffer_edits;
pub mod observed_channel_messages;
pub mod project;
pub mod project_collaborator;
+pub mod rate_buckets;
pub mod room;
pub mod room_participant;
pub mod server;
@@ -0,0 +1,31 @@
+use crate::db::UserId;
+use sea_orm::entity::prelude::*;
+
+#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
+#[sea_orm(table_name = "rate_buckets")]
+pub struct Model {
+ #[sea_orm(primary_key, auto_increment = false)]
+ pub user_id: UserId,
+ #[sea_orm(primary_key, auto_increment = false)]
+ pub rate_limit_name: String,
+ pub token_count: i32,
+ pub last_refill: DateTime,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+ #[sea_orm(
+ belongs_to = "super::user::Entity",
+ from = "Column::UserId",
+ to = "super::user::Column::Id"
+ )]
+ User,
+}
+
+impl Related<super::user::Entity> for Entity {
+ fn to() -> RelationDef {
+ Relation::User.def()
+ }
+}
+
+impl ActiveModelBehavior for ActiveModel {}
@@ -1,8 +1,10 @@
+pub mod ai;
pub mod api;
pub mod auth;
pub mod db;
pub mod env;
pub mod executor;
+mod rate_limiter;
pub mod rpc;
#[cfg(test)]
@@ -13,6 +15,7 @@ use aws_config::{BehaviorVersion, Region};
use axum::{http::StatusCode, response::IntoResponse};
use db::{ChannelId, Database};
use executor::Executor;
+pub use rate_limiter::*;
use serde::Deserialize;
use std::{path::PathBuf, sync::Arc};
use util::ResultExt;
@@ -126,6 +129,8 @@ pub struct Config {
pub blob_store_secret_key: Option<String>,
pub blob_store_bucket: Option<String>,
pub zed_environment: Arc<str>,
+ pub openai_api_key: Option<Arc<str>>,
+ pub google_ai_api_key: Option<Arc<str>>,
pub zed_client_checksum_seed: Option<String>,
pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>,
@@ -147,12 +152,14 @@ pub struct AppState {
pub db: Arc<Database>,
pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
pub blob_store_client: Option<aws_sdk_s3::Client>,
+ pub rate_limiter: Arc<RateLimiter>,
+ pub executor: Executor,
pub clickhouse_client: Option<clickhouse::Client>,
pub config: Config,
}
impl AppState {
- pub async fn new(config: Config) -> Result<Arc<Self>> {
+ pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
let mut db_options = db::ConnectOptions::new(config.database_url.clone());
db_options.max_connections(config.database_max_connections);
let mut db = Database::new(db_options, Executor::Production).await?;
@@ -173,10 +180,13 @@ impl AppState {
None
};
+ let db = Arc::new(db);
let this = Self {
- db: Arc::new(db),
+ db: db.clone(),
live_kit_client,
blob_store_client: build_blob_store_client(&config).await.log_err(),
+ rate_limiter: Arc::new(RateLimiter::new(db)),
+ executor,
clickhouse_client: config
.clickhouse_url
.as_ref()
@@ -7,7 +7,7 @@ use axum::{
};
use collab::{
api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, AppState,
- Config, MigrateConfig, Result,
+ Config, MigrateConfig, RateLimiter, Result,
};
use db::Database;
use std::{
@@ -62,18 +62,27 @@ async fn main() -> Result<()> {
run_migrations().await?;
- let state = AppState::new(config).await?;
+ let state = AppState::new(config, Executor::Production).await?;
let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
.expect("failed to bind TCP listener");
+ let epoch = state
+ .db
+ .create_server(&state.config.zed_environment)
+ .await?;
+ let rpc_server = collab::rpc::Server::new(epoch, state.clone());
+ rpc_server.start().await?;
+
+ fetch_extensions_from_blob_store_periodically(state.clone());
+ RateLimiter::save_periodically(state.rate_limiter.clone(), state.executor.clone());
+
let rpc_server = if is_collab {
let epoch = state
.db
.create_server(&state.config.zed_environment)
.await?;
- let rpc_server =
- collab::rpc::Server::new(epoch, state.clone(), Executor::Production);
+ let rpc_server = collab::rpc::Server::new(epoch, state.clone());
rpc_server.start().await?;
Some(rpc_server)
@@ -82,7 +91,7 @@ async fn main() -> Result<()> {
};
if is_api {
- fetch_extensions_from_blob_store_periodically(state.clone(), Executor::Production);
+ fetch_extensions_from_blob_store_periodically(state.clone());
}
let mut app = collab::api::routes(rpc_server.clone(), state.clone());
@@ -0,0 +1,274 @@
+use crate::{db::UserId, executor::Executor, Database, Error, Result};
+use anyhow::anyhow;
+use chrono::{DateTime, Duration, Utc};
+use dashmap::{DashMap, DashSet};
+use sea_orm::prelude::DateTimeUtc;
+use std::sync::Arc;
+use util::ResultExt;
+
+pub trait RateLimit: 'static {
+ fn capacity() -> usize;
+ fn refill_duration() -> Duration;
+ fn db_name() -> &'static str;
+}
+
+/// Used to enforce per-user rate limits
+pub struct RateLimiter {
+ buckets: DashMap<(UserId, String), RateBucket>,
+ dirty_buckets: DashSet<(UserId, String)>,
+ db: Arc<Database>,
+}
+
+impl RateLimiter {
+ pub fn new(db: Arc<Database>) -> Self {
+ RateLimiter {
+ buckets: DashMap::new(),
+ dirty_buckets: DashSet::new(),
+ db,
+ }
+ }
+
+ /// Spawns a new task that periodically saves rate limit data to the database.
+ pub fn save_periodically(rate_limiter: Arc<Self>, executor: Executor) {
+ const RATE_LIMITER_SAVE_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10);
+
+ executor.clone().spawn_detached(async move {
+ loop {
+ executor.sleep(RATE_LIMITER_SAVE_INTERVAL).await;
+ rate_limiter.save().await.log_err();
+ }
+ });
+ }
+
+ /// Returns an error if the user has exceeded the specified `RateLimit`.
+ /// Attempts to read the from the database if no cached RateBucket currently exists.
+ pub async fn check<T: RateLimit>(&self, user_id: UserId) -> Result<()> {
+ self.check_internal::<T>(user_id, Utc::now()).await
+ }
+
+ async fn check_internal<T: RateLimit>(&self, user_id: UserId, now: DateTimeUtc) -> Result<()> {
+ let bucket_key = (user_id, T::db_name().to_string());
+
+ // Attempt to fetch the bucket from the database if it hasn't been cached.
+ // For now, we keep buckets in memory for the lifetime of the process rather than expiring them,
+ // but this enforces limits across restarts so long as the database is reachable.
+ if !self.buckets.contains_key(&bucket_key) {
+ if let Some(bucket) = self.load_bucket::<T>(user_id).await.log_err().flatten() {
+ self.buckets.insert(bucket_key.clone(), bucket);
+ self.dirty_buckets.insert(bucket_key.clone());
+ }
+ }
+
+ let mut bucket = self
+ .buckets
+ .entry(bucket_key.clone())
+ .or_insert_with(|| RateBucket::new(T::capacity(), T::refill_duration(), now));
+
+ if bucket.value_mut().allow(now) {
+ self.dirty_buckets.insert(bucket_key);
+ Ok(())
+ } else {
+ Err(anyhow!("rate limit exceeded"))?
+ }
+ }
+
+ async fn load_bucket<K: RateLimit>(
+ &self,
+ user_id: UserId,
+ ) -> Result<Option<RateBucket>, Error> {
+ Ok(self
+ .db
+ .get_rate_bucket(user_id, K::db_name())
+ .await?
+ .map(|saved_bucket| RateBucket {
+ capacity: K::capacity(),
+ refill_time_per_token: K::refill_duration(),
+ token_count: saved_bucket.token_count as usize,
+ last_refill: DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
+ }))
+ }
+
+ pub async fn save(&self) -> Result<()> {
+ let mut buckets = Vec::new();
+ self.dirty_buckets.retain(|key| {
+ if let Some(bucket) = self.buckets.get(&key) {
+ buckets.push(crate::db::rate_buckets::Model {
+ user_id: key.0,
+ rate_limit_name: key.1.clone(),
+ token_count: bucket.token_count as i32,
+ last_refill: bucket.last_refill.naive_utc(),
+ });
+ }
+ false
+ });
+
+ match self.db.save_rate_buckets(&buckets).await {
+ Ok(()) => Ok(()),
+ Err(err) => {
+ for bucket in buckets {
+ self.dirty_buckets
+ .insert((bucket.user_id, bucket.rate_limit_name));
+ }
+ Err(err)
+ }
+ }
+ }
+}
+
+#[derive(Clone)]
+struct RateBucket {
+ capacity: usize,
+ token_count: usize,
+ refill_time_per_token: Duration,
+ last_refill: DateTimeUtc,
+}
+
+impl RateBucket {
+ fn new(capacity: usize, refill_duration: Duration, now: DateTimeUtc) -> Self {
+ RateBucket {
+ capacity,
+ token_count: capacity,
+ refill_time_per_token: refill_duration / capacity as i32,
+ last_refill: now,
+ }
+ }
+
+ fn allow(&mut self, now: DateTimeUtc) -> bool {
+ self.refill(now);
+ if self.token_count > 0 {
+ self.token_count -= 1;
+ true
+ } else {
+ false
+ }
+ }
+
+ fn refill(&mut self, now: DateTimeUtc) {
+ let elapsed = now - self.last_refill;
+ if elapsed >= self.refill_time_per_token {
+ let new_tokens =
+ elapsed.num_milliseconds() / self.refill_time_per_token.num_milliseconds();
+
+ self.token_count = (self.token_count + new_tokens as usize).min(self.capacity);
+ self.last_refill = now;
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::db::{NewUserParams, TestDb};
+ use gpui::TestAppContext;
+
+ #[gpui::test]
+ async fn test_rate_limiter(cx: &mut TestAppContext) {
+ let test_db = TestDb::sqlite(cx.executor().clone());
+ let db = test_db.db().clone();
+ let user_1 = db
+ .create_user(
+ "user-1@zed.dev",
+ false,
+ NewUserParams {
+ github_login: "user-1".into(),
+ github_user_id: 1,
+ },
+ )
+ .await
+ .unwrap()
+ .user_id;
+ let user_2 = db
+ .create_user(
+ "user-2@zed.dev",
+ false,
+ NewUserParams {
+ github_login: "user-2".into(),
+ github_user_id: 2,
+ },
+ )
+ .await
+ .unwrap()
+ .user_id;
+
+ let mut now = Utc::now();
+
+ let rate_limiter = RateLimiter::new(db.clone());
+
+ // User 1 can access resource A two times before being rate-limited.
+ rate_limiter
+ .check_internal::<RateLimitA>(user_1, now)
+ .await
+ .unwrap();
+ rate_limiter
+ .check_internal::<RateLimitA>(user_1, now)
+ .await
+ .unwrap();
+ rate_limiter
+ .check_internal::<RateLimitA>(user_1, now)
+ .await
+ .unwrap_err();
+
+ // User 2 can access resource A and user 1 can access resource B.
+ rate_limiter
+ .check_internal::<RateLimitB>(user_2, now)
+ .await
+ .unwrap();
+ rate_limiter
+ .check_internal::<RateLimitB>(user_1, now)
+ .await
+ .unwrap();
+
+ // After one second, user 1 can make another request before being rate-limited again.
+ now += Duration::seconds(1);
+ rate_limiter
+ .check_internal::<RateLimitA>(user_1, now)
+ .await
+ .unwrap();
+ rate_limiter
+ .check_internal::<RateLimitA>(user_1, now)
+ .await
+ .unwrap_err();
+
+ rate_limiter.save().await.unwrap();
+
+ // Rate limits are reloaded from the database, so user A is still rate-limited
+ // for resource A.
+ let rate_limiter = RateLimiter::new(db.clone());
+ rate_limiter
+ .check_internal::<RateLimitA>(user_1, now)
+ .await
+ .unwrap_err();
+ }
+
+ struct RateLimitA;
+
+ impl RateLimit for RateLimitA {
+ fn capacity() -> usize {
+ 2
+ }
+
+ fn refill_duration() -> Duration {
+ Duration::seconds(2)
+ }
+
+ fn db_name() -> &'static str {
+ "rate-limit-a"
+ }
+ }
+
+ struct RateLimitB;
+
+ impl RateLimit for RateLimitB {
+ fn capacity() -> usize {
+ 10
+ }
+
+ fn refill_duration() -> Duration {
+ Duration::seconds(3)
+ }
+
+ fn db_name() -> &'static str {
+ "rate-limit-b"
+ }
+ }
+}
@@ -9,9 +9,9 @@ use crate::{
User, UserId,
},
executor::Executor,
- AppState, Error, Result,
+ AppState, Error, RateLimit, RateLimiter, Result,
};
-use anyhow::anyhow;
+use anyhow::{anyhow, Context as _};
use async_tungstenite::tungstenite::{
protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage,
};
@@ -30,6 +30,8 @@ use axum::{
};
use collections::{HashMap, HashSet};
pub use connection_pool::{ConnectionPool, ZedVersion};
+use core::fmt::{self, Debug, Formatter};
+
use futures::{
channel::oneshot,
future::{self, BoxFuture},
@@ -39,15 +41,14 @@ use futures::{
use prometheus::{register_int_gauge, IntGauge};
use rpc::{
proto::{
- self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
- RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
+ self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole,
+ LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
},
Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope,
};
use serde::{Serialize, Serializer};
use std::{
any::TypeId,
- fmt,
future::Future,
marker::PhantomData,
mem,
@@ -64,7 +65,7 @@ use time::OffsetDateTime;
use tokio::sync::{watch, Semaphore};
use tower::ServiceBuilder;
use tracing::{field, info_span, instrument, Instrument};
-use util::SemanticVersion;
+use util::{http::IsahcHttpClient, SemanticVersion};
pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30);
@@ -92,6 +93,18 @@ impl<R: RequestMessage> Response<R> {
}
}
+struct StreamingResponse<R: RequestMessage> {
+ peer: Arc<Peer>,
+ receipt: Receipt<R>,
+}
+
+impl<R: RequestMessage> StreamingResponse<R> {
+ fn send(&self, payload: R::Response) -> Result<()> {
+ self.peer.respond(self.receipt, payload)?;
+ Ok(())
+ }
+}
+
#[derive(Clone)]
struct Session {
user_id: UserId,
@@ -100,6 +113,8 @@ struct Session {
peer: Arc<Peer>,
connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
+ http_client: IsahcHttpClient,
+ rate_limiter: Arc<RateLimiter>,
_executor: Executor,
}
@@ -124,8 +139,8 @@ impl Session {
}
}
-impl fmt::Debug for Session {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+impl Debug for Session {
+ fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("Session")
.field("user_id", &self.user_id)
.field("connection_id", &self.connection_id)
@@ -148,7 +163,6 @@ pub struct Server {
peer: Arc<Peer>,
pub(crate) connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
app_state: Arc<AppState>,
- executor: Executor,
handlers: HashMap<TypeId, MessageHandler>,
teardown: watch::Sender<bool>,
}
@@ -175,12 +189,11 @@ where
}
impl Server {
- pub fn new(id: ServerId, app_state: Arc<AppState>, executor: Executor) -> Arc<Self> {
+ pub fn new(id: ServerId, app_state: Arc<AppState>) -> Arc<Self> {
let mut server = Self {
id: parking_lot::Mutex::new(id),
peer: Peer::new(id.0 as u32),
- app_state,
- executor,
+ app_state: app_state.clone(),
connection_pool: Default::default(),
handlers: Default::default(),
teardown: watch::channel(false).0,
@@ -280,7 +293,30 @@ impl Server {
.add_message_handler(update_followers)
.add_request_handler(get_private_user_info)
.add_message_handler(acknowledge_channel_message)
- .add_message_handler(acknowledge_buffer_version);
+ .add_message_handler(acknowledge_buffer_version)
+ .add_streaming_request_handler({
+ let app_state = app_state.clone();
+ move |request, response, session| {
+ complete_with_language_model(
+ request,
+ response,
+ session,
+ app_state.config.openai_api_key.clone(),
+ app_state.config.google_ai_api_key.clone(),
+ )
+ }
+ })
+ .add_request_handler({
+ let app_state = app_state.clone();
+ move |request, response, session| {
+ count_tokens_with_language_model(
+ request,
+ response,
+ session,
+ app_state.config.google_ai_api_key.clone(),
+ )
+ }
+ });
Arc::new(server)
}
@@ -289,12 +325,12 @@ impl Server {
let server_id = *self.id.lock();
let app_state = self.app_state.clone();
let peer = self.peer.clone();
- let timeout = self.executor.sleep(CLEANUP_TIMEOUT);
+ let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT);
let pool = self.connection_pool.clone();
let live_kit_client = self.app_state.live_kit_client.clone();
let span = info_span!("start server");
- self.executor.spawn_detached(
+ self.app_state.executor.spawn_detached(
async move {
tracing::info!("waiting for cleanup timeout");
timeout.await;
@@ -536,6 +572,40 @@ impl Server {
})
}
+ fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
+ where
+ F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
+ Fut: Send + Future<Output = Result<()>>,
+ M: RequestMessage,
+ {
+ let handler = Arc::new(handler);
+ self.add_handler(move |envelope, session| {
+ let receipt = envelope.receipt();
+ let handler = handler.clone();
+ async move {
+ let peer = session.peer.clone();
+ let response = StreamingResponse {
+ peer: peer.clone(),
+ receipt,
+ };
+ match (handler)(envelope.payload, response, session).await {
+ Ok(()) => {
+ peer.end_stream(receipt)?;
+ Ok(())
+ }
+ Err(error) => {
+ let proto_err = match &error {
+ Error::Internal(err) => err.to_proto(),
+ _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
+ };
+ peer.respond_with_error(receipt, proto_err)?;
+ Err(error)
+ }
+ }
+ }
+ })
+ }
+
#[allow(clippy::too_many_arguments)]
pub fn handle_connection(
self: &Arc<Self>,
@@ -569,6 +639,14 @@ impl Server {
tracing::Span::current().record("connection_id", format!("{}", connection_id));
tracing::info!("connection opened");
+ let http_client = match IsahcHttpClient::new() {
+ Ok(http_client) => http_client,
+ Err(error) => {
+ tracing::error!(?error, "failed to create HTTP client");
+ return;
+ }
+ };
+
let session = Session {
user_id,
connection_id,
@@ -576,7 +654,9 @@ impl Server {
peer: this.peer.clone(),
connection_pool: this.connection_pool.clone(),
live_kit_client: this.app_state.live_kit_client.clone(),
- _executor: executor.clone()
+ http_client,
+ rate_limiter: this.app_state.rate_limiter.clone(),
+ _executor: executor.clone(),
};
if let Err(error) = this.send_initial_client_update(connection_id, user, zed_version, send_connection_id, &session).await {
@@ -3220,6 +3300,207 @@ async fn acknowledge_buffer_version(
Ok(())
}
+struct CompleteWithLanguageModelRateLimit;
+
+impl RateLimit for CompleteWithLanguageModelRateLimit {
+ fn capacity() -> usize {
+ std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
+ .ok()
+ .and_then(|v| v.parse().ok())
+ .unwrap_or(120) // Picked arbitrarily
+ }
+
+ fn refill_duration() -> chrono::Duration {
+ chrono::Duration::hours(1)
+ }
+
+ fn db_name() -> &'static str {
+ "complete-with-language-model"
+ }
+}
+
+async fn complete_with_language_model(
+ request: proto::CompleteWithLanguageModel,
+ response: StreamingResponse<proto::CompleteWithLanguageModel>,
+ session: Session,
+ open_ai_api_key: Option<Arc<str>>,
+ google_ai_api_key: Option<Arc<str>>,
+) -> Result<()> {
+ authorize_access_to_language_models(&session).await?;
+ session
+ .rate_limiter
+ .check::<CompleteWithLanguageModelRateLimit>(session.user_id)
+ .await?;
+
+ if request.model.starts_with("gpt") {
+ let api_key =
+ open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
+ complete_with_open_ai(request, response, session, api_key).await?;
+ } else if request.model.starts_with("gemini") {
+ let api_key = google_ai_api_key
+ .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
+ complete_with_google_ai(request, response, session, api_key).await?;
+ }
+
+ Ok(())
+}
+
+async fn complete_with_open_ai(
+ request: proto::CompleteWithLanguageModel,
+ response: StreamingResponse<proto::CompleteWithLanguageModel>,
+ session: Session,
+ api_key: Arc<str>,
+) -> Result<()> {
+ const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
+
+ let mut completion_stream = open_ai::stream_completion(
+ &session.http_client,
+ OPEN_AI_API_URL,
+ &api_key,
+ crate::ai::language_model_request_to_open_ai(request)?,
+ )
+ .await
+ .context("open_ai::stream_completion request failed")?;
+
+ while let Some(event) = completion_stream.next().await {
+ let event = event?;
+ response.send(proto::LanguageModelResponse {
+ choices: event
+ .choices
+ .into_iter()
+ .map(|choice| proto::LanguageModelChoiceDelta {
+ index: choice.index,
+ delta: Some(proto::LanguageModelResponseMessage {
+ role: choice.delta.role.map(|role| match role {
+ open_ai::Role::User => LanguageModelRole::LanguageModelUser,
+ open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant,
+ open_ai::Role::System => LanguageModelRole::LanguageModelSystem,
+ } as i32),
+ content: choice.delta.content,
+ }),
+ finish_reason: choice.finish_reason,
+ })
+ .collect(),
+ })?;
+ }
+
+ Ok(())
+}
+
+async fn complete_with_google_ai(
+ request: proto::CompleteWithLanguageModel,
+ response: StreamingResponse<proto::CompleteWithLanguageModel>,
+ session: Session,
+ api_key: Arc<str>,
+) -> Result<()> {
+ let mut stream = google_ai::stream_generate_content(
+ &session.http_client,
+ google_ai::API_URL,
+ api_key.as_ref(),
+ crate::ai::language_model_request_to_google_ai(request)?,
+ )
+ .await
+ .context("google_ai::stream_generate_content request failed")?;
+
+ while let Some(event) = stream.next().await {
+ let event = event?;
+ response.send(proto::LanguageModelResponse {
+ choices: event
+ .candidates
+ .unwrap_or_default()
+ .into_iter()
+ .map(|candidate| proto::LanguageModelChoiceDelta {
+ index: candidate.index as u32,
+ delta: Some(proto::LanguageModelResponseMessage {
+ role: Some(match candidate.content.role {
+ google_ai::Role::User => LanguageModelRole::LanguageModelUser,
+ google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant,
+ } as i32),
+ content: Some(
+ candidate
+ .content
+ .parts
+ .into_iter()
+ .filter_map(|part| match part {
+ google_ai::Part::TextPart(part) => Some(part.text),
+ google_ai::Part::InlineDataPart(_) => None,
+ })
+ .collect(),
+ ),
+ }),
+ finish_reason: candidate.finish_reason.map(|reason| reason.to_string()),
+ })
+ .collect(),
+ })?;
+ }
+
+ Ok(())
+}
+
+struct CountTokensWithLanguageModelRateLimit;
+
+impl RateLimit for CountTokensWithLanguageModelRateLimit {
+ fn capacity() -> usize {
+ std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
+ .ok()
+ .and_then(|v| v.parse().ok())
+ .unwrap_or(600) // Picked arbitrarily
+ }
+
+ fn refill_duration() -> chrono::Duration {
+ chrono::Duration::hours(1)
+ }
+
+ fn db_name() -> &'static str {
+ "count-tokens-with-language-model"
+ }
+}
+
+async fn count_tokens_with_language_model(
+ request: proto::CountTokensWithLanguageModel,
+ response: Response<proto::CountTokensWithLanguageModel>,
+ session: Session,
+ google_ai_api_key: Option<Arc<str>>,
+) -> Result<()> {
+ authorize_access_to_language_models(&session).await?;
+
+ if !request.model.starts_with("gemini") {
+ return Err(anyhow!(
+ "counting tokens for model: {:?} is not supported",
+ request.model
+ ))?;
+ }
+
+ session
+ .rate_limiter
+ .check::<CountTokensWithLanguageModelRateLimit>(session.user_id)
+ .await?;
+
+ let api_key = google_ai_api_key
+ .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
+ let tokens_response = google_ai::count_tokens(
+ &session.http_client,
+ google_ai::API_URL,
+ &api_key,
+ crate::ai::count_tokens_request_to_google_ai(request)?,
+ )
+ .await?;
+ response.send(proto::CountTokensResponse {
+ token_count: tokens_response.total_tokens as u32,
+ })?;
+ Ok(())
+}
+
+async fn authorize_access_to_language_models(session: &Session) -> Result<(), Error> {
+ let db = session.db().await;
+ let flags = db.get_user_flags(session.user_id).await?;
+ if flags.iter().any(|flag| flag == "language-models") {
+ Ok(())
+ } else {
+ Err(anyhow!("permission denied"))?
+ }
+}
+
/// Start receiving chat updates for a channel
async fn join_channel_chat(
request: proto::JoinChannelChat,
@@ -2,7 +2,7 @@ use crate::{
db::{tests::TestDb, NewUserParams, UserId},
executor::Executor,
rpc::{Server, ZedVersion, CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
- AppState, Config,
+ AppState, Config, RateLimiter,
};
use anyhow::anyhow;
use call::ActiveCall;
@@ -93,17 +93,14 @@ impl TestServer {
deterministic.clone(),
)
.unwrap();
- let app_state = Self::build_app_state(&test_db, &live_kit_server).await;
+ let executor = Executor::Deterministic(deterministic.clone());
+ let app_state = Self::build_app_state(&test_db, &live_kit_server, executor.clone()).await;
let epoch = app_state
.db
.create_server(&app_state.config.zed_environment)
.await
.unwrap();
- let server = Server::new(
- epoch,
- app_state.clone(),
- Executor::Deterministic(deterministic.clone()),
- );
+ let server = Server::new(epoch, app_state.clone());
server.start().await.unwrap();
// Advance clock to ensure the server's cleanup task is finished.
deterministic.advance_clock(CLEANUP_TIMEOUT);
@@ -482,12 +479,15 @@ impl TestServer {
pub async fn build_app_state(
test_db: &TestDb,
- fake_server: &live_kit_client::TestServer,
+ live_kit_test_server: &live_kit_client::TestServer,
+ executor: Executor,
) -> Arc<AppState> {
Arc::new(AppState {
db: test_db.db().clone(),
- live_kit_client: Some(Arc::new(fake_server.create_api_client())),
+ live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())),
blob_store_client: None,
+ rate_limiter: Arc::new(RateLimiter::new(test_db.db().clone())),
+ executor,
clickhouse_client: None,
config: Config {
http_port: 0,
@@ -506,6 +506,8 @@ impl TestServer {
blob_store_access_key: None,
blob_store_secret_key: None,
blob_store_bucket: None,
+ openai_api_key: None,
+ google_ai_api_key: None,
clickhouse_url: None,
clickhouse_user: None,
clickhouse_password: None,
@@ -0,0 +1,14 @@
+[package]
+name = "google_ai"
+version = "0.1.0"
+edition = "2021"
+
+[lib]
+path = "src/google_ai.rs"
+
+[dependencies]
+anyhow.workspace = true
+futures.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+util.workspace = true
@@ -0,0 +1,266 @@
+use anyhow::{anyhow, Result};
+use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
+use serde::{Deserialize, Serialize};
+use util::http::HttpClient;
+
+pub const API_URL: &str = "https://generativelanguage.googleapis.com";
+
+pub async fn stream_generate_content<T: HttpClient>(
+ client: &T,
+ api_url: &str,
+ api_key: &str,
+ request: GenerateContentRequest,
+) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
+ let uri = format!(
+ "{}/v1beta/models/gemini-pro:streamGenerateContent?alt=sse&key={}",
+ api_url, api_key
+ );
+
+ let request = serde_json::to_string(&request)?;
+ let mut response = client.post_json(&uri, request.into()).await?;
+ if response.status().is_success() {
+ let reader = BufReader::new(response.into_body());
+ Ok(reader
+ .lines()
+ .filter_map(|line| async move {
+ match line {
+ Ok(line) => {
+ if let Some(line) = line.strip_prefix("data: ") {
+ match serde_json::from_str(line) {
+ Ok(response) => Some(Ok(response)),
+ Err(error) => Some(Err(anyhow!(error))),
+ }
+ } else {
+ None
+ }
+ }
+ Err(error) => Some(Err(anyhow!(error))),
+ }
+ })
+ .boxed())
+ } else {
+ let mut text = String::new();
+ response.body_mut().read_to_string(&mut text).await?;
+ Err(anyhow!(
+ "error during streamGenerateContent, status code: {:?}, body: {}",
+ response.status(),
+ text
+ ))
+ }
+}
+
+pub async fn count_tokens<T: HttpClient>(
+ client: &T,
+ api_url: &str,
+ api_key: &str,
+ request: CountTokensRequest,
+) -> Result<CountTokensResponse> {
+ let uri = format!(
+ "{}/v1beta/models/gemini-pro:countTokens?key={}",
+ api_url, api_key
+ );
+ let request = serde_json::to_string(&request)?;
+ let mut response = client.post_json(&uri, request.into()).await?;
+ let mut text = String::new();
+ response.body_mut().read_to_string(&mut text).await?;
+ if response.status().is_success() {
+ Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
+ } else {
+ Err(anyhow!(
+ "error during countTokens, status code: {:?}, body: {}",
+ response.status(),
+ text
+ ))
+ }
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub enum Task {
+ #[serde(rename = "generateContent")]
+ GenerateContent,
+ #[serde(rename = "streamGenerateContent")]
+ StreamGenerateContent,
+ #[serde(rename = "countTokens")]
+ CountTokens,
+ #[serde(rename = "embedContent")]
+ EmbedContent,
+ #[serde(rename = "batchEmbedContents")]
+ BatchEmbedContents,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct GenerateContentRequest {
+ pub contents: Vec<Content>,
+ pub generation_config: Option<GenerationConfig>,
+ pub safety_settings: Option<Vec<SafetySetting>>,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct GenerateContentResponse {
+ pub candidates: Option<Vec<GenerateContentCandidate>>,
+ pub prompt_feedback: Option<PromptFeedback>,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct GenerateContentCandidate {
+ pub index: usize,
+ pub content: Content,
+ pub finish_reason: Option<String>,
+ pub finish_message: Option<String>,
+ pub safety_ratings: Option<Vec<SafetyRating>>,
+ pub citation_metadata: Option<CitationMetadata>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct Content {
+ pub parts: Vec<Part>,
+ pub role: Role,
+}
+
+#[derive(Debug, Deserialize, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub enum Role {
+ User,
+ Model,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(untagged)]
+pub enum Part {
+ TextPart(TextPart),
+ InlineDataPart(InlineDataPart),
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct TextPart {
+ pub text: String,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct InlineDataPart {
+ pub inline_data: GenerativeContentBlob,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct GenerativeContentBlob {
+ pub mime_type: String,
+ pub data: String,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct CitationSource {
+ pub start_index: Option<usize>,
+ pub end_index: Option<usize>,
+ pub uri: Option<String>,
+ pub license: Option<String>,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct CitationMetadata {
+ pub citation_sources: Vec<CitationSource>,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct PromptFeedback {
+ pub block_reason: Option<String>,
+ pub safety_ratings: Vec<SafetyRating>,
+ pub block_reason_message: Option<String>,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct GenerationConfig {
+ pub candidate_count: Option<usize>,
+ pub stop_sequences: Option<Vec<String>>,
+ pub max_output_tokens: Option<usize>,
+ pub temperature: Option<f64>,
+ pub top_p: Option<f64>,
+ pub top_k: Option<usize>,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct SafetySetting {
+ pub category: HarmCategory,
+ pub threshold: HarmBlockThreshold,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub enum HarmCategory {
+ #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
+ Unspecified,
+ #[serde(rename = "HARM_CATEGORY_DEROGATORY")]
+ Derogatory,
+ #[serde(rename = "HARM_CATEGORY_TOXICITY")]
+ Toxicity,
+ #[serde(rename = "HARM_CATEGORY_VIOLENCE")]
+ Violence,
+ #[serde(rename = "HARM_CATEGORY_SEXUAL")]
+ Sexual,
+ #[serde(rename = "HARM_CATEGORY_MEDICAL")]
+ Medical,
+ #[serde(rename = "HARM_CATEGORY_DANGEROUS")]
+ Dangerous,
+ #[serde(rename = "HARM_CATEGORY_HARASSMENT")]
+ Harassment,
+ #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
+ HateSpeech,
+ #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
+ SexuallyExplicit,
+ #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
+ DangerousContent,
+}
+
+#[derive(Debug, Serialize)]
+pub enum HarmBlockThreshold {
+ #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
+ Unspecified,
+ #[serde(rename = "BLOCK_LOW_AND_ABOVE")]
+ BlockLowAndAbove,
+ #[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")]
+ BlockMediumAndAbove,
+ #[serde(rename = "BLOCK_ONLY_HIGH")]
+ BlockOnlyHigh,
+ #[serde(rename = "BLOCK_NONE")]
+ BlockNone,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
+pub enum HarmProbability {
+ #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")]
+ Unspecified,
+ Negligible,
+ Low,
+ Medium,
+ High,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct SafetyRating {
+ pub category: HarmCategory,
+ pub probability: HarmProbability,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub struct CountTokensRequest {
+ pub contents: Vec<Content>,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct CountTokensResponse {
+ pub total_tokens: usize,
+}
@@ -0,0 +1,19 @@
+[package]
+name = "open_ai"
+version = "0.1.0"
+edition = "2021"
+
+[lib]
+path = "src/open_ai.rs"
+
+[features]
+default = []
+schemars = ["dep:schemars"]
+
+[dependencies]
+anyhow.workspace = true
+futures.workspace = true
+schemars = { workspace = true, optional = true }
+serde.workspace = true
+serde_json.workspace = true
+util.workspace = true
@@ -0,0 +1,182 @@
+use anyhow::{anyhow, Result};
+use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
+use serde::{Deserialize, Serialize};
+use std::convert::TryFrom;
+use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+
+#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(rename_all = "lowercase")]
+pub enum Role {
+ User,
+ Assistant,
+ System,
+}
+
+impl TryFrom<String> for Role {
+ type Error = anyhow::Error;
+
+ fn try_from(value: String) -> Result<Self> {
+ match value.as_str() {
+ "user" => Ok(Self::User),
+ "assistant" => Ok(Self::Assistant),
+ "system" => Ok(Self::System),
+ _ => Err(anyhow!("invalid role '{value}'")),
+ }
+ }
+}
+
+impl From<Role> for String {
+ fn from(val: Role) -> Self {
+ match val {
+ Role::User => "user".to_owned(),
+ Role::Assistant => "assistant".to_owned(),
+ Role::System => "system".to_owned(),
+ }
+ }
+}
+
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
+pub enum Model {
+ #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")]
+ ThreePointFiveTurbo,
+ #[serde(rename = "gpt-4", alias = "gpt-4-0613")]
+ Four,
+ #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")]
+ #[default]
+ FourTurbo,
+}
+
+impl Model {
+ pub fn from_id(id: &str) -> Result<Self> {
+ match id {
+ "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo),
+ "gpt-4" => Ok(Self::Four),
+ "gpt-4-turbo-preview" => Ok(Self::FourTurbo),
+ _ => Err(anyhow!("invalid model id")),
+ }
+ }
+
+ pub fn id(&self) -> &'static str {
+ match self {
+ Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
+ Self::Four => "gpt-4",
+ Self::FourTurbo => "gpt-4-turbo-preview",
+ }
+ }
+
+ pub fn display_name(&self) -> &'static str {
+ match self {
+ Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
+ Self::Four => "gpt-4",
+ Self::FourTurbo => "gpt-4-turbo",
+ }
+ }
+}
+
+#[derive(Debug, Serialize)]
+pub struct Request {
+ pub model: Model,
+ pub messages: Vec<RequestMessage>,
+ pub stream: bool,
+ pub stop: Vec<String>,
+ pub temperature: f32,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct RequestMessage {
+ pub role: Role,
+ pub content: String,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct ResponseMessage {
+ pub role: Option<Role>,
+ pub content: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct Usage {
+ pub prompt_tokens: u32,
+ pub completion_tokens: u32,
+ pub total_tokens: u32,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct ChoiceDelta {
+ pub index: u32,
+ pub delta: ResponseMessage,
+ pub finish_reason: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct ResponseStreamEvent {
+ pub created: u32,
+ pub model: String,
+ pub choices: Vec<ChoiceDelta>,
+ pub usage: Option<Usage>,
+}
+
+pub async fn stream_completion(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ request: Request,
+) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
+ let uri = format!("{api_url}/chat/completions");
+ let request = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(uri)
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_key))
+ .body(AsyncBody::from(serde_json::to_string(&request)?))?;
+ let mut response = client.send(request).await?;
+ if response.status().is_success() {
+ let reader = BufReader::new(response.into_body());
+ Ok(reader
+ .lines()
+ .filter_map(|line| async move {
+ match line {
+ Ok(line) => {
+ let line = line.strip_prefix("data: ")?;
+ if line == "[DONE]" {
+ None
+ } else {
+ match serde_json::from_str(line) {
+ Ok(response) => Some(Ok(response)),
+ Err(error) => Some(Err(anyhow!(error))),
+ }
+ }
+ }
+ Err(error) => Some(Err(anyhow!(error))),
+ }
+ })
+ .boxed())
+ } else {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ #[derive(Deserialize)]
+ struct OpenAiResponse {
+ error: OpenAiError,
+ }
+
+ #[derive(Deserialize)]
+ struct OpenAiError {
+ message: String,
+ }
+
+ match serde_json::from_str::<OpenAiResponse>(&body) {
+ Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
+ "Failed to connect to OpenAI API: {}",
+ response.error.message,
+ )),
+
+ _ => Err(anyhow!(
+ "Failed to connect to OpenAI API: {} {}",
+ response.status(),
+ body,
+ )),
+ }
+ }
+}
@@ -1,7 +1,7 @@
syntax = "proto3";
package zed.messages;
-// Looking for a number? Search "// Current max"
+// Looking for a number? Search "// current max"
message PeerId {
uint32 owner_id = 1;
@@ -26,6 +26,7 @@ message Envelope {
Error error = 6;
Ping ping = 7;
Test test = 8;
+ EndStream end_stream = 165;
CreateRoom create_room = 9;
CreateRoomResponse create_room_response = 10;
@@ -198,6 +199,11 @@ message Envelope {
GetImplementationResponse get_implementation_response = 163;
JoinHostedProject join_hosted_project = 164;
+
+ CompleteWithLanguageModel complete_with_language_model = 166;
+ LanguageModelResponse language_model_response = 167;
+ CountTokensWithLanguageModel count_tokens_with_language_model = 168;
+ CountTokensResponse count_tokens_response = 169; // current max
}
reserved 158 to 161;
@@ -236,6 +242,8 @@ enum ErrorCode {
reserved 6;
}
+message EndStream {}
+
message Test {
uint64 id = 1;
}
@@ -1718,3 +1726,45 @@ message SetRoomParticipantRole {
uint64 user_id = 2;
ChannelRole role = 3;
}
+
+message CompleteWithLanguageModel {
+ string model = 1;
+ repeated LanguageModelRequestMessage messages = 2;
+ repeated string stop = 3;
+ float temperature = 4;
+}
+
+message LanguageModelRequestMessage {
+ LanguageModelRole role = 1;
+ string content = 2;
+}
+
+enum LanguageModelRole {
+ LanguageModelUser = 0;
+ LanguageModelAssistant = 1;
+ LanguageModelSystem = 2;
+}
+
+message LanguageModelResponseMessage {
+ optional LanguageModelRole role = 1;
+ optional string content = 2;
+}
+
+message LanguageModelResponse {
+ repeated LanguageModelChoiceDelta choices = 1;
+}
+
+message LanguageModelChoiceDelta {
+ uint32 index = 1;
+ LanguageModelResponseMessage delta = 2;
+ optional string finish_reason = 3;
+}
+
+message CountTokensWithLanguageModel {
+ string model = 1;
+ repeated LanguageModelRequestMessage messages = 2;
+}
+
+message CountTokensResponse {
+ uint32 token_count = 1;
+}
@@ -80,7 +80,7 @@ pub trait ErrorExt {
fn error_tag(&self, k: &str) -> Option<&str>;
/// to_proto() converts the error into a proto::Error
fn to_proto(&self) -> proto::Error;
- ///
+ /// Clones the error and turns into an [anyhow::Error].
fn cloned(&self) -> anyhow::Error;
}
@@ -9,19 +9,21 @@ use collections::HashMap;
use futures::{
channel::{mpsc, oneshot},
stream::BoxStream,
- FutureExt, SinkExt, StreamExt, TryFutureExt,
+ FutureExt, SinkExt, Stream, StreamExt, TryFutureExt,
};
use parking_lot::{Mutex, RwLock};
use serde::{ser::SerializeStruct, Serialize};
-use std::{fmt, sync::atomic::Ordering::SeqCst, time::Instant};
use std::{
+ fmt, future,
future::Future,
marker::PhantomData,
+ sync::atomic::Ordering::SeqCst,
sync::{
atomic::{self, AtomicU32},
Arc,
},
time::Duration,
+ time::Instant,
};
use tracing::instrument;
@@ -118,6 +120,15 @@ pub struct ConnectionState {
>,
>,
>,
+ #[allow(clippy::type_complexity)]
+ #[serde(skip)]
+ stream_response_channels: Arc<
+ Mutex<
+ Option<
+ HashMap<u32, mpsc::UnboundedSender<(Result<proto::Envelope>, oneshot::Sender<()>)>>,
+ >,
+ >,
+ >,
}
const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
@@ -171,17 +182,28 @@ impl Peer {
outgoing_tx,
next_message_id: Default::default(),
response_channels: Arc::new(Mutex::new(Some(Default::default()))),
+ stream_response_channels: Arc::new(Mutex::new(Some(Default::default()))),
};
let mut writer = MessageStream::new(connection.tx);
let mut reader = MessageStream::new(connection.rx);
let this = self.clone();
let response_channels = connection_state.response_channels.clone();
+ let stream_response_channels = connection_state.stream_response_channels.clone();
+
let handle_io = async move {
tracing::trace!(%connection_id, "handle io future: start");
let _end_connection = util::defer(|| {
response_channels.lock().take();
+ if let Some(channels) = stream_response_channels.lock().take() {
+ for channel in channels.values() {
+ let _ = channel.unbounded_send((
+ Err(anyhow!("connection closed")),
+ oneshot::channel().0,
+ ));
+ }
+ }
this.connections.write().remove(&connection_id);
tracing::trace!(%connection_id, "handle io future: end");
});
@@ -273,12 +295,14 @@ impl Peer {
};
let response_channels = connection_state.response_channels.clone();
+ let stream_response_channels = connection_state.stream_response_channels.clone();
self.connections
.write()
.insert(connection_id, connection_state);
let incoming_rx = incoming_rx.filter_map(move |(incoming, received_at)| {
let response_channels = response_channels.clone();
+ let stream_response_channels = stream_response_channels.clone();
async move {
let message_id = incoming.id;
tracing::trace!(?incoming, "incoming message future: start");
@@ -293,8 +317,15 @@ impl Peer {
responding_to,
"incoming response: received"
);
- let channel = response_channels.lock().as_mut()?.remove(&responding_to);
- if let Some(tx) = channel {
+ let response_channel =
+ response_channels.lock().as_mut()?.remove(&responding_to);
+ let stream_response_channel = stream_response_channels
+ .lock()
+ .as_ref()?
+ .get(&responding_to)
+ .cloned();
+
+ if let Some(tx) = response_channel {
let requester_resumed = oneshot::channel();
if let Err(error) = tx.send((incoming, received_at, requester_resumed.0)) {
tracing::trace!(
@@ -319,6 +350,31 @@ impl Peer {
responding_to,
"incoming response: requester resumed"
);
+ } else if let Some(tx) = stream_response_channel {
+ let requester_resumed = oneshot::channel();
+ if let Err(error) = tx.unbounded_send((Ok(incoming), requester_resumed.0)) {
+ tracing::debug!(
+ %connection_id,
+ message_id,
+ responding_to = responding_to,
+ ?error,
+ "incoming stream response: request future dropped",
+ );
+ }
+
+ tracing::debug!(
+ %connection_id,
+ message_id,
+ responding_to,
+ "incoming stream response: waiting to resume requester"
+ );
+ let _ = requester_resumed.1.await;
+ tracing::debug!(
+ %connection_id,
+ message_id,
+ responding_to,
+ "incoming stream response: requester resumed"
+ );
} else {
let message_type =
proto::build_typed_envelope(connection_id, received_at, incoming)
@@ -451,6 +507,66 @@ impl Peer {
}
}
+ pub fn request_stream<T: RequestMessage>(
+ &self,
+ receiver_id: ConnectionId,
+ request: T,
+ ) -> impl Future<Output = Result<impl Unpin + Stream<Item = Result<T::Response>>>> {
+ let (tx, rx) = mpsc::unbounded();
+ let send = self.connection_state(receiver_id).and_then(|connection| {
+ let message_id = connection.next_message_id.fetch_add(1, SeqCst);
+ let stream_response_channels = connection.stream_response_channels.clone();
+ stream_response_channels
+ .lock()
+ .as_mut()
+ .ok_or_else(|| anyhow!("connection was closed"))?
+ .insert(message_id, tx);
+ connection
+ .outgoing_tx
+ .unbounded_send(proto::Message::Envelope(
+ request.into_envelope(message_id, None, None),
+ ))
+ .map_err(|_| anyhow!("connection was closed"))?;
+ Ok((message_id, stream_response_channels))
+ });
+
+ async move {
+ let (message_id, stream_response_channels) = send?;
+ let stream_response_channels = Arc::downgrade(&stream_response_channels);
+
+ Ok(rx.filter_map(move |(response, _barrier)| {
+ let stream_response_channels = stream_response_channels.clone();
+ future::ready(match response {
+ Ok(response) => {
+ if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
+ Some(Err(anyhow!(
+ "RPC request {} failed - {}",
+ T::NAME,
+ error.message
+ )))
+ } else if let Some(proto::envelope::Payload::EndStream(_)) =
+ &response.payload
+ {
+ // Remove the transmitting end of the response channel to end the stream.
+ if let Some(channels) = stream_response_channels.upgrade() {
+ if let Some(channels) = channels.lock().as_mut() {
+ channels.remove(&message_id);
+ }
+ }
+ None
+ } else {
+ Some(
+ T::Response::from_envelope(response)
+ .ok_or_else(|| anyhow!("received response of the wrong type")),
+ )
+ }
+ }
+ Err(error) => Some(Err(error)),
+ })
+ }))
+ }
+ }
+
pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
let connection = self.connection_state(receiver_id)?;
let message_id = connection
@@ -503,6 +619,24 @@ impl Peer {
Ok(())
}
+ pub fn end_stream<T: RequestMessage>(&self, receipt: Receipt<T>) -> Result<()> {
+ let connection = self.connection_state(receipt.sender_id)?;
+ let message_id = connection
+ .next_message_id
+ .fetch_add(1, atomic::Ordering::SeqCst);
+
+ let message = proto::EndStream {};
+
+ connection
+ .outgoing_tx
+ .unbounded_send(proto::Message::Envelope(message.into_envelope(
+ message_id,
+ Some(receipt.message_id),
+ None,
+ )))?;
+ Ok(())
+ }
+
pub fn respond_with_error<T: RequestMessage>(
&self,
receipt: Receipt<T>,
@@ -149,7 +149,10 @@ messages!(
(CallCanceled, Foreground),
(CancelCall, Foreground),
(ChannelMessageSent, Foreground),
+ (CompleteWithLanguageModel, Background),
(CopyProjectEntry, Foreground),
+ (CountTokensWithLanguageModel, Background),
+ (CountTokensResponse, Background),
(CreateBufferForPeer, Foreground),
(CreateChannel, Foreground),
(CreateChannelResponse, Foreground),
@@ -160,6 +163,7 @@ messages!(
(DeleteChannel, Foreground),
(DeleteNotification, Foreground),
(DeleteProjectEntry, Foreground),
+ (EndStream, Foreground),
(Error, Foreground),
(ExpandProjectEntry, Foreground),
(ExpandProjectEntryResponse, Foreground),
@@ -211,6 +215,7 @@ messages!(
(JoinProjectResponse, Foreground),
(JoinRoom, Foreground),
(JoinRoomResponse, Foreground),
+ (LanguageModelResponse, Background),
(LeaveChannelBuffer, Background),
(LeaveChannelChat, Foreground),
(LeaveProject, Foreground),
@@ -300,6 +305,8 @@ request_messages!(
(Call, Ack),
(CancelCall, Ack),
(CopyProjectEntry, ProjectEntryResponse),
+ (CompleteWithLanguageModel, LanguageModelResponse),
+ (CountTokensWithLanguageModel, CountTokensResponse),
(CreateChannel, CreateChannelResponse),
(CreateProjectEntry, ProjectEntryResponse),
(CreateRoom, CreateRoomResponse),
@@ -22,7 +22,6 @@ gpui.workspace = true
language.workspace = true
menu.workspace = true
project.workspace = true
-semantic_index.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
@@ -705,11 +705,6 @@ impl BufferSearchBar {
option.as_button(is_active, action)
}
pub fn activate_search_mode(&mut self, mode: SearchMode, cx: &mut ViewContext<Self>) {
- assert_ne!(
- mode,
- SearchMode::Semantic,
- "Semantic search is not supported in buffer search"
- );
if mode == self.current_mode {
return;
}
@@ -1022,7 +1017,7 @@ impl BufferSearchBar {
}
}
fn cycle_mode(&mut self, _: &CycleMode, cx: &mut ViewContext<Self>) {
- self.activate_search_mode(next_mode(&self.current_mode, false), cx);
+ self.activate_search_mode(next_mode(&self.current_mode), cx);
}
fn toggle_replace(&mut self, _: &ToggleReplace, cx: &mut ViewContext<Self>) {
if let Some(_) = &self.active_searchable_item {
@@ -1,13 +1,12 @@
use gpui::{Action, SharedString};
-use crate::{ActivateRegexMode, ActivateSemanticMode, ActivateTextMode};
+use crate::{ActivateRegexMode, ActivateTextMode};
// TODO: Update the default search mode to get from config
#[derive(Copy, Clone, Debug, Default, PartialEq)]
pub enum SearchMode {
#[default]
Text,
- Semantic,
Regex,
}
@@ -15,7 +14,6 @@ impl SearchMode {
pub(crate) fn label(&self) -> &'static str {
match self {
SearchMode::Text => "Text",
- SearchMode::Semantic => "Semantic",
SearchMode::Regex => "Regex",
}
}
@@ -25,22 +23,14 @@ impl SearchMode {
pub(crate) fn action(&self) -> Box<dyn Action> {
match self {
SearchMode::Text => ActivateTextMode.boxed_clone(),
- SearchMode::Semantic => ActivateSemanticMode.boxed_clone(),
SearchMode::Regex => ActivateRegexMode.boxed_clone(),
}
}
}
-pub(crate) fn next_mode(mode: &SearchMode, semantic_enabled: bool) -> SearchMode {
+pub(crate) fn next_mode(mode: &SearchMode) -> SearchMode {
match mode {
SearchMode::Text => SearchMode::Regex,
- SearchMode::Regex => {
- if semantic_enabled {
- SearchMode::Semantic
- } else {
- SearchMode::Text
- }
- }
- SearchMode::Semantic => SearchMode::Text,
+ SearchMode::Regex => SearchMode::Text,
}
}
@@ -1,33 +1,26 @@
use crate::{
- history::SearchHistory, mode::SearchMode, ActivateRegexMode, ActivateSemanticMode,
- ActivateTextMode, CycleMode, NextHistoryQuery, PreviousHistoryQuery, ReplaceAll, ReplaceNext,
- SearchOptions, SelectNextMatch, SelectPrevMatch, ToggleCaseSensitive, ToggleIncludeIgnored,
- ToggleReplace, ToggleWholeWord,
+ history::SearchHistory, mode::SearchMode, ActivateRegexMode, ActivateTextMode, CycleMode,
+ NextHistoryQuery, PreviousHistoryQuery, ReplaceAll, ReplaceNext, SearchOptions,
+ SelectNextMatch, SelectPrevMatch, ToggleCaseSensitive, ToggleIncludeIgnored, ToggleReplace,
+ ToggleWholeWord,
};
-use anyhow::{Context as _, Result};
-use collections::HashMap;
+use anyhow::Context as _;
+use collections::{HashMap, HashSet};
use editor::{
actions::SelectAll,
items::active_match_index,
scroll::{Autoscroll, Axis},
- Anchor, Editor, EditorEvent, MultiBuffer, MAX_TAB_TITLE_LEN,
+ Anchor, Editor, EditorElement, EditorEvent, EditorStyle, MultiBuffer, MAX_TAB_TITLE_LEN,
};
-use editor::{EditorElement, EditorStyle};
use gpui::{
actions, div, Action, AnyElement, AnyView, AppContext, Context as _, Element, EntityId,
EventEmitter, FocusHandle, FocusableView, FontStyle, FontWeight, Global, Hsla,
- InteractiveElement, IntoElement, KeyContext, Model, ModelContext, ParentElement, Point,
- PromptLevel, Render, SharedString, Styled, Subscription, Task, TextStyle, View, ViewContext,
- VisualContext, WeakModel, WeakView, WhiteSpace, WindowContext,
+ InteractiveElement, IntoElement, KeyContext, Model, ModelContext, ParentElement, Point, Render,
+ SharedString, Styled, Subscription, Task, TextStyle, View, ViewContext, VisualContext,
+ WeakModel, WeakView, WhiteSpace, WindowContext,
};
use menu::Confirm;
-use project::{
- search::{SearchInputs, SearchQuery},
- Project,
-};
-use semantic_index::{SemanticIndex, SemanticIndexStatus};
-
-use collections::HashSet;
+use project::{search::SearchQuery, Project};
use settings::Settings;
use smol::stream::StreamExt;
use std::{
@@ -35,22 +28,20 @@ use std::{
mem,
ops::{Not, Range},
path::{Path, PathBuf},
- time::{Duration, Instant},
};
use theme::ThemeSettings;
-use workspace::{DeploySearch, NewSearch};
-
use ui::{
h_flex, prelude::*, v_flex, Icon, IconButton, IconName, Label, LabelCommon, LabelSize,
Selectable, ToggleButton, Tooltip,
};
-use util::{paths::PathMatcher, ResultExt as _};
+use util::paths::PathMatcher;
use workspace::{
item::{BreadcrumbText, Item, ItemEvent, ItemHandle},
searchable::{Direction, SearchableItem, SearchableItemHandle},
ItemNavHistory, Pane, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace,
WorkspaceId,
};
+use workspace::{DeploySearch, NewSearch};
const MIN_INPUT_WIDTH_REMS: f32 = 15.;
const MAX_INPUT_WIDTH_REMS: f32 = 30.;
@@ -86,12 +77,6 @@ pub fn init(cx: &mut AppContext) {
register_workspace_action(workspace, move |search_bar, _: &ActivateTextMode, cx| {
search_bar.activate_search_mode(SearchMode::Text, cx)
});
- register_workspace_action(
- workspace,
- move |search_bar, _: &ActivateSemanticMode, cx| {
- search_bar.activate_search_mode(SearchMode::Semantic, cx)
- },
- );
register_workspace_action(workspace, move |search_bar, action: &CycleMode, cx| {
search_bar.cycle_mode(action, cx)
});
@@ -159,8 +144,6 @@ pub struct ProjectSearchView {
query_editor: View<Editor>,
replacement_editor: View<Editor>,
results_editor: View<Editor>,
- semantic_state: Option<SemanticState>,
- semantic_permissioned: Option<bool>,
search_options: SearchOptions,
panels_with_errors: HashSet<InputPanel>,
active_match_index: Option<usize>,
@@ -174,12 +157,6 @@ pub struct ProjectSearchView {
_subscriptions: Vec<Subscription>,
}
-struct SemanticState {
- index_status: SemanticIndexStatus,
- maintain_rate_limit: Option<Task<()>>,
- _subscription: Subscription,
-}
-
#[derive(Debug, Clone)]
struct ProjectSearchSettings {
search_options: SearchOptions,
@@ -282,68 +259,6 @@ impl ProjectSearch {
}));
cx.notify();
}
-
- fn semantic_search(&mut self, inputs: &SearchInputs, cx: &mut ModelContext<Self>) {
- let search = SemanticIndex::global(cx).map(|index| {
- index.update(cx, |semantic_index, cx| {
- semantic_index.search_project(
- self.project.clone(),
- inputs.as_str().to_owned(),
- 10,
- inputs.files_to_include().to_vec(),
- inputs.files_to_exclude().to_vec(),
- cx,
- )
- })
- });
- self.search_id += 1;
- self.match_ranges.clear();
- self.search_history.add(inputs.as_str().to_string());
- self.no_results = None;
- self.pending_search = Some(cx.spawn(|this, mut cx| async move {
- let results = search?.await.log_err()?;
- let matches = results
- .into_iter()
- .map(|result| (result.buffer, vec![result.range.start..result.range.start]));
-
- this.update(&mut cx, |this, cx| {
- this.no_results = Some(true);
- this.excerpts.update(cx, |excerpts, cx| {
- excerpts.clear(cx);
- });
- })
- .ok()?;
- for (buffer, ranges) in matches {
- let mut match_ranges = this
- .update(&mut cx, |this, cx| {
- this.no_results = Some(false);
- this.excerpts.update(cx, |excerpts, cx| {
- excerpts.stream_excerpts_with_context_lines(buffer, ranges, 3, cx)
- })
- })
- .ok()?;
- while let Some(match_range) = match_ranges.next().await {
- this.update(&mut cx, |this, cx| {
- this.match_ranges.push(match_range);
- while let Ok(Some(match_range)) = match_ranges.try_next() {
- this.match_ranges.push(match_range);
- }
- cx.notify();
- })
- .ok()?;
- }
- }
-
- this.update(&mut cx, |this, cx| {
- this.pending_search.take();
- cx.notify();
- })
- .ok()?;
-
- None
- }));
- cx.notify();
- }
}
#[derive(Clone, Debug, PartialEq, Eq)]
@@ -358,8 +273,6 @@ impl EventEmitter<ViewEvent> for ProjectSearchView {}
impl Render for ProjectSearchView {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
- const PLEASE_AUTHENTICATE: &str = "API Key Missing: Please set 'OPENAI_API_KEY' in Environment Variables. If you authenticated using the Assistant Panel, please restart Zed to Authenticate.";
-
if self.has_matches() {
div()
.flex_1()
@@ -370,7 +283,7 @@ impl Render for ProjectSearchView {
let model = self.model.read(cx);
let has_no_results = model.no_results.unwrap_or(false);
let is_search_underway = model.pending_search.is_some();
- let mut major_text = if is_search_underway {
+ let major_text = if is_search_underway {
Label::new("Searching...")
} else if has_no_results {
Label::new("No results")
@@ -378,43 +291,6 @@ impl Render for ProjectSearchView {
Label::new(format!("{} search all files", self.current_mode.label()))
};
- let mut show_minor_text = true;
- let semantic_status = self.semantic_state.as_ref().and_then(|semantic| {
- let status = semantic.index_status;
- match status {
- SemanticIndexStatus::NotAuthenticated => {
- major_text = Label::new("Not Authenticated");
- show_minor_text = false;
- Some(PLEASE_AUTHENTICATE.to_string())
- }
- SemanticIndexStatus::Indexed => Some("Indexing complete".to_string()),
- SemanticIndexStatus::Indexing {
- remaining_files,
- rate_limit_expiry,
- } => {
- if remaining_files == 0 {
- Some("Indexing...".to_string())
- } else {
- if let Some(rate_limit_expiry) = rate_limit_expiry {
- let remaining_seconds =
- rate_limit_expiry.duration_since(Instant::now());
- if remaining_seconds > Duration::from_secs(0) {
- Some(format!(
- "Remaining files to index (rate limit resets in {}s): {}",
- remaining_seconds.as_secs(),
- remaining_files
- ))
- } else {
- Some(format!("Remaining files to index: {}", remaining_files))
- }
- } else {
- Some(format!("Remaining files to index: {}", remaining_files))
- }
- }
- }
- SemanticIndexStatus::NotIndexed => None,
- }
- });
let major_text = div().justify_center().max_w_96().child(major_text);
let minor_text: Option<SharedString> = if let Some(no_results) = model.no_results {
@@ -424,12 +300,7 @@ impl Render for ProjectSearchView {
None
}
} else {
- if let Some(mut semantic_status) = semantic_status {
- semantic_status.extend(self.landing_text_minor().chars());
- Some(semantic_status.into())
- } else {
- Some(self.landing_text_minor())
- }
+ Some(self.landing_text_minor())
};
let minor_text = minor_text.map(|text| {
div()
@@ -676,58 +547,6 @@ impl ProjectSearchView {
});
}
- fn index_project(&mut self, cx: &mut ViewContext<Self>) {
- if let Some(semantic_index) = SemanticIndex::global(cx) {
- // Semantic search uses no options
- self.search_options = SearchOptions::none();
-
- let project = self.model.read(cx).project.clone();
-
- semantic_index.update(cx, |semantic_index, cx| {
- semantic_index
- .index_project(project.clone(), cx)
- .detach_and_log_err(cx);
- });
-
- self.semantic_state = Some(SemanticState {
- index_status: semantic_index.read(cx).status(&project),
- maintain_rate_limit: None,
- _subscription: cx.observe(&semantic_index, Self::semantic_index_changed),
- });
- self.semantic_index_changed(semantic_index, cx);
- }
- }
-
- fn semantic_index_changed(
- &mut self,
- semantic_index: Model<SemanticIndex>,
- cx: &mut ViewContext<Self>,
- ) {
- let project = self.model.read(cx).project.clone();
- if let Some(semantic_state) = self.semantic_state.as_mut() {
- cx.notify();
- semantic_state.index_status = semantic_index.read(cx).status(&project);
- if let SemanticIndexStatus::Indexing {
- rate_limit_expiry: Some(_),
- ..
- } = &semantic_state.index_status
- {
- if semantic_state.maintain_rate_limit.is_none() {
- semantic_state.maintain_rate_limit =
- Some(cx.spawn(|this, mut cx| async move {
- loop {
- cx.background_executor().timer(Duration::from_secs(1)).await;
- this.update(&mut cx, |_, cx| cx.notify()).log_err();
- }
- }));
- return;
- }
- } else {
- semantic_state.maintain_rate_limit = None;
- }
- }
- }
-
fn clear_search(&mut self, cx: &mut ViewContext<Self>) {
self.model.update(cx, |model, cx| {
model.pending_search = None;
@@ -750,63 +569,7 @@ impl ProjectSearchView {
self.clear_search(cx);
self.current_mode = mode;
self.active_match_index = None;
-
- match mode {
- SearchMode::Semantic => {
- let has_permission = self.semantic_permissioned(cx);
- self.active_match_index = None;
- cx.spawn(|this, mut cx| async move {
- let has_permission = has_permission.await?;
-
- if !has_permission {
- let answer = this.update(&mut cx, |this, cx| {
- let project = this.model.read(cx).project.clone();
- let project_name = project
- .read(cx)
- .worktree_root_names(cx)
- .collect::<Vec<&str>>()
- .join("/");
- let is_plural =
- project_name.chars().filter(|letter| *letter == '/').count() > 0;
- let prompt_text = format!("Would you like to index the '{}' project{} for semantic search? This requires sending code to the OpenAI API", project_name,
- if is_plural {
- "s"
- } else {""});
- cx.prompt(
- PromptLevel::Info,
- prompt_text.as_str(),
- None,
- &["Continue", "Cancel"],
- )
- })?;
-
- if answer.await? == 0 {
- this.update(&mut cx, |this, _| {
- this.semantic_permissioned = Some(true);
- })?;
- } else {
- this.update(&mut cx, |this, cx| {
- this.semantic_permissioned = Some(false);
- debug_assert_ne!(previous_mode, SearchMode::Semantic, "Tried to re-enable semantic search mode after user modal was rejected");
- this.activate_search_mode(previous_mode, cx);
- })?;
- return anyhow::Ok(());
- }
- }
-
- this.update(&mut cx, |this, cx| {
- this.index_project(cx);
- })?;
-
- anyhow::Ok(())
- }).detach_and_log_err(cx);
- }
- SearchMode::Regex | SearchMode::Text => {
- self.semantic_state = None;
- self.active_match_index = None;
- self.search(cx);
- }
- }
+ self.search(cx);
cx.update_global(|state: &mut ActiveSettings, cx| {
state.0.insert(
@@ -973,8 +736,6 @@ impl ProjectSearchView {
model,
query_editor,
results_editor,
- semantic_state: None,
- semantic_permissioned: None,
search_options: options,
panels_with_errors: HashSet::default(),
active_match_index: None,
@@ -990,19 +751,6 @@ impl ProjectSearchView {
this
}
- fn semantic_permissioned(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<bool>> {
- if let Some(value) = self.semantic_permissioned {
- return Task::ready(Ok(value));
- }
-
- SemanticIndex::global(cx)
- .map(|semantic| {
- let project = self.model.read(cx).project.clone();
- semantic.update(cx, |this, cx| this.project_previously_indexed(&project, cx))
- })
- .unwrap_or(Task::ready(Ok(false)))
- }
-
pub fn new_search_in_directory(
workspace: &mut Workspace,
dir_path: &Path,
@@ -1126,22 +874,8 @@ impl ProjectSearchView {
}
fn search(&mut self, cx: &mut ViewContext<Self>) {
- let mode = self.current_mode;
- match mode {
- SearchMode::Semantic => {
- if self.semantic_state.is_some() {
- if let Some(query) = self.build_search_query(cx) {
- self.model
- .update(cx, |model, cx| model.semantic_search(query.as_inner(), cx));
- }
- }
- }
-
- _ => {
- if let Some(query) = self.build_search_query(cx) {
- self.model.update(cx, |model, cx| model.search(query, cx));
- }
- }
+ if let Some(query) = self.build_search_query(cx) {
+ self.model.update(cx, |model, cx| model.search(query, cx));
}
}
@@ -1356,7 +1090,6 @@ impl ProjectSearchView {
fn landing_text_minor(&self) -> SharedString {
match self.current_mode {
SearchMode::Text | SearchMode::Regex => "Include/exclude specific paths with the filter option. Matching exact word and/or casing is available too.".into(),
- SearchMode::Semantic => "\nSimply explain the code you are looking to find. ex. 'prompt user for permissions to index their project'".into()
}
}
fn border_color_for(&self, panel: InputPanel, cx: &WindowContext) -> Hsla {
@@ -1387,8 +1120,7 @@ impl ProjectSearchBar {
fn cycle_mode(&self, _: &CycleMode, cx: &mut ViewContext<Self>) {
if let Some(view) = self.active_project_search.as_ref() {
view.update(cx, |this, cx| {
- let new_mode =
- crate::mode::next_mode(&this.current_mode, SemanticIndex::enabled(cx));
+ let new_mode = crate::mode::next_mode(&this.current_mode);
this.activate_search_mode(new_mode, cx);
let editor_handle = this.query_editor.focus_handle(cx);
cx.focus(&editor_handle);
@@ -1681,7 +1413,6 @@ impl Render for ProjectSearchBar {
});
}
let search = search.read(cx);
- let semantic_is_available = SemanticIndex::enabled(cx);
let query_column = h_flex()
.flex_1()
@@ -1711,12 +1442,8 @@ impl Render for ProjectSearchBar {
.unwrap_or_default(),
),
)
- .when(search.current_mode != SearchMode::Semantic, |this| {
- this.child(
- IconButton::new(
- "project-search-case-sensitive",
- IconName::CaseSensitive,
- )
+ .child(
+ IconButton::new("project-search-case-sensitive", IconName::CaseSensitive)
.tooltip(|cx| {
Tooltip::for_action(
"Toggle case sensitive",
@@ -1728,18 +1455,17 @@ impl Render for ProjectSearchBar {
.on_click(cx.listener(|this, _, cx| {
this.toggle_search_option(SearchOptions::CASE_SENSITIVE, cx);
})),
- )
- .child(
- IconButton::new("project-search-whole-word", IconName::WholeWord)
- .tooltip(|cx| {
- Tooltip::for_action("Toggle whole word", &ToggleWholeWord, cx)
- })
- .selected(self.is_option_enabled(SearchOptions::WHOLE_WORD, cx))
- .on_click(cx.listener(|this, _, cx| {
- this.toggle_search_option(SearchOptions::WHOLE_WORD, cx);
- })),
- )
- }),
+ )
+ .child(
+ IconButton::new("project-search-whole-word", IconName::WholeWord)
+ .tooltip(|cx| {
+ Tooltip::for_action("Toggle whole word", &ToggleWholeWord, cx)
+ })
+ .selected(self.is_option_enabled(SearchOptions::WHOLE_WORD, cx))
+ .on_click(cx.listener(|this, _, cx| {
+ this.toggle_search_option(SearchOptions::WHOLE_WORD, cx);
+ })),
+ ),
);
let mode_column = v_flex().items_start().justify_start().child(
@@ -1775,33 +1501,8 @@ impl Render for ProjectSearchBar {
cx,
)
})
- .map(|this| {
- if semantic_is_available {
- this.middle()
- } else {
- this.last()
- }
- }),
- )
- .when(semantic_is_available, |this| {
- this.child(
- ToggleButton::new("project-search-semantic-button", "Semantic")
- .style(ButtonStyle::Filled)
- .size(ButtonSize::Large)
- .selected(search.current_mode == SearchMode::Semantic)
- .on_click(cx.listener(|this, _, cx| {
- this.activate_search_mode(SearchMode::Semantic, cx)
- }))
- .tooltip(|cx| {
- Tooltip::for_action(
- "Toggle semantic search",
- &ActivateSemanticMode,
- cx,
- )
- })
- .last(),
- )
- }),
+ .last(),
+ ),
)
.child(
IconButton::new("project-search-toggle-replace", IconName::Replace)
@@ -1929,21 +1630,16 @@ impl Render for ProjectSearchBar {
.border_color(search.border_color_for(InputPanel::Include, cx))
.rounded_lg()
.child(self.render_text_input(&search.included_files_editor, cx))
- .when(search.current_mode != SearchMode::Semantic, |this| {
- this.child(
- SearchOptions::INCLUDE_IGNORED.as_button(
- search
- .search_options
- .contains(SearchOptions::INCLUDE_IGNORED),
- cx.listener(|this, _, cx| {
- this.toggle_search_option(
- SearchOptions::INCLUDE_IGNORED,
- cx,
- );
- }),
- ),
- )
- }),
+ .child(
+ SearchOptions::INCLUDE_IGNORED.as_button(
+ search
+ .search_options
+ .contains(SearchOptions::INCLUDE_IGNORED),
+ cx.listener(|this, _, cx| {
+ this.toggle_search_option(SearchOptions::INCLUDE_IGNORED, cx);
+ }),
+ ),
+ ),
)
.child(
h_flex()
@@ -1972,9 +1668,6 @@ impl Render for ProjectSearchBar {
.on_action(cx.listener(|this, _: &ActivateRegexMode, cx| {
this.activate_search_mode(SearchMode::Regex, cx)
}))
- .on_action(cx.listener(|this, _: &ActivateSemanticMode, cx| {
- this.activate_search_mode(SearchMode::Semantic, cx)
- }))
.capture_action(cx.listener(|this, action, cx| {
this.tab(action, cx);
cx.stop_propagation();
@@ -1987,35 +1680,33 @@ impl Render for ProjectSearchBar {
.on_action(cx.listener(|this, action, cx| {
this.cycle_mode(action, cx);
}))
- .when(search.current_mode != SearchMode::Semantic, |this| {
- this.on_action(cx.listener(|this, action, cx| {
- this.toggle_replace(action, cx);
- }))
- .on_action(cx.listener(|this, _: &ToggleWholeWord, cx| {
- this.toggle_search_option(SearchOptions::WHOLE_WORD, cx);
- }))
- .on_action(cx.listener(|this, _: &ToggleCaseSensitive, cx| {
- this.toggle_search_option(SearchOptions::CASE_SENSITIVE, cx);
- }))
- .on_action(cx.listener(|this, action, cx| {
- if let Some(search) = this.active_project_search.as_ref() {
- search.update(cx, |this, cx| {
- this.replace_next(action, cx);
- })
- }
- }))
- .on_action(cx.listener(|this, action, cx| {
- if let Some(search) = this.active_project_search.as_ref() {
- search.update(cx, |this, cx| {
- this.replace_all(action, cx);
- })
- }
+ .on_action(cx.listener(|this, action, cx| {
+ this.toggle_replace(action, cx);
+ }))
+ .on_action(cx.listener(|this, _: &ToggleWholeWord, cx| {
+ this.toggle_search_option(SearchOptions::WHOLE_WORD, cx);
+ }))
+ .on_action(cx.listener(|this, _: &ToggleCaseSensitive, cx| {
+ this.toggle_search_option(SearchOptions::CASE_SENSITIVE, cx);
+ }))
+ .on_action(cx.listener(|this, action, cx| {
+ if let Some(search) = this.active_project_search.as_ref() {
+ search.update(cx, |this, cx| {
+ this.replace_next(action, cx);
+ })
+ }
+ }))
+ .on_action(cx.listener(|this, action, cx| {
+ if let Some(search) = this.active_project_search.as_ref() {
+ search.update(cx, |this, cx| {
+ this.replace_all(action, cx);
+ })
+ }
+ }))
+ .when(search.filters_enabled, |this| {
+ this.on_action(cx.listener(|this, _: &ToggleIncludeIgnored, cx| {
+ this.toggle_search_option(SearchOptions::INCLUDE_IGNORED, cx);
}))
- .when(search.filters_enabled, |this| {
- this.on_action(cx.listener(|this, _: &ToggleIncludeIgnored, cx| {
- this.toggle_search_option(SearchOptions::INCLUDE_IGNORED, cx);
- }))
- })
})
.on_action(cx.listener(Self::select_next_match))
.on_action(cx.listener(Self::select_prev_match))
@@ -2039,12 +1730,6 @@ impl ToolbarItemView for ProjectSearchBar {
self.subscription = None;
self.active_project_search = None;
if let Some(search) = active_pane_item.and_then(|i| i.downcast::<ProjectSearchView>()) {
- search.update(cx, |search, cx| {
- if search.current_mode == SearchMode::Semantic {
- search.index_project(cx);
- }
- });
-
self.subscription = Some(cx.observe(&search, |_, _, cx| cx.notify()));
self.active_project_search = Some(search);
ToolbarItemLocation::PrimaryLeft {}
@@ -2123,9 +1808,8 @@ pub mod tests {
use editor::DisplayPoint;
use gpui::{Action, TestAppContext, WindowHandle};
use project::FakeFs;
- use semantic_index::semantic_index_settings::SemanticIndexSettings;
use serde_json::json;
- use settings::{Settings, SettingsStore};
+ use settings::SettingsStore;
use std::sync::Arc;
use workspace::DeploySearch;
@@ -3446,8 +3130,6 @@ pub mod tests {
let settings = SettingsStore::test(cx);
cx.set_global(settings);
- SemanticIndexSettings::register(cx);
-
theme::init(theme::LoadThemes::JustBase, cx);
language::init(cx);
@@ -33,7 +33,6 @@ actions!(
NextHistoryQuery,
PreviousHistoryQuery,
ActivateTextMode,
- ActivateSemanticMode,
ActivateRegexMode,
ReplaceAll,
ReplaceNext,
@@ -1,66 +0,0 @@
-[package]
-name = "semantic_index"
-version = "0.1.0"
-edition = "2021"
-publish = false
-license = "GPL-3.0-or-later"
-
-[lints]
-workspace = true
-
-[lib]
-path = "src/semantic_index.rs"
-doctest = false
-
-[dependencies]
-ai.workspace = true
-anyhow.workspace = true
-collections.workspace = true
-futures.workspace = true
-gpui.workspace = true
-language.workspace = true
-lazy_static.workspace = true
-log.workspace = true
-ndarray = { version = "0.15.0" }
-ordered-float.workspace = true
-parking_lot.workspace = true
-postage.workspace = true
-project.workspace = true
-rand.workspace = true
-release_channel.workspace = true
-rpc.workspace = true
-rusqlite.workspace = true
-schemars.workspace = true
-serde.workspace = true
-serde_json.workspace = true
-settings.workspace = true
-sha1 = "0.10.5"
-smol.workspace = true
-tree-sitter.workspace = true
-util.workspace = true
-workspace.workspace = true
-
-[dev-dependencies]
-ai = { workspace = true, features = ["test-support"] }
-collections = { workspace = true, features = ["test-support"] }
-ctor.workspace = true
-env_logger.workspace = true
-gpui = { workspace = true, features = ["test-support"] }
-language = { workspace = true, features = ["test-support"] }
-pretty_assertions.workspace = true
-project = { workspace = true, features = ["test-support"] }
-rand.workspace = true
-rpc = { workspace = true, features = ["test-support"] }
-settings = { workspace = true, features = ["test-support"]}
-tempfile.workspace = true
-tree-sitter-cpp.workspace = true
-tree-sitter-elixir.workspace = true
-tree-sitter-json.workspace = true
-tree-sitter-lua.workspace = true
-tree-sitter-php.workspace = true
-tree-sitter-ruby.workspace = true
-tree-sitter-rust.workspace = true
-tree-sitter-toml.workspace = true
-tree-sitter-typescript.workspace = true
-unindent.workspace = true
-workspace = { workspace = true, features = ["test-support"] }
@@ -1 +0,0 @@
-../../LICENSE-GPL
@@ -1,20 +0,0 @@
-
-# Semantic Index
-
-## Evaluation
-
-### Metrics
-
-nDCG@k:
-- "The value of NDCG is determined by comparing the relevance of the items returned by the search engine to the relevance of the item that a hypothetical "ideal" search engine would return.
-- "The relevance of result is represented by a score (also known as a 'grade') that is assigned to the search query. The scores of these results are then discounted based on their position in the search results -- did they get recommended first or last?"
-
-MRR@k:
-- "Mean reciprocal rank quantifies the rank of the first relevant item found in the recommendation list."
-
-MAP@k:
-- "Mean average precision averages the precision@k metric at each relevant item position in the recommendation list.
-
-Resources:
-- [Evaluating recommendation metrics](https://www.shaped.ai/blog/evaluating-recommendation-systems-map-mmr-ndcg)
-- [Math Walkthrough](https://towardsdatascience.com/demystifying-ndcg-bee3be58cfe0)
@@ -1,114 +0,0 @@
-{
- "repo": "https://github.com/AntonOsika/gpt-engineer.git",
- "commit": "7735a6445bae3611c62f521e6464c67c957f87c2",
- "assertions": [
- {
- "query": "How do I contribute to this project?",
- "matches": [
- ".github/CONTRIBUTING.md:1",
- "ROADMAP.md:48"
- ]
- },
- {
- "query": "What version of the openai package is active?",
- "matches": [
- "pyproject.toml:14"
- ]
- },
- {
- "query": "Ask user for clarification",
- "matches": [
- "gpt_engineer/steps.py:69"
- ]
- },
- {
- "query": "generate tests for python code",
- "matches": [
- "gpt_engineer/steps.py:153"
- ]
- },
- {
- "query": "get item from database based on key",
- "matches": [
- "gpt_engineer/db.py:42",
- "gpt_engineer/db.py:68"
- ]
- },
- {
- "query": "prompt user to select files",
- "matches": [
- "gpt_engineer/file_selector.py:171",
- "gpt_engineer/file_selector.py:306",
- "gpt_engineer/file_selector.py:289",
- "gpt_engineer/file_selector.py:234"
- ]
- },
- {
- "query": "send to rudderstack",
- "matches": [
- "gpt_engineer/collect.py:11",
- "gpt_engineer/collect.py:38"
- ]
- },
- {
- "query": "parse code blocks from chat messages",
- "matches": [
- "gpt_engineer/chat_to_files.py:10",
- "docs/intro/chat_parsing.md:1"
- ]
- },
- {
- "query": "how do I use the docker cli?",
- "matches": [
- "docker/README.md:1"
- ]
- },
- {
- "query": "ask the user if the code ran successfully?",
- "matches": [
- "gpt_engineer/learning.py:54"
- ]
- },
- {
- "query": "how is consent granted by the user?",
- "matches": [
- "gpt_engineer/learning.py:107",
- "gpt_engineer/learning.py:130",
- "gpt_engineer/learning.py:152"
- ]
- },
- {
- "query": "what are all the different steps the agent can take?",
- "matches": [
- "docs/intro/steps_module.md:1",
- "gpt_engineer/steps.py:391"
- ]
- },
- {
- "query": "ask the user for clarification?",
- "matches": [
- "gpt_engineer/steps.py:69"
- ]
- },
- {
- "query": "what models are available?",
- "matches": [
- "gpt_engineer/ai.py:315",
- "gpt_engineer/ai.py:341",
- "docs/open-models.md:1"
- ]
- },
- {
- "query": "what is the current focus of the project?",
- "matches": [
- "ROADMAP.md:11"
- ]
- },
- {
- "query": "does the agent know how to fix code?",
- "matches": [
- "gpt_engineer/steps.py:367"
- ]
- }
- ]
-}
@@ -1,104 +0,0 @@
-{
- "repo": "https://github.com/tree-sitter/tree-sitter.git",
- "commit": "46af27796a76c72d8466627d499f2bca4af958ee",
- "assertions": [
- {
- "query": "What attributes are available for the tags configuration struct?",
- "matches": [
- "tags/src/lib.rs:24"
- ]
- },
- {
- "query": "create a new tag configuration",
- "matches": [
- "tags/src/lib.rs:119"
- ]
- },
- {
- "query": "generate tags based on config",
- "matches": [
- "tags/src/lib.rs:261"
- ]
- },
- {
- "query": "match on ts quantifier in rust",
- "matches": [
- "lib/binding_rust/lib.rs:139"
- ]
- },
- {
- "query": "cli command to generate tags",
- "matches": [
- "cli/src/tags.rs:10"
- ]
- },
- {
- "query": "what version of the tree-sitter-tags package is active?",
- "matches": [
- "tags/Cargo.toml:4"
- ]
- },
- {
- "query": "Insert a new parse state",
- "matches": [
- "cli/src/generate/build_tables/build_parse_table.rs:153"
- ]
- },
- {
- "query": "Handle conflict when numerous actions occur on the same symbol",
- "matches": [
- "cli/src/generate/build_tables/build_parse_table.rs:363",
- "cli/src/generate/build_tables/build_parse_table.rs:442"
- ]
- },
- {
- "query": "Match based on associativity of actions",
- "matches": [
- "cri/src/generate/build_tables/build_parse_table.rs:542"
- ]
- },
- {
- "query": "Format token set display",
- "matches": [
- "cli/src/generate/build_tables/item.rs:246"
- ]
- },
- {
- "query": "extract choices from rule",
- "matches": [
- "cli/src/generate/prepare_grammar/flatten_grammar.rs:124"
- ]
- },
- {
- "query": "How do we identify if a symbol is being used?",
- "matches": [
- "cli/src/generate/prepare_grammar/flatten_grammar.rs:175"
- ]
- },
- {
- "query": "How do we launch the playground?",
- "matches": [
- "cli/src/playground.rs:46"
- ]
- },
- {
- "query": "How do we test treesitter query matches in rust?",
- "matches": [
- "cli/src/query_testing.rs:152",
- "cli/src/tests/query_test.rs:781",
- "cli/src/tests/query_test.rs:2163",
- "cli/src/tests/query_test.rs:3781",
- "cli/src/tests/query_test.rs:887"
- ]
- },
- {
- "query": "What does the CLI do?",
- "matches": [
- "cli/README.md:10",
- "cli/loader/README.md:3",
- "docs/section-5-implementation.md:14",
- "docs/section-5-implementation.md:18"
- ]
- }
- ]
-}
@@ -1,594 +0,0 @@
-use crate::{
- parsing::{Span, SpanDigest},
- SEMANTIC_INDEX_VERSION,
-};
-use ai::embedding::Embedding;
-use anyhow::{anyhow, Context, Result};
-use collections::HashMap;
-use futures::channel::oneshot;
-use gpui::BackgroundExecutor;
-use ndarray::{Array1, Array2};
-use ordered_float::OrderedFloat;
-use project::Fs;
-use rpc::proto::Timestamp;
-use rusqlite::params;
-use rusqlite::types::Value;
-use std::{
- future::Future,
- ops::Range,
- path::{Path, PathBuf},
- rc::Rc,
- sync::Arc,
- time::SystemTime,
-};
-use util::{paths::PathMatcher, TryFutureExt};
-
-pub fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
- let mut indices = (0..data.len()).collect::<Vec<_>>();
- indices.sort_by_key(|&i| &data[i]);
- indices.reverse();
- indices
-}
-
-#[derive(Debug)]
-pub struct FileRecord {
- pub id: usize,
- pub relative_path: String,
- pub mtime: Timestamp,
-}
-
-#[derive(Clone)]
-pub struct VectorDatabase {
- path: Arc<Path>,
- transactions:
- smol::channel::Sender<Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>>,
-}
-
-impl VectorDatabase {
- pub async fn new(
- fs: Arc<dyn Fs>,
- path: Arc<Path>,
- executor: BackgroundExecutor,
- ) -> Result<Self> {
- if let Some(db_directory) = path.parent() {
- fs.create_dir(db_directory).await?;
- }
-
- let (transactions_tx, transactions_rx) = smol::channel::unbounded::<
- Box<dyn 'static + Send + FnOnce(&mut rusqlite::Connection)>,
- >();
- executor
- .spawn({
- let path = path.clone();
- async move {
- let mut connection = rusqlite::Connection::open(&path)?;
-
- connection.pragma_update(None, "journal_mode", "wal")?;
- connection.pragma_update(None, "synchronous", "normal")?;
- connection.pragma_update(None, "cache_size", 1000000)?;
- connection.pragma_update(None, "temp_store", "MEMORY")?;
-
- while let Ok(transaction) = transactions_rx.recv().await {
- transaction(&mut connection);
- }
-
- anyhow::Ok(())
- }
- .log_err()
- })
- .detach();
- let this = Self {
- transactions: transactions_tx,
- path,
- };
- this.initialize_database().await?;
- Ok(this)
- }
-
- pub fn path(&self) -> &Arc<Path> {
- &self.path
- }
-
- fn transact<F, T>(&self, f: F) -> impl Future<Output = Result<T>>
- where
- F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result<T>,
- T: 'static + Send,
- {
- let (tx, rx) = oneshot::channel();
- let transactions = self.transactions.clone();
- async move {
- if transactions
- .send(Box::new(|connection| {
- let result = connection
- .transaction()
- .map_err(|err| anyhow!(err))
- .and_then(|transaction| {
- let result = f(&transaction)?;
- transaction.commit()?;
- Ok(result)
- });
- let _ = tx.send(result);
- }))
- .await
- .is_err()
- {
- return Err(anyhow!("connection was dropped"))?;
- }
- rx.await?
- }
- }
-
- fn initialize_database(&self) -> impl Future<Output = Result<()>> {
- self.transact(|db| {
- rusqlite::vtab::array::load_module(&db)?;
-
- // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped
- let version_query = db.prepare("SELECT version from semantic_index_config");
- let version = version_query
- .and_then(|mut query| query.query_row([], |row| row.get::<_, i64>(0)));
- if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) {
- log::trace!("vector database schema up to date");
- return Ok(());
- }
-
- log::trace!("vector database schema out of date. updating...");
- // We renamed the `documents` table to `spans`, so we want to drop
- // `documents` without recreating it if it exists.
- db.execute("DROP TABLE IF EXISTS documents", [])
- .context("failed to drop 'documents' table")?;
- db.execute("DROP TABLE IF EXISTS spans", [])
- .context("failed to drop 'spans' table")?;
- db.execute("DROP TABLE IF EXISTS files", [])
- .context("failed to drop 'files' table")?;
- db.execute("DROP TABLE IF EXISTS worktrees", [])
- .context("failed to drop 'worktrees' table")?;
- db.execute("DROP TABLE IF EXISTS semantic_index_config", [])
- .context("failed to drop 'semantic_index_config' table")?;
-
- // Initialize Vector Databasing Tables
- db.execute(
- "CREATE TABLE semantic_index_config (
- version INTEGER NOT NULL
- )",
- [],
- )?;
-
- db.execute(
- "INSERT INTO semantic_index_config (version) VALUES (?1)",
- params![SEMANTIC_INDEX_VERSION],
- )?;
-
- db.execute(
- "CREATE TABLE worktrees (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- absolute_path VARCHAR NOT NULL
- );
- CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path);
- ",
- [],
- )?;
-
- db.execute(
- "CREATE TABLE files (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- worktree_id INTEGER NOT NULL,
- relative_path VARCHAR NOT NULL,
- mtime_seconds INTEGER NOT NULL,
- mtime_nanos INTEGER NOT NULL,
- FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
- )",
- [],
- )?;
-
- db.execute(
- "CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)",
- [],
- )?;
-
- db.execute(
- "CREATE TABLE spans (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- file_id INTEGER NOT NULL,
- start_byte INTEGER NOT NULL,
- end_byte INTEGER NOT NULL,
- name VARCHAR NOT NULL,
- embedding BLOB NOT NULL,
- digest BLOB NOT NULL,
- FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
- )",
- [],
- )?;
- db.execute(
- "CREATE INDEX spans_digest ON spans (digest)",
- [],
- )?;
-
- log::trace!("vector database initialized with updated schema.");
- Ok(())
- })
- }
-
- pub fn delete_file(
- &self,
- worktree_id: i64,
- delete_path: Arc<Path>,
- ) -> impl Future<Output = Result<()>> {
- self.transact(move |db| {
- db.execute(
- "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2",
- params![worktree_id, delete_path.to_str()],
- )?;
- Ok(())
- })
- }
-
- pub fn insert_file(
- &self,
- worktree_id: i64,
- path: Arc<Path>,
- mtime: SystemTime,
- spans: Vec<Span>,
- ) -> impl Future<Output = Result<()>> {
- self.transact(move |db| {
- // Return the existing ID, if both the file and mtime match
- let mtime = Timestamp::from(mtime);
-
- db.execute(
- "
- REPLACE INTO files
- (worktree_id, relative_path, mtime_seconds, mtime_nanos)
- VALUES (?1, ?2, ?3, ?4)
- ",
- params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos],
- )?;
-
- let file_id = db.last_insert_rowid();
-
- let mut query = db.prepare(
- "
- INSERT INTO spans
- (file_id, start_byte, end_byte, name, embedding, digest)
- VALUES (?1, ?2, ?3, ?4, ?5, ?6)
- ",
- )?;
-
- for span in spans {
- query.execute(params![
- file_id,
- span.range.start.to_string(),
- span.range.end.to_string(),
- span.name,
- span.embedding,
- span.digest
- ])?;
- }
-
- Ok(())
- })
- }
-
- pub fn worktree_previously_indexed(
- &self,
- worktree_root_path: &Path,
- ) -> impl Future<Output = Result<bool>> {
- let worktree_root_path = worktree_root_path.to_string_lossy().into_owned();
- self.transact(move |db| {
- let mut worktree_query =
- db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
- let worktree_id =
- worktree_query.query_row(params![worktree_root_path], |row| row.get::<_, i64>(0));
-
- Ok(worktree_id.is_ok())
- })
- }
-
- pub fn embeddings_for_digests(
- &self,
- digests: Vec<SpanDigest>,
- ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
- self.transact(move |db| {
- let mut query = db.prepare(
- "
- SELECT digest, embedding
- FROM spans
- WHERE digest IN rarray(?)
- ",
- )?;
- let mut embeddings_by_digest = HashMap::default();
- let digests = Rc::new(
- digests
- .into_iter()
- .map(|digest| Value::Blob(digest.0.to_vec()))
- .collect::<Vec<_>>(),
- );
- let rows = query.query_map(params![digests], |row| {
- Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
- })?;
-
- for (digest, embedding) in rows.flatten() {
- embeddings_by_digest.insert(digest, embedding);
- }
-
- Ok(embeddings_by_digest)
- })
- }
-
- pub fn embeddings_for_files(
- &self,
- worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
- ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
- self.transact(move |db| {
- let mut query = db.prepare(
- "
- SELECT digest, embedding
- FROM spans
- LEFT JOIN files ON files.id = spans.file_id
- WHERE files.worktree_id = ? AND files.relative_path IN rarray(?)
- ",
- )?;
- let mut embeddings_by_digest = HashMap::default();
- for (worktree_id, file_paths) in worktree_id_file_paths {
- let file_paths = Rc::new(
- file_paths
- .into_iter()
- .map(|p| Value::Text(p.to_string_lossy().into_owned()))
- .collect::<Vec<_>>(),
- );
- let rows = query.query_map(params![worktree_id, file_paths], |row| {
- Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
- })?;
-
- for (digest, embedding) in rows.flatten() {
- embeddings_by_digest.insert(digest, embedding);
- }
- }
-
- Ok(embeddings_by_digest)
- })
- }
-
- pub fn find_or_create_worktree(
- &self,
- worktree_root_path: Arc<Path>,
- ) -> impl Future<Output = Result<i64>> {
- self.transact(move |db| {
- let mut worktree_query =
- db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
- let worktree_id = worktree_query
- .query_row(params![worktree_root_path.to_string_lossy()], |row| {
- row.get::<_, i64>(0)
- });
-
- if worktree_id.is_ok() {
- return Ok(worktree_id?);
- }
-
- // If worktree_id is Err, insert new worktree
- db.execute(
- "INSERT into worktrees (absolute_path) VALUES (?1)",
- params![worktree_root_path.to_string_lossy()],
- )?;
- Ok(db.last_insert_rowid())
- })
- }
-
- pub fn get_file_mtimes(
- &self,
- worktree_id: i64,
- ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
- self.transact(move |db| {
- let mut statement = db.prepare(
- "
- SELECT relative_path, mtime_seconds, mtime_nanos
- FROM files
- WHERE worktree_id = ?1
- ORDER BY relative_path",
- )?;
- let mut result: HashMap<PathBuf, SystemTime> = HashMap::default();
- for row in statement.query_map(params![worktree_id], |row| {
- Ok((
- row.get::<_, String>(0)?.into(),
- Timestamp {
- seconds: row.get(1)?,
- nanos: row.get(2)?,
- }
- .into(),
- ))
- })? {
- let row = row?;
- result.insert(row.0, row.1);
- }
- Ok(result)
- })
- }
-
- pub fn top_k_search(
- &self,
- query_embedding: &Embedding,
- limit: usize,
- file_ids: &[i64],
- ) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
- let file_ids = file_ids.to_vec();
- let query = query_embedding.clone().0;
- let query = Array1::from_vec(query);
- self.transact(move |db| {
- let mut query_statement = db.prepare(
- "
- SELECT
- id, embedding
- FROM
- spans
- WHERE
- file_id IN rarray(?)
- ",
- )?;
-
- let deserialized_rows = query_statement
- .query_map(params![ids_to_sql(&file_ids)], |row| {
- Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?))
- })?
- .filter_map(|row| row.ok())
- .collect::<Vec<(usize, Embedding)>>();
-
- if deserialized_rows.len() == 0 {
- return Ok(Vec::new());
- }
-
- // Get Length of Embeddings Returned
- let embedding_len = deserialized_rows[0].1 .0.len();
-
- let batch_n = 1000;
- let mut batches = Vec::new();
- let mut batch_ids = Vec::new();
- let mut batch_embeddings: Vec<f32> = Vec::new();
- deserialized_rows.iter().for_each(|(id, embedding)| {
- batch_ids.push(id);
- batch_embeddings.extend(&embedding.0);
-
- if batch_ids.len() == batch_n {
- let embeddings = std::mem::take(&mut batch_embeddings);
- let ids = std::mem::take(&mut batch_ids);
- let array = Array2::from_shape_vec((ids.len(), embedding_len), embeddings);
- match array {
- Ok(array) => {
- batches.push((ids, array));
- }
- Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
- }
- }
- });
-
- if batch_ids.len() > 0 {
- let array = Array2::from_shape_vec(
- (batch_ids.len(), embedding_len),
- batch_embeddings.clone(),
- );
- match array {
- Ok(array) => {
- batches.push((batch_ids.clone(), array));
- }
- Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
- }
- }
-
- let mut ids: Vec<usize> = Vec::new();
- let mut results = Vec::new();
- for (batch_ids, array) in batches {
- let scores = array
- .dot(&query.t())
- .to_vec()
- .iter()
- .map(|score| OrderedFloat(*score))
- .collect::<Vec<OrderedFloat<f32>>>();
- results.extend(scores);
- ids.extend(batch_ids);
- }
-
- let sorted_idx = argsort(&results);
- let mut sorted_results = Vec::new();
- let last_idx = limit.min(sorted_idx.len());
- for idx in &sorted_idx[0..last_idx] {
- sorted_results.push((ids[*idx] as i64, results[*idx]))
- }
-
- Ok(sorted_results)
- })
- }
-
- pub fn retrieve_included_file_ids(
- &self,
- worktree_ids: &[i64],
- includes: &[PathMatcher],
- excludes: &[PathMatcher],
- ) -> impl Future<Output = Result<Vec<i64>>> {
- let worktree_ids = worktree_ids.to_vec();
- let includes = includes.to_vec();
- let excludes = excludes.to_vec();
- self.transact(move |db| {
- let mut file_query = db.prepare(
- "
- SELECT
- id, relative_path
- FROM
- files
- WHERE
- worktree_id IN rarray(?)
- ",
- )?;
-
- let mut file_ids = Vec::<i64>::new();
- let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?;
-
- while let Some(row) = rows.next()? {
- let file_id = row.get(0)?;
- let relative_path = row.get_ref(1)?.as_str()?;
- let included =
- includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path));
- let excluded = excludes.iter().any(|glob| glob.is_match(relative_path));
- if included && !excluded {
- file_ids.push(file_id);
- }
- }
-
- anyhow::Ok(file_ids)
- })
- }
-
- pub fn spans_for_ids(
- &self,
- ids: &[i64],
- ) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
- let ids = ids.to_vec();
- self.transact(move |db| {
- let mut statement = db.prepare(
- "
- SELECT
- spans.id,
- files.worktree_id,
- files.relative_path,
- spans.start_byte,
- spans.end_byte
- FROM
- spans, files
- WHERE
- spans.file_id = files.id AND
- spans.id in rarray(?)
- ",
- )?;
-
- let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| {
- Ok((
- row.get::<_, i64>(0)?,
- row.get::<_, i64>(1)?,
- row.get::<_, String>(2)?.into(),
- row.get(3)?..row.get(4)?,
- ))
- })?;
-
- let mut values_by_id = HashMap::<i64, (i64, PathBuf, Range<usize>)>::default();
- for row in result_iter {
- let (id, worktree_id, path, range) = row?;
- values_by_id.insert(id, (worktree_id, path, range));
- }
-
- let mut results = Vec::with_capacity(ids.len());
- for id in &ids {
- let value = values_by_id
- .remove(id)
- .ok_or(anyhow!("missing span id {}", id))?;
- results.push(value);
- }
-
- Ok(results)
- })
- }
-}
-
-fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
- Rc::new(
- ids.iter()
- .copied()
- .map(|v| rusqlite::types::Value::from(v))
- .collect::<Vec<_>>(),
- )
-}
@@ -1,169 +0,0 @@
-use crate::{parsing::Span, JobHandle};
-use ai::embedding::EmbeddingProvider;
-use gpui::BackgroundExecutor;
-use parking_lot::Mutex;
-use smol::channel;
-use std::{mem, ops::Range, path::Path, sync::Arc, time::SystemTime};
-
-#[derive(Clone)]
-pub struct FileToEmbed {
- pub worktree_id: i64,
- pub path: Arc<Path>,
- pub mtime: SystemTime,
- pub spans: Vec<Span>,
- pub job_handle: JobHandle,
-}
-
-impl std::fmt::Debug for FileToEmbed {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- f.debug_struct("FileToEmbed")
- .field("worktree_id", &self.worktree_id)
- .field("path", &self.path)
- .field("mtime", &self.mtime)
- .field("spans", &self.spans)
- .finish_non_exhaustive()
- }
-}
-
-impl PartialEq for FileToEmbed {
- fn eq(&self, other: &Self) -> bool {
- self.worktree_id == other.worktree_id
- && self.path == other.path
- && self.mtime == other.mtime
- && self.spans == other.spans
- }
-}
-
-pub struct EmbeddingQueue {
- embedding_provider: Arc<dyn EmbeddingProvider>,
- pending_batch: Vec<FileFragmentToEmbed>,
- executor: BackgroundExecutor,
- pending_batch_token_count: usize,
- finished_files_tx: channel::Sender<FileToEmbed>,
- finished_files_rx: channel::Receiver<FileToEmbed>,
-}
-
-#[derive(Clone)]
-pub struct FileFragmentToEmbed {
- file: Arc<Mutex<FileToEmbed>>,
- span_range: Range<usize>,
-}
-
-impl EmbeddingQueue {
- pub fn new(
- embedding_provider: Arc<dyn EmbeddingProvider>,
- executor: BackgroundExecutor,
- ) -> Self {
- let (finished_files_tx, finished_files_rx) = channel::unbounded();
- Self {
- embedding_provider,
- executor,
- pending_batch: Vec::new(),
- pending_batch_token_count: 0,
- finished_files_tx,
- finished_files_rx,
- }
- }
-
- pub fn push(&mut self, file: FileToEmbed) {
- if file.spans.is_empty() {
- self.finished_files_tx.try_send(file).unwrap();
- return;
- }
-
- let file = Arc::new(Mutex::new(file));
-
- self.pending_batch.push(FileFragmentToEmbed {
- file: file.clone(),
- span_range: 0..0,
- });
-
- let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
- for (ix, span) in file.lock().spans.iter().enumerate() {
- let span_token_count = if span.embedding.is_none() {
- span.token_count
- } else {
- 0
- };
-
- let next_token_count = self.pending_batch_token_count + span_token_count;
- if next_token_count > self.embedding_provider.max_tokens_per_batch() {
- let range_end = fragment_range.end;
- self.flush();
- self.pending_batch.push(FileFragmentToEmbed {
- file: file.clone(),
- span_range: range_end..range_end,
- });
- fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
- }
-
- fragment_range.end = ix + 1;
- self.pending_batch_token_count += span_token_count;
- }
- }
-
- pub fn flush(&mut self) {
- let batch = mem::take(&mut self.pending_batch);
- self.pending_batch_token_count = 0;
- if batch.is_empty() {
- return;
- }
-
- let finished_files_tx = self.finished_files_tx.clone();
- let embedding_provider = self.embedding_provider.clone();
-
- self.executor
- .spawn(async move {
- let mut spans = Vec::new();
- for fragment in &batch {
- let file = fragment.file.lock();
- spans.extend(
- file.spans[fragment.span_range.clone()]
- .iter()
- .filter(|d| d.embedding.is_none())
- .map(|d| d.content.clone()),
- );
- }
-
- // If spans is 0, just send the fragment to the finished files if its the last one.
- if spans.is_empty() {
- for fragment in batch.clone() {
- if let Some(file) = Arc::into_inner(fragment.file) {
- finished_files_tx.try_send(file.into_inner()).unwrap();
- }
- }
- return;
- };
-
- match embedding_provider.embed_batch(spans).await {
- Ok(embeddings) => {
- let mut embeddings = embeddings.into_iter();
- for fragment in batch {
- for span in &mut fragment.file.lock().spans[fragment.span_range.clone()]
- .iter_mut()
- .filter(|d| d.embedding.is_none())
- {
- if let Some(embedding) = embeddings.next() {
- span.embedding = Some(embedding);
- } else {
- log::error!("number of embeddings != number of documents");
- }
- }
-
- if let Some(file) = Arc::into_inner(fragment.file) {
- finished_files_tx.try_send(file.into_inner()).unwrap();
- }
- }
- }
- Err(error) => {
- log::error!("{:?}", error);
- }
- }
- })
- .detach();
- }
-
- pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {
- self.finished_files_rx.clone()
- }
-}
@@ -1,414 +0,0 @@
-use ai::{
- embedding::{Embedding, EmbeddingProvider},
- models::TruncationDirection,
-};
-use anyhow::{anyhow, Result};
-use collections::HashSet;
-use language::{Grammar, Language};
-use rusqlite::{
- types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
- ToSql,
-};
-use sha1::{Digest, Sha1};
-use std::{
- borrow::Cow,
- cmp::{self, Reverse},
- ops::Range,
- path::Path,
- sync::Arc,
-};
-use tree_sitter::{Parser, QueryCursor};
-
-#[derive(Debug, PartialEq, Eq, Clone, Hash)]
-pub struct SpanDigest(pub [u8; 20]);
-
-impl FromSql for SpanDigest {
- fn column_result(value: ValueRef) -> FromSqlResult<Self> {
- let blob = value.as_blob()?;
- let bytes =
- blob.try_into()
- .map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize {
- expected_size: 20,
- blob_size: blob.len(),
- })?;
- return Ok(SpanDigest(bytes));
- }
-}
-
-impl ToSql for SpanDigest {
- fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
- self.0.to_sql()
- }
-}
-
-impl From<&'_ str> for SpanDigest {
- fn from(value: &'_ str) -> Self {
- let mut sha1 = Sha1::new();
- sha1.update(value);
- Self(sha1.finalize().into())
- }
-}
-
-#[derive(Debug, PartialEq, Clone)]
-pub struct Span {
- pub name: String,
- pub range: Range<usize>,
- pub content: String,
- pub embedding: Option<Embedding>,
- pub digest: SpanDigest,
- pub token_count: usize,
-}
-
-const CODE_CONTEXT_TEMPLATE: &str =
- "The below code snippet is from file '<path>'\n\n```<language>\n<item>\n```";
-const ENTIRE_FILE_TEMPLATE: &str =
- "The below snippet is from file '<path>'\n\n```<language>\n<item>\n```";
-const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file '<path>'\n\n<item>";
-pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = &[
- "TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML", "Scheme",
-];
-
-pub struct CodeContextRetriever {
- pub parser: Parser,
- pub cursor: QueryCursor,
- pub embedding_provider: Arc<dyn EmbeddingProvider>,
-}
-
-// Every match has an item, this represents the fundamental treesitter symbol and anchors the search
-// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication.
-// If there are preceding comments, we track this with a context capture
-// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture
-// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture
-#[derive(Debug, Clone)]
-pub struct CodeContextMatch {
- pub start_col: usize,
- pub item_range: Option<Range<usize>>,
- pub name_range: Option<Range<usize>>,
- pub context_ranges: Vec<Range<usize>>,
- pub collapse_ranges: Vec<Range<usize>>,
-}
-
-impl CodeContextRetriever {
- pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
- Self {
- parser: Parser::new(),
- cursor: QueryCursor::new(),
- embedding_provider,
- }
- }
-
- fn parse_entire_file(
- &self,
- relative_path: Option<&Path>,
- language_name: Arc<str>,
- content: &str,
- ) -> Result<Vec<Span>> {
- let document_span = ENTIRE_FILE_TEMPLATE
- .replace(
- "<path>",
- &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
- )
- .replace("<language>", language_name.as_ref())
- .replace("<item>", &content);
- let digest = SpanDigest::from(document_span.as_str());
- let model = self.embedding_provider.base_model();
- let document_span = model.truncate(
- &document_span,
- model.capacity()?,
- ai::models::TruncationDirection::End,
- )?;
- let token_count = model.count_tokens(&document_span)?;
-
- Ok(vec![Span {
- range: 0..content.len(),
- content: document_span,
- embedding: Default::default(),
- name: language_name.to_string(),
- digest,
- token_count,
- }])
- }
-
- fn parse_markdown_file(
- &self,
- relative_path: Option<&Path>,
- content: &str,
- ) -> Result<Vec<Span>> {
- let document_span = MARKDOWN_CONTEXT_TEMPLATE
- .replace(
- "<path>",
- &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
- )
- .replace("<item>", &content);
- let digest = SpanDigest::from(document_span.as_str());
-
- let model = self.embedding_provider.base_model();
- let document_span = model.truncate(
- &document_span,
- model.capacity()?,
- ai::models::TruncationDirection::End,
- )?;
- let token_count = model.count_tokens(&document_span)?;
-
- Ok(vec![Span {
- range: 0..content.len(),
- content: document_span,
- embedding: None,
- name: "Markdown".to_string(),
- digest,
- token_count,
- }])
- }
-
- fn get_matches_in_file(
- &mut self,
- content: &str,
- grammar: &Arc<Grammar>,
- ) -> Result<Vec<CodeContextMatch>> {
- let embedding_config = grammar
- .embedding_config
- .as_ref()
- .ok_or_else(|| anyhow!("no embedding queries"))?;
- self.parser.set_language(&grammar.ts_language).unwrap();
-
- let tree = self
- .parser
- .parse(&content, None)
- .ok_or_else(|| anyhow!("parsing failed"))?;
-
- let mut captures: Vec<CodeContextMatch> = Vec::new();
- let mut collapse_ranges: Vec<Range<usize>> = Vec::new();
- let mut keep_ranges: Vec<Range<usize>> = Vec::new();
- for mat in self.cursor.matches(
- &embedding_config.query,
- tree.root_node(),
- content.as_bytes(),
- ) {
- let mut start_col = 0;
- let mut item_range: Option<Range<usize>> = None;
- let mut name_range: Option<Range<usize>> = None;
- let mut context_ranges: Vec<Range<usize>> = Vec::new();
- collapse_ranges.clear();
- keep_ranges.clear();
- for capture in mat.captures {
- if capture.index == embedding_config.item_capture_ix {
- item_range = Some(capture.node.byte_range());
- start_col = capture.node.start_position().column;
- } else if Some(capture.index) == embedding_config.name_capture_ix {
- name_range = Some(capture.node.byte_range());
- } else if Some(capture.index) == embedding_config.context_capture_ix {
- context_ranges.push(capture.node.byte_range());
- } else if Some(capture.index) == embedding_config.collapse_capture_ix {
- collapse_ranges.push(capture.node.byte_range());
- } else if Some(capture.index) == embedding_config.keep_capture_ix {
- keep_ranges.push(capture.node.byte_range());
- }
- }
-
- captures.push(CodeContextMatch {
- start_col,
- item_range,
- name_range,
- context_ranges,
- collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges),
- });
- }
- Ok(captures)
- }
-
- pub fn parse_file_with_template(
- &mut self,
- relative_path: Option<&Path>,
- content: &str,
- language: Arc<Language>,
- ) -> Result<Vec<Span>> {
- let language_name = language.name();
-
- if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
- return self.parse_entire_file(relative_path, language_name, &content);
- } else if ["Markdown", "Plain Text"].contains(&language_name.as_ref()) {
- return self.parse_markdown_file(relative_path, &content);
- }
-
- let mut spans = self.parse_file(content, language)?;
- for span in &mut spans {
- let document_content = CODE_CONTEXT_TEMPLATE
- .replace(
- "<path>",
- &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()),
- )
- .replace("<language>", language_name.as_ref())
- .replace("item", &span.content);
-
- let model = self.embedding_provider.base_model();
- let document_content = model.truncate(
- &document_content,
- model.capacity()?,
- TruncationDirection::End,
- )?;
- let token_count = model.count_tokens(&document_content)?;
-
- span.content = document_content;
- span.token_count = token_count;
- }
- Ok(spans)
- }
-
- pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Span>> {
- let grammar = language
- .grammar()
- .ok_or_else(|| anyhow!("no grammar for language"))?;
-
- // Iterate through query matches
- let matches = self.get_matches_in_file(content, grammar)?;
-
- let language_scope = language.default_scope();
- let placeholder = language_scope.collapsed_placeholder();
-
- let mut spans = Vec::new();
- let mut collapsed_ranges_within = Vec::new();
- let mut parsed_name_ranges = HashSet::default();
- for (i, context_match) in matches.iter().enumerate() {
- // Items which are collapsible but not embeddable have no item range
- let item_range = if let Some(item_range) = context_match.item_range.clone() {
- item_range
- } else {
- continue;
- };
-
- // Checks for deduplication
- let name;
- if let Some(name_range) = context_match.name_range.clone() {
- name = content
- .get(name_range.clone())
- .map_or(String::new(), |s| s.to_string());
- if parsed_name_ranges.contains(&name_range) {
- continue;
- }
- parsed_name_ranges.insert(name_range);
- } else {
- name = String::new();
- }
-
- collapsed_ranges_within.clear();
- 'outer: for remaining_match in &matches[(i + 1)..] {
- for collapsed_range in &remaining_match.collapse_ranges {
- if item_range.start <= collapsed_range.start
- && item_range.end >= collapsed_range.end
- {
- collapsed_ranges_within.push(collapsed_range.clone());
- } else {
- break 'outer;
- }
- }
- }
-
- collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
-
- let mut span_content = String::new();
- for context_range in &context_match.context_ranges {
- add_content_from_range(
- &mut span_content,
- content,
- context_range.clone(),
- context_match.start_col,
- );
- span_content.push_str("\n");
- }
-
- let mut offset = item_range.start;
- for collapsed_range in &collapsed_ranges_within {
- if collapsed_range.start > offset {
- add_content_from_range(
- &mut span_content,
- content,
- offset..collapsed_range.start,
- context_match.start_col,
- );
- offset = collapsed_range.start;
- }
-
- if collapsed_range.end > offset {
- span_content.push_str(placeholder);
- offset = collapsed_range.end;
- }
- }
-
- if offset < item_range.end {
- add_content_from_range(
- &mut span_content,
- content,
- offset..item_range.end,
- context_match.start_col,
- );
- }
-
- let sha1 = SpanDigest::from(span_content.as_str());
- spans.push(Span {
- name,
- content: span_content,
- range: item_range.clone(),
- embedding: None,
- digest: sha1,
- token_count: 0,
- })
- }
-
- return Ok(spans);
- }
-}
-
-pub(crate) fn subtract_ranges(
- ranges: &[Range<usize>],
- ranges_to_subtract: &[Range<usize>],
-) -> Vec<Range<usize>> {
- let mut result = Vec::new();
-
- let mut ranges_to_subtract = ranges_to_subtract.iter().peekable();
-
- for range in ranges {
- let mut offset = range.start;
-
- while offset < range.end {
- if let Some(range_to_subtract) = ranges_to_subtract.peek() {
- if offset < range_to_subtract.start {
- let next_offset = cmp::min(range_to_subtract.start, range.end);
- result.push(offset..next_offset);
- offset = next_offset;
- } else {
- let next_offset = cmp::min(range_to_subtract.end, range.end);
- offset = next_offset;
- }
-
- if offset >= range_to_subtract.end {
- ranges_to_subtract.next();
- }
- } else {
- result.push(offset..range.end);
- offset = range.end;
- }
- }
- }
-
- result
-}
-
-fn add_content_from_range(
- output: &mut String,
- content: &str,
- range: Range<usize>,
- start_col: usize,
-) {
- for mut line in content.get(range.clone()).unwrap_or("").lines() {
- for _ in 0..start_col {
- if line.starts_with(' ') {
- line = &line[1..];
- } else {
- break;
- }
- }
- output.push_str(line);
- output.push('\n');
- }
- output.pop();
-}
@@ -1,1308 +0,0 @@
-mod db;
-mod embedding_queue;
-mod parsing;
-pub mod semantic_index_settings;
-
-#[cfg(test)]
-mod semantic_index_tests;
-
-use crate::semantic_index_settings::SemanticIndexSettings;
-use ai::embedding::{Embedding, EmbeddingProvider};
-use ai::providers::open_ai::{OpenAiEmbeddingProvider, OPEN_AI_API_URL};
-use anyhow::{anyhow, Context as _, Result};
-use collections::{BTreeMap, HashMap, HashSet};
-use db::VectorDatabase;
-use embedding_queue::{EmbeddingQueue, FileToEmbed};
-use futures::{future, FutureExt, StreamExt};
-use gpui::{
- AppContext, AsyncAppContext, BorrowWindow, Context, Global, Model, ModelContext, Task,
- ViewContext, WeakModel,
-};
-use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
-use lazy_static::lazy_static;
-use ordered_float::OrderedFloat;
-use parking_lot::Mutex;
-use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
-use postage::watch;
-use project::{Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
-use release_channel::ReleaseChannel;
-use settings::Settings;
-use smol::channel;
-use std::{
- cmp::Reverse,
- env,
- future::Future,
- mem,
- ops::Range,
- path::{Path, PathBuf},
- sync::{Arc, Weak},
- time::{Duration, Instant, SystemTime},
-};
-use util::paths::PathMatcher;
-use util::{http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt};
-use workspace::Workspace;
-
-const SEMANTIC_INDEX_VERSION: usize = 11;
-const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
-const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
-
-lazy_static! {
- static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
-}
-
-pub fn init(
- fs: Arc<dyn Fs>,
- http_client: Arc<dyn HttpClient>,
- language_registry: Arc<LanguageRegistry>,
- cx: &mut AppContext,
-) {
- SemanticIndexSettings::register(cx);
-
- let db_file_path = EMBEDDINGS_DIR
- .join(Path::new(ReleaseChannel::global(cx).dev_name()))
- .join("embeddings_db");
-
- cx.observe_new_views(
- |workspace: &mut Workspace, cx: &mut ViewContext<Workspace>| {
- let Some(semantic_index) = SemanticIndex::global(cx) else {
- return;
- };
- let project = workspace.project().clone();
-
- if project.read(cx).is_local() {
- cx.app_mut()
- .spawn(|mut cx| async move {
- let previously_indexed = semantic_index
- .update(&mut cx, |index, cx| {
- index.project_previously_indexed(&project, cx)
- })?
- .await?;
- if previously_indexed {
- semantic_index
- .update(&mut cx, |index, cx| index.index_project(project, cx))?
- .await?;
- }
- anyhow::Ok(())
- })
- .detach_and_log_err(cx);
- }
- },
- )
- .detach();
-
- cx.spawn(move |cx| async move {
- let embedding_provider = OpenAiEmbeddingProvider::new(
- // TODO: We should read it from config, but I'm not sure whether to reuse `openai_api_url` in assistant settings or not
- OPEN_AI_API_URL.to_string(),
- http_client,
- cx.background_executor().clone(),
- )
- .await;
- let semantic_index = SemanticIndex::new(
- fs,
- db_file_path,
- Arc::new(embedding_provider),
- language_registry,
- cx.clone(),
- )
- .await?;
-
- cx.update(|cx| cx.set_global(GlobalSemanticIndex(semantic_index.clone())))?;
-
- anyhow::Ok(())
- })
- .detach();
-}
-
-#[derive(Copy, Clone, Debug)]
-pub enum SemanticIndexStatus {
- NotAuthenticated,
- NotIndexed,
- Indexed,
- Indexing {
- remaining_files: usize,
- rate_limit_expiry: Option<Instant>,
- },
-}
-
-pub struct SemanticIndex {
- fs: Arc<dyn Fs>,
- db: VectorDatabase,
- embedding_provider: Arc<dyn EmbeddingProvider>,
- language_registry: Arc<LanguageRegistry>,
- parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
- _embedding_task: Task<()>,
- _parsing_files_tasks: Vec<Task<()>>,
- projects: HashMap<WeakModel<Project>, ProjectState>,
-}
-
-struct GlobalSemanticIndex(Model<SemanticIndex>);
-
-impl Global for GlobalSemanticIndex {}
-
-struct ProjectState {
- worktrees: HashMap<WorktreeId, WorktreeState>,
- pending_file_count_rx: watch::Receiver<usize>,
- pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
- pending_index: usize,
- _subscription: gpui::Subscription,
- _observe_pending_file_count: Task<()>,
-}
-
-enum WorktreeState {
- Registering(RegisteringWorktreeState),
- Registered(RegisteredWorktreeState),
-}
-
-impl WorktreeState {
- fn is_registered(&self) -> bool {
- matches!(self, Self::Registered(_))
- }
-
- fn paths_changed(
- &mut self,
- changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
- worktree: &Worktree,
- ) {
- let changed_paths = match self {
- Self::Registering(state) => &mut state.changed_paths,
- Self::Registered(state) => &mut state.changed_paths,
- };
-
- for (path, entry_id, change) in changes.iter() {
- let Some(entry) = worktree.entry_for_id(*entry_id) else {
- continue;
- };
- let Some(mtime) = entry.mtime else {
- continue;
- };
- if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
- continue;
- }
- changed_paths.insert(
- path.clone(),
- ChangedPathInfo {
- mtime,
- is_deleted: *change == PathChange::Removed,
- },
- );
- }
- }
-}
-
-struct RegisteringWorktreeState {
- changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
- done_rx: watch::Receiver<Option<()>>,
- _registration: Task<()>,
-}
-
-impl RegisteringWorktreeState {
- fn done(&self) -> impl Future<Output = ()> {
- let mut done_rx = self.done_rx.clone();
- async move {
- while let Some(result) = done_rx.next().await {
- if result.is_some() {
- break;
- }
- }
- }
- }
-}
-
-struct RegisteredWorktreeState {
- db_id: i64,
- changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
-}
-
-struct ChangedPathInfo {
- mtime: SystemTime,
- is_deleted: bool,
-}
-
-#[derive(Clone)]
-pub struct JobHandle {
- /// The outer Arc is here to count the clones of a JobHandle instance;
- /// when the last handle to a given job is dropped, we decrement a counter (just once).
- tx: Arc<Weak<Mutex<watch::Sender<usize>>>>,
-}
-
-impl JobHandle {
- fn new(tx: &Arc<Mutex<watch::Sender<usize>>>) -> Self {
- *tx.lock().borrow_mut() += 1;
- Self {
- tx: Arc::new(Arc::downgrade(&tx)),
- }
- }
-}
-
-impl ProjectState {
- fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
- let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
- let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
- Self {
- worktrees: Default::default(),
- pending_file_count_rx: pending_file_count_rx.clone(),
- pending_file_count_tx,
- pending_index: 0,
- _subscription: subscription,
- _observe_pending_file_count: cx.spawn({
- let mut pending_file_count_rx = pending_file_count_rx.clone();
- |this, mut cx| async move {
- while let Some(_) = pending_file_count_rx.next().await {
- if this.update(&mut cx, |_, cx| cx.notify()).is_err() {
- break;
- }
- }
- }
- }),
- }
- }
-
- fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
- self.worktrees
- .iter()
- .find_map(|(worktree_id, worktree_state)| match worktree_state {
- WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
- _ => None,
- })
- }
-}
-
-#[derive(Clone)]
-pub struct PendingFile {
- worktree_db_id: i64,
- relative_path: Arc<Path>,
- absolute_path: PathBuf,
- language: Option<Arc<Language>>,
- modified_time: SystemTime,
- job_handle: JobHandle,
-}
-
-#[derive(Clone)]
-pub struct SearchResult {
- pub buffer: Model<Buffer>,
- pub range: Range<Anchor>,
- pub similarity: OrderedFloat<f32>,
-}
-
-impl SemanticIndex {
- pub fn global(cx: &mut AppContext) -> Option<Model<SemanticIndex>> {
- cx.try_global::<GlobalSemanticIndex>()
- .map(|semantic_index| semantic_index.0.clone())
- }
-
- pub fn authenticate(&mut self, cx: &mut AppContext) -> Task<bool> {
- if !self.embedding_provider.has_credentials() {
- let embedding_provider = self.embedding_provider.clone();
- cx.spawn(|cx| async move {
- if let Some(retrieve_credentials) = cx
- .update(|cx| embedding_provider.retrieve_credentials(cx))
- .log_err()
- {
- retrieve_credentials.await;
- }
-
- embedding_provider.has_credentials()
- })
- } else {
- Task::ready(true)
- }
- }
-
- pub fn is_authenticated(&self) -> bool {
- self.embedding_provider.has_credentials()
- }
-
- pub fn enabled(cx: &AppContext) -> bool {
- SemanticIndexSettings::get_global(cx).enabled
- }
-
- pub fn status(&self, project: &Model<Project>) -> SemanticIndexStatus {
- if !self.is_authenticated() {
- return SemanticIndexStatus::NotAuthenticated;
- }
-
- if let Some(project_state) = self.projects.get(&project.downgrade()) {
- if project_state
- .worktrees
- .values()
- .all(|worktree| worktree.is_registered())
- && project_state.pending_index == 0
- {
- SemanticIndexStatus::Indexed
- } else {
- SemanticIndexStatus::Indexing {
- remaining_files: *project_state.pending_file_count_rx.borrow(),
- rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
- }
- }
- } else {
- SemanticIndexStatus::NotIndexed
- }
- }
-
- pub async fn new(
- fs: Arc<dyn Fs>,
- database_path: PathBuf,
- embedding_provider: Arc<dyn EmbeddingProvider>,
- language_registry: Arc<LanguageRegistry>,
- mut cx: AsyncAppContext,
- ) -> Result<Model<Self>> {
- let t0 = Instant::now();
- let database_path = Arc::from(database_path);
- let db = VectorDatabase::new(fs.clone(), database_path, cx.background_executor().clone())
- .await?;
-
- log::trace!(
- "db initialization took {:?} milliseconds",
- t0.elapsed().as_millis()
- );
-
- cx.new_model(|cx| {
- let t0 = Instant::now();
- let embedding_queue =
- EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor().clone());
- let _embedding_task = cx.background_executor().spawn({
- let embedded_files = embedding_queue.finished_files();
- let db = db.clone();
- async move {
- while let Ok(file) = embedded_files.recv().await {
- db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
- .await
- .log_err();
- }
- }
- });
-
- // Parse files into embeddable spans.
- let (parsing_files_tx, parsing_files_rx) =
- channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
- let embedding_queue = Arc::new(Mutex::new(embedding_queue));
- let mut _parsing_files_tasks = Vec::new();
- for _ in 0..cx.background_executor().num_cpus() {
- let fs = fs.clone();
- let mut parsing_files_rx = parsing_files_rx.clone();
- let embedding_provider = embedding_provider.clone();
- let embedding_queue = embedding_queue.clone();
- let background = cx.background_executor().clone();
- _parsing_files_tasks.push(cx.background_executor().spawn(async move {
- let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
- loop {
- let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse();
- let mut next_file_to_parse = parsing_files_rx.next().fuse();
- futures::select_biased! {
- next_file_to_parse = next_file_to_parse => {
- if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse {
- Self::parse_file(
- &fs,
- pending_file,
- &mut retriever,
- &embedding_queue,
- &embeddings_for_digest,
- )
- .await
- } else {
- break;
- }
- },
- _ = timer => {
- embedding_queue.lock().flush();
- }
- }
- }
- }));
- }
-
- log::trace!(
- "semantic index task initialization took {:?} milliseconds",
- t0.elapsed().as_millis()
- );
- Self {
- fs,
- db,
- embedding_provider,
- language_registry,
- parsing_files_tx,
- _embedding_task,
- _parsing_files_tasks,
- projects: Default::default(),
- }
- })
- }
-
- async fn parse_file(
- fs: &Arc<dyn Fs>,
- pending_file: PendingFile,
- retriever: &mut CodeContextRetriever,
- embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
- embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
- ) {
- let Some(language) = pending_file.language else {
- return;
- };
-
- if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
- if let Some(mut spans) = retriever
- .parse_file_with_template(Some(&pending_file.relative_path), &content, language)
- .log_err()
- {
- log::trace!(
- "parsed path {:?}: {} spans",
- pending_file.relative_path,
- spans.len()
- );
-
- for span in &mut spans {
- if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
- span.embedding = Some(embedding.to_owned());
- }
- }
-
- embedding_queue.lock().push(FileToEmbed {
- worktree_id: pending_file.worktree_db_id,
- path: pending_file.relative_path,
- mtime: pending_file.modified_time,
- job_handle: pending_file.job_handle,
- spans,
- });
- }
- }
- }
-
- pub fn project_previously_indexed(
- &mut self,
- project: &Model<Project>,
- cx: &mut ModelContext<Self>,
- ) -> Task<Result<bool>> {
- let worktrees_indexed_previously = project
- .read(cx)
- .worktrees()
- .map(|worktree| {
- self.db
- .worktree_previously_indexed(&worktree.read(cx).abs_path())
- })
- .collect::<Vec<_>>();
- cx.spawn(|_, _cx| async move {
- let worktree_indexed_previously =
- futures::future::join_all(worktrees_indexed_previously).await;
-
- Ok(worktree_indexed_previously
- .iter()
- .filter(|worktree| worktree.is_ok())
- .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned())))
- })
- }
-
- fn project_entries_changed(
- &mut self,
- project: Model<Project>,
- worktree_id: WorktreeId,
- changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
- cx: &mut ModelContext<Self>,
- ) {
- let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) else {
- return;
- };
- let project = project.downgrade();
- let Some(project_state) = self.projects.get_mut(&project) else {
- return;
- };
-
- let worktree = worktree.read(cx);
- let worktree_state =
- if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
- worktree_state
- } else {
- return;
- };
- worktree_state.paths_changed(changes, worktree);
- if let WorktreeState::Registered(_) = worktree_state {
- cx.spawn(|this, mut cx| async move {
- cx.background_executor()
- .timer(BACKGROUND_INDEXING_DELAY)
- .await;
- if let Some((this, project)) = this.upgrade().zip(project.upgrade()) {
- this.update(&mut cx, |this, cx| {
- this.index_project(project, cx).detach_and_log_err(cx)
- })?;
- }
- anyhow::Ok(())
- })
- .detach_and_log_err(cx);
- }
- }
-
- fn register_worktree(
- &mut self,
- project: Model<Project>,
- worktree: Model<Worktree>,
- cx: &mut ModelContext<Self>,
- ) {
- let project = project.downgrade();
- let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
- project_state
- } else {
- return;
- };
- let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
- worktree
- } else {
- return;
- };
- let worktree_abs_path = worktree.abs_path().clone();
- let scan_complete = worktree.scan_complete();
- let worktree_id = worktree.id();
- let db = self.db.clone();
- let language_registry = self.language_registry.clone();
- let (mut done_tx, done_rx) = watch::channel();
- let registration = cx.spawn(|this, mut cx| {
- async move {
- let register = async {
- scan_complete.await;
- let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
- let mut file_mtimes = db.get_file_mtimes(db_id).await?;
- let worktree = if let Some(project) = project.upgrade() {
- project
- .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
- .ok()
- .flatten()
- .context("worktree not found")?
- } else {
- return anyhow::Ok(());
- };
- let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot())?;
- let mut changed_paths = cx
- .background_executor()
- .spawn(async move {
- let mut changed_paths = BTreeMap::new();
- for file in worktree.files(false, 0) {
- let absolute_path = worktree.absolutize(&file.path)?;
-
- if file.is_external || file.is_ignored || file.is_symlink {
- continue;
- }
-
- if let Ok(language) = language_registry
- .language_for_file_path(&absolute_path)
- .await
- {
- // Test if file is valid parseable file
- if !PARSEABLE_ENTIRE_FILE_TYPES
- .contains(&language.name().as_ref())
- && &language.name().as_ref() != &"Markdown"
- && language
- .grammar()
- .and_then(|grammar| grammar.embedding_config.as_ref())
- .is_none()
- {
- continue;
- }
- let Some(new_mtime) = file.mtime else {
- continue;
- };
-
- let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
- let already_stored = stored_mtime == Some(new_mtime);
-
- if !already_stored {
- changed_paths.insert(
- file.path.clone(),
- ChangedPathInfo {
- mtime: new_mtime,
- is_deleted: false,
- },
- );
- }
- }
- }
-
- // Clean up entries from database that are no longer in the worktree.
- for (path, mtime) in file_mtimes {
- changed_paths.insert(
- path.into(),
- ChangedPathInfo {
- mtime,
- is_deleted: true,
- },
- );
- }
-
- anyhow::Ok(changed_paths)
- })
- .await?;
- this.update(&mut cx, |this, cx| {
- let project_state = this
- .projects
- .get_mut(&project)
- .context("project not registered")?;
- let project = project.upgrade().context("project was dropped")?;
-
- if let Some(WorktreeState::Registering(state)) =
- project_state.worktrees.remove(&worktree_id)
- {
- changed_paths.extend(state.changed_paths);
- }
- project_state.worktrees.insert(
- worktree_id,
- WorktreeState::Registered(RegisteredWorktreeState {
- db_id,
- changed_paths,
- }),
- );
- this.index_project(project, cx).detach_and_log_err(cx);
-
- anyhow::Ok(())
- })??;
-
- anyhow::Ok(())
- };
-
- if register.await.log_err().is_none() {
- // Stop tracking this worktree if the registration failed.
- this.update(&mut cx, |this, _| {
- if let Some(project_state) = this.projects.get_mut(&project) {
- project_state.worktrees.remove(&worktree_id);
- }
- })
- .ok();
- }
-
- *done_tx.borrow_mut() = Some(());
- }
- });
- project_state.worktrees.insert(
- worktree_id,
- WorktreeState::Registering(RegisteringWorktreeState {
- changed_paths: Default::default(),
- done_rx,
- _registration: registration,
- }),
- );
- }
-
- fn project_worktrees_changed(&mut self, project: Model<Project>, cx: &mut ModelContext<Self>) {
- let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
- {
- project_state
- } else {
- return;
- };
-
- let mut worktrees = project
- .read(cx)
- .worktrees()
- .filter(|worktree| worktree.read(cx).is_local())
- .collect::<Vec<_>>();
- let worktree_ids = worktrees
- .iter()
- .map(|worktree| worktree.read(cx).id())
- .collect::<HashSet<_>>();
-
- // Remove worktrees that are no longer present
- project_state
- .worktrees
- .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
-
- // Register new worktrees
- worktrees.retain(|worktree| {
- let worktree_id = worktree.read(cx).id();
- !project_state.worktrees.contains_key(&worktree_id)
- });
- for worktree in worktrees {
- self.register_worktree(project.clone(), worktree, cx);
- }
- }
-
- pub fn pending_file_count(&self, project: &Model<Project>) -> Option<watch::Receiver<usize>> {
- Some(
- self.projects
- .get(&project.downgrade())?
- .pending_file_count_rx
- .clone(),
- )
- }
-
- pub fn search_project(
- &mut self,
- project: Model<Project>,
- query: String,
- limit: usize,
- includes: Vec<PathMatcher>,
- excludes: Vec<PathMatcher>,
- cx: &mut ModelContext<Self>,
- ) -> Task<Result<Vec<SearchResult>>> {
- if query.is_empty() {
- return Task::ready(Ok(Vec::new()));
- }
-
- let index = self.index_project(project.clone(), cx);
- let embedding_provider = self.embedding_provider.clone();
-
- cx.spawn(|this, mut cx| async move {
- index.await?;
- let t0 = Instant::now();
-
- let query = embedding_provider
- .embed_batch(vec![query])
- .await?
- .pop()
- .context("could not embed query")?;
- log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
-
- let search_start = Instant::now();
- let modified_buffer_results = this.update(&mut cx, |this, cx| {
- this.search_modified_buffers(
- &project,
- query.clone(),
- limit,
- &includes,
- &excludes,
- cx,
- )
- })?;
- let file_results = this.update(&mut cx, |this, cx| {
- this.search_files(project, query, limit, includes, excludes, cx)
- })?;
- let (modified_buffer_results, file_results) =
- futures::join!(modified_buffer_results, file_results);
-
- // Weave together the results from modified buffers and files.
- let mut results = Vec::new();
- let mut modified_buffers = HashSet::default();
- for result in modified_buffer_results.log_err().unwrap_or_default() {
- modified_buffers.insert(result.buffer.clone());
- results.push(result);
- }
- for result in file_results.log_err().unwrap_or_default() {
- if !modified_buffers.contains(&result.buffer) {
- results.push(result);
- }
- }
- results.sort_by_key(|result| Reverse(result.similarity));
- results.truncate(limit);
- log::trace!("Semantic search took {:?}", search_start.elapsed());
- Ok(results)
- })
- }
-
- pub fn search_files(
- &mut self,
- project: Model<Project>,
- query: Embedding,
- limit: usize,
- includes: Vec<PathMatcher>,
- excludes: Vec<PathMatcher>,
- cx: &mut ModelContext<Self>,
- ) -> Task<Result<Vec<SearchResult>>> {
- let db_path = self.db.path().clone();
- let fs = self.fs.clone();
- cx.spawn(|this, mut cx| async move {
- let database = VectorDatabase::new(
- fs.clone(),
- db_path.clone(),
- cx.background_executor().clone(),
- )
- .await?;
-
- let worktree_db_ids = this.read_with(&cx, |this, _| {
- let project_state = this
- .projects
- .get(&project.downgrade())
- .context("project was not indexed")?;
- let worktree_db_ids = project_state
- .worktrees
- .values()
- .filter_map(|worktree| {
- if let WorktreeState::Registered(worktree) = worktree {
- Some(worktree.db_id)
- } else {
- None
- }
- })
- .collect::<Vec<i64>>();
- anyhow::Ok(worktree_db_ids)
- })??;
-
- let file_ids = database
- .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
- .await?;
-
- let batch_n = cx.background_executor().num_cpus();
- let ids_len = file_ids.clone().len();
- let minimum_batch_size = 50;
-
- let batch_size = {
- let size = ids_len / batch_n;
- if size < minimum_batch_size {
- minimum_batch_size
- } else {
- size
- }
- };
-
- let mut batch_results = Vec::new();
- for batch in file_ids.chunks(batch_size) {
- let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
- let fs = fs.clone();
- let db_path = db_path.clone();
- let query = query.clone();
- if let Some(db) =
- VectorDatabase::new(fs, db_path.clone(), cx.background_executor().clone())
- .await
- .log_err()
- {
- batch_results.push(async move {
- db.top_k_search(&query, limit, batch.as_slice()).await
- });
- }
- }
-
- let batch_results = futures::future::join_all(batch_results).await;
-
- let mut results = Vec::new();
- for batch_result in batch_results {
- if batch_result.is_ok() {
- for (id, similarity) in batch_result.unwrap() {
- let ix = match results
- .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
- {
- Ok(ix) => ix,
- Err(ix) => ix,
- };
-
- results.insert(ix, (id, similarity));
- results.truncate(limit);
- }
- }
- }
-
- let ids = results.iter().map(|(id, _)| *id).collect::<Vec<i64>>();
- let scores = results
- .into_iter()
- .map(|(_, score)| score)
- .collect::<Vec<_>>();
- let spans = database.spans_for_ids(ids.as_slice()).await?;
-
- let mut tasks = Vec::new();
- let mut ranges = Vec::new();
- let weak_project = project.downgrade();
- project.update(&mut cx, |project, cx| {
- let this = this.upgrade().context("index was dropped")?;
- for (worktree_db_id, file_path, byte_range) in spans {
- let project_state =
- if let Some(state) = this.read(cx).projects.get(&weak_project) {
- state
- } else {
- return Err(anyhow!("project not added"));
- };
- if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) {
- tasks.push(project.open_buffer((worktree_id, file_path), cx));
- ranges.push(byte_range);
- }
- }
-
- Ok(())
- })??;
-
- let buffers = futures::future::join_all(tasks).await;
- Ok(buffers
- .into_iter()
- .zip(ranges)
- .zip(scores)
- .filter_map(|((buffer, range), similarity)| {
- let buffer = buffer.log_err()?;
- let range = buffer
- .read_with(&cx, |buffer, _| {
- let start = buffer.clip_offset(range.start, Bias::Left);
- let end = buffer.clip_offset(range.end, Bias::Right);
- buffer.anchor_before(start)..buffer.anchor_after(end)
- })
- .log_err()?;
- Some(SearchResult {
- buffer,
- range,
- similarity,
- })
- })
- .collect())
- })
- }
-
- fn search_modified_buffers(
- &self,
- project: &Model<Project>,
- query: Embedding,
- limit: usize,
- includes: &[PathMatcher],
- excludes: &[PathMatcher],
- cx: &mut ModelContext<Self>,
- ) -> Task<Result<Vec<SearchResult>>> {
- let modified_buffers = project
- .read(cx)
- .opened_buffers()
- .into_iter()
- .filter_map(|buffer_handle| {
- let buffer = buffer_handle.read(cx);
- let snapshot = buffer.snapshot();
- let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| {
- excludes.iter().any(|matcher| matcher.is_match(&path))
- });
-
- let included = if includes.len() == 0 {
- true
- } else {
- snapshot.resolve_file_path(cx, false).map_or(false, |path| {
- includes.iter().any(|matcher| matcher.is_match(&path))
- })
- };
-
- if buffer.is_dirty() && !excluded && included {
- Some((buffer_handle, snapshot))
- } else {
- None
- }
- })
- .collect::<HashMap<_, _>>();
-
- let embedding_provider = self.embedding_provider.clone();
- let fs = self.fs.clone();
- let db_path = self.db.path().clone();
- let background = cx.background_executor().clone();
- cx.background_executor().spawn(async move {
- let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
- let mut results = Vec::<SearchResult>::new();
-
- let mut retriever = CodeContextRetriever::new(embedding_provider.clone());
- for (buffer, snapshot) in modified_buffers {
- let language = snapshot
- .language_at(0)
- .cloned()
- .unwrap_or_else(|| language::PLAIN_TEXT.clone());
- let mut spans = retriever
- .parse_file_with_template(None, &snapshot.text(), language)
- .log_err()
- .unwrap_or_default();
- if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
- .await
- .log_err()
- .is_some()
- {
- for span in spans {
- let similarity = span.embedding.unwrap().similarity(&query);
- let ix = match results
- .binary_search_by_key(&Reverse(similarity), |result| {
- Reverse(result.similarity)
- }) {
- Ok(ix) => ix,
- Err(ix) => ix,
- };
-
- let range = {
- let start = snapshot.clip_offset(span.range.start, Bias::Left);
- let end = snapshot.clip_offset(span.range.end, Bias::Right);
- snapshot.anchor_before(start)..snapshot.anchor_after(end)
- };
-
- results.insert(
- ix,
- SearchResult {
- buffer: buffer.clone(),
- range,
- similarity,
- },
- );
- results.truncate(limit);
- }
- }
- }
-
- Ok(results)
- })
- }
-
- pub fn index_project(
- &mut self,
- project: Model<Project>,
- cx: &mut ModelContext<Self>,
- ) -> Task<Result<()>> {
- if self.is_authenticated() {
- self.index_project_internal(project, cx)
- } else {
- let authenticate = self.authenticate(cx);
- cx.spawn(|this, mut cx| async move {
- if authenticate.await {
- this.update(&mut cx, |this, cx| this.index_project_internal(project, cx))?
- .await
- } else {
- Err(anyhow!("user is not authenticated"))
- }
- })
- }
- }
-
- fn index_project_internal(
- &mut self,
- project: Model<Project>,
- cx: &mut ModelContext<Self>,
- ) -> Task<Result<()>> {
- if !self.projects.contains_key(&project.downgrade()) {
- let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
- project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
- this.project_worktrees_changed(project.clone(), cx);
- }
- project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
- this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
- }
- _ => {}
- });
- let project_state = ProjectState::new(subscription, cx);
- self.projects.insert(project.downgrade(), project_state);
- self.project_worktrees_changed(project.clone(), cx);
- }
- let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
- project_state.pending_index += 1;
- cx.notify();
-
- let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
- let db = self.db.clone();
- let language_registry = self.language_registry.clone();
- let parsing_files_tx = self.parsing_files_tx.clone();
- let worktree_registration = self.wait_for_worktree_registration(&project, cx);
-
- cx.spawn(|this, mut cx| async move {
- worktree_registration.await?;
-
- let mut pending_files = Vec::new();
- let mut files_to_delete = Vec::new();
- this.update(&mut cx, |this, cx| {
- let project_state = this
- .projects
- .get_mut(&project.downgrade())
- .context("project was dropped")?;
- let pending_file_count_tx = &project_state.pending_file_count_tx;
-
- project_state
- .worktrees
- .retain(|worktree_id, worktree_state| {
- let worktree = if let Some(worktree) =
- project.read(cx).worktree_for_id(*worktree_id, cx)
- {
- worktree
- } else {
- return false;
- };
- let worktree_state =
- if let WorktreeState::Registered(worktree_state) = worktree_state {
- worktree_state
- } else {
- return true;
- };
-
- for (path, info) in &worktree_state.changed_paths {
- if info.is_deleted {
- files_to_delete.push((worktree_state.db_id, path.clone()));
- } else if let Ok(absolute_path) = worktree.read(cx).absolutize(path) {
- let job_handle = JobHandle::new(pending_file_count_tx);
- pending_files.push(PendingFile {
- absolute_path,
- relative_path: path.clone(),
- language: None,
- job_handle,
- modified_time: info.mtime,
- worktree_db_id: worktree_state.db_id,
- });
- }
- }
- worktree_state.changed_paths.clear();
- true
- });
-
- anyhow::Ok(())
- })??;
-
- cx.background_executor()
- .spawn(async move {
- for (worktree_db_id, path) in files_to_delete {
- db.delete_file(worktree_db_id, path).await.log_err();
- }
-
- let embeddings_for_digest = {
- let mut files = HashMap::default();
- for pending_file in &pending_files {
- files
- .entry(pending_file.worktree_db_id)
- .or_insert(Vec::new())
- .push(pending_file.relative_path.clone());
- }
- Arc::new(
- db.embeddings_for_files(files)
- .await
- .log_err()
- .unwrap_or_default(),
- )
- };
-
- for mut pending_file in pending_files {
- if let Ok(language) = language_registry
- .language_for_file_path(&pending_file.relative_path)
- .await
- {
- if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
- && &language.name().as_ref() != &"Markdown"
- && language
- .grammar()
- .and_then(|grammar| grammar.embedding_config.as_ref())
- .is_none()
- {
- continue;
- }
- pending_file.language = Some(language);
- }
- parsing_files_tx
- .try_send((embeddings_for_digest.clone(), pending_file))
- .ok();
- }
-
- // Wait until we're done indexing.
- while let Some(count) = pending_file_count_rx.next().await {
- if count == 0 {
- break;
- }
- }
- })
- .await;
-
- this.update(&mut cx, |this, cx| {
- let project_state = this
- .projects
- .get_mut(&project.downgrade())
- .context("project was dropped")?;
- project_state.pending_index -= 1;
- cx.notify();
- anyhow::Ok(())
- })??;
-
- Ok(())
- })
- }
-
- fn wait_for_worktree_registration(
- &self,
- project: &Model<Project>,
- cx: &mut ModelContext<Self>,
- ) -> Task<Result<()>> {
- let project = project.downgrade();
- cx.spawn(|this, cx| async move {
- loop {
- let mut pending_worktrees = Vec::new();
- this.upgrade()
- .context("semantic index dropped")?
- .read_with(&cx, |this, _| {
- if let Some(project) = this.projects.get(&project) {
- for worktree in project.worktrees.values() {
- if let WorktreeState::Registering(worktree) = worktree {
- pending_worktrees.push(worktree.done());
- }
- }
- }
- })?;
-
- if pending_worktrees.is_empty() {
- break;
- } else {
- future::join_all(pending_worktrees).await;
- }
- }
- Ok(())
- })
- }
-
- async fn embed_spans(
- spans: &mut [Span],
- embedding_provider: &dyn EmbeddingProvider,
- db: &VectorDatabase,
- ) -> Result<()> {
- let mut batch = Vec::new();
- let mut batch_tokens = 0;
- let mut embeddings = Vec::new();
-
- let digests = spans
- .iter()
- .map(|span| span.digest.clone())
- .collect::<Vec<_>>();
- let embeddings_for_digests = db
- .embeddings_for_digests(digests)
- .await
- .log_err()
- .unwrap_or_default();
-
- for span in &*spans {
- if embeddings_for_digests.contains_key(&span.digest) {
- continue;
- };
-
- if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
- let batch_embeddings = embedding_provider
- .embed_batch(mem::take(&mut batch))
- .await?;
- embeddings.extend(batch_embeddings);
- batch_tokens = 0;
- }
-
- batch_tokens += span.token_count;
- batch.push(span.content.clone());
- }
-
- if !batch.is_empty() {
- let batch_embeddings = embedding_provider
- .embed_batch(mem::take(&mut batch))
- .await?;
-
- embeddings.extend(batch_embeddings);
- }
-
- let mut embeddings = embeddings.into_iter();
- for span in spans {
- let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) {
- Some(embedding.clone())
- } else {
- embeddings.next()
- };
- let embedding = embedding.context("failed to embed spans")?;
- span.embedding = Some(embedding);
- }
- Ok(())
- }
-}
-
-impl Drop for JobHandle {
- fn drop(&mut self) {
- if let Some(inner) = Arc::get_mut(&mut self.tx) {
- // This is the last instance of the JobHandle (regardless of its origin - whether it was cloned or not)
- if let Some(tx) = inner.upgrade() {
- let mut tx = tx.lock();
- *tx.borrow_mut() -= 1;
- }
- }
- }
-}
-
-#[cfg(test)]
-mod tests {
-
- use super::*;
- #[test]
- fn test_job_handle() {
- let (job_count_tx, job_count_rx) = watch::channel_with(0);
- let tx = Arc::new(Mutex::new(job_count_tx));
- let job_handle = JobHandle::new(&tx);
-
- assert_eq!(1, *job_count_rx.borrow());
- let new_job_handle = job_handle.clone();
- assert_eq!(1, *job_count_rx.borrow());
- drop(job_handle);
- assert_eq!(1, *job_count_rx.borrow());
- drop(new_job_handle);
- assert_eq!(0, *job_count_rx.borrow());
- }
-}
@@ -1,33 +0,0 @@
-use anyhow;
-use schemars::JsonSchema;
-use serde::{Deserialize, Serialize};
-use settings::Settings;
-
-#[derive(Deserialize, Debug)]
-pub struct SemanticIndexSettings {
- pub enabled: bool,
-}
-
-/// Configuration of semantic index, an alternate search engine available in
-/// project search.
-#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
-pub struct SemanticIndexSettingsContent {
- /// Whether or not to display the Semantic mode in project search.
- ///
- /// Default: true
- pub enabled: Option<bool>,
-}
-
-impl Settings for SemanticIndexSettings {
- const KEY: Option<&'static str> = Some("semantic_index");
-
- type FileContent = SemanticIndexSettingsContent;
-
- fn load(
- default_value: &Self::FileContent,
- user_values: &[&Self::FileContent],
- _: &mut gpui::AppContext,
- ) -> anyhow::Result<Self> {
- Self::load_via_json_merge(default_value, user_values)
- }
-}
@@ -1,1725 +0,0 @@
-use crate::{
- embedding_queue::EmbeddingQueue,
- parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
- semantic_index_settings::SemanticIndexSettings,
- FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
-};
-use ai::test::FakeEmbeddingProvider;
-use gpui::TestAppContext;
-use language::{Language, LanguageConfig, LanguageMatcher, LanguageRegistry, ToOffset};
-use parking_lot::Mutex;
-use pretty_assertions::assert_eq;
-use project::{FakeFs, Fs, Project};
-use rand::{rngs::StdRng, Rng};
-use serde_json::json;
-use settings::{Settings, SettingsStore};
-use std::{path::Path, sync::Arc, time::SystemTime};
-use unindent::Unindent;
-use util::{paths::PathMatcher, RandomCharIter};
-
-#[ctor::ctor]
-fn init_logger() {
- if std::env::var("RUST_LOG").is_ok() {
- env_logger::init();
- }
-}
-
-#[gpui::test]
-async fn test_semantic_index(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = FakeFs::new(cx.background_executor.clone());
- fs.insert_tree(
- "/the-root",
- json!({
- "src": {
- "file1.rs": "
- fn aaa() {
- println!(\"aaaaaaaaaaaa!\");
- }
-
- fn zzzzz() {
- println!(\"SLEEPING\");
- }
- ".unindent(),
- "file2.rs": "
- fn bbb() {
- println!(\"bbbbbbbbbbbbb!\");
- }
- struct pqpqpqp {}
- ".unindent(),
- "file3.toml": "
- ZZZZZZZZZZZZZZZZZZ = 5
- ".unindent(),
- }
- }),
- )
- .await;
-
- let languages = Arc::new(LanguageRegistry::test(cx.executor().clone()));
- let rust_language = rust_lang();
- let toml_language = toml_lang();
- languages.add(rust_language);
- languages.add(toml_language);
-
- let db_dir = tempfile::Builder::new()
- .prefix("vector-store")
- .tempdir()
- .unwrap();
- let db_path = db_dir.path().join("db.sqlite");
-
- let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
- let semantic_index = SemanticIndex::new(
- fs.clone(),
- db_path,
- embedding_provider.clone(),
- languages,
- cx.to_async(),
- )
- .await
- .unwrap();
-
- let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
-
- let search_results = semantic_index.update(cx, |store, cx| {
- store.search_project(
- project.clone(),
- "aaaaaabbbbzz".to_string(),
- 5,
- vec![],
- vec![],
- cx,
- )
- });
- let pending_file_count =
- semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
- cx.background_executor.run_until_parked();
- assert_eq!(*pending_file_count.borrow(), 3);
- cx.background_executor
- .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
- assert_eq!(*pending_file_count.borrow(), 0);
-
- let search_results = search_results.await.unwrap();
- assert_search_results(
- &search_results,
- &[
- (Path::new("src/file1.rs").into(), 0),
- (Path::new("src/file2.rs").into(), 0),
- (Path::new("src/file3.toml").into(), 0),
- (Path::new("src/file1.rs").into(), 45),
- (Path::new("src/file2.rs").into(), 45),
- ],
- cx,
- );
-
- // Test Include Files Functionality
- let include_files = vec![PathMatcher::new("*.rs").unwrap()];
- let exclude_files = vec![PathMatcher::new("*.rs").unwrap()];
- let rust_only_search_results = semantic_index
- .update(cx, |store, cx| {
- store.search_project(
- project.clone(),
- "aaaaaabbbbzz".to_string(),
- 5,
- include_files,
- vec![],
- cx,
- )
- })
- .await
- .unwrap();
-
- assert_search_results(
- &rust_only_search_results,
- &[
- (Path::new("src/file1.rs").into(), 0),
- (Path::new("src/file2.rs").into(), 0),
- (Path::new("src/file1.rs").into(), 45),
- (Path::new("src/file2.rs").into(), 45),
- ],
- cx,
- );
-
- let no_rust_search_results = semantic_index
- .update(cx, |store, cx| {
- store.search_project(
- project.clone(),
- "aaaaaabbbbzz".to_string(),
- 5,
- vec![],
- exclude_files,
- cx,
- )
- })
- .await
- .unwrap();
-
- assert_search_results(
- &no_rust_search_results,
- &[(Path::new("src/file3.toml").into(), 0)],
- cx,
- );
-
- fs.save(
- "/the-root/src/file2.rs".as_ref(),
- &"
- fn dddd() { println!(\"ddddd!\"); }
- struct pqpqpqp {}
- "
- .unindent()
- .into(),
- Default::default(),
- )
- .await
- .unwrap();
-
- cx.background_executor
- .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
-
- let prev_embedding_count = embedding_provider.embedding_count();
- let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
- cx.background_executor.run_until_parked();
- assert_eq!(*pending_file_count.borrow(), 1);
- cx.background_executor
- .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
- assert_eq!(*pending_file_count.borrow(), 0);
- index.await.unwrap();
-
- assert_eq!(
- embedding_provider.embedding_count() - prev_embedding_count,
- 1
- );
-}
-
-#[gpui::test(iterations = 10)]
-async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
- let (outstanding_job_count, _) = postage::watch::channel_with(0);
- let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count));
-
- let files = (1..=3)
- .map(|file_ix| FileToEmbed {
- worktree_id: 5,
- path: Path::new(&format!("path-{file_ix}")).into(),
- mtime: SystemTime::now(),
- spans: (0..rng.gen_range(4..22))
- .map(|document_ix| {
- let content_len = rng.gen_range(10..100);
- let content = RandomCharIter::new(&mut rng)
- .with_simple_text()
- .take(content_len)
- .collect::<String>();
- let digest = SpanDigest::from(content.as_str());
- Span {
- range: 0..10,
- embedding: None,
- name: format!("document {document_ix}"),
- content,
- digest,
- token_count: rng.gen_range(10..30),
- }
- })
- .collect(),
- job_handle: JobHandle::new(&outstanding_job_count),
- })
- .collect::<Vec<_>>();
-
- let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
-
- let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor.clone());
- for file in &files {
- queue.push(file.clone());
- }
- queue.flush();
-
- cx.background_executor.run_until_parked();
- let finished_files = queue.finished_files();
- let mut embedded_files: Vec<_> = files
- .iter()
- .map(|_| finished_files.try_recv().expect("no finished file"))
- .collect();
-
- let expected_files: Vec<_> = files
- .iter()
- .map(|file| {
- let mut file = file.clone();
- for doc in &mut file.spans {
- doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
- }
- file
- })
- .collect();
-
- embedded_files.sort_by_key(|f| f.path.clone());
-
- assert_eq!(embedded_files, expected_files);
-}
-
-#[track_caller]
-fn assert_search_results(
- actual: &[SearchResult],
- expected: &[(Arc<Path>, usize)],
- cx: &TestAppContext,
-) {
- let actual = actual
- .iter()
- .map(|search_result| {
- search_result.buffer.read_with(cx, |buffer, _cx| {
- (
- buffer.file().unwrap().path().clone(),
- search_result.range.start.to_offset(buffer),
- )
- })
- })
- .collect::<Vec<_>>();
- assert_eq!(actual, expected);
-}
-
-#[gpui::test]
-async fn test_code_context_retrieval_rust() {
- let language = rust_lang();
- let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
- let mut retriever = CodeContextRetriever::new(embedding_provider);
-
- let text = "
- /// A doc comment
- /// that spans multiple lines
- #[gpui::test]
- fn a() {
- b
- }
-
- impl C for D {
- }
-
- impl E {
- // This is also a preceding comment
- pub fn function_1() -> Option<()> {
- unimplemented!();
- }
-
- // This is a preceding comment
- fn function_2() -> Result<()> {
- unimplemented!();
- }
- }
-
- #[derive(Clone)]
- struct D {
- name: String
- }
- "
- .unindent();
-
- let documents = retriever.parse_file(&text, language).unwrap();
-
- assert_documents_eq(
- &documents,
- &[
- (
- "
- /// A doc comment
- /// that spans multiple lines
- #[gpui::test]
- fn a() {
- b
- }"
- .unindent(),
- text.find("fn a").unwrap(),
- ),
- (
- "
- impl C for D {
- }"
- .unindent(),
- text.find("impl C").unwrap(),
- ),
- (
- "
- impl E {
- // This is also a preceding comment
- pub fn function_1() -> Option<()> { /* ... */ }
-
- // This is a preceding comment
- fn function_2() -> Result<()> { /* ... */ }
- }"
- .unindent(),
- text.find("impl E").unwrap(),
- ),
- (
- "
- // This is also a preceding comment
- pub fn function_1() -> Option<()> {
- unimplemented!();
- }"
- .unindent(),
- text.find("pub fn function_1").unwrap(),
- ),
- (
- "
- // This is a preceding comment
- fn function_2() -> Result<()> {
- unimplemented!();
- }"
- .unindent(),
- text.find("fn function_2").unwrap(),
- ),
- (
- "
- #[derive(Clone)]
- struct D {
- name: String
- }"
- .unindent(),
- text.find("struct D").unwrap(),
- ),
- ],
- );
-}
-
-#[gpui::test]
-async fn test_code_context_retrieval_json() {
- let language = json_lang();
- let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
- let mut retriever = CodeContextRetriever::new(embedding_provider);
-
- let text = r#"
- {
- "array": [1, 2, 3, 4],
- "string": "abcdefg",
- "nested_object": {
- "array_2": [5, 6, 7, 8],
- "string_2": "hijklmnop",
- "boolean": true,
- "none": null
- }
- }
- "#
- .unindent();
-
- let documents = retriever.parse_file(&text, language.clone()).unwrap();
-
- assert_documents_eq(
- &documents,
- &[(
- r#"
- {
- "array": [],
- "string": "",
- "nested_object": {
- "array_2": [],
- "string_2": "",
- "boolean": true,
- "none": null
- }
- }"#
- .unindent(),
- text.find('{').unwrap(),
- )],
- );
-
- let text = r#"
- [
- {
- "name": "somebody",
- "age": 42
- },
- {
- "name": "somebody else",
- "age": 43
- }
- ]
- "#
- .unindent();
-
- let documents = retriever.parse_file(&text, language.clone()).unwrap();
-
- assert_documents_eq(
- &documents,
- &[(
- r#"
- [{
- "name": "",
- "age": 42
- }]"#
- .unindent(),
- text.find('[').unwrap(),
- )],
- );
-}
-
-fn assert_documents_eq(
- documents: &[Span],
- expected_contents_and_start_offsets: &[(String, usize)],
-) {
- assert_eq!(
- documents
- .iter()
- .map(|document| (document.content.clone(), document.range.start))
- .collect::<Vec<_>>(),
- expected_contents_and_start_offsets
- );
-}
-
-#[gpui::test]
-async fn test_code_context_retrieval_javascript() {
- let language = js_lang();
- let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
- let mut retriever = CodeContextRetriever::new(embedding_provider);
-
- let text = "
- /* globals importScripts, backend */
- function _authorize() {}
-
- /**
- * Sometimes the frontend build is way faster than backend.
- */
- export async function authorizeBank() {
- _authorize(pushModal, upgradingAccountId, {});
- }
-
- export class SettingsPage {
- /* This is a test setting */
- constructor(page) {
- this.page = page;
- }
- }
-
- /* This is a test comment */
- class TestClass {}
-
- /* Schema for editor_events in Clickhouse. */
- export interface ClickhouseEditorEvent {
- installation_id: string
- operation: string
- }
- "
- .unindent();
-
- let documents = retriever.parse_file(&text, language.clone()).unwrap();
-
- assert_documents_eq(
- &documents,
- &[
- (
- "
- /* globals importScripts, backend */
- function _authorize() {}"
- .unindent(),
- 37,
- ),
- (
- "
- /**
- * Sometimes the frontend build is way faster than backend.
- */
- export async function authorizeBank() {
- _authorize(pushModal, upgradingAccountId, {});
- }"
- .unindent(),
- 131,
- ),
- (
- "
- export class SettingsPage {
- /* This is a test setting */
- constructor(page) {
- this.page = page;
- }
- }"
- .unindent(),
- 225,
- ),
- (
- "
- /* This is a test setting */
- constructor(page) {
- this.page = page;
- }"
- .unindent(),
- 290,
- ),
- (
- "
- /* This is a test comment */
- class TestClass {}"
- .unindent(),
- 374,
- ),
- (
- "
- /* Schema for editor_events in Clickhouse. */
- export interface ClickhouseEditorEvent {
- installation_id: string
- operation: string
- }"
- .unindent(),
- 440,
- ),
- ],
- )
-}
-
-#[gpui::test]
-async fn test_code_context_retrieval_lua() {
- let language = lua_lang();
- let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
- let mut retriever = CodeContextRetriever::new(embedding_provider);
-
- let text = r#"
- -- Creates a new class
- -- @param baseclass The Baseclass of this class, or nil.
- -- @return A new class reference.
- function classes.class(baseclass)
- -- Create the class definition and metatable.
- local classdef = {}
- -- Find the super class, either Object or user-defined.
- baseclass = baseclass or classes.Object
- -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
- setmetatable(classdef, { __index = baseclass })
- -- All class instances have a reference to the class object.
- classdef.class = classdef
- --- Recursively allocates the inheritance tree of the instance.
- -- @param mastertable The 'root' of the inheritance tree.
- -- @return Returns the instance with the allocated inheritance tree.
- function classdef.alloc(mastertable)
- -- All class instances have a reference to a superclass object.
- local instance = { super = baseclass.alloc(mastertable) }
- -- Any functions this instance does not know of will 'look up' to the superclass definition.
- setmetatable(instance, { __index = classdef, __newindex = mastertable })
- return instance
- end
- end
- "#.unindent();
-
- let documents = retriever.parse_file(&text, language.clone()).unwrap();
-
- assert_documents_eq(
- &documents,
- &[
- (r#"
- -- Creates a new class
- -- @param baseclass The Baseclass of this class, or nil.
- -- @return A new class reference.
- function classes.class(baseclass)
- -- Create the class definition and metatable.
- local classdef = {}
- -- Find the super class, either Object or user-defined.
- baseclass = baseclass or classes.Object
- -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable.
- setmetatable(classdef, { __index = baseclass })
- -- All class instances have a reference to the class object.
- classdef.class = classdef
- --- Recursively allocates the inheritance tree of the instance.
- -- @param mastertable The 'root' of the inheritance tree.
- -- @return Returns the instance with the allocated inheritance tree.
- function classdef.alloc(mastertable)
- --[ ... ]--
- --[ ... ]--
- end
- end"#.unindent(),
- 114),
- (r#"
- --- Recursively allocates the inheritance tree of the instance.
- -- @param mastertable The 'root' of the inheritance tree.
- -- @return Returns the instance with the allocated inheritance tree.
- function classdef.alloc(mastertable)
- -- All class instances have a reference to a superclass object.
- local instance = { super = baseclass.alloc(mastertable) }
- -- Any functions this instance does not know of will 'look up' to the superclass definition.
- setmetatable(instance, { __index = classdef, __newindex = mastertable })
- return instance
- end"#.unindent(), 810),
- ]
- );
-}
-
-#[gpui::test]
-async fn test_code_context_retrieval_elixir() {
- let language = elixir_lang();
- let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
- let mut retriever = CodeContextRetriever::new(embedding_provider);
-
- let text = r#"
- defmodule File.Stream do
- @moduledoc """
- Defines a `File.Stream` struct returned by `File.stream!/3`.
-
- The following fields are public:
-
- * `path` - the file path
- * `modes` - the file modes
- * `raw` - a boolean indicating if bin functions should be used
- * `line_or_bytes` - if reading should read lines or a given number of bytes
- * `node` - the node the file belongs to
-
- """
-
- defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
-
- @type t :: %__MODULE__{}
-
- @doc false
- def __build__(path, modes, line_or_bytes) do
- raw = :lists.keyfind(:encoding, 1, modes) == false
-
- modes =
- case raw do
- true ->
- case :lists.keyfind(:read_ahead, 1, modes) do
- {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
- {:read_ahead, _} -> [:raw | modes]
- false -> [:raw, :read_ahead | modes]
- end
-
- false ->
- modes
- end
-
- %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
-
- end"#
- .unindent();
-
- let documents = retriever.parse_file(&text, language.clone()).unwrap();
-
- assert_documents_eq(
- &documents,
- &[(
- r#"
- defmodule File.Stream do
- @moduledoc """
- Defines a `File.Stream` struct returned by `File.stream!/3`.
-
- The following fields are public:
-
- * `path` - the file path
- * `modes` - the file modes
- * `raw` - a boolean indicating if bin functions should be used
- * `line_or_bytes` - if reading should read lines or a given number of bytes
- * `node` - the node the file belongs to
-
- """
-
- defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil
-
- @type t :: %__MODULE__{}
-
- @doc false
- def __build__(path, modes, line_or_bytes) do
- raw = :lists.keyfind(:encoding, 1, modes) == false
-
- modes =
- case raw do
- true ->
- case :lists.keyfind(:read_ahead, 1, modes) do
- {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
- {:read_ahead, _} -> [:raw | modes]
- false -> [:raw, :read_ahead | modes]
- end
-
- false ->
- modes
- end
-
- %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
-
- end"#
- .unindent(),
- 0,
- ),(r#"
- @doc false
- def __build__(path, modes, line_or_bytes) do
- raw = :lists.keyfind(:encoding, 1, modes) == false
-
- modes =
- case raw do
- true ->
- case :lists.keyfind(:read_ahead, 1, modes) do
- {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)]
- {:read_ahead, _} -> [:raw | modes]
- false -> [:raw, :read_ahead | modes]
- end
-
- false ->
- modes
- end
-
- %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()}
-
- end"#.unindent(), 574)],
- );
-}
-
-#[gpui::test]
-async fn test_code_context_retrieval_cpp() {
- let language = cpp_lang();
- let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
- let mut retriever = CodeContextRetriever::new(embedding_provider);
-
- let text = "
- /**
- * @brief Main function
- * @returns 0 on exit
- */
- int main() { return 0; }
-
- /**
- * This is a test comment
- */
- class MyClass { // The class
- public: // Access specifier
- int myNum; // Attribute (int variable)
- string myString; // Attribute (string variable)
- };
-
- // This is a test comment
- enum Color { red, green, blue };
-
- /** This is a preceding block comment
- * This is the second line
- */
- struct { // Structure declaration
- int myNum; // Member (int variable)
- string myString; // Member (string variable)
- } myStructure;
-
- /**
- * @brief Matrix class.
- */
- template <typename T,
- typename = typename std::enable_if<
- std::is_integral<T>::value || std::is_floating_point<T>::value,
- bool>::type>
- class Matrix2 {
- std::vector<std::vector<T>> _mat;
-
- public:
- /**
- * @brief Constructor
- * @tparam Integer ensuring integers are being evaluated and not other
- * data types.
- * @param size denoting the size of Matrix as size x size
- */
- template <typename Integer,
- typename = typename std::enable_if<std::is_integral<Integer>::value,
- Integer>::type>
- explicit Matrix(const Integer size) {
- for (size_t i = 0; i < size; ++i) {
- _mat.emplace_back(std::vector<T>(size, 0));
- }
- }
- }"
- .unindent();
-
- let documents = retriever.parse_file(&text, language.clone()).unwrap();
-
- assert_documents_eq(
- &documents,
- &[
- (
- "
- /**
- * @brief Main function
- * @returns 0 on exit
- */
- int main() { return 0; }"
- .unindent(),
- 54,
- ),
- (
- "
- /**
- * This is a test comment
- */
- class MyClass { // The class
- public: // Access specifier
- int myNum; // Attribute (int variable)
- string myString; // Attribute (string variable)
- }"
- .unindent(),
- 112,
- ),
- (
- "
- // This is a test comment
- enum Color { red, green, blue }"
- .unindent(),
- 322,
- ),
- (
- "
- /** This is a preceding block comment
- * This is the second line
- */
- struct { // Structure declaration
- int myNum; // Member (int variable)
- string myString; // Member (string variable)
- } myStructure;"
- .unindent(),
- 425,
- ),
- (
- "
- /**
- * @brief Matrix class.
- */
- template <typename T,
- typename = typename std::enable_if<
- std::is_integral<T>::value || std::is_floating_point<T>::value,
- bool>::type>
- class Matrix2 {
- std::vector<std::vector<T>> _mat;
-
- public:
- /**
- * @brief Constructor
- * @tparam Integer ensuring integers are being evaluated and not other
- * data types.
- * @param size denoting the size of Matrix as size x size
- */
- template <typename Integer,
- typename = typename std::enable_if<std::is_integral<Integer>::value,
- Integer>::type>
- explicit Matrix(const Integer size) {
- for (size_t i = 0; i < size; ++i) {
- _mat.emplace_back(std::vector<T>(size, 0));
- }
- }
- }"
- .unindent(),
- 612,
- ),
- (
- "
- explicit Matrix(const Integer size) {
- for (size_t i = 0; i < size; ++i) {
- _mat.emplace_back(std::vector<T>(size, 0));
- }
- }"
- .unindent(),
- 1226,
- ),
- ],
- );
-}
-
-#[gpui::test]
-async fn test_code_context_retrieval_ruby() {
- let language = ruby_lang();
- let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
- let mut retriever = CodeContextRetriever::new(embedding_provider);
-
- let text = r#"
- # This concern is inspired by "sudo mode" on GitHub. It
- # is a way to re-authenticate a user before allowing them
- # to see or perform an action.
- #
- # Add `before_action :require_challenge!` to actions you
- # want to protect.
- #
- # The user will be shown a page to enter the challenge (which
- # is either the password, or just the username when no
- # password exists). Upon passing, there is a grace period
- # during which no challenge will be asked from the user.
- #
- # Accessing challenge-protected resources during the grace
- # period will refresh the grace period.
- module ChallengableConcern
- extend ActiveSupport::Concern
-
- CHALLENGE_TIMEOUT = 1.hour.freeze
-
- def require_challenge!
- return if skip_challenge?
-
- if challenge_passed_recently?
- session[:challenge_passed_at] = Time.now.utc
- return
- end
-
- @challenge = Form::Challenge.new(return_to: request.url)
-
- if params.key?(:form_challenge)
- if challenge_passed?
- session[:challenge_passed_at] = Time.now.utc
- else
- flash.now[:alert] = I18n.t('challenge.invalid_password')
- render_challenge
- end
- else
- render_challenge
- end
- end
-
- def challenge_passed?
- current_user.valid_password?(challenge_params[:current_password])
- end
- end
-
- class Animal
- include Comparable
-
- attr_reader :legs
-
- def initialize(name, legs)
- @name, @legs = name, legs
- end
-
- def <=>(other)
- legs <=> other.legs
- end
- end
-
- # Singleton method for car object
- def car.wheels
- puts "There are four wheels"
- end"#
- .unindent();
-
- let documents = retriever.parse_file(&text, language.clone()).unwrap();
-
- assert_documents_eq(
- &documents,
- &[
- (
- r#"
- # This concern is inspired by "sudo mode" on GitHub. It
- # is a way to re-authenticate a user before allowing them
- # to see or perform an action.
- #
- # Add `before_action :require_challenge!` to actions you
- # want to protect.
- #
- # The user will be shown a page to enter the challenge (which
- # is either the password, or just the username when no
- # password exists). Upon passing, there is a grace period
- # during which no challenge will be asked from the user.
- #
- # Accessing challenge-protected resources during the grace
- # period will refresh the grace period.
- module ChallengableConcern
- extend ActiveSupport::Concern
-
- CHALLENGE_TIMEOUT = 1.hour.freeze
-
- def require_challenge!
- # ...
- end
-
- def challenge_passed?
- # ...
- end
- end"#
- .unindent(),
- 558,
- ),
- (
- r#"
- def require_challenge!
- return if skip_challenge?
-
- if challenge_passed_recently?
- session[:challenge_passed_at] = Time.now.utc
- return
- end
-
- @challenge = Form::Challenge.new(return_to: request.url)
-
- if params.key?(:form_challenge)
- if challenge_passed?
- session[:challenge_passed_at] = Time.now.utc
- else
- flash.now[:alert] = I18n.t('challenge.invalid_password')
- render_challenge
- end
- else
- render_challenge
- end
- end"#
- .unindent(),
- 663,
- ),
- (
- r#"
- def challenge_passed?
- current_user.valid_password?(challenge_params[:current_password])
- end"#
- .unindent(),
- 1254,
- ),
- (
- r#"
- class Animal
- include Comparable
-
- attr_reader :legs
-
- def initialize(name, legs)
- # ...
- end
-
- def <=>(other)
- # ...
- end
- end"#
- .unindent(),
- 1363,
- ),
- (
- r#"
- def initialize(name, legs)
- @name, @legs = name, legs
- end"#
- .unindent(),
- 1427,
- ),
- (
- r#"
- def <=>(other)
- legs <=> other.legs
- end"#
- .unindent(),
- 1501,
- ),
- (
- r#"
- # Singleton method for car object
- def car.wheels
- puts "There are four wheels"
- end"#
- .unindent(),
- 1591,
- ),
- ],
- );
-}
-
-#[gpui::test]
-async fn test_code_context_retrieval_php() {
- let language = php_lang();
- let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
- let mut retriever = CodeContextRetriever::new(embedding_provider);
-
- let text = r#"
- <?php
-
- namespace LevelUp\Experience\Concerns;
-
- /*
- This is a multiple-lines comment block
- that spans over multiple
- lines
- */
- function functionName() {
- echo "Hello world!";
- }
-
- trait HasAchievements
- {
- /**
- * @throws \Exception
- */
- public function grantAchievement(Achievement $achievement, $progress = null): void
- {
- if ($progress > 100) {
- throw new Exception(message: 'Progress cannot be greater than 100');
- }
-
- if ($this->achievements()->find($achievement->id)) {
- throw new Exception(message: 'User already has this Achievement');
- }
-
- $this->achievements()->attach($achievement, [
- 'progress' => $progress ?? null,
- ]);
-
- $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
- }
-
- public function achievements(): BelongsToMany
- {
- return $this->belongsToMany(related: Achievement::class)
- ->withPivot(columns: 'progress')
- ->where('is_secret', false)
- ->using(AchievementUser::class);
- }
- }
-
- interface Multiplier
- {
- public function qualifies(array $data): bool;
-
- public function setMultiplier(): int;
- }
-
- enum AuditType: string
- {
- case Add = 'add';
- case Remove = 'remove';
- case Reset = 'reset';
- case LevelUp = 'level_up';
- }
-
- ?>"#
- .unindent();
-
- let documents = retriever.parse_file(&text, language.clone()).unwrap();
-
- assert_documents_eq(
- &documents,
- &[
- (
- r#"
- /*
- This is a multiple-lines comment block
- that spans over multiple
- lines
- */
- function functionName() {
- echo "Hello world!";
- }"#
- .unindent(),
- 123,
- ),
- (
- r#"
- trait HasAchievements
- {
- /**
- * @throws \Exception
- */
- public function grantAchievement(Achievement $achievement, $progress = null): void
- {/* ... */}
-
- public function achievements(): BelongsToMany
- {/* ... */}
- }"#
- .unindent(),
- 177,
- ),
- (r#"
- /**
- * @throws \Exception
- */
- public function grantAchievement(Achievement $achievement, $progress = null): void
- {
- if ($progress > 100) {
- throw new Exception(message: 'Progress cannot be greater than 100');
- }
-
- if ($this->achievements()->find($achievement->id)) {
- throw new Exception(message: 'User already has this Achievement');
- }
-
- $this->achievements()->attach($achievement, [
- 'progress' => $progress ?? null,
- ]);
-
- $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this)));
- }"#.unindent(), 245),
- (r#"
- public function achievements(): BelongsToMany
- {
- return $this->belongsToMany(related: Achievement::class)
- ->withPivot(columns: 'progress')
- ->where('is_secret', false)
- ->using(AchievementUser::class);
- }"#.unindent(), 902),
- (r#"
- interface Multiplier
- {
- public function qualifies(array $data): bool;
-
- public function setMultiplier(): int;
- }"#.unindent(),
- 1146),
- (r#"
- enum AuditType: string
- {
- case Add = 'add';
- case Remove = 'remove';
- case Reset = 'reset';
- case LevelUp = 'level_up';
- }"#.unindent(), 1265)
- ],
- );
-}
-
-fn js_lang() -> Arc<Language> {
- Arc::new(
- Language::new(
- LanguageConfig {
- name: "Javascript".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["js".into()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_typescript::language_tsx()),
- )
- .with_embedding_query(
- &r#"
-
- (
- (comment)* @context
- .
- [
- (export_statement
- (function_declaration
- "async"? @name
- "function" @name
- name: (_) @name))
- (function_declaration
- "async"? @name
- "function" @name
- name: (_) @name)
- ] @item
- )
-
- (
- (comment)* @context
- .
- [
- (export_statement
- (class_declaration
- "class" @name
- name: (_) @name))
- (class_declaration
- "class" @name
- name: (_) @name)
- ] @item
- )
-
- (
- (comment)* @context
- .
- [
- (export_statement
- (interface_declaration
- "interface" @name
- name: (_) @name))
- (interface_declaration
- "interface" @name
- name: (_) @name)
- ] @item
- )
-
- (
- (comment)* @context
- .
- [
- (export_statement
- (enum_declaration
- "enum" @name
- name: (_) @name))
- (enum_declaration
- "enum" @name
- name: (_) @name)
- ] @item
- )
-
- (
- (comment)* @context
- .
- (method_definition
- [
- "get"
- "set"
- "async"
- "*"
- "static"
- ]* @name
- name: (_) @name) @item
- )
-
- "#
- .unindent(),
- )
- .unwrap(),
- )
-}
-
-fn rust_lang() -> Arc<Language> {
- Arc::new(
- Language::new(
- LanguageConfig {
- name: "Rust".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".into()],
- ..Default::default()
- },
- collapsed_placeholder: " /* ... */ ".to_string(),
- ..Default::default()
- },
- Some(tree_sitter_rust::language()),
- )
- .with_embedding_query(
- r#"
- (
- [(line_comment) (attribute_item)]* @context
- .
- [
- (struct_item
- name: (_) @name)
-
- (enum_item
- name: (_) @name)
-
- (impl_item
- trait: (_)? @name
- "for"? @name
- type: (_) @name)
-
- (trait_item
- name: (_) @name)
-
- (function_item
- name: (_) @name
- body: (block
- "{" @keep
- "}" @keep) @collapse)
-
- (macro_definition
- name: (_) @name)
- ] @item
- )
-
- (attribute_item) @collapse
- (use_declaration) @collapse
- "#,
- )
- .unwrap(),
- )
-}
-
-fn json_lang() -> Arc<Language> {
- Arc::new(
- Language::new(
- LanguageConfig {
- name: "JSON".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["json".into()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_json::language()),
- )
- .with_embedding_query(
- r#"
- (document) @item
-
- (array
- "[" @keep
- .
- (object)? @keep
- "]" @keep) @collapse
-
- (pair value: (string
- "\"" @keep
- "\"" @keep) @collapse)
- "#,
- )
- .unwrap(),
- )
-}
-
-fn toml_lang() -> Arc<Language> {
- Arc::new(Language::new(
- LanguageConfig {
- name: "TOML".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["toml".into()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_toml::language()),
- ))
-}
-
-fn cpp_lang() -> Arc<Language> {
- Arc::new(
- Language::new(
- LanguageConfig {
- name: "CPP".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["cpp".into()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_cpp::language()),
- )
- .with_embedding_query(
- r#"
- (
- (comment)* @context
- .
- (function_definition
- (type_qualifier)? @name
- type: (_)? @name
- declarator: [
- (function_declarator
- declarator: (_) @name)
- (pointer_declarator
- "*" @name
- declarator: (function_declarator
- declarator: (_) @name))
- (pointer_declarator
- "*" @name
- declarator: (pointer_declarator
- "*" @name
- declarator: (function_declarator
- declarator: (_) @name)))
- (reference_declarator
- ["&" "&&"] @name
- (function_declarator
- declarator: (_) @name))
- ]
- (type_qualifier)? @name) @item
- )
-
- (
- (comment)* @context
- .
- (template_declaration
- (class_specifier
- "class" @name
- name: (_) @name)
- ) @item
- )
-
- (
- (comment)* @context
- .
- (class_specifier
- "class" @name
- name: (_) @name) @item
- )
-
- (
- (comment)* @context
- .
- (enum_specifier
- "enum" @name
- name: (_) @name) @item
- )
-
- (
- (comment)* @context
- .
- (declaration
- type: (struct_specifier
- "struct" @name)
- declarator: (_) @name) @item
- )
-
- "#,
- )
- .unwrap(),
- )
-}
-
-fn lua_lang() -> Arc<Language> {
- Arc::new(
- Language::new(
- LanguageConfig {
- name: "Lua".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["lua".into()],
- ..Default::default()
- },
- collapsed_placeholder: "--[ ... ]--".to_string(),
- ..Default::default()
- },
- Some(tree_sitter_lua::language()),
- )
- .with_embedding_query(
- r#"
- (
- (comment)* @context
- .
- (function_declaration
- "function" @name
- name: (_) @name
- (comment)* @collapse
- body: (block) @collapse
- ) @item
- )
- "#,
- )
- .unwrap(),
- )
-}
-
-fn php_lang() -> Arc<Language> {
- Arc::new(
- Language::new(
- LanguageConfig {
- name: "PHP".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["php".into()],
- ..Default::default()
- },
- collapsed_placeholder: "/* ... */".into(),
- ..Default::default()
- },
- Some(tree_sitter_php::language_php()),
- )
- .with_embedding_query(
- r#"
- (
- (comment)* @context
- .
- [
- (function_definition
- "function" @name
- name: (_) @name
- body: (_
- "{" @keep
- "}" @keep) @collapse
- )
-
- (trait_declaration
- "trait" @name
- name: (_) @name)
-
- (method_declaration
- "function" @name
- name: (_) @name
- body: (_
- "{" @keep
- "}" @keep) @collapse
- )
-
- (interface_declaration
- "interface" @name
- name: (_) @name
- )
-
- (enum_declaration
- "enum" @name
- name: (_) @name
- )
-
- ] @item
- )
- "#,
- )
- .unwrap(),
- )
-}
-
-fn ruby_lang() -> Arc<Language> {
- Arc::new(
- Language::new(
- LanguageConfig {
- name: "Ruby".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rb".into()],
- ..Default::default()
- },
- collapsed_placeholder: "# ...".to_string(),
- ..Default::default()
- },
- Some(tree_sitter_ruby::language()),
- )
- .with_embedding_query(
- r#"
- (
- (comment)* @context
- .
- [
- (module
- "module" @name
- name: (_) @name)
- (method
- "def" @name
- name: (_) @name
- body: (body_statement) @collapse)
- (class
- "class" @name
- name: (_) @name)
- (singleton_method
- "def" @name
- object: (_) @name
- "." @name
- name: (_) @name
- body: (body_statement) @collapse)
- ] @item
- )
- "#,
- )
- .unwrap(),
- )
-}
-
-fn elixir_lang() -> Arc<Language> {
- Arc::new(
- Language::new(
- LanguageConfig {
- name: "Elixir".into(),
- matcher: LanguageMatcher {
- path_suffixes: vec!["rs".into()],
- ..Default::default()
- },
- ..Default::default()
- },
- Some(tree_sitter_elixir::language()),
- )
- .with_embedding_query(
- r#"
- (
- (unary_operator
- operator: "@"
- operand: (call
- target: (identifier) @unary
- (#match? @unary "^(doc)$"))
- ) @context
- .
- (call
- target: (identifier) @name
- (arguments
- [
- (identifier) @name
- (call
- target: (identifier) @name)
- (binary_operator
- left: (call
- target: (identifier) @name)
- operator: "when")
- ])
- (#any-match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item
- )
-
- (call
- target: (identifier) @name
- (arguments (alias) @name)
- (#any-match? @name "^(defmodule|defprotocol)$")) @item
- "#,
- )
- .unwrap(),
- )
-}
-
-#[gpui::test]
-fn test_subtract_ranges() {
- assert_eq!(
- subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]),
- vec![1..4, 10..21]
- );
-
- assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]);
-}
-
-fn init_test(cx: &mut TestAppContext) {
- cx.update(|cx| {
- let settings_store = SettingsStore::test(cx);
- cx.set_global(settings_store);
- SemanticIndexSettings::register(cx);
- language::init(cx);
- Project::init_settings(cx);
- });
-}
@@ -479,7 +479,28 @@ impl SettingsStore {
merge_schema(target_schema, setting_schema.schema);
}
- fn merge_schema(target: &mut SchemaObject, source: SchemaObject) {
+ fn merge_schema(target: &mut SchemaObject, mut source: SchemaObject) {
+ let source_subschemas = source.subschemas();
+ let target_subschemas = target.subschemas();
+ if let Some(all_of) = source_subschemas.all_of.take() {
+ target_subschemas
+ .all_of
+ .get_or_insert(Vec::new())
+ .extend(all_of);
+ }
+ if let Some(any_of) = source_subschemas.any_of.take() {
+ target_subschemas
+ .any_of
+ .get_or_insert(Vec::new())
+ .extend(any_of);
+ }
+ if let Some(one_of) = source_subschemas.one_of.take() {
+ target_subschemas
+ .one_of
+ .get_or_insert(Vec::new())
+ .extend(one_of);
+ }
+
if let Some(source) = source.object {
let target_properties = &mut target.object().properties;
for (key, value) in source.properties {
@@ -5,9 +5,8 @@ use futures_lite::FutureExt;
use isahc::config::{Configurable, RedirectPolicy};
pub use isahc::{
http::{Method, StatusCode, Uri},
- Error,
+ AsyncBody, Error, HttpClient as IsahcHttpClient, Request, Response,
};
-pub use isahc::{AsyncBody, Request, Response};
#[cfg(feature = "test-support")]
use std::fmt;
use std::{
@@ -71,7 +71,6 @@ recent_projects.workspace = true
release_channel.workspace = true
rope.workspace = true
search.workspace = true
-semantic_index.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
@@ -174,7 +174,7 @@ fn main() {
node_runtime.clone(),
cx,
);
- assistant::init(cx);
+ assistant::init(client.clone(), cx);
extension::init(
fs.clone(),
@@ -247,7 +247,6 @@ fn main() {
tasks_ui::init(cx);
channel::init(&client, user_store.clone(), cx);
search::init(cx);
- semantic_index::init(fs.clone(), http.clone(), languages.clone(), cx);
vim::init(cx);
terminal_view::init(cx);
@@ -3060,7 +3060,7 @@ mod tests {
collab_ui::init(&app_state, cx);
project_panel::init((), cx);
terminal_view::init(cx);
- assistant::init(cx);
+ assistant::init(app_state.client.clone(), cx);
initialize_workspace(app_state.clone(), cx);
app_state
})
@@ -606,28 +606,6 @@ These values take in the same options as the root-level settings with the same n
`boolean` values
-## Semantic Index
-
-- Description: Settings related to semantic index.
-- Setting: `semantic_index`
-- Default:
-
-```json
-"semantic_index": {
- "enabled": false
-},
-```
-
-### Enabled
-
-- Description: Whether or not to display the `Semantic` mode in project search.
-- Setting: `enabled`
-- Default: `true`
-
-**Options**
-
-`boolean` values
-
## Show Call Status Icon
- Description: Whether or not to show the call status icon in the status bar.
@@ -11,3 +11,8 @@ cargo run -p collab -- migrate
echo "seeding database..."
script/seed-db
+
+if [[ "$OSTYPE" == "linux-gnu"* ]]; then
+ echo "Linux dependencies..."
+ script/linux
+fi
@@ -1,3 +0,0 @@
-#!/bin/bash
-
-RUST_LOG=semantic_index=trace cargo run --example semantic_index_eval --release
@@ -0,0 +1,91 @@
+import subprocess
+import json
+import http.client
+import mimetypes
+import os
+
+def get_text_files():
+ text_files = []
+ # List all files tracked by Git
+ git_files_proc = subprocess.run(['git', 'ls-files'], stdout=subprocess.PIPE, text=True)
+ for file in git_files_proc.stdout.strip().split('\n'):
+ # Check MIME type for each file
+ mime_check_proc = subprocess.run(['file', '--mime', file], stdout=subprocess.PIPE, text=True)
+ if 'text' in mime_check_proc.stdout:
+ text_files.append(file)
+
+ print(f"File count: {len(text_files)}")
+
+ return text_files
+
+def get_file_contents(file):
+ # Read file content
+ with open(file, 'r') as f:
+ return f.read()
+
+
+def main():
+ GEMINI_API_KEY = os.environ.get('GEMINI_API_KEY')
+
+ # Your prompt
+ prompt = "Document the data types and dataflow in this codebase in preparation to port a streaming implementation to rust:\n\n"
+ # Fetch all text files
+ text_files = get_text_files()
+ code_blocks = []
+ for file in text_files:
+ file_contents = get_file_contents(file)
+ # Create a code block for each text file
+ code_blocks.append(f"\n`{file}`\n\n```{file_contents}```\n")
+
+ # Construct the JSON payload
+ payload = json.dumps({
+ "contents": [{
+ "parts": [{
+ "text": prompt + "".join(code_blocks)
+ }]
+ }]
+ })
+
+ # Prepare the HTTP connection
+ conn = http.client.HTTPSConnection("generativelanguage.googleapis.com")
+
+ # Define headers
+ headers = {
+ 'Content-Type': 'application/json',
+ 'Content-Length': str(len(payload))
+ }
+
+ # Output the content length in bytes
+ print(f"Content Length in kilobytes: {len(payload.encode('utf-8')) / 1024:.2f} KB")
+
+
+ # Send a request to count the tokens
+ conn.request("POST", f"/v1beta/models/gemini-1.5-pro-latest:countTokens?key={GEMINI_API_KEY}", body=payload, headers=headers)
+ # Get the response
+ response = conn.getresponse()
+ if response.status == 200:
+ token_count = json.loads(response.read().decode('utf-8')).get('totalTokens')
+ print(f"Token count: {token_count}")
+ else:
+ print(f"Failed to get token count. Status code: {response.status}, Response body: {response.read().decode('utf-8')}")
+
+
+ # Prepare the HTTP connection
+ conn = http.client.HTTPSConnection("generativelanguage.googleapis.com")
+ conn.request("GET", f"/v1beta/models/gemini-1.5-pro-latest:streamGenerateContent?key={GEMINI_API_KEY}", body=payload, headers=headers)
+
+ # Get the response in a streaming manner
+ response = conn.getresponse()
+ if response.status == 200:
+ print("Successfully sent the data to the API.")
+ # Read the response in chunks
+ while chunk := response.read(4096):
+ print(chunk.decode('utf-8'))
+ else:
+ print(f"Failed to send the data to the API. Status code: {response.status}, Response body: {response.read().decode('utf-8')}")
+
+ # Close the connection
+ conn.close()
+
+if __name__ == "__main__":
+ main()
@@ -1,4 +1,6 @@
-#!/usr/bin/bash -e
+#!/usr/bin/bash
+
+set -e
# if sudo is not installed, define an empty alias
maysudo=$(command -v sudo || command -v doas || true)
@@ -0,0 +1 @@
+
@@ -3,12 +3,15 @@
set -e
# Install sqlx-cli if needed
-[[ "$(sqlx --version)" == "sqlx-cli 0.5.7" ]] || cargo install sqlx-cli --version 0.5.7
+if [[ "$(sqlx --version)" != "sqlx-cli 0.5.7" ]]; then
+ echo "sqlx-cli not found or not the required version, installing version 0.5.7..."
+ cargo install sqlx-cli --version 0.5.7
+fi
cd crates/collab
# Export contents of .env.toml
-eval "$(cargo run --quiet --bin dotenv)"
+eval "$(cargo run --bin dotenv)"
# Run sqlx command
sqlx $@