Detailed changes
@@ -77,33 +77,6 @@ dependencies = [
[[package]]
name = "ai"
version = "0.1.0"
-dependencies = [
- "anyhow",
- "async-trait",
- "bincode",
- "futures 0.3.28",
- "gpui",
- "isahc",
- "language",
- "lazy_static",
- "log",
- "matrixmultiply",
- "ordered-float 2.10.0",
- "parking_lot 0.11.2",
- "parse_duration",
- "postage",
- "rand 0.8.5",
- "regex",
- "rusqlite",
- "serde",
- "serde_json",
- "tiktoken-rs",
- "util",
-]
-
-[[package]]
-name = "ai2"
-version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
@@ -329,7 +302,7 @@ checksum = "d92bec98840b8f03a5ff5413de5293bfcd8bf96467cf5452609f939ec6f5de16"
name = "assistant"
version = "0.1.0"
dependencies = [
- "ai2",
+ "ai",
"anyhow",
"chrono",
"client2",
@@ -1301,7 +1274,7 @@ dependencies = [
"clock",
"collections",
"db2",
- "feature_flags2",
+ "feature_flags",
"futures 0.3.28",
"gpui2",
"image",
@@ -1512,7 +1485,7 @@ dependencies = [
"chrono",
"collections",
"db2",
- "feature_flags2",
+ "feature_flags",
"futures 0.3.28",
"gpui2",
"image",
@@ -1743,7 +1716,7 @@ dependencies = [
"collections",
"db2",
"editor",
- "feature_flags2",
+ "feature_flags",
"feedback",
"futures 0.3.28",
"fuzzy2",
@@ -1758,7 +1731,7 @@ dependencies = [
"pretty_assertions",
"project2",
"recent_projects",
- "rich_text2",
+ "rich_text",
"rpc2",
"schemars",
"serde",
@@ -1773,7 +1746,7 @@ dependencies = [
"util",
"vcs_menu",
"workspace",
- "zed_actions2",
+ "zed_actions",
]
[[package]]
@@ -1828,7 +1801,7 @@ dependencies = [
"ui2",
"util",
"workspace",
- "zed_actions2",
+ "zed_actions",
]
[[package]]
@@ -1947,7 +1920,7 @@ dependencies = [
"theme2",
"util",
"workspace",
- "zed_actions2",
+ "zed_actions",
]
[[package]]
@@ -2645,7 +2618,7 @@ dependencies = [
"postage",
"project2",
"rand 0.8.5",
- "rich_text2",
+ "rich_text",
"rpc2",
"schemars",
"serde",
@@ -2839,14 +2812,6 @@ checksum = "6999dc1837253364c2ebb0704ba97994bd874e8f195d665c50b7548f6ea92764"
[[package]]
name = "feature_flags"
version = "0.1.0"
-dependencies = [
- "anyhow",
- "gpui",
-]
-
-[[package]]
-name = "feature_flags2"
-version = "0.1.0"
dependencies = [
"anyhow",
"gpui2",
@@ -3991,22 +3956,10 @@ dependencies = [
[[package]]
name = "install_cli"
version = "0.1.0"
-dependencies = [
- "anyhow",
- "gpui",
- "log",
- "smol",
- "util",
-]
-
-[[package]]
-name = "install_cli2"
-version = "0.1.0"
dependencies = [
"anyhow",
"gpui2",
"log",
- "serde",
"smol",
"util",
]
@@ -5011,7 +4964,7 @@ dependencies = [
"project2",
"pulldown-cmark",
"rand 0.8.5",
- "rich_text2",
+ "rich_text",
"schemars",
"serde",
"serde_derive",
@@ -5209,7 +5162,7 @@ dependencies = [
"clock",
"collections",
"db2",
- "feature_flags2",
+ "feature_flags",
"gpui2",
"rpc2",
"settings2",
@@ -6827,24 +6780,6 @@ dependencies = [
[[package]]
name = "rich_text"
version = "0.1.0"
-dependencies = [
- "anyhow",
- "collections",
- "futures 0.3.28",
- "gpui",
- "language",
- "lazy_static",
- "pulldown-cmark",
- "smallvec",
- "smol",
- "sum_tree",
- "theme",
- "util",
-]
-
-[[package]]
-name = "rich_text2"
-version = "0.1.0"
dependencies = [
"anyhow",
"collections",
@@ -7576,7 +7511,7 @@ dependencies = [
name = "semantic_index2"
version = "0.1.0"
dependencies = [
- "ai2",
+ "ai",
"anyhow",
"async-trait",
"client2",
@@ -7778,7 +7713,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"collections",
- "feature_flags2",
+ "feature_flags",
"fs2",
"futures 0.3.28",
"gpui2",
@@ -8909,7 +8844,7 @@ version = "0.1.0"
dependencies = [
"client2",
"editor",
- "feature_flags2",
+ "feature_flags",
"fs2",
"fuzzy2",
"gpui2",
@@ -10026,7 +9961,7 @@ dependencies = [
"ui2",
"util",
"workspace",
- "zed_actions2",
+ "zed_actions",
]
[[package]]
@@ -10427,7 +10362,7 @@ dependencies = [
"fs2",
"fuzzy2",
"gpui2",
- "install_cli2",
+ "install_cli",
"log",
"picker",
"project2",
@@ -10697,7 +10632,7 @@ dependencies = [
"futures 0.3.28",
"gpui2",
"indoc",
- "install_cli2",
+ "install_cli",
"itertools 0.10.5",
"language2",
"lazy_static",
@@ -10802,7 +10737,7 @@ name = "zed"
version = "0.119.0"
dependencies = [
"activity_indicator",
- "ai2",
+ "ai",
"anyhow",
"assistant",
"async-compression",
@@ -10828,7 +10763,7 @@ dependencies = [
"diagnostics",
"editor",
"env_logger",
- "feature_flags2",
+ "feature_flags",
"feedback",
"file_finder",
"fs2",
@@ -10839,7 +10774,7 @@ dependencies = [
"ignore",
"image",
"indexmap 1.9.3",
- "install_cli2",
+ "install_cli",
"isahc",
"journal",
"language2",
@@ -10924,19 +10859,11 @@ dependencies = [
"vim",
"welcome",
"workspace",
- "zed_actions2",
-]
-
-[[package]]
-name = "zed-actions"
-version = "0.1.0"
-dependencies = [
- "gpui",
- "serde",
+ "zed_actions",
]
[[package]]
-name = "zed_actions2"
+name = "zed_actions"
version = "0.1.0"
dependencies = [
"gpui2",
@@ -32,7 +32,6 @@ members = [
"crates/drag_and_drop",
"crates/editor",
"crates/feature_flags",
- "crates/feature_flags2",
"crates/feedback",
"crates/file_finder",
"crates/fs",
@@ -47,7 +46,6 @@ members = [
"crates/gpui2",
"crates/gpui2_macros",
"crates/install_cli",
- "crates/install_cli2",
"crates/journal",
"crates/journal",
"crates/language",
@@ -108,8 +106,7 @@ members = [
"crates/welcome",
"crates/xtask",
"crates/zed",
- "crates/zed-actions",
- "crates/zed_actions2"
+ "crates/zed_actions",
]
default-members = ["crates/zed"]
resolver = "2"
@@ -12,9 +12,9 @@ doctest = false
test-support = []
[dependencies]
-gpui = { path = "../gpui" }
+gpui = { package = "gpui2", path = "../gpui2" }
util = { path = "../util" }
-language = { path = "../language" }
+language = { package = "language2", path = "../language2" }
async-trait.workspace = true
anyhow.workspace = true
futures.workspace = true
@@ -35,4 +35,4 @@ rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
bincode = "1.3.3"
[dev-dependencies]
-gpui = { path = "../gpui", features = ["test-support"] }
+gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
@@ -9,7 +9,7 @@ pub enum ProviderCredential {
pub trait CredentialProvider: Send + Sync {
fn has_credentials(&self) -> bool;
- fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential;
- fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential);
- fn delete_credentials(&self, cx: &AppContext);
+ fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential;
+ fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential);
+ fn delete_credentials(&self, cx: &mut AppContext);
}
@@ -2,7 +2,7 @@ use crate::prompts::base::{PromptArguments, PromptTemplate};
use std::fmt::Write;
use std::{ops::Range, path::PathBuf};
-use gpui::{AsyncAppContext, ModelHandle};
+use gpui::{AsyncAppContext, Model};
use language::{Anchor, Buffer};
#[derive(Clone)]
@@ -13,8 +13,12 @@ pub struct PromptCodeSnippet {
}
impl PromptCodeSnippet {
- pub fn new(buffer: ModelHandle<Buffer>, range: Range<Anchor>, cx: &AsyncAppContext) -> Self {
- let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| {
+ pub fn new(
+ buffer: Model<Buffer>,
+ range: Range<Anchor>,
+ cx: &mut AsyncAppContext,
+ ) -> anyhow::Result<Self> {
+ let (content, language_name, file_path) = buffer.update(cx, |buffer, _| {
let snapshot = buffer.snapshot();
let content = snapshot.text_for_range(range.clone()).collect::<String>();
@@ -27,13 +31,13 @@ impl PromptCodeSnippet {
.and_then(|file| Some(file.path().to_path_buf()));
(content, language_name, file_path)
- });
+ })?;
- PromptCodeSnippet {
+ anyhow::Ok(PromptCodeSnippet {
path: file_path,
language_name,
content,
- }
+ })
}
}
@@ -3,7 +3,7 @@ use futures::{
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
Stream, StreamExt,
};
-use gpui::{executor::Background, AppContext};
+use gpui::{AppContext, BackgroundExecutor};
use isahc::{http::StatusCode, Request, RequestExt};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
@@ -104,7 +104,7 @@ pub struct OpenAIResponseStreamEvent {
pub async fn stream_completion(
credential: ProviderCredential,
- executor: Arc<Background>,
+ executor: BackgroundExecutor,
request: Box<dyn CompletionRequest>,
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
let api_key = match credential {
@@ -197,11 +197,11 @@ pub async fn stream_completion(
pub struct OpenAICompletionProvider {
model: OpenAILanguageModel,
credential: Arc<RwLock<ProviderCredential>>,
- executor: Arc<Background>,
+ executor: BackgroundExecutor,
}
impl OpenAICompletionProvider {
- pub fn new(model_name: &str, executor: Arc<Background>) -> Self {
+ pub fn new(model_name: &str, executor: BackgroundExecutor) -> Self {
let model = OpenAILanguageModel::load(model_name);
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
Self {
@@ -219,46 +219,45 @@ impl CredentialProvider for OpenAICompletionProvider {
_ => false,
}
}
- fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
- let mut credential = self.credential.write();
- match *credential {
- ProviderCredential::Credentials { .. } => {
- return credential.clone();
- }
+
+ fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
+ let existing_credential = self.credential.read().clone();
+ let retrieved_credential = match existing_credential {
+ ProviderCredential::Credentials { .. } => existing_credential.clone(),
_ => {
- if let Ok(api_key) = env::var("OPENAI_API_KEY") {
- *credential = ProviderCredential::Credentials { api_key };
- } else if let Some((_, api_key)) = cx
- .platform()
- .read_credentials(OPENAI_API_URL)
- .log_err()
- .flatten()
+ if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
+ ProviderCredential::Credentials { api_key }
+ } else if let Some(Some((_, api_key))) =
+ cx.read_credentials(OPENAI_API_URL).log_err()
{
if let Some(api_key) = String::from_utf8(api_key).log_err() {
- *credential = ProviderCredential::Credentials { api_key };
+ ProviderCredential::Credentials { api_key }
+ } else {
+ ProviderCredential::NoCredentials
}
} else {
- };
+ ProviderCredential::NoCredentials
+ }
}
- }
-
- credential.clone()
+ };
+ *self.credential.write() = retrieved_credential.clone();
+ retrieved_credential
}
- fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
- match credential.clone() {
+ fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
+ *self.credential.write() = credential.clone();
+ let credential = credential.clone();
+ match credential {
ProviderCredential::Credentials { api_key } => {
- cx.platform()
- .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
+ cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
.log_err();
}
_ => {}
}
-
- *self.credential.write() = credential;
}
- fn delete_credentials(&self, cx: &AppContext) {
- cx.platform().delete_credentials(OPENAI_API_URL).log_err();
+
+ fn delete_credentials(&self, cx: &mut AppContext) {
+ cx.delete_credentials(OPENAI_API_URL).log_err();
*self.credential.write() = ProviderCredential::NoCredentials;
}
}
@@ -1,7 +1,7 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::AsyncReadExt;
-use gpui::executor::Background;
+use gpui::BackgroundExecutor;
use gpui::{serde_json, AppContext};
use isahc::http::StatusCode;
use isahc::prelude::Configurable;
@@ -35,7 +35,7 @@ pub struct OpenAIEmbeddingProvider {
model: OpenAILanguageModel,
credential: Arc<RwLock<ProviderCredential>>,
pub client: Arc<dyn HttpClient>,
- pub executor: Arc<Background>,
+ pub executor: BackgroundExecutor,
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
}
@@ -66,7 +66,7 @@ struct OpenAIEmbeddingUsage {
}
impl OpenAIEmbeddingProvider {
- pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
+ pub fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
@@ -153,46 +153,45 @@ impl CredentialProvider for OpenAIEmbeddingProvider {
_ => false,
}
}
- fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
- let mut credential = self.credential.write();
- match *credential {
- ProviderCredential::Credentials { .. } => {
- return credential.clone();
- }
+ fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
+ let existing_credential = self.credential.read().clone();
+
+ let retrieved_credential = match existing_credential {
+ ProviderCredential::Credentials { .. } => existing_credential.clone(),
_ => {
- if let Ok(api_key) = env::var("OPENAI_API_KEY") {
- *credential = ProviderCredential::Credentials { api_key };
- } else if let Some((_, api_key)) = cx
- .platform()
- .read_credentials(OPENAI_API_URL)
- .log_err()
- .flatten()
+ if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
+ ProviderCredential::Credentials { api_key }
+ } else if let Some(Some((_, api_key))) =
+ cx.read_credentials(OPENAI_API_URL).log_err()
{
if let Some(api_key) = String::from_utf8(api_key).log_err() {
- *credential = ProviderCredential::Credentials { api_key };
+ ProviderCredential::Credentials { api_key }
+ } else {
+ ProviderCredential::NoCredentials
}
} else {
- };
+ ProviderCredential::NoCredentials
+ }
}
- }
+ };
- credential.clone()
+ *self.credential.write() = retrieved_credential.clone();
+ retrieved_credential
}
- fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
- match credential.clone() {
+ fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
+ *self.credential.write() = credential.clone();
+ match credential {
ProviderCredential::Credentials { api_key } => {
- cx.platform()
- .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
+ cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
.log_err();
}
_ => {}
}
-
- *self.credential.write() = credential;
}
- fn delete_credentials(&self, cx: &AppContext) {
- cx.platform().delete_credentials(OPENAI_API_URL).log_err();
+
+ fn delete_credentials(&self, cx: &mut AppContext) {
+ cx.delete_credentials(OPENAI_API_URL).log_err();
*self.credential.write() = ProviderCredential::NoCredentials;
}
}
@@ -104,11 +104,11 @@ impl CredentialProvider for FakeEmbeddingProvider {
fn has_credentials(&self) -> bool {
true
}
- fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
+ fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
ProviderCredential::NotNeeded
}
- fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
- fn delete_credentials(&self, _cx: &AppContext) {}
+ fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
+ fn delete_credentials(&self, _cx: &mut AppContext) {}
}
#[async_trait]
@@ -153,17 +153,10 @@ impl FakeCompletionProvider {
pub fn send_completion(&self, completion: impl Into<String>) {
let mut tx = self.last_completion_tx.lock();
-
- println!("COMPLETION TX: {:?}", &tx);
-
- let a = tx.as_mut().unwrap();
- a.try_send(completion.into()).unwrap();
-
- // tx.as_mut().unwrap().try_send(completion.into()).unwrap();
+ tx.as_mut().unwrap().try_send(completion.into()).unwrap();
}
pub fn finish_completion(&self) {
- println!("FINISHING COMPLETION");
self.last_completion_tx.lock().take().unwrap();
}
}
@@ -172,11 +165,11 @@ impl CredentialProvider for FakeCompletionProvider {
fn has_credentials(&self) -> bool {
true
}
- fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
+ fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
ProviderCredential::NotNeeded
}
- fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
- fn delete_credentials(&self, _cx: &AppContext) {}
+ fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
+ fn delete_credentials(&self, _cx: &mut AppContext) {}
}
impl CompletionProvider for FakeCompletionProvider {
@@ -188,10 +181,8 @@ impl CompletionProvider for FakeCompletionProvider {
&self,
_prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
- println!("COMPLETING");
let (tx, rx) = mpsc::channel(1);
*self.last_completion_tx.lock() = Some(tx);
- println!("TX: {:?}", *self.last_completion_tx.lock());
async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
}
fn box_clone(&self) -> Box<dyn CompletionProvider> {
@@ -1,38 +0,0 @@
-[package]
-name = "ai2"
-version = "0.1.0"
-edition = "2021"
-publish = false
-
-[lib]
-path = "src/ai2.rs"
-doctest = false
-
-[features]
-test-support = []
-
-[dependencies]
-gpui = { package = "gpui2", path = "../gpui2" }
-util = { path = "../util" }
-language = { package = "language2", path = "../language2" }
-async-trait.workspace = true
-anyhow.workspace = true
-futures.workspace = true
-lazy_static.workspace = true
-ordered-float.workspace = true
-parking_lot.workspace = true
-isahc.workspace = true
-regex.workspace = true
-serde.workspace = true
-serde_json.workspace = true
-postage.workspace = true
-rand.workspace = true
-log.workspace = true
-parse_duration = "2.1.1"
-tiktoken-rs.workspace = true
-matrixmultiply = "0.3.7"
-rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
-bincode = "1.3.3"
-
-[dev-dependencies]
-gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
@@ -1,8 +0,0 @@
-pub mod auth;
-pub mod completion;
-pub mod embedding;
-pub mod models;
-pub mod prompts;
-pub mod providers;
-#[cfg(any(test, feature = "test-support"))]
-pub mod test;
@@ -1,15 +0,0 @@
-use gpui::AppContext;
-
-#[derive(Clone, Debug)]
-pub enum ProviderCredential {
- Credentials { api_key: String },
- NoCredentials,
- NotNeeded,
-}
-
-pub trait CredentialProvider: Send + Sync {
- fn has_credentials(&self) -> bool;
- fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential;
- fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential);
- fn delete_credentials(&self, cx: &mut AppContext);
-}
@@ -1,23 +0,0 @@
-use anyhow::Result;
-use futures::{future::BoxFuture, stream::BoxStream};
-
-use crate::{auth::CredentialProvider, models::LanguageModel};
-
-pub trait CompletionRequest: Send + Sync {
- fn data(&self) -> serde_json::Result<String>;
-}
-
-pub trait CompletionProvider: CredentialProvider {
- fn base_model(&self) -> Box<dyn LanguageModel>;
- fn complete(
- &self,
- prompt: Box<dyn CompletionRequest>,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
- fn box_clone(&self) -> Box<dyn CompletionProvider>;
-}
-
-impl Clone for Box<dyn CompletionProvider> {
- fn clone(&self) -> Box<dyn CompletionProvider> {
- self.box_clone()
- }
-}
@@ -1,123 +0,0 @@
-use std::time::Instant;
-
-use anyhow::Result;
-use async_trait::async_trait;
-use ordered_float::OrderedFloat;
-use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
-use rusqlite::ToSql;
-
-use crate::auth::CredentialProvider;
-use crate::models::LanguageModel;
-
-#[derive(Debug, PartialEq, Clone)]
-pub struct Embedding(pub Vec<f32>);
-
-// This is needed for semantic index functionality
-// Unfortunately it has to live wherever the "Embedding" struct is created.
-// Keeping this in here though, introduces a 'rusqlite' dependency into AI
-// which is less than ideal
-impl FromSql for Embedding {
- fn column_result(value: ValueRef) -> FromSqlResult<Self> {
- let bytes = value.as_blob()?;
- let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
- if embedding.is_err() {
- return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
- }
- Ok(Embedding(embedding.unwrap()))
- }
-}
-
-impl ToSql for Embedding {
- fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
- let bytes = bincode::serialize(&self.0)
- .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
- Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
- }
-}
-impl From<Vec<f32>> for Embedding {
- fn from(value: Vec<f32>) -> Self {
- Embedding(value)
- }
-}
-
-impl Embedding {
- pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
- let len = self.0.len();
- assert_eq!(len, other.0.len());
-
- let mut result = 0.0;
- unsafe {
- matrixmultiply::sgemm(
- 1,
- len,
- 1,
- 1.0,
- self.0.as_ptr(),
- len as isize,
- 1,
- other.0.as_ptr(),
- 1,
- len as isize,
- 0.0,
- &mut result as *mut f32,
- 1,
- 1,
- );
- }
- OrderedFloat(result)
- }
-}
-
-#[async_trait]
-pub trait EmbeddingProvider: CredentialProvider {
- fn base_model(&self) -> Box<dyn LanguageModel>;
- async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
- fn max_tokens_per_batch(&self) -> usize;
- fn rate_limit_expiration(&self) -> Option<Instant>;
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use rand::prelude::*;
-
- #[gpui::test]
- fn test_similarity(mut rng: StdRng) {
- assert_eq!(
- Embedding::from(vec![1., 0., 0., 0., 0.])
- .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
- 0.
- );
- assert_eq!(
- Embedding::from(vec![2., 0., 0., 0., 0.])
- .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
- 6.
- );
-
- for _ in 0..100 {
- let size = 1536;
- let mut a = vec![0.; size];
- let mut b = vec![0.; size];
- for (a, b) in a.iter_mut().zip(b.iter_mut()) {
- *a = rng.gen();
- *b = rng.gen();
- }
- let a = Embedding::from(a);
- let b = Embedding::from(b);
-
- assert_eq!(
- round_to_decimals(a.similarity(&b), 1),
- round_to_decimals(reference_dot(&a.0, &b.0), 1)
- );
- }
-
- fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
- let factor = (10.0 as f32).powi(decimal_places);
- (n * factor).round() / factor
- }
-
- fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
- OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
- }
- }
-}
@@ -1,16 +0,0 @@
-pub enum TruncationDirection {
- Start,
- End,
-}
-
-pub trait LanguageModel {
- fn name(&self) -> String;
- fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
- fn truncate(
- &self,
- content: &str,
- length: usize,
- direction: TruncationDirection,
- ) -> anyhow::Result<String>;
- fn capacity(&self) -> anyhow::Result<usize>;
-}
@@ -1,330 +0,0 @@
-use std::cmp::Reverse;
-use std::ops::Range;
-use std::sync::Arc;
-
-use language::BufferSnapshot;
-use util::ResultExt;
-
-use crate::models::LanguageModel;
-use crate::prompts::repository_context::PromptCodeSnippet;
-
-pub(crate) enum PromptFileType {
- Text,
- Code,
-}
-
-// TODO: Set this up to manage for defaults well
-pub struct PromptArguments {
- pub model: Arc<dyn LanguageModel>,
- pub user_prompt: Option<String>,
- pub language_name: Option<String>,
- pub project_name: Option<String>,
- pub snippets: Vec<PromptCodeSnippet>,
- pub reserved_tokens: usize,
- pub buffer: Option<BufferSnapshot>,
- pub selected_range: Option<Range<usize>>,
-}
-
-impl PromptArguments {
- pub(crate) fn get_file_type(&self) -> PromptFileType {
- if self
- .language_name
- .as_ref()
- .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
- .unwrap_or(true)
- {
- PromptFileType::Code
- } else {
- PromptFileType::Text
- }
- }
-}
-
-pub trait PromptTemplate {
- fn generate(
- &self,
- args: &PromptArguments,
- max_token_length: Option<usize>,
- ) -> anyhow::Result<(String, usize)>;
-}
-
-#[repr(i8)]
-#[derive(PartialEq, Eq, Ord)]
-pub enum PromptPriority {
- Mandatory, // Ignores truncation
- Ordered { order: usize }, // Truncates based on priority
-}
-
-impl PartialOrd for PromptPriority {
- fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
- match (self, other) {
- (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal),
- (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater),
- (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less),
- (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a),
- }
- }
-}
-
-pub struct PromptChain {
- args: PromptArguments,
- templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
-}
-
-impl PromptChain {
- pub fn new(
- args: PromptArguments,
- templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
- ) -> Self {
- PromptChain { args, templates }
- }
-
- pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
- // Argsort based on Prompt Priority
- let seperator = "\n";
- let seperator_tokens = self.args.model.count_tokens(seperator)?;
- let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
- sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
-
- // If Truncate
- let mut tokens_outstanding = if truncate {
- Some(self.args.model.capacity()? - self.args.reserved_tokens)
- } else {
- None
- };
-
- let mut prompts = vec!["".to_string(); sorted_indices.len()];
- for idx in sorted_indices {
- let (_, template) = &self.templates[idx];
-
- if let Some((template_prompt, prompt_token_count)) =
- template.generate(&self.args, tokens_outstanding).log_err()
- {
- if template_prompt != "" {
- prompts[idx] = template_prompt;
-
- if let Some(remaining_tokens) = tokens_outstanding {
- let new_tokens = prompt_token_count + seperator_tokens;
- tokens_outstanding = if remaining_tokens > new_tokens {
- Some(remaining_tokens - new_tokens)
- } else {
- Some(0)
- };
- }
- }
- }
- }
-
- prompts.retain(|x| x != "");
-
- let full_prompt = prompts.join(seperator);
- let total_token_count = self.args.model.count_tokens(&full_prompt)?;
- anyhow::Ok((prompts.join(seperator), total_token_count))
- }
-}
-
-#[cfg(test)]
-pub(crate) mod tests {
- use crate::models::TruncationDirection;
- use crate::test::FakeLanguageModel;
-
- use super::*;
-
- #[test]
- pub fn test_prompt_chain() {
- struct TestPromptTemplate {}
- impl PromptTemplate for TestPromptTemplate {
- fn generate(
- &self,
- args: &PromptArguments,
- max_token_length: Option<usize>,
- ) -> anyhow::Result<(String, usize)> {
- let mut content = "This is a test prompt template".to_string();
-
- let mut token_count = args.model.count_tokens(&content)?;
- if let Some(max_token_length) = max_token_length {
- if token_count > max_token_length {
- content = args.model.truncate(
- &content,
- max_token_length,
- TruncationDirection::End,
- )?;
- token_count = max_token_length;
- }
- }
-
- anyhow::Ok((content, token_count))
- }
- }
-
- struct TestLowPriorityTemplate {}
- impl PromptTemplate for TestLowPriorityTemplate {
- fn generate(
- &self,
- args: &PromptArguments,
- max_token_length: Option<usize>,
- ) -> anyhow::Result<(String, usize)> {
- let mut content = "This is a low priority test prompt template".to_string();
-
- let mut token_count = args.model.count_tokens(&content)?;
- if let Some(max_token_length) = max_token_length {
- if token_count > max_token_length {
- content = args.model.truncate(
- &content,
- max_token_length,
- TruncationDirection::End,
- )?;
- token_count = max_token_length;
- }
- }
-
- anyhow::Ok((content, token_count))
- }
- }
-
- let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
- let args = PromptArguments {
- model: model.clone(),
- language_name: None,
- project_name: None,
- snippets: Vec::new(),
- reserved_tokens: 0,
- buffer: None,
- selected_range: None,
- user_prompt: None,
- };
-
- let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
- (
- PromptPriority::Ordered { order: 0 },
- Box::new(TestPromptTemplate {}),
- ),
- (
- PromptPriority::Ordered { order: 1 },
- Box::new(TestLowPriorityTemplate {}),
- ),
- ];
- let chain = PromptChain::new(args, templates);
-
- let (prompt, token_count) = chain.generate(false).unwrap();
-
- assert_eq!(
- prompt,
- "This is a test prompt template\nThis is a low priority test prompt template"
- .to_string()
- );
-
- assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
-
- // Testing with Truncation Off
- // Should ignore capacity and return all prompts
- let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
- let args = PromptArguments {
- model: model.clone(),
- language_name: None,
- project_name: None,
- snippets: Vec::new(),
- reserved_tokens: 0,
- buffer: None,
- selected_range: None,
- user_prompt: None,
- };
-
- let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
- (
- PromptPriority::Ordered { order: 0 },
- Box::new(TestPromptTemplate {}),
- ),
- (
- PromptPriority::Ordered { order: 1 },
- Box::new(TestLowPriorityTemplate {}),
- ),
- ];
- let chain = PromptChain::new(args, templates);
-
- let (prompt, token_count) = chain.generate(false).unwrap();
-
- assert_eq!(
- prompt,
- "This is a test prompt template\nThis is a low priority test prompt template"
- .to_string()
- );
-
- assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
-
- // Testing with Truncation Off
- // Should ignore capacity and return all prompts
- let capacity = 20;
- let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
- let args = PromptArguments {
- model: model.clone(),
- language_name: None,
- project_name: None,
- snippets: Vec::new(),
- reserved_tokens: 0,
- buffer: None,
- selected_range: None,
- user_prompt: None,
- };
-
- let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
- (
- PromptPriority::Ordered { order: 0 },
- Box::new(TestPromptTemplate {}),
- ),
- (
- PromptPriority::Ordered { order: 1 },
- Box::new(TestLowPriorityTemplate {}),
- ),
- (
- PromptPriority::Ordered { order: 2 },
- Box::new(TestLowPriorityTemplate {}),
- ),
- ];
- let chain = PromptChain::new(args, templates);
-
- let (prompt, token_count) = chain.generate(true).unwrap();
-
- assert_eq!(prompt, "This is a test promp".to_string());
- assert_eq!(token_count, capacity);
-
- // Change Ordering of Prompts Based on Priority
- let capacity = 120;
- let reserved_tokens = 10;
- let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
- let args = PromptArguments {
- model: model.clone(),
- language_name: None,
- project_name: None,
- snippets: Vec::new(),
- reserved_tokens,
- buffer: None,
- selected_range: None,
- user_prompt: None,
- };
- let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
- (
- PromptPriority::Mandatory,
- Box::new(TestLowPriorityTemplate {}),
- ),
- (
- PromptPriority::Ordered { order: 0 },
- Box::new(TestPromptTemplate {}),
- ),
- (
- PromptPriority::Ordered { order: 1 },
- Box::new(TestLowPriorityTemplate {}),
- ),
- ];
- let chain = PromptChain::new(args, templates);
-
- let (prompt, token_count) = chain.generate(true).unwrap();
-
- assert_eq!(
- prompt,
- "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
- .to_string()
- );
- assert_eq!(token_count, capacity - reserved_tokens);
- }
-}
@@ -1,164 +0,0 @@
-use anyhow::anyhow;
-use language::BufferSnapshot;
-use language::ToOffset;
-
-use crate::models::LanguageModel;
-use crate::models::TruncationDirection;
-use crate::prompts::base::PromptArguments;
-use crate::prompts::base::PromptTemplate;
-use std::fmt::Write;
-use std::ops::Range;
-use std::sync::Arc;
-
-fn retrieve_context(
- buffer: &BufferSnapshot,
- selected_range: &Option<Range<usize>>,
- model: Arc<dyn LanguageModel>,
- max_token_count: Option<usize>,
-) -> anyhow::Result<(String, usize, bool)> {
- let mut prompt = String::new();
- let mut truncated = false;
- if let Some(selected_range) = selected_range {
- let start = selected_range.start.to_offset(buffer);
- let end = selected_range.end.to_offset(buffer);
-
- let start_window = buffer.text_for_range(0..start).collect::<String>();
-
- let mut selected_window = String::new();
- if start == end {
- write!(selected_window, "<|START|>").unwrap();
- } else {
- write!(selected_window, "<|START|").unwrap();
- }
-
- write!(
- selected_window,
- "{}",
- buffer.text_for_range(start..end).collect::<String>()
- )
- .unwrap();
-
- if start != end {
- write!(selected_window, "|END|>").unwrap();
- }
-
- let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
-
- if let Some(max_token_count) = max_token_count {
- let selected_tokens = model.count_tokens(&selected_window)?;
- if selected_tokens > max_token_count {
- return Err(anyhow!(
- "selected range is greater than model context window, truncation not possible"
- ));
- };
-
- let mut remaining_tokens = max_token_count - selected_tokens;
- let start_window_tokens = model.count_tokens(&start_window)?;
- let end_window_tokens = model.count_tokens(&end_window)?;
- let outside_tokens = start_window_tokens + end_window_tokens;
- if outside_tokens > remaining_tokens {
- let (start_goal_tokens, end_goal_tokens) =
- if start_window_tokens < end_window_tokens {
- let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
- remaining_tokens -= start_goal_tokens;
- let end_goal_tokens = remaining_tokens.min(end_window_tokens);
- (start_goal_tokens, end_goal_tokens)
- } else {
- let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
- remaining_tokens -= end_goal_tokens;
- let start_goal_tokens = remaining_tokens.min(start_window_tokens);
- (start_goal_tokens, end_goal_tokens)
- };
-
- let truncated_start_window =
- model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
- let truncated_end_window =
- model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
- writeln!(
- prompt,
- "{truncated_start_window}{selected_window}{truncated_end_window}"
- )
- .unwrap();
- truncated = true;
- } else {
- writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
- }
- } else {
- // If we dont have a selected range, include entire file.
- writeln!(prompt, "{}", &buffer.text()).unwrap();
-
- // Dumb truncation strategy
- if let Some(max_token_count) = max_token_count {
- if model.count_tokens(&prompt)? > max_token_count {
- truncated = true;
- prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
- }
- }
- }
- }
-
- let token_count = model.count_tokens(&prompt)?;
- anyhow::Ok((prompt, token_count, truncated))
-}
-
-pub struct FileContext {}
-
-impl PromptTemplate for FileContext {
- fn generate(
- &self,
- args: &PromptArguments,
- max_token_length: Option<usize>,
- ) -> anyhow::Result<(String, usize)> {
- if let Some(buffer) = &args.buffer {
- let mut prompt = String::new();
- // Add Initial Preamble
- // TODO: Do we want to add the path in here?
- writeln!(
- prompt,
- "The file you are currently working on has the following content:"
- )
- .unwrap();
-
- let language_name = args
- .language_name
- .clone()
- .unwrap_or("".to_string())
- .to_lowercase();
-
- let (context, _, truncated) = retrieve_context(
- buffer,
- &args.selected_range,
- args.model.clone(),
- max_token_length,
- )?;
- writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
-
- if truncated {
- writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
- }
-
- if let Some(selected_range) = &args.selected_range {
- let start = selected_range.start.to_offset(buffer);
- let end = selected_range.end.to_offset(buffer);
-
- if start == end {
- writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
- } else {
- writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
- }
- }
-
- // Really dumb truncation strategy
- if let Some(max_tokens) = max_token_length {
- prompt = args
- .model
- .truncate(&prompt, max_tokens, TruncationDirection::End)?;
- }
-
- let token_count = args.model.count_tokens(&prompt)?;
- anyhow::Ok((prompt, token_count))
- } else {
- Err(anyhow!("no buffer provided to retrieve file context from"))
- }
- }
-}
@@ -1,99 +0,0 @@
-use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
-use anyhow::anyhow;
-use std::fmt::Write;
-
-pub fn capitalize(s: &str) -> String {
- let mut c = s.chars();
- match c.next() {
- None => String::new(),
- Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
- }
-}
-
-pub struct GenerateInlineContent {}
-
-impl PromptTemplate for GenerateInlineContent {
- fn generate(
- &self,
- args: &PromptArguments,
- max_token_length: Option<usize>,
- ) -> anyhow::Result<(String, usize)> {
- let Some(user_prompt) = &args.user_prompt else {
- return Err(anyhow!("user prompt not provided"));
- };
-
- let file_type = args.get_file_type();
- let content_type = match &file_type {
- PromptFileType::Code => "code",
- PromptFileType::Text => "text",
- };
-
- let mut prompt = String::new();
-
- if let Some(selected_range) = &args.selected_range {
- if selected_range.start == selected_range.end {
- writeln!(
- prompt,
- "Assume the cursor is located where the `<|START|>` span is."
- )
- .unwrap();
- writeln!(
- prompt,
- "{} can't be replaced, so assume your answer will be inserted at the cursor.",
- capitalize(content_type)
- )
- .unwrap();
- writeln!(
- prompt,
- "Generate {content_type} based on the users prompt: {user_prompt}",
- )
- .unwrap();
- } else {
- writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
- writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
- writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap();
- }
- } else {
- writeln!(
- prompt,
- "Generate {content_type} based on the users prompt: {user_prompt}"
- )
- .unwrap();
- }
-
- if let Some(language_name) = &args.language_name {
- writeln!(
- prompt,
- "Your answer MUST always and only be valid {}.",
- language_name
- )
- .unwrap();
- }
- writeln!(prompt, "Never make remarks about the output.").unwrap();
- writeln!(
- prompt,
- "Do not return anything else, except the generated {content_type}."
- )
- .unwrap();
-
- match file_type {
- PromptFileType::Code => {
- // writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap();
- }
- _ => {}
- }
-
- // Really dumb truncation strategy
- if let Some(max_tokens) = max_token_length {
- prompt = args.model.truncate(
- &prompt,
- max_tokens,
- crate::models::TruncationDirection::End,
- )?;
- }
-
- let token_count = args.model.count_tokens(&prompt)?;
-
- anyhow::Ok((prompt, token_count))
- }
-}
@@ -1,5 +0,0 @@
-pub mod base;
-pub mod file_context;
-pub mod generate;
-pub mod preamble;
-pub mod repository_context;
@@ -1,52 +0,0 @@
-use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
-use std::fmt::Write;
-
-pub struct EngineerPreamble {}
-
-impl PromptTemplate for EngineerPreamble {
- fn generate(
- &self,
- args: &PromptArguments,
- max_token_length: Option<usize>,
- ) -> anyhow::Result<(String, usize)> {
- let mut prompts = Vec::new();
-
- match args.get_file_type() {
- PromptFileType::Code => {
- prompts.push(format!(
- "You are an expert {}engineer.",
- args.language_name.clone().unwrap_or("".to_string()) + " "
- ));
- }
- PromptFileType::Text => {
- prompts.push("You are an expert engineer.".to_string());
- }
- }
-
- if let Some(project_name) = args.project_name.clone() {
- prompts.push(format!(
- "You are currently working inside the '{project_name}' project in code editor Zed."
- ));
- }
-
- if let Some(mut remaining_tokens) = max_token_length {
- let mut prompt = String::new();
- let mut total_count = 0;
- for prompt_piece in prompts {
- let prompt_token_count =
- args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?;
- if remaining_tokens > prompt_token_count {
- writeln!(prompt, "{prompt_piece}").unwrap();
- remaining_tokens -= prompt_token_count;
- total_count += prompt_token_count;
- }
- }
-
- anyhow::Ok((prompt, total_count))
- } else {
- let prompt = prompts.join("\n");
- let token_count = args.model.count_tokens(&prompt)?;
- anyhow::Ok((prompt, token_count))
- }
- }
-}
@@ -1,98 +0,0 @@
-use crate::prompts::base::{PromptArguments, PromptTemplate};
-use std::fmt::Write;
-use std::{ops::Range, path::PathBuf};
-
-use gpui::{AsyncAppContext, Model};
-use language::{Anchor, Buffer};
-
-#[derive(Clone)]
-pub struct PromptCodeSnippet {
- path: Option<PathBuf>,
- language_name: Option<String>,
- content: String,
-}
-
-impl PromptCodeSnippet {
- pub fn new(
- buffer: Model<Buffer>,
- range: Range<Anchor>,
- cx: &mut AsyncAppContext,
- ) -> anyhow::Result<Self> {
- let (content, language_name, file_path) = buffer.update(cx, |buffer, _| {
- let snapshot = buffer.snapshot();
- let content = snapshot.text_for_range(range.clone()).collect::<String>();
-
- let language_name = buffer
- .language()
- .and_then(|language| Some(language.name().to_string().to_lowercase()));
-
- let file_path = buffer
- .file()
- .and_then(|file| Some(file.path().to_path_buf()));
-
- (content, language_name, file_path)
- })?;
-
- anyhow::Ok(PromptCodeSnippet {
- path: file_path,
- language_name,
- content,
- })
- }
-}
-
-impl ToString for PromptCodeSnippet {
- fn to_string(&self) -> String {
- let path = self
- .path
- .as_ref()
- .and_then(|path| Some(path.to_string_lossy().to_string()))
- .unwrap_or("".to_string());
- let language_name = self.language_name.clone().unwrap_or("".to_string());
- let content = self.content.clone();
-
- format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
- }
-}
-
-pub struct RepositoryContext {}
-
-impl PromptTemplate for RepositoryContext {
- fn generate(
- &self,
- args: &PromptArguments,
- max_token_length: Option<usize>,
- ) -> anyhow::Result<(String, usize)> {
- const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
- let template = "You are working inside a large repository, here are a few code snippets that may be useful.";
- let mut prompt = String::new();
-
- let mut remaining_tokens = max_token_length.clone();
- let seperator_token_length = args.model.count_tokens("\n")?;
- for snippet in &args.snippets {
- let mut snippet_prompt = template.to_string();
- let content = snippet.to_string();
- writeln!(snippet_prompt, "{content}").unwrap();
-
- let token_count = args.model.count_tokens(&snippet_prompt)?;
- if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
- if let Some(tokens_left) = remaining_tokens {
- if tokens_left >= token_count {
- writeln!(prompt, "{snippet_prompt}").unwrap();
- remaining_tokens = if tokens_left >= (token_count + seperator_token_length)
- {
- Some(tokens_left - token_count - seperator_token_length)
- } else {
- Some(0)
- };
- }
- } else {
- writeln!(prompt, "{snippet_prompt}").unwrap();
- }
- }
- }
-
- let total_token_count = args.model.count_tokens(&prompt)?;
- anyhow::Ok((prompt, total_token_count))
- }
-}
@@ -1 +0,0 @@
-pub mod open_ai;
@@ -1,297 +0,0 @@
-use anyhow::{anyhow, Result};
-use futures::{
- future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
- Stream, StreamExt,
-};
-use gpui::{AppContext, BackgroundExecutor};
-use isahc::{http::StatusCode, Request, RequestExt};
-use parking_lot::RwLock;
-use serde::{Deserialize, Serialize};
-use std::{
- env,
- fmt::{self, Display},
- io,
- sync::Arc,
-};
-use util::ResultExt;
-
-use crate::{
- auth::{CredentialProvider, ProviderCredential},
- completion::{CompletionProvider, CompletionRequest},
- models::LanguageModel,
-};
-
-use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
-
-#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
-#[serde(rename_all = "lowercase")]
-pub enum Role {
- User,
- Assistant,
- System,
-}
-
-impl Role {
- pub fn cycle(&mut self) {
- *self = match self {
- Role::User => Role::Assistant,
- Role::Assistant => Role::System,
- Role::System => Role::User,
- }
- }
-}
-
-impl Display for Role {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
- match self {
- Role::User => write!(f, "User"),
- Role::Assistant => write!(f, "Assistant"),
- Role::System => write!(f, "System"),
- }
- }
-}
-
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct RequestMessage {
- pub role: Role,
- pub content: String,
-}
-
-#[derive(Debug, Default, Serialize)]
-pub struct OpenAIRequest {
- pub model: String,
- pub messages: Vec<RequestMessage>,
- pub stream: bool,
- pub stop: Vec<String>,
- pub temperature: f32,
-}
-
-impl CompletionRequest for OpenAIRequest {
- fn data(&self) -> serde_json::Result<String> {
- serde_json::to_string(self)
- }
-}
-
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct ResponseMessage {
- pub role: Option<Role>,
- pub content: Option<String>,
-}
-
-#[derive(Deserialize, Debug)]
-pub struct OpenAIUsage {
- pub prompt_tokens: u32,
- pub completion_tokens: u32,
- pub total_tokens: u32,
-}
-
-#[derive(Deserialize, Debug)]
-pub struct ChatChoiceDelta {
- pub index: u32,
- pub delta: ResponseMessage,
- pub finish_reason: Option<String>,
-}
-
-#[derive(Deserialize, Debug)]
-pub struct OpenAIResponseStreamEvent {
- pub id: Option<String>,
- pub object: String,
- pub created: u32,
- pub model: String,
- pub choices: Vec<ChatChoiceDelta>,
- pub usage: Option<OpenAIUsage>,
-}
-
-pub async fn stream_completion(
- credential: ProviderCredential,
- executor: BackgroundExecutor,
- request: Box<dyn CompletionRequest>,
-) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
- let api_key = match credential {
- ProviderCredential::Credentials { api_key } => api_key,
- _ => {
- return Err(anyhow!("no credentials provider for completion"));
- }
- };
-
- let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
-
- let json_data = request.data()?;
- let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", api_key))
- .body(json_data)?
- .send_async()
- .await?;
-
- let status = response.status();
- if status == StatusCode::OK {
- executor
- .spawn(async move {
- let mut lines = BufReader::new(response.body_mut()).lines();
-
- fn parse_line(
- line: Result<String, io::Error>,
- ) -> Result<Option<OpenAIResponseStreamEvent>> {
- if let Some(data) = line?.strip_prefix("data: ") {
- let event = serde_json::from_str(&data)?;
- Ok(Some(event))
- } else {
- Ok(None)
- }
- }
-
- while let Some(line) = lines.next().await {
- if let Some(event) = parse_line(line).transpose() {
- let done = event.as_ref().map_or(false, |event| {
- event
- .choices
- .last()
- .map_or(false, |choice| choice.finish_reason.is_some())
- });
- if tx.unbounded_send(event).is_err() {
- break;
- }
-
- if done {
- break;
- }
- }
- }
-
- anyhow::Ok(())
- })
- .detach();
-
- Ok(rx)
- } else {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
-
- #[derive(Deserialize)]
- struct OpenAIResponse {
- error: OpenAIError,
- }
-
- #[derive(Deserialize)]
- struct OpenAIError {
- message: String,
- }
-
- match serde_json::from_str::<OpenAIResponse>(&body) {
- Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
- "Failed to connect to OpenAI API: {}",
- response.error.message,
- )),
-
- _ => Err(anyhow!(
- "Failed to connect to OpenAI API: {} {}",
- response.status(),
- body,
- )),
- }
- }
-}
-
-#[derive(Clone)]
-pub struct OpenAICompletionProvider {
- model: OpenAILanguageModel,
- credential: Arc<RwLock<ProviderCredential>>,
- executor: BackgroundExecutor,
-}
-
-impl OpenAICompletionProvider {
- pub fn new(model_name: &str, executor: BackgroundExecutor) -> Self {
- let model = OpenAILanguageModel::load(model_name);
- let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
- Self {
- model,
- credential,
- executor,
- }
- }
-}
-
-impl CredentialProvider for OpenAICompletionProvider {
- fn has_credentials(&self) -> bool {
- match *self.credential.read() {
- ProviderCredential::Credentials { .. } => true,
- _ => false,
- }
- }
-
- fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
- let existing_credential = self.credential.read().clone();
- let retrieved_credential = match existing_credential {
- ProviderCredential::Credentials { .. } => existing_credential.clone(),
- _ => {
- if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
- ProviderCredential::Credentials { api_key }
- } else if let Some(Some((_, api_key))) =
- cx.read_credentials(OPENAI_API_URL).log_err()
- {
- if let Some(api_key) = String::from_utf8(api_key).log_err() {
- ProviderCredential::Credentials { api_key }
- } else {
- ProviderCredential::NoCredentials
- }
- } else {
- ProviderCredential::NoCredentials
- }
- }
- };
- *self.credential.write() = retrieved_credential.clone();
- retrieved_credential
- }
-
- fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
- *self.credential.write() = credential.clone();
- let credential = credential.clone();
- match credential {
- ProviderCredential::Credentials { api_key } => {
- cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
- .log_err();
- }
- _ => {}
- }
- }
-
- fn delete_credentials(&self, cx: &mut AppContext) {
- cx.delete_credentials(OPENAI_API_URL).log_err();
- *self.credential.write() = ProviderCredential::NoCredentials;
- }
-}
-
-impl CompletionProvider for OpenAICompletionProvider {
- fn base_model(&self) -> Box<dyn LanguageModel> {
- let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
- model
- }
- fn complete(
- &self,
- prompt: Box<dyn CompletionRequest>,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
- // Currently the CompletionRequest for OpenAI, includes a 'model' parameter
- // This means that the model is determined by the CompletionRequest and not the CompletionProvider,
- // which is currently model based, due to the langauge model.
- // At some point in the future we should rectify this.
- let credential = self.credential.read().clone();
- let request = stream_completion(credential, self.executor.clone(), prompt);
- async move {
- let response = request.await?;
- let stream = response
- .filter_map(|response| async move {
- match response {
- Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
- Err(error) => Some(Err(error)),
- }
- })
- .boxed();
- Ok(stream)
- }
- .boxed()
- }
- fn box_clone(&self) -> Box<dyn CompletionProvider> {
- Box::new((*self).clone())
- }
-}
@@ -1,305 +0,0 @@
-use anyhow::{anyhow, Result};
-use async_trait::async_trait;
-use futures::AsyncReadExt;
-use gpui::BackgroundExecutor;
-use gpui::{serde_json, AppContext};
-use isahc::http::StatusCode;
-use isahc::prelude::Configurable;
-use isahc::{AsyncBody, Response};
-use lazy_static::lazy_static;
-use parking_lot::{Mutex, RwLock};
-use parse_duration::parse;
-use postage::watch;
-use serde::{Deserialize, Serialize};
-use std::env;
-use std::ops::Add;
-use std::sync::Arc;
-use std::time::{Duration, Instant};
-use tiktoken_rs::{cl100k_base, CoreBPE};
-use util::http::{HttpClient, Request};
-use util::ResultExt;
-
-use crate::auth::{CredentialProvider, ProviderCredential};
-use crate::embedding::{Embedding, EmbeddingProvider};
-use crate::models::LanguageModel;
-use crate::providers::open_ai::OpenAILanguageModel;
-
-use crate::providers::open_ai::OPENAI_API_URL;
-
-lazy_static! {
- static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
-}
-
-#[derive(Clone)]
-pub struct OpenAIEmbeddingProvider {
- model: OpenAILanguageModel,
- credential: Arc<RwLock<ProviderCredential>>,
- pub client: Arc<dyn HttpClient>,
- pub executor: BackgroundExecutor,
- rate_limit_count_rx: watch::Receiver<Option<Instant>>,
- rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
-}
-
-#[derive(Serialize)]
-struct OpenAIEmbeddingRequest<'a> {
- model: &'static str,
- input: Vec<&'a str>,
-}
-
-#[derive(Deserialize)]
-struct OpenAIEmbeddingResponse {
- data: Vec<OpenAIEmbedding>,
- usage: OpenAIEmbeddingUsage,
-}
-
-#[derive(Debug, Deserialize)]
-struct OpenAIEmbedding {
- embedding: Vec<f32>,
- index: usize,
- object: String,
-}
-
-#[derive(Deserialize)]
-struct OpenAIEmbeddingUsage {
- prompt_tokens: usize,
- total_tokens: usize,
-}
-
-impl OpenAIEmbeddingProvider {
- pub fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self {
- let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
- let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
-
- let model = OpenAILanguageModel::load("text-embedding-ada-002");
- let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
-
- OpenAIEmbeddingProvider {
- model,
- credential,
- client,
- executor,
- rate_limit_count_rx,
- rate_limit_count_tx,
- }
- }
-
- fn get_api_key(&self) -> Result<String> {
- match self.credential.read().clone() {
- ProviderCredential::Credentials { api_key } => Ok(api_key),
- _ => Err(anyhow!("api credentials not provided")),
- }
- }
-
- fn resolve_rate_limit(&self) {
- let reset_time = *self.rate_limit_count_tx.lock().borrow();
-
- if let Some(reset_time) = reset_time {
- if Instant::now() >= reset_time {
- *self.rate_limit_count_tx.lock().borrow_mut() = None
- }
- }
-
- log::trace!(
- "resolving reset time: {:?}",
- *self.rate_limit_count_tx.lock().borrow()
- );
- }
-
- fn update_reset_time(&self, reset_time: Instant) {
- let original_time = *self.rate_limit_count_tx.lock().borrow();
-
- let updated_time = if let Some(original_time) = original_time {
- if reset_time < original_time {
- Some(reset_time)
- } else {
- Some(original_time)
- }
- } else {
- Some(reset_time)
- };
-
- log::trace!("updating rate limit time: {:?}", updated_time);
-
- *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
- }
- async fn send_request(
- &self,
- api_key: &str,
- spans: Vec<&str>,
- request_timeout: u64,
- ) -> Result<Response<AsyncBody>> {
- let request = Request::post("https://api.openai.com/v1/embeddings")
- .redirect_policy(isahc::config::RedirectPolicy::Follow)
- .timeout(Duration::from_secs(request_timeout))
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", api_key))
- .body(
- serde_json::to_string(&OpenAIEmbeddingRequest {
- input: spans.clone(),
- model: "text-embedding-ada-002",
- })
- .unwrap()
- .into(),
- )?;
-
- Ok(self.client.send(request).await?)
- }
-}
-
-impl CredentialProvider for OpenAIEmbeddingProvider {
- fn has_credentials(&self) -> bool {
- match *self.credential.read() {
- ProviderCredential::Credentials { .. } => true,
- _ => false,
- }
- }
- fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
- let existing_credential = self.credential.read().clone();
-
- let retrieved_credential = match existing_credential {
- ProviderCredential::Credentials { .. } => existing_credential.clone(),
- _ => {
- if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
- ProviderCredential::Credentials { api_key }
- } else if let Some(Some((_, api_key))) =
- cx.read_credentials(OPENAI_API_URL).log_err()
- {
- if let Some(api_key) = String::from_utf8(api_key).log_err() {
- ProviderCredential::Credentials { api_key }
- } else {
- ProviderCredential::NoCredentials
- }
- } else {
- ProviderCredential::NoCredentials
- }
- }
- };
-
- *self.credential.write() = retrieved_credential.clone();
- retrieved_credential
- }
-
- fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
- *self.credential.write() = credential.clone();
- match credential {
- ProviderCredential::Credentials { api_key } => {
- cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
- .log_err();
- }
- _ => {}
- }
- }
-
- fn delete_credentials(&self, cx: &mut AppContext) {
- cx.delete_credentials(OPENAI_API_URL).log_err();
- *self.credential.write() = ProviderCredential::NoCredentials;
- }
-}
-
-#[async_trait]
-impl EmbeddingProvider for OpenAIEmbeddingProvider {
- fn base_model(&self) -> Box<dyn LanguageModel> {
- let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
- model
- }
-
- fn max_tokens_per_batch(&self) -> usize {
- 50000
- }
-
- fn rate_limit_expiration(&self) -> Option<Instant> {
- *self.rate_limit_count_rx.borrow()
- }
-
- async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
- const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
- const MAX_RETRIES: usize = 4;
-
- let api_key = self.get_api_key()?;
-
- let mut request_number = 0;
- let mut rate_limiting = false;
- let mut request_timeout: u64 = 15;
- let mut response: Response<AsyncBody>;
- while request_number < MAX_RETRIES {
- response = self
- .send_request(
- &api_key,
- spans.iter().map(|x| &**x).collect(),
- request_timeout,
- )
- .await?;
-
- request_number += 1;
-
- match response.status() {
- StatusCode::REQUEST_TIMEOUT => {
- request_timeout += 5;
- }
- StatusCode::OK => {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
-
- log::trace!(
- "openai embedding completed. tokens: {:?}",
- response.usage.total_tokens
- );
-
- // If we complete a request successfully that was previously rate_limited
- // resolve the rate limit
- if rate_limiting {
- self.resolve_rate_limit()
- }
-
- return Ok(response
- .data
- .into_iter()
- .map(|embedding| Embedding::from(embedding.embedding))
- .collect());
- }
- StatusCode::TOO_MANY_REQUESTS => {
- rate_limiting = true;
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
-
- let delay_duration = {
- let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
- if let Some(time_to_reset) =
- response.headers().get("x-ratelimit-reset-tokens")
- {
- if let Ok(time_str) = time_to_reset.to_str() {
- parse(time_str).unwrap_or(delay)
- } else {
- delay
- }
- } else {
- delay
- }
- };
-
- // If we've previously rate limited, increment the duration but not the count
- let reset_time = Instant::now().add(delay_duration);
- self.update_reset_time(reset_time);
-
- log::trace!(
- "openai rate limiting: waiting {:?} until lifted",
- &delay_duration
- );
-
- self.executor.timer(delay_duration).await;
- }
- _ => {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- return Err(anyhow!(
- "open ai bad request: {:?} {:?}",
- &response.status(),
- body
- ));
- }
- }
- }
- Err(anyhow!("openai max retries"))
- }
-}
@@ -1,9 +0,0 @@
-pub mod completion;
-pub mod embedding;
-pub mod model;
-
-pub use completion::*;
-pub use embedding::*;
-pub use model::OpenAILanguageModel;
-
-pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
@@ -1,57 +0,0 @@
-use anyhow::anyhow;
-use tiktoken_rs::CoreBPE;
-use util::ResultExt;
-
-use crate::models::{LanguageModel, TruncationDirection};
-
-#[derive(Clone)]
-pub struct OpenAILanguageModel {
- name: String,
- bpe: Option<CoreBPE>,
-}
-
-impl OpenAILanguageModel {
- pub fn load(model_name: &str) -> Self {
- let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
- OpenAILanguageModel {
- name: model_name.to_string(),
- bpe,
- }
- }
-}
-
-impl LanguageModel for OpenAILanguageModel {
- fn name(&self) -> String {
- self.name.clone()
- }
- fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
- if let Some(bpe) = &self.bpe {
- anyhow::Ok(bpe.encode_with_special_tokens(content).len())
- } else {
- Err(anyhow!("bpe for open ai model was not retrieved"))
- }
- }
- fn truncate(
- &self,
- content: &str,
- length: usize,
- direction: TruncationDirection,
- ) -> anyhow::Result<String> {
- if let Some(bpe) = &self.bpe {
- let tokens = bpe.encode_with_special_tokens(content);
- if tokens.len() > length {
- match direction {
- TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
- TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
- }
- } else {
- bpe.decode(tokens)
- }
- } else {
- Err(anyhow!("bpe for open ai model was not retrieved"))
- }
- }
- fn capacity(&self) -> anyhow::Result<usize> {
- anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
- }
-}
@@ -1,11 +0,0 @@
-pub trait LanguageModel {
- fn name(&self) -> String;
- fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
- fn truncate(
- &self,
- content: &str,
- length: usize,
- direction: TruncationDirection,
- ) -> anyhow::Result<String>;
- fn capacity(&self) -> anyhow::Result<usize>;
-}
@@ -1,191 +0,0 @@
-use std::{
- sync::atomic::{self, AtomicUsize, Ordering},
- time::Instant,
-};
-
-use async_trait::async_trait;
-use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
-use gpui::AppContext;
-use parking_lot::Mutex;
-
-use crate::{
- auth::{CredentialProvider, ProviderCredential},
- completion::{CompletionProvider, CompletionRequest},
- embedding::{Embedding, EmbeddingProvider},
- models::{LanguageModel, TruncationDirection},
-};
-
-#[derive(Clone)]
-pub struct FakeLanguageModel {
- pub capacity: usize,
-}
-
-impl LanguageModel for FakeLanguageModel {
- fn name(&self) -> String {
- "dummy".to_string()
- }
- fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
- anyhow::Ok(content.chars().collect::<Vec<char>>().len())
- }
- fn truncate(
- &self,
- content: &str,
- length: usize,
- direction: TruncationDirection,
- ) -> anyhow::Result<String> {
- println!("TRYING TO TRUNCATE: {:?}", length.clone());
-
- if length > self.count_tokens(content)? {
- println!("NOT TRUNCATING");
- return anyhow::Ok(content.to_string());
- }
-
- anyhow::Ok(match direction {
- TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
- .into_iter()
- .collect::<String>(),
- TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
- .into_iter()
- .collect::<String>(),
- })
- }
- fn capacity(&self) -> anyhow::Result<usize> {
- anyhow::Ok(self.capacity)
- }
-}
-
-pub struct FakeEmbeddingProvider {
- pub embedding_count: AtomicUsize,
-}
-
-impl Clone for FakeEmbeddingProvider {
- fn clone(&self) -> Self {
- FakeEmbeddingProvider {
- embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
- }
- }
-}
-
-impl Default for FakeEmbeddingProvider {
- fn default() -> Self {
- FakeEmbeddingProvider {
- embedding_count: AtomicUsize::default(),
- }
- }
-}
-
-impl FakeEmbeddingProvider {
- pub fn embedding_count(&self) -> usize {
- self.embedding_count.load(atomic::Ordering::SeqCst)
- }
-
- pub fn embed_sync(&self, span: &str) -> Embedding {
- let mut result = vec![1.0; 26];
- for letter in span.chars() {
- let letter = letter.to_ascii_lowercase();
- if letter as u32 >= 'a' as u32 {
- let ix = (letter as u32) - ('a' as u32);
- if ix < 26 {
- result[ix as usize] += 1.0;
- }
- }
- }
-
- let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
- for x in &mut result {
- *x /= norm;
- }
-
- result.into()
- }
-}
-
-impl CredentialProvider for FakeEmbeddingProvider {
- fn has_credentials(&self) -> bool {
- true
- }
- fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
- ProviderCredential::NotNeeded
- }
- fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
- fn delete_credentials(&self, _cx: &mut AppContext) {}
-}
-
-#[async_trait]
-impl EmbeddingProvider for FakeEmbeddingProvider {
- fn base_model(&self) -> Box<dyn LanguageModel> {
- Box::new(FakeLanguageModel { capacity: 1000 })
- }
- fn max_tokens_per_batch(&self) -> usize {
- 1000
- }
-
- fn rate_limit_expiration(&self) -> Option<Instant> {
- None
- }
-
- async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
- self.embedding_count
- .fetch_add(spans.len(), atomic::Ordering::SeqCst);
-
- anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
- }
-}
-
-pub struct FakeCompletionProvider {
- last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
-}
-
-impl Clone for FakeCompletionProvider {
- fn clone(&self) -> Self {
- Self {
- last_completion_tx: Mutex::new(None),
- }
- }
-}
-
-impl FakeCompletionProvider {
- pub fn new() -> Self {
- Self {
- last_completion_tx: Mutex::new(None),
- }
- }
-
- pub fn send_completion(&self, completion: impl Into<String>) {
- let mut tx = self.last_completion_tx.lock();
- tx.as_mut().unwrap().try_send(completion.into()).unwrap();
- }
-
- pub fn finish_completion(&self) {
- self.last_completion_tx.lock().take().unwrap();
- }
-}
-
-impl CredentialProvider for FakeCompletionProvider {
- fn has_credentials(&self) -> bool {
- true
- }
- fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
- ProviderCredential::NotNeeded
- }
- fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
- fn delete_credentials(&self, _cx: &mut AppContext) {}
-}
-
-impl CompletionProvider for FakeCompletionProvider {
- fn base_model(&self) -> Box<dyn LanguageModel> {
- let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
- model
- }
- fn complete(
- &self,
- _prompt: Box<dyn CompletionRequest>,
- ) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
- let (tx, rx) = mpsc::channel(1);
- *self.last_completion_tx.lock() = Some(tx);
- async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
- }
- fn box_clone(&self) -> Box<dyn CompletionProvider> {
- Box::new((*self).clone())
- }
-}
@@ -9,7 +9,7 @@ path = "src/assistant.rs"
doctest = false
[dependencies]
-ai = { package = "ai2", path = "../ai2" }
+ai = { path = "../ai" }
client = { package = "client2", path = "../client2" }
collections = { path = "../collections"}
editor = { path = "../editor" }
@@ -44,7 +44,7 @@ smol.workspace = true
tiktoken-rs.workspace = true
[dev-dependencies]
-ai = { package = "ai2", path = "../ai2", features = ["test-support"]}
+ai = { path = "../ai", features = ["test-support"]}
editor = { path = "../editor", features = ["test-support"] }
project = { package = "project2", path = "../project2", features = ["test-support"] }
@@ -21,7 +21,7 @@ rpc = { package = "rpc2", path = "../rpc2" }
text = { package = "text2", path = "../text2" }
language = { package = "language2", path = "../language2" }
settings = { package = "settings2", path = "../settings2" }
-feature_flags = { package = "feature_flags2", path = "../feature_flags2" }
+feature_flags = { path = "../feature_flags" }
sum_tree = { path = "../sum_tree" }
clock = { path = "../clock" }
@@ -20,7 +20,7 @@ util = { path = "../util" }
rpc = { package = "rpc2", path = "../rpc2" }
text = { package = "text2", path = "../text2" }
settings = { package = "settings2", path = "../settings2" }
-feature_flags = { package = "feature_flags2", path = "../feature_flags2" }
+feature_flags = { path = "../feature_flags" }
sum_tree = { path = "../sum_tree" }
anyhow.workspace = true
@@ -38,20 +38,20 @@ gpui = { package = "gpui2", path = "../gpui2" }
language = { package = "language2", path = "../language2" }
menu = { package = "menu2", path = "../menu2" }
notifications = { package = "notifications2", path = "../notifications2" }
-rich_text = { package = "rich_text2", path = "../rich_text2" }
+rich_text = { path = "../rich_text" }
picker = { path = "../picker" }
project = { package = "project2", path = "../project2" }
recent_projects = { path = "../recent_projects" }
rpc = { package ="rpc2", path = "../rpc2" }
settings = { package = "settings2", path = "../settings2" }
-feature_flags = { package = "feature_flags2", path = "../feature_flags2"}
+feature_flags = { path = "../feature_flags"}
theme = { package = "theme2", path = "../theme2" }
theme_selector = { path = "../theme_selector" }
vcs_menu = { path = "../vcs_menu" }
ui = { package = "ui2", path = "../ui2" }
util = { path = "../util" }
workspace = { path = "../workspace" }
-zed-actions = { package="zed_actions2", path = "../zed_actions2"}
+zed_actions = { path = "../zed_actions"}
anyhow.workspace = true
futures.workspace = true
@@ -20,7 +20,7 @@ ui = { package = "ui2", path = "../ui2" }
util = { path = "../util" }
theme = { package = "theme2", path = "../theme2" }
workspace = { path = "../workspace" }
-zed_actions = { package = "zed_actions2", path = "../zed_actions2" }
+zed_actions = { path = "../zed_actions" }
anyhow.workspace = true
serde.workspace = true
@@ -1,36 +0,0 @@
-[package]
-name = "command_palette"
-version = "0.1.0"
-edition = "2021"
-publish = false
-
-[lib]
-path = "src/command_palette.rs"
-doctest = false
-
-[dependencies]
-collections = { path = "../collections" }
-editor = { path = "../editor" }
-fuzzy = { package = "fuzzy2", path = "../fuzzy2" }
-gpui = { package = "gpui2", path = "../gpui2" }
-picker = { path = "../picker" }
-project = { package = "project2", path = "../project2" }
-settings = { package = "settings2", path = "../settings2" }
-ui = { package = "ui2", path = "../ui2" }
-util = { path = "../util" }
-theme = { package = "theme2", path = "../theme2" }
-workspace = { package="workspace2", path = "../workspace2" }
-zed_actions = { package = "zed_actions2", path = "../zed_actions2" }
-anyhow.workspace = true
-serde.workspace = true
-[dev-dependencies]
-gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
-editor = { path = "../editor", features = ["test-support"] }
-language = { package="language2", path = "../language2", features = ["test-support"] }
-project = { package="project2", path = "../project2", features = ["test-support"] }
-menu = { package = "menu2", path = "../menu2" }
-go_to_line = { package = "go_to_line2", path = "../go_to_line2" }
-serde_json.workspace = true
-workspace = { path = "../workspace", features = ["test-support"] }
-ctor.workspace = true
-env_logger.workspace = true
@@ -12,7 +12,7 @@ doctest = false
copilot = { path = "../copilot" }
editor = { path = "../editor" }
fs = { package = "fs2", path = "../fs2" }
-zed-actions = { package="zed_actions2", path = "../zed_actions2"}
+zed_actions = { path = "../zed_actions"}
gpui = { package = "gpui2", path = "../gpui2" }
language = { package = "language2", path = "../language2" }
settings = { package = "settings2", path = "../settings2" }
@@ -37,7 +37,7 @@ lsp = { package = "lsp2", path = "../lsp2" }
multi_buffer = { path = "../multi_buffer" }
project = { package = "project2", path = "../project2" }
rpc = { package = "rpc2", path = "../rpc2" }
-rich_text = { package = "rich_text2", path = "../rich_text2" }
+rich_text = { path = "../rich_text" }
settings = { package="settings2", path = "../settings2" }
snippet = { path = "../snippet" }
sum_tree = { path = "../sum_tree" }
@@ -8,5 +8,5 @@ publish = false
path = "src/feature_flags.rs"
[dependencies]
-gpui = { path = "../gpui" }
+gpui = { package = "gpui2", path = "../gpui2" }
anyhow.workspace = true
@@ -25,15 +25,18 @@ impl FeatureFlag for ChannelsAlpha {
pub trait FeatureFlagViewExt<V: 'static> {
fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
where
- F: Fn(bool, &mut V, &mut ViewContext<V>) + 'static;
+ F: Fn(bool, &mut V, &mut ViewContext<V>) + Send + Sync + 'static;
}
-impl<V: 'static> FeatureFlagViewExt<V> for ViewContext<'_, '_, V> {
+impl<V> FeatureFlagViewExt<V> for ViewContext<'_, V>
+where
+ V: 'static,
+{
fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
where
F: Fn(bool, &mut V, &mut ViewContext<V>) + 'static,
{
- self.observe_global::<FeatureFlags, _>(move |v, cx| {
+ self.observe_global::<FeatureFlags>(move |v, cx| {
let feature_flags = cx.global::<FeatureFlags>();
callback(feature_flags.has_flag(<T as FeatureFlag>::NAME), v, cx);
})
@@ -49,16 +52,14 @@ pub trait FeatureFlagAppExt {
impl FeatureFlagAppExt for AppContext {
fn update_flags(&mut self, staff: bool, flags: Vec<String>) {
- self.update_default_global::<FeatureFlags, _, _>(|feature_flags, _| {
- feature_flags.staff = staff;
- feature_flags.flags = flags;
- })
+ let feature_flags = self.default_global::<FeatureFlags>();
+ feature_flags.staff = staff;
+ feature_flags.flags = flags;
}
fn set_staff(&mut self, staff: bool) {
- self.update_default_global::<FeatureFlags, _, _>(|feature_flags, _| {
- feature_flags.staff = staff;
- })
+ let feature_flags = self.default_global::<FeatureFlags>();
+ feature_flags.staff = staff;
}
fn has_flag<T: FeatureFlag>(&self) -> bool {
@@ -1,12 +0,0 @@
-[package]
-name = "feature_flags2"
-version = "0.1.0"
-edition = "2021"
-publish = false
-
-[lib]
-path = "src/feature_flags2.rs"
-
-[dependencies]
-gpui = { package = "gpui2", path = "../gpui2" }
-anyhow.workspace = true
@@ -1,80 +0,0 @@
-use gpui::{AppContext, Subscription, ViewContext};
-
-#[derive(Default)]
-struct FeatureFlags {
- flags: Vec<String>,
- staff: bool,
-}
-
-impl FeatureFlags {
- fn has_flag(&self, flag: &str) -> bool {
- self.staff || self.flags.iter().find(|f| f.as_str() == flag).is_some()
- }
-}
-
-pub trait FeatureFlag {
- const NAME: &'static str;
-}
-
-pub enum ChannelsAlpha {}
-
-impl FeatureFlag for ChannelsAlpha {
- const NAME: &'static str = "channels_alpha";
-}
-
-pub trait FeatureFlagViewExt<V: 'static> {
- fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
- where
- F: Fn(bool, &mut V, &mut ViewContext<V>) + Send + Sync + 'static;
-}
-
-impl<V> FeatureFlagViewExt<V> for ViewContext<'_, V>
-where
- V: 'static,
-{
- fn observe_flag<T: FeatureFlag, F>(&mut self, callback: F) -> Subscription
- where
- F: Fn(bool, &mut V, &mut ViewContext<V>) + 'static,
- {
- self.observe_global::<FeatureFlags>(move |v, cx| {
- let feature_flags = cx.global::<FeatureFlags>();
- callback(feature_flags.has_flag(<T as FeatureFlag>::NAME), v, cx);
- })
- }
-}
-
-pub trait FeatureFlagAppExt {
- fn update_flags(&mut self, staff: bool, flags: Vec<String>);
- fn set_staff(&mut self, staff: bool);
- fn has_flag<T: FeatureFlag>(&self) -> bool;
- fn is_staff(&self) -> bool;
-}
-
-impl FeatureFlagAppExt for AppContext {
- fn update_flags(&mut self, staff: bool, flags: Vec<String>) {
- let feature_flags = self.default_global::<FeatureFlags>();
- feature_flags.staff = staff;
- feature_flags.flags = flags;
- }
-
- fn set_staff(&mut self, staff: bool) {
- let feature_flags = self.default_global::<FeatureFlags>();
- feature_flags.staff = staff;
- }
-
- fn has_flag<T: FeatureFlag>(&self) -> bool {
- if self.has_global::<FeatureFlags>() {
- self.global::<FeatureFlags>().has_flag(T::NAME)
- } else {
- false
- }
- }
-
- fn is_staff(&self) -> bool {
- if self.has_global::<FeatureFlags>() {
- return self.global::<FeatureFlags>().staff;
- } else {
- false
- }
- }
-}
@@ -14,5 +14,5 @@ test-support = []
smol.workspace = true
anyhow.workspace = true
log.workspace = true
-gpui = { path = "../gpui" }
+gpui = { path = "../gpui2", package = "gpui2" }
util = { path = "../util" }
@@ -1,13 +1,12 @@
-use std::path::Path;
-
use anyhow::{anyhow, Result};
use gpui::{actions, AsyncAppContext};
+use std::path::Path;
use util::ResultExt;
actions!(cli, [Install]);
pub async fn install_cli(cx: &AsyncAppContext) -> Result<()> {
- let cli_path = cx.platform().path_for_auxiliary_executable("cli")?;
+ let cli_path = cx.update(|cx| cx.path_for_auxiliary_executable("cli"))??;
let link_path = Path::new("/usr/local/bin/zed");
let bin_dir_path = link_path.parent().unwrap();
@@ -1,19 +0,0 @@
-[package]
-name = "install_cli2"
-version = "0.1.0"
-edition = "2021"
-publish = false
-
-[lib]
-path = "src/install_cli2.rs"
-
-[features]
-test-support = []
-
-[dependencies]
-smol.workspace = true
-anyhow.workspace = true
-log.workspace = true
-serde.workspace = true
-gpui = { package = "gpui2", path = "../gpui2" }
-util = { path = "../util" }
@@ -1,54 +0,0 @@
-use anyhow::{anyhow, Result};
-use gpui::{actions, AsyncAppContext};
-use std::path::Path;
-use util::ResultExt;
-
-actions!(cli, [Install]);
-
-pub async fn install_cli(cx: &AsyncAppContext) -> Result<()> {
- let cli_path = cx.update(|cx| cx.path_for_auxiliary_executable("cli"))??;
- let link_path = Path::new("/usr/local/bin/zed");
- let bin_dir_path = link_path.parent().unwrap();
-
- // Don't re-create symlink if it points to the same CLI binary.
- if smol::fs::read_link(link_path).await.ok().as_ref() == Some(&cli_path) {
- return Ok(());
- }
-
- // If the symlink is not there or is outdated, first try replacing it
- // without escalating.
- smol::fs::remove_file(link_path).await.log_err();
- if smol::fs::unix::symlink(&cli_path, link_path)
- .await
- .log_err()
- .is_some()
- {
- return Ok(());
- }
-
- // The symlink could not be created, so use osascript with admin privileges
- // to create it.
- let status = smol::process::Command::new("/usr/bin/osascript")
- .args([
- "-e",
- &format!(
- "do shell script \" \
- mkdir -p \'{}\' && \
- ln -sf \'{}\' \'{}\' \
- \" with administrator privileges",
- bin_dir_path.to_string_lossy(),
- cli_path.to_string_lossy(),
- link_path.to_string_lossy(),
- ),
- ])
- .stdout(smol::process::Stdio::inherit())
- .stderr(smol::process::Stdio::inherit())
- .output()
- .await?
- .status;
- if status.success() {
- Ok(())
- } else {
- Err(anyhow!("error running osascript"))
- }
-}
@@ -27,7 +27,7 @@ git = { package = "git3", path = "../git3" }
gpui = { package = "gpui2", path = "../gpui2" }
language = { package = "language2", path = "../language2" }
lsp = { package = "lsp2", path = "../lsp2" }
-rich_text = { package = "rich_text2", path = "../rich_text2" }
+rich_text = { path = "../rich_text" }
settings = { package = "settings2", path = "../settings2" }
snippet = { path = "../snippet" }
sum_tree = { path = "../sum_tree" }
@@ -22,7 +22,7 @@ client = { package = "client2", path = "../client2" }
clock = { path = "../clock" }
collections = { path = "../collections" }
db = { package = "db2", path = "../db2" }
-feature_flags = { package = "feature_flags2", path = "../feature_flags2" }
+feature_flags = { path = "../feature_flags" }
gpui = { package = "gpui2", path = "../gpui2" }
rpc = { package = "rpc2", path = "../rpc2" }
settings = { package = "settings2", path = "../settings2" }
@@ -14,13 +14,12 @@ test-support = [
"util/test-support",
]
-
[dependencies]
collections = { path = "../collections" }
-gpui = { path = "../gpui" }
+gpui = { package = "gpui2", path = "../gpui2" }
sum_tree = { path = "../sum_tree" }
-theme = { path = "../theme" }
-language = { path = "../language" }
+theme = { package = "theme2", path = "../theme2" }
+language = { package = "language2", path = "../language2" }
util = { path = "../util" }
anyhow.workspace = true
futures.workspace = true
@@ -1,19 +1,16 @@
-use std::{ops::Range, sync::Arc};
-
-use anyhow::bail;
use futures::FutureExt;
use gpui::{
- elements::Text,
- fonts::{HighlightStyle, Underline, Weight},
- platform::{CursorStyle, MouseButton},
- AnyElement, CursorRegion, Element, MouseRegion, ViewContext,
+ AnyElement, ElementId, FontStyle, FontWeight, HighlightStyle, InteractiveText, IntoElement,
+ SharedString, StyledText, UnderlineStyle, WindowContext,
};
use language::{HighlightId, Language, LanguageRegistry};
-use theme::{RichTextStyle, SyntaxTheme};
+use std::{ops::Range, sync::Arc};
+use theme::ActiveTheme;
use util::RangeExt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Highlight {
+ Code,
Id(HighlightId),
Highlight(HighlightStyle),
Mention,
@@ -34,24 +31,10 @@ impl From<HighlightId> for Highlight {
#[derive(Debug, Clone)]
pub struct RichText {
- pub text: String,
+ pub text: SharedString,
pub highlights: Vec<(Range<usize>, Highlight)>,
- pub region_ranges: Vec<Range<usize>>,
- pub regions: Vec<RenderedRegion>,
-}
-
-#[derive(Clone, Copy, Debug, PartialEq, Eq)]
-pub enum BackgroundKind {
- Code,
- /// A mention background for non-self user.
- Mention,
- SelfMention,
-}
-
-#[derive(Debug, Clone, PartialEq, Eq)]
-pub struct RenderedRegion {
- pub background_kind: Option<BackgroundKind>,
- pub link_url: Option<String>,
+ pub link_ranges: Vec<Range<usize>>,
+ pub link_urls: Arc<[String]>,
}
/// Allows one to specify extra links to the rendered markdown, which can be used
@@ -62,92 +45,71 @@ pub struct Mention {
}
impl RichText {
- pub fn element<V: 'static>(
- &self,
- syntax: Arc<SyntaxTheme>,
- style: RichTextStyle,
- cx: &mut ViewContext<V>,
- ) -> AnyElement<V> {
- let mut region_id = 0;
- let view_id = cx.view_id();
-
- let regions = self.regions.clone();
+ pub fn element(&self, id: ElementId, cx: &mut WindowContext) -> AnyElement {
+ let theme = cx.theme();
+ let code_background = theme.colors().surface_background;
- enum Markdown {}
- Text::new(self.text.clone(), style.text.clone())
- .with_highlights(
- self.highlights
- .iter()
- .filter_map(|(range, highlight)| {
- let style = match highlight {
- Highlight::Id(id) => id.style(&syntax)?,
- Highlight::Highlight(style) => style.clone(),
- Highlight::Mention => style.mention_highlight,
- Highlight::SelfMention => style.self_mention_highlight,
- };
- Some((range.clone(), style))
- })
- .collect::<Vec<_>>(),
- )
- .with_custom_runs(self.region_ranges.clone(), move |ix, bounds, cx| {
- region_id += 1;
- let region = regions[ix].clone();
- if let Some(url) = region.link_url {
- cx.scene().push_cursor_region(CursorRegion {
- bounds,
- style: CursorStyle::PointingHand,
- });
- cx.scene().push_mouse_region(
- MouseRegion::new::<Markdown>(view_id, region_id, bounds)
- .on_click::<V, _>(MouseButton::Left, move |_, _, cx| {
- cx.platform().open_url(&url)
- }),
- );
- }
- if let Some(region_kind) = ®ion.background_kind {
- let background = match region_kind {
- BackgroundKind::Code => style.code_background,
- BackgroundKind::Mention => style.mention_background,
- BackgroundKind::SelfMention => style.self_mention_background,
- };
- if background.is_some() {
- cx.scene().push_quad(gpui::Quad {
- bounds,
- background,
- border: Default::default(),
- corner_radii: (2.0).into(),
- });
- }
- }
- })
- .with_soft_wrap(true)
- .into_any()
+ InteractiveText::new(
+ id,
+ StyledText::new(self.text.clone()).with_highlights(
+ &cx.text_style(),
+ self.highlights.iter().map(|(range, highlight)| {
+ (
+ range.clone(),
+ match highlight {
+ Highlight::Code => HighlightStyle {
+ background_color: Some(code_background),
+ ..Default::default()
+ },
+ Highlight::Id(id) => HighlightStyle {
+ background_color: Some(code_background),
+ ..id.style(&theme.syntax()).unwrap_or_default()
+ },
+ Highlight::Highlight(highlight) => *highlight,
+ Highlight::Mention => HighlightStyle {
+ font_weight: Some(FontWeight::BOLD),
+ ..Default::default()
+ },
+ Highlight::SelfMention => HighlightStyle {
+ font_weight: Some(FontWeight::BOLD),
+ ..Default::default()
+ },
+ },
+ )
+ }),
+ ),
+ )
+ .on_click(self.link_ranges.clone(), {
+ let link_urls = self.link_urls.clone();
+ move |ix, cx| cx.open_url(&link_urls[ix])
+ })
+ .into_any_element()
}
- pub fn add_mention(
- &mut self,
- range: Range<usize>,
- is_current_user: bool,
- mention_style: HighlightStyle,
- ) -> anyhow::Result<()> {
- if range.end > self.text.len() {
- bail!(
- "Mention in range {range:?} is outside of bounds for a message of length {}",
- self.text.len()
- );
- }
+ // pub fn add_mention(
+ // &mut self,
+ // range: Range<usize>,
+ // is_current_user: bool,
+ // mention_style: HighlightStyle,
+ // ) -> anyhow::Result<()> {
+ // if range.end > self.text.len() {
+ // bail!(
+ // "Mention in range {range:?} is outside of bounds for a message of length {}",
+ // self.text.len()
+ // );
+ // }
- if is_current_user {
- self.region_ranges.push(range.clone());
- self.regions.push(RenderedRegion {
- background_kind: Some(BackgroundKind::Mention),
- link_url: None,
- });
- }
- self.highlights
- .push((range, Highlight::Highlight(mention_style)));
- Ok(())
- }
+ // if is_current_user {
+ // self.region_ranges.push(range.clone());
+ // self.regions.push(RenderedRegion {
+ // background_kind: Some(BackgroundKind::Mention),
+ // link_url: None,
+ // });
+ // }
+ // self.highlights
+ // .push((range, Highlight::Highlight(mention_style)));
+ // Ok(())
+ // }
}
pub fn render_markdown_mut(
@@ -155,7 +117,10 @@ pub fn render_markdown_mut(
mut mentions: &[Mention],
language_registry: &Arc<LanguageRegistry>,
language: Option<&Arc<Language>>,
- data: &mut RichText,
+ text: &mut String,
+ highlights: &mut Vec<(Range<usize>, Highlight)>,
+ link_ranges: &mut Vec<Range<usize>>,
+ link_urls: &mut Vec<String>,
) {
use pulldown_cmark::{CodeBlockKind, Event, Options, Parser, Tag};
@@ -167,18 +132,18 @@ pub fn render_markdown_mut(
let options = Options::all();
for (event, source_range) in Parser::new_ext(&block, options).into_offset_iter() {
- let prev_len = data.text.len();
+ let prev_len = text.len();
match event {
Event::Text(t) => {
if let Some(language) = ¤t_language {
- render_code(&mut data.text, &mut data.highlights, t.as_ref(), language);
+ render_code(text, highlights, t.as_ref(), language);
} else {
if let Some(mention) = mentions.first() {
if source_range.contains_inclusive(&mention.range) {
mentions = &mentions[1..];
let range = (prev_len + mention.range.start - source_range.start)
..(prev_len + mention.range.end - source_range.start);
- data.highlights.push((
+ highlights.push((
range.clone(),
if mention.is_self_mention {
Highlight::SelfMention
@@ -186,33 +151,21 @@ pub fn render_markdown_mut(
Highlight::Mention
},
));
- data.region_ranges.push(range);
- data.regions.push(RenderedRegion {
- background_kind: Some(if mention.is_self_mention {
- BackgroundKind::SelfMention
- } else {
- BackgroundKind::Mention
- }),
- link_url: None,
- });
}
}
- data.text.push_str(t.as_ref());
+ text.push_str(t.as_ref());
let mut style = HighlightStyle::default();
if bold_depth > 0 {
- style.weight = Some(Weight::BOLD);
+ style.font_weight = Some(FontWeight::BOLD);
}
if italic_depth > 0 {
- style.italic = Some(true);
+ style.font_style = Some(FontStyle::Italic);
}
if let Some(link_url) = link_url.clone() {
- data.region_ranges.push(prev_len..data.text.len());
- data.regions.push(RenderedRegion {
- link_url: Some(link_url),
- background_kind: None,
- });
- style.underline = Some(Underline {
+ link_ranges.push(prev_len..text.len());
+ link_urls.push(link_url);
+ style.underline = Some(UnderlineStyle {
thickness: 1.0.into(),
..Default::default()
});
@@ -220,29 +173,27 @@ pub fn render_markdown_mut(
if style != HighlightStyle::default() {
let mut new_highlight = true;
- if let Some((last_range, last_style)) = data.highlights.last_mut() {
+ if let Some((last_range, last_style)) = highlights.last_mut() {
if last_range.end == prev_len
&& last_style == &Highlight::Highlight(style)
{
- last_range.end = data.text.len();
+ last_range.end = text.len();
new_highlight = false;
}
}
if new_highlight {
- data.highlights
- .push((prev_len..data.text.len(), Highlight::Highlight(style)));
+ highlights.push((prev_len..text.len(), Highlight::Highlight(style)));
}
}
}
}
Event::Code(t) => {
- data.text.push_str(t.as_ref());
- data.region_ranges.push(prev_len..data.text.len());
+ text.push_str(t.as_ref());
if link_url.is_some() {
- data.highlights.push((
- prev_len..data.text.len(),
+ highlights.push((
+ prev_len..text.len(),
Highlight::Highlight(HighlightStyle {
- underline: Some(Underline {
+ underline: Some(UnderlineStyle {
thickness: 1.0.into(),
..Default::default()
}),
@@ -250,19 +201,19 @@ pub fn render_markdown_mut(
}),
));
}
- data.regions.push(RenderedRegion {
- background_kind: Some(BackgroundKind::Code),
- link_url: link_url.clone(),
- });
+ if let Some(link_url) = link_url.clone() {
+ link_ranges.push(prev_len..text.len());
+ link_urls.push(link_url);
+ }
}
Event::Start(tag) => match tag {
- Tag::Paragraph => new_paragraph(&mut data.text, &mut list_stack),
+ Tag::Paragraph => new_paragraph(text, &mut list_stack),
Tag::Heading(_, _, _) => {
- new_paragraph(&mut data.text, &mut list_stack);
+ new_paragraph(text, &mut list_stack);
bold_depth += 1;
}
Tag::CodeBlock(kind) => {
- new_paragraph(&mut data.text, &mut list_stack);
+ new_paragraph(text, &mut list_stack);
current_language = if let CodeBlockKind::Fenced(language) = kind {
language_registry
.language_for_name(language.as_ref())
@@ -282,18 +233,18 @@ pub fn render_markdown_mut(
let len = list_stack.len();
if let Some((list_number, has_content)) = list_stack.last_mut() {
*has_content = false;
- if !data.text.is_empty() && !data.text.ends_with('\n') {
- data.text.push('\n');
+ if !text.is_empty() && !text.ends_with('\n') {
+ text.push('\n');
}
for _ in 0..len - 1 {
- data.text.push_str(" ");
+ text.push_str(" ");
}
if let Some(number) = list_number {
- data.text.push_str(&format!("{}. ", number));
+ text.push_str(&format!("{}. ", number));
*number += 1;
*has_content = false;
} else {
- data.text.push_str("- ");
+ text.push_str("- ");
}
}
}
@@ -308,8 +259,8 @@ pub fn render_markdown_mut(
Tag::List(_) => drop(list_stack.pop()),
_ => {}
},
- Event::HardBreak => data.text.push('\n'),
- Event::SoftBreak => data.text.push(' '),
+ Event::HardBreak => text.push('\n'),
+ Event::SoftBreak => text.push(' '),
_ => {}
}
}
@@ -321,18 +272,35 @@ pub fn render_markdown(
language_registry: &Arc<LanguageRegistry>,
language: Option<&Arc<Language>>,
) -> RichText {
- let mut data = RichText {
- text: Default::default(),
- highlights: Default::default(),
- region_ranges: Default::default(),
- regions: Default::default(),
- };
+ // let mut data = RichText {
+ // text: Default::default(),
+ // highlights: Default::default(),
+ // region_ranges: Default::default(),
+ // regions: Default::default(),
+ // };
- render_markdown_mut(&block, mentions, language_registry, language, &mut data);
+ let mut text = String::new();
+ let mut highlights = Vec::new();
+ let mut link_ranges = Vec::new();
+ let mut link_urls = Vec::new();
+ render_markdown_mut(
+ &block,
+ mentions,
+ language_registry,
+ language,
+ &mut text,
+ &mut highlights,
+ &mut link_ranges,
+ &mut link_urls,
+ );
+ text.truncate(text.trim_end().len());
- data.text = data.text.trim().to_string();
-
- data
+ RichText {
+ text: SharedString::from(text),
+ link_urls: link_urls.into(),
+ link_ranges,
+ highlights,
+ }
}
pub fn render_code(
@@ -343,11 +311,19 @@ pub fn render_code(
) {
let prev_len = text.len();
text.push_str(content);
+ let mut offset = 0;
for (range, highlight_id) in language.highlight_text(&content.into(), 0..content.len()) {
+ if range.start > offset {
+ highlights.push((prev_len + offset..prev_len + range.start, Highlight::Code));
+ }
highlights.push((
prev_len + range.start..prev_len + range.end,
Highlight::Id(highlight_id),
));
+ offset = range.end;
+ }
+ if offset < content.len() {
+ highlights.push((prev_len + offset..prev_len + content.len(), Highlight::Code));
}
}
@@ -1,29 +0,0 @@
-[package]
-name = "rich_text2"
-version = "0.1.0"
-edition = "2021"
-publish = false
-
-[lib]
-path = "src/rich_text.rs"
-doctest = false
-
-[features]
-test-support = [
- "gpui/test-support",
- "util/test-support",
-]
-
-[dependencies]
-collections = { path = "../collections" }
-gpui = { package = "gpui2", path = "../gpui2" }
-sum_tree = { path = "../sum_tree" }
-theme = { package = "theme2", path = "../theme2" }
-language = { package = "language2", path = "../language2" }
-util = { path = "../util" }
-anyhow.workspace = true
-futures.workspace = true
-lazy_static.workspace = true
-pulldown-cmark = { version = "0.9.2", default-features = false }
-smallvec.workspace = true
-smol.workspace = true
@@ -1,353 +0,0 @@
-use futures::FutureExt;
-use gpui::{
- AnyElement, ElementId, FontStyle, FontWeight, HighlightStyle, InteractiveText, IntoElement,
- SharedString, StyledText, UnderlineStyle, WindowContext,
-};
-use language::{HighlightId, Language, LanguageRegistry};
-use std::{ops::Range, sync::Arc};
-use theme::ActiveTheme;
-use util::RangeExt;
-
-#[derive(Debug, Clone, PartialEq, Eq)]
-pub enum Highlight {
- Code,
- Id(HighlightId),
- Highlight(HighlightStyle),
- Mention,
- SelfMention,
-}
-
-impl From<HighlightStyle> for Highlight {
- fn from(style: HighlightStyle) -> Self {
- Self::Highlight(style)
- }
-}
-
-impl From<HighlightId> for Highlight {
- fn from(style: HighlightId) -> Self {
- Self::Id(style)
- }
-}
-
-#[derive(Debug, Clone)]
-pub struct RichText {
- pub text: SharedString,
- pub highlights: Vec<(Range<usize>, Highlight)>,
- pub link_ranges: Vec<Range<usize>>,
- pub link_urls: Arc<[String]>,
-}
-
-/// Allows one to specify extra links to the rendered markdown, which can be used
-/// for e.g. mentions.
-pub struct Mention {
- pub range: Range<usize>,
- pub is_self_mention: bool,
-}
-
-impl RichText {
- pub fn element(&self, id: ElementId, cx: &mut WindowContext) -> AnyElement {
- let theme = cx.theme();
- let code_background = theme.colors().surface_background;
-
- InteractiveText::new(
- id,
- StyledText::new(self.text.clone()).with_highlights(
- &cx.text_style(),
- self.highlights.iter().map(|(range, highlight)| {
- (
- range.clone(),
- match highlight {
- Highlight::Code => HighlightStyle {
- background_color: Some(code_background),
- ..Default::default()
- },
- Highlight::Id(id) => HighlightStyle {
- background_color: Some(code_background),
- ..id.style(&theme.syntax()).unwrap_or_default()
- },
- Highlight::Highlight(highlight) => *highlight,
- Highlight::Mention => HighlightStyle {
- font_weight: Some(FontWeight::BOLD),
- ..Default::default()
- },
- Highlight::SelfMention => HighlightStyle {
- font_weight: Some(FontWeight::BOLD),
- ..Default::default()
- },
- },
- )
- }),
- ),
- )
- .on_click(self.link_ranges.clone(), {
- let link_urls = self.link_urls.clone();
- move |ix, cx| cx.open_url(&link_urls[ix])
- })
- .into_any_element()
- }
-
- // pub fn add_mention(
- // &mut self,
- // range: Range<usize>,
- // is_current_user: bool,
- // mention_style: HighlightStyle,
- // ) -> anyhow::Result<()> {
- // if range.end > self.text.len() {
- // bail!(
- // "Mention in range {range:?} is outside of bounds for a message of length {}",
- // self.text.len()
- // );
- // }
-
- // if is_current_user {
- // self.region_ranges.push(range.clone());
- // self.regions.push(RenderedRegion {
- // background_kind: Some(BackgroundKind::Mention),
- // link_url: None,
- // });
- // }
- // self.highlights
- // .push((range, Highlight::Highlight(mention_style)));
- // Ok(())
- // }
-}
-
-pub fn render_markdown_mut(
- block: &str,
- mut mentions: &[Mention],
- language_registry: &Arc<LanguageRegistry>,
- language: Option<&Arc<Language>>,
- text: &mut String,
- highlights: &mut Vec<(Range<usize>, Highlight)>,
- link_ranges: &mut Vec<Range<usize>>,
- link_urls: &mut Vec<String>,
-) {
- use pulldown_cmark::{CodeBlockKind, Event, Options, Parser, Tag};
-
- let mut bold_depth = 0;
- let mut italic_depth = 0;
- let mut link_url = None;
- let mut current_language = None;
- let mut list_stack = Vec::new();
-
- let options = Options::all();
- for (event, source_range) in Parser::new_ext(&block, options).into_offset_iter() {
- let prev_len = text.len();
- match event {
- Event::Text(t) => {
- if let Some(language) = ¤t_language {
- render_code(text, highlights, t.as_ref(), language);
- } else {
- if let Some(mention) = mentions.first() {
- if source_range.contains_inclusive(&mention.range) {
- mentions = &mentions[1..];
- let range = (prev_len + mention.range.start - source_range.start)
- ..(prev_len + mention.range.end - source_range.start);
- highlights.push((
- range.clone(),
- if mention.is_self_mention {
- Highlight::SelfMention
- } else {
- Highlight::Mention
- },
- ));
- }
- }
-
- text.push_str(t.as_ref());
- let mut style = HighlightStyle::default();
- if bold_depth > 0 {
- style.font_weight = Some(FontWeight::BOLD);
- }
- if italic_depth > 0 {
- style.font_style = Some(FontStyle::Italic);
- }
- if let Some(link_url) = link_url.clone() {
- link_ranges.push(prev_len..text.len());
- link_urls.push(link_url);
- style.underline = Some(UnderlineStyle {
- thickness: 1.0.into(),
- ..Default::default()
- });
- }
-
- if style != HighlightStyle::default() {
- let mut new_highlight = true;
- if let Some((last_range, last_style)) = highlights.last_mut() {
- if last_range.end == prev_len
- && last_style == &Highlight::Highlight(style)
- {
- last_range.end = text.len();
- new_highlight = false;
- }
- }
- if new_highlight {
- highlights.push((prev_len..text.len(), Highlight::Highlight(style)));
- }
- }
- }
- }
- Event::Code(t) => {
- text.push_str(t.as_ref());
- if link_url.is_some() {
- highlights.push((
- prev_len..text.len(),
- Highlight::Highlight(HighlightStyle {
- underline: Some(UnderlineStyle {
- thickness: 1.0.into(),
- ..Default::default()
- }),
- ..Default::default()
- }),
- ));
- }
- if let Some(link_url) = link_url.clone() {
- link_ranges.push(prev_len..text.len());
- link_urls.push(link_url);
- }
- }
- Event::Start(tag) => match tag {
- Tag::Paragraph => new_paragraph(text, &mut list_stack),
- Tag::Heading(_, _, _) => {
- new_paragraph(text, &mut list_stack);
- bold_depth += 1;
- }
- Tag::CodeBlock(kind) => {
- new_paragraph(text, &mut list_stack);
- current_language = if let CodeBlockKind::Fenced(language) = kind {
- language_registry
- .language_for_name(language.as_ref())
- .now_or_never()
- .and_then(Result::ok)
- } else {
- language.cloned()
- }
- }
- Tag::Emphasis => italic_depth += 1,
- Tag::Strong => bold_depth += 1,
- Tag::Link(_, url, _) => link_url = Some(url.to_string()),
- Tag::List(number) => {
- list_stack.push((number, false));
- }
- Tag::Item => {
- let len = list_stack.len();
- if let Some((list_number, has_content)) = list_stack.last_mut() {
- *has_content = false;
- if !text.is_empty() && !text.ends_with('\n') {
- text.push('\n');
- }
- for _ in 0..len - 1 {
- text.push_str(" ");
- }
- if let Some(number) = list_number {
- text.push_str(&format!("{}. ", number));
- *number += 1;
- *has_content = false;
- } else {
- text.push_str("- ");
- }
- }
- }
- _ => {}
- },
- Event::End(tag) => match tag {
- Tag::Heading(_, _, _) => bold_depth -= 1,
- Tag::CodeBlock(_) => current_language = None,
- Tag::Emphasis => italic_depth -= 1,
- Tag::Strong => bold_depth -= 1,
- Tag::Link(_, _, _) => link_url = None,
- Tag::List(_) => drop(list_stack.pop()),
- _ => {}
- },
- Event::HardBreak => text.push('\n'),
- Event::SoftBreak => text.push(' '),
- _ => {}
- }
- }
-}
-
-pub fn render_markdown(
- block: String,
- mentions: &[Mention],
- language_registry: &Arc<LanguageRegistry>,
- language: Option<&Arc<Language>>,
-) -> RichText {
- // let mut data = RichText {
- // text: Default::default(),
- // highlights: Default::default(),
- // region_ranges: Default::default(),
- // regions: Default::default(),
- // };
-
- let mut text = String::new();
- let mut highlights = Vec::new();
- let mut link_ranges = Vec::new();
- let mut link_urls = Vec::new();
- render_markdown_mut(
- &block,
- mentions,
- language_registry,
- language,
- &mut text,
- &mut highlights,
- &mut link_ranges,
- &mut link_urls,
- );
- text.truncate(text.trim_end().len());
-
- RichText {
- text: SharedString::from(text),
- link_urls: link_urls.into(),
- link_ranges,
- highlights,
- }
-}
-
-pub fn render_code(
- text: &mut String,
- highlights: &mut Vec<(Range<usize>, Highlight)>,
- content: &str,
- language: &Arc<Language>,
-) {
- let prev_len = text.len();
- text.push_str(content);
- let mut offset = 0;
- for (range, highlight_id) in language.highlight_text(&content.into(), 0..content.len()) {
- if range.start > offset {
- highlights.push((prev_len + offset..prev_len + range.start, Highlight::Code));
- }
- highlights.push((
- prev_len + range.start..prev_len + range.end,
- Highlight::Id(highlight_id),
- ));
- offset = range.end;
- }
- if offset < content.len() {
- highlights.push((prev_len + offset..prev_len + content.len(), Highlight::Code));
- }
-}
-
-pub fn new_paragraph(text: &mut String, list_stack: &mut Vec<(Option<u64>, bool)>) {
- let mut is_subsequent_paragraph_of_list = false;
- if let Some((_, has_content)) = list_stack.last_mut() {
- if *has_content {
- is_subsequent_paragraph_of_list = true;
- } else {
- *has_content = true;
- return;
- }
- }
-
- if !text.is_empty() {
- if !text.ends_with('\n') {
- text.push('\n');
- }
- text.push('\n');
- }
- for _ in 0..list_stack.len().saturating_sub(1) {
- text.push_str(" ");
- }
- if is_subsequent_paragraph_of_list {
- text.push_str(" ");
- }
-}
@@ -9,7 +9,7 @@ path = "src/semantic_index.rs"
doctest = false
[dependencies]
-ai = { package = "ai2", path = "../ai2" }
+ai = { path = "../ai" }
collections = { path = "../collections" }
gpui = { package = "gpui2", path = "../gpui2" }
language = { package = "language2", path = "../language2" }
@@ -39,7 +39,7 @@ sha1 = "0.10.5"
ndarray = { version = "0.15.0" }
[dev-dependencies]
-ai = { package = "ai2", path = "../ai2", features = ["test-support"] }
+ai = { path = "../ai", features = ["test-support"] }
collections = { path = "../collections", features = ["test-support"] }
gpui = { package = "gpui2", path = "../gpui2", features = ["test-support"] }
language = { package = "language2", path = "../language2", features = ["test-support"] }
@@ -16,7 +16,7 @@ collections = { path = "../collections" }
gpui = {package = "gpui2", path = "../gpui2" }
sqlez = { path = "../sqlez" }
fs = {package = "fs2", path = "../fs2" }
-feature_flags = {package = "feature_flags2", path = "../feature_flags2" }
+feature_flags = { path = "../feature_flags" }
util = { path = "../util" }
anyhow.workspace = true
@@ -11,7 +11,7 @@ doctest = false
[dependencies]
client = { package = "client2", path = "../client2" }
editor = { path = "../editor" }
-feature_flags = { package = "feature_flags2", path = "../feature_flags2" }
+feature_flags = { path = "../feature_flags" }
fs = { package = "fs2", path = "../fs2" }
fuzzy = { package = "fuzzy2", path = "../fuzzy2" }
gpui = { package = "gpui2", path = "../gpui2" }
@@ -35,7 +35,7 @@ workspace = { path = "../workspace" }
theme = { package = "theme2", path = "../theme2" }
ui = { package = "ui2", path = "../ui2"}
diagnostics = { path = "../diagnostics" }
-zed_actions = { package = "zed_actions2", path = "../zed_actions2" }
+zed_actions = { path = "../zed_actions" }
[dev-dependencies]
indoc.workspace = true
@@ -18,7 +18,7 @@ fuzzy = { package = "fuzzy2", path = "../fuzzy2" }
gpui = { package = "gpui2", path = "../gpui2" }
ui = { package = "ui2", path = "../ui2" }
db = { package = "db2", path = "../db2" }
-install_cli = { package = "install_cli2", path = "../install_cli2" }
+install_cli = { path = "../install_cli" }
project = { package = "project2", path = "../project2" }
settings = { package = "settings2", path = "../settings2" }
theme = { package = "theme2", path = "../theme2" }
@@ -26,7 +26,7 @@ collections = { path = "../collections" }
# context_menu = { path = "../context_menu" }
fs = { path = "../fs2", package = "fs2" }
gpui = { package = "gpui2", path = "../gpui2" }
-install_cli = { path = "../install_cli2", package = "install_cli2" }
+install_cli = { path = "../install_cli" }
language = { path = "../language2", package = "language2" }
#menu = { path = "../menu" }
node_runtime = { path = "../node_runtime" }
@@ -1,11 +0,0 @@
-[package]
-name = "zed-actions"
-version = "0.1.0"
-edition = "2021"
-publish = false
-
-# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
-
-[dependencies]
-gpui = { path = "../gpui" }
-serde.workspace = true
@@ -1,41 +0,0 @@
-use std::sync::Arc;
-
-use gpui::{actions, impl_actions};
-use serde::Deserialize;
-
-actions!(
- zed,
- [
- About,
- DebugElements,
- DecreaseBufferFontSize,
- Hide,
- HideOthers,
- IncreaseBufferFontSize,
- Minimize,
- OpenDefaultKeymap,
- OpenDefaultSettings,
- OpenKeymap,
- OpenLicenses,
- OpenLocalSettings,
- OpenLog,
- OpenSettings,
- OpenTelemetryLog,
- Quit,
- ResetBufferFontSize,
- ResetDatabase,
- ShowAll,
- ToggleFullScreen,
- Zoom,
- ]
-);
-
-#[derive(Deserialize, Clone, PartialEq)]
-pub struct OpenBrowser {
- pub url: Arc<str>,
-}
-#[derive(Deserialize, Clone, PartialEq)]
-pub struct OpenZedURL {
- pub url: String,
-}
-impl_actions!(zed, [OpenBrowser, OpenZedURL]);
@@ -15,7 +15,7 @@ name = "zed"
path = "src/main.rs"
[dependencies]
-ai = { package = "ai2", path = "../ai2"}
+ai = { path = "../ai"}
audio = { package = "audio2", path = "../audio2" }
activity_indicator = { path = "../activity_indicator"}
auto_update = { path = "../auto_update" }
@@ -41,7 +41,7 @@ fs = { package = "fs2", path = "../fs2" }
fsevent = { path = "../fsevent" }
go_to_line = { path = "../go_to_line" }
gpui = { package = "gpui2", path = "../gpui2" }
-install_cli = { package = "install_cli2", path = "../install_cli2" }
+install_cli = { path = "../install_cli" }
journal = { path = "../journal" }
language = { package = "language2", path = "../language2" }
language_selector = { path = "../language_selector" }
@@ -61,7 +61,7 @@ recent_projects = { path = "../recent_projects" }
rope = { package = "rope2", path = "../rope2"}
rpc = { package = "rpc2", path = "../rpc2" }
settings = { package = "settings2", path = "../settings2" }
-feature_flags = { package = "feature_flags2", path = "../feature_flags2" }
+feature_flags = { path = "../feature_flags" }
sum_tree = { path = "../sum_tree" }
shellexpand = "2.1.0"
text = { package = "text2", path = "../text2" }
@@ -73,7 +73,7 @@ semantic_index = { package = "semantic_index2", path = "../semantic_index2" }
vim = { path = "../vim" }
workspace = { path = "../workspace" }
welcome = { path = "../welcome" }
-zed_actions = {package = "zed_actions2", path = "../zed_actions2"}
+zed_actions = {path = "../zed_actions"}
anyhow.workspace = true
async-compression.workspace = true
async-tar = "0.4.2"
@@ -1,5 +1,5 @@
[package]
-name = "zed_actions2"
+name = "zed_actions"
version = "0.1.0"
edition = "2021"
publish = false