From 0b57ab730332dbf0033d652b4b531b2898c88039 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Sun, 22 Oct 2023 13:34:22 +0200 Subject: [PATCH 01/25] cleaned up truncate vs truncate start --- crates/ai/src/models.rs | 37 ++++++++++++++----------- crates/ai/src/templates/base.rs | 33 ++++++++++++++-------- crates/ai/src/templates/file_context.rs | 12 +++++--- crates/ai/src/templates/generate.rs | 6 +++- 4 files changed, 56 insertions(+), 32 deletions(-) diff --git a/crates/ai/src/models.rs b/crates/ai/src/models.rs index d0206cc41c526f171fef8521a120f8f4ff70aa74..afb8496156f6521eb3125b6a0ba6d703d5d0fe50 100644 --- a/crates/ai/src/models.rs +++ b/crates/ai/src/models.rs @@ -2,11 +2,20 @@ use anyhow::anyhow; use tiktoken_rs::CoreBPE; use util::ResultExt; +pub enum TruncationDirection { + Start, + End, +} + pub trait LanguageModel { fn name(&self) -> String; fn count_tokens(&self, content: &str) -> anyhow::Result; - fn truncate(&self, content: &str, length: usize) -> anyhow::Result; - fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result; + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result; fn capacity(&self) -> anyhow::Result; } @@ -36,23 +45,19 @@ impl LanguageModel for OpenAILanguageModel { Err(anyhow!("bpe for open ai model was not retrieved")) } } - fn truncate(&self, content: &str, length: usize) -> anyhow::Result { - if let Some(bpe) = &self.bpe { - let tokens = bpe.encode_with_special_tokens(content); - if tokens.len() > length { - bpe.decode(tokens[..length].to_vec()) - } else { - bpe.decode(tokens) - } - } else { - Err(anyhow!("bpe for open ai model was not retrieved")) - } - } - fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result { + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { if let Some(bpe) = &self.bpe { let tokens = bpe.encode_with_special_tokens(content); if tokens.len() > length { - bpe.decode(tokens[length..].to_vec()) + match direction { + TruncationDirection::End => bpe.decode(tokens[..length].to_vec()), + TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()), + } } else { bpe.decode(tokens) } diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/templates/base.rs index bda1d6c30e61a9e2fd3808fa45a34cbe041cf2b6..e5ac414bc1691b02090361aa19bd0c56ee1557f5 100644 --- a/crates/ai/src/templates/base.rs +++ b/crates/ai/src/templates/base.rs @@ -125,6 +125,8 @@ impl PromptChain { #[cfg(test)] pub(crate) mod tests { + use crate::models::TruncationDirection; + use super::*; #[test] @@ -141,7 +143,11 @@ pub(crate) mod tests { let mut token_count = args.model.count_tokens(&content)?; if let Some(max_token_length) = max_token_length { if token_count > max_token_length { - content = args.model.truncate(&content, max_token_length)?; + content = args.model.truncate( + &content, + max_token_length, + TruncationDirection::Start, + )?; token_count = max_token_length; } } @@ -162,7 +168,11 @@ pub(crate) mod tests { let mut token_count = args.model.count_tokens(&content)?; if let Some(max_token_length) = max_token_length { if token_count > max_token_length { - content = args.model.truncate(&content, max_token_length)?; + content = args.model.truncate( + &content, + max_token_length, + TruncationDirection::Start, + )?; token_count = max_token_length; } } @@ -183,19 +193,20 @@ pub(crate) mod tests { fn count_tokens(&self, content: &str) -> anyhow::Result { anyhow::Ok(content.chars().collect::>().len()) } - fn truncate(&self, content: &str, length: usize) -> anyhow::Result { - anyhow::Ok( - content.chars().collect::>()[..length] + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { + anyhow::Ok(match direction { + TruncationDirection::End => content.chars().collect::>()[..length] .into_iter() .collect::(), - ) - } - fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result { - anyhow::Ok( - content.chars().collect::>()[length..] + TruncationDirection::Start => content.chars().collect::>()[length..] .into_iter() .collect::(), - ) + }) } fn capacity(&self) -> anyhow::Result { anyhow::Ok(self.capacity) diff --git a/crates/ai/src/templates/file_context.rs b/crates/ai/src/templates/file_context.rs index 1afd61192edc02b153abe8cd00836d67caa42f02..1517134abb97c05866c007c7072175bc2f7f6aca 100644 --- a/crates/ai/src/templates/file_context.rs +++ b/crates/ai/src/templates/file_context.rs @@ -3,6 +3,7 @@ use language::BufferSnapshot; use language::ToOffset; use crate::models::LanguageModel; +use crate::models::TruncationDirection; use crate::templates::base::PromptArguments; use crate::templates::base::PromptTemplate; use std::fmt::Write; @@ -70,8 +71,9 @@ fn retrieve_context( }; let truncated_start_window = - model.truncate_start(&start_window, start_goal_tokens)?; - let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?; + model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?; + let truncated_end_window = + model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?; writeln!( prompt, "{truncated_start_window}{selected_window}{truncated_end_window}" @@ -89,7 +91,7 @@ fn retrieve_context( if let Some(max_token_count) = max_token_count { if model.count_tokens(&prompt)? > max_token_count { truncated = true; - prompt = model.truncate(&prompt, max_token_count)?; + prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?; } } } @@ -148,7 +150,9 @@ impl PromptTemplate for FileContext { // Really dumb truncation strategy if let Some(max_tokens) = max_token_length { - prompt = args.model.truncate(&prompt, max_tokens)?; + prompt = args + .model + .truncate(&prompt, max_tokens, TruncationDirection::End)?; } let token_count = args.model.count_tokens(&prompt)?; diff --git a/crates/ai/src/templates/generate.rs b/crates/ai/src/templates/generate.rs index 1eeb197f932db0dc13963982e7e8bc983c338db7..c9541c6b44a076fdd87b491669b34616fec04e24 100644 --- a/crates/ai/src/templates/generate.rs +++ b/crates/ai/src/templates/generate.rs @@ -85,7 +85,11 @@ impl PromptTemplate for GenerateInlineContent { // Really dumb truncation strategy if let Some(max_tokens) = max_token_length { - prompt = args.model.truncate(&prompt, max_tokens)?; + prompt = args.model.truncate( + &prompt, + max_tokens, + crate::models::TruncationDirection::End, + )?; } let token_count = args.model.count_tokens(&prompt)?; From a62baf34f2e3bef619ba57e557cb30baa6356b29 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Sun, 22 Oct 2023 13:46:49 +0200 Subject: [PATCH 02/25] rename templates to prompts in ai crate --- crates/ai/src/ai.rs | 2 +- crates/ai/src/{templates => prompts}/base.rs | 2 +- crates/ai/src/{templates => prompts}/file_context.rs | 4 ++-- crates/ai/src/{templates => prompts}/generate.rs | 2 +- crates/ai/src/{templates => prompts}/mod.rs | 0 crates/ai/src/{templates => prompts}/preamble.rs | 2 +- .../src/{templates => prompts}/repository_context.rs | 2 +- crates/assistant/src/assistant_panel.rs | 2 +- crates/assistant/src/prompts.rs | 10 +++++----- 9 files changed, 13 insertions(+), 13 deletions(-) rename crates/ai/src/{templates => prompts}/base.rs (99%) rename crates/ai/src/{templates => prompts}/file_context.rs (98%) rename crates/ai/src/{templates => prompts}/generate.rs (97%) rename crates/ai/src/{templates => prompts}/mod.rs (100%) rename crates/ai/src/{templates => prompts}/preamble.rs (95%) rename crates/ai/src/{templates => prompts}/repository_context.rs (98%) diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index f168c157934f6b70be775f7e17e9ba27ef9b3103..c0b78b74cf2a24dded65671303ccade5cc4261e3 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -1,4 +1,4 @@ pub mod completion; pub mod embedding; pub mod models; -pub mod templates; +pub mod prompts; diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/prompts/base.rs similarity index 99% rename from crates/ai/src/templates/base.rs rename to crates/ai/src/prompts/base.rs index e5ac414bc1691b02090361aa19bd0c56ee1557f5..f0ff597e635702329964fe18efc4b2300c4c26c6 100644 --- a/crates/ai/src/templates/base.rs +++ b/crates/ai/src/prompts/base.rs @@ -6,7 +6,7 @@ use language::BufferSnapshot; use util::ResultExt; use crate::models::LanguageModel; -use crate::templates::repository_context::PromptCodeSnippet; +use crate::prompts::repository_context::PromptCodeSnippet; pub(crate) enum PromptFileType { Text, diff --git a/crates/ai/src/templates/file_context.rs b/crates/ai/src/prompts/file_context.rs similarity index 98% rename from crates/ai/src/templates/file_context.rs rename to crates/ai/src/prompts/file_context.rs index 1517134abb97c05866c007c7072175bc2f7f6aca..f108a62f6f0f82e1d25f01cb2e2ae2a755d69fda 100644 --- a/crates/ai/src/templates/file_context.rs +++ b/crates/ai/src/prompts/file_context.rs @@ -4,8 +4,8 @@ use language::ToOffset; use crate::models::LanguageModel; use crate::models::TruncationDirection; -use crate::templates::base::PromptArguments; -use crate::templates::base::PromptTemplate; +use crate::prompts::base::PromptArguments; +use crate::prompts::base::PromptTemplate; use std::fmt::Write; use std::ops::Range; use std::sync::Arc; diff --git a/crates/ai/src/templates/generate.rs b/crates/ai/src/prompts/generate.rs similarity index 97% rename from crates/ai/src/templates/generate.rs rename to crates/ai/src/prompts/generate.rs index c9541c6b44a076fdd87b491669b34616fec04e24..c7be620107ee4d6daca06a8cb38019aceedc40a4 100644 --- a/crates/ai/src/templates/generate.rs +++ b/crates/ai/src/prompts/generate.rs @@ -1,4 +1,4 @@ -use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate}; +use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate}; use anyhow::anyhow; use std::fmt::Write; diff --git a/crates/ai/src/templates/mod.rs b/crates/ai/src/prompts/mod.rs similarity index 100% rename from crates/ai/src/templates/mod.rs rename to crates/ai/src/prompts/mod.rs diff --git a/crates/ai/src/templates/preamble.rs b/crates/ai/src/prompts/preamble.rs similarity index 95% rename from crates/ai/src/templates/preamble.rs rename to crates/ai/src/prompts/preamble.rs index 9eabaaeb97fe4216c6bac44cf4eabfb7c129ecf2..92e0edeb78b48169379aae2e88e81f62463a1057 100644 --- a/crates/ai/src/templates/preamble.rs +++ b/crates/ai/src/prompts/preamble.rs @@ -1,4 +1,4 @@ -use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate}; +use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate}; use std::fmt::Write; pub struct EngineerPreamble {} diff --git a/crates/ai/src/templates/repository_context.rs b/crates/ai/src/prompts/repository_context.rs similarity index 98% rename from crates/ai/src/templates/repository_context.rs rename to crates/ai/src/prompts/repository_context.rs index a8e7f4b5af7bee4d3f29d70c665965dc7e05ed4b..c21b0f995c19361ed89243bf49696f3b0ba3c865 100644 --- a/crates/ai/src/templates/repository_context.rs +++ b/crates/ai/src/prompts/repository_context.rs @@ -1,4 +1,4 @@ -use crate::templates::base::{PromptArguments, PromptTemplate}; +use crate::prompts::base::{PromptArguments, PromptTemplate}; use std::fmt::Write; use std::{ops::Range, path::PathBuf}; diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index ca8c54a285d70d3eaa9f1aee09437994708ebdfb..64eff04b8dfb9d66164851ab2513db48f68572f9 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -9,7 +9,7 @@ use ai::{ completion::{ stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, }, - templates::repository_context::PromptCodeSnippet, + prompts::repository_context::PromptCodeSnippet, }; use anyhow::{anyhow, Result}; use chrono::{DateTime, Local}; diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index dffcbc29234d3f24174d1d9a6610045105eae890..8fff232fdbf9358de9a97892bdd58d070871f64a 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -1,9 +1,9 @@ use ai::models::{LanguageModel, OpenAILanguageModel}; -use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate}; -use ai::templates::file_context::FileContext; -use ai::templates::generate::GenerateInlineContent; -use ai::templates::preamble::EngineerPreamble; -use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext}; +use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate}; +use ai::prompts::file_context::FileContext; +use ai::prompts::generate::GenerateInlineContent; +use ai::prompts::preamble::EngineerPreamble; +use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext}; use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; use std::cmp::{self, Reverse}; use std::ops::Range; From 3712794e561b7aa6068ee3de0fc411d5cb311566 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Sun, 22 Oct 2023 13:47:28 +0200 Subject: [PATCH 03/25] move OpenAILanguageModel to providers folder --- crates/ai/src/ai.rs | 1 + crates/ai/src/models.rs | 55 ----------------------- crates/ai/src/providers/mod.rs | 1 + crates/ai/src/providers/open_ai/mod.rs | 2 + crates/ai/src/providers/open_ai/model.rs | 56 ++++++++++++++++++++++++ crates/assistant/src/prompts.rs | 3 +- 6 files changed, 62 insertions(+), 56 deletions(-) create mode 100644 crates/ai/src/providers/mod.rs create mode 100644 crates/ai/src/providers/open_ai/mod.rs create mode 100644 crates/ai/src/providers/open_ai/model.rs diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index c0b78b74cf2a24dded65671303ccade5cc4261e3..a3ae2fcf7ffb5075c70b607f1cdf34279cd063a3 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -2,3 +2,4 @@ pub mod completion; pub mod embedding; pub mod models; pub mod prompts; +pub mod providers; diff --git a/crates/ai/src/models.rs b/crates/ai/src/models.rs index afb8496156f6521eb3125b6a0ba6d703d5d0fe50..1db3d58c6f54ad613cb98fc3f425df3d47e5e97f 100644 --- a/crates/ai/src/models.rs +++ b/crates/ai/src/models.rs @@ -1,7 +1,3 @@ -use anyhow::anyhow; -use tiktoken_rs::CoreBPE; -use util::ResultExt; - pub enum TruncationDirection { Start, End, @@ -18,54 +14,3 @@ pub trait LanguageModel { ) -> anyhow::Result; fn capacity(&self) -> anyhow::Result; } - -pub struct OpenAILanguageModel { - name: String, - bpe: Option, -} - -impl OpenAILanguageModel { - pub fn load(model_name: &str) -> Self { - let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err(); - OpenAILanguageModel { - name: model_name.to_string(), - bpe, - } - } -} - -impl LanguageModel for OpenAILanguageModel { - fn name(&self) -> String { - self.name.clone() - } - fn count_tokens(&self, content: &str) -> anyhow::Result { - if let Some(bpe) = &self.bpe { - anyhow::Ok(bpe.encode_with_special_tokens(content).len()) - } else { - Err(anyhow!("bpe for open ai model was not retrieved")) - } - } - fn truncate( - &self, - content: &str, - length: usize, - direction: TruncationDirection, - ) -> anyhow::Result { - if let Some(bpe) = &self.bpe { - let tokens = bpe.encode_with_special_tokens(content); - if tokens.len() > length { - match direction { - TruncationDirection::End => bpe.decode(tokens[..length].to_vec()), - TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()), - } - } else { - bpe.decode(tokens) - } - } else { - Err(anyhow!("bpe for open ai model was not retrieved")) - } - } - fn capacity(&self) -> anyhow::Result { - anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name)) - } -} diff --git a/crates/ai/src/providers/mod.rs b/crates/ai/src/providers/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..acd0f9d91053869e3999ef0c1a23326480a7cbdd --- /dev/null +++ b/crates/ai/src/providers/mod.rs @@ -0,0 +1 @@ +pub mod open_ai; diff --git a/crates/ai/src/providers/open_ai/mod.rs b/crates/ai/src/providers/open_ai/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..8d8489e187b26c092cd2043b42f7339b4a43d794 --- /dev/null +++ b/crates/ai/src/providers/open_ai/mod.rs @@ -0,0 +1,2 @@ +pub mod model; +pub use model::OpenAILanguageModel; diff --git a/crates/ai/src/providers/open_ai/model.rs b/crates/ai/src/providers/open_ai/model.rs new file mode 100644 index 0000000000000000000000000000000000000000..42523f3df48951d8674b33409105f8d802fd6c25 --- /dev/null +++ b/crates/ai/src/providers/open_ai/model.rs @@ -0,0 +1,56 @@ +use anyhow::anyhow; +use tiktoken_rs::CoreBPE; +use util::ResultExt; + +use crate::models::{LanguageModel, TruncationDirection}; + +pub struct OpenAILanguageModel { + name: String, + bpe: Option, +} + +impl OpenAILanguageModel { + pub fn load(model_name: &str) -> Self { + let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err(); + OpenAILanguageModel { + name: model_name.to_string(), + bpe, + } + } +} + +impl LanguageModel for OpenAILanguageModel { + fn name(&self) -> String { + self.name.clone() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + anyhow::Ok(bpe.encode_with_special_tokens(content).len()) + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + let tokens = bpe.encode_with_special_tokens(content); + if tokens.len() > length { + match direction { + TruncationDirection::End => bpe.decode(tokens[..length].to_vec()), + TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()), + } + } else { + bpe.decode(tokens) + } + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name)) + } +} diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index 8fff232fdbf9358de9a97892bdd58d070871f64a..25af023c40072cbc56ba2865658544999e0fa7c7 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -1,9 +1,10 @@ -use ai::models::{LanguageModel, OpenAILanguageModel}; +use ai::models::LanguageModel; use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate}; use ai::prompts::file_context::FileContext; use ai::prompts::generate::GenerateInlineContent; use ai::prompts::preamble::EngineerPreamble; use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext}; +use ai::providers::open_ai::OpenAILanguageModel; use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; use std::cmp::{self, Reverse}; use std::ops::Range; From 05ae978cb773978fcb16c81d14e8b4cd4907decd Mon Sep 17 00:00:00 2001 From: KCaverly Date: Sun, 22 Oct 2023 13:57:13 +0200 Subject: [PATCH 04/25] move OpenAICompletionProvider to providers location --- crates/ai/src/completion.rs | 209 +----------------- crates/ai/src/providers/open_ai/completion.rs | 209 ++++++++++++++++++ crates/ai/src/providers/open_ai/mod.rs | 2 + crates/assistant/src/assistant.rs | 2 +- crates/assistant/src/assistant_panel.rs | 10 +- crates/assistant/src/codegen.rs | 3 +- 6 files changed, 222 insertions(+), 213 deletions(-) create mode 100644 crates/ai/src/providers/open_ai/completion.rs diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index de6ce9da711ee17f9fc072276a499d1769b874ce..f45893898fccaf54e794d151ce4c6c64eff34bc5 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -1,177 +1,7 @@ -use anyhow::{anyhow, Result}; -use futures::{ - future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, - Stream, StreamExt, -}; -use gpui::executor::Background; -use isahc::{http::StatusCode, Request, RequestExt}; -use serde::{Deserialize, Serialize}; -use std::{ - fmt::{self, Display}, - io, - sync::Arc, -}; +use anyhow::Result; +use futures::{future::BoxFuture, stream::BoxStream}; -pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; - -#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] -#[serde(rename_all = "lowercase")] -pub enum Role { - User, - Assistant, - System, -} - -impl Role { - pub fn cycle(&mut self) { - *self = match self { - Role::User => Role::Assistant, - Role::Assistant => Role::System, - Role::System => Role::User, - } - } -} - -impl Display for Role { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Role::User => write!(f, "User"), - Role::Assistant => write!(f, "Assistant"), - Role::System => write!(f, "System"), - } - } -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct RequestMessage { - pub role: Role, - pub content: String, -} - -#[derive(Debug, Default, Serialize)] -pub struct OpenAIRequest { - pub model: String, - pub messages: Vec, - pub stream: bool, - pub stop: Vec, - pub temperature: f32, -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct ResponseMessage { - pub role: Option, - pub content: Option, -} - -#[derive(Deserialize, Debug)] -pub struct OpenAIUsage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, -} - -#[derive(Deserialize, Debug)] -pub struct ChatChoiceDelta { - pub index: u32, - pub delta: ResponseMessage, - pub finish_reason: Option, -} - -#[derive(Deserialize, Debug)] -pub struct OpenAIResponseStreamEvent { - pub id: Option, - pub object: String, - pub created: u32, - pub model: String, - pub choices: Vec, - pub usage: Option, -} - -pub async fn stream_completion( - api_key: String, - executor: Arc, - mut request: OpenAIRequest, -) -> Result>> { - request.stream = true; - - let (tx, rx) = futures::channel::mpsc::unbounded::>(); - - let json_data = serde_json::to_string(&request)?; - let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)) - .body(json_data)? - .send_async() - .await?; - - let status = response.status(); - if status == StatusCode::OK { - executor - .spawn(async move { - let mut lines = BufReader::new(response.body_mut()).lines(); - - fn parse_line( - line: Result, - ) -> Result> { - if let Some(data) = line?.strip_prefix("data: ") { - let event = serde_json::from_str(&data)?; - Ok(Some(event)) - } else { - Ok(None) - } - } - - while let Some(line) = lines.next().await { - if let Some(event) = parse_line(line).transpose() { - let done = event.as_ref().map_or(false, |event| { - event - .choices - .last() - .map_or(false, |choice| choice.finish_reason.is_some()) - }); - if tx.unbounded_send(event).is_err() { - break; - } - - if done { - break; - } - } - } - - anyhow::Ok(()) - }) - .detach(); - - Ok(rx) - } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - #[derive(Deserialize)] - struct OpenAIResponse { - error: OpenAIError, - } - - #[derive(Deserialize)] - struct OpenAIError { - message: String, - } - - match serde_json::from_str::(&body) { - Ok(response) if !response.error.message.is_empty() => Err(anyhow!( - "Failed to connect to OpenAI API: {}", - response.error.message, - )), - - _ => Err(anyhow!( - "Failed to connect to OpenAI API: {} {}", - response.status(), - body, - )), - } - } -} +use crate::providers::open_ai::completion::OpenAIRequest; pub trait CompletionProvider { fn complete( @@ -179,36 +9,3 @@ pub trait CompletionProvider { prompt: OpenAIRequest, ) -> BoxFuture<'static, Result>>>; } - -pub struct OpenAICompletionProvider { - api_key: String, - executor: Arc, -} - -impl OpenAICompletionProvider { - pub fn new(api_key: String, executor: Arc) -> Self { - Self { api_key, executor } - } -} - -impl CompletionProvider for OpenAICompletionProvider { - fn complete( - &self, - prompt: OpenAIRequest, - ) -> BoxFuture<'static, Result>>> { - let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt); - async move { - let response = request.await?; - let stream = response - .filter_map(|response| async move { - match response { - Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), - Err(error) => Some(Err(error)), - } - }) - .boxed(); - Ok(stream) - } - .boxed() - } -} diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs new file mode 100644 index 0000000000000000000000000000000000000000..bb6138eee3b74d1471ae03d2f8352d3906391087 --- /dev/null +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -0,0 +1,209 @@ +use anyhow::{anyhow, Result}; +use futures::{ + future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, + Stream, StreamExt, +}; +use gpui::executor::Background; +use isahc::{http::StatusCode, Request, RequestExt}; +use serde::{Deserialize, Serialize}; +use std::{ + fmt::{self, Display}, + io, + sync::Arc, +}; + +use crate::completion::CompletionProvider; + +pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, +} + +impl Role { + pub fn cycle(&mut self) { + *self = match self { + Role::User => Role::Assistant, + Role::Assistant => Role::System, + Role::System => Role::User, + } + } +} + +impl Display for Role { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Role::User => write!(f, "User"), + Role::Assistant => write!(f, "Assistant"), + Role::System => write!(f, "System"), + } + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct RequestMessage { + pub role: Role, + pub content: String, +} + +#[derive(Debug, Default, Serialize)] +pub struct OpenAIRequest { + pub model: String, + pub messages: Vec, + pub stream: bool, + pub stop: Vec, + pub temperature: f32, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ResponseMessage { + pub role: Option, + pub content: Option, +} + +#[derive(Deserialize, Debug)] +pub struct OpenAIUsage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Deserialize, Debug)] +pub struct ChatChoiceDelta { + pub index: u32, + pub delta: ResponseMessage, + pub finish_reason: Option, +} + +#[derive(Deserialize, Debug)] +pub struct OpenAIResponseStreamEvent { + pub id: Option, + pub object: String, + pub created: u32, + pub model: String, + pub choices: Vec, + pub usage: Option, +} + +pub async fn stream_completion( + api_key: String, + executor: Arc, + mut request: OpenAIRequest, +) -> Result>> { + request.stream = true; + + let (tx, rx) = futures::channel::mpsc::unbounded::>(); + + let json_data = serde_json::to_string(&request)?; + let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body(json_data)? + .send_async() + .await?; + + let status = response.status(); + if status == StatusCode::OK { + executor + .spawn(async move { + let mut lines = BufReader::new(response.body_mut()).lines(); + + fn parse_line( + line: Result, + ) -> Result> { + if let Some(data) = line?.strip_prefix("data: ") { + let event = serde_json::from_str(&data)?; + Ok(Some(event)) + } else { + Ok(None) + } + } + + while let Some(line) = lines.next().await { + if let Some(event) = parse_line(line).transpose() { + let done = event.as_ref().map_or(false, |event| { + event + .choices + .last() + .map_or(false, |choice| choice.finish_reason.is_some()) + }); + if tx.unbounded_send(event).is_err() { + break; + } + + if done { + break; + } + } + } + + anyhow::Ok(()) + }) + .detach(); + + Ok(rx) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + #[derive(Deserialize)] + struct OpenAIResponse { + error: OpenAIError, + } + + #[derive(Deserialize)] + struct OpenAIError { + message: String, + } + + match serde_json::from_str::(&body) { + Ok(response) if !response.error.message.is_empty() => Err(anyhow!( + "Failed to connect to OpenAI API: {}", + response.error.message, + )), + + _ => Err(anyhow!( + "Failed to connect to OpenAI API: {} {}", + response.status(), + body, + )), + } + } +} + +pub struct OpenAICompletionProvider { + api_key: String, + executor: Arc, +} + +impl OpenAICompletionProvider { + pub fn new(api_key: String, executor: Arc) -> Self { + Self { api_key, executor } + } +} + +impl CompletionProvider for OpenAICompletionProvider { + fn complete( + &self, + prompt: OpenAIRequest, + ) -> BoxFuture<'static, Result>>> { + let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt); + async move { + let response = request.await?; + let stream = response + .filter_map(|response| async move { + match response { + Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), + Err(error) => Some(Err(error)), + } + }) + .boxed(); + Ok(stream) + } + .boxed() + } +} diff --git a/crates/ai/src/providers/open_ai/mod.rs b/crates/ai/src/providers/open_ai/mod.rs index 8d8489e187b26c092cd2043b42f7339b4a43d794..26f3068ca17723a8ae838d935ac7596972b7c51a 100644 --- a/crates/ai/src/providers/open_ai/mod.rs +++ b/crates/ai/src/providers/open_ai/mod.rs @@ -1,2 +1,4 @@ +pub mod completion; pub mod model; +pub use completion::*; pub use model::OpenAILanguageModel; diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 6c9b14333e34cbf5fd49d8299ba7bd891b607526..91d61a19f98b4cf4a61257a68d8f7212c6a33586 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -4,7 +4,7 @@ mod codegen; mod prompts; mod streaming_diff; -use ai::completion::Role; +use ai::providers::open_ai::Role; use anyhow::Result; pub use assistant_panel::AssistantPanel; use assistant_settings::OpenAIModel; diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 64eff04b8dfb9d66164851ab2513db48f68572f9..9b749e5091fd394eab07739e509b2a574bc43640 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -5,12 +5,12 @@ use crate::{ MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata, SavedMessage, }; -use ai::{ - completion::{ - stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, - }, - prompts::repository_context::PromptCodeSnippet, + +use ai::providers::open_ai::{ + stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, }; + +use ai::prompts::repository_context::PromptCodeSnippet; use anyhow::{anyhow, Result}; use chrono::{DateTime, Local}; use client::{telemetry::AssistantKind, ClickhouseEvent, TelemetrySettings}; diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index b6ef6b5cfa7fef58936828e0f121946290bc8b48..66d2f60690d23628bfdc7d62f84f869c24347119 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -1,5 +1,6 @@ use crate::streaming_diff::{Hunk, StreamingDiff}; -use ai::completion::{CompletionProvider, OpenAIRequest}; +use ai::completion::CompletionProvider; +use ai::providers::open_ai::OpenAIRequest; use anyhow::Result; use editor::{multi_buffer, Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint}; use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; From d813ae88458ed5e14899c5ebdd4437daa033ae6e Mon Sep 17 00:00:00 2001 From: KCaverly Date: Sun, 22 Oct 2023 14:33:19 +0200 Subject: [PATCH 05/25] replace OpenAIRequest with more generalized Box --- crates/ai/src/completion.rs | 6 +++-- crates/ai/src/providers/dummy.rs | 13 ++++++++++ crates/ai/src/providers/mod.rs | 1 + crates/ai/src/providers/open_ai/completion.rs | 16 +++++++----- crates/assistant/src/assistant_panel.rs | 20 +++++++++------ crates/assistant/src/codegen.rs | 25 +++++++++++++------ 6 files changed, 58 insertions(+), 23 deletions(-) create mode 100644 crates/ai/src/providers/dummy.rs diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index f45893898fccaf54e794d151ce4c6c64eff34bc5..ba89c869d214c8772430e30b3177af502a41031e 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -1,11 +1,13 @@ use anyhow::Result; use futures::{future::BoxFuture, stream::BoxStream}; -use crate::providers::open_ai::completion::OpenAIRequest; +pub trait CompletionRequest: Send + Sync { + fn data(&self) -> serde_json::Result; +} pub trait CompletionProvider { fn complete( &self, - prompt: OpenAIRequest, + prompt: Box, ) -> BoxFuture<'static, Result>>>; } diff --git a/crates/ai/src/providers/dummy.rs b/crates/ai/src/providers/dummy.rs new file mode 100644 index 0000000000000000000000000000000000000000..be42b13f2fe53dd1568ef5f59fc24e394bef2fef --- /dev/null +++ b/crates/ai/src/providers/dummy.rs @@ -0,0 +1,13 @@ +use crate::completion::CompletionRequest; +use serde::Serialize; + +#[derive(Serialize)] +pub struct DummyCompletionRequest { + pub name: String, +} + +impl CompletionRequest for DummyCompletionRequest { + fn data(&self) -> serde_json::Result { + serde_json::to_string(self) + } +} diff --git a/crates/ai/src/providers/mod.rs b/crates/ai/src/providers/mod.rs index acd0f9d91053869e3999ef0c1a23326480a7cbdd..7a7092baf39d40da8b09ffb8085463a4a8b2efdf 100644 --- a/crates/ai/src/providers/mod.rs +++ b/crates/ai/src/providers/mod.rs @@ -1 +1,2 @@ +pub mod dummy; pub mod open_ai; diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs index bb6138eee3b74d1471ae03d2f8352d3906391087..95ed13c0dd7f1d519c8279ea75fd3980dd4f5d50 100644 --- a/crates/ai/src/providers/open_ai/completion.rs +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -12,7 +12,7 @@ use std::{ sync::Arc, }; -use crate::completion::CompletionProvider; +use crate::completion::{CompletionProvider, CompletionRequest}; pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; @@ -59,6 +59,12 @@ pub struct OpenAIRequest { pub temperature: f32, } +impl CompletionRequest for OpenAIRequest { + fn data(&self) -> serde_json::Result { + serde_json::to_string(self) + } +} + #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct ResponseMessage { pub role: Option, @@ -92,13 +98,11 @@ pub struct OpenAIResponseStreamEvent { pub async fn stream_completion( api_key: String, executor: Arc, - mut request: OpenAIRequest, + request: Box, ) -> Result>> { - request.stream = true; - let (tx, rx) = futures::channel::mpsc::unbounded::>(); - let json_data = serde_json::to_string(&request)?; + let json_data = request.data()?; let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", api_key)) @@ -189,7 +193,7 @@ impl OpenAICompletionProvider { impl CompletionProvider for OpenAICompletionProvider { fn complete( &self, - prompt: OpenAIRequest, + prompt: Box, ) -> BoxFuture<'static, Result>>> { let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt); async move { diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 9b749e5091fd394eab07739e509b2a574bc43640..ec16c8fd046b48b8ef8dc0fc52afbdc6a0da1b3e 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -6,8 +6,11 @@ use crate::{ SavedMessage, }; -use ai::providers::open_ai::{ - stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, +use ai::{ + completion::CompletionRequest, + providers::open_ai::{ + stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, + }, }; use ai::prompts::repository_context::PromptCodeSnippet; @@ -745,13 +748,14 @@ impl AssistantPanel { content: prompt, }); - let request = OpenAIRequest { + let request = Box::new(OpenAIRequest { model: model.full_name().into(), messages, stream: true, stop: vec!["|END|>".to_string()], temperature, - }; + }); + codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx)); anyhow::Ok(()) }) @@ -1735,7 +1739,7 @@ impl Conversation { return Default::default(); }; - let request = OpenAIRequest { + let request: Box = Box::new(OpenAIRequest { model: self.model.full_name().to_string(), messages: self .messages(cx) @@ -1745,7 +1749,7 @@ impl Conversation { stream: true, stop: vec![], temperature: 1.0, - }; + }); let stream = stream_completion(api_key, cx.background().clone(), request); let assistant_message = self @@ -2025,13 +2029,13 @@ impl Conversation { "Summarize the conversation into a short title without punctuation" .into(), })); - let request = OpenAIRequest { + let request: Box = Box::new(OpenAIRequest { model: self.model.full_name().to_string(), messages: messages.collect(), stream: true, stop: vec![], temperature: 1.0, - }; + }); let stream = stream_completion(api_key, cx.background().clone(), request); self.pending_summary = cx.spawn(|this, mut cx| { diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index 66d2f60690d23628bfdc7d62f84f869c24347119..e535eca1441b4c8137d043c304a68cb2866e38fc 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -1,6 +1,5 @@ use crate::streaming_diff::{Hunk, StreamingDiff}; -use ai::completion::CompletionProvider; -use ai::providers::open_ai::OpenAIRequest; +use ai::completion::{CompletionProvider, CompletionRequest}; use anyhow::Result; use editor::{multi_buffer, Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint}; use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; @@ -96,7 +95,7 @@ impl Codegen { self.error.as_ref() } - pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext) { + pub fn start(&mut self, prompt: Box, cx: &mut ModelContext) { let range = self.range(); let snapshot = self.snapshot.clone(); let selected_text = snapshot @@ -336,6 +335,7 @@ fn strip_markdown_codeblock( #[cfg(test)] mod tests { use super::*; + use ai::providers::dummy::DummyCompletionRequest; use futures::{ future::BoxFuture, stream::{self, BoxStream}, @@ -381,7 +381,10 @@ mod tests { cx, ) }); - codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); + let request = Box::new(DummyCompletionRequest { + name: "test".to_string(), + }); + codegen.update(cx, |codegen, cx| codegen.start(request, cx)); let mut new_text = concat!( " let mut x = 0;\n", @@ -443,7 +446,11 @@ mod tests { cx, ) }); - codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); + + let request = Box::new(DummyCompletionRequest { + name: "test".to_string(), + }); + codegen.update(cx, |codegen, cx| codegen.start(request, cx)); let mut new_text = concat!( "t mut x = 0;\n", @@ -505,7 +512,11 @@ mod tests { cx, ) }); - codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); + + let request = Box::new(DummyCompletionRequest { + name: "test".to_string(), + }); + codegen.update(cx, |codegen, cx| codegen.start(request, cx)); let mut new_text = concat!( "let mut x = 0;\n", @@ -617,7 +628,7 @@ mod tests { impl CompletionProvider for TestCompletionProvider { fn complete( &self, - _prompt: OpenAIRequest, + _prompt: Box, ) -> BoxFuture<'static, Result>>> { let (tx, rx) = mpsc::channel(1); *self.last_completion_tx.lock() = Some(tx); From d1dec8314adb8be628912642efeffb572ef83b71 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Sun, 22 Oct 2023 14:46:22 +0200 Subject: [PATCH 06/25] move OpenAIEmbeddings to OpenAIEmbeddingProvider in providers folder --- crates/ai/src/embedding.rs | 287 +----------------- crates/ai/src/providers/dummy.rs | 37 ++- crates/ai/src/providers/open_ai/embedding.rs | 252 +++++++++++++++ crates/ai/src/providers/open_ai/mod.rs | 3 + crates/semantic_index/src/semantic_index.rs | 5 +- .../src/semantic_index_tests.rs | 19 +- crates/zed/examples/semantic_index_eval.rs | 4 +- 7 files changed, 308 insertions(+), 299 deletions(-) create mode 100644 crates/ai/src/providers/open_ai/embedding.rs diff --git a/crates/ai/src/embedding.rs b/crates/ai/src/embedding.rs index 4587ece0a23d116c55f07405e009497486d583d7..05798c3f5d1a3ed2a88739fd3a5a911ed708d560 100644 --- a/crates/ai/src/embedding.rs +++ b/crates/ai/src/embedding.rs @@ -1,30 +1,9 @@ -use anyhow::{anyhow, Result}; +use anyhow::Result; use async_trait::async_trait; -use futures::AsyncReadExt; -use gpui::executor::Background; -use gpui::serde_json; -use isahc::http::StatusCode; -use isahc::prelude::Configurable; -use isahc::{AsyncBody, Response}; -use lazy_static::lazy_static; use ordered_float::OrderedFloat; -use parking_lot::Mutex; -use parse_duration::parse; -use postage::watch; use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; use rusqlite::ToSql; -use serde::{Deserialize, Serialize}; -use std::env; -use std::ops::Add; -use std::sync::Arc; -use std::time::{Duration, Instant}; -use tiktoken_rs::{cl100k_base, CoreBPE}; -use util::http::{HttpClient, Request}; - -lazy_static! { - static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); - static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); -} +use std::time::Instant; #[derive(Debug, PartialEq, Clone)] pub struct Embedding(pub Vec); @@ -85,39 +64,6 @@ impl Embedding { } } -#[derive(Clone)] -pub struct OpenAIEmbeddings { - pub client: Arc, - pub executor: Arc, - rate_limit_count_rx: watch::Receiver>, - rate_limit_count_tx: Arc>>>, -} - -#[derive(Serialize)] -struct OpenAIEmbeddingRequest<'a> { - model: &'static str, - input: Vec<&'a str>, -} - -#[derive(Deserialize)] -struct OpenAIEmbeddingResponse { - data: Vec, - usage: OpenAIEmbeddingUsage, -} - -#[derive(Debug, Deserialize)] -struct OpenAIEmbedding { - embedding: Vec, - index: usize, - object: String, -} - -#[derive(Deserialize)] -struct OpenAIEmbeddingUsage { - prompt_tokens: usize, - total_tokens: usize, -} - #[async_trait] pub trait EmbeddingProvider: Sync + Send { fn is_authenticated(&self) -> bool; @@ -127,235 +73,6 @@ pub trait EmbeddingProvider: Sync + Send { fn rate_limit_expiration(&self) -> Option; } -pub struct DummyEmbeddings {} - -#[async_trait] -impl EmbeddingProvider for DummyEmbeddings { - fn is_authenticated(&self) -> bool { - true - } - fn rate_limit_expiration(&self) -> Option { - None - } - async fn embed_batch(&self, spans: Vec) -> Result> { - // 1024 is the OpenAI Embeddings size for ada models. - // the model we will likely be starting with. - let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]); - return Ok(vec![dummy_vec; spans.len()]); - } - - fn max_tokens_per_batch(&self) -> usize { - OPENAI_INPUT_LIMIT - } - - fn truncate(&self, span: &str) -> (String, usize) { - let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - let token_count = tokens.len(); - let output = if token_count > OPENAI_INPUT_LIMIT { - tokens.truncate(OPENAI_INPUT_LIMIT); - let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone()); - new_input.ok().unwrap_or_else(|| span.to_string()) - } else { - span.to_string() - }; - - (output, tokens.len()) - } -} - -const OPENAI_INPUT_LIMIT: usize = 8190; - -impl OpenAIEmbeddings { - pub fn new(client: Arc, executor: Arc) -> Self { - let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); - let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); - - OpenAIEmbeddings { - client, - executor, - rate_limit_count_rx, - rate_limit_count_tx, - } - } - - fn resolve_rate_limit(&self) { - let reset_time = *self.rate_limit_count_tx.lock().borrow(); - - if let Some(reset_time) = reset_time { - if Instant::now() >= reset_time { - *self.rate_limit_count_tx.lock().borrow_mut() = None - } - } - - log::trace!( - "resolving reset time: {:?}", - *self.rate_limit_count_tx.lock().borrow() - ); - } - - fn update_reset_time(&self, reset_time: Instant) { - let original_time = *self.rate_limit_count_tx.lock().borrow(); - - let updated_time = if let Some(original_time) = original_time { - if reset_time < original_time { - Some(reset_time) - } else { - Some(original_time) - } - } else { - Some(reset_time) - }; - - log::trace!("updating rate limit time: {:?}", updated_time); - - *self.rate_limit_count_tx.lock().borrow_mut() = updated_time; - } - async fn send_request( - &self, - api_key: &str, - spans: Vec<&str>, - request_timeout: u64, - ) -> Result> { - let request = Request::post("https://api.openai.com/v1/embeddings") - .redirect_policy(isahc::config::RedirectPolicy::Follow) - .timeout(Duration::from_secs(request_timeout)) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)) - .body( - serde_json::to_string(&OpenAIEmbeddingRequest { - input: spans.clone(), - model: "text-embedding-ada-002", - }) - .unwrap() - .into(), - )?; - - Ok(self.client.send(request).await?) - } -} - -#[async_trait] -impl EmbeddingProvider for OpenAIEmbeddings { - fn is_authenticated(&self) -> bool { - OPENAI_API_KEY.as_ref().is_some() - } - fn max_tokens_per_batch(&self) -> usize { - 50000 - } - - fn rate_limit_expiration(&self) -> Option { - *self.rate_limit_count_rx.borrow() - } - fn truncate(&self, span: &str) -> (String, usize) { - let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - let output = if tokens.len() > OPENAI_INPUT_LIMIT { - tokens.truncate(OPENAI_INPUT_LIMIT); - OPENAI_BPE_TOKENIZER - .decode(tokens.clone()) - .ok() - .unwrap_or_else(|| span.to_string()) - } else { - span.to_string() - }; - - (output, tokens.len()) - } - - async fn embed_batch(&self, spans: Vec) -> Result> { - const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; - const MAX_RETRIES: usize = 4; - - let api_key = OPENAI_API_KEY - .as_ref() - .ok_or_else(|| anyhow!("no api key"))?; - - let mut request_number = 0; - let mut rate_limiting = false; - let mut request_timeout: u64 = 15; - let mut response: Response; - while request_number < MAX_RETRIES { - response = self - .send_request( - api_key, - spans.iter().map(|x| &**x).collect(), - request_timeout, - ) - .await?; - - request_number += 1; - - match response.status() { - StatusCode::REQUEST_TIMEOUT => { - request_timeout += 5; - } - StatusCode::OK => { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; - - log::trace!( - "openai embedding completed. tokens: {:?}", - response.usage.total_tokens - ); - - // If we complete a request successfully that was previously rate_limited - // resolve the rate limit - if rate_limiting { - self.resolve_rate_limit() - } - - return Ok(response - .data - .into_iter() - .map(|embedding| Embedding::from(embedding.embedding)) - .collect()); - } - StatusCode::TOO_MANY_REQUESTS => { - rate_limiting = true; - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - let delay_duration = { - let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); - if let Some(time_to_reset) = - response.headers().get("x-ratelimit-reset-tokens") - { - if let Ok(time_str) = time_to_reset.to_str() { - parse(time_str).unwrap_or(delay) - } else { - delay - } - } else { - delay - } - }; - - // If we've previously rate limited, increment the duration but not the count - let reset_time = Instant::now().add(delay_duration); - self.update_reset_time(reset_time); - - log::trace!( - "openai rate limiting: waiting {:?} until lifted", - &delay_duration - ); - - self.executor.timer(delay_duration).await; - } - _ => { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - return Err(anyhow!( - "open ai bad request: {:?} {:?}", - &response.status(), - body - )); - } - } - } - Err(anyhow!("openai max retries")) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/ai/src/providers/dummy.rs b/crates/ai/src/providers/dummy.rs index be42b13f2fe53dd1568ef5f59fc24e394bef2fef..8061a2ca6b1c21d853db152fa431d9a2674c8740 100644 --- a/crates/ai/src/providers/dummy.rs +++ b/crates/ai/src/providers/dummy.rs @@ -1,4 +1,10 @@ -use crate::completion::CompletionRequest; +use std::time::Instant; + +use crate::{ + completion::CompletionRequest, + embedding::{Embedding, EmbeddingProvider}, +}; +use async_trait::async_trait; use serde::Serialize; #[derive(Serialize)] @@ -11,3 +17,32 @@ impl CompletionRequest for DummyCompletionRequest { serde_json::to_string(self) } } + +pub struct DummyEmbeddingProvider {} + +#[async_trait] +impl EmbeddingProvider for DummyEmbeddingProvider { + fn is_authenticated(&self) -> bool { + true + } + fn rate_limit_expiration(&self) -> Option { + None + } + async fn embed_batch(&self, spans: Vec) -> anyhow::Result> { + // 1024 is the OpenAI Embeddings size for ada models. + // the model we will likely be starting with. + let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]); + return Ok(vec![dummy_vec; spans.len()]); + } + + fn max_tokens_per_batch(&self) -> usize { + 8190 + } + + fn truncate(&self, span: &str) -> (String, usize) { + let truncated = span.chars().collect::>()[..8190] + .iter() + .collect::(); + (truncated, 8190) + } +} diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs new file mode 100644 index 0000000000000000000000000000000000000000..35398394dc8fdbef09f74dab5a82307d7b4b0aaf --- /dev/null +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -0,0 +1,252 @@ +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use futures::AsyncReadExt; +use gpui::executor::Background; +use gpui::serde_json; +use isahc::http::StatusCode; +use isahc::prelude::Configurable; +use isahc::{AsyncBody, Response}; +use lazy_static::lazy_static; +use parking_lot::Mutex; +use parse_duration::parse; +use postage::watch; +use serde::{Deserialize, Serialize}; +use std::env; +use std::ops::Add; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tiktoken_rs::{cl100k_base, CoreBPE}; +use util::http::{HttpClient, Request}; + +use crate::embedding::{Embedding, EmbeddingProvider}; + +lazy_static! { + static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); + static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); +} + +#[derive(Clone)] +pub struct OpenAIEmbeddingProvider { + pub client: Arc, + pub executor: Arc, + rate_limit_count_rx: watch::Receiver>, + rate_limit_count_tx: Arc>>>, +} + +#[derive(Serialize)] +struct OpenAIEmbeddingRequest<'a> { + model: &'static str, + input: Vec<&'a str>, +} + +#[derive(Deserialize)] +struct OpenAIEmbeddingResponse { + data: Vec, + usage: OpenAIEmbeddingUsage, +} + +#[derive(Debug, Deserialize)] +struct OpenAIEmbedding { + embedding: Vec, + index: usize, + object: String, +} + +#[derive(Deserialize)] +struct OpenAIEmbeddingUsage { + prompt_tokens: usize, + total_tokens: usize, +} + +const OPENAI_INPUT_LIMIT: usize = 8190; + +impl OpenAIEmbeddingProvider { + pub fn new(client: Arc, executor: Arc) -> Self { + let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); + let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); + + OpenAIEmbeddingProvider { + client, + executor, + rate_limit_count_rx, + rate_limit_count_tx, + } + } + + fn resolve_rate_limit(&self) { + let reset_time = *self.rate_limit_count_tx.lock().borrow(); + + if let Some(reset_time) = reset_time { + if Instant::now() >= reset_time { + *self.rate_limit_count_tx.lock().borrow_mut() = None + } + } + + log::trace!( + "resolving reset time: {:?}", + *self.rate_limit_count_tx.lock().borrow() + ); + } + + fn update_reset_time(&self, reset_time: Instant) { + let original_time = *self.rate_limit_count_tx.lock().borrow(); + + let updated_time = if let Some(original_time) = original_time { + if reset_time < original_time { + Some(reset_time) + } else { + Some(original_time) + } + } else { + Some(reset_time) + }; + + log::trace!("updating rate limit time: {:?}", updated_time); + + *self.rate_limit_count_tx.lock().borrow_mut() = updated_time; + } + async fn send_request( + &self, + api_key: &str, + spans: Vec<&str>, + request_timeout: u64, + ) -> Result> { + let request = Request::post("https://api.openai.com/v1/embeddings") + .redirect_policy(isahc::config::RedirectPolicy::Follow) + .timeout(Duration::from_secs(request_timeout)) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body( + serde_json::to_string(&OpenAIEmbeddingRequest { + input: spans.clone(), + model: "text-embedding-ada-002", + }) + .unwrap() + .into(), + )?; + + Ok(self.client.send(request).await?) + } +} + +#[async_trait] +impl EmbeddingProvider for OpenAIEmbeddingProvider { + fn is_authenticated(&self) -> bool { + OPENAI_API_KEY.as_ref().is_some() + } + fn max_tokens_per_batch(&self) -> usize { + 50000 + } + + fn rate_limit_expiration(&self) -> Option { + *self.rate_limit_count_rx.borrow() + } + fn truncate(&self, span: &str) -> (String, usize) { + let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); + let output = if tokens.len() > OPENAI_INPUT_LIMIT { + tokens.truncate(OPENAI_INPUT_LIMIT); + OPENAI_BPE_TOKENIZER + .decode(tokens.clone()) + .ok() + .unwrap_or_else(|| span.to_string()) + } else { + span.to_string() + }; + + (output, tokens.len()) + } + + async fn embed_batch(&self, spans: Vec) -> Result> { + const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; + const MAX_RETRIES: usize = 4; + + let api_key = OPENAI_API_KEY + .as_ref() + .ok_or_else(|| anyhow!("no api key"))?; + + let mut request_number = 0; + let mut rate_limiting = false; + let mut request_timeout: u64 = 15; + let mut response: Response; + while request_number < MAX_RETRIES { + response = self + .send_request( + api_key, + spans.iter().map(|x| &**x).collect(), + request_timeout, + ) + .await?; + + request_number += 1; + + match response.status() { + StatusCode::REQUEST_TIMEOUT => { + request_timeout += 5; + } + StatusCode::OK => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; + + log::trace!( + "openai embedding completed. tokens: {:?}", + response.usage.total_tokens + ); + + // If we complete a request successfully that was previously rate_limited + // resolve the rate limit + if rate_limiting { + self.resolve_rate_limit() + } + + return Ok(response + .data + .into_iter() + .map(|embedding| Embedding::from(embedding.embedding)) + .collect()); + } + StatusCode::TOO_MANY_REQUESTS => { + rate_limiting = true; + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + let delay_duration = { + let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); + if let Some(time_to_reset) = + response.headers().get("x-ratelimit-reset-tokens") + { + if let Ok(time_str) = time_to_reset.to_str() { + parse(time_str).unwrap_or(delay) + } else { + delay + } + } else { + delay + } + }; + + // If we've previously rate limited, increment the duration but not the count + let reset_time = Instant::now().add(delay_duration); + self.update_reset_time(reset_time); + + log::trace!( + "openai rate limiting: waiting {:?} until lifted", + &delay_duration + ); + + self.executor.timer(delay_duration).await; + } + _ => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + return Err(anyhow!( + "open ai bad request: {:?} {:?}", + &response.status(), + body + )); + } + } + } + Err(anyhow!("openai max retries")) + } +} diff --git a/crates/ai/src/providers/open_ai/mod.rs b/crates/ai/src/providers/open_ai/mod.rs index 26f3068ca17723a8ae838d935ac7596972b7c51a..67cb2b53151e93560ea38429ef17dc2b57dc3e86 100644 --- a/crates/ai/src/providers/open_ai/mod.rs +++ b/crates/ai/src/providers/open_ai/mod.rs @@ -1,4 +1,7 @@ pub mod completion; +pub mod embedding; pub mod model; + pub use completion::*; +pub use embedding::*; pub use model::OpenAILanguageModel; diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index ecdba4364315eb8f0a4ed2cf579fcc3149e56e67..926eb3045c826fc0bf0db6c7e34d968f1cf9fe8d 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -7,7 +7,8 @@ pub mod semantic_index_settings; mod semantic_index_tests; use crate::semantic_index_settings::SemanticIndexSettings; -use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings}; +use ai::embedding::{Embedding, EmbeddingProvider}; +use ai::providers::open_ai::OpenAIEmbeddingProvider; use anyhow::{anyhow, Result}; use collections::{BTreeMap, HashMap, HashSet}; use db::VectorDatabase; @@ -88,7 +89,7 @@ pub fn init( let semantic_index = SemanticIndex::new( fs, db_file_path, - Arc::new(OpenAIEmbeddings::new(http_client, cx.background())), + Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())), language_registry, cx.clone(), ) diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 182010ca8339e9cc8ec1ff06ac31741eb4fb78ae..6842ce5c5d45711c45883c348028463e8429aa64 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -4,7 +4,8 @@ use crate::{ semantic_index_settings::SemanticIndexSettings, FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT, }; -use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider}; +use ai::embedding::{Embedding, EmbeddingProvider}; +use ai::providers::dummy::DummyEmbeddingProvider; use anyhow::Result; use async_trait::async_trait; use gpui::{executor::Deterministic, Task, TestAppContext}; @@ -280,7 +281,7 @@ fn assert_search_results( #[gpui::test] async fn test_code_context_retrieval_rust() { let language = rust_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(DummyEmbeddingProvider {}); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " @@ -382,7 +383,7 @@ async fn test_code_context_retrieval_rust() { #[gpui::test] async fn test_code_context_retrieval_json() { let language = json_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(DummyEmbeddingProvider {}); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -466,7 +467,7 @@ fn assert_documents_eq( #[gpui::test] async fn test_code_context_retrieval_javascript() { let language = js_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(DummyEmbeddingProvider {}); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " @@ -565,7 +566,7 @@ async fn test_code_context_retrieval_javascript() { #[gpui::test] async fn test_code_context_retrieval_lua() { let language = lua_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(DummyEmbeddingProvider {}); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -639,7 +640,7 @@ async fn test_code_context_retrieval_lua() { #[gpui::test] async fn test_code_context_retrieval_elixir() { let language = elixir_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(DummyEmbeddingProvider {}); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -756,7 +757,7 @@ async fn test_code_context_retrieval_elixir() { #[gpui::test] async fn test_code_context_retrieval_cpp() { let language = cpp_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(DummyEmbeddingProvider {}); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " @@ -909,7 +910,7 @@ async fn test_code_context_retrieval_cpp() { #[gpui::test] async fn test_code_context_retrieval_ruby() { let language = ruby_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(DummyEmbeddingProvider {}); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -1100,7 +1101,7 @@ async fn test_code_context_retrieval_ruby() { #[gpui::test] async fn test_code_context_retrieval_php() { let language = php_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(DummyEmbeddingProvider {}); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" diff --git a/crates/zed/examples/semantic_index_eval.rs b/crates/zed/examples/semantic_index_eval.rs index 33d6b3689c1f617a96d5fab00d41e51fa28d63f4..0bada4750222d68bd398af824afbd00f30ce2877 100644 --- a/crates/zed/examples/semantic_index_eval.rs +++ b/crates/zed/examples/semantic_index_eval.rs @@ -1,4 +1,4 @@ -use ai::embedding::OpenAIEmbeddings; +use ai::providers::open_ai::OpenAIEmbeddingProvider; use anyhow::{anyhow, Result}; use client::{self, UserStore}; use gpui::{AsyncAppContext, ModelHandle, Task}; @@ -474,7 +474,7 @@ fn main() { let semantic_index = SemanticIndex::new( fs.clone(), db_file_path, - Arc::new(OpenAIEmbeddings::new(http_client, cx.background())), + Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())), languages.clone(), cx.clone(), ) From 2b780ee7b2e63d990400c19c002eb1fc2f7bdfd7 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Sun, 22 Oct 2023 15:00:09 +0200 Subject: [PATCH 07/25] add base model to EmbeddingProvider, not yet leveraged for truncation --- crates/ai/src/embedding.rs | 3 ++ crates/ai/src/providers/dummy.rs | 35 +++++++++++++++++++ crates/ai/src/providers/open_ai/embedding.rs | 10 ++++++ crates/ai/src/providers/open_ai/model.rs | 1 + .../src/semantic_index_tests.rs | 10 ++++-- 5 files changed, 57 insertions(+), 2 deletions(-) diff --git a/crates/ai/src/embedding.rs b/crates/ai/src/embedding.rs index 05798c3f5d1a3ed2a88739fd3a5a911ed708d560..f792406c8b9367baf825af50b01f077f119b3860 100644 --- a/crates/ai/src/embedding.rs +++ b/crates/ai/src/embedding.rs @@ -5,6 +5,8 @@ use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; use rusqlite::ToSql; use std::time::Instant; +use crate::models::LanguageModel; + #[derive(Debug, PartialEq, Clone)] pub struct Embedding(pub Vec); @@ -66,6 +68,7 @@ impl Embedding { #[async_trait] pub trait EmbeddingProvider: Sync + Send { + fn base_model(&self) -> Box; fn is_authenticated(&self) -> bool; async fn embed_batch(&self, spans: Vec) -> Result>; fn max_tokens_per_batch(&self) -> usize; diff --git a/crates/ai/src/providers/dummy.rs b/crates/ai/src/providers/dummy.rs index 8061a2ca6b1c21d853db152fa431d9a2674c8740..9df5547da1fcdd0f0473f2b5371ff85a046b46a3 100644 --- a/crates/ai/src/providers/dummy.rs +++ b/crates/ai/src/providers/dummy.rs @@ -3,10 +3,42 @@ use std::time::Instant; use crate::{ completion::CompletionRequest, embedding::{Embedding, EmbeddingProvider}, + models::{LanguageModel, TruncationDirection}, }; use async_trait::async_trait; use serde::Serialize; +pub struct DummyLanguageModel {} + +impl LanguageModel for DummyLanguageModel { + fn name(&self) -> String { + "dummy".to_string() + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(1000) + } + fn truncate( + &self, + content: &str, + length: usize, + direction: crate::models::TruncationDirection, + ) -> anyhow::Result { + let truncated = match direction { + TruncationDirection::End => content.chars().collect::>()[..length] + .iter() + .collect::(), + TruncationDirection::Start => content.chars().collect::>()[..length] + .iter() + .collect::(), + }; + + anyhow::Ok(truncated) + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + anyhow::Ok(content.chars().collect::>().len()) + } +} + #[derive(Serialize)] pub struct DummyCompletionRequest { pub name: String, @@ -22,6 +54,9 @@ pub struct DummyEmbeddingProvider {} #[async_trait] impl EmbeddingProvider for DummyEmbeddingProvider { + fn base_model(&self) -> Box { + Box::new(DummyLanguageModel {}) + } fn is_authenticated(&self) -> bool { true } diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index 35398394dc8fdbef09f74dab5a82307d7b4b0aaf..ed028177f68d96bf84576c7dbcba7d4bb4888907 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -19,6 +19,8 @@ use tiktoken_rs::{cl100k_base, CoreBPE}; use util::http::{HttpClient, Request}; use crate::embedding::{Embedding, EmbeddingProvider}; +use crate::models::LanguageModel; +use crate::providers::open_ai::OpenAILanguageModel; lazy_static! { static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); @@ -27,6 +29,7 @@ lazy_static! { #[derive(Clone)] pub struct OpenAIEmbeddingProvider { + model: OpenAILanguageModel, pub client: Arc, pub executor: Arc, rate_limit_count_rx: watch::Receiver>, @@ -65,7 +68,10 @@ impl OpenAIEmbeddingProvider { let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); + let model = OpenAILanguageModel::load("text-embedding-ada-002"); + OpenAIEmbeddingProvider { + model, client, executor, rate_limit_count_rx, @@ -131,6 +137,10 @@ impl OpenAIEmbeddingProvider { #[async_trait] impl EmbeddingProvider for OpenAIEmbeddingProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(self.model.clone()); + model + } fn is_authenticated(&self) -> bool { OPENAI_API_KEY.as_ref().is_some() } diff --git a/crates/ai/src/providers/open_ai/model.rs b/crates/ai/src/providers/open_ai/model.rs index 42523f3df48951d8674b33409105f8d802fd6c25..6e306c80b905865c011c9064934827085ca126d6 100644 --- a/crates/ai/src/providers/open_ai/model.rs +++ b/crates/ai/src/providers/open_ai/model.rs @@ -4,6 +4,7 @@ use util::ResultExt; use crate::models::{LanguageModel, TruncationDirection}; +#[derive(Clone)] pub struct OpenAILanguageModel { name: String, bpe: Option, diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 6842ce5c5d45711c45883c348028463e8429aa64..43779f5b6ccf23cf18fad232a6a4db2f33ce0b2c 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -4,8 +4,11 @@ use crate::{ semantic_index_settings::SemanticIndexSettings, FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT, }; -use ai::embedding::{Embedding, EmbeddingProvider}; -use ai::providers::dummy::DummyEmbeddingProvider; +use ai::providers::dummy::{DummyEmbeddingProvider, DummyLanguageModel}; +use ai::{ + embedding::{Embedding, EmbeddingProvider}, + models::LanguageModel, +}; use anyhow::Result; use async_trait::async_trait; use gpui::{executor::Deterministic, Task, TestAppContext}; @@ -1282,6 +1285,9 @@ impl FakeEmbeddingProvider { #[async_trait] impl EmbeddingProvider for FakeEmbeddingProvider { + fn base_model(&self) -> Box { + Box::new(DummyLanguageModel {}) + } fn is_authenticated(&self) -> bool { true } From 4e90e4599973b2016370b99a5406bb1a49ca21f4 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 23 Oct 2023 14:07:45 +0200 Subject: [PATCH 08/25] move embedding truncation to base model --- crates/ai/src/embedding.rs | 1 - crates/ai/src/providers/dummy.rs | 11 +++---- crates/ai/src/providers/open_ai/embedding.rs | 30 ++++++++--------- crates/semantic_index/src/parsing.rs | 33 ++++++++++++++++--- .../src/semantic_index_tests.rs | 9 ++--- 5 files changed, 49 insertions(+), 35 deletions(-) diff --git a/crates/ai/src/embedding.rs b/crates/ai/src/embedding.rs index f792406c8b9367baf825af50b01f077f119b3860..4e67f44cae72b9b7778e05895964f42f8f78e535 100644 --- a/crates/ai/src/embedding.rs +++ b/crates/ai/src/embedding.rs @@ -72,7 +72,6 @@ pub trait EmbeddingProvider: Sync + Send { fn is_authenticated(&self) -> bool; async fn embed_batch(&self, spans: Vec) -> Result>; fn max_tokens_per_batch(&self) -> usize; - fn truncate(&self, span: &str) -> (String, usize); fn rate_limit_expiration(&self) -> Option; } diff --git a/crates/ai/src/providers/dummy.rs b/crates/ai/src/providers/dummy.rs index 9df5547da1fcdd0f0473f2b5371ff85a046b46a3..7eef16111d9789975128dee1e2183094908e84f2 100644 --- a/crates/ai/src/providers/dummy.rs +++ b/crates/ai/src/providers/dummy.rs @@ -23,6 +23,10 @@ impl LanguageModel for DummyLanguageModel { length: usize, direction: crate::models::TruncationDirection, ) -> anyhow::Result { + if content.len() < length { + return anyhow::Ok(content.to_string()); + } + let truncated = match direction { TruncationDirection::End => content.chars().collect::>()[..length] .iter() @@ -73,11 +77,4 @@ impl EmbeddingProvider for DummyEmbeddingProvider { fn max_tokens_per_batch(&self) -> usize { 8190 } - - fn truncate(&self, span: &str) -> (String, usize) { - let truncated = span.chars().collect::>()[..8190] - .iter() - .collect::(); - (truncated, 8190) - } } diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index ed028177f68d96bf84576c7dbcba7d4bb4888907..3689cb36f41d34ca51d39478eba14ebef21b5c00 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -61,8 +61,6 @@ struct OpenAIEmbeddingUsage { total_tokens: usize, } -const OPENAI_INPUT_LIMIT: usize = 8190; - impl OpenAIEmbeddingProvider { pub fn new(client: Arc, executor: Arc) -> Self { let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); @@ -151,20 +149,20 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider { fn rate_limit_expiration(&self) -> Option { *self.rate_limit_count_rx.borrow() } - fn truncate(&self, span: &str) -> (String, usize) { - let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - let output = if tokens.len() > OPENAI_INPUT_LIMIT { - tokens.truncate(OPENAI_INPUT_LIMIT); - OPENAI_BPE_TOKENIZER - .decode(tokens.clone()) - .ok() - .unwrap_or_else(|| span.to_string()) - } else { - span.to_string() - }; - - (output, tokens.len()) - } + // fn truncate(&self, span: &str) -> (String, usize) { + // let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); + // let output = if tokens.len() > OPENAI_INPUT_LIMIT { + // tokens.truncate(OPENAI_INPUT_LIMIT); + // OPENAI_BPE_TOKENIZER + // .decode(tokens.clone()) + // .ok() + // .unwrap_or_else(|| span.to_string()) + // } else { + // span.to_string() + // }; + + // (output, tokens.len()) + // } async fn embed_batch(&self, spans: Vec) -> Result> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index f9b8bac9a484bfae48c62683ee096e2e49420622..cb15ca453b2c0640739bd44a95482ca527b8d91b 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -1,4 +1,7 @@ -use ai::embedding::{Embedding, EmbeddingProvider}; +use ai::{ + embedding::{Embedding, EmbeddingProvider}, + models::TruncationDirection, +}; use anyhow::{anyhow, Result}; use language::{Grammar, Language}; use rusqlite::{ @@ -108,7 +111,14 @@ impl CodeContextRetriever { .replace("", language_name.as_ref()) .replace("", &content); let digest = SpanDigest::from(document_span.as_str()); - let (document_span, token_count) = self.embedding_provider.truncate(&document_span); + let model = self.embedding_provider.base_model(); + let document_span = model.truncate( + &document_span, + model.capacity()?, + ai::models::TruncationDirection::End, + )?; + let token_count = model.count_tokens(&document_span)?; + Ok(vec![Span { range: 0..content.len(), content: document_span, @@ -131,7 +141,15 @@ impl CodeContextRetriever { ) .replace("", &content); let digest = SpanDigest::from(document_span.as_str()); - let (document_span, token_count) = self.embedding_provider.truncate(&document_span); + + let model = self.embedding_provider.base_model(); + let document_span = model.truncate( + &document_span, + model.capacity()?, + ai::models::TruncationDirection::End, + )?; + let token_count = model.count_tokens(&document_span)?; + Ok(vec![Span { range: 0..content.len(), content: document_span, @@ -222,8 +240,13 @@ impl CodeContextRetriever { .replace("", language_name.as_ref()) .replace("item", &span.content); - let (document_content, token_count) = - self.embedding_provider.truncate(&document_content); + let model = self.embedding_provider.base_model(); + let document_content = model.truncate( + &document_content, + model.capacity()?, + TruncationDirection::End, + )?; + let token_count = model.count_tokens(&document_content)?; span.content = document_content; span.token_count = token_count; diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 43779f5b6ccf23cf18fad232a6a4db2f33ce0b2c..002dee33e33c9a253ad9c2baf51c6c4dcdb6f2a4 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -1291,12 +1291,8 @@ impl EmbeddingProvider for FakeEmbeddingProvider { fn is_authenticated(&self) -> bool { true } - fn truncate(&self, span: &str) -> (String, usize) { - (span.to_string(), 1) - } - fn max_tokens_per_batch(&self) -> usize { - 200 + 1000 } fn rate_limit_expiration(&self) -> Option { @@ -1306,7 +1302,8 @@ impl EmbeddingProvider for FakeEmbeddingProvider { async fn embed_batch(&self, spans: Vec) -> Result> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst); - Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) + + anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) } } From 3447a9478c62476728f1e0131d708699dad2bcd1 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 26 Oct 2023 11:18:16 +0200 Subject: [PATCH 09/25] updated authentication for embedding provider --- crates/ai/Cargo.toml | 3 + crates/ai/src/ai.rs | 3 + crates/ai/src/auth.rs | 20 +++ crates/ai/src/embedding.rs | 8 +- crates/ai/src/prompts/base.rs | 41 +----- crates/ai/src/providers/dummy.rs | 85 ------------ crates/ai/src/providers/mod.rs | 1 - crates/ai/src/providers/open_ai/auth.rs | 33 +++++ crates/ai/src/providers/open_ai/embedding.rs | 46 ++----- crates/ai/src/providers/open_ai/mod.rs | 1 + crates/ai/src/test.rs | 123 ++++++++++++++++++ crates/assistant/src/codegen.rs | 14 +- crates/semantic_index/Cargo.toml | 1 + crates/semantic_index/src/embedding_queue.rs | 16 +-- crates/semantic_index/src/semantic_index.rs | 52 +++++--- .../src/semantic_index_tests.rs | 101 +++----------- 16 files changed, 277 insertions(+), 271 deletions(-) create mode 100644 crates/ai/src/auth.rs delete mode 100644 crates/ai/src/providers/dummy.rs create mode 100644 crates/ai/src/providers/open_ai/auth.rs create mode 100644 crates/ai/src/test.rs diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index b24c4e5ece5b02eac003a6c18f186faa1eaef7ef..fb49a4b515540836a757610db5c268321f9f068b 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -8,6 +8,9 @@ publish = false path = "src/ai.rs" doctest = false +[features] +test-support = [] + [dependencies] gpui = { path = "../gpui" } util = { path = "../util" } diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index a3ae2fcf7ffb5075c70b607f1cdf34279cd063a3..dda22d2a1d04dd6083fb1ae9879f49e74c8b4627 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -1,5 +1,8 @@ +pub mod auth; pub mod completion; pub mod embedding; pub mod models; pub mod prompts; pub mod providers; +#[cfg(any(test, feature = "test-support"))] +pub mod test; diff --git a/crates/ai/src/auth.rs b/crates/ai/src/auth.rs new file mode 100644 index 0000000000000000000000000000000000000000..a3ce8aece1f97131183d193447bd9e0b848795cc --- /dev/null +++ b/crates/ai/src/auth.rs @@ -0,0 +1,20 @@ +use gpui::AppContext; + +#[derive(Clone)] +pub enum ProviderCredential { + Credentials { api_key: String }, + NoCredentials, + NotNeeded, +} + +pub trait CredentialProvider: Send + Sync { + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential; +} + +#[derive(Clone)] +pub struct NullCredentialProvider; +impl CredentialProvider for NullCredentialProvider { + fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential { + ProviderCredential::NotNeeded + } +} diff --git a/crates/ai/src/embedding.rs b/crates/ai/src/embedding.rs index 8cfc901525978307da15e5db47512d7ce26cf597..50f04232ab64f4f16571efb9c11fd4f1bccf89e2 100644 --- a/crates/ai/src/embedding.rs +++ b/crates/ai/src/embedding.rs @@ -7,6 +7,7 @@ use ordered_float::OrderedFloat; use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; use rusqlite::ToSql; +use crate::auth::{CredentialProvider, ProviderCredential}; use crate::models::LanguageModel; #[derive(Debug, PartialEq, Clone)] @@ -71,11 +72,14 @@ impl Embedding { #[async_trait] pub trait EmbeddingProvider: Sync + Send { fn base_model(&self) -> Box; - fn retrieve_credentials(&self, cx: &AppContext) -> Option; + fn credential_provider(&self) -> Box; + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { + self.credential_provider().retrieve_credentials(cx) + } async fn embed_batch( &self, spans: Vec, - api_key: Option, + credential: ProviderCredential, ) -> Result>; fn max_tokens_per_batch(&self) -> usize; fn rate_limit_expiration(&self) -> Option; diff --git a/crates/ai/src/prompts/base.rs b/crates/ai/src/prompts/base.rs index f0ff597e635702329964fe18efc4b2300c4c26c6..a2106c74106b0b976cc70717b4535ccb557573bf 100644 --- a/crates/ai/src/prompts/base.rs +++ b/crates/ai/src/prompts/base.rs @@ -126,6 +126,7 @@ impl PromptChain { #[cfg(test)] pub(crate) mod tests { use crate::models::TruncationDirection; + use crate::test::FakeLanguageModel; use super::*; @@ -181,39 +182,7 @@ pub(crate) mod tests { } } - #[derive(Clone)] - struct DummyLanguageModel { - capacity: usize, - } - - impl LanguageModel for DummyLanguageModel { - fn name(&self) -> String { - "dummy".to_string() - } - fn count_tokens(&self, content: &str) -> anyhow::Result { - anyhow::Ok(content.chars().collect::>().len()) - } - fn truncate( - &self, - content: &str, - length: usize, - direction: TruncationDirection, - ) -> anyhow::Result { - anyhow::Ok(match direction { - TruncationDirection::End => content.chars().collect::>()[..length] - .into_iter() - .collect::(), - TruncationDirection::Start => content.chars().collect::>()[length..] - .into_iter() - .collect::(), - }) - } - fn capacity(&self) -> anyhow::Result { - anyhow::Ok(self.capacity) - } - } - - let model: Arc = Arc::new(DummyLanguageModel { capacity: 100 }); + let model: Arc = Arc::new(FakeLanguageModel { capacity: 100 }); let args = PromptArguments { model: model.clone(), language_name: None, @@ -249,7 +218,7 @@ pub(crate) mod tests { // Testing with Truncation Off // Should ignore capacity and return all prompts - let model: Arc = Arc::new(DummyLanguageModel { capacity: 20 }); + let model: Arc = Arc::new(FakeLanguageModel { capacity: 20 }); let args = PromptArguments { model: model.clone(), language_name: None, @@ -286,7 +255,7 @@ pub(crate) mod tests { // Testing with Truncation Off // Should ignore capacity and return all prompts let capacity = 20; - let model: Arc = Arc::new(DummyLanguageModel { capacity }); + let model: Arc = Arc::new(FakeLanguageModel { capacity }); let args = PromptArguments { model: model.clone(), language_name: None, @@ -322,7 +291,7 @@ pub(crate) mod tests { // Change Ordering of Prompts Based on Priority let capacity = 120; let reserved_tokens = 10; - let model: Arc = Arc::new(DummyLanguageModel { capacity }); + let model: Arc = Arc::new(FakeLanguageModel { capacity }); let args = PromptArguments { model: model.clone(), language_name: None, diff --git a/crates/ai/src/providers/dummy.rs b/crates/ai/src/providers/dummy.rs deleted file mode 100644 index 2ee26488bd99a46d16459e1d215a23adcb4d2e71..0000000000000000000000000000000000000000 --- a/crates/ai/src/providers/dummy.rs +++ /dev/null @@ -1,85 +0,0 @@ -use std::time::Instant; - -use crate::{ - completion::CompletionRequest, - embedding::{Embedding, EmbeddingProvider}, - models::{LanguageModel, TruncationDirection}, -}; -use async_trait::async_trait; -use gpui::AppContext; -use serde::Serialize; - -pub struct DummyLanguageModel {} - -impl LanguageModel for DummyLanguageModel { - fn name(&self) -> String { - "dummy".to_string() - } - fn capacity(&self) -> anyhow::Result { - anyhow::Ok(1000) - } - fn truncate( - &self, - content: &str, - length: usize, - direction: crate::models::TruncationDirection, - ) -> anyhow::Result { - if content.len() < length { - return anyhow::Ok(content.to_string()); - } - - let truncated = match direction { - TruncationDirection::End => content.chars().collect::>()[..length] - .iter() - .collect::(), - TruncationDirection::Start => content.chars().collect::>()[..length] - .iter() - .collect::(), - }; - - anyhow::Ok(truncated) - } - fn count_tokens(&self, content: &str) -> anyhow::Result { - anyhow::Ok(content.chars().collect::>().len()) - } -} - -#[derive(Serialize)] -pub struct DummyCompletionRequest { - pub name: String, -} - -impl CompletionRequest for DummyCompletionRequest { - fn data(&self) -> serde_json::Result { - serde_json::to_string(self) - } -} - -pub struct DummyEmbeddingProvider {} - -#[async_trait] -impl EmbeddingProvider for DummyEmbeddingProvider { - fn retrieve_credentials(&self, _cx: &AppContext) -> Option { - Some("Dummy Credentials".to_string()) - } - fn base_model(&self) -> Box { - Box::new(DummyLanguageModel {}) - } - fn rate_limit_expiration(&self) -> Option { - None - } - async fn embed_batch( - &self, - spans: Vec, - api_key: Option, - ) -> anyhow::Result> { - // 1024 is the OpenAI Embeddings size for ada models. - // the model we will likely be starting with. - let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]); - return Ok(vec![dummy_vec; spans.len()]); - } - - fn max_tokens_per_batch(&self) -> usize { - 8190 - } -} diff --git a/crates/ai/src/providers/mod.rs b/crates/ai/src/providers/mod.rs index 7a7092baf39d40da8b09ffb8085463a4a8b2efdf..acd0f9d91053869e3999ef0c1a23326480a7cbdd 100644 --- a/crates/ai/src/providers/mod.rs +++ b/crates/ai/src/providers/mod.rs @@ -1,2 +1 @@ -pub mod dummy; pub mod open_ai; diff --git a/crates/ai/src/providers/open_ai/auth.rs b/crates/ai/src/providers/open_ai/auth.rs new file mode 100644 index 0000000000000000000000000000000000000000..c817ffea0056424397dcd66d32a2e4e0548302b3 --- /dev/null +++ b/crates/ai/src/providers/open_ai/auth.rs @@ -0,0 +1,33 @@ +use std::env; + +use gpui::AppContext; +use util::ResultExt; + +use crate::auth::{CredentialProvider, ProviderCredential}; +use crate::providers::open_ai::OPENAI_API_URL; + +#[derive(Clone)] +pub struct OpenAICredentialProvider {} + +impl CredentialProvider for OpenAICredentialProvider { + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { + let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { + Some(api_key) + } else if let Some((_, api_key)) = cx + .platform() + .read_credentials(OPENAI_API_URL) + .log_err() + .flatten() + { + String::from_utf8(api_key).log_err() + } else { + None + }; + + if let Some(api_key) = api_key { + ProviderCredential::Credentials { api_key } + } else { + ProviderCredential::NoCredentials + } + } +} diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index 805a906dda97ea356f73b2278ca5992892b6c206..1385b32b4da6762b7dc6dac29eb0db8de692a6b3 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -2,7 +2,7 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; use futures::AsyncReadExt; use gpui::executor::Background; -use gpui::{serde_json, AppContext}; +use gpui::serde_json; use isahc::http::StatusCode; use isahc::prelude::Configurable; use isahc::{AsyncBody, Response}; @@ -17,13 +17,13 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use tiktoken_rs::{cl100k_base, CoreBPE}; use util::http::{HttpClient, Request}; -use util::ResultExt; +use crate::auth::{CredentialProvider, ProviderCredential}; use crate::embedding::{Embedding, EmbeddingProvider}; use crate::models::LanguageModel; use crate::providers::open_ai::OpenAILanguageModel; -use super::OPENAI_API_URL; +use crate::providers::open_ai::auth::OpenAICredentialProvider; lazy_static! { static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); @@ -33,6 +33,7 @@ lazy_static! { #[derive(Clone)] pub struct OpenAIEmbeddingProvider { model: OpenAILanguageModel, + credential_provider: OpenAICredentialProvider, pub client: Arc, pub executor: Arc, rate_limit_count_rx: watch::Receiver>, @@ -73,6 +74,7 @@ impl OpenAIEmbeddingProvider { OpenAIEmbeddingProvider { model, + credential_provider: OpenAICredentialProvider {}, client, executor, rate_limit_count_rx, @@ -138,25 +140,17 @@ impl OpenAIEmbeddingProvider { #[async_trait] impl EmbeddingProvider for OpenAIEmbeddingProvider { - fn retrieve_credentials(&self, cx: &AppContext) -> Option { - let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { - Some(api_key) - } else if let Some((_, api_key)) = cx - .platform() - .read_credentials(OPENAI_API_URL) - .log_err() - .flatten() - { - String::from_utf8(api_key).log_err() - } else { - None - }; - api_key - } fn base_model(&self) -> Box { let model: Box = Box::new(self.model.clone()); model } + + fn credential_provider(&self) -> Box { + let credential_provider: Box = + Box::new(self.credential_provider.clone()); + credential_provider + } + fn max_tokens_per_batch(&self) -> usize { 50000 } @@ -164,25 +158,11 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider { fn rate_limit_expiration(&self) -> Option { *self.rate_limit_count_rx.borrow() } - // fn truncate(&self, span: &str) -> (String, usize) { - // let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - // let output = if tokens.len() > OPENAI_INPUT_LIMIT { - // tokens.truncate(OPENAI_INPUT_LIMIT); - // OPENAI_BPE_TOKENIZER - // .decode(tokens.clone()) - // .ok() - // .unwrap_or_else(|| span.to_string()) - // } else { - // span.to_string() - // }; - - // (output, tokens.len()) - // } async fn embed_batch( &self, spans: Vec, - api_key: Option, + _credential: ProviderCredential, ) -> Result> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; diff --git a/crates/ai/src/providers/open_ai/mod.rs b/crates/ai/src/providers/open_ai/mod.rs index 67cb2b53151e93560ea38429ef17dc2b57dc3e86..49e29fbc8c184e9bc39e923563778e3678ed7c9e 100644 --- a/crates/ai/src/providers/open_ai/mod.rs +++ b/crates/ai/src/providers/open_ai/mod.rs @@ -1,3 +1,4 @@ +pub mod auth; pub mod completion; pub mod embedding; pub mod model; diff --git a/crates/ai/src/test.rs b/crates/ai/src/test.rs new file mode 100644 index 0000000000000000000000000000000000000000..d8805bad1a6903064d92075bbf68f4dc06928da5 --- /dev/null +++ b/crates/ai/src/test.rs @@ -0,0 +1,123 @@ +use std::{ + sync::atomic::{self, AtomicUsize, Ordering}, + time::Instant, +}; + +use async_trait::async_trait; + +use crate::{ + auth::{CredentialProvider, NullCredentialProvider, ProviderCredential}, + embedding::{Embedding, EmbeddingProvider}, + models::{LanguageModel, TruncationDirection}, +}; + +#[derive(Clone)] +pub struct FakeLanguageModel { + pub capacity: usize, +} + +impl LanguageModel for FakeLanguageModel { + fn name(&self) -> String { + "dummy".to_string() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + anyhow::Ok(content.chars().collect::>().len()) + } + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { + anyhow::Ok(match direction { + TruncationDirection::End => content.chars().collect::>()[..length] + .into_iter() + .collect::(), + TruncationDirection::Start => content.chars().collect::>()[length..] + .into_iter() + .collect::(), + }) + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(self.capacity) + } +} + +pub struct FakeEmbeddingProvider { + pub embedding_count: AtomicUsize, + pub credential_provider: NullCredentialProvider, +} + +impl Clone for FakeEmbeddingProvider { + fn clone(&self) -> Self { + FakeEmbeddingProvider { + embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)), + credential_provider: self.credential_provider.clone(), + } + } +} + +impl Default for FakeEmbeddingProvider { + fn default() -> Self { + FakeEmbeddingProvider { + embedding_count: AtomicUsize::default(), + credential_provider: NullCredentialProvider {}, + } + } +} + +impl FakeEmbeddingProvider { + pub fn embedding_count(&self) -> usize { + self.embedding_count.load(atomic::Ordering::SeqCst) + } + + pub fn embed_sync(&self, span: &str) -> Embedding { + let mut result = vec![1.0; 26]; + for letter in span.chars() { + let letter = letter.to_ascii_lowercase(); + if letter as u32 >= 'a' as u32 { + let ix = (letter as u32) - ('a' as u32); + if ix < 26 { + result[ix as usize] += 1.0; + } + } + } + + let norm = result.iter().map(|x| x * x).sum::().sqrt(); + for x in &mut result { + *x /= norm; + } + + result.into() + } +} + +#[async_trait] +impl EmbeddingProvider for FakeEmbeddingProvider { + fn base_model(&self) -> Box { + Box::new(FakeLanguageModel { capacity: 1000 }) + } + fn credential_provider(&self) -> Box { + let credential_provider: Box = + Box::new(self.credential_provider.clone()); + credential_provider + } + fn max_tokens_per_batch(&self) -> usize { + 1000 + } + + fn rate_limit_expiration(&self) -> Option { + None + } + + async fn embed_batch( + &self, + spans: Vec, + _credential: ProviderCredential, + ) -> anyhow::Result> { + self.embedding_count + .fetch_add(spans.len(), atomic::Ordering::SeqCst); + + anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) + } +} diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index e535eca1441b4c8137d043c304a68cb2866e38fc..e71b1ae2cb36aa85eea932a1fc6fefc634e50c09 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -335,7 +335,6 @@ fn strip_markdown_codeblock( #[cfg(test)] mod tests { use super::*; - use ai::providers::dummy::DummyCompletionRequest; use futures::{ future::BoxFuture, stream::{self, BoxStream}, @@ -345,9 +344,21 @@ mod tests { use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point}; use parking_lot::Mutex; use rand::prelude::*; + use serde::Serialize; use settings::SettingsStore; use smol::future::FutureExt; + #[derive(Serialize)] + pub struct DummyCompletionRequest { + pub name: String, + } + + impl CompletionRequest for DummyCompletionRequest { + fn data(&self) -> serde_json::Result { + serde_json::to_string(self) + } + } + #[gpui::test(iterations = 10)] async fn test_transform_autoindent( cx: &mut TestAppContext, @@ -381,6 +392,7 @@ mod tests { cx, ) }); + let request = Box::new(DummyCompletionRequest { name: "test".to_string(), }); diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index 1febb2af786e205e19764e90ca1c78d954d8a1bb..875440ef3fa866734c734830ff5c7b95550fd33c 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -42,6 +42,7 @@ sha1 = "0.10.5" ndarray = { version = "0.15.0" } [dev-dependencies] +ai = { path = "../ai", features = ["test-support"] } collections = { path = "../collections", features = ["test-support"] } gpui = { path = "../gpui", features = ["test-support"] } language = { path = "../language", features = ["test-support"] } diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index d57d5c7bbe43130d8df135116adeaa76d0a4e357..9ca6d8a0d945a4eae80010a40f31f1a1c39d3d8f 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -1,5 +1,5 @@ use crate::{parsing::Span, JobHandle}; -use ai::embedding::EmbeddingProvider; +use ai::{auth::ProviderCredential, embedding::EmbeddingProvider}; use gpui::executor::Background; use parking_lot::Mutex; use smol::channel; @@ -41,7 +41,7 @@ pub struct EmbeddingQueue { pending_batch_token_count: usize, finished_files_tx: channel::Sender, finished_files_rx: channel::Receiver, - api_key: Option, + provider_credential: ProviderCredential, } #[derive(Clone)] @@ -54,7 +54,7 @@ impl EmbeddingQueue { pub fn new( embedding_provider: Arc, executor: Arc, - api_key: Option, + provider_credential: ProviderCredential, ) -> Self { let (finished_files_tx, finished_files_rx) = channel::unbounded(); Self { @@ -64,12 +64,12 @@ impl EmbeddingQueue { pending_batch_token_count: 0, finished_files_tx, finished_files_rx, - api_key, + provider_credential, } } - pub fn set_api_key(&mut self, api_key: Option) { - self.api_key = api_key + pub fn set_credential(&mut self, credential: ProviderCredential) { + self.provider_credential = credential } pub fn push(&mut self, file: FileToEmbed) { @@ -118,7 +118,7 @@ impl EmbeddingQueue { let finished_files_tx = self.finished_files_tx.clone(); let embedding_provider = self.embedding_provider.clone(); - let api_key = self.api_key.clone(); + let credential = self.provider_credential.clone(); self.executor .spawn(async move { @@ -143,7 +143,7 @@ impl EmbeddingQueue { return; }; - match embedding_provider.embed_batch(spans, api_key).await { + match embedding_provider.embed_batch(spans, credential).await { Ok(embeddings) => { let mut embeddings = embeddings.into_iter(); for fragment in batch { diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 6863918d5d545894dd8858134ea1f874ccadcc3a..5be3d6ccf5302d94a3c4d8a12674f3c6e17421d0 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -7,6 +7,7 @@ pub mod semantic_index_settings; mod semantic_index_tests; use crate::semantic_index_settings::SemanticIndexSettings; +use ai::auth::ProviderCredential; use ai::embedding::{Embedding, EmbeddingProvider}; use ai::providers::open_ai::OpenAIEmbeddingProvider; use anyhow::{anyhow, Result}; @@ -124,7 +125,7 @@ pub struct SemanticIndex { _embedding_task: Task<()>, _parsing_files_tasks: Vec>, projects: HashMap, ProjectState>, - api_key: Option, + provider_credential: ProviderCredential, embedding_queue: Arc>, } @@ -279,18 +280,27 @@ impl SemanticIndex { } } - pub fn authenticate(&mut self, cx: &AppContext) { - if self.api_key.is_none() { - self.api_key = self.embedding_provider.retrieve_credentials(cx); - - self.embedding_queue - .lock() - .set_api_key(self.api_key.clone()); + pub fn authenticate(&mut self, cx: &AppContext) -> bool { + let credential = self.provider_credential.clone(); + match credential { + ProviderCredential::NoCredentials => { + let credential = self.embedding_provider.retrieve_credentials(cx); + self.provider_credential = credential; + } + _ => {} } + + self.embedding_queue.lock().set_credential(credential); + + self.is_authenticated() } pub fn is_authenticated(&self) -> bool { - self.api_key.is_some() + let credential = &self.provider_credential; + match credential { + &ProviderCredential::Credentials { .. } => true, + _ => false, + } } pub fn enabled(cx: &AppContext) -> bool { @@ -340,7 +350,7 @@ impl SemanticIndex { Ok(cx.add_model(|cx| { let t0 = Instant::now(); let embedding_queue = - EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None); + EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), ProviderCredential::NoCredentials); let _embedding_task = cx.background().spawn({ let embedded_files = embedding_queue.finished_files(); let db = db.clone(); @@ -405,7 +415,7 @@ impl SemanticIndex { _embedding_task, _parsing_files_tasks, projects: Default::default(), - api_key: None, + provider_credential: ProviderCredential::NoCredentials, embedding_queue } })) @@ -721,13 +731,14 @@ impl SemanticIndex { let index = self.index_project(project.clone(), cx); let embedding_provider = self.embedding_provider.clone(); - let api_key = self.api_key.clone(); + let credential = self.provider_credential.clone(); cx.spawn(|this, mut cx| async move { index.await?; let t0 = Instant::now(); + let query = embedding_provider - .embed_batch(vec![query], api_key) + .embed_batch(vec![query], credential) .await? .pop() .ok_or_else(|| anyhow!("could not embed query"))?; @@ -945,7 +956,7 @@ impl SemanticIndex { let fs = self.fs.clone(); let db_path = self.db.path().clone(); let background = cx.background().clone(); - let api_key = self.api_key.clone(); + let credential = self.provider_credential.clone(); cx.background().spawn(async move { let db = VectorDatabase::new(fs, db_path.clone(), background).await?; let mut results = Vec::::new(); @@ -964,7 +975,7 @@ impl SemanticIndex { &mut spans, embedding_provider.as_ref(), &db, - api_key.clone(), + credential.clone(), ) .await .log_err() @@ -1008,9 +1019,8 @@ impl SemanticIndex { project: ModelHandle, cx: &mut ModelContext, ) -> Task> { - if self.api_key.is_none() { - self.authenticate(cx); - if self.api_key.is_none() { + if !self.is_authenticated() { + if !self.authenticate(cx) { return Task::ready(Err(anyhow!("user is not authenticated"))); } } @@ -1193,7 +1203,7 @@ impl SemanticIndex { spans: &mut [Span], embedding_provider: &dyn EmbeddingProvider, db: &VectorDatabase, - api_key: Option, + credential: ProviderCredential, ) -> Result<()> { let mut batch = Vec::new(); let mut batch_tokens = 0; @@ -1216,7 +1226,7 @@ impl SemanticIndex { if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() { let batch_embeddings = embedding_provider - .embed_batch(mem::take(&mut batch), api_key.clone()) + .embed_batch(mem::take(&mut batch), credential.clone()) .await?; embeddings.extend(batch_embeddings); batch_tokens = 0; @@ -1228,7 +1238,7 @@ impl SemanticIndex { if !batch.is_empty() { let batch_embeddings = embedding_provider - .embed_batch(mem::take(&mut batch), api_key) + .embed_batch(mem::take(&mut batch), credential) .await?; embeddings.extend(batch_embeddings); diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 1c117c9ea20cecd1ea389064a67d6ae08c7a22aa..7d5a4e22e80530b4d22314beb12582bb300ff406 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -4,14 +4,9 @@ use crate::{ semantic_index_settings::SemanticIndexSettings, FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT, }; -use ai::providers::dummy::{DummyEmbeddingProvider, DummyLanguageModel}; -use ai::{ - embedding::{Embedding, EmbeddingProvider}, - models::LanguageModel, -}; -use anyhow::Result; -use async_trait::async_trait; -use gpui::{executor::Deterministic, AppContext, Task, TestAppContext}; +use ai::test::FakeEmbeddingProvider; + +use gpui::{executor::Deterministic, Task, TestAppContext}; use language::{Language, LanguageConfig, LanguageRegistry, ToOffset}; use parking_lot::Mutex; use pretty_assertions::assert_eq; @@ -19,14 +14,7 @@ use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs use rand::{rngs::StdRng, Rng}; use serde_json::json; use settings::SettingsStore; -use std::{ - path::Path, - sync::{ - atomic::{self, AtomicUsize}, - Arc, - }, - time::{Instant, SystemTime}, -}; +use std::{path::Path, sync::Arc, time::SystemTime}; use unindent::Unindent; use util::RandomCharIter; @@ -232,7 +220,11 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None); + let mut queue = EmbeddingQueue::new( + embedding_provider.clone(), + cx.background(), + ai::auth::ProviderCredential::NoCredentials, + ); for file in &files { queue.push(file.clone()); } @@ -284,7 +276,7 @@ fn assert_search_results( #[gpui::test] async fn test_code_context_retrieval_rust() { let language = rust_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " @@ -386,7 +378,7 @@ async fn test_code_context_retrieval_rust() { #[gpui::test] async fn test_code_context_retrieval_json() { let language = json_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -470,7 +462,7 @@ fn assert_documents_eq( #[gpui::test] async fn test_code_context_retrieval_javascript() { let language = js_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " @@ -569,7 +561,7 @@ async fn test_code_context_retrieval_javascript() { #[gpui::test] async fn test_code_context_retrieval_lua() { let language = lua_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -643,7 +635,7 @@ async fn test_code_context_retrieval_lua() { #[gpui::test] async fn test_code_context_retrieval_elixir() { let language = elixir_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -760,7 +752,7 @@ async fn test_code_context_retrieval_elixir() { #[gpui::test] async fn test_code_context_retrieval_cpp() { let language = cpp_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " @@ -913,7 +905,7 @@ async fn test_code_context_retrieval_cpp() { #[gpui::test] async fn test_code_context_retrieval_ruby() { let language = ruby_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -1104,7 +1096,7 @@ async fn test_code_context_retrieval_ruby() { #[gpui::test] async fn test_code_context_retrieval_php() { let language = php_lang(); - let embedding_provider = Arc::new(DummyEmbeddingProvider {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -1252,65 +1244,6 @@ async fn test_code_context_retrieval_php() { ); } -#[derive(Default)] -struct FakeEmbeddingProvider { - embedding_count: AtomicUsize, -} - -impl FakeEmbeddingProvider { - fn embedding_count(&self) -> usize { - self.embedding_count.load(atomic::Ordering::SeqCst) - } - - fn embed_sync(&self, span: &str) -> Embedding { - let mut result = vec![1.0; 26]; - for letter in span.chars() { - let letter = letter.to_ascii_lowercase(); - if letter as u32 >= 'a' as u32 { - let ix = (letter as u32) - ('a' as u32); - if ix < 26 { - result[ix as usize] += 1.0; - } - } - } - - let norm = result.iter().map(|x| x * x).sum::().sqrt(); - for x in &mut result { - *x /= norm; - } - - result.into() - } -} - -#[async_trait] -impl EmbeddingProvider for FakeEmbeddingProvider { - fn base_model(&self) -> Box { - Box::new(DummyLanguageModel {}) - } - fn retrieve_credentials(&self, _cx: &AppContext) -> Option { - Some("Fake Credentials".to_string()) - } - fn max_tokens_per_batch(&self) -> usize { - 1000 - } - - fn rate_limit_expiration(&self) -> Option { - None - } - - async fn embed_batch( - &self, - spans: Vec, - _api_key: Option, - ) -> Result> { - self.embedding_count - .fetch_add(spans.len(), atomic::Ordering::SeqCst); - - anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) - } -} - fn js_lang() -> Arc { Arc::new( Language::new( From ca82ec8e8e2484099de3020233804671387ba9de Mon Sep 17 00:00:00 2001 From: KCaverly Date: Thu, 26 Oct 2023 14:05:55 +0200 Subject: [PATCH 10/25] fixed truncation error in fake language model --- crates/ai/src/auth.rs | 2 +- crates/ai/src/test.rs | 4 ++++ crates/semantic_index/src/embedding_queue.rs | 2 +- crates/semantic_index/src/semantic_index.rs | 5 ++++- 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/crates/ai/src/auth.rs b/crates/ai/src/auth.rs index a3ce8aece1f97131183d193447bd9e0b848795cc..c188c30797d4734e45899dd2e3c7c5f3c3c5a378 100644 --- a/crates/ai/src/auth.rs +++ b/crates/ai/src/auth.rs @@ -1,6 +1,6 @@ use gpui::AppContext; -#[derive(Clone)] +#[derive(Clone, Debug)] pub enum ProviderCredential { Credentials { api_key: String }, NoCredentials, diff --git a/crates/ai/src/test.rs b/crates/ai/src/test.rs index d8805bad1a6903064d92075bbf68f4dc06928da5..bc143e3c210049cd0afdb588880d44958dbbed75 100644 --- a/crates/ai/src/test.rs +++ b/crates/ai/src/test.rs @@ -29,6 +29,10 @@ impl LanguageModel for FakeLanguageModel { length: usize, direction: TruncationDirection, ) -> anyhow::Result { + if length > self.count_tokens(content)? { + return anyhow::Ok(content.to_string()); + } + anyhow::Ok(match direction { TruncationDirection::End => content.chars().collect::>()[..length] .into_iter() diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index 9ca6d8a0d945a4eae80010a40f31f1a1c39d3d8f..299aa328b52e4893d92f6fa644209850a83792b2 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -69,7 +69,7 @@ impl EmbeddingQueue { } pub fn set_credential(&mut self, credential: ProviderCredential) { - self.provider_credential = credential + self.provider_credential = credential; } pub fn push(&mut self, file: FileToEmbed) { diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 5be3d6ccf5302d94a3c4d8a12674f3c6e17421d0..f420e0503bd9c481c74cc94ed19594ca5dedce02 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -291,7 +291,6 @@ impl SemanticIndex { } self.embedding_queue.lock().set_credential(credential); - self.is_authenticated() } @@ -299,6 +298,7 @@ impl SemanticIndex { let credential = &self.provider_credential; match credential { &ProviderCredential::Credentials { .. } => true, + &ProviderCredential::NotNeeded => true, _ => false, } } @@ -1020,11 +1020,14 @@ impl SemanticIndex { cx: &mut ModelContext, ) -> Task> { if !self.is_authenticated() { + println!("Authenticating"); if !self.authenticate(cx) { return Task::ready(Err(anyhow!("user is not authenticated"))); } } + println!("SHOULD NOW BE AUTHENTICATED"); + if !self.projects.contains_key(&project.downgrade()) { let subscription = cx.subscribe(&project, |this, project, event, cx| match event { project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => { From 6c8bb4b05e62aab88d5487cecb215c2fe8863a49 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 27 Oct 2023 08:33:35 +0200 Subject: [PATCH 11/25] ensure OpenAIEmbeddingProvider is using the provider credentials --- crates/ai/src/providers/open_ai/embedding.rs | 11 ++++++----- crates/semantic_index/src/embedding_queue.rs | 2 +- crates/semantic_index/src/semantic_index.rs | 17 ++++++----------- 3 files changed, 13 insertions(+), 17 deletions(-) diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index 1385b32b4da6762b7dc6dac29eb0db8de692a6b3..980687766029ecf1c5e49d338586f471e29f35c2 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -162,14 +162,15 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider { async fn embed_batch( &self, spans: Vec, - _credential: ProviderCredential, + credential: ProviderCredential, ) -> Result> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; - let api_key = OPENAI_API_KEY - .as_ref() - .ok_or_else(|| anyhow!("no api key"))?; + let api_key = match credential { + ProviderCredential::Credentials { api_key } => anyhow::Ok(api_key), + _ => Err(anyhow!("no api key provided")), + }?; let mut request_number = 0; let mut rate_limiting = false; @@ -178,7 +179,7 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider { while request_number < MAX_RETRIES { response = self .send_request( - api_key, + &api_key, spans.iter().map(|x| &**x).collect(), request_timeout, ) diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index 299aa328b52e4893d92f6fa644209850a83792b2..6f792c78e21dc91f333c22a918f6d13047c506c9 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -41,7 +41,7 @@ pub struct EmbeddingQueue { pending_batch_token_count: usize, finished_files_tx: channel::Sender, finished_files_rx: channel::Receiver, - provider_credential: ProviderCredential, + pub provider_credential: ProviderCredential, } #[derive(Clone)] diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index f420e0503bd9c481c74cc94ed19594ca5dedce02..7fb5f749b4f6851b9201ddac6cd7fc59ce6f782a 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -281,15 +281,13 @@ impl SemanticIndex { } pub fn authenticate(&mut self, cx: &AppContext) -> bool { - let credential = self.provider_credential.clone(); - match credential { - ProviderCredential::NoCredentials => { - let credential = self.embedding_provider.retrieve_credentials(cx); - self.provider_credential = credential; - } - _ => {} - } + let existing_credential = self.provider_credential.clone(); + let credential = match existing_credential { + ProviderCredential::NoCredentials => self.embedding_provider.retrieve_credentials(cx), + _ => existing_credential, + }; + self.provider_credential = credential.clone(); self.embedding_queue.lock().set_credential(credential); self.is_authenticated() } @@ -1020,14 +1018,11 @@ impl SemanticIndex { cx: &mut ModelContext, ) -> Task> { if !self.is_authenticated() { - println!("Authenticating"); if !self.authenticate(cx) { return Task::ready(Err(anyhow!("user is not authenticated"))); } } - println!("SHOULD NOW BE AUTHENTICATED"); - if !self.projects.contains_key(&project.downgrade()) { let subscription = cx.subscribe(&project, |this, project, event, cx| match event { project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => { From ec9d79b6fec4e90f34367bd3a855ef11c58f75fd Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 27 Oct 2023 08:51:30 +0200 Subject: [PATCH 12/25] add concept of LanguageModel to CompletionProvider --- crates/ai/src/completion.rs | 3 +++ crates/ai/src/providers/open_ai/completion.rs | 21 ++++++++++++++++--- crates/ai/src/providers/open_ai/embedding.rs | 1 - crates/assistant/src/assistant_panel.rs | 1 + crates/assistant/src/codegen.rs | 5 +++++ 5 files changed, 27 insertions(+), 4 deletions(-) diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index ba89c869d214c8772430e30b3177af502a41031e..da9ebd5a1d7c5f7e651ccdb65f02df47794c45e9 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -1,11 +1,14 @@ use anyhow::Result; use futures::{future::BoxFuture, stream::BoxStream}; +use crate::models::LanguageModel; + pub trait CompletionRequest: Send + Sync { fn data(&self) -> serde_json::Result; } pub trait CompletionProvider { + fn base_model(&self) -> Box; fn complete( &self, prompt: Box, diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs index 95ed13c0dd7f1d519c8279ea75fd3980dd4f5d50..20f72c0ff761a691dfc7e70b44be9d537a1ba9ee 100644 --- a/crates/ai/src/providers/open_ai/completion.rs +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -12,7 +12,12 @@ use std::{ sync::Arc, }; -use crate::completion::{CompletionProvider, CompletionRequest}; +use crate::{ + completion::{CompletionProvider, CompletionRequest}, + models::LanguageModel, +}; + +use super::OpenAILanguageModel; pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; @@ -180,17 +185,27 @@ pub async fn stream_completion( } pub struct OpenAICompletionProvider { + model: OpenAILanguageModel, api_key: String, executor: Arc, } impl OpenAICompletionProvider { - pub fn new(api_key: String, executor: Arc) -> Self { - Self { api_key, executor } + pub fn new(model_name: &str, api_key: String, executor: Arc) -> Self { + let model = OpenAILanguageModel::load(model_name); + Self { + model, + api_key, + executor, + } } } impl CompletionProvider for OpenAICompletionProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(self.model.clone()); + model + } fn complete( &self, prompt: Box, diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index 980687766029ecf1c5e49d338586f471e29f35c2..64f568da1ab6a39c5fa038e27d6962fd65a6567d 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -26,7 +26,6 @@ use crate::providers::open_ai::OpenAILanguageModel; use crate::providers::open_ai::auth::OpenAICredentialProvider; lazy_static! { - static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index ec16c8fd046b48b8ef8dc0fc52afbdc6a0da1b3e..c899465ed2655d317e4f8e06d09240d07d90ea14 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -328,6 +328,7 @@ impl AssistantPanel { let inline_assist_id = post_inc(&mut self.next_inline_assist_id); let provider = Arc::new(OpenAICompletionProvider::new( + "gpt-4", api_key, cx.background().clone(), )); diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index e71b1ae2cb36aa85eea932a1fc6fefc634e50c09..33adb2e5706611bc810214f7b77cd2330ed57196 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -335,6 +335,7 @@ fn strip_markdown_codeblock( #[cfg(test)] mod tests { use super::*; + use ai::{models::LanguageModel, test::FakeLanguageModel}; use futures::{ future::BoxFuture, stream::{self, BoxStream}, @@ -638,6 +639,10 @@ mod tests { } impl CompletionProvider for TestCompletionProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(FakeLanguageModel { capacity: 8190 }); + model + } fn complete( &self, _prompt: Box, From 7af77b1cf95da45314092aa35f7bcc04fa4fd3bc Mon Sep 17 00:00:00 2001 From: KCaverly Date: Fri, 27 Oct 2023 12:26:01 +0200 Subject: [PATCH 13/25] moved TestCompletionProvider into ai --- crates/ai/src/test.rs | 39 +++++++++++++++++++++++++++++++++ crates/assistant/Cargo.toml | 1 + crates/assistant/src/codegen.rs | 38 +------------------------------- 3 files changed, 41 insertions(+), 37 deletions(-) diff --git a/crates/ai/src/test.rs b/crates/ai/src/test.rs index bc143e3c210049cd0afdb588880d44958dbbed75..2c78027b62781368aaceea862777d412d6d29e3f 100644 --- a/crates/ai/src/test.rs +++ b/crates/ai/src/test.rs @@ -4,9 +4,12 @@ use std::{ }; use async_trait::async_trait; +use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use parking_lot::Mutex; use crate::{ auth::{CredentialProvider, NullCredentialProvider, ProviderCredential}, + completion::{CompletionProvider, CompletionRequest}, embedding::{Embedding, EmbeddingProvider}, models::{LanguageModel, TruncationDirection}, }; @@ -125,3 +128,39 @@ impl EmbeddingProvider for FakeEmbeddingProvider { anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) } } + +pub struct TestCompletionProvider { + last_completion_tx: Mutex>>, +} + +impl TestCompletionProvider { + pub fn new() -> Self { + Self { + last_completion_tx: Mutex::new(None), + } + } + + pub fn send_completion(&self, completion: impl Into) { + let mut tx = self.last_completion_tx.lock(); + tx.as_mut().unwrap().try_send(completion.into()).unwrap(); + } + + pub fn finish_completion(&self) { + self.last_completion_tx.lock().take().unwrap(); + } +} + +impl CompletionProvider for TestCompletionProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(FakeLanguageModel { capacity: 8190 }); + model + } + fn complete( + &self, + _prompt: Box, + ) -> BoxFuture<'static, anyhow::Result>>> { + let (tx, rx) = mpsc::channel(1); + *self.last_completion_tx.lock() = Some(tx); + async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() + } +} diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index 9cfdd3301ad712333f7d9775be2d411664d4d369..6b0ce659e3079094615f629b5216fa92c144f6d2 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -44,6 +44,7 @@ tiktoken-rs = "0.5" [dev-dependencies] editor = { path = "../editor", features = ["test-support"] } project = { path = "../project", features = ["test-support"] } +ai = { path = "../ai", features = ["test-support"]} ctor.workspace = true env_logger.workspace = true diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index 33adb2e5706611bc810214f7b77cd2330ed57196..3516fc3708bb4e3059c622db6292a709d2a8113e 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -335,7 +335,7 @@ fn strip_markdown_codeblock( #[cfg(test)] mod tests { use super::*; - use ai::{models::LanguageModel, test::FakeLanguageModel}; + use ai::test::TestCompletionProvider; use futures::{ future::BoxFuture, stream::{self, BoxStream}, @@ -617,42 +617,6 @@ mod tests { } } - struct TestCompletionProvider { - last_completion_tx: Mutex>>, - } - - impl TestCompletionProvider { - fn new() -> Self { - Self { - last_completion_tx: Mutex::new(None), - } - } - - fn send_completion(&self, completion: impl Into) { - let mut tx = self.last_completion_tx.lock(); - tx.as_mut().unwrap().try_send(completion.into()).unwrap(); - } - - fn finish_completion(&self) { - self.last_completion_tx.lock().take().unwrap(); - } - } - - impl CompletionProvider for TestCompletionProvider { - fn base_model(&self) -> Box { - let model: Box = Box::new(FakeLanguageModel { capacity: 8190 }); - model - } - fn complete( - &self, - _prompt: Box, - ) -> BoxFuture<'static, Result>>> { - let (tx, rx) = mpsc::channel(1); - *self.last_completion_tx.lock() = Some(tx); - async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() - } - } - fn rust_lang() -> Language { Language::new( LanguageConfig { From 558f54c424a1b0f7ccaa317af62b50fd5c467fc0 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Sat, 28 Oct 2023 16:35:43 -0400 Subject: [PATCH 14/25] added credential provider to completion provider --- crates/ai/src/completion.rs | 10 +++++++++- crates/ai/src/providers/open_ai/completion.rs | 10 +++++++++- crates/ai/src/providers/open_ai/embedding.rs | 1 - crates/ai/src/test.rs | 3 +++ crates/assistant/src/codegen.rs | 7 +------ 5 files changed, 22 insertions(+), 9 deletions(-) diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index da9ebd5a1d7c5f7e651ccdb65f02df47794c45e9..5b9bad48704528e1ca38e2a23152bfcffa5aa7ba 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -1,7 +1,11 @@ use anyhow::Result; use futures::{future::BoxFuture, stream::BoxStream}; +use gpui::AppContext; -use crate::models::LanguageModel; +use crate::{ + auth::{CredentialProvider, ProviderCredential}, + models::LanguageModel, +}; pub trait CompletionRequest: Send + Sync { fn data(&self) -> serde_json::Result; @@ -9,6 +13,10 @@ pub trait CompletionRequest: Send + Sync { pub trait CompletionProvider { fn base_model(&self) -> Box; + fn credential_provider(&self) -> Box; + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { + self.credential_provider().retrieve_credentials(cx) + } fn complete( &self, prompt: Box, diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs index 20f72c0ff761a691dfc7e70b44be9d537a1ba9ee..9c9d205ff726d1d37952afb6e6918644cbf5f3b9 100644 --- a/crates/ai/src/providers/open_ai/completion.rs +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -13,11 +13,12 @@ use std::{ }; use crate::{ + auth::CredentialProvider, completion::{CompletionProvider, CompletionRequest}, models::LanguageModel, }; -use super::OpenAILanguageModel; +use super::{auth::OpenAICredentialProvider, OpenAILanguageModel}; pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; @@ -186,6 +187,7 @@ pub async fn stream_completion( pub struct OpenAICompletionProvider { model: OpenAILanguageModel, + credential_provider: OpenAICredentialProvider, api_key: String, executor: Arc, } @@ -193,8 +195,10 @@ pub struct OpenAICompletionProvider { impl OpenAICompletionProvider { pub fn new(model_name: &str, api_key: String, executor: Arc) -> Self { let model = OpenAILanguageModel::load(model_name); + let credential_provider = OpenAICredentialProvider {}; Self { model, + credential_provider, api_key, executor, } @@ -206,6 +210,10 @@ impl CompletionProvider for OpenAICompletionProvider { let model: Box = Box::new(self.model.clone()); model } + fn credential_provider(&self) -> Box { + let provider: Box = Box::new(self.credential_provider.clone()); + provider + } fn complete( &self, prompt: Box, diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index 64f568da1ab6a39c5fa038e27d6962fd65a6567d..dafc94580d16d3c79987311ee0e6558fb45d3669 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -11,7 +11,6 @@ use parking_lot::Mutex; use parse_duration::parse; use postage::watch; use serde::{Deserialize, Serialize}; -use std::env; use std::ops::Add; use std::sync::Arc; use std::time::{Duration, Instant}; diff --git a/crates/ai/src/test.rs b/crates/ai/src/test.rs index 2c78027b62781368aaceea862777d412d6d29e3f..b8f99af400f213f96ef7a1911d06f419fb990c54 100644 --- a/crates/ai/src/test.rs +++ b/crates/ai/src/test.rs @@ -155,6 +155,9 @@ impl CompletionProvider for TestCompletionProvider { let model: Box = Box::new(FakeLanguageModel { capacity: 8190 }); model } + fn credential_provider(&self) -> Box { + Box::new(NullCredentialProvider {}) + } fn complete( &self, _prompt: Box, diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index 3516fc3708bb4e3059c622db6292a709d2a8113e..7f4c95f655e7bf1ea222b109865fbf4a98ad2d0d 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -336,18 +336,13 @@ fn strip_markdown_codeblock( mod tests { use super::*; use ai::test::TestCompletionProvider; - use futures::{ - future::BoxFuture, - stream::{self, BoxStream}, - }; + use futures::stream::{self}; use gpui::{executor::Deterministic, TestAppContext}; use indoc::indoc; use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point}; - use parking_lot::Mutex; use rand::prelude::*; use serde::Serialize; use settings::SettingsStore; - use smol::future::FutureExt; #[derive(Serialize)] pub struct DummyCompletionRequest { From 1e8b23d8fb9cf231de56ed25ebd56ea04190fc55 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Sat, 28 Oct 2023 18:16:45 -0400 Subject: [PATCH 15/25] replace api_key with ProviderCredential throughout the AssistantPanel --- crates/ai/src/auth.rs | 4 + crates/ai/src/completion.rs | 6 + crates/ai/src/providers/open_ai/auth.rs | 13 + crates/ai/src/providers/open_ai/completion.rs | 24 +- crates/assistant/src/assistant_panel.rs | 276 +++++++++++------- 5 files changed, 205 insertions(+), 118 deletions(-) diff --git a/crates/ai/src/auth.rs b/crates/ai/src/auth.rs index c188c30797d4734e45899dd2e3c7c5f3c3c5a378..cb3f2beabb00ef79a52c5bbf10a19ef8d1a5cc48 100644 --- a/crates/ai/src/auth.rs +++ b/crates/ai/src/auth.rs @@ -9,6 +9,8 @@ pub enum ProviderCredential { pub trait CredentialProvider: Send + Sync { fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential; + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential); + fn delete_credentials(&self, cx: &AppContext); } #[derive(Clone)] @@ -17,4 +19,6 @@ impl CredentialProvider for NullCredentialProvider { fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential { ProviderCredential::NotNeeded } + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {} + fn delete_credentials(&self, cx: &AppContext) {} } diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index 5b9bad48704528e1ca38e2a23152bfcffa5aa7ba..6a2806a5cb30b702952e90362a72df41a5c80786 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -17,6 +17,12 @@ pub trait CompletionProvider { fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { self.credential_provider().retrieve_credentials(cx) } + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) { + self.credential_provider().save_credentials(cx, credential); + } + fn delete_credentials(&self, cx: &AppContext) { + self.credential_provider().delete_credentials(cx); + } fn complete( &self, prompt: Box, diff --git a/crates/ai/src/providers/open_ai/auth.rs b/crates/ai/src/providers/open_ai/auth.rs index c817ffea0056424397dcd66d32a2e4e0548302b3..7cb51ab449a066d683ab7e5921b3713bc5482cde 100644 --- a/crates/ai/src/providers/open_ai/auth.rs +++ b/crates/ai/src/providers/open_ai/auth.rs @@ -30,4 +30,17 @@ impl CredentialProvider for OpenAICredentialProvider { ProviderCredential::NoCredentials } } + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) { + match credential { + ProviderCredential::Credentials { api_key } => { + cx.platform() + .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) + .log_err(); + } + _ => {} + } + } + fn delete_credentials(&self, cx: &AppContext) { + cx.platform().delete_credentials(OPENAI_API_URL).log_err(); + } } diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs index 9c9d205ff726d1d37952afb6e6918644cbf5f3b9..febe491123fe848e74b49a50c6b18d0fafc3a702 100644 --- a/crates/ai/src/providers/open_ai/completion.rs +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -13,7 +13,7 @@ use std::{ }; use crate::{ - auth::CredentialProvider, + auth::{CredentialProvider, ProviderCredential}, completion::{CompletionProvider, CompletionRequest}, models::LanguageModel, }; @@ -102,10 +102,17 @@ pub struct OpenAIResponseStreamEvent { } pub async fn stream_completion( - api_key: String, + credential: ProviderCredential, executor: Arc, request: Box, ) -> Result>> { + let api_key = match credential { + ProviderCredential::Credentials { api_key } => api_key, + _ => { + return Err(anyhow!("no credentials provider for completion")); + } + }; + let (tx, rx) = futures::channel::mpsc::unbounded::>(); let json_data = request.data()?; @@ -188,18 +195,22 @@ pub async fn stream_completion( pub struct OpenAICompletionProvider { model: OpenAILanguageModel, credential_provider: OpenAICredentialProvider, - api_key: String, + credential: ProviderCredential, executor: Arc, } impl OpenAICompletionProvider { - pub fn new(model_name: &str, api_key: String, executor: Arc) -> Self { + pub fn new( + model_name: &str, + credential: ProviderCredential, + executor: Arc, + ) -> Self { let model = OpenAILanguageModel::load(model_name); let credential_provider = OpenAICredentialProvider {}; Self { model, credential_provider, - api_key, + credential, executor, } } @@ -218,7 +229,8 @@ impl CompletionProvider for OpenAICompletionProvider { &self, prompt: Box, ) -> BoxFuture<'static, Result>>> { - let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt); + let credential = self.credential.clone(); + let request = stream_completion(credential, self.executor.clone(), prompt); async move { let response = request.await?; let stream = response diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index c899465ed2655d317e4f8e06d09240d07d90ea14..f9187b87855b748b536aecb908c0c44ce3377b5f 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -7,7 +7,8 @@ use crate::{ }; use ai::{ - completion::CompletionRequest, + auth::ProviderCredential, + completion::{CompletionProvider, CompletionRequest}, providers::open_ai::{ stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, }, @@ -100,8 +101,8 @@ pub fn init(cx: &mut AppContext) { cx.capture_action(ConversationEditor::copy); cx.add_action(ConversationEditor::split); cx.capture_action(ConversationEditor::cycle_message_role); - cx.add_action(AssistantPanel::save_api_key); - cx.add_action(AssistantPanel::reset_api_key); + cx.add_action(AssistantPanel::save_credentials); + cx.add_action(AssistantPanel::reset_credentials); cx.add_action(AssistantPanel::toggle_zoom); cx.add_action(AssistantPanel::deploy); cx.add_action(AssistantPanel::select_next_match); @@ -143,7 +144,8 @@ pub struct AssistantPanel { zoomed: bool, has_focus: bool, toolbar: ViewHandle, - api_key: Rc>>, + credential: Rc>, + completion_provider: Box, api_key_editor: Option>, has_read_credentials: bool, languages: Arc, @@ -205,6 +207,12 @@ impl AssistantPanel { }); let semantic_index = SemanticIndex::global(cx); + // Defaulting currently to GPT4, allow for this to be set via config. + let completion_provider = Box::new(OpenAICompletionProvider::new( + "gpt-4", + ProviderCredential::NoCredentials, + cx.background().clone(), + )); let mut this = Self { workspace: workspace_handle, @@ -216,7 +224,8 @@ impl AssistantPanel { zoomed: false, has_focus: false, toolbar, - api_key: Rc::new(RefCell::new(None)), + credential: Rc::new(RefCell::new(ProviderCredential::NoCredentials)), + completion_provider, api_key_editor: None, has_read_credentials: false, languages: workspace.app_state().languages.clone(), @@ -257,10 +266,7 @@ impl AssistantPanel { cx: &mut ViewContext, ) { let this = if let Some(this) = workspace.panel::(cx) { - if this - .update(cx, |assistant, cx| assistant.load_api_key(cx)) - .is_some() - { + if this.update(cx, |assistant, cx| assistant.has_credentials(cx)) { this } else { workspace.focus_panel::(cx); @@ -292,12 +298,7 @@ impl AssistantPanel { cx: &mut ViewContext, project: &ModelHandle, ) { - let api_key = if let Some(api_key) = self.api_key.borrow().clone() { - api_key - } else { - return; - }; - + let credential = self.credential.borrow().clone(); let selection = editor.read(cx).selections.newest_anchor().clone(); if selection.start.excerpt_id() != selection.end.excerpt_id() { return; @@ -329,7 +330,7 @@ impl AssistantPanel { let inline_assist_id = post_inc(&mut self.next_inline_assist_id); let provider = Arc::new(OpenAICompletionProvider::new( "gpt-4", - api_key, + credential, cx.background().clone(), )); @@ -816,7 +817,7 @@ impl AssistantPanel { fn new_conversation(&mut self, cx: &mut ViewContext) -> ViewHandle { let editor = cx.add_view(|cx| { ConversationEditor::new( - self.api_key.clone(), + self.credential.clone(), self.languages.clone(), self.fs.clone(), self.workspace.clone(), @@ -875,17 +876,20 @@ impl AssistantPanel { } } - fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { + fn save_credentials(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { if let Some(api_key) = self .api_key_editor .as_ref() .map(|editor| editor.read(cx).text(cx)) { if !api_key.is_empty() { - cx.platform() - .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) - .log_err(); - *self.api_key.borrow_mut() = Some(api_key); + let credential = ProviderCredential::Credentials { + api_key: api_key.clone(), + }; + self.completion_provider + .save_credentials(cx, credential.clone()); + *self.credential.borrow_mut() = credential; + self.api_key_editor.take(); cx.focus_self(); cx.notify(); @@ -895,9 +899,9 @@ impl AssistantPanel { } } - fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext) { - cx.platform().delete_credentials(OPENAI_API_URL).log_err(); - self.api_key.take(); + fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext) { + self.completion_provider.delete_credentials(cx); + *self.credential.borrow_mut() = ProviderCredential::NoCredentials; self.api_key_editor = Some(build_api_key_editor(cx)); cx.focus_self(); cx.notify(); @@ -1156,13 +1160,19 @@ impl AssistantPanel { let fs = self.fs.clone(); let workspace = self.workspace.clone(); - let api_key = self.api_key.clone(); + let credential = self.credential.clone(); let languages = self.languages.clone(); cx.spawn(|this, mut cx| async move { let saved_conversation = fs.load(&path).await?; let saved_conversation = serde_json::from_str(&saved_conversation)?; let conversation = cx.add_model(|cx| { - Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx) + Conversation::deserialize( + saved_conversation, + path.clone(), + credential, + languages, + cx, + ) }); this.update(&mut cx, |this, cx| { // If, by the time we've loaded the conversation, the user has already opened @@ -1186,30 +1196,39 @@ impl AssistantPanel { .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path)) } - fn load_api_key(&mut self, cx: &mut ViewContext) -> Option { - if self.api_key.borrow().is_none() && !self.has_read_credentials { - self.has_read_credentials = true; - let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { - Some(api_key) - } else if let Some((_, api_key)) = cx - .platform() - .read_credentials(OPENAI_API_URL) - .log_err() - .flatten() - { - String::from_utf8(api_key).log_err() - } else { - None - }; - if let Some(api_key) = api_key { - *self.api_key.borrow_mut() = Some(api_key); - } else if self.api_key_editor.is_none() { - self.api_key_editor = Some(build_api_key_editor(cx)); - cx.notify(); + fn has_credentials(&mut self, cx: &mut ViewContext) -> bool { + let credential = self.load_credentials(cx); + match credential { + ProviderCredential::Credentials { .. } => true, + ProviderCredential::NotNeeded => true, + ProviderCredential::NoCredentials => false, + } + } + + fn load_credentials(&mut self, cx: &mut ViewContext) -> ProviderCredential { + let existing_credential = self.credential.clone(); + let existing_credential = existing_credential.borrow().clone(); + match existing_credential { + ProviderCredential::NoCredentials => { + if !self.has_read_credentials { + self.has_read_credentials = true; + let retrieved_credentials = self.completion_provider.retrieve_credentials(cx); + + match retrieved_credentials { + ProviderCredential::NoCredentials {} => { + self.api_key_editor = Some(build_api_key_editor(cx)); + cx.notify(); + } + _ => { + *self.credential.borrow_mut() = retrieved_credentials; + } + } + } } + _ => {} } - self.api_key.borrow().clone() + self.credential.borrow().clone() } } @@ -1394,7 +1413,7 @@ impl Panel for AssistantPanel { fn set_active(&mut self, active: bool, cx: &mut ViewContext) { if active { - self.load_api_key(cx); + self.load_credentials(cx); if self.editors.is_empty() { self.new_conversation(cx); @@ -1459,7 +1478,7 @@ struct Conversation { token_count: Option, max_token_count: usize, pending_token_count: Task>, - api_key: Rc>>, + credential: Rc>, pending_save: Task>, path: Option, _subscriptions: Vec, @@ -1471,7 +1490,8 @@ impl Entity for Conversation { impl Conversation { fn new( - api_key: Rc>>, + credential: Rc>, + language_registry: Arc, cx: &mut ModelContext, ) -> Self { @@ -1512,7 +1532,7 @@ impl Conversation { _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), path: None, - api_key, + credential, buffer, }; let message = MessageAnchor { @@ -1559,7 +1579,7 @@ impl Conversation { fn deserialize( saved_conversation: SavedConversation, path: PathBuf, - api_key: Rc>>, + credential: Rc>, language_registry: Arc, cx: &mut ModelContext, ) -> Self { @@ -1614,7 +1634,7 @@ impl Conversation { _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), path: Some(path), - api_key, + credential, buffer, }; this.count_remaining_tokens(cx); @@ -1736,9 +1756,13 @@ impl Conversation { } if should_assist { - let Some(api_key) = self.api_key.borrow().clone() else { - return Default::default(); - }; + let credential = self.credential.borrow().clone(); + match credential { + ProviderCredential::NoCredentials => { + return Default::default(); + } + _ => {} + } let request: Box = Box::new(OpenAIRequest { model: self.model.full_name().to_string(), @@ -1752,7 +1776,7 @@ impl Conversation { temperature: 1.0, }); - let stream = stream_completion(api_key, cx.background().clone(), request); + let stream = stream_completion(credential, cx.background().clone(), request); let assistant_message = self .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx) .unwrap(); @@ -2018,57 +2042,62 @@ impl Conversation { fn summarize(&mut self, cx: &mut ModelContext) { if self.message_anchors.len() >= 2 && self.summary.is_none() { - let api_key = self.api_key.borrow().clone(); - if let Some(api_key) = api_key { - let messages = self - .messages(cx) - .take(2) - .map(|message| message.to_open_ai_message(self.buffer.read(cx))) - .chain(Some(RequestMessage { - role: Role::User, - content: - "Summarize the conversation into a short title without punctuation" - .into(), - })); - let request: Box = Box::new(OpenAIRequest { - model: self.model.full_name().to_string(), - messages: messages.collect(), - stream: true, - stop: vec![], - temperature: 1.0, - }); + let credential = self.credential.borrow().clone(); - let stream = stream_completion(api_key, cx.background().clone(), request); - self.pending_summary = cx.spawn(|this, mut cx| { - async move { - let mut messages = stream.await?; + match credential { + ProviderCredential::NoCredentials => { + return; + } + _ => {} + } - while let Some(message) = messages.next().await { - let mut message = message?; - if let Some(choice) = message.choices.pop() { - let text = choice.delta.content.unwrap_or_default(); - this.update(&mut cx, |this, cx| { - this.summary - .get_or_insert(Default::default()) - .text - .push_str(&text); - cx.emit(ConversationEvent::SummaryChanged); - }); - } - } + let messages = self + .messages(cx) + .take(2) + .map(|message| message.to_open_ai_message(self.buffer.read(cx))) + .chain(Some(RequestMessage { + role: Role::User, + content: "Summarize the conversation into a short title without punctuation" + .into(), + })); + let request: Box = Box::new(OpenAIRequest { + model: self.model.full_name().to_string(), + messages: messages.collect(), + stream: true, + stop: vec![], + temperature: 1.0, + }); - this.update(&mut cx, |this, cx| { - if let Some(summary) = this.summary.as_mut() { - summary.done = true; - cx.emit(ConversationEvent::SummaryChanged); - } - }); + let stream = stream_completion(credential, cx.background().clone(), request); + self.pending_summary = cx.spawn(|this, mut cx| { + async move { + let mut messages = stream.await?; - anyhow::Ok(()) + while let Some(message) = messages.next().await { + let mut message = message?; + if let Some(choice) = message.choices.pop() { + let text = choice.delta.content.unwrap_or_default(); + this.update(&mut cx, |this, cx| { + this.summary + .get_or_insert(Default::default()) + .text + .push_str(&text); + cx.emit(ConversationEvent::SummaryChanged); + }); + } } - .log_err() - }); - } + + this.update(&mut cx, |this, cx| { + if let Some(summary) = this.summary.as_mut() { + summary.done = true; + cx.emit(ConversationEvent::SummaryChanged); + } + }); + + anyhow::Ok(()) + } + .log_err() + }); } } @@ -2229,13 +2258,13 @@ struct ConversationEditor { impl ConversationEditor { fn new( - api_key: Rc>>, + credential: Rc>, language_registry: Arc, fs: Arc, workspace: WeakViewHandle, cx: &mut ViewContext, ) -> Self { - let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx)); + let conversation = cx.add_model(|cx| Conversation::new(credential, language_registry, cx)); Self::for_conversation(conversation, fs, workspace, cx) } @@ -3431,7 +3460,13 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx)); + let conversation = cx.add_model(|cx| { + Conversation::new( + Rc::new(RefCell::new(ProviderCredential::NotNeeded)), + registry, + cx, + ) + }); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3559,7 +3594,13 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx)); + let conversation = cx.add_model(|cx| { + Conversation::new( + Rc::new(RefCell::new(ProviderCredential::NotNeeded)), + registry, + cx, + ) + }); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3655,7 +3696,13 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx)); + let conversation = cx.add_model(|cx| { + Conversation::new( + Rc::new(RefCell::new(ProviderCredential::NotNeeded)), + registry, + cx, + ) + }); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3737,8 +3784,13 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = - cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx)); + let conversation = cx.add_model(|cx| { + Conversation::new( + Rc::new(RefCell::new(ProviderCredential::NotNeeded)), + registry.clone(), + cx, + ) + }); let buffer = conversation.read(cx).buffer.clone(); let message_0 = conversation.read(cx).message_anchors[0].id; let message_1 = conversation.update(cx, |conversation, cx| { @@ -3775,7 +3827,7 @@ mod tests { Conversation::deserialize( conversation.read(cx).serialize(cx), Default::default(), - Default::default(), + Rc::new(RefCell::new(ProviderCredential::NotNeeded)), registry.clone(), cx, ) From 96bbb5cdea41d537324294bf025cfc6cca7ea51e Mon Sep 17 00:00:00 2001 From: Kirill Bulatov Date: Mon, 30 Oct 2023 11:14:00 +0200 Subject: [PATCH 16/25] Properly log prettier paths --- crates/prettier/src/prettier.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/prettier/src/prettier.rs b/crates/prettier/src/prettier.rs index 79fef40908c3589a192bf4a57e56e84eb68846ef..53e3101a3bebef89f515ed42864a6d806a9deeb5 100644 --- a/crates/prettier/src/prettier.rs +++ b/crates/prettier/src/prettier.rs @@ -93,7 +93,7 @@ impl Prettier { ) })?; (worktree_root_data.unwrap_or_else(|| { - panic!("cannot query prettier for non existing worktree root at {worktree_root_data:?}") + panic!("cannot query prettier for non existing worktree root at {worktree_root:?}") }), None) } else { let full_starting_path = worktree_root.join(&starting_path.starting_path); @@ -106,7 +106,7 @@ impl Prettier { })?; ( worktree_root_data.unwrap_or_else(|| { - panic!("cannot query prettier for non existing worktree root at {worktree_root_data:?}") + panic!("cannot query prettier for non existing worktree root at {worktree_root:?}") }), start_path_data, ) From 249bec3cac269909d96226e5732ea36ce8b3569d Mon Sep 17 00:00:00 2001 From: Kirill Bulatov Date: Mon, 30 Oct 2023 12:13:34 +0200 Subject: [PATCH 17/25] Do not panic on prettier search --- crates/prettier/src/prettier.rs | 120 +++++++++++++++----------------- 1 file changed, 55 insertions(+), 65 deletions(-) diff --git a/crates/prettier/src/prettier.rs b/crates/prettier/src/prettier.rs index 53e3101a3bebef89f515ed42864a6d806a9deeb5..6784dba7dcf92aa2b2a6f6f1dbadae062edff6a3 100644 --- a/crates/prettier/src/prettier.rs +++ b/crates/prettier/src/prettier.rs @@ -81,77 +81,67 @@ impl Prettier { if worktree_root != starting_path.worktree_root_path.as_ref() { vec![worktree_root] } else { - let (worktree_root_metadata, start_path_metadata) = if starting_path - .starting_path - .as_ref() - == Path::new("") - { - let worktree_root_data = - fs.metadata(&worktree_root).await.with_context(|| { + let worktree_root_metadata = fs + .metadata(&worktree_root) + .await + .with_context(|| { + format!("FS metadata fetch for worktree root path {worktree_root:?}",) + })? + .with_context(|| { + format!("empty FS metadata for worktree root at {worktree_root:?}") + })?; + if starting_path.starting_path.as_ref() == Path::new("") { + anyhow::ensure!( + !worktree_root_metadata.is_dir, + "For empty start path, worktree root should not be a directory {starting_path:?}" + ); + anyhow::ensure!( + !worktree_root_metadata.is_symlink, + "For empty start path, worktree root should not be a symlink {starting_path:?}" + ); + worktree_root + .parent() + .map(|path| vec![path.to_path_buf()]) + .unwrap_or_default() + } else { + let full_starting_path = worktree_root.join(&starting_path.starting_path); + let start_path_metadata = fs + .metadata(&full_starting_path) + .await + .with_context(|| { + format!( + "FS metadata fetch for starting path {full_starting_path:?}" + ) + })? + .with_context(|| { format!( - "FS metadata fetch for worktree root path {worktree_root:?}", + "empty FS metadata for starting path {full_starting_path:?}" ) })?; - (worktree_root_data.unwrap_or_else(|| { - panic!("cannot query prettier for non existing worktree root at {worktree_root:?}") - }), None) - } else { - let full_starting_path = worktree_root.join(&starting_path.starting_path); - let (worktree_root_data, start_path_data) = futures::try_join!( - fs.metadata(&worktree_root), - fs.metadata(&full_starting_path), - ) - .with_context(|| { - format!("FS metadata fetch for starting path {full_starting_path:?}",) - })?; - ( - worktree_root_data.unwrap_or_else(|| { - panic!("cannot query prettier for non existing worktree root at {worktree_root:?}") - }), - start_path_data, - ) - }; - - match start_path_metadata { - Some(start_path_metadata) => { - anyhow::ensure!(worktree_root_metadata.is_dir, - "For non-empty start path, worktree root {starting_path:?} should be a directory"); - anyhow::ensure!( - !start_path_metadata.is_dir, - "For non-empty start path, it should not be a directory {starting_path:?}" - ); - anyhow::ensure!( - !start_path_metadata.is_symlink, - "For non-empty start path, it should not be a symlink {starting_path:?}" - ); - let file_to_format = starting_path.starting_path.as_ref(); - let mut paths_to_check = VecDeque::from(vec![worktree_root.clone()]); - let mut current_path = worktree_root; - for path_component in file_to_format.components().into_iter() { - current_path = current_path.join(path_component); - paths_to_check.push_front(current_path.clone()); - if path_component.as_os_str().to_string_lossy() == "node_modules" { - break; - } + anyhow::ensure!(worktree_root_metadata.is_dir, + "For non-empty start path, worktree root {starting_path:?} should be a directory"); + anyhow::ensure!( + !start_path_metadata.is_dir, + "For non-empty start path, it should not be a directory {starting_path:?}" + ); + anyhow::ensure!( + !start_path_metadata.is_symlink, + "For non-empty start path, it should not be a symlink {starting_path:?}" + ); + + let file_to_format = starting_path.starting_path.as_ref(); + let mut paths_to_check = VecDeque::from(vec![worktree_root.clone()]); + let mut current_path = worktree_root; + for path_component in file_to_format.components().into_iter() { + current_path = current_path.join(path_component); + paths_to_check.push_front(current_path.clone()); + if path_component.as_os_str().to_string_lossy() == "node_modules" { + break; } - paths_to_check.pop_front(); // last one is the file itself or node_modules, skip it - Vec::from(paths_to_check) - } - None => { - anyhow::ensure!( - !worktree_root_metadata.is_dir, - "For empty start path, worktree root should not be a directory {starting_path:?}" - ); - anyhow::ensure!( - !worktree_root_metadata.is_symlink, - "For empty start path, worktree root should not be a symlink {starting_path:?}" - ); - worktree_root - .parent() - .map(|path| vec![path.to_path_buf()]) - .unwrap_or_default() } + paths_to_check.pop_front(); // last one is the file itself or node_modules, skip it + Vec::from(paths_to_check) } } } From b46a4b56808f7c3521250bef6ee9e4f4389b6973 Mon Sep 17 00:00:00 2001 From: Kirill Bulatov Date: Mon, 30 Oct 2023 12:07:11 +0200 Subject: [PATCH 18/25] Be more lenient when searching for prettier instance Do not check FS for existence (we'll error when start running prettier), simplify the code for looking it up --- crates/prettier/src/prettier.rs | 62 ++++++--------------------------- 1 file changed, 10 insertions(+), 52 deletions(-) diff --git a/crates/prettier/src/prettier.rs b/crates/prettier/src/prettier.rs index 6784dba7dcf92aa2b2a6f6f1dbadae062edff6a3..7517b4ee4345733fcf0d08c3292276c1d90b67a3 100644 --- a/crates/prettier/src/prettier.rs +++ b/crates/prettier/src/prettier.rs @@ -67,80 +67,38 @@ impl Prettier { starting_path: Option, fs: Arc, ) -> anyhow::Result { + fn is_node_modules(path_component: &std::path::Component<'_>) -> bool { + path_component.as_os_str().to_string_lossy() == "node_modules" + } + let paths_to_check = match starting_path.as_ref() { Some(starting_path) => { let worktree_root = starting_path .worktree_root_path .components() .into_iter() - .take_while(|path_component| { - path_component.as_os_str().to_string_lossy() != "node_modules" - }) + .take_while(|path_component| !is_node_modules(path_component)) .collect::(); - if worktree_root != starting_path.worktree_root_path.as_ref() { vec![worktree_root] } else { - let worktree_root_metadata = fs - .metadata(&worktree_root) - .await - .with_context(|| { - format!("FS metadata fetch for worktree root path {worktree_root:?}",) - })? - .with_context(|| { - format!("empty FS metadata for worktree root at {worktree_root:?}") - })?; if starting_path.starting_path.as_ref() == Path::new("") { - anyhow::ensure!( - !worktree_root_metadata.is_dir, - "For empty start path, worktree root should not be a directory {starting_path:?}" - ); - anyhow::ensure!( - !worktree_root_metadata.is_symlink, - "For empty start path, worktree root should not be a symlink {starting_path:?}" - ); worktree_root .parent() .map(|path| vec![path.to_path_buf()]) .unwrap_or_default() } else { - let full_starting_path = worktree_root.join(&starting_path.starting_path); - let start_path_metadata = fs - .metadata(&full_starting_path) - .await - .with_context(|| { - format!( - "FS metadata fetch for starting path {full_starting_path:?}" - ) - })? - .with_context(|| { - format!( - "empty FS metadata for starting path {full_starting_path:?}" - ) - })?; - - anyhow::ensure!(worktree_root_metadata.is_dir, - "For non-empty start path, worktree root {starting_path:?} should be a directory"); - anyhow::ensure!( - !start_path_metadata.is_dir, - "For non-empty start path, it should not be a directory {starting_path:?}" - ); - anyhow::ensure!( - !start_path_metadata.is_symlink, - "For non-empty start path, it should not be a symlink {starting_path:?}" - ); - let file_to_format = starting_path.starting_path.as_ref(); - let mut paths_to_check = VecDeque::from(vec![worktree_root.clone()]); + let mut paths_to_check = VecDeque::new(); let mut current_path = worktree_root; for path_component in file_to_format.components().into_iter() { - current_path = current_path.join(path_component); - paths_to_check.push_front(current_path.clone()); - if path_component.as_os_str().to_string_lossy() == "node_modules" { + let new_path = current_path.join(path_component); + let old_path = std::mem::replace(&mut current_path, new_path); + paths_to_check.push_front(old_path); + if is_node_modules(&path_component) { break; } } - paths_to_check.pop_front(); // last one is the file itself or node_modules, skip it Vec::from(paths_to_check) } } From a2c3971ad6202bcec51dc0f36ef13497e94d1597 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 30 Oct 2023 10:02:27 -0400 Subject: [PATCH 19/25] moved authentication for the semantic index into the EmbeddingProvider --- crates/ai/src/auth.rs | 11 +-- crates/ai/src/completion.rs | 18 +--- crates/ai/src/embedding.rs | 15 +--- crates/ai/src/providers/open_ai/auth.rs | 46 ---------- crates/ai/src/providers/open_ai/completion.rs | 78 ++++++++++++---- crates/ai/src/providers/open_ai/embedding.rs | 88 ++++++++++++++----- crates/ai/src/providers/open_ai/mod.rs | 3 +- crates/ai/src/providers/open_ai/new.rs | 11 +++ crates/ai/src/test.rs | 48 +++++----- crates/assistant/src/assistant_panel.rs | 7 +- crates/assistant/src/codegen.rs | 8 +- crates/semantic_index/src/embedding_queue.rs | 17 +--- crates/semantic_index/src/semantic_index.rs | 50 ++++------- .../src/semantic_index_tests.rs | 6 +- 14 files changed, 200 insertions(+), 206 deletions(-) delete mode 100644 crates/ai/src/providers/open_ai/auth.rs create mode 100644 crates/ai/src/providers/open_ai/new.rs diff --git a/crates/ai/src/auth.rs b/crates/ai/src/auth.rs index cb3f2beabb00ef79a52c5bbf10a19ef8d1a5cc48..c6256df2160abc7284c070ff204ee849911e4cf3 100644 --- a/crates/ai/src/auth.rs +++ b/crates/ai/src/auth.rs @@ -8,17 +8,8 @@ pub enum ProviderCredential { } pub trait CredentialProvider: Send + Sync { + fn has_credentials(&self) -> bool; fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential; fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential); fn delete_credentials(&self, cx: &AppContext); } - -#[derive(Clone)] -pub struct NullCredentialProvider; -impl CredentialProvider for NullCredentialProvider { - fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential { - ProviderCredential::NotNeeded - } - fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {} - fn delete_credentials(&self, cx: &AppContext) {} -} diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index 6a2806a5cb30b702952e90362a72df41a5c80786..7fdc49e91802995b467839a3c3db2f5d3659e834 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -1,28 +1,14 @@ use anyhow::Result; use futures::{future::BoxFuture, stream::BoxStream}; -use gpui::AppContext; -use crate::{ - auth::{CredentialProvider, ProviderCredential}, - models::LanguageModel, -}; +use crate::{auth::CredentialProvider, models::LanguageModel}; pub trait CompletionRequest: Send + Sync { fn data(&self) -> serde_json::Result; } -pub trait CompletionProvider { +pub trait CompletionProvider: CredentialProvider { fn base_model(&self) -> Box; - fn credential_provider(&self) -> Box; - fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { - self.credential_provider().retrieve_credentials(cx) - } - fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) { - self.credential_provider().save_credentials(cx, credential); - } - fn delete_credentials(&self, cx: &AppContext) { - self.credential_provider().delete_credentials(cx); - } fn complete( &self, prompt: Box, diff --git a/crates/ai/src/embedding.rs b/crates/ai/src/embedding.rs index 50f04232ab64f4f16571efb9c11fd4f1bccf89e2..6768b7ce7bab7898ee4222d5217a4de120ca57ed 100644 --- a/crates/ai/src/embedding.rs +++ b/crates/ai/src/embedding.rs @@ -2,12 +2,11 @@ use std::time::Instant; use anyhow::Result; use async_trait::async_trait; -use gpui::AppContext; use ordered_float::OrderedFloat; use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; use rusqlite::ToSql; -use crate::auth::{CredentialProvider, ProviderCredential}; +use crate::auth::CredentialProvider; use crate::models::LanguageModel; #[derive(Debug, PartialEq, Clone)] @@ -70,17 +69,9 @@ impl Embedding { } #[async_trait] -pub trait EmbeddingProvider: Sync + Send { +pub trait EmbeddingProvider: CredentialProvider { fn base_model(&self) -> Box; - fn credential_provider(&self) -> Box; - fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { - self.credential_provider().retrieve_credentials(cx) - } - async fn embed_batch( - &self, - spans: Vec, - credential: ProviderCredential, - ) -> Result>; + async fn embed_batch(&self, spans: Vec) -> Result>; fn max_tokens_per_batch(&self) -> usize; fn rate_limit_expiration(&self) -> Option; } diff --git a/crates/ai/src/providers/open_ai/auth.rs b/crates/ai/src/providers/open_ai/auth.rs deleted file mode 100644 index 7cb51ab449a066d683ab7e5921b3713bc5482cde..0000000000000000000000000000000000000000 --- a/crates/ai/src/providers/open_ai/auth.rs +++ /dev/null @@ -1,46 +0,0 @@ -use std::env; - -use gpui::AppContext; -use util::ResultExt; - -use crate::auth::{CredentialProvider, ProviderCredential}; -use crate::providers::open_ai::OPENAI_API_URL; - -#[derive(Clone)] -pub struct OpenAICredentialProvider {} - -impl CredentialProvider for OpenAICredentialProvider { - fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { - let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { - Some(api_key) - } else if let Some((_, api_key)) = cx - .platform() - .read_credentials(OPENAI_API_URL) - .log_err() - .flatten() - { - String::from_utf8(api_key).log_err() - } else { - None - }; - - if let Some(api_key) = api_key { - ProviderCredential::Credentials { api_key } - } else { - ProviderCredential::NoCredentials - } - } - fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) { - match credential { - ProviderCredential::Credentials { api_key } => { - cx.platform() - .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) - .log_err(); - } - _ => {} - } - } - fn delete_credentials(&self, cx: &AppContext) { - cx.platform().delete_credentials(OPENAI_API_URL).log_err(); - } -} diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs index febe491123fe848e74b49a50c6b18d0fafc3a702..02d25a7eec2a841bbc5226f4bd447909c34662fe 100644 --- a/crates/ai/src/providers/open_ai/completion.rs +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -3,14 +3,17 @@ use futures::{ future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, Stream, StreamExt, }; -use gpui::executor::Background; +use gpui::{executor::Background, AppContext}; use isahc::{http::StatusCode, Request, RequestExt}; +use parking_lot::RwLock; use serde::{Deserialize, Serialize}; use std::{ + env, fmt::{self, Display}, io, sync::Arc, }; +use util::ResultExt; use crate::{ auth::{CredentialProvider, ProviderCredential}, @@ -18,9 +21,7 @@ use crate::{ models::LanguageModel, }; -use super::{auth::OpenAICredentialProvider, OpenAILanguageModel}; - -pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; +use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL}; #[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] #[serde(rename_all = "lowercase")] @@ -194,42 +195,83 @@ pub async fn stream_completion( pub struct OpenAICompletionProvider { model: OpenAILanguageModel, - credential_provider: OpenAICredentialProvider, - credential: ProviderCredential, + credential: Arc>, executor: Arc, } impl OpenAICompletionProvider { - pub fn new( - model_name: &str, - credential: ProviderCredential, - executor: Arc, - ) -> Self { + pub fn new(model_name: &str, executor: Arc) -> Self { let model = OpenAILanguageModel::load(model_name); - let credential_provider = OpenAICredentialProvider {}; + let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); Self { model, - credential_provider, credential, executor, } } } +impl CredentialProvider for OpenAICompletionProvider { + fn has_credentials(&self) -> bool { + match *self.credential.read() { + ProviderCredential::Credentials { .. } => true, + _ => false, + } + } + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { + let mut credential = self.credential.write(); + match *credential { + ProviderCredential::Credentials { .. } => { + return credential.clone(); + } + _ => { + if let Ok(api_key) = env::var("OPENAI_API_KEY") { + *credential = ProviderCredential::Credentials { api_key }; + } else if let Some((_, api_key)) = cx + .platform() + .read_credentials(OPENAI_API_URL) + .log_err() + .flatten() + { + if let Some(api_key) = String::from_utf8(api_key).log_err() { + *credential = ProviderCredential::Credentials { api_key }; + } + } else { + }; + } + } + + credential.clone() + } + + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) { + match credential.clone() { + ProviderCredential::Credentials { api_key } => { + cx.platform() + .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) + .log_err(); + } + _ => {} + } + + *self.credential.write() = credential; + } + fn delete_credentials(&self, cx: &AppContext) { + cx.platform().delete_credentials(OPENAI_API_URL).log_err(); + *self.credential.write() = ProviderCredential::NoCredentials; + } +} + impl CompletionProvider for OpenAICompletionProvider { fn base_model(&self) -> Box { let model: Box = Box::new(self.model.clone()); model } - fn credential_provider(&self) -> Box { - let provider: Box = Box::new(self.credential_provider.clone()); - provider - } fn complete( &self, prompt: Box, ) -> BoxFuture<'static, Result>>> { - let credential = self.credential.clone(); + let credential = self.credential.read().clone(); let request = stream_completion(credential, self.executor.clone(), prompt); async move { let response = request.await?; diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs index dafc94580d16d3c79987311ee0e6558fb45d3669..fbfd0028f9fd0f99dc60649d40425c1cbd822485 100644 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -2,27 +2,29 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; use futures::AsyncReadExt; use gpui::executor::Background; -use gpui::serde_json; +use gpui::{serde_json, AppContext}; use isahc::http::StatusCode; use isahc::prelude::Configurable; use isahc::{AsyncBody, Response}; use lazy_static::lazy_static; -use parking_lot::Mutex; +use parking_lot::{Mutex, RwLock}; use parse_duration::parse; use postage::watch; use serde::{Deserialize, Serialize}; +use std::env; use std::ops::Add; use std::sync::Arc; use std::time::{Duration, Instant}; use tiktoken_rs::{cl100k_base, CoreBPE}; use util::http::{HttpClient, Request}; +use util::ResultExt; use crate::auth::{CredentialProvider, ProviderCredential}; use crate::embedding::{Embedding, EmbeddingProvider}; use crate::models::LanguageModel; use crate::providers::open_ai::OpenAILanguageModel; -use crate::providers::open_ai::auth::OpenAICredentialProvider; +use crate::providers::open_ai::OPENAI_API_URL; lazy_static! { static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); @@ -31,7 +33,7 @@ lazy_static! { #[derive(Clone)] pub struct OpenAIEmbeddingProvider { model: OpenAILanguageModel, - credential_provider: OpenAICredentialProvider, + credential: Arc>, pub client: Arc, pub executor: Arc, rate_limit_count_rx: watch::Receiver>, @@ -69,10 +71,11 @@ impl OpenAIEmbeddingProvider { let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); let model = OpenAILanguageModel::load("text-embedding-ada-002"); + let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); OpenAIEmbeddingProvider { model, - credential_provider: OpenAICredentialProvider {}, + credential, client, executor, rate_limit_count_rx, @@ -80,6 +83,13 @@ impl OpenAIEmbeddingProvider { } } + fn get_api_key(&self) -> Result { + match self.credential.read().clone() { + ProviderCredential::Credentials { api_key } => Ok(api_key), + _ => Err(anyhow!("api credentials not provided")), + } + } + fn resolve_rate_limit(&self) { let reset_time = *self.rate_limit_count_tx.lock().borrow(); @@ -136,6 +146,57 @@ impl OpenAIEmbeddingProvider { } } +impl CredentialProvider for OpenAIEmbeddingProvider { + fn has_credentials(&self) -> bool { + match *self.credential.read() { + ProviderCredential::Credentials { .. } => true, + _ => false, + } + } + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { + let mut credential = self.credential.write(); + match *credential { + ProviderCredential::Credentials { .. } => { + return credential.clone(); + } + _ => { + if let Ok(api_key) = env::var("OPENAI_API_KEY") { + *credential = ProviderCredential::Credentials { api_key }; + } else if let Some((_, api_key)) = cx + .platform() + .read_credentials(OPENAI_API_URL) + .log_err() + .flatten() + { + if let Some(api_key) = String::from_utf8(api_key).log_err() { + *credential = ProviderCredential::Credentials { api_key }; + } + } else { + }; + } + } + + credential.clone() + } + + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) { + match credential.clone() { + ProviderCredential::Credentials { api_key } => { + cx.platform() + .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) + .log_err(); + } + _ => {} + } + + *self.credential.write() = credential; + } + fn delete_credentials(&self, cx: &AppContext) { + cx.platform().delete_credentials(OPENAI_API_URL).log_err(); + *self.credential.write() = ProviderCredential::NoCredentials; + } +} + #[async_trait] impl EmbeddingProvider for OpenAIEmbeddingProvider { fn base_model(&self) -> Box { @@ -143,12 +204,6 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider { model } - fn credential_provider(&self) -> Box { - let credential_provider: Box = - Box::new(self.credential_provider.clone()); - credential_provider - } - fn max_tokens_per_batch(&self) -> usize { 50000 } @@ -157,18 +212,11 @@ impl EmbeddingProvider for OpenAIEmbeddingProvider { *self.rate_limit_count_rx.borrow() } - async fn embed_batch( - &self, - spans: Vec, - credential: ProviderCredential, - ) -> Result> { + async fn embed_batch(&self, spans: Vec) -> Result> { const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; const MAX_RETRIES: usize = 4; - let api_key = match credential { - ProviderCredential::Credentials { api_key } => anyhow::Ok(api_key), - _ => Err(anyhow!("no api key provided")), - }?; + let api_key = self.get_api_key()?; let mut request_number = 0; let mut rate_limiting = false; diff --git a/crates/ai/src/providers/open_ai/mod.rs b/crates/ai/src/providers/open_ai/mod.rs index 49e29fbc8c184e9bc39e923563778e3678ed7c9e..7d2f86045d9e1049c55111aef175ac9b56dc7e16 100644 --- a/crates/ai/src/providers/open_ai/mod.rs +++ b/crates/ai/src/providers/open_ai/mod.rs @@ -1,4 +1,3 @@ -pub mod auth; pub mod completion; pub mod embedding; pub mod model; @@ -6,3 +5,5 @@ pub mod model; pub use completion::*; pub use embedding::*; pub use model::OpenAILanguageModel; + +pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; diff --git a/crates/ai/src/providers/open_ai/new.rs b/crates/ai/src/providers/open_ai/new.rs new file mode 100644 index 0000000000000000000000000000000000000000..c7d67f2ba1d252a6865124d8ffdfb79130a8c3a0 --- /dev/null +++ b/crates/ai/src/providers/open_ai/new.rs @@ -0,0 +1,11 @@ +pub trait LanguageModel { + fn name(&self) -> String; + fn count_tokens(&self, content: &str) -> anyhow::Result; + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result; + fn capacity(&self) -> anyhow::Result; +} diff --git a/crates/ai/src/test.rs b/crates/ai/src/test.rs index b8f99af400f213f96ef7a1911d06f419fb990c54..bc9a6a3e434dc8374fb181b27f7bbfa81f63f235 100644 --- a/crates/ai/src/test.rs +++ b/crates/ai/src/test.rs @@ -5,10 +5,11 @@ use std::{ use async_trait::async_trait; use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use gpui::AppContext; use parking_lot::Mutex; use crate::{ - auth::{CredentialProvider, NullCredentialProvider, ProviderCredential}, + auth::{CredentialProvider, ProviderCredential}, completion::{CompletionProvider, CompletionRequest}, embedding::{Embedding, EmbeddingProvider}, models::{LanguageModel, TruncationDirection}, @@ -52,14 +53,12 @@ impl LanguageModel for FakeLanguageModel { pub struct FakeEmbeddingProvider { pub embedding_count: AtomicUsize, - pub credential_provider: NullCredentialProvider, } impl Clone for FakeEmbeddingProvider { fn clone(&self) -> Self { FakeEmbeddingProvider { embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)), - credential_provider: self.credential_provider.clone(), } } } @@ -68,7 +67,6 @@ impl Default for FakeEmbeddingProvider { fn default() -> Self { FakeEmbeddingProvider { embedding_count: AtomicUsize::default(), - credential_provider: NullCredentialProvider {}, } } } @@ -99,16 +97,22 @@ impl FakeEmbeddingProvider { } } +impl CredentialProvider for FakeEmbeddingProvider { + fn has_credentials(&self) -> bool { + true + } + fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential { + ProviderCredential::NotNeeded + } + fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {} + fn delete_credentials(&self, _cx: &AppContext) {} +} + #[async_trait] impl EmbeddingProvider for FakeEmbeddingProvider { fn base_model(&self) -> Box { Box::new(FakeLanguageModel { capacity: 1000 }) } - fn credential_provider(&self) -> Box { - let credential_provider: Box = - Box::new(self.credential_provider.clone()); - credential_provider - } fn max_tokens_per_batch(&self) -> usize { 1000 } @@ -117,11 +121,7 @@ impl EmbeddingProvider for FakeEmbeddingProvider { None } - async fn embed_batch( - &self, - spans: Vec, - _credential: ProviderCredential, - ) -> anyhow::Result> { + async fn embed_batch(&self, spans: Vec) -> anyhow::Result> { self.embedding_count .fetch_add(spans.len(), atomic::Ordering::SeqCst); @@ -129,11 +129,11 @@ impl EmbeddingProvider for FakeEmbeddingProvider { } } -pub struct TestCompletionProvider { +pub struct FakeCompletionProvider { last_completion_tx: Mutex>>, } -impl TestCompletionProvider { +impl FakeCompletionProvider { pub fn new() -> Self { Self { last_completion_tx: Mutex::new(None), @@ -150,14 +150,22 @@ impl TestCompletionProvider { } } -impl CompletionProvider for TestCompletionProvider { +impl CredentialProvider for FakeCompletionProvider { + fn has_credentials(&self) -> bool { + true + } + fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential { + ProviderCredential::NotNeeded + } + fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {} + fn delete_credentials(&self, _cx: &AppContext) {} +} + +impl CompletionProvider for FakeCompletionProvider { fn base_model(&self) -> Box { let model: Box = Box::new(FakeLanguageModel { capacity: 8190 }); model } - fn credential_provider(&self) -> Box { - Box::new(NullCredentialProvider {}) - } fn complete( &self, _prompt: Box, diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index f9187b87855b748b536aecb908c0c44ce3377b5f..c10ad2c362fbea3118aa77f2f89af42168c3a8f9 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -10,7 +10,7 @@ use ai::{ auth::ProviderCredential, completion::{CompletionProvider, CompletionRequest}, providers::open_ai::{ - stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, + stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, }, }; @@ -48,7 +48,7 @@ use semantic_index::{SemanticIndex, SemanticIndexStatus}; use settings::SettingsStore; use std::{ cell::{Cell, RefCell}, - cmp, env, + cmp, fmt::Write, iter, ops::Range, @@ -210,7 +210,6 @@ impl AssistantPanel { // Defaulting currently to GPT4, allow for this to be set via config. let completion_provider = Box::new(OpenAICompletionProvider::new( "gpt-4", - ProviderCredential::NoCredentials, cx.background().clone(), )); @@ -298,7 +297,6 @@ impl AssistantPanel { cx: &mut ViewContext, project: &ModelHandle, ) { - let credential = self.credential.borrow().clone(); let selection = editor.read(cx).selections.newest_anchor().clone(); if selection.start.excerpt_id() != selection.end.excerpt_id() { return; @@ -330,7 +328,6 @@ impl AssistantPanel { let inline_assist_id = post_inc(&mut self.next_inline_assist_id); let provider = Arc::new(OpenAICompletionProvider::new( "gpt-4", - credential, cx.background().clone(), )); diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index 7f4c95f655e7bf1ea222b109865fbf4a98ad2d0d..8d8e49902f1f8d1365d365ae507d208244ca39ce 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -335,7 +335,7 @@ fn strip_markdown_codeblock( #[cfg(test)] mod tests { use super::*; - use ai::test::TestCompletionProvider; + use ai::test::FakeCompletionProvider; use futures::stream::{self}; use gpui::{executor::Deterministic, TestAppContext}; use indoc::indoc; @@ -379,7 +379,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) }); - let provider = Arc::new(TestCompletionProvider::new()); + let provider = Arc::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(), @@ -445,7 +445,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 6)) }); - let provider = Arc::new(TestCompletionProvider::new()); + let provider = Arc::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(), @@ -511,7 +511,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 2)) }); - let provider = Arc::new(TestCompletionProvider::new()); + let provider = Arc::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(), diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index 6f792c78e21dc91f333c22a918f6d13047c506c9..6ae8faa4cdb0998acad8ab361c83183cff909578 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -1,5 +1,5 @@ use crate::{parsing::Span, JobHandle}; -use ai::{auth::ProviderCredential, embedding::EmbeddingProvider}; +use ai::embedding::EmbeddingProvider; use gpui::executor::Background; use parking_lot::Mutex; use smol::channel; @@ -41,7 +41,6 @@ pub struct EmbeddingQueue { pending_batch_token_count: usize, finished_files_tx: channel::Sender, finished_files_rx: channel::Receiver, - pub provider_credential: ProviderCredential, } #[derive(Clone)] @@ -51,11 +50,7 @@ pub struct FileFragmentToEmbed { } impl EmbeddingQueue { - pub fn new( - embedding_provider: Arc, - executor: Arc, - provider_credential: ProviderCredential, - ) -> Self { + pub fn new(embedding_provider: Arc, executor: Arc) -> Self { let (finished_files_tx, finished_files_rx) = channel::unbounded(); Self { embedding_provider, @@ -64,14 +59,9 @@ impl EmbeddingQueue { pending_batch_token_count: 0, finished_files_tx, finished_files_rx, - provider_credential, } } - pub fn set_credential(&mut self, credential: ProviderCredential) { - self.provider_credential = credential; - } - pub fn push(&mut self, file: FileToEmbed) { if file.spans.is_empty() { self.finished_files_tx.try_send(file).unwrap(); @@ -118,7 +108,6 @@ impl EmbeddingQueue { let finished_files_tx = self.finished_files_tx.clone(); let embedding_provider = self.embedding_provider.clone(); - let credential = self.provider_credential.clone(); self.executor .spawn(async move { @@ -143,7 +132,7 @@ impl EmbeddingQueue { return; }; - match embedding_provider.embed_batch(spans, credential).await { + match embedding_provider.embed_batch(spans).await { Ok(embeddings) => { let mut embeddings = embeddings.into_iter(); for fragment in batch { diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 7fb5f749b4f6851b9201ddac6cd7fc59ce6f782a..818faa044475d8ed76a90dafcb6a2f6e3b7f7ee6 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -7,7 +7,6 @@ pub mod semantic_index_settings; mod semantic_index_tests; use crate::semantic_index_settings::SemanticIndexSettings; -use ai::auth::ProviderCredential; use ai::embedding::{Embedding, EmbeddingProvider}; use ai::providers::open_ai::OpenAIEmbeddingProvider; use anyhow::{anyhow, Result}; @@ -125,8 +124,6 @@ pub struct SemanticIndex { _embedding_task: Task<()>, _parsing_files_tasks: Vec>, projects: HashMap, ProjectState>, - provider_credential: ProviderCredential, - embedding_queue: Arc>, } struct ProjectState { @@ -281,24 +278,17 @@ impl SemanticIndex { } pub fn authenticate(&mut self, cx: &AppContext) -> bool { - let existing_credential = self.provider_credential.clone(); - let credential = match existing_credential { - ProviderCredential::NoCredentials => self.embedding_provider.retrieve_credentials(cx), - _ => existing_credential, - }; + if !self.embedding_provider.has_credentials() { + self.embedding_provider.retrieve_credentials(cx); + } else { + return true; + } - self.provider_credential = credential.clone(); - self.embedding_queue.lock().set_credential(credential); - self.is_authenticated() + self.embedding_provider.has_credentials() } pub fn is_authenticated(&self) -> bool { - let credential = &self.provider_credential; - match credential { - &ProviderCredential::Credentials { .. } => true, - &ProviderCredential::NotNeeded => true, - _ => false, - } + self.embedding_provider.has_credentials() } pub fn enabled(cx: &AppContext) -> bool { @@ -348,7 +338,7 @@ impl SemanticIndex { Ok(cx.add_model(|cx| { let t0 = Instant::now(); let embedding_queue = - EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), ProviderCredential::NoCredentials); + EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone()); let _embedding_task = cx.background().spawn({ let embedded_files = embedding_queue.finished_files(); let db = db.clone(); @@ -413,8 +403,6 @@ impl SemanticIndex { _embedding_task, _parsing_files_tasks, projects: Default::default(), - provider_credential: ProviderCredential::NoCredentials, - embedding_queue } })) } @@ -729,14 +717,13 @@ impl SemanticIndex { let index = self.index_project(project.clone(), cx); let embedding_provider = self.embedding_provider.clone(); - let credential = self.provider_credential.clone(); cx.spawn(|this, mut cx| async move { index.await?; let t0 = Instant::now(); let query = embedding_provider - .embed_batch(vec![query], credential) + .embed_batch(vec![query]) .await? .pop() .ok_or_else(|| anyhow!("could not embed query"))?; @@ -954,7 +941,6 @@ impl SemanticIndex { let fs = self.fs.clone(); let db_path = self.db.path().clone(); let background = cx.background().clone(); - let credential = self.provider_credential.clone(); cx.background().spawn(async move { let db = VectorDatabase::new(fs, db_path.clone(), background).await?; let mut results = Vec::::new(); @@ -969,15 +955,10 @@ impl SemanticIndex { .parse_file_with_template(None, &snapshot.text(), language) .log_err() .unwrap_or_default(); - if Self::embed_spans( - &mut spans, - embedding_provider.as_ref(), - &db, - credential.clone(), - ) - .await - .log_err() - .is_some() + if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db) + .await + .log_err() + .is_some() { for span in spans { let similarity = span.embedding.unwrap().similarity(&query); @@ -1201,7 +1182,6 @@ impl SemanticIndex { spans: &mut [Span], embedding_provider: &dyn EmbeddingProvider, db: &VectorDatabase, - credential: ProviderCredential, ) -> Result<()> { let mut batch = Vec::new(); let mut batch_tokens = 0; @@ -1224,7 +1204,7 @@ impl SemanticIndex { if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() { let batch_embeddings = embedding_provider - .embed_batch(mem::take(&mut batch), credential.clone()) + .embed_batch(mem::take(&mut batch)) .await?; embeddings.extend(batch_embeddings); batch_tokens = 0; @@ -1236,7 +1216,7 @@ impl SemanticIndex { if !batch.is_empty() { let batch_embeddings = embedding_provider - .embed_batch(mem::take(&mut batch), credential) + .embed_batch(mem::take(&mut batch)) .await?; embeddings.extend(batch_embeddings); diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index 7d5a4e22e80530b4d22314beb12582bb300ff406..7a91d1e100ca96d21e937fe7945c3cfd78fc68fd 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -220,11 +220,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let mut queue = EmbeddingQueue::new( - embedding_provider.clone(), - cx.background(), - ai::auth::ProviderCredential::NoCredentials, - ); + let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background()); for file in &files { queue.push(file.clone()); } From f3c113fe02489748823b67f6e3340a094d412795 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 30 Oct 2023 11:07:24 -0400 Subject: [PATCH 20/25] clean up warnings and fix tests in the ai crate --- crates/ai/src/completion.rs | 7 + crates/ai/src/prompts/base.rs | 4 +- crates/ai/src/providers/open_ai/completion.rs | 8 + crates/ai/src/test.rs | 14 ++ crates/assistant/src/assistant_panel.rs | 214 ++++++------------ 5 files changed, 103 insertions(+), 144 deletions(-) diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index 7fdc49e91802995b467839a3c3db2f5d3659e834..30a60fcf1d5c5dc66717773968e432e510d6421f 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -13,4 +13,11 @@ pub trait CompletionProvider: CredentialProvider { &self, prompt: Box, ) -> BoxFuture<'static, Result>>>; + fn box_clone(&self) -> Box; +} + +impl Clone for Box { + fn clone(&self) -> Box { + self.box_clone() + } } diff --git a/crates/ai/src/prompts/base.rs b/crates/ai/src/prompts/base.rs index a2106c74106b0b976cc70717b4535ccb557573bf..75bad00154b001a356f35cb5d80e4ac4962fe0f9 100644 --- a/crates/ai/src/prompts/base.rs +++ b/crates/ai/src/prompts/base.rs @@ -147,7 +147,7 @@ pub(crate) mod tests { content = args.model.truncate( &content, max_token_length, - TruncationDirection::Start, + TruncationDirection::End, )?; token_count = max_token_length; } @@ -172,7 +172,7 @@ pub(crate) mod tests { content = args.model.truncate( &content, max_token_length, - TruncationDirection::Start, + TruncationDirection::End, )?; token_count = max_token_length; } diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs index 02d25a7eec2a841bbc5226f4bd447909c34662fe..94685fd233520fc919d2708a20c202828250a481 100644 --- a/crates/ai/src/providers/open_ai/completion.rs +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -193,6 +193,7 @@ pub async fn stream_completion( } } +#[derive(Clone)] pub struct OpenAICompletionProvider { model: OpenAILanguageModel, credential: Arc>, @@ -271,6 +272,10 @@ impl CompletionProvider for OpenAICompletionProvider { &self, prompt: Box, ) -> BoxFuture<'static, Result>>> { + // Currently the CompletionRequest for OpenAI, includes a 'model' parameter + // This means that the model is determined by the CompletionRequest and not the CompletionProvider, + // which is currently model based, due to the langauge model. + // At some point in the future we should rectify this. let credential = self.credential.read().clone(); let request = stream_completion(credential, self.executor.clone(), prompt); async move { @@ -287,4 +292,7 @@ impl CompletionProvider for OpenAICompletionProvider { } .boxed() } + fn box_clone(&self) -> Box { + Box::new((*self).clone()) + } } diff --git a/crates/ai/src/test.rs b/crates/ai/src/test.rs index bc9a6a3e434dc8374fb181b27f7bbfa81f63f235..d4165f3cca897c4adbf11c2babf6038a8d86f0a6 100644 --- a/crates/ai/src/test.rs +++ b/crates/ai/src/test.rs @@ -33,7 +33,10 @@ impl LanguageModel for FakeLanguageModel { length: usize, direction: TruncationDirection, ) -> anyhow::Result { + println!("TRYING TO TRUNCATE: {:?}", length.clone()); + if length > self.count_tokens(content)? { + println!("NOT TRUNCATING"); return anyhow::Ok(content.to_string()); } @@ -133,6 +136,14 @@ pub struct FakeCompletionProvider { last_completion_tx: Mutex>>, } +impl Clone for FakeCompletionProvider { + fn clone(&self) -> Self { + Self { + last_completion_tx: Mutex::new(None), + } + } +} + impl FakeCompletionProvider { pub fn new() -> Self { Self { @@ -174,4 +185,7 @@ impl CompletionProvider for FakeCompletionProvider { *self.last_completion_tx.lock() = Some(tx); async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() } + fn box_clone(&self) -> Box { + Box::new((*self).clone()) + } } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index c10ad2c362fbea3118aa77f2f89af42168c3a8f9..d0c7e7e8831c7de3cdd19593f60a0fab68c35393 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -9,9 +9,7 @@ use crate::{ use ai::{ auth::ProviderCredential, completion::{CompletionProvider, CompletionRequest}, - providers::open_ai::{ - stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, - }, + providers::open_ai::{OpenAICompletionProvider, OpenAIRequest, RequestMessage}, }; use ai::prompts::repository_context::PromptCodeSnippet; @@ -47,7 +45,7 @@ use search::BufferSearchBar; use semantic_index::{SemanticIndex, SemanticIndexStatus}; use settings::SettingsStore; use std::{ - cell::{Cell, RefCell}, + cell::Cell, cmp, fmt::Write, iter, @@ -144,10 +142,8 @@ pub struct AssistantPanel { zoomed: bool, has_focus: bool, toolbar: ViewHandle, - credential: Rc>, completion_provider: Box, api_key_editor: Option>, - has_read_credentials: bool, languages: Arc, fs: Arc, subscriptions: Vec, @@ -223,10 +219,8 @@ impl AssistantPanel { zoomed: false, has_focus: false, toolbar, - credential: Rc::new(RefCell::new(ProviderCredential::NoCredentials)), completion_provider, api_key_editor: None, - has_read_credentials: false, languages: workspace.app_state().languages.clone(), fs: workspace.app_state().fs.clone(), width: None, @@ -265,7 +259,7 @@ impl AssistantPanel { cx: &mut ViewContext, ) { let this = if let Some(this) = workspace.panel::(cx) { - if this.update(cx, |assistant, cx| assistant.has_credentials(cx)) { + if this.update(cx, |assistant, _| assistant.has_credentials()) { this } else { workspace.focus_panel::(cx); @@ -331,6 +325,9 @@ impl AssistantPanel { cx.background().clone(), )); + // Retrieve Credentials Authenticates the Provider + // provider.retrieve_credentials(cx); + let codegen = cx.add_model(|cx| { Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx) }); @@ -814,7 +811,7 @@ impl AssistantPanel { fn new_conversation(&mut self, cx: &mut ViewContext) -> ViewHandle { let editor = cx.add_view(|cx| { ConversationEditor::new( - self.credential.clone(), + self.completion_provider.clone(), self.languages.clone(), self.fs.clone(), self.workspace.clone(), @@ -883,9 +880,8 @@ impl AssistantPanel { let credential = ProviderCredential::Credentials { api_key: api_key.clone(), }; - self.completion_provider - .save_credentials(cx, credential.clone()); - *self.credential.borrow_mut() = credential; + + self.completion_provider.save_credentials(cx, credential); self.api_key_editor.take(); cx.focus_self(); @@ -898,7 +894,6 @@ impl AssistantPanel { fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext) { self.completion_provider.delete_credentials(cx); - *self.credential.borrow_mut() = ProviderCredential::NoCredentials; self.api_key_editor = Some(build_api_key_editor(cx)); cx.focus_self(); cx.notify(); @@ -1157,19 +1152,12 @@ impl AssistantPanel { let fs = self.fs.clone(); let workspace = self.workspace.clone(); - let credential = self.credential.clone(); let languages = self.languages.clone(); cx.spawn(|this, mut cx| async move { let saved_conversation = fs.load(&path).await?; let saved_conversation = serde_json::from_str(&saved_conversation)?; let conversation = cx.add_model(|cx| { - Conversation::deserialize( - saved_conversation, - path.clone(), - credential, - languages, - cx, - ) + Conversation::deserialize(saved_conversation, path.clone(), languages, cx) }); this.update(&mut cx, |this, cx| { // If, by the time we've loaded the conversation, the user has already opened @@ -1193,39 +1181,12 @@ impl AssistantPanel { .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path)) } - fn has_credentials(&mut self, cx: &mut ViewContext) -> bool { - let credential = self.load_credentials(cx); - match credential { - ProviderCredential::Credentials { .. } => true, - ProviderCredential::NotNeeded => true, - ProviderCredential::NoCredentials => false, - } + fn has_credentials(&mut self) -> bool { + self.completion_provider.has_credentials() } - fn load_credentials(&mut self, cx: &mut ViewContext) -> ProviderCredential { - let existing_credential = self.credential.clone(); - let existing_credential = existing_credential.borrow().clone(); - match existing_credential { - ProviderCredential::NoCredentials => { - if !self.has_read_credentials { - self.has_read_credentials = true; - let retrieved_credentials = self.completion_provider.retrieve_credentials(cx); - - match retrieved_credentials { - ProviderCredential::NoCredentials {} => { - self.api_key_editor = Some(build_api_key_editor(cx)); - cx.notify(); - } - _ => { - *self.credential.borrow_mut() = retrieved_credentials; - } - } - } - } - _ => {} - } - - self.credential.borrow().clone() + fn load_credentials(&mut self, cx: &mut ViewContext) { + self.completion_provider.retrieve_credentials(cx); } } @@ -1475,10 +1436,10 @@ struct Conversation { token_count: Option, max_token_count: usize, pending_token_count: Task>, - credential: Rc>, pending_save: Task>, path: Option, _subscriptions: Vec, + completion_provider: Box, } impl Entity for Conversation { @@ -1487,10 +1448,9 @@ impl Entity for Conversation { impl Conversation { fn new( - credential: Rc>, - language_registry: Arc, cx: &mut ModelContext, + completion_provider: Box, ) -> Self { let markdown = language_registry.language_for_name("Markdown"); let buffer = cx.add_model(|cx| { @@ -1529,8 +1489,8 @@ impl Conversation { _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), path: None, - credential, buffer, + completion_provider, }; let message = MessageAnchor { id: MessageId(post_inc(&mut this.next_message_id.0)), @@ -1576,7 +1536,6 @@ impl Conversation { fn deserialize( saved_conversation: SavedConversation, path: PathBuf, - credential: Rc>, language_registry: Arc, cx: &mut ModelContext, ) -> Self { @@ -1585,6 +1544,10 @@ impl Conversation { None => Some(Uuid::new_v4().to_string()), }; let model = saved_conversation.model; + let completion_provider: Box = Box::new( + OpenAICompletionProvider::new(model.full_name(), cx.background().clone()), + ); + completion_provider.retrieve_credentials(cx); let markdown = language_registry.language_for_name("Markdown"); let mut message_anchors = Vec::new(); let mut next_message_id = MessageId(0); @@ -1631,8 +1594,8 @@ impl Conversation { _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), path: Some(path), - credential, buffer, + completion_provider, }; this.count_remaining_tokens(cx); this @@ -1753,12 +1716,8 @@ impl Conversation { } if should_assist { - let credential = self.credential.borrow().clone(); - match credential { - ProviderCredential::NoCredentials => { - return Default::default(); - } - _ => {} + if !self.completion_provider.has_credentials() { + return Default::default(); } let request: Box = Box::new(OpenAIRequest { @@ -1773,7 +1732,7 @@ impl Conversation { temperature: 1.0, }); - let stream = stream_completion(credential, cx.background().clone(), request); + let stream = self.completion_provider.complete(request); let assistant_message = self .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx) .unwrap(); @@ -1791,33 +1750,28 @@ impl Conversation { let mut messages = stream.await?; while let Some(message) = messages.next().await { - let mut message = message?; - if let Some(choice) = message.choices.pop() { - this.upgrade(&cx) - .ok_or_else(|| anyhow!("conversation was dropped"))? - .update(&mut cx, |this, cx| { - let text: Arc = choice.delta.content?.into(); - let message_ix = - this.message_anchors.iter().position(|message| { - message.id == assistant_message_id - })?; - this.buffer.update(cx, |buffer, cx| { - let offset = this.message_anchors[message_ix + 1..] - .iter() - .find(|message| message.start.is_valid(buffer)) - .map_or(buffer.len(), |message| { - message - .start - .to_offset(buffer) - .saturating_sub(1) - }); - buffer.edit([(offset..offset, text)], None, cx); - }); - cx.emit(ConversationEvent::StreamedCompletion); - - Some(()) + let text = message?; + + this.upgrade(&cx) + .ok_or_else(|| anyhow!("conversation was dropped"))? + .update(&mut cx, |this, cx| { + let message_ix = this + .message_anchors + .iter() + .position(|message| message.id == assistant_message_id)?; + this.buffer.update(cx, |buffer, cx| { + let offset = this.message_anchors[message_ix + 1..] + .iter() + .find(|message| message.start.is_valid(buffer)) + .map_or(buffer.len(), |message| { + message.start.to_offset(buffer).saturating_sub(1) + }); + buffer.edit([(offset..offset, text)], None, cx); }); - } + cx.emit(ConversationEvent::StreamedCompletion); + + Some(()) + }); smol::future::yield_now().await; } @@ -2039,13 +1993,8 @@ impl Conversation { fn summarize(&mut self, cx: &mut ModelContext) { if self.message_anchors.len() >= 2 && self.summary.is_none() { - let credential = self.credential.borrow().clone(); - - match credential { - ProviderCredential::NoCredentials => { - return; - } - _ => {} + if !self.completion_provider.has_credentials() { + return; } let messages = self @@ -2065,23 +2014,20 @@ impl Conversation { temperature: 1.0, }); - let stream = stream_completion(credential, cx.background().clone(), request); + let stream = self.completion_provider.complete(request); self.pending_summary = cx.spawn(|this, mut cx| { async move { let mut messages = stream.await?; while let Some(message) = messages.next().await { - let mut message = message?; - if let Some(choice) = message.choices.pop() { - let text = choice.delta.content.unwrap_or_default(); - this.update(&mut cx, |this, cx| { - this.summary - .get_or_insert(Default::default()) - .text - .push_str(&text); - cx.emit(ConversationEvent::SummaryChanged); - }); - } + let text = message?; + this.update(&mut cx, |this, cx| { + this.summary + .get_or_insert(Default::default()) + .text + .push_str(&text); + cx.emit(ConversationEvent::SummaryChanged); + }); } this.update(&mut cx, |this, cx| { @@ -2255,13 +2201,14 @@ struct ConversationEditor { impl ConversationEditor { fn new( - credential: Rc>, + completion_provider: Box, language_registry: Arc, fs: Arc, workspace: WeakViewHandle, cx: &mut ViewContext, ) -> Self { - let conversation = cx.add_model(|cx| Conversation::new(credential, language_registry, cx)); + let conversation = + cx.add_model(|cx| Conversation::new(language_registry, cx, completion_provider)); Self::for_conversation(conversation, fs, workspace, cx) } @@ -3450,6 +3397,7 @@ fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { mod tests { use super::*; use crate::MessageId; + use ai::test::FakeCompletionProvider; use gpui::AppContext; #[gpui::test] @@ -3457,13 +3405,9 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| { - Conversation::new( - Rc::new(RefCell::new(ProviderCredential::NotNeeded)), - registry, - cx, - ) - }); + + let completion_provider = Box::new(FakeCompletionProvider::new()); + let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3591,13 +3535,9 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| { - Conversation::new( - Rc::new(RefCell::new(ProviderCredential::NotNeeded)), - registry, - cx, - ) - }); + let completion_provider = Box::new(FakeCompletionProvider::new()); + + let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3693,13 +3633,8 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| { - Conversation::new( - Rc::new(RefCell::new(ProviderCredential::NotNeeded)), - registry, - cx, - ) - }); + let completion_provider = Box::new(FakeCompletionProvider::new()); + let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3781,13 +3716,9 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| { - Conversation::new( - Rc::new(RefCell::new(ProviderCredential::NotNeeded)), - registry.clone(), - cx, - ) - }); + let completion_provider = Box::new(FakeCompletionProvider::new()); + let conversation = + cx.add_model(|cx| Conversation::new(registry.clone(), cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); let message_0 = conversation.read(cx).message_anchors[0].id; let message_1 = conversation.update(cx, |conversation, cx| { @@ -3824,7 +3755,6 @@ mod tests { Conversation::deserialize( conversation.read(cx).serialize(cx), Default::default(), - Rc::new(RefCell::new(ProviderCredential::NotNeeded)), registry.clone(), cx, ) From dc8a8538421102c1f41cbec1f79f5130ecb7ed35 Mon Sep 17 00:00:00 2001 From: Piotr Osiewicz <24362066+osiewicz@users.noreply.github.com> Date: Mon, 30 Oct 2023 18:27:05 +0100 Subject: [PATCH 21/25] lsp/next-ls: Fix wrong nls binary being fetched. (#3181) CPU types had to be swapped around. Fixed zed-industries/community#2185 Release Notes: - Fixed Elixir next-ls LSP installation failing due to fetching a binary for the wrong architecture (zed-industries/community#2185). --- crates/zed/src/languages/elixir.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/zed/src/languages/elixir.rs b/crates/zed/src/languages/elixir.rs index 5c0ff273ae385ce1aa2efff99cc2d00860898f08..df438d89eef7b7c5c89a376670d7d270db4fa413 100644 --- a/crates/zed/src/languages/elixir.rs +++ b/crates/zed/src/languages/elixir.rs @@ -321,8 +321,8 @@ impl LspAdapter for NextLspAdapter { latest_github_release("elixir-tools/next-ls", false, delegate.http_client()).await?; let version = release.name.clone(); let platform = match consts::ARCH { - "x86_64" => "darwin_arm64", - "aarch64" => "darwin_amd64", + "x86_64" => "darwin_amd64", + "aarch64" => "darwin_arm64", other => bail!("Running on unsupported platform: {other}"), }; let asset_name = format!("next_ls_{}", platform); From 04ab68502b950e7d23c0522347b586ebd3670a4f Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 30 Oct 2023 14:40:31 -0400 Subject: [PATCH 22/25] port ai crate to ai2, with all tests passing --- Cargo.lock | 28 ++ crates/Cargo.toml | 38 ++ crates/ai2/Cargo.toml | 38 ++ crates/ai2/src/ai2.rs | 8 + crates/ai2/src/auth.rs | 17 + crates/ai2/src/completion.rs | 23 ++ crates/ai2/src/embedding.rs | 123 +++++++ crates/ai2/src/models.rs | 16 + crates/ai2/src/prompts/base.rs | 330 ++++++++++++++++++ crates/ai2/src/prompts/file_context.rs | 164 +++++++++ crates/ai2/src/prompts/generate.rs | 99 ++++++ crates/ai2/src/prompts/mod.rs | 5 + crates/ai2/src/prompts/preamble.rs | 52 +++ crates/ai2/src/prompts/repository_context.rs | 98 ++++++ crates/ai2/src/providers/mod.rs | 1 + .../ai2/src/providers/open_ai/completion.rs | 306 ++++++++++++++++ crates/ai2/src/providers/open_ai/embedding.rs | 313 +++++++++++++++++ crates/ai2/src/providers/open_ai/mod.rs | 9 + crates/ai2/src/providers/open_ai/model.rs | 57 +++ crates/ai2/src/providers/open_ai/new.rs | 11 + crates/ai2/src/test.rs | 193 ++++++++++ crates/zed2/Cargo.toml | 1 + 22 files changed, 1930 insertions(+) create mode 100644 crates/Cargo.toml create mode 100644 crates/ai2/Cargo.toml create mode 100644 crates/ai2/src/ai2.rs create mode 100644 crates/ai2/src/auth.rs create mode 100644 crates/ai2/src/completion.rs create mode 100644 crates/ai2/src/embedding.rs create mode 100644 crates/ai2/src/models.rs create mode 100644 crates/ai2/src/prompts/base.rs create mode 100644 crates/ai2/src/prompts/file_context.rs create mode 100644 crates/ai2/src/prompts/generate.rs create mode 100644 crates/ai2/src/prompts/mod.rs create mode 100644 crates/ai2/src/prompts/preamble.rs create mode 100644 crates/ai2/src/prompts/repository_context.rs create mode 100644 crates/ai2/src/providers/mod.rs create mode 100644 crates/ai2/src/providers/open_ai/completion.rs create mode 100644 crates/ai2/src/providers/open_ai/embedding.rs create mode 100644 crates/ai2/src/providers/open_ai/mod.rs create mode 100644 crates/ai2/src/providers/open_ai/model.rs create mode 100644 crates/ai2/src/providers/open_ai/new.rs create mode 100644 crates/ai2/src/test.rs diff --git a/Cargo.lock b/Cargo.lock index 0caaaeceeffdc78ce96d165721ee8eb4e9bae081..a5d187d08e5c20a77132c6e19024925520120af9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -108,6 +108,33 @@ dependencies = [ "util", ] +[[package]] +name = "ai2" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-trait", + "bincode", + "futures 0.3.28", + "gpui2", + "isahc", + "language2", + "lazy_static", + "log", + "matrixmultiply", + "ordered-float 2.10.0", + "parking_lot 0.11.2", + "parse_duration", + "postage", + "rand 0.8.5", + "regex", + "rusqlite", + "serde", + "serde_json", + "tiktoken-rs", + "util", +] + [[package]] name = "alacritty_config" version = "0.1.2-dev" @@ -10903,6 +10930,7 @@ dependencies = [ name = "zed2" version = "0.109.0" dependencies = [ + "ai2", "anyhow", "async-compression", "async-recursion 0.3.2", diff --git a/crates/Cargo.toml b/crates/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..fb49a4b515540836a757610db5c268321f9f068b --- /dev/null +++ b/crates/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "ai" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +path = "src/ai.rs" +doctest = false + +[features] +test-support = [] + +[dependencies] +gpui = { path = "../gpui" } +util = { path = "../util" } +language = { path = "../language" } +async-trait.workspace = true +anyhow.workspace = true +futures.workspace = true +lazy_static.workspace = true +ordered-float.workspace = true +parking_lot.workspace = true +isahc.workspace = true +regex.workspace = true +serde.workspace = true +serde_json.workspace = true +postage.workspace = true +rand.workspace = true +log.workspace = true +parse_duration = "2.1.1" +tiktoken-rs = "0.5.0" +matrixmultiply = "0.3.7" +rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] } +bincode = "1.3.3" + +[dev-dependencies] +gpui = { path = "../gpui", features = ["test-support"] } diff --git a/crates/ai2/Cargo.toml b/crates/ai2/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..4f06840e8e53bbcb06c377a6304fb8be13b85946 --- /dev/null +++ b/crates/ai2/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "ai2" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +path = "src/ai2.rs" +doctest = false + +[features] +test-support = [] + +[dependencies] +gpui2 = { path = "../gpui2" } +util = { path = "../util" } +language2 = { path = "../language2" } +async-trait.workspace = true +anyhow.workspace = true +futures.workspace = true +lazy_static.workspace = true +ordered-float.workspace = true +parking_lot.workspace = true +isahc.workspace = true +regex.workspace = true +serde.workspace = true +serde_json.workspace = true +postage.workspace = true +rand.workspace = true +log.workspace = true +parse_duration = "2.1.1" +tiktoken-rs = "0.5.0" +matrixmultiply = "0.3.7" +rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] } +bincode = "1.3.3" + +[dev-dependencies] +gpui2 = { path = "../gpui2", features = ["test-support"] } diff --git a/crates/ai2/src/ai2.rs b/crates/ai2/src/ai2.rs new file mode 100644 index 0000000000000000000000000000000000000000..dda22d2a1d04dd6083fb1ae9879f49e74c8b4627 --- /dev/null +++ b/crates/ai2/src/ai2.rs @@ -0,0 +1,8 @@ +pub mod auth; +pub mod completion; +pub mod embedding; +pub mod models; +pub mod prompts; +pub mod providers; +#[cfg(any(test, feature = "test-support"))] +pub mod test; diff --git a/crates/ai2/src/auth.rs b/crates/ai2/src/auth.rs new file mode 100644 index 0000000000000000000000000000000000000000..e4670bb449025d5ecc5f0cabe65ad6ff4727c10c --- /dev/null +++ b/crates/ai2/src/auth.rs @@ -0,0 +1,17 @@ +use async_trait::async_trait; +use gpui2::AppContext; + +#[derive(Clone, Debug)] +pub enum ProviderCredential { + Credentials { api_key: String }, + NoCredentials, + NotNeeded, +} + +#[async_trait] +pub trait CredentialProvider: Send + Sync { + fn has_credentials(&self) -> bool; + async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential; + async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential); + async fn delete_credentials(&self, cx: &mut AppContext); +} diff --git a/crates/ai2/src/completion.rs b/crates/ai2/src/completion.rs new file mode 100644 index 0000000000000000000000000000000000000000..30a60fcf1d5c5dc66717773968e432e510d6421f --- /dev/null +++ b/crates/ai2/src/completion.rs @@ -0,0 +1,23 @@ +use anyhow::Result; +use futures::{future::BoxFuture, stream::BoxStream}; + +use crate::{auth::CredentialProvider, models::LanguageModel}; + +pub trait CompletionRequest: Send + Sync { + fn data(&self) -> serde_json::Result; +} + +pub trait CompletionProvider: CredentialProvider { + fn base_model(&self) -> Box; + fn complete( + &self, + prompt: Box, + ) -> BoxFuture<'static, Result>>>; + fn box_clone(&self) -> Box; +} + +impl Clone for Box { + fn clone(&self) -> Box { + self.box_clone() + } +} diff --git a/crates/ai2/src/embedding.rs b/crates/ai2/src/embedding.rs new file mode 100644 index 0000000000000000000000000000000000000000..7ea47861782cf9002796a8b6e655989b871e0191 --- /dev/null +++ b/crates/ai2/src/embedding.rs @@ -0,0 +1,123 @@ +use std::time::Instant; + +use anyhow::Result; +use async_trait::async_trait; +use ordered_float::OrderedFloat; +use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; +use rusqlite::ToSql; + +use crate::auth::CredentialProvider; +use crate::models::LanguageModel; + +#[derive(Debug, PartialEq, Clone)] +pub struct Embedding(pub Vec); + +// This is needed for semantic index functionality +// Unfortunately it has to live wherever the "Embedding" struct is created. +// Keeping this in here though, introduces a 'rusqlite' dependency into AI +// which is less than ideal +impl FromSql for Embedding { + fn column_result(value: ValueRef) -> FromSqlResult { + let bytes = value.as_blob()?; + let embedding: Result, Box> = bincode::deserialize(bytes); + if embedding.is_err() { + return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err())); + } + Ok(Embedding(embedding.unwrap())) + } +} + +impl ToSql for Embedding { + fn to_sql(&self) -> rusqlite::Result { + let bytes = bincode::serialize(&self.0) + .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?; + Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes))) + } +} +impl From> for Embedding { + fn from(value: Vec) -> Self { + Embedding(value) + } +} + +impl Embedding { + pub fn similarity(&self, other: &Self) -> OrderedFloat { + let len = self.0.len(); + assert_eq!(len, other.0.len()); + + let mut result = 0.0; + unsafe { + matrixmultiply::sgemm( + 1, + len, + 1, + 1.0, + self.0.as_ptr(), + len as isize, + 1, + other.0.as_ptr(), + 1, + len as isize, + 0.0, + &mut result as *mut f32, + 1, + 1, + ); + } + OrderedFloat(result) + } +} + +#[async_trait] +pub trait EmbeddingProvider: CredentialProvider { + fn base_model(&self) -> Box; + async fn embed_batch(&self, spans: Vec) -> Result>; + fn max_tokens_per_batch(&self) -> usize; + fn rate_limit_expiration(&self) -> Option; +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::prelude::*; + + #[gpui2::test] + fn test_similarity(mut rng: StdRng) { + assert_eq!( + Embedding::from(vec![1., 0., 0., 0., 0.]) + .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])), + 0. + ); + assert_eq!( + Embedding::from(vec![2., 0., 0., 0., 0.]) + .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])), + 6. + ); + + for _ in 0..100 { + let size = 1536; + let mut a = vec![0.; size]; + let mut b = vec![0.; size]; + for (a, b) in a.iter_mut().zip(b.iter_mut()) { + *a = rng.gen(); + *b = rng.gen(); + } + let a = Embedding::from(a); + let b = Embedding::from(b); + + assert_eq!( + round_to_decimals(a.similarity(&b), 1), + round_to_decimals(reference_dot(&a.0, &b.0), 1) + ); + } + + fn round_to_decimals(n: OrderedFloat, decimal_places: i32) -> f32 { + let factor = (10.0 as f32).powi(decimal_places); + (n * factor).round() / factor + } + + fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat { + OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()) + } + } +} diff --git a/crates/ai2/src/models.rs b/crates/ai2/src/models.rs new file mode 100644 index 0000000000000000000000000000000000000000..1db3d58c6f54ad613cb98fc3f425df3d47e5e97f --- /dev/null +++ b/crates/ai2/src/models.rs @@ -0,0 +1,16 @@ +pub enum TruncationDirection { + Start, + End, +} + +pub trait LanguageModel { + fn name(&self) -> String; + fn count_tokens(&self, content: &str) -> anyhow::Result; + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result; + fn capacity(&self) -> anyhow::Result; +} diff --git a/crates/ai2/src/prompts/base.rs b/crates/ai2/src/prompts/base.rs new file mode 100644 index 0000000000000000000000000000000000000000..29091d0f5b435b556a0c2cae60aa4526370832ab --- /dev/null +++ b/crates/ai2/src/prompts/base.rs @@ -0,0 +1,330 @@ +use std::cmp::Reverse; +use std::ops::Range; +use std::sync::Arc; + +use language2::BufferSnapshot; +use util::ResultExt; + +use crate::models::LanguageModel; +use crate::prompts::repository_context::PromptCodeSnippet; + +pub(crate) enum PromptFileType { + Text, + Code, +} + +// TODO: Set this up to manage for defaults well +pub struct PromptArguments { + pub model: Arc, + pub user_prompt: Option, + pub language_name: Option, + pub project_name: Option, + pub snippets: Vec, + pub reserved_tokens: usize, + pub buffer: Option, + pub selected_range: Option>, +} + +impl PromptArguments { + pub(crate) fn get_file_type(&self) -> PromptFileType { + if self + .language_name + .as_ref() + .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str()))) + .unwrap_or(true) + { + PromptFileType::Code + } else { + PromptFileType::Text + } + } +} + +pub trait PromptTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)>; +} + +#[repr(i8)] +#[derive(PartialEq, Eq, Ord)] +pub enum PromptPriority { + Mandatory, // Ignores truncation + Ordered { order: usize }, // Truncates based on priority +} + +impl PartialOrd for PromptPriority { + fn partial_cmp(&self, other: &Self) -> Option { + match (self, other) { + (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal), + (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater), + (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less), + (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a), + } + } +} + +pub struct PromptChain { + args: PromptArguments, + templates: Vec<(PromptPriority, Box)>, +} + +impl PromptChain { + pub fn new( + args: PromptArguments, + templates: Vec<(PromptPriority, Box)>, + ) -> Self { + PromptChain { args, templates } + } + + pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> { + // Argsort based on Prompt Priority + let seperator = "\n"; + let seperator_tokens = self.args.model.count_tokens(seperator)?; + let mut sorted_indices = (0..self.templates.len()).collect::>(); + sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0)); + + // If Truncate + let mut tokens_outstanding = if truncate { + Some(self.args.model.capacity()? - self.args.reserved_tokens) + } else { + None + }; + + let mut prompts = vec!["".to_string(); sorted_indices.len()]; + for idx in sorted_indices { + let (_, template) = &self.templates[idx]; + + if let Some((template_prompt, prompt_token_count)) = + template.generate(&self.args, tokens_outstanding).log_err() + { + if template_prompt != "" { + prompts[idx] = template_prompt; + + if let Some(remaining_tokens) = tokens_outstanding { + let new_tokens = prompt_token_count + seperator_tokens; + tokens_outstanding = if remaining_tokens > new_tokens { + Some(remaining_tokens - new_tokens) + } else { + Some(0) + }; + } + } + } + } + + prompts.retain(|x| x != ""); + + let full_prompt = prompts.join(seperator); + let total_token_count = self.args.model.count_tokens(&full_prompt)?; + anyhow::Ok((prompts.join(seperator), total_token_count)) + } +} + +#[cfg(test)] +pub(crate) mod tests { + use crate::models::TruncationDirection; + use crate::test::FakeLanguageModel; + + use super::*; + + #[test] + pub fn test_prompt_chain() { + struct TestPromptTemplate {} + impl PromptTemplate for TestPromptTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut content = "This is a test prompt template".to_string(); + + let mut token_count = args.model.count_tokens(&content)?; + if let Some(max_token_length) = max_token_length { + if token_count > max_token_length { + content = args.model.truncate( + &content, + max_token_length, + TruncationDirection::End, + )?; + token_count = max_token_length; + } + } + + anyhow::Ok((content, token_count)) + } + } + + struct TestLowPriorityTemplate {} + impl PromptTemplate for TestLowPriorityTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut content = "This is a low priority test prompt template".to_string(); + + let mut token_count = args.model.count_tokens(&content)?; + if let Some(max_token_length) = max_token_length { + if token_count > max_token_length { + content = args.model.truncate( + &content, + max_token_length, + TruncationDirection::End, + )?; + token_count = max_token_length; + } + } + + anyhow::Ok((content, token_count)) + } + } + + let model: Arc = Arc::new(FakeLanguageModel { capacity: 100 }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(false).unwrap(); + + assert_eq!( + prompt, + "This is a test prompt template\nThis is a low priority test prompt template" + .to_string() + ); + + assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); + + // Testing with Truncation Off + // Should ignore capacity and return all prompts + let model: Arc = Arc::new(FakeLanguageModel { capacity: 20 }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(false).unwrap(); + + assert_eq!( + prompt, + "This is a test prompt template\nThis is a low priority test prompt template" + .to_string() + ); + + assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); + + // Testing with Truncation Off + // Should ignore capacity and return all prompts + let capacity = 20; + let model: Arc = Arc::new(FakeLanguageModel { capacity }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ( + PromptPriority::Ordered { order: 2 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(true).unwrap(); + + assert_eq!(prompt, "This is a test promp".to_string()); + assert_eq!(token_count, capacity); + + // Change Ordering of Prompts Based on Priority + let capacity = 120; + let reserved_tokens = 10; + let model: Arc = Arc::new(FakeLanguageModel { capacity }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens, + buffer: None, + selected_range: None, + user_prompt: None, + }; + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Mandatory, + Box::new(TestLowPriorityTemplate {}), + ), + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(true).unwrap(); + + assert_eq!( + prompt, + "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt " + .to_string() + ); + assert_eq!(token_count, capacity - reserved_tokens); + } +} diff --git a/crates/ai2/src/prompts/file_context.rs b/crates/ai2/src/prompts/file_context.rs new file mode 100644 index 0000000000000000000000000000000000000000..4a741beb24984c5c038ef10439d5b182438d2866 --- /dev/null +++ b/crates/ai2/src/prompts/file_context.rs @@ -0,0 +1,164 @@ +use anyhow::anyhow; +use language2::BufferSnapshot; +use language2::ToOffset; + +use crate::models::LanguageModel; +use crate::models::TruncationDirection; +use crate::prompts::base::PromptArguments; +use crate::prompts::base::PromptTemplate; +use std::fmt::Write; +use std::ops::Range; +use std::sync::Arc; + +fn retrieve_context( + buffer: &BufferSnapshot, + selected_range: &Option>, + model: Arc, + max_token_count: Option, +) -> anyhow::Result<(String, usize, bool)> { + let mut prompt = String::new(); + let mut truncated = false; + if let Some(selected_range) = selected_range { + let start = selected_range.start.to_offset(buffer); + let end = selected_range.end.to_offset(buffer); + + let start_window = buffer.text_for_range(0..start).collect::(); + + let mut selected_window = String::new(); + if start == end { + write!(selected_window, "<|START|>").unwrap(); + } else { + write!(selected_window, "<|START|").unwrap(); + } + + write!( + selected_window, + "{}", + buffer.text_for_range(start..end).collect::() + ) + .unwrap(); + + if start != end { + write!(selected_window, "|END|>").unwrap(); + } + + let end_window = buffer.text_for_range(end..buffer.len()).collect::(); + + if let Some(max_token_count) = max_token_count { + let selected_tokens = model.count_tokens(&selected_window)?; + if selected_tokens > max_token_count { + return Err(anyhow!( + "selected range is greater than model context window, truncation not possible" + )); + }; + + let mut remaining_tokens = max_token_count - selected_tokens; + let start_window_tokens = model.count_tokens(&start_window)?; + let end_window_tokens = model.count_tokens(&end_window)?; + let outside_tokens = start_window_tokens + end_window_tokens; + if outside_tokens > remaining_tokens { + let (start_goal_tokens, end_goal_tokens) = + if start_window_tokens < end_window_tokens { + let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens); + remaining_tokens -= start_goal_tokens; + let end_goal_tokens = remaining_tokens.min(end_window_tokens); + (start_goal_tokens, end_goal_tokens) + } else { + let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens); + remaining_tokens -= end_goal_tokens; + let start_goal_tokens = remaining_tokens.min(start_window_tokens); + (start_goal_tokens, end_goal_tokens) + }; + + let truncated_start_window = + model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?; + let truncated_end_window = + model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?; + writeln!( + prompt, + "{truncated_start_window}{selected_window}{truncated_end_window}" + ) + .unwrap(); + truncated = true; + } else { + writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap(); + } + } else { + // If we dont have a selected range, include entire file. + writeln!(prompt, "{}", &buffer.text()).unwrap(); + + // Dumb truncation strategy + if let Some(max_token_count) = max_token_count { + if model.count_tokens(&prompt)? > max_token_count { + truncated = true; + prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?; + } + } + } + } + + let token_count = model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count, truncated)) +} + +pub struct FileContext {} + +impl PromptTemplate for FileContext { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + if let Some(buffer) = &args.buffer { + let mut prompt = String::new(); + // Add Initial Preamble + // TODO: Do we want to add the path in here? + writeln!( + prompt, + "The file you are currently working on has the following content:" + ) + .unwrap(); + + let language_name = args + .language_name + .clone() + .unwrap_or("".to_string()) + .to_lowercase(); + + let (context, _, truncated) = retrieve_context( + buffer, + &args.selected_range, + args.model.clone(), + max_token_length, + )?; + writeln!(prompt, "```{language_name}\n{context}\n```").unwrap(); + + if truncated { + writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap(); + } + + if let Some(selected_range) = &args.selected_range { + let start = selected_range.start.to_offset(buffer); + let end = selected_range.end.to_offset(buffer); + + if start == end { + writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap(); + } else { + writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap(); + } + } + + // Really dumb truncation strategy + if let Some(max_tokens) = max_token_length { + prompt = args + .model + .truncate(&prompt, max_tokens, TruncationDirection::End)?; + } + + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } else { + Err(anyhow!("no buffer provided to retrieve file context from")) + } + } +} diff --git a/crates/ai2/src/prompts/generate.rs b/crates/ai2/src/prompts/generate.rs new file mode 100644 index 0000000000000000000000000000000000000000..c7be620107ee4d6daca06a8cb38019aceedc40a4 --- /dev/null +++ b/crates/ai2/src/prompts/generate.rs @@ -0,0 +1,99 @@ +use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate}; +use anyhow::anyhow; +use std::fmt::Write; + +pub fn capitalize(s: &str) -> String { + let mut c = s.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + c.as_str(), + } +} + +pub struct GenerateInlineContent {} + +impl PromptTemplate for GenerateInlineContent { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let Some(user_prompt) = &args.user_prompt else { + return Err(anyhow!("user prompt not provided")); + }; + + let file_type = args.get_file_type(); + let content_type = match &file_type { + PromptFileType::Code => "code", + PromptFileType::Text => "text", + }; + + let mut prompt = String::new(); + + if let Some(selected_range) = &args.selected_range { + if selected_range.start == selected_range.end { + writeln!( + prompt, + "Assume the cursor is located where the `<|START|>` span is." + ) + .unwrap(); + writeln!( + prompt, + "{} can't be replaced, so assume your answer will be inserted at the cursor.", + capitalize(content_type) + ) + .unwrap(); + writeln!( + prompt, + "Generate {content_type} based on the users prompt: {user_prompt}", + ) + .unwrap(); + } else { + writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap(); + writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap(); + writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap(); + } + } else { + writeln!( + prompt, + "Generate {content_type} based on the users prompt: {user_prompt}" + ) + .unwrap(); + } + + if let Some(language_name) = &args.language_name { + writeln!( + prompt, + "Your answer MUST always and only be valid {}.", + language_name + ) + .unwrap(); + } + writeln!(prompt, "Never make remarks about the output.").unwrap(); + writeln!( + prompt, + "Do not return anything else, except the generated {content_type}." + ) + .unwrap(); + + match file_type { + PromptFileType::Code => { + // writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap(); + } + _ => {} + } + + // Really dumb truncation strategy + if let Some(max_tokens) = max_token_length { + prompt = args.model.truncate( + &prompt, + max_tokens, + crate::models::TruncationDirection::End, + )?; + } + + let token_count = args.model.count_tokens(&prompt)?; + + anyhow::Ok((prompt, token_count)) + } +} diff --git a/crates/ai2/src/prompts/mod.rs b/crates/ai2/src/prompts/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..0025269a440d1e6ead6a81615a64a3c28da62bb8 --- /dev/null +++ b/crates/ai2/src/prompts/mod.rs @@ -0,0 +1,5 @@ +pub mod base; +pub mod file_context; +pub mod generate; +pub mod preamble; +pub mod repository_context; diff --git a/crates/ai2/src/prompts/preamble.rs b/crates/ai2/src/prompts/preamble.rs new file mode 100644 index 0000000000000000000000000000000000000000..92e0edeb78b48169379aae2e88e81f62463a1057 --- /dev/null +++ b/crates/ai2/src/prompts/preamble.rs @@ -0,0 +1,52 @@ +use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate}; +use std::fmt::Write; + +pub struct EngineerPreamble {} + +impl PromptTemplate for EngineerPreamble { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut prompts = Vec::new(); + + match args.get_file_type() { + PromptFileType::Code => { + prompts.push(format!( + "You are an expert {}engineer.", + args.language_name.clone().unwrap_or("".to_string()) + " " + )); + } + PromptFileType::Text => { + prompts.push("You are an expert engineer.".to_string()); + } + } + + if let Some(project_name) = args.project_name.clone() { + prompts.push(format!( + "You are currently working inside the '{project_name}' project in code editor Zed." + )); + } + + if let Some(mut remaining_tokens) = max_token_length { + let mut prompt = String::new(); + let mut total_count = 0; + for prompt_piece in prompts { + let prompt_token_count = + args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?; + if remaining_tokens > prompt_token_count { + writeln!(prompt, "{prompt_piece}").unwrap(); + remaining_tokens -= prompt_token_count; + total_count += prompt_token_count; + } + } + + anyhow::Ok((prompt, total_count)) + } else { + let prompt = prompts.join("\n"); + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } + } +} diff --git a/crates/ai2/src/prompts/repository_context.rs b/crates/ai2/src/prompts/repository_context.rs new file mode 100644 index 0000000000000000000000000000000000000000..78db5a16516f9797927f10ff136905054f36fa65 --- /dev/null +++ b/crates/ai2/src/prompts/repository_context.rs @@ -0,0 +1,98 @@ +use crate::prompts::base::{PromptArguments, PromptTemplate}; +use std::fmt::Write; +use std::{ops::Range, path::PathBuf}; + +use gpui2::{AsyncAppContext, Handle}; +use language2::{Anchor, Buffer}; + +#[derive(Clone)] +pub struct PromptCodeSnippet { + path: Option, + language_name: Option, + content: String, +} + +impl PromptCodeSnippet { + pub fn new( + buffer: Handle, + range: Range, + cx: &mut AsyncAppContext, + ) -> anyhow::Result { + let (content, language_name, file_path) = buffer.update(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + let content = snapshot.text_for_range(range.clone()).collect::(); + + let language_name = buffer + .language() + .and_then(|language| Some(language.name().to_string().to_lowercase())); + + let file_path = buffer + .file() + .and_then(|file| Some(file.path().to_path_buf())); + + (content, language_name, file_path) + })?; + + anyhow::Ok(PromptCodeSnippet { + path: file_path, + language_name, + content, + }) + } +} + +impl ToString for PromptCodeSnippet { + fn to_string(&self) -> String { + let path = self + .path + .as_ref() + .and_then(|path| Some(path.to_string_lossy().to_string())) + .unwrap_or("".to_string()); + let language_name = self.language_name.clone().unwrap_or("".to_string()); + let content = self.content.clone(); + + format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```") + } +} + +pub struct RepositoryContext {} + +impl PromptTemplate for RepositoryContext { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; + let template = "You are working inside a large repository, here are a few code snippets that may be useful."; + let mut prompt = String::new(); + + let mut remaining_tokens = max_token_length.clone(); + let seperator_token_length = args.model.count_tokens("\n")?; + for snippet in &args.snippets { + let mut snippet_prompt = template.to_string(); + let content = snippet.to_string(); + writeln!(snippet_prompt, "{content}").unwrap(); + + let token_count = args.model.count_tokens(&snippet_prompt)?; + if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT { + if let Some(tokens_left) = remaining_tokens { + if tokens_left >= token_count { + writeln!(prompt, "{snippet_prompt}").unwrap(); + remaining_tokens = if tokens_left >= (token_count + seperator_token_length) + { + Some(tokens_left - token_count - seperator_token_length) + } else { + Some(0) + }; + } + } else { + writeln!(prompt, "{snippet_prompt}").unwrap(); + } + } + } + + let total_token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, total_token_count)) + } +} diff --git a/crates/ai2/src/providers/mod.rs b/crates/ai2/src/providers/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..acd0f9d91053869e3999ef0c1a23326480a7cbdd --- /dev/null +++ b/crates/ai2/src/providers/mod.rs @@ -0,0 +1 @@ +pub mod open_ai; diff --git a/crates/ai2/src/providers/open_ai/completion.rs b/crates/ai2/src/providers/open_ai/completion.rs new file mode 100644 index 0000000000000000000000000000000000000000..eca56110271a3be407a1c8b9f82cbb63c41bef23 --- /dev/null +++ b/crates/ai2/src/providers/open_ai/completion.rs @@ -0,0 +1,306 @@ +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use futures::{ + future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, + Stream, StreamExt, +}; +use gpui2::{AppContext, Executor}; +use isahc::{http::StatusCode, Request, RequestExt}; +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; +use std::{ + env, + fmt::{self, Display}, + io, + sync::Arc, +}; +use util::ResultExt; + +use crate::{ + auth::{CredentialProvider, ProviderCredential}, + completion::{CompletionProvider, CompletionRequest}, + models::LanguageModel, +}; + +use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL}; + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, +} + +impl Role { + pub fn cycle(&mut self) { + *self = match self { + Role::User => Role::Assistant, + Role::Assistant => Role::System, + Role::System => Role::User, + } + } +} + +impl Display for Role { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Role::User => write!(f, "User"), + Role::Assistant => write!(f, "Assistant"), + Role::System => write!(f, "System"), + } + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct RequestMessage { + pub role: Role, + pub content: String, +} + +#[derive(Debug, Default, Serialize)] +pub struct OpenAIRequest { + pub model: String, + pub messages: Vec, + pub stream: bool, + pub stop: Vec, + pub temperature: f32, +} + +impl CompletionRequest for OpenAIRequest { + fn data(&self) -> serde_json::Result { + serde_json::to_string(self) + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ResponseMessage { + pub role: Option, + pub content: Option, +} + +#[derive(Deserialize, Debug)] +pub struct OpenAIUsage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Deserialize, Debug)] +pub struct ChatChoiceDelta { + pub index: u32, + pub delta: ResponseMessage, + pub finish_reason: Option, +} + +#[derive(Deserialize, Debug)] +pub struct OpenAIResponseStreamEvent { + pub id: Option, + pub object: String, + pub created: u32, + pub model: String, + pub choices: Vec, + pub usage: Option, +} + +pub async fn stream_completion( + credential: ProviderCredential, + executor: Arc, + request: Box, +) -> Result>> { + let api_key = match credential { + ProviderCredential::Credentials { api_key } => api_key, + _ => { + return Err(anyhow!("no credentials provider for completion")); + } + }; + + let (tx, rx) = futures::channel::mpsc::unbounded::>(); + + let json_data = request.data()?; + let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body(json_data)? + .send_async() + .await?; + + let status = response.status(); + if status == StatusCode::OK { + executor + .spawn(async move { + let mut lines = BufReader::new(response.body_mut()).lines(); + + fn parse_line( + line: Result, + ) -> Result> { + if let Some(data) = line?.strip_prefix("data: ") { + let event = serde_json::from_str(&data)?; + Ok(Some(event)) + } else { + Ok(None) + } + } + + while let Some(line) = lines.next().await { + if let Some(event) = parse_line(line).transpose() { + let done = event.as_ref().map_or(false, |event| { + event + .choices + .last() + .map_or(false, |choice| choice.finish_reason.is_some()) + }); + if tx.unbounded_send(event).is_err() { + break; + } + + if done { + break; + } + } + } + + anyhow::Ok(()) + }) + .detach(); + + Ok(rx) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + #[derive(Deserialize)] + struct OpenAIResponse { + error: OpenAIError, + } + + #[derive(Deserialize)] + struct OpenAIError { + message: String, + } + + match serde_json::from_str::(&body) { + Ok(response) if !response.error.message.is_empty() => Err(anyhow!( + "Failed to connect to OpenAI API: {}", + response.error.message, + )), + + _ => Err(anyhow!( + "Failed to connect to OpenAI API: {} {}", + response.status(), + body, + )), + } + } +} + +#[derive(Clone)] +pub struct OpenAICompletionProvider { + model: OpenAILanguageModel, + credential: Arc>, + executor: Arc, +} + +impl OpenAICompletionProvider { + pub fn new(model_name: &str, executor: Arc) -> Self { + let model = OpenAILanguageModel::load(model_name); + let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); + Self { + model, + credential, + executor, + } + } +} + +#[async_trait] +impl CredentialProvider for OpenAICompletionProvider { + fn has_credentials(&self) -> bool { + match *self.credential.read() { + ProviderCredential::Credentials { .. } => true, + _ => false, + } + } + async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential { + let existing_credential = self.credential.read().clone(); + + let retrieved_credential = cx + .run_on_main(move |cx| match existing_credential { + ProviderCredential::Credentials { .. } => { + return existing_credential.clone(); + } + _ => { + if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() { + return ProviderCredential::Credentials { api_key }; + } + + if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err() + { + if let Some(api_key) = String::from_utf8(api_key).log_err() { + return ProviderCredential::Credentials { api_key }; + } else { + return ProviderCredential::NoCredentials; + } + } else { + return ProviderCredential::NoCredentials; + } + } + }) + .await; + + *self.credential.write() = retrieved_credential.clone(); + retrieved_credential + } + + async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) { + *self.credential.write() = credential.clone(); + let credential = credential.clone(); + cx.run_on_main(move |cx| match credential { + ProviderCredential::Credentials { api_key } => { + cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) + .log_err(); + } + _ => {} + }) + .await; + } + async fn delete_credentials(&self, cx: &mut AppContext) { + cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err()) + .await; + *self.credential.write() = ProviderCredential::NoCredentials; + } +} + +impl CompletionProvider for OpenAICompletionProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(self.model.clone()); + model + } + fn complete( + &self, + prompt: Box, + ) -> BoxFuture<'static, Result>>> { + // Currently the CompletionRequest for OpenAI, includes a 'model' parameter + // This means that the model is determined by the CompletionRequest and not the CompletionProvider, + // which is currently model based, due to the langauge model. + // At some point in the future we should rectify this. + let credential = self.credential.read().clone(); + let request = stream_completion(credential, self.executor.clone(), prompt); + async move { + let response = request.await?; + let stream = response + .filter_map(|response| async move { + match response { + Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), + Err(error) => Some(Err(error)), + } + }) + .boxed(); + Ok(stream) + } + .boxed() + } + fn box_clone(&self) -> Box { + Box::new((*self).clone()) + } +} diff --git a/crates/ai2/src/providers/open_ai/embedding.rs b/crates/ai2/src/providers/open_ai/embedding.rs new file mode 100644 index 0000000000000000000000000000000000000000..fc49c15134d0aba787968acbd412daff60ce6106 --- /dev/null +++ b/crates/ai2/src/providers/open_ai/embedding.rs @@ -0,0 +1,313 @@ +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use futures::AsyncReadExt; +use gpui2::Executor; +use gpui2::{serde_json, AppContext}; +use isahc::http::StatusCode; +use isahc::prelude::Configurable; +use isahc::{AsyncBody, Response}; +use lazy_static::lazy_static; +use parking_lot::{Mutex, RwLock}; +use parse_duration::parse; +use postage::watch; +use serde::{Deserialize, Serialize}; +use std::env; +use std::ops::Add; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tiktoken_rs::{cl100k_base, CoreBPE}; +use util::http::{HttpClient, Request}; +use util::ResultExt; + +use crate::auth::{CredentialProvider, ProviderCredential}; +use crate::embedding::{Embedding, EmbeddingProvider}; +use crate::models::LanguageModel; +use crate::providers::open_ai::OpenAILanguageModel; + +use crate::providers::open_ai::OPENAI_API_URL; + +lazy_static! { + static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); +} + +#[derive(Clone)] +pub struct OpenAIEmbeddingProvider { + model: OpenAILanguageModel, + credential: Arc>, + pub client: Arc, + pub executor: Arc, + rate_limit_count_rx: watch::Receiver>, + rate_limit_count_tx: Arc>>>, +} + +#[derive(Serialize)] +struct OpenAIEmbeddingRequest<'a> { + model: &'static str, + input: Vec<&'a str>, +} + +#[derive(Deserialize)] +struct OpenAIEmbeddingResponse { + data: Vec, + usage: OpenAIEmbeddingUsage, +} + +#[derive(Debug, Deserialize)] +struct OpenAIEmbedding { + embedding: Vec, + index: usize, + object: String, +} + +#[derive(Deserialize)] +struct OpenAIEmbeddingUsage { + prompt_tokens: usize, + total_tokens: usize, +} + +impl OpenAIEmbeddingProvider { + pub fn new(client: Arc, executor: Arc) -> Self { + let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); + let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); + + let model = OpenAILanguageModel::load("text-embedding-ada-002"); + let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); + + OpenAIEmbeddingProvider { + model, + credential, + client, + executor, + rate_limit_count_rx, + rate_limit_count_tx, + } + } + + fn get_api_key(&self) -> Result { + match self.credential.read().clone() { + ProviderCredential::Credentials { api_key } => Ok(api_key), + _ => Err(anyhow!("api credentials not provided")), + } + } + + fn resolve_rate_limit(&self) { + let reset_time = *self.rate_limit_count_tx.lock().borrow(); + + if let Some(reset_time) = reset_time { + if Instant::now() >= reset_time { + *self.rate_limit_count_tx.lock().borrow_mut() = None + } + } + + log::trace!( + "resolving reset time: {:?}", + *self.rate_limit_count_tx.lock().borrow() + ); + } + + fn update_reset_time(&self, reset_time: Instant) { + let original_time = *self.rate_limit_count_tx.lock().borrow(); + + let updated_time = if let Some(original_time) = original_time { + if reset_time < original_time { + Some(reset_time) + } else { + Some(original_time) + } + } else { + Some(reset_time) + }; + + log::trace!("updating rate limit time: {:?}", updated_time); + + *self.rate_limit_count_tx.lock().borrow_mut() = updated_time; + } + async fn send_request( + &self, + api_key: &str, + spans: Vec<&str>, + request_timeout: u64, + ) -> Result> { + let request = Request::post("https://api.openai.com/v1/embeddings") + .redirect_policy(isahc::config::RedirectPolicy::Follow) + .timeout(Duration::from_secs(request_timeout)) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body( + serde_json::to_string(&OpenAIEmbeddingRequest { + input: spans.clone(), + model: "text-embedding-ada-002", + }) + .unwrap() + .into(), + )?; + + Ok(self.client.send(request).await?) + } +} + +#[async_trait] +impl CredentialProvider for OpenAIEmbeddingProvider { + fn has_credentials(&self) -> bool { + match *self.credential.read() { + ProviderCredential::Credentials { .. } => true, + _ => false, + } + } + async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential { + let existing_credential = self.credential.read().clone(); + + let retrieved_credential = cx + .run_on_main(move |cx| match existing_credential { + ProviderCredential::Credentials { .. } => { + return existing_credential.clone(); + } + _ => { + if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() { + return ProviderCredential::Credentials { api_key }; + } + + if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err() + { + if let Some(api_key) = String::from_utf8(api_key).log_err() { + return ProviderCredential::Credentials { api_key }; + } else { + return ProviderCredential::NoCredentials; + } + } else { + return ProviderCredential::NoCredentials; + } + } + }) + .await; + + *self.credential.write() = retrieved_credential.clone(); + retrieved_credential + } + + async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) { + *self.credential.write() = credential.clone(); + let credential = credential.clone(); + cx.run_on_main(move |cx| match credential { + ProviderCredential::Credentials { api_key } => { + cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) + .log_err(); + } + _ => {} + }) + .await; + } + async fn delete_credentials(&self, cx: &mut AppContext) { + cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err()) + .await; + *self.credential.write() = ProviderCredential::NoCredentials; + } +} + +#[async_trait] +impl EmbeddingProvider for OpenAIEmbeddingProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(self.model.clone()); + model + } + + fn max_tokens_per_batch(&self) -> usize { + 50000 + } + + fn rate_limit_expiration(&self) -> Option { + *self.rate_limit_count_rx.borrow() + } + + async fn embed_batch(&self, spans: Vec) -> Result> { + const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; + const MAX_RETRIES: usize = 4; + + let api_key = self.get_api_key()?; + + let mut request_number = 0; + let mut rate_limiting = false; + let mut request_timeout: u64 = 15; + let mut response: Response; + while request_number < MAX_RETRIES { + response = self + .send_request( + &api_key, + spans.iter().map(|x| &**x).collect(), + request_timeout, + ) + .await?; + + request_number += 1; + + match response.status() { + StatusCode::REQUEST_TIMEOUT => { + request_timeout += 5; + } + StatusCode::OK => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; + + log::trace!( + "openai embedding completed. tokens: {:?}", + response.usage.total_tokens + ); + + // If we complete a request successfully that was previously rate_limited + // resolve the rate limit + if rate_limiting { + self.resolve_rate_limit() + } + + return Ok(response + .data + .into_iter() + .map(|embedding| Embedding::from(embedding.embedding)) + .collect()); + } + StatusCode::TOO_MANY_REQUESTS => { + rate_limiting = true; + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + let delay_duration = { + let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); + if let Some(time_to_reset) = + response.headers().get("x-ratelimit-reset-tokens") + { + if let Ok(time_str) = time_to_reset.to_str() { + parse(time_str).unwrap_or(delay) + } else { + delay + } + } else { + delay + } + }; + + // If we've previously rate limited, increment the duration but not the count + let reset_time = Instant::now().add(delay_duration); + self.update_reset_time(reset_time); + + log::trace!( + "openai rate limiting: waiting {:?} until lifted", + &delay_duration + ); + + self.executor.timer(delay_duration).await; + } + _ => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + return Err(anyhow!( + "open ai bad request: {:?} {:?}", + &response.status(), + body + )); + } + } + } + Err(anyhow!("openai max retries")) + } +} diff --git a/crates/ai2/src/providers/open_ai/mod.rs b/crates/ai2/src/providers/open_ai/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..7d2f86045d9e1049c55111aef175ac9b56dc7e16 --- /dev/null +++ b/crates/ai2/src/providers/open_ai/mod.rs @@ -0,0 +1,9 @@ +pub mod completion; +pub mod embedding; +pub mod model; + +pub use completion::*; +pub use embedding::*; +pub use model::OpenAILanguageModel; + +pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; diff --git a/crates/ai2/src/providers/open_ai/model.rs b/crates/ai2/src/providers/open_ai/model.rs new file mode 100644 index 0000000000000000000000000000000000000000..6e306c80b905865c011c9064934827085ca126d6 --- /dev/null +++ b/crates/ai2/src/providers/open_ai/model.rs @@ -0,0 +1,57 @@ +use anyhow::anyhow; +use tiktoken_rs::CoreBPE; +use util::ResultExt; + +use crate::models::{LanguageModel, TruncationDirection}; + +#[derive(Clone)] +pub struct OpenAILanguageModel { + name: String, + bpe: Option, +} + +impl OpenAILanguageModel { + pub fn load(model_name: &str) -> Self { + let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err(); + OpenAILanguageModel { + name: model_name.to_string(), + bpe, + } + } +} + +impl LanguageModel for OpenAILanguageModel { + fn name(&self) -> String { + self.name.clone() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + anyhow::Ok(bpe.encode_with_special_tokens(content).len()) + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + let tokens = bpe.encode_with_special_tokens(content); + if tokens.len() > length { + match direction { + TruncationDirection::End => bpe.decode(tokens[..length].to_vec()), + TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()), + } + } else { + bpe.decode(tokens) + } + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name)) + } +} diff --git a/crates/ai2/src/providers/open_ai/new.rs b/crates/ai2/src/providers/open_ai/new.rs new file mode 100644 index 0000000000000000000000000000000000000000..c7d67f2ba1d252a6865124d8ffdfb79130a8c3a0 --- /dev/null +++ b/crates/ai2/src/providers/open_ai/new.rs @@ -0,0 +1,11 @@ +pub trait LanguageModel { + fn name(&self) -> String; + fn count_tokens(&self, content: &str) -> anyhow::Result; + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result; + fn capacity(&self) -> anyhow::Result; +} diff --git a/crates/ai2/src/test.rs b/crates/ai2/src/test.rs new file mode 100644 index 0000000000000000000000000000000000000000..ee88529aecb004ce3b725fb61abd679359673404 --- /dev/null +++ b/crates/ai2/src/test.rs @@ -0,0 +1,193 @@ +use std::{ + sync::atomic::{self, AtomicUsize, Ordering}, + time::Instant, +}; + +use async_trait::async_trait; +use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use gpui2::AppContext; +use parking_lot::Mutex; + +use crate::{ + auth::{CredentialProvider, ProviderCredential}, + completion::{CompletionProvider, CompletionRequest}, + embedding::{Embedding, EmbeddingProvider}, + models::{LanguageModel, TruncationDirection}, +}; + +#[derive(Clone)] +pub struct FakeLanguageModel { + pub capacity: usize, +} + +impl LanguageModel for FakeLanguageModel { + fn name(&self) -> String { + "dummy".to_string() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + anyhow::Ok(content.chars().collect::>().len()) + } + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { + println!("TRYING TO TRUNCATE: {:?}", length.clone()); + + if length > self.count_tokens(content)? { + println!("NOT TRUNCATING"); + return anyhow::Ok(content.to_string()); + } + + anyhow::Ok(match direction { + TruncationDirection::End => content.chars().collect::>()[..length] + .into_iter() + .collect::(), + TruncationDirection::Start => content.chars().collect::>()[length..] + .into_iter() + .collect::(), + }) + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(self.capacity) + } +} + +pub struct FakeEmbeddingProvider { + pub embedding_count: AtomicUsize, +} + +impl Clone for FakeEmbeddingProvider { + fn clone(&self) -> Self { + FakeEmbeddingProvider { + embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)), + } + } +} + +impl Default for FakeEmbeddingProvider { + fn default() -> Self { + FakeEmbeddingProvider { + embedding_count: AtomicUsize::default(), + } + } +} + +impl FakeEmbeddingProvider { + pub fn embedding_count(&self) -> usize { + self.embedding_count.load(atomic::Ordering::SeqCst) + } + + pub fn embed_sync(&self, span: &str) -> Embedding { + let mut result = vec![1.0; 26]; + for letter in span.chars() { + let letter = letter.to_ascii_lowercase(); + if letter as u32 >= 'a' as u32 { + let ix = (letter as u32) - ('a' as u32); + if ix < 26 { + result[ix as usize] += 1.0; + } + } + } + + let norm = result.iter().map(|x| x * x).sum::().sqrt(); + for x in &mut result { + *x /= norm; + } + + result.into() + } +} + +#[async_trait] +impl CredentialProvider for FakeEmbeddingProvider { + fn has_credentials(&self) -> bool { + true + } + async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential { + ProviderCredential::NotNeeded + } + async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {} + async fn delete_credentials(&self, _cx: &mut AppContext) {} +} + +#[async_trait] +impl EmbeddingProvider for FakeEmbeddingProvider { + fn base_model(&self) -> Box { + Box::new(FakeLanguageModel { capacity: 1000 }) + } + fn max_tokens_per_batch(&self) -> usize { + 1000 + } + + fn rate_limit_expiration(&self) -> Option { + None + } + + async fn embed_batch(&self, spans: Vec) -> anyhow::Result> { + self.embedding_count + .fetch_add(spans.len(), atomic::Ordering::SeqCst); + + anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) + } +} + +pub struct FakeCompletionProvider { + last_completion_tx: Mutex>>, +} + +impl Clone for FakeCompletionProvider { + fn clone(&self) -> Self { + Self { + last_completion_tx: Mutex::new(None), + } + } +} + +impl FakeCompletionProvider { + pub fn new() -> Self { + Self { + last_completion_tx: Mutex::new(None), + } + } + + pub fn send_completion(&self, completion: impl Into) { + let mut tx = self.last_completion_tx.lock(); + tx.as_mut().unwrap().try_send(completion.into()).unwrap(); + } + + pub fn finish_completion(&self) { + self.last_completion_tx.lock().take().unwrap(); + } +} + +#[async_trait] +impl CredentialProvider for FakeCompletionProvider { + fn has_credentials(&self) -> bool { + true + } + async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential { + ProviderCredential::NotNeeded + } + async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {} + async fn delete_credentials(&self, _cx: &mut AppContext) {} +} + +impl CompletionProvider for FakeCompletionProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(FakeLanguageModel { capacity: 8190 }); + model + } + fn complete( + &self, + _prompt: Box, + ) -> BoxFuture<'static, anyhow::Result>>> { + let (tx, rx) = mpsc::channel(1); + *self.last_completion_tx.lock() = Some(tx); + async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() + } + fn box_clone(&self) -> Box { + Box::new((*self).clone()) + } +} diff --git a/crates/zed2/Cargo.toml b/crates/zed2/Cargo.toml index a6b31871dd6d8290f60a0862aa96ef6433df364f..9f681a49e95565b58d8cba5ba2e775ba1d0e37bb 100644 --- a/crates/zed2/Cargo.toml +++ b/crates/zed2/Cargo.toml @@ -15,6 +15,7 @@ name = "Zed" path = "src/main.rs" [dependencies] +ai2 = { path = "../ai2"} # audio = { path = "../audio" } # activity_indicator = { path = "../activity_indicator" } # auto_update = { path = "../auto_update" } From 5ff70f7dbab60cda01d0e17e3941c6410d0a890e Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 30 Oct 2023 14:49:31 -0400 Subject: [PATCH 23/25] keeping this bad boy green during fmt checks --- crates/ui2/src/elements/icon.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/ui2/src/elements/icon.rs b/crates/ui2/src/elements/icon.rs index ef3644229685dbfaa2ad466c4d1b242272377479..2273ec24f2369e50bb9152eeedddfd5f9aad8cbf 100644 --- a/crates/ui2/src/elements/icon.rs +++ b/crates/ui2/src/elements/icon.rs @@ -36,7 +36,7 @@ impl IconColor { IconColor::Error => gpui2::red(), IconColor::Warning => gpui2::red(), IconColor::Success => gpui2::red(), - IconColor::Info => gpui2::red() + IconColor::Info => gpui2::red(), } } } From bc4f8fbf4e88f1c9d5139b840fbc58c557b94370 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 30 Oct 2023 19:53:48 +0100 Subject: [PATCH 24/25] Rename other references from "handle" to "model" Co-Authored-By: Max Co-Authored-By: Mikayla --- crates/call2/src/call2.rs | 6 +- crates/call2/src/participant.rs | 4 +- crates/call2/src/room.rs | 8 +- crates/client2/src/client2.rs | 16 +-- crates/copilot2/src/copilot2.rs | 8 +- crates/gpui2/src/app.rs | 10 +- crates/gpui2/src/app/entity_map.rs | 140 +++++++++++++------------- crates/gpui2/src/app/model_context.rs | 12 +-- crates/gpui2/src/interactive.rs | 2 +- crates/gpui2/src/view.rs | 4 +- crates/gpui2/src/window.rs | 14 +-- crates/project2/src/project2.rs | 26 ++--- crates/project2/src/terminals.rs | 6 +- 13 files changed, 128 insertions(+), 128 deletions(-) diff --git a/crates/call2/src/call2.rs b/crates/call2/src/call2.rs index ffa2e5e9dc6072a543aff43eec5d2582992d2144..d8678b7ed46d155ece9373a161bd59b0322e7bce 100644 --- a/crates/call2/src/call2.rs +++ b/crates/call2/src/call2.rs @@ -13,7 +13,7 @@ use collections::HashSet; use futures::{future::Shared, FutureExt}; use gpui2::{ AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Subscription, Task, - WeakHandle, + WeakModel, }; use postage::watch; use project2::Project; @@ -42,7 +42,7 @@ pub struct IncomingCall { pub struct ActiveCall { room: Option<(Model, Vec)>, pending_room_creation: Option, Arc>>>>, - location: Option>, + location: Option>, pending_invites: HashSet, incoming_call: ( watch::Sender>, @@ -347,7 +347,7 @@ impl ActiveCall { } } - pub fn location(&self) -> Option<&WeakHandle> { + pub fn location(&self) -> Option<&WeakModel> { self.location.as_ref() } diff --git a/crates/call2/src/participant.rs b/crates/call2/src/participant.rs index c5c873a78aa277d045b04da7571e5c9948677c10..7f3e91dbba0116a7b7f7ef5b1c471fb1a768529f 100644 --- a/crates/call2/src/participant.rs +++ b/crates/call2/src/participant.rs @@ -1,7 +1,7 @@ use anyhow::{anyhow, Result}; use client2::ParticipantIndex; use client2::{proto, User}; -use gpui2::WeakHandle; +use gpui2::WeakModel; pub use live_kit_client::Frame; use project2::Project; use std::{fmt, sync::Arc}; @@ -33,7 +33,7 @@ impl ParticipantLocation { #[derive(Clone, Default)] pub struct LocalParticipant { pub projects: Vec, - pub active_project: Option>, + pub active_project: Option>, } #[derive(Clone, Debug)] diff --git a/crates/call2/src/room.rs b/crates/call2/src/room.rs index 07873c4cd5887bd39cd48b2a48c29cac7fb9dc54..7f51c64d4b62ff4371dab4bffc36196b28843362 100644 --- a/crates/call2/src/room.rs +++ b/crates/call2/src/room.rs @@ -16,7 +16,7 @@ use collections::{BTreeMap, HashMap, HashSet}; use fs::Fs; use futures::{FutureExt, StreamExt}; use gpui2::{ - AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task, WeakHandle, + AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task, WeakModel, }; use language2::LanguageRegistry; use live_kit_client::{LocalTrackPublication, RemoteAudioTrackUpdate, RemoteVideoTrackUpdate}; @@ -61,8 +61,8 @@ pub struct Room { channel_id: Option, // live_kit: Option, status: RoomStatus, - shared_projects: HashSet>, - joined_projects: HashSet>, + shared_projects: HashSet>, + joined_projects: HashSet>, local_participant: LocalParticipant, remote_participants: BTreeMap, pending_participants: Vec>, @@ -424,7 +424,7 @@ impl Room { } async fn maintain_connection( - this: WeakHandle, + this: WeakModel, client: Arc, mut cx: AsyncAppContext, ) -> Result<()> { diff --git a/crates/client2/src/client2.rs b/crates/client2/src/client2.rs index dcea6ded4e61beb60ca22dd644b0fd216a3facac..19e8685c28cd55094b064ea0af60bbd6744fa475 100644 --- a/crates/client2/src/client2.rs +++ b/crates/client2/src/client2.rs @@ -14,8 +14,8 @@ use futures::{ future::BoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _, TryStreamExt, }; use gpui2::{ - serde_json, AnyHandle, AnyWeakHandle, AppContext, AsyncAppContext, Model, SemanticVersion, - Task, WeakHandle, + serde_json, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, Model, SemanticVersion, Task, + WeakModel, }; use lazy_static::lazy_static; use parking_lot::RwLock; @@ -227,7 +227,7 @@ struct ClientState { _reconnect_task: Option>, reconnect_interval: Duration, entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>, - models_by_message_type: HashMap, + models_by_message_type: HashMap, entity_types_by_message_type: HashMap, #[allow(clippy::type_complexity)] message_handlers: HashMap< @@ -236,7 +236,7 @@ struct ClientState { dyn Send + Sync + Fn( - AnyHandle, + AnyModel, Box, &Arc, AsyncAppContext, @@ -246,7 +246,7 @@ struct ClientState { } enum WeakSubscriber { - Entity { handle: AnyWeakHandle }, + Entity { handle: AnyWeakModel }, Pending(Vec>), } @@ -552,7 +552,7 @@ impl Client { #[track_caller] pub fn add_message_handler( self: &Arc, - entity: WeakHandle, + entity: WeakModel, handler: H, ) -> Subscription where @@ -594,7 +594,7 @@ impl Client { pub fn add_request_handler( self: &Arc, - model: WeakHandle, + model: WeakModel, handler: H, ) -> Subscription where @@ -628,7 +628,7 @@ impl Client { where M: EntityMessage, E: 'static + Send, - H: 'static + Send + Sync + Fn(AnyHandle, TypedEnvelope, Arc, AsyncAppContext) -> F, + H: 'static + Send + Sync + Fn(AnyModel, TypedEnvelope, Arc, AsyncAppContext) -> F, F: 'static + Future> + Send, { let model_type_id = TypeId::of::(); diff --git a/crates/copilot2/src/copilot2.rs b/crates/copilot2/src/copilot2.rs index 42b0e3aa41ad154b6b88fa494a2132aa35ba2c97..c3107a2f4731897199eb80677f5bc8857c7beb39 100644 --- a/crates/copilot2/src/copilot2.rs +++ b/crates/copilot2/src/copilot2.rs @@ -8,7 +8,7 @@ use collections::{HashMap, HashSet}; use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt}; use gpui2::{ AppContext, AsyncAppContext, Context, EntityId, EventEmitter, Model, ModelContext, Task, - WeakHandle, + WeakModel, }; use language2::{ language_settings::{all_language_settings, language_settings}, @@ -278,7 +278,7 @@ pub struct Copilot { http: Arc, node_runtime: Arc, server: CopilotServer, - buffers: HashSet>, + buffers: HashSet>, server_id: LanguageServerId, _subscription: gpui2::Subscription, } @@ -383,7 +383,7 @@ impl Copilot { new_server_id: LanguageServerId, http: Arc, node_runtime: Arc, - this: WeakHandle, + this: WeakModel, mut cx: AsyncAppContext, ) -> impl Future { async move { @@ -706,7 +706,7 @@ impl Copilot { Ok(()) } - fn unregister_buffer(&mut self, buffer: &WeakHandle) { + fn unregister_buffer(&mut self, buffer: &WeakModel) { if let Ok(server) = self.server.as_running() { if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) { server diff --git a/crates/gpui2/src/app.rs b/crates/gpui2/src/app.rs index c4ba6a472486106fbefa6d0d61625b2aa9d8331f..0a09bb2ff834bfb71d784d6a537c4772d1eb37aa 100644 --- a/crates/gpui2/src/app.rs +++ b/crates/gpui2/src/app.rs @@ -711,7 +711,7 @@ impl Context for AppContext { type Result = T; /// Build an entity that is owned by the application. The given function will be invoked with - /// a `ModelContext` and must return an object representing the entity. A `Handle` will be returned + /// a `ModelContext` and must return an object representing the entity. A `Model` will be returned /// which can be used to access the entity in a context. fn build_model( &mut self, @@ -724,18 +724,18 @@ impl Context for AppContext { }) } - /// Update the entity referenced by the given handle. The function is passed a mutable reference to the + /// Update the entity referenced by the given model. The function is passed a mutable reference to the /// entity along with a `ModelContext` for the entity. fn update_entity( &mut self, - handle: &Model, + model: &Model, update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R, ) -> R { self.update(|cx| { - let mut entity = cx.entities.lease(handle); + let mut entity = cx.entities.lease(model); let result = update( &mut entity, - &mut ModelContext::mutable(cx, handle.downgrade()), + &mut ModelContext::mutable(cx, model.downgrade()), ); cx.entities.end_lease(entity); result diff --git a/crates/gpui2/src/app/entity_map.rs b/crates/gpui2/src/app/entity_map.rs index 68f0a8fa484b6f25caffb607214e3f70462ab907..bfd457e4fff4df1d6a8e8f4a78b9a76f6cf26e58 100644 --- a/crates/gpui2/src/app/entity_map.rs +++ b/crates/gpui2/src/app/entity_map.rs @@ -61,21 +61,21 @@ impl EntityMap { where T: 'static + Send, { - let handle = slot.0; - self.entities.insert(handle.entity_id, Box::new(entity)); - handle + let model = slot.0; + self.entities.insert(model.entity_id, Box::new(entity)); + model } /// Move an entity to the stack. - pub fn lease<'a, T>(&mut self, handle: &'a Model) -> Lease<'a, T> { - self.assert_valid_context(handle); + pub fn lease<'a, T>(&mut self, model: &'a Model) -> Lease<'a, T> { + self.assert_valid_context(model); let entity = Some( self.entities - .remove(handle.entity_id) + .remove(model.entity_id) .expect("Circular entity lease. Is the entity already being updated?"), ); Lease { - handle, + model, entity, entity_type: PhantomData, } @@ -84,18 +84,18 @@ impl EntityMap { /// Return an entity after moving it to the stack. pub fn end_lease(&mut self, mut lease: Lease) { self.entities - .insert(lease.handle.entity_id, lease.entity.take().unwrap()); + .insert(lease.model.entity_id, lease.entity.take().unwrap()); } - pub fn read(&self, handle: &Model) -> &T { - self.assert_valid_context(handle); - self.entities[handle.entity_id].downcast_ref().unwrap() + pub fn read(&self, model: &Model) -> &T { + self.assert_valid_context(model); + self.entities[model.entity_id].downcast_ref().unwrap() } - fn assert_valid_context(&self, handle: &AnyHandle) { + fn assert_valid_context(&self, model: &AnyModel) { debug_assert!( - Weak::ptr_eq(&handle.entity_map, &Arc::downgrade(&self.ref_counts)), - "used a handle with the wrong context" + Weak::ptr_eq(&model.entity_map, &Arc::downgrade(&self.ref_counts)), + "used a model with the wrong context" ); } @@ -115,7 +115,7 @@ impl EntityMap { pub struct Lease<'a, T> { entity: Option, - pub handle: &'a Model, + pub model: &'a Model, entity_type: PhantomData, } @@ -145,13 +145,13 @@ impl<'a, T> Drop for Lease<'a, T> { #[derive(Deref, DerefMut)] pub struct Slot(Model); -pub struct AnyHandle { +pub struct AnyModel { pub(crate) entity_id: EntityId, entity_type: TypeId, entity_map: Weak>, } -impl AnyHandle { +impl AnyModel { fn new(id: EntityId, entity_type: TypeId, entity_map: Weak>) -> Self { Self { entity_id: id, @@ -164,8 +164,8 @@ impl AnyHandle { self.entity_id } - pub fn downgrade(&self) -> AnyWeakHandle { - AnyWeakHandle { + pub fn downgrade(&self) -> AnyWeakModel { + AnyWeakModel { entity_id: self.entity_id, entity_type: self.entity_type, entity_ref_counts: self.entity_map.clone(), @@ -175,7 +175,7 @@ impl AnyHandle { pub fn downcast(&self) -> Option> { if TypeId::of::() == self.entity_type { Some(Model { - any_handle: self.clone(), + any_model: self.clone(), entity_type: PhantomData, }) } else { @@ -184,16 +184,16 @@ impl AnyHandle { } } -impl Clone for AnyHandle { +impl Clone for AnyModel { fn clone(&self) -> Self { if let Some(entity_map) = self.entity_map.upgrade() { let entity_map = entity_map.read(); let count = entity_map .counts .get(self.entity_id) - .expect("detected over-release of a handle"); + .expect("detected over-release of a model"); let prev_count = count.fetch_add(1, SeqCst); - assert_ne!(prev_count, 0, "Detected over-release of a handle."); + assert_ne!(prev_count, 0, "Detected over-release of a model."); } Self { @@ -204,16 +204,16 @@ impl Clone for AnyHandle { } } -impl Drop for AnyHandle { +impl Drop for AnyModel { fn drop(&mut self) { if let Some(entity_map) = self.entity_map.upgrade() { let entity_map = entity_map.upgradable_read(); let count = entity_map .counts .get(self.entity_id) - .expect("Detected over-release of a handle."); + .expect("Detected over-release of a model."); let prev_count = count.fetch_sub(1, SeqCst); - assert_ne!(prev_count, 0, "Detected over-release of a handle."); + assert_ne!(prev_count, 0, "Detected over-release of a model."); if prev_count == 1 { // We were the last reference to this entity, so we can remove it. let mut entity_map = RwLockUpgradableReadGuard::upgrade(entity_map); @@ -223,31 +223,31 @@ impl Drop for AnyHandle { } } -impl From> for AnyHandle { - fn from(handle: Model) -> Self { - handle.any_handle +impl From> for AnyModel { + fn from(model: Model) -> Self { + model.any_model } } -impl Hash for AnyHandle { +impl Hash for AnyModel { fn hash(&self, state: &mut H) { self.entity_id.hash(state); } } -impl PartialEq for AnyHandle { +impl PartialEq for AnyModel { fn eq(&self, other: &Self) -> bool { self.entity_id == other.entity_id } } -impl Eq for AnyHandle {} +impl Eq for AnyModel {} #[derive(Deref, DerefMut)] pub struct Model { #[deref] #[deref_mut] - any_handle: AnyHandle, + any_model: AnyModel, entity_type: PhantomData, } @@ -260,14 +260,14 @@ impl Model { T: 'static, { Self { - any_handle: AnyHandle::new(id, TypeId::of::(), entity_map), + any_model: AnyModel::new(id, TypeId::of::(), entity_map), entity_type: PhantomData, } } - pub fn downgrade(&self) -> WeakHandle { - WeakHandle { - any_handle: self.any_handle.downgrade(), + pub fn downgrade(&self) -> WeakModel { + WeakModel { + any_model: self.any_model.downgrade(), entity_type: self.entity_type, } } @@ -276,7 +276,7 @@ impl Model { cx.entities.read(self) } - /// Update the entity referenced by this handle with the given function. + /// Update the entity referenced by this model with the given function. /// /// The update function receives a context appropriate for its environment. /// When updating in an `AppContext`, it receives a `ModelContext`. @@ -296,7 +296,7 @@ impl Model { impl Clone for Model { fn clone(&self) -> Self { Self { - any_handle: self.any_handle.clone(), + any_model: self.any_model.clone(), entity_type: self.entity_type, } } @@ -306,8 +306,8 @@ impl std::fmt::Debug for Model { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "Handle {{ entity_id: {:?}, entity_type: {:?} }}", - self.any_handle.entity_id, + "Model {{ entity_id: {:?}, entity_type: {:?} }}", + self.any_model.entity_id, type_name::() ) } @@ -315,32 +315,32 @@ impl std::fmt::Debug for Model { impl Hash for Model { fn hash(&self, state: &mut H) { - self.any_handle.hash(state); + self.any_model.hash(state); } } impl PartialEq for Model { fn eq(&self, other: &Self) -> bool { - self.any_handle == other.any_handle + self.any_model == other.any_model } } impl Eq for Model {} -impl PartialEq> for Model { - fn eq(&self, other: &WeakHandle) -> bool { +impl PartialEq> for Model { + fn eq(&self, other: &WeakModel) -> bool { self.entity_id() == other.entity_id() } } #[derive(Clone)] -pub struct AnyWeakHandle { +pub struct AnyWeakModel { pub(crate) entity_id: EntityId, entity_type: TypeId, entity_ref_counts: Weak>, } -impl AnyWeakHandle { +impl AnyWeakModel { pub fn entity_id(&self) -> EntityId { self.entity_id } @@ -354,14 +354,14 @@ impl AnyWeakHandle { ref_count > 0 } - pub fn upgrade(&self) -> Option { + pub fn upgrade(&self) -> Option { let entity_map = self.entity_ref_counts.upgrade()?; entity_map .read() .counts .get(self.entity_id)? .fetch_add(1, SeqCst); - Some(AnyHandle { + Some(AnyModel { entity_id: self.entity_id, entity_type: self.entity_type, entity_map: self.entity_ref_counts.clone(), @@ -369,55 +369,55 @@ impl AnyWeakHandle { } } -impl From> for AnyWeakHandle { - fn from(handle: WeakHandle) -> Self { - handle.any_handle +impl From> for AnyWeakModel { + fn from(model: WeakModel) -> Self { + model.any_model } } -impl Hash for AnyWeakHandle { +impl Hash for AnyWeakModel { fn hash(&self, state: &mut H) { self.entity_id.hash(state); } } -impl PartialEq for AnyWeakHandle { +impl PartialEq for AnyWeakModel { fn eq(&self, other: &Self) -> bool { self.entity_id == other.entity_id } } -impl Eq for AnyWeakHandle {} +impl Eq for AnyWeakModel {} #[derive(Deref, DerefMut)] -pub struct WeakHandle { +pub struct WeakModel { #[deref] #[deref_mut] - any_handle: AnyWeakHandle, + any_model: AnyWeakModel, entity_type: PhantomData, } -unsafe impl Send for WeakHandle {} -unsafe impl Sync for WeakHandle {} +unsafe impl Send for WeakModel {} +unsafe impl Sync for WeakModel {} -impl Clone for WeakHandle { +impl Clone for WeakModel { fn clone(&self) -> Self { Self { - any_handle: self.any_handle.clone(), + any_model: self.any_model.clone(), entity_type: self.entity_type, } } } -impl WeakHandle { +impl WeakModel { pub fn upgrade(&self) -> Option> { Some(Model { - any_handle: self.any_handle.upgrade()?, + any_model: self.any_model.upgrade()?, entity_type: self.entity_type, }) } - /// Update the entity referenced by this handle with the given function if + /// Update the entity referenced by this model with the given function if /// the referenced entity still exists. Returns an error if the entity has /// been released. /// @@ -441,21 +441,21 @@ impl WeakHandle { } } -impl Hash for WeakHandle { +impl Hash for WeakModel { fn hash(&self, state: &mut H) { - self.any_handle.hash(state); + self.any_model.hash(state); } } -impl PartialEq for WeakHandle { +impl PartialEq for WeakModel { fn eq(&self, other: &Self) -> bool { - self.any_handle == other.any_handle + self.any_model == other.any_model } } -impl Eq for WeakHandle {} +impl Eq for WeakModel {} -impl PartialEq> for WeakHandle { +impl PartialEq> for WeakModel { fn eq(&self, other: &Model) -> bool { self.entity_id() == other.entity_id() } diff --git a/crates/gpui2/src/app/model_context.rs b/crates/gpui2/src/app/model_context.rs index b5f78fbc468a6365f1f3a6dd8f10dfa3efc0d808..463652886b2767c8c155f068c057afd2befe9851 100644 --- a/crates/gpui2/src/app/model_context.rs +++ b/crates/gpui2/src/app/model_context.rs @@ -1,6 +1,6 @@ use crate::{ AppContext, AsyncAppContext, Context, Effect, EntityId, EventEmitter, MainThread, Model, - Reference, Subscription, Task, WeakHandle, + Reference, Subscription, Task, WeakModel, }; use derive_more::{Deref, DerefMut}; use futures::FutureExt; @@ -15,11 +15,11 @@ pub struct ModelContext<'a, T> { #[deref] #[deref_mut] app: Reference<'a, AppContext>, - model_state: WeakHandle, + model_state: WeakModel, } impl<'a, T: 'static> ModelContext<'a, T> { - pub(crate) fn mutable(app: &'a mut AppContext, model_state: WeakHandle) -> Self { + pub(crate) fn mutable(app: &'a mut AppContext, model_state: WeakModel) -> Self { Self { app: Reference::Mutable(app), model_state, @@ -36,7 +36,7 @@ impl<'a, T: 'static> ModelContext<'a, T> { .expect("The entity must be alive if we have a model context") } - pub fn weak_handle(&self) -> WeakHandle { + pub fn weak_handle(&self) -> WeakModel { self.model_state.clone() } @@ -184,7 +184,7 @@ impl<'a, T: 'static> ModelContext<'a, T> { pub fn spawn( &self, - f: impl FnOnce(WeakHandle, AsyncAppContext) -> Fut + Send + 'static, + f: impl FnOnce(WeakModel, AsyncAppContext) -> Fut + Send + 'static, ) -> Task where T: 'static, @@ -197,7 +197,7 @@ impl<'a, T: 'static> ModelContext<'a, T> { pub fn spawn_on_main( &self, - f: impl FnOnce(WeakHandle, MainThread) -> Fut + Send + 'static, + f: impl FnOnce(WeakModel, MainThread) -> Fut + Send + 'static, ) -> Task where Fut: Future + 'static, diff --git a/crates/gpui2/src/interactive.rs b/crates/gpui2/src/interactive.rs index a617792bfb88685351c8d9e4eb17652cc49b6dcd..faa7d239757d43f6774610a4544b81f02f61fb17 100644 --- a/crates/gpui2/src/interactive.rs +++ b/crates/gpui2/src/interactive.rs @@ -333,7 +333,7 @@ pub trait StatefulInteractive: StatelessInteractive { Some(Box::new(move |view_state, cursor_offset, cx| { let drag = listener(view_state, cx); let drag_handle_view = Some( - View::for_handle(cx.handle().upgrade().unwrap(), move |view_state, cx| { + View::for_handle(cx.model().upgrade().unwrap(), move |view_state, cx| { (drag.render_drag_handle)(view_state, cx) }) .into_any(), diff --git a/crates/gpui2/src/view.rs b/crates/gpui2/src/view.rs index c988223fd01ddcd2eed5ad69b5a9c4281fb31519..cacca8b91e2f4f7992bbbcc3dbfd5ea9c31cfbdb 100644 --- a/crates/gpui2/src/view.rs +++ b/crates/gpui2/src/view.rs @@ -1,6 +1,6 @@ use crate::{ AnyBox, AnyElement, AvailableSpace, BorrowWindow, Bounds, Component, Element, ElementId, - EntityId, LayoutId, Model, Pixels, Size, ViewContext, VisualContext, WeakHandle, WindowContext, + EntityId, LayoutId, Model, Pixels, Size, ViewContext, VisualContext, WeakModel, WindowContext, }; use anyhow::{Context, Result}; use parking_lot::Mutex; @@ -116,7 +116,7 @@ impl Element<()> for View { } pub struct WeakView { - pub(crate) state: WeakHandle, + pub(crate) state: WeakModel, render: Weak) -> AnyElement + Send + 'static>>, } diff --git a/crates/gpui2/src/window.rs b/crates/gpui2/src/window.rs index 073ffa56bda384bd7703fa3710e551b022035263..3d6a891dfe21097ab06d0f08666aaf793d012a64 100644 --- a/crates/gpui2/src/window.rs +++ b/crates/gpui2/src/window.rs @@ -7,7 +7,7 @@ use crate::{ MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Path, Pixels, PlatformAtlas, PlatformWindow, Point, PolychromeSprite, Quad, Reference, RenderGlyphParams, RenderImageParams, RenderSvgParams, ScaledPixels, SceneBuilder, Shadow, SharedString, Size, Style, Subscription, - TaffyLayoutEngine, Task, Underline, UnderlineStyle, View, VisualContext, WeakHandle, WeakView, + TaffyLayoutEngine, Task, Underline, UnderlineStyle, View, VisualContext, WeakModel, WeakView, WindowOptions, SUBPIXEL_VARIANTS, }; use anyhow::Result; @@ -1257,13 +1257,13 @@ impl Context for WindowContext<'_, '_> { fn update_entity( &mut self, - handle: &Model, + model: &Model, update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R, ) -> R { - let mut entity = self.entities.lease(handle); + let mut entity = self.entities.lease(model); let result = update( &mut *entity, - &mut ModelContext::mutable(&mut *self.app, handle.downgrade()), + &mut ModelContext::mutable(&mut *self.app, model.downgrade()), ); self.entities.end_lease(entity); result @@ -1555,7 +1555,7 @@ impl<'a, 'w, V: 'static> ViewContext<'a, 'w, V> { self.view.clone() } - pub fn handle(&self) -> WeakHandle { + pub fn model(&self) -> WeakModel { self.view.state.clone() } @@ -1872,10 +1872,10 @@ impl<'a, 'w, V> Context for ViewContext<'a, 'w, V> { fn update_entity( &mut self, - handle: &Model, + model: &Model, update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R, ) -> R { - self.window_cx.update_entity(handle, update) + self.window_cx.update_entity(model, update) } } diff --git a/crates/project2/src/project2.rs b/crates/project2/src/project2.rs index e91c3b72631994c07205920eb4bfa2d5f3f61308..ce97b9cc22d386cbc995d9fb8f7330a096a69eb1 100644 --- a/crates/project2/src/project2.rs +++ b/crates/project2/src/project2.rs @@ -26,8 +26,8 @@ use futures::{ }; use globset::{Glob, GlobSet, GlobSetBuilder}; use gpui2::{ - AnyHandle, AppContext, AsyncAppContext, Context, EventEmitter, Executor, Model, ModelContext, - Task, WeakHandle, + AnyModel, AppContext, AsyncAppContext, Context, EventEmitter, Executor, Model, ModelContext, + Task, WeakModel, }; use itertools::Itertools; use language2::{ @@ -153,7 +153,7 @@ pub struct Project { incomplete_remote_buffers: HashMap>>, buffer_snapshots: HashMap>>, // buffer_id -> server_id -> vec of snapshots buffers_being_formatted: HashSet, - buffers_needing_diff: HashSet>, + buffers_needing_diff: HashSet>, git_diff_debouncer: DelayedDebounced, nonce: u128, _maintain_buffer_languages: Task<()>, @@ -245,14 +245,14 @@ enum LocalProjectUpdate { enum OpenBuffer { Strong(Model), - Weak(WeakHandle), + Weak(WeakModel), Operations(Vec), } #[derive(Clone)] enum WorktreeHandle { Strong(Model), - Weak(WeakHandle), + Weak(WeakModel), } enum ProjectClientState { @@ -1671,7 +1671,7 @@ impl Project { &mut self, path: impl Into, cx: &mut ModelContext, - ) -> Task> { + ) -> Task> { let task = self.open_buffer(path, cx); cx.spawn(move |_, mut cx| async move { let buffer = task.await?; @@ -1681,7 +1681,7 @@ impl Project { })? .ok_or_else(|| anyhow!("no project entry"))?; - let buffer: &AnyHandle = &buffer; + let buffer: &AnyModel = &buffer; Ok((project_entry_id, buffer.clone())) }) } @@ -2158,7 +2158,7 @@ impl Project { } async fn send_buffer_ordered_messages( - this: WeakHandle, + this: WeakModel, rx: UnboundedReceiver, mut cx: AsyncAppContext, ) -> Result<()> { @@ -2166,7 +2166,7 @@ impl Project { let mut operations_by_buffer_id = HashMap::default(); async fn flush_operations( - this: &WeakHandle, + this: &WeakModel, operations_by_buffer_id: &mut HashMap>, needs_resync_with_host: &mut bool, is_local: bool, @@ -2931,7 +2931,7 @@ impl Project { } async fn setup_and_insert_language_server( - this: WeakHandle, + this: WeakModel, initialization_options: Option, pending_server: PendingLanguageServer, adapter: Arc, @@ -2970,7 +2970,7 @@ impl Project { } async fn setup_pending_language_server( - this: WeakHandle, + this: WeakModel, initialization_options: Option, pending_server: PendingLanguageServer, adapter: Arc, @@ -3748,7 +3748,7 @@ impl Project { } async fn on_lsp_workspace_edit( - this: WeakHandle, + this: WeakModel, params: lsp2::ApplyWorkspaceEditParams, server_id: LanguageServerId, adapter: Arc, @@ -4360,7 +4360,7 @@ impl Project { } async fn format_via_lsp( - this: &WeakHandle, + this: &WeakModel, buffer: &Model, abs_path: &Path, language_server: &Arc, diff --git a/crates/project2/src/terminals.rs b/crates/project2/src/terminals.rs index 239cb99d86134ad581bb48eac43cac70ebd0c456..5cd62d5ae61cf145c13fa7b6c6cf0f446bb96ff8 100644 --- a/crates/project2/src/terminals.rs +++ b/crates/project2/src/terminals.rs @@ -1,5 +1,5 @@ use crate::Project; -use gpui2::{AnyWindowHandle, Context, Model, ModelContext, WeakHandle}; +use gpui2::{AnyWindowHandle, Context, Model, ModelContext, WeakModel}; use settings2::Settings; use std::path::{Path, PathBuf}; use terminal2::{ @@ -11,7 +11,7 @@ use terminal2::{ use std::os::unix::ffi::OsStrExt; pub struct Terminals { - pub(crate) local_handles: Vec>, + pub(crate) local_handles: Vec>, } impl Project { @@ -121,7 +121,7 @@ impl Project { } } - pub fn local_terminal_handles(&self) -> &Vec> { + pub fn local_terminal_handles(&self) -> &Vec> { &self.terminals.local_handles } } From c17b246bac4e57be486f008f712bd63a191c91cb Mon Sep 17 00:00:00 2001 From: KCaverly Date: Mon, 30 Oct 2023 15:04:16 -0400 Subject: [PATCH 25/25] updated for model handle rename --- crates/ai2/src/prompts/repository_context.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/ai2/src/prompts/repository_context.rs b/crates/ai2/src/prompts/repository_context.rs index 78db5a16516f9797927f10ff136905054f36fa65..1bb75de7d242d34315238f2c43baeb0c016dbdfb 100644 --- a/crates/ai2/src/prompts/repository_context.rs +++ b/crates/ai2/src/prompts/repository_context.rs @@ -2,7 +2,7 @@ use crate::prompts::base::{PromptArguments, PromptTemplate}; use std::fmt::Write; use std::{ops::Range, path::PathBuf}; -use gpui2::{AsyncAppContext, Handle}; +use gpui2::{AsyncAppContext, Model}; use language2::{Anchor, Buffer}; #[derive(Clone)] @@ -14,7 +14,7 @@ pub struct PromptCodeSnippet { impl PromptCodeSnippet { pub fn new( - buffer: Handle, + buffer: Model, range: Range, cx: &mut AsyncAppContext, ) -> anyhow::Result {