diff --git a/Cargo.lock b/Cargo.lock index 0caaaeceeffdc78ce96d165721ee8eb4e9bae081..a5d187d08e5c20a77132c6e19024925520120af9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -108,6 +108,33 @@ dependencies = [ "util", ] +[[package]] +name = "ai2" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-trait", + "bincode", + "futures 0.3.28", + "gpui2", + "isahc", + "language2", + "lazy_static", + "log", + "matrixmultiply", + "ordered-float 2.10.0", + "parking_lot 0.11.2", + "parse_duration", + "postage", + "rand 0.8.5", + "regex", + "rusqlite", + "serde", + "serde_json", + "tiktoken-rs", + "util", +] + [[package]] name = "alacritty_config" version = "0.1.2-dev" @@ -10903,6 +10930,7 @@ dependencies = [ name = "zed2" version = "0.109.0" dependencies = [ + "ai2", "anyhow", "async-compression", "async-recursion 0.3.2", diff --git a/crates/Cargo.toml b/crates/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..fb49a4b515540836a757610db5c268321f9f068b --- /dev/null +++ b/crates/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "ai" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +path = "src/ai.rs" +doctest = false + +[features] +test-support = [] + +[dependencies] +gpui = { path = "../gpui" } +util = { path = "../util" } +language = { path = "../language" } +async-trait.workspace = true +anyhow.workspace = true +futures.workspace = true +lazy_static.workspace = true +ordered-float.workspace = true +parking_lot.workspace = true +isahc.workspace = true +regex.workspace = true +serde.workspace = true +serde_json.workspace = true +postage.workspace = true +rand.workspace = true +log.workspace = true +parse_duration = "2.1.1" +tiktoken-rs = "0.5.0" +matrixmultiply = "0.3.7" +rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] } +bincode = "1.3.3" + +[dev-dependencies] +gpui = { path = "../gpui", features = ["test-support"] } diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index b24c4e5ece5b02eac003a6c18f186faa1eaef7ef..fb49a4b515540836a757610db5c268321f9f068b 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -8,6 +8,9 @@ publish = false path = "src/ai.rs" doctest = false +[features] +test-support = [] + [dependencies] gpui = { path = "../gpui" } util = { path = "../util" } diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index f168c157934f6b70be775f7e17e9ba27ef9b3103..dda22d2a1d04dd6083fb1ae9879f49e74c8b4627 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -1,4 +1,8 @@ +pub mod auth; pub mod completion; pub mod embedding; pub mod models; -pub mod templates; +pub mod prompts; +pub mod providers; +#[cfg(any(test, feature = "test-support"))] +pub mod test; diff --git a/crates/ai/src/auth.rs b/crates/ai/src/auth.rs new file mode 100644 index 0000000000000000000000000000000000000000..c6256df2160abc7284c070ff204ee849911e4cf3 --- /dev/null +++ b/crates/ai/src/auth.rs @@ -0,0 +1,15 @@ +use gpui::AppContext; + +#[derive(Clone, Debug)] +pub enum ProviderCredential { + Credentials { api_key: String }, + NoCredentials, + NotNeeded, +} + +pub trait CredentialProvider: Send + Sync { + fn has_credentials(&self) -> bool; + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential; + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential); + fn delete_credentials(&self, cx: &AppContext); +} diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index de6ce9da711ee17f9fc072276a499d1769b874ce..30a60fcf1d5c5dc66717773968e432e510d6421f 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -1,214 +1,23 @@ -use anyhow::{anyhow, Result}; -use futures::{ - future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, - Stream, StreamExt, -}; -use gpui::executor::Background; -use isahc::{http::StatusCode, Request, RequestExt}; -use serde::{Deserialize, Serialize}; -use std::{ - fmt::{self, Display}, - io, - sync::Arc, -}; +use anyhow::Result; +use futures::{future::BoxFuture, stream::BoxStream}; -pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; +use crate::{auth::CredentialProvider, models::LanguageModel}; -#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] -#[serde(rename_all = "lowercase")] -pub enum Role { - User, - Assistant, - System, +pub trait CompletionRequest: Send + Sync { + fn data(&self) -> serde_json::Result; } -impl Role { - pub fn cycle(&mut self) { - *self = match self { - Role::User => Role::Assistant, - Role::Assistant => Role::System, - Role::System => Role::User, - } - } -} - -impl Display for Role { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Role::User => write!(f, "User"), - Role::Assistant => write!(f, "Assistant"), - Role::System => write!(f, "System"), - } - } -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct RequestMessage { - pub role: Role, - pub content: String, -} - -#[derive(Debug, Default, Serialize)] -pub struct OpenAIRequest { - pub model: String, - pub messages: Vec, - pub stream: bool, - pub stop: Vec, - pub temperature: f32, -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct ResponseMessage { - pub role: Option, - pub content: Option, -} - -#[derive(Deserialize, Debug)] -pub struct OpenAIUsage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, -} - -#[derive(Deserialize, Debug)] -pub struct ChatChoiceDelta { - pub index: u32, - pub delta: ResponseMessage, - pub finish_reason: Option, -} - -#[derive(Deserialize, Debug)] -pub struct OpenAIResponseStreamEvent { - pub id: Option, - pub object: String, - pub created: u32, - pub model: String, - pub choices: Vec, - pub usage: Option, -} - -pub async fn stream_completion( - api_key: String, - executor: Arc, - mut request: OpenAIRequest, -) -> Result>> { - request.stream = true; - - let (tx, rx) = futures::channel::mpsc::unbounded::>(); - - let json_data = serde_json::to_string(&request)?; - let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)) - .body(json_data)? - .send_async() - .await?; - - let status = response.status(); - if status == StatusCode::OK { - executor - .spawn(async move { - let mut lines = BufReader::new(response.body_mut()).lines(); - - fn parse_line( - line: Result, - ) -> Result> { - if let Some(data) = line?.strip_prefix("data: ") { - let event = serde_json::from_str(&data)?; - Ok(Some(event)) - } else { - Ok(None) - } - } - - while let Some(line) = lines.next().await { - if let Some(event) = parse_line(line).transpose() { - let done = event.as_ref().map_or(false, |event| { - event - .choices - .last() - .map_or(false, |choice| choice.finish_reason.is_some()) - }); - if tx.unbounded_send(event).is_err() { - break; - } - - if done { - break; - } - } - } - - anyhow::Ok(()) - }) - .detach(); - - Ok(rx) - } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - #[derive(Deserialize)] - struct OpenAIResponse { - error: OpenAIError, - } - - #[derive(Deserialize)] - struct OpenAIError { - message: String, - } - - match serde_json::from_str::(&body) { - Ok(response) if !response.error.message.is_empty() => Err(anyhow!( - "Failed to connect to OpenAI API: {}", - response.error.message, - )), - - _ => Err(anyhow!( - "Failed to connect to OpenAI API: {} {}", - response.status(), - body, - )), - } - } -} - -pub trait CompletionProvider { +pub trait CompletionProvider: CredentialProvider { + fn base_model(&self) -> Box; fn complete( &self, - prompt: OpenAIRequest, + prompt: Box, ) -> BoxFuture<'static, Result>>>; + fn box_clone(&self) -> Box; } -pub struct OpenAICompletionProvider { - api_key: String, - executor: Arc, -} - -impl OpenAICompletionProvider { - pub fn new(api_key: String, executor: Arc) -> Self { - Self { api_key, executor } - } -} - -impl CompletionProvider for OpenAICompletionProvider { - fn complete( - &self, - prompt: OpenAIRequest, - ) -> BoxFuture<'static, Result>>> { - let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt); - async move { - let response = request.await?; - let stream = response - .filter_map(|response| async move { - match response { - Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), - Err(error) => Some(Err(error)), - } - }) - .boxed(); - Ok(stream) - } - .boxed() +impl Clone for Box { + fn clone(&self) -> Box { + self.box_clone() } } diff --git a/crates/ai/src/embedding.rs b/crates/ai/src/embedding.rs index b791414ba274d3529e62868a5747ff7bcb946b31..6768b7ce7bab7898ee4222d5217a4de120ca57ed 100644 --- a/crates/ai/src/embedding.rs +++ b/crates/ai/src/embedding.rs @@ -1,32 +1,13 @@ -use anyhow::{anyhow, Result}; +use std::time::Instant; + +use anyhow::Result; use async_trait::async_trait; -use futures::AsyncReadExt; -use gpui::executor::Background; -use gpui::{serde_json, AppContext}; -use isahc::http::StatusCode; -use isahc::prelude::Configurable; -use isahc::{AsyncBody, Response}; -use lazy_static::lazy_static; use ordered_float::OrderedFloat; -use parking_lot::Mutex; -use parse_duration::parse; -use postage::watch; use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; use rusqlite::ToSql; -use serde::{Deserialize, Serialize}; -use std::env; -use std::ops::Add; -use std::sync::Arc; -use std::time::{Duration, Instant}; -use tiktoken_rs::{cl100k_base, CoreBPE}; -use util::http::{HttpClient, Request}; -use util::ResultExt; - -use crate::completion::OPENAI_API_URL; -lazy_static! { - static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); -} +use crate::auth::CredentialProvider; +use crate::models::LanguageModel; #[derive(Debug, PartialEq, Clone)] pub struct Embedding(pub Vec); @@ -87,301 +68,14 @@ impl Embedding { } } -#[derive(Clone)] -pub struct OpenAIEmbeddings { - pub client: Arc, - pub executor: Arc, - rate_limit_count_rx: watch::Receiver>, - rate_limit_count_tx: Arc>>>, -} - -#[derive(Serialize)] -struct OpenAIEmbeddingRequest<'a> { - model: &'static str, - input: Vec<&'a str>, -} - -#[derive(Deserialize)] -struct OpenAIEmbeddingResponse { - data: Vec, - usage: OpenAIEmbeddingUsage, -} - -#[derive(Debug, Deserialize)] -struct OpenAIEmbedding { - embedding: Vec, - index: usize, - object: String, -} - -#[derive(Deserialize)] -struct OpenAIEmbeddingUsage { - prompt_tokens: usize, - total_tokens: usize, -} - #[async_trait] -pub trait EmbeddingProvider: Sync + Send { - fn retrieve_credentials(&self, cx: &AppContext) -> Option; - async fn embed_batch( - &self, - spans: Vec, - api_key: Option, - ) -> Result>; +pub trait EmbeddingProvider: CredentialProvider { + fn base_model(&self) -> Box; + async fn embed_batch(&self, spans: Vec) -> Result>; fn max_tokens_per_batch(&self) -> usize; - fn truncate(&self, span: &str) -> (String, usize); fn rate_limit_expiration(&self) -> Option; } -pub struct DummyEmbeddings {} - -#[async_trait] -impl EmbeddingProvider for DummyEmbeddings { - fn retrieve_credentials(&self, _cx: &AppContext) -> Option { - Some("Dummy API KEY".to_string()) - } - fn rate_limit_expiration(&self) -> Option { - None - } - async fn embed_batch( - &self, - spans: Vec, - _api_key: Option, - ) -> Result> { - // 1024 is the OpenAI Embeddings size for ada models. - // the model we will likely be starting with. - let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]); - return Ok(vec![dummy_vec; spans.len()]); - } - - fn max_tokens_per_batch(&self) -> usize { - OPENAI_INPUT_LIMIT - } - - fn truncate(&self, span: &str) -> (String, usize) { - let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - let token_count = tokens.len(); - let output = if token_count > OPENAI_INPUT_LIMIT { - tokens.truncate(OPENAI_INPUT_LIMIT); - let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone()); - new_input.ok().unwrap_or_else(|| span.to_string()) - } else { - span.to_string() - }; - - (output, tokens.len()) - } -} - -const OPENAI_INPUT_LIMIT: usize = 8190; - -impl OpenAIEmbeddings { - pub fn new(client: Arc, executor: Arc) -> Self { - let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); - let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); - - OpenAIEmbeddings { - client, - executor, - rate_limit_count_rx, - rate_limit_count_tx, - } - } - - fn resolve_rate_limit(&self) { - let reset_time = *self.rate_limit_count_tx.lock().borrow(); - - if let Some(reset_time) = reset_time { - if Instant::now() >= reset_time { - *self.rate_limit_count_tx.lock().borrow_mut() = None - } - } - - log::trace!( - "resolving reset time: {:?}", - *self.rate_limit_count_tx.lock().borrow() - ); - } - - fn update_reset_time(&self, reset_time: Instant) { - let original_time = *self.rate_limit_count_tx.lock().borrow(); - - let updated_time = if let Some(original_time) = original_time { - if reset_time < original_time { - Some(reset_time) - } else { - Some(original_time) - } - } else { - Some(reset_time) - }; - - log::trace!("updating rate limit time: {:?}", updated_time); - - *self.rate_limit_count_tx.lock().borrow_mut() = updated_time; - } - async fn send_request( - &self, - api_key: &str, - spans: Vec<&str>, - request_timeout: u64, - ) -> Result> { - let request = Request::post("https://api.openai.com/v1/embeddings") - .redirect_policy(isahc::config::RedirectPolicy::Follow) - .timeout(Duration::from_secs(request_timeout)) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)) - .body( - serde_json::to_string(&OpenAIEmbeddingRequest { - input: spans.clone(), - model: "text-embedding-ada-002", - }) - .unwrap() - .into(), - )?; - - Ok(self.client.send(request).await?) - } -} - -#[async_trait] -impl EmbeddingProvider for OpenAIEmbeddings { - fn retrieve_credentials(&self, cx: &AppContext) -> Option { - if let Ok(api_key) = env::var("OPENAI_API_KEY") { - Some(api_key) - } else if let Some((_, api_key)) = cx - .platform() - .read_credentials(OPENAI_API_URL) - .log_err() - .flatten() - { - String::from_utf8(api_key).log_err() - } else { - None - } - } - - fn max_tokens_per_batch(&self) -> usize { - 50000 - } - - fn rate_limit_expiration(&self) -> Option { - *self.rate_limit_count_rx.borrow() - } - fn truncate(&self, span: &str) -> (String, usize) { - let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span); - let output = if tokens.len() > OPENAI_INPUT_LIMIT { - tokens.truncate(OPENAI_INPUT_LIMIT); - OPENAI_BPE_TOKENIZER - .decode(tokens.clone()) - .ok() - .unwrap_or_else(|| span.to_string()) - } else { - span.to_string() - }; - - (output, tokens.len()) - } - - async fn embed_batch( - &self, - spans: Vec, - api_key: Option, - ) -> Result> { - const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; - const MAX_RETRIES: usize = 4; - - let Some(api_key) = api_key else { - return Err(anyhow!("no open ai key provided")); - }; - - let mut request_number = 0; - let mut rate_limiting = false; - let mut request_timeout: u64 = 15; - let mut response: Response; - while request_number < MAX_RETRIES { - response = self - .send_request( - &api_key, - spans.iter().map(|x| &**x).collect(), - request_timeout, - ) - .await?; - - request_number += 1; - - match response.status() { - StatusCode::REQUEST_TIMEOUT => { - request_timeout += 5; - } - StatusCode::OK => { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; - - log::trace!( - "openai embedding completed. tokens: {:?}", - response.usage.total_tokens - ); - - // If we complete a request successfully that was previously rate_limited - // resolve the rate limit - if rate_limiting { - self.resolve_rate_limit() - } - - return Ok(response - .data - .into_iter() - .map(|embedding| Embedding::from(embedding.embedding)) - .collect()); - } - StatusCode::TOO_MANY_REQUESTS => { - rate_limiting = true; - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - let delay_duration = { - let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); - if let Some(time_to_reset) = - response.headers().get("x-ratelimit-reset-tokens") - { - if let Ok(time_str) = time_to_reset.to_str() { - parse(time_str).unwrap_or(delay) - } else { - delay - } - } else { - delay - } - }; - - // If we've previously rate limited, increment the duration but not the count - let reset_time = Instant::now().add(delay_duration); - self.update_reset_time(reset_time); - - log::trace!( - "openai rate limiting: waiting {:?} until lifted", - &delay_duration - ); - - self.executor.timer(delay_duration).await; - } - _ => { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - return Err(anyhow!( - "open ai bad request: {:?} {:?}", - &response.status(), - body - )); - } - } - } - Err(anyhow!("openai max retries")) - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/ai/src/models.rs b/crates/ai/src/models.rs index d0206cc41c526f171fef8521a120f8f4ff70aa74..1db3d58c6f54ad613cb98fc3f425df3d47e5e97f 100644 --- a/crates/ai/src/models.rs +++ b/crates/ai/src/models.rs @@ -1,66 +1,16 @@ -use anyhow::anyhow; -use tiktoken_rs::CoreBPE; -use util::ResultExt; +pub enum TruncationDirection { + Start, + End, +} pub trait LanguageModel { fn name(&self) -> String; fn count_tokens(&self, content: &str) -> anyhow::Result; - fn truncate(&self, content: &str, length: usize) -> anyhow::Result; - fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result; + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result; fn capacity(&self) -> anyhow::Result; } - -pub struct OpenAILanguageModel { - name: String, - bpe: Option, -} - -impl OpenAILanguageModel { - pub fn load(model_name: &str) -> Self { - let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err(); - OpenAILanguageModel { - name: model_name.to_string(), - bpe, - } - } -} - -impl LanguageModel for OpenAILanguageModel { - fn name(&self) -> String { - self.name.clone() - } - fn count_tokens(&self, content: &str) -> anyhow::Result { - if let Some(bpe) = &self.bpe { - anyhow::Ok(bpe.encode_with_special_tokens(content).len()) - } else { - Err(anyhow!("bpe for open ai model was not retrieved")) - } - } - fn truncate(&self, content: &str, length: usize) -> anyhow::Result { - if let Some(bpe) = &self.bpe { - let tokens = bpe.encode_with_special_tokens(content); - if tokens.len() > length { - bpe.decode(tokens[..length].to_vec()) - } else { - bpe.decode(tokens) - } - } else { - Err(anyhow!("bpe for open ai model was not retrieved")) - } - } - fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result { - if let Some(bpe) = &self.bpe { - let tokens = bpe.encode_with_special_tokens(content); - if tokens.len() > length { - bpe.decode(tokens[length..].to_vec()) - } else { - bpe.decode(tokens) - } - } else { - Err(anyhow!("bpe for open ai model was not retrieved")) - } - } - fn capacity(&self) -> anyhow::Result { - anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name)) - } -} diff --git a/crates/ai/src/templates/base.rs b/crates/ai/src/prompts/base.rs similarity index 85% rename from crates/ai/src/templates/base.rs rename to crates/ai/src/prompts/base.rs index bda1d6c30e61a9e2fd3808fa45a34cbe041cf2b6..75bad00154b001a356f35cb5d80e4ac4962fe0f9 100644 --- a/crates/ai/src/templates/base.rs +++ b/crates/ai/src/prompts/base.rs @@ -6,7 +6,7 @@ use language::BufferSnapshot; use util::ResultExt; use crate::models::LanguageModel; -use crate::templates::repository_context::PromptCodeSnippet; +use crate::prompts::repository_context::PromptCodeSnippet; pub(crate) enum PromptFileType { Text, @@ -125,6 +125,9 @@ impl PromptChain { #[cfg(test)] pub(crate) mod tests { + use crate::models::TruncationDirection; + use crate::test::FakeLanguageModel; + use super::*; #[test] @@ -141,7 +144,11 @@ pub(crate) mod tests { let mut token_count = args.model.count_tokens(&content)?; if let Some(max_token_length) = max_token_length { if token_count > max_token_length { - content = args.model.truncate(&content, max_token_length)?; + content = args.model.truncate( + &content, + max_token_length, + TruncationDirection::End, + )?; token_count = max_token_length; } } @@ -162,7 +169,11 @@ pub(crate) mod tests { let mut token_count = args.model.count_tokens(&content)?; if let Some(max_token_length) = max_token_length { if token_count > max_token_length { - content = args.model.truncate(&content, max_token_length)?; + content = args.model.truncate( + &content, + max_token_length, + TruncationDirection::End, + )?; token_count = max_token_length; } } @@ -171,38 +182,7 @@ pub(crate) mod tests { } } - #[derive(Clone)] - struct DummyLanguageModel { - capacity: usize, - } - - impl LanguageModel for DummyLanguageModel { - fn name(&self) -> String { - "dummy".to_string() - } - fn count_tokens(&self, content: &str) -> anyhow::Result { - anyhow::Ok(content.chars().collect::>().len()) - } - fn truncate(&self, content: &str, length: usize) -> anyhow::Result { - anyhow::Ok( - content.chars().collect::>()[..length] - .into_iter() - .collect::(), - ) - } - fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result { - anyhow::Ok( - content.chars().collect::>()[length..] - .into_iter() - .collect::(), - ) - } - fn capacity(&self) -> anyhow::Result { - anyhow::Ok(self.capacity) - } - } - - let model: Arc = Arc::new(DummyLanguageModel { capacity: 100 }); + let model: Arc = Arc::new(FakeLanguageModel { capacity: 100 }); let args = PromptArguments { model: model.clone(), language_name: None, @@ -238,7 +218,7 @@ pub(crate) mod tests { // Testing with Truncation Off // Should ignore capacity and return all prompts - let model: Arc = Arc::new(DummyLanguageModel { capacity: 20 }); + let model: Arc = Arc::new(FakeLanguageModel { capacity: 20 }); let args = PromptArguments { model: model.clone(), language_name: None, @@ -275,7 +255,7 @@ pub(crate) mod tests { // Testing with Truncation Off // Should ignore capacity and return all prompts let capacity = 20; - let model: Arc = Arc::new(DummyLanguageModel { capacity }); + let model: Arc = Arc::new(FakeLanguageModel { capacity }); let args = PromptArguments { model: model.clone(), language_name: None, @@ -311,7 +291,7 @@ pub(crate) mod tests { // Change Ordering of Prompts Based on Priority let capacity = 120; let reserved_tokens = 10; - let model: Arc = Arc::new(DummyLanguageModel { capacity }); + let model: Arc = Arc::new(FakeLanguageModel { capacity }); let args = PromptArguments { model: model.clone(), language_name: None, diff --git a/crates/ai/src/templates/file_context.rs b/crates/ai/src/prompts/file_context.rs similarity index 91% rename from crates/ai/src/templates/file_context.rs rename to crates/ai/src/prompts/file_context.rs index 1afd61192edc02b153abe8cd00836d67caa42f02..f108a62f6f0f82e1d25f01cb2e2ae2a755d69fda 100644 --- a/crates/ai/src/templates/file_context.rs +++ b/crates/ai/src/prompts/file_context.rs @@ -3,8 +3,9 @@ use language::BufferSnapshot; use language::ToOffset; use crate::models::LanguageModel; -use crate::templates::base::PromptArguments; -use crate::templates::base::PromptTemplate; +use crate::models::TruncationDirection; +use crate::prompts::base::PromptArguments; +use crate::prompts::base::PromptTemplate; use std::fmt::Write; use std::ops::Range; use std::sync::Arc; @@ -70,8 +71,9 @@ fn retrieve_context( }; let truncated_start_window = - model.truncate_start(&start_window, start_goal_tokens)?; - let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?; + model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?; + let truncated_end_window = + model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?; writeln!( prompt, "{truncated_start_window}{selected_window}{truncated_end_window}" @@ -89,7 +91,7 @@ fn retrieve_context( if let Some(max_token_count) = max_token_count { if model.count_tokens(&prompt)? > max_token_count { truncated = true; - prompt = model.truncate(&prompt, max_token_count)?; + prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?; } } } @@ -148,7 +150,9 @@ impl PromptTemplate for FileContext { // Really dumb truncation strategy if let Some(max_tokens) = max_token_length { - prompt = args.model.truncate(&prompt, max_tokens)?; + prompt = args + .model + .truncate(&prompt, max_tokens, TruncationDirection::End)?; } let token_count = args.model.count_tokens(&prompt)?; diff --git a/crates/ai/src/templates/generate.rs b/crates/ai/src/prompts/generate.rs similarity index 92% rename from crates/ai/src/templates/generate.rs rename to crates/ai/src/prompts/generate.rs index 1eeb197f932db0dc13963982e7e8bc983c338db7..c7be620107ee4d6daca06a8cb38019aceedc40a4 100644 --- a/crates/ai/src/templates/generate.rs +++ b/crates/ai/src/prompts/generate.rs @@ -1,4 +1,4 @@ -use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate}; +use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate}; use anyhow::anyhow; use std::fmt::Write; @@ -85,7 +85,11 @@ impl PromptTemplate for GenerateInlineContent { // Really dumb truncation strategy if let Some(max_tokens) = max_token_length { - prompt = args.model.truncate(&prompt, max_tokens)?; + prompt = args.model.truncate( + &prompt, + max_tokens, + crate::models::TruncationDirection::End, + )?; } let token_count = args.model.count_tokens(&prompt)?; diff --git a/crates/ai/src/templates/mod.rs b/crates/ai/src/prompts/mod.rs similarity index 100% rename from crates/ai/src/templates/mod.rs rename to crates/ai/src/prompts/mod.rs diff --git a/crates/ai/src/templates/preamble.rs b/crates/ai/src/prompts/preamble.rs similarity index 95% rename from crates/ai/src/templates/preamble.rs rename to crates/ai/src/prompts/preamble.rs index 9eabaaeb97fe4216c6bac44cf4eabfb7c129ecf2..92e0edeb78b48169379aae2e88e81f62463a1057 100644 --- a/crates/ai/src/templates/preamble.rs +++ b/crates/ai/src/prompts/preamble.rs @@ -1,4 +1,4 @@ -use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate}; +use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate}; use std::fmt::Write; pub struct EngineerPreamble {} diff --git a/crates/ai/src/templates/repository_context.rs b/crates/ai/src/prompts/repository_context.rs similarity index 98% rename from crates/ai/src/templates/repository_context.rs rename to crates/ai/src/prompts/repository_context.rs index a8e7f4b5af7bee4d3f29d70c665965dc7e05ed4b..c21b0f995c19361ed89243bf49696f3b0ba3c865 100644 --- a/crates/ai/src/templates/repository_context.rs +++ b/crates/ai/src/prompts/repository_context.rs @@ -1,4 +1,4 @@ -use crate::templates::base::{PromptArguments, PromptTemplate}; +use crate::prompts::base::{PromptArguments, PromptTemplate}; use std::fmt::Write; use std::{ops::Range, path::PathBuf}; diff --git a/crates/ai/src/providers/mod.rs b/crates/ai/src/providers/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..acd0f9d91053869e3999ef0c1a23326480a7cbdd --- /dev/null +++ b/crates/ai/src/providers/mod.rs @@ -0,0 +1 @@ +pub mod open_ai; diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs new file mode 100644 index 0000000000000000000000000000000000000000..94685fd233520fc919d2708a20c202828250a481 --- /dev/null +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -0,0 +1,298 @@ +use anyhow::{anyhow, Result}; +use futures::{ + future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, + Stream, StreamExt, +}; +use gpui::{executor::Background, AppContext}; +use isahc::{http::StatusCode, Request, RequestExt}; +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; +use std::{ + env, + fmt::{self, Display}, + io, + sync::Arc, +}; +use util::ResultExt; + +use crate::{ + auth::{CredentialProvider, ProviderCredential}, + completion::{CompletionProvider, CompletionRequest}, + models::LanguageModel, +}; + +use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL}; + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, +} + +impl Role { + pub fn cycle(&mut self) { + *self = match self { + Role::User => Role::Assistant, + Role::Assistant => Role::System, + Role::System => Role::User, + } + } +} + +impl Display for Role { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Role::User => write!(f, "User"), + Role::Assistant => write!(f, "Assistant"), + Role::System => write!(f, "System"), + } + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct RequestMessage { + pub role: Role, + pub content: String, +} + +#[derive(Debug, Default, Serialize)] +pub struct OpenAIRequest { + pub model: String, + pub messages: Vec, + pub stream: bool, + pub stop: Vec, + pub temperature: f32, +} + +impl CompletionRequest for OpenAIRequest { + fn data(&self) -> serde_json::Result { + serde_json::to_string(self) + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ResponseMessage { + pub role: Option, + pub content: Option, +} + +#[derive(Deserialize, Debug)] +pub struct OpenAIUsage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Deserialize, Debug)] +pub struct ChatChoiceDelta { + pub index: u32, + pub delta: ResponseMessage, + pub finish_reason: Option, +} + +#[derive(Deserialize, Debug)] +pub struct OpenAIResponseStreamEvent { + pub id: Option, + pub object: String, + pub created: u32, + pub model: String, + pub choices: Vec, + pub usage: Option, +} + +pub async fn stream_completion( + credential: ProviderCredential, + executor: Arc, + request: Box, +) -> Result>> { + let api_key = match credential { + ProviderCredential::Credentials { api_key } => api_key, + _ => { + return Err(anyhow!("no credentials provider for completion")); + } + }; + + let (tx, rx) = futures::channel::mpsc::unbounded::>(); + + let json_data = request.data()?; + let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body(json_data)? + .send_async() + .await?; + + let status = response.status(); + if status == StatusCode::OK { + executor + .spawn(async move { + let mut lines = BufReader::new(response.body_mut()).lines(); + + fn parse_line( + line: Result, + ) -> Result> { + if let Some(data) = line?.strip_prefix("data: ") { + let event = serde_json::from_str(&data)?; + Ok(Some(event)) + } else { + Ok(None) + } + } + + while let Some(line) = lines.next().await { + if let Some(event) = parse_line(line).transpose() { + let done = event.as_ref().map_or(false, |event| { + event + .choices + .last() + .map_or(false, |choice| choice.finish_reason.is_some()) + }); + if tx.unbounded_send(event).is_err() { + break; + } + + if done { + break; + } + } + } + + anyhow::Ok(()) + }) + .detach(); + + Ok(rx) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + #[derive(Deserialize)] + struct OpenAIResponse { + error: OpenAIError, + } + + #[derive(Deserialize)] + struct OpenAIError { + message: String, + } + + match serde_json::from_str::(&body) { + Ok(response) if !response.error.message.is_empty() => Err(anyhow!( + "Failed to connect to OpenAI API: {}", + response.error.message, + )), + + _ => Err(anyhow!( + "Failed to connect to OpenAI API: {} {}", + response.status(), + body, + )), + } + } +} + +#[derive(Clone)] +pub struct OpenAICompletionProvider { + model: OpenAILanguageModel, + credential: Arc>, + executor: Arc, +} + +impl OpenAICompletionProvider { + pub fn new(model_name: &str, executor: Arc) -> Self { + let model = OpenAILanguageModel::load(model_name); + let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); + Self { + model, + credential, + executor, + } + } +} + +impl CredentialProvider for OpenAICompletionProvider { + fn has_credentials(&self) -> bool { + match *self.credential.read() { + ProviderCredential::Credentials { .. } => true, + _ => false, + } + } + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { + let mut credential = self.credential.write(); + match *credential { + ProviderCredential::Credentials { .. } => { + return credential.clone(); + } + _ => { + if let Ok(api_key) = env::var("OPENAI_API_KEY") { + *credential = ProviderCredential::Credentials { api_key }; + } else if let Some((_, api_key)) = cx + .platform() + .read_credentials(OPENAI_API_URL) + .log_err() + .flatten() + { + if let Some(api_key) = String::from_utf8(api_key).log_err() { + *credential = ProviderCredential::Credentials { api_key }; + } + } else { + }; + } + } + + credential.clone() + } + + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) { + match credential.clone() { + ProviderCredential::Credentials { api_key } => { + cx.platform() + .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) + .log_err(); + } + _ => {} + } + + *self.credential.write() = credential; + } + fn delete_credentials(&self, cx: &AppContext) { + cx.platform().delete_credentials(OPENAI_API_URL).log_err(); + *self.credential.write() = ProviderCredential::NoCredentials; + } +} + +impl CompletionProvider for OpenAICompletionProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(self.model.clone()); + model + } + fn complete( + &self, + prompt: Box, + ) -> BoxFuture<'static, Result>>> { + // Currently the CompletionRequest for OpenAI, includes a 'model' parameter + // This means that the model is determined by the CompletionRequest and not the CompletionProvider, + // which is currently model based, due to the langauge model. + // At some point in the future we should rectify this. + let credential = self.credential.read().clone(); + let request = stream_completion(credential, self.executor.clone(), prompt); + async move { + let response = request.await?; + let stream = response + .filter_map(|response| async move { + match response { + Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), + Err(error) => Some(Err(error)), + } + }) + .boxed(); + Ok(stream) + } + .boxed() + } + fn box_clone(&self) -> Box { + Box::new((*self).clone()) + } +} diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs new file mode 100644 index 0000000000000000000000000000000000000000..fbfd0028f9fd0f99dc60649d40425c1cbd822485 --- /dev/null +++ b/crates/ai/src/providers/open_ai/embedding.rs @@ -0,0 +1,306 @@ +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use futures::AsyncReadExt; +use gpui::executor::Background; +use gpui::{serde_json, AppContext}; +use isahc::http::StatusCode; +use isahc::prelude::Configurable; +use isahc::{AsyncBody, Response}; +use lazy_static::lazy_static; +use parking_lot::{Mutex, RwLock}; +use parse_duration::parse; +use postage::watch; +use serde::{Deserialize, Serialize}; +use std::env; +use std::ops::Add; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tiktoken_rs::{cl100k_base, CoreBPE}; +use util::http::{HttpClient, Request}; +use util::ResultExt; + +use crate::auth::{CredentialProvider, ProviderCredential}; +use crate::embedding::{Embedding, EmbeddingProvider}; +use crate::models::LanguageModel; +use crate::providers::open_ai::OpenAILanguageModel; + +use crate::providers::open_ai::OPENAI_API_URL; + +lazy_static! { + static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); +} + +#[derive(Clone)] +pub struct OpenAIEmbeddingProvider { + model: OpenAILanguageModel, + credential: Arc>, + pub client: Arc, + pub executor: Arc, + rate_limit_count_rx: watch::Receiver>, + rate_limit_count_tx: Arc>>>, +} + +#[derive(Serialize)] +struct OpenAIEmbeddingRequest<'a> { + model: &'static str, + input: Vec<&'a str>, +} + +#[derive(Deserialize)] +struct OpenAIEmbeddingResponse { + data: Vec, + usage: OpenAIEmbeddingUsage, +} + +#[derive(Debug, Deserialize)] +struct OpenAIEmbedding { + embedding: Vec, + index: usize, + object: String, +} + +#[derive(Deserialize)] +struct OpenAIEmbeddingUsage { + prompt_tokens: usize, + total_tokens: usize, +} + +impl OpenAIEmbeddingProvider { + pub fn new(client: Arc, executor: Arc) -> Self { + let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); + let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); + + let model = OpenAILanguageModel::load("text-embedding-ada-002"); + let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); + + OpenAIEmbeddingProvider { + model, + credential, + client, + executor, + rate_limit_count_rx, + rate_limit_count_tx, + } + } + + fn get_api_key(&self) -> Result { + match self.credential.read().clone() { + ProviderCredential::Credentials { api_key } => Ok(api_key), + _ => Err(anyhow!("api credentials not provided")), + } + } + + fn resolve_rate_limit(&self) { + let reset_time = *self.rate_limit_count_tx.lock().borrow(); + + if let Some(reset_time) = reset_time { + if Instant::now() >= reset_time { + *self.rate_limit_count_tx.lock().borrow_mut() = None + } + } + + log::trace!( + "resolving reset time: {:?}", + *self.rate_limit_count_tx.lock().borrow() + ); + } + + fn update_reset_time(&self, reset_time: Instant) { + let original_time = *self.rate_limit_count_tx.lock().borrow(); + + let updated_time = if let Some(original_time) = original_time { + if reset_time < original_time { + Some(reset_time) + } else { + Some(original_time) + } + } else { + Some(reset_time) + }; + + log::trace!("updating rate limit time: {:?}", updated_time); + + *self.rate_limit_count_tx.lock().borrow_mut() = updated_time; + } + async fn send_request( + &self, + api_key: &str, + spans: Vec<&str>, + request_timeout: u64, + ) -> Result> { + let request = Request::post("https://api.openai.com/v1/embeddings") + .redirect_policy(isahc::config::RedirectPolicy::Follow) + .timeout(Duration::from_secs(request_timeout)) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body( + serde_json::to_string(&OpenAIEmbeddingRequest { + input: spans.clone(), + model: "text-embedding-ada-002", + }) + .unwrap() + .into(), + )?; + + Ok(self.client.send(request).await?) + } +} + +impl CredentialProvider for OpenAIEmbeddingProvider { + fn has_credentials(&self) -> bool { + match *self.credential.read() { + ProviderCredential::Credentials { .. } => true, + _ => false, + } + } + fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { + let mut credential = self.credential.write(); + match *credential { + ProviderCredential::Credentials { .. } => { + return credential.clone(); + } + _ => { + if let Ok(api_key) = env::var("OPENAI_API_KEY") { + *credential = ProviderCredential::Credentials { api_key }; + } else if let Some((_, api_key)) = cx + .platform() + .read_credentials(OPENAI_API_URL) + .log_err() + .flatten() + { + if let Some(api_key) = String::from_utf8(api_key).log_err() { + *credential = ProviderCredential::Credentials { api_key }; + } + } else { + }; + } + } + + credential.clone() + } + + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) { + match credential.clone() { + ProviderCredential::Credentials { api_key } => { + cx.platform() + .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) + .log_err(); + } + _ => {} + } + + *self.credential.write() = credential; + } + fn delete_credentials(&self, cx: &AppContext) { + cx.platform().delete_credentials(OPENAI_API_URL).log_err(); + *self.credential.write() = ProviderCredential::NoCredentials; + } +} + +#[async_trait] +impl EmbeddingProvider for OpenAIEmbeddingProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(self.model.clone()); + model + } + + fn max_tokens_per_batch(&self) -> usize { + 50000 + } + + fn rate_limit_expiration(&self) -> Option { + *self.rate_limit_count_rx.borrow() + } + + async fn embed_batch(&self, spans: Vec) -> Result> { + const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; + const MAX_RETRIES: usize = 4; + + let api_key = self.get_api_key()?; + + let mut request_number = 0; + let mut rate_limiting = false; + let mut request_timeout: u64 = 15; + let mut response: Response; + while request_number < MAX_RETRIES { + response = self + .send_request( + &api_key, + spans.iter().map(|x| &**x).collect(), + request_timeout, + ) + .await?; + + request_number += 1; + + match response.status() { + StatusCode::REQUEST_TIMEOUT => { + request_timeout += 5; + } + StatusCode::OK => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; + + log::trace!( + "openai embedding completed. tokens: {:?}", + response.usage.total_tokens + ); + + // If we complete a request successfully that was previously rate_limited + // resolve the rate limit + if rate_limiting { + self.resolve_rate_limit() + } + + return Ok(response + .data + .into_iter() + .map(|embedding| Embedding::from(embedding.embedding)) + .collect()); + } + StatusCode::TOO_MANY_REQUESTS => { + rate_limiting = true; + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + let delay_duration = { + let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); + if let Some(time_to_reset) = + response.headers().get("x-ratelimit-reset-tokens") + { + if let Ok(time_str) = time_to_reset.to_str() { + parse(time_str).unwrap_or(delay) + } else { + delay + } + } else { + delay + } + }; + + // If we've previously rate limited, increment the duration but not the count + let reset_time = Instant::now().add(delay_duration); + self.update_reset_time(reset_time); + + log::trace!( + "openai rate limiting: waiting {:?} until lifted", + &delay_duration + ); + + self.executor.timer(delay_duration).await; + } + _ => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + return Err(anyhow!( + "open ai bad request: {:?} {:?}", + &response.status(), + body + )); + } + } + } + Err(anyhow!("openai max retries")) + } +} diff --git a/crates/ai/src/providers/open_ai/mod.rs b/crates/ai/src/providers/open_ai/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..7d2f86045d9e1049c55111aef175ac9b56dc7e16 --- /dev/null +++ b/crates/ai/src/providers/open_ai/mod.rs @@ -0,0 +1,9 @@ +pub mod completion; +pub mod embedding; +pub mod model; + +pub use completion::*; +pub use embedding::*; +pub use model::OpenAILanguageModel; + +pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; diff --git a/crates/ai/src/providers/open_ai/model.rs b/crates/ai/src/providers/open_ai/model.rs new file mode 100644 index 0000000000000000000000000000000000000000..6e306c80b905865c011c9064934827085ca126d6 --- /dev/null +++ b/crates/ai/src/providers/open_ai/model.rs @@ -0,0 +1,57 @@ +use anyhow::anyhow; +use tiktoken_rs::CoreBPE; +use util::ResultExt; + +use crate::models::{LanguageModel, TruncationDirection}; + +#[derive(Clone)] +pub struct OpenAILanguageModel { + name: String, + bpe: Option, +} + +impl OpenAILanguageModel { + pub fn load(model_name: &str) -> Self { + let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err(); + OpenAILanguageModel { + name: model_name.to_string(), + bpe, + } + } +} + +impl LanguageModel for OpenAILanguageModel { + fn name(&self) -> String { + self.name.clone() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + anyhow::Ok(bpe.encode_with_special_tokens(content).len()) + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + let tokens = bpe.encode_with_special_tokens(content); + if tokens.len() > length { + match direction { + TruncationDirection::End => bpe.decode(tokens[..length].to_vec()), + TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()), + } + } else { + bpe.decode(tokens) + } + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name)) + } +} diff --git a/crates/ai/src/providers/open_ai/new.rs b/crates/ai/src/providers/open_ai/new.rs new file mode 100644 index 0000000000000000000000000000000000000000..c7d67f2ba1d252a6865124d8ffdfb79130a8c3a0 --- /dev/null +++ b/crates/ai/src/providers/open_ai/new.rs @@ -0,0 +1,11 @@ +pub trait LanguageModel { + fn name(&self) -> String; + fn count_tokens(&self, content: &str) -> anyhow::Result; + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result; + fn capacity(&self) -> anyhow::Result; +} diff --git a/crates/ai/src/test.rs b/crates/ai/src/test.rs new file mode 100644 index 0000000000000000000000000000000000000000..d4165f3cca897c4adbf11c2babf6038a8d86f0a6 --- /dev/null +++ b/crates/ai/src/test.rs @@ -0,0 +1,191 @@ +use std::{ + sync::atomic::{self, AtomicUsize, Ordering}, + time::Instant, +}; + +use async_trait::async_trait; +use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use gpui::AppContext; +use parking_lot::Mutex; + +use crate::{ + auth::{CredentialProvider, ProviderCredential}, + completion::{CompletionProvider, CompletionRequest}, + embedding::{Embedding, EmbeddingProvider}, + models::{LanguageModel, TruncationDirection}, +}; + +#[derive(Clone)] +pub struct FakeLanguageModel { + pub capacity: usize, +} + +impl LanguageModel for FakeLanguageModel { + fn name(&self) -> String { + "dummy".to_string() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + anyhow::Ok(content.chars().collect::>().len()) + } + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { + println!("TRYING TO TRUNCATE: {:?}", length.clone()); + + if length > self.count_tokens(content)? { + println!("NOT TRUNCATING"); + return anyhow::Ok(content.to_string()); + } + + anyhow::Ok(match direction { + TruncationDirection::End => content.chars().collect::>()[..length] + .into_iter() + .collect::(), + TruncationDirection::Start => content.chars().collect::>()[length..] + .into_iter() + .collect::(), + }) + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(self.capacity) + } +} + +pub struct FakeEmbeddingProvider { + pub embedding_count: AtomicUsize, +} + +impl Clone for FakeEmbeddingProvider { + fn clone(&self) -> Self { + FakeEmbeddingProvider { + embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)), + } + } +} + +impl Default for FakeEmbeddingProvider { + fn default() -> Self { + FakeEmbeddingProvider { + embedding_count: AtomicUsize::default(), + } + } +} + +impl FakeEmbeddingProvider { + pub fn embedding_count(&self) -> usize { + self.embedding_count.load(atomic::Ordering::SeqCst) + } + + pub fn embed_sync(&self, span: &str) -> Embedding { + let mut result = vec![1.0; 26]; + for letter in span.chars() { + let letter = letter.to_ascii_lowercase(); + if letter as u32 >= 'a' as u32 { + let ix = (letter as u32) - ('a' as u32); + if ix < 26 { + result[ix as usize] += 1.0; + } + } + } + + let norm = result.iter().map(|x| x * x).sum::().sqrt(); + for x in &mut result { + *x /= norm; + } + + result.into() + } +} + +impl CredentialProvider for FakeEmbeddingProvider { + fn has_credentials(&self) -> bool { + true + } + fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential { + ProviderCredential::NotNeeded + } + fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {} + fn delete_credentials(&self, _cx: &AppContext) {} +} + +#[async_trait] +impl EmbeddingProvider for FakeEmbeddingProvider { + fn base_model(&self) -> Box { + Box::new(FakeLanguageModel { capacity: 1000 }) + } + fn max_tokens_per_batch(&self) -> usize { + 1000 + } + + fn rate_limit_expiration(&self) -> Option { + None + } + + async fn embed_batch(&self, spans: Vec) -> anyhow::Result> { + self.embedding_count + .fetch_add(spans.len(), atomic::Ordering::SeqCst); + + anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) + } +} + +pub struct FakeCompletionProvider { + last_completion_tx: Mutex>>, +} + +impl Clone for FakeCompletionProvider { + fn clone(&self) -> Self { + Self { + last_completion_tx: Mutex::new(None), + } + } +} + +impl FakeCompletionProvider { + pub fn new() -> Self { + Self { + last_completion_tx: Mutex::new(None), + } + } + + pub fn send_completion(&self, completion: impl Into) { + let mut tx = self.last_completion_tx.lock(); + tx.as_mut().unwrap().try_send(completion.into()).unwrap(); + } + + pub fn finish_completion(&self) { + self.last_completion_tx.lock().take().unwrap(); + } +} + +impl CredentialProvider for FakeCompletionProvider { + fn has_credentials(&self) -> bool { + true + } + fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential { + ProviderCredential::NotNeeded + } + fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {} + fn delete_credentials(&self, _cx: &AppContext) {} +} + +impl CompletionProvider for FakeCompletionProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(FakeLanguageModel { capacity: 8190 }); + model + } + fn complete( + &self, + _prompt: Box, + ) -> BoxFuture<'static, anyhow::Result>>> { + let (tx, rx) = mpsc::channel(1); + *self.last_completion_tx.lock() = Some(tx); + async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() + } + fn box_clone(&self) -> Box { + Box::new((*self).clone()) + } +} diff --git a/crates/ai2/Cargo.toml b/crates/ai2/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..4f06840e8e53bbcb06c377a6304fb8be13b85946 --- /dev/null +++ b/crates/ai2/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "ai2" +version = "0.1.0" +edition = "2021" +publish = false + +[lib] +path = "src/ai2.rs" +doctest = false + +[features] +test-support = [] + +[dependencies] +gpui2 = { path = "../gpui2" } +util = { path = "../util" } +language2 = { path = "../language2" } +async-trait.workspace = true +anyhow.workspace = true +futures.workspace = true +lazy_static.workspace = true +ordered-float.workspace = true +parking_lot.workspace = true +isahc.workspace = true +regex.workspace = true +serde.workspace = true +serde_json.workspace = true +postage.workspace = true +rand.workspace = true +log.workspace = true +parse_duration = "2.1.1" +tiktoken-rs = "0.5.0" +matrixmultiply = "0.3.7" +rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] } +bincode = "1.3.3" + +[dev-dependencies] +gpui2 = { path = "../gpui2", features = ["test-support"] } diff --git a/crates/ai2/src/ai2.rs b/crates/ai2/src/ai2.rs new file mode 100644 index 0000000000000000000000000000000000000000..dda22d2a1d04dd6083fb1ae9879f49e74c8b4627 --- /dev/null +++ b/crates/ai2/src/ai2.rs @@ -0,0 +1,8 @@ +pub mod auth; +pub mod completion; +pub mod embedding; +pub mod models; +pub mod prompts; +pub mod providers; +#[cfg(any(test, feature = "test-support"))] +pub mod test; diff --git a/crates/ai2/src/auth.rs b/crates/ai2/src/auth.rs new file mode 100644 index 0000000000000000000000000000000000000000..e4670bb449025d5ecc5f0cabe65ad6ff4727c10c --- /dev/null +++ b/crates/ai2/src/auth.rs @@ -0,0 +1,17 @@ +use async_trait::async_trait; +use gpui2::AppContext; + +#[derive(Clone, Debug)] +pub enum ProviderCredential { + Credentials { api_key: String }, + NoCredentials, + NotNeeded, +} + +#[async_trait] +pub trait CredentialProvider: Send + Sync { + fn has_credentials(&self) -> bool; + async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential; + async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential); + async fn delete_credentials(&self, cx: &mut AppContext); +} diff --git a/crates/ai2/src/completion.rs b/crates/ai2/src/completion.rs new file mode 100644 index 0000000000000000000000000000000000000000..30a60fcf1d5c5dc66717773968e432e510d6421f --- /dev/null +++ b/crates/ai2/src/completion.rs @@ -0,0 +1,23 @@ +use anyhow::Result; +use futures::{future::BoxFuture, stream::BoxStream}; + +use crate::{auth::CredentialProvider, models::LanguageModel}; + +pub trait CompletionRequest: Send + Sync { + fn data(&self) -> serde_json::Result; +} + +pub trait CompletionProvider: CredentialProvider { + fn base_model(&self) -> Box; + fn complete( + &self, + prompt: Box, + ) -> BoxFuture<'static, Result>>>; + fn box_clone(&self) -> Box; +} + +impl Clone for Box { + fn clone(&self) -> Box { + self.box_clone() + } +} diff --git a/crates/ai2/src/embedding.rs b/crates/ai2/src/embedding.rs new file mode 100644 index 0000000000000000000000000000000000000000..7ea47861782cf9002796a8b6e655989b871e0191 --- /dev/null +++ b/crates/ai2/src/embedding.rs @@ -0,0 +1,123 @@ +use std::time::Instant; + +use anyhow::Result; +use async_trait::async_trait; +use ordered_float::OrderedFloat; +use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; +use rusqlite::ToSql; + +use crate::auth::CredentialProvider; +use crate::models::LanguageModel; + +#[derive(Debug, PartialEq, Clone)] +pub struct Embedding(pub Vec); + +// This is needed for semantic index functionality +// Unfortunately it has to live wherever the "Embedding" struct is created. +// Keeping this in here though, introduces a 'rusqlite' dependency into AI +// which is less than ideal +impl FromSql for Embedding { + fn column_result(value: ValueRef) -> FromSqlResult { + let bytes = value.as_blob()?; + let embedding: Result, Box> = bincode::deserialize(bytes); + if embedding.is_err() { + return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err())); + } + Ok(Embedding(embedding.unwrap())) + } +} + +impl ToSql for Embedding { + fn to_sql(&self) -> rusqlite::Result { + let bytes = bincode::serialize(&self.0) + .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?; + Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes))) + } +} +impl From> for Embedding { + fn from(value: Vec) -> Self { + Embedding(value) + } +} + +impl Embedding { + pub fn similarity(&self, other: &Self) -> OrderedFloat { + let len = self.0.len(); + assert_eq!(len, other.0.len()); + + let mut result = 0.0; + unsafe { + matrixmultiply::sgemm( + 1, + len, + 1, + 1.0, + self.0.as_ptr(), + len as isize, + 1, + other.0.as_ptr(), + 1, + len as isize, + 0.0, + &mut result as *mut f32, + 1, + 1, + ); + } + OrderedFloat(result) + } +} + +#[async_trait] +pub trait EmbeddingProvider: CredentialProvider { + fn base_model(&self) -> Box; + async fn embed_batch(&self, spans: Vec) -> Result>; + fn max_tokens_per_batch(&self) -> usize; + fn rate_limit_expiration(&self) -> Option; +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::prelude::*; + + #[gpui2::test] + fn test_similarity(mut rng: StdRng) { + assert_eq!( + Embedding::from(vec![1., 0., 0., 0., 0.]) + .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])), + 0. + ); + assert_eq!( + Embedding::from(vec![2., 0., 0., 0., 0.]) + .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])), + 6. + ); + + for _ in 0..100 { + let size = 1536; + let mut a = vec![0.; size]; + let mut b = vec![0.; size]; + for (a, b) in a.iter_mut().zip(b.iter_mut()) { + *a = rng.gen(); + *b = rng.gen(); + } + let a = Embedding::from(a); + let b = Embedding::from(b); + + assert_eq!( + round_to_decimals(a.similarity(&b), 1), + round_to_decimals(reference_dot(&a.0, &b.0), 1) + ); + } + + fn round_to_decimals(n: OrderedFloat, decimal_places: i32) -> f32 { + let factor = (10.0 as f32).powi(decimal_places); + (n * factor).round() / factor + } + + fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat { + OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()) + } + } +} diff --git a/crates/ai2/src/models.rs b/crates/ai2/src/models.rs new file mode 100644 index 0000000000000000000000000000000000000000..1db3d58c6f54ad613cb98fc3f425df3d47e5e97f --- /dev/null +++ b/crates/ai2/src/models.rs @@ -0,0 +1,16 @@ +pub enum TruncationDirection { + Start, + End, +} + +pub trait LanguageModel { + fn name(&self) -> String; + fn count_tokens(&self, content: &str) -> anyhow::Result; + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result; + fn capacity(&self) -> anyhow::Result; +} diff --git a/crates/ai2/src/prompts/base.rs b/crates/ai2/src/prompts/base.rs new file mode 100644 index 0000000000000000000000000000000000000000..29091d0f5b435b556a0c2cae60aa4526370832ab --- /dev/null +++ b/crates/ai2/src/prompts/base.rs @@ -0,0 +1,330 @@ +use std::cmp::Reverse; +use std::ops::Range; +use std::sync::Arc; + +use language2::BufferSnapshot; +use util::ResultExt; + +use crate::models::LanguageModel; +use crate::prompts::repository_context::PromptCodeSnippet; + +pub(crate) enum PromptFileType { + Text, + Code, +} + +// TODO: Set this up to manage for defaults well +pub struct PromptArguments { + pub model: Arc, + pub user_prompt: Option, + pub language_name: Option, + pub project_name: Option, + pub snippets: Vec, + pub reserved_tokens: usize, + pub buffer: Option, + pub selected_range: Option>, +} + +impl PromptArguments { + pub(crate) fn get_file_type(&self) -> PromptFileType { + if self + .language_name + .as_ref() + .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str()))) + .unwrap_or(true) + { + PromptFileType::Code + } else { + PromptFileType::Text + } + } +} + +pub trait PromptTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)>; +} + +#[repr(i8)] +#[derive(PartialEq, Eq, Ord)] +pub enum PromptPriority { + Mandatory, // Ignores truncation + Ordered { order: usize }, // Truncates based on priority +} + +impl PartialOrd for PromptPriority { + fn partial_cmp(&self, other: &Self) -> Option { + match (self, other) { + (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal), + (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater), + (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less), + (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a), + } + } +} + +pub struct PromptChain { + args: PromptArguments, + templates: Vec<(PromptPriority, Box)>, +} + +impl PromptChain { + pub fn new( + args: PromptArguments, + templates: Vec<(PromptPriority, Box)>, + ) -> Self { + PromptChain { args, templates } + } + + pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> { + // Argsort based on Prompt Priority + let seperator = "\n"; + let seperator_tokens = self.args.model.count_tokens(seperator)?; + let mut sorted_indices = (0..self.templates.len()).collect::>(); + sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0)); + + // If Truncate + let mut tokens_outstanding = if truncate { + Some(self.args.model.capacity()? - self.args.reserved_tokens) + } else { + None + }; + + let mut prompts = vec!["".to_string(); sorted_indices.len()]; + for idx in sorted_indices { + let (_, template) = &self.templates[idx]; + + if let Some((template_prompt, prompt_token_count)) = + template.generate(&self.args, tokens_outstanding).log_err() + { + if template_prompt != "" { + prompts[idx] = template_prompt; + + if let Some(remaining_tokens) = tokens_outstanding { + let new_tokens = prompt_token_count + seperator_tokens; + tokens_outstanding = if remaining_tokens > new_tokens { + Some(remaining_tokens - new_tokens) + } else { + Some(0) + }; + } + } + } + } + + prompts.retain(|x| x != ""); + + let full_prompt = prompts.join(seperator); + let total_token_count = self.args.model.count_tokens(&full_prompt)?; + anyhow::Ok((prompts.join(seperator), total_token_count)) + } +} + +#[cfg(test)] +pub(crate) mod tests { + use crate::models::TruncationDirection; + use crate::test::FakeLanguageModel; + + use super::*; + + #[test] + pub fn test_prompt_chain() { + struct TestPromptTemplate {} + impl PromptTemplate for TestPromptTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut content = "This is a test prompt template".to_string(); + + let mut token_count = args.model.count_tokens(&content)?; + if let Some(max_token_length) = max_token_length { + if token_count > max_token_length { + content = args.model.truncate( + &content, + max_token_length, + TruncationDirection::End, + )?; + token_count = max_token_length; + } + } + + anyhow::Ok((content, token_count)) + } + } + + struct TestLowPriorityTemplate {} + impl PromptTemplate for TestLowPriorityTemplate { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut content = "This is a low priority test prompt template".to_string(); + + let mut token_count = args.model.count_tokens(&content)?; + if let Some(max_token_length) = max_token_length { + if token_count > max_token_length { + content = args.model.truncate( + &content, + max_token_length, + TruncationDirection::End, + )?; + token_count = max_token_length; + } + } + + anyhow::Ok((content, token_count)) + } + } + + let model: Arc = Arc::new(FakeLanguageModel { capacity: 100 }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(false).unwrap(); + + assert_eq!( + prompt, + "This is a test prompt template\nThis is a low priority test prompt template" + .to_string() + ); + + assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); + + // Testing with Truncation Off + // Should ignore capacity and return all prompts + let model: Arc = Arc::new(FakeLanguageModel { capacity: 20 }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(false).unwrap(); + + assert_eq!( + prompt, + "This is a test prompt template\nThis is a low priority test prompt template" + .to_string() + ); + + assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); + + // Testing with Truncation Off + // Should ignore capacity and return all prompts + let capacity = 20; + let model: Arc = Arc::new(FakeLanguageModel { capacity }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens: 0, + buffer: None, + selected_range: None, + user_prompt: None, + }; + + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ( + PromptPriority::Ordered { order: 2 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(true).unwrap(); + + assert_eq!(prompt, "This is a test promp".to_string()); + assert_eq!(token_count, capacity); + + // Change Ordering of Prompts Based on Priority + let capacity = 120; + let reserved_tokens = 10; + let model: Arc = Arc::new(FakeLanguageModel { capacity }); + let args = PromptArguments { + model: model.clone(), + language_name: None, + project_name: None, + snippets: Vec::new(), + reserved_tokens, + buffer: None, + selected_range: None, + user_prompt: None, + }; + let templates: Vec<(PromptPriority, Box)> = vec![ + ( + PromptPriority::Mandatory, + Box::new(TestLowPriorityTemplate {}), + ), + ( + PromptPriority::Ordered { order: 0 }, + Box::new(TestPromptTemplate {}), + ), + ( + PromptPriority::Ordered { order: 1 }, + Box::new(TestLowPriorityTemplate {}), + ), + ]; + let chain = PromptChain::new(args, templates); + + let (prompt, token_count) = chain.generate(true).unwrap(); + + assert_eq!( + prompt, + "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt " + .to_string() + ); + assert_eq!(token_count, capacity - reserved_tokens); + } +} diff --git a/crates/ai2/src/prompts/file_context.rs b/crates/ai2/src/prompts/file_context.rs new file mode 100644 index 0000000000000000000000000000000000000000..4a741beb24984c5c038ef10439d5b182438d2866 --- /dev/null +++ b/crates/ai2/src/prompts/file_context.rs @@ -0,0 +1,164 @@ +use anyhow::anyhow; +use language2::BufferSnapshot; +use language2::ToOffset; + +use crate::models::LanguageModel; +use crate::models::TruncationDirection; +use crate::prompts::base::PromptArguments; +use crate::prompts::base::PromptTemplate; +use std::fmt::Write; +use std::ops::Range; +use std::sync::Arc; + +fn retrieve_context( + buffer: &BufferSnapshot, + selected_range: &Option>, + model: Arc, + max_token_count: Option, +) -> anyhow::Result<(String, usize, bool)> { + let mut prompt = String::new(); + let mut truncated = false; + if let Some(selected_range) = selected_range { + let start = selected_range.start.to_offset(buffer); + let end = selected_range.end.to_offset(buffer); + + let start_window = buffer.text_for_range(0..start).collect::(); + + let mut selected_window = String::new(); + if start == end { + write!(selected_window, "<|START|>").unwrap(); + } else { + write!(selected_window, "<|START|").unwrap(); + } + + write!( + selected_window, + "{}", + buffer.text_for_range(start..end).collect::() + ) + .unwrap(); + + if start != end { + write!(selected_window, "|END|>").unwrap(); + } + + let end_window = buffer.text_for_range(end..buffer.len()).collect::(); + + if let Some(max_token_count) = max_token_count { + let selected_tokens = model.count_tokens(&selected_window)?; + if selected_tokens > max_token_count { + return Err(anyhow!( + "selected range is greater than model context window, truncation not possible" + )); + }; + + let mut remaining_tokens = max_token_count - selected_tokens; + let start_window_tokens = model.count_tokens(&start_window)?; + let end_window_tokens = model.count_tokens(&end_window)?; + let outside_tokens = start_window_tokens + end_window_tokens; + if outside_tokens > remaining_tokens { + let (start_goal_tokens, end_goal_tokens) = + if start_window_tokens < end_window_tokens { + let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens); + remaining_tokens -= start_goal_tokens; + let end_goal_tokens = remaining_tokens.min(end_window_tokens); + (start_goal_tokens, end_goal_tokens) + } else { + let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens); + remaining_tokens -= end_goal_tokens; + let start_goal_tokens = remaining_tokens.min(start_window_tokens); + (start_goal_tokens, end_goal_tokens) + }; + + let truncated_start_window = + model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?; + let truncated_end_window = + model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?; + writeln!( + prompt, + "{truncated_start_window}{selected_window}{truncated_end_window}" + ) + .unwrap(); + truncated = true; + } else { + writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap(); + } + } else { + // If we dont have a selected range, include entire file. + writeln!(prompt, "{}", &buffer.text()).unwrap(); + + // Dumb truncation strategy + if let Some(max_token_count) = max_token_count { + if model.count_tokens(&prompt)? > max_token_count { + truncated = true; + prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?; + } + } + } + } + + let token_count = model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count, truncated)) +} + +pub struct FileContext {} + +impl PromptTemplate for FileContext { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + if let Some(buffer) = &args.buffer { + let mut prompt = String::new(); + // Add Initial Preamble + // TODO: Do we want to add the path in here? + writeln!( + prompt, + "The file you are currently working on has the following content:" + ) + .unwrap(); + + let language_name = args + .language_name + .clone() + .unwrap_or("".to_string()) + .to_lowercase(); + + let (context, _, truncated) = retrieve_context( + buffer, + &args.selected_range, + args.model.clone(), + max_token_length, + )?; + writeln!(prompt, "```{language_name}\n{context}\n```").unwrap(); + + if truncated { + writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap(); + } + + if let Some(selected_range) = &args.selected_range { + let start = selected_range.start.to_offset(buffer); + let end = selected_range.end.to_offset(buffer); + + if start == end { + writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap(); + } else { + writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap(); + } + } + + // Really dumb truncation strategy + if let Some(max_tokens) = max_token_length { + prompt = args + .model + .truncate(&prompt, max_tokens, TruncationDirection::End)?; + } + + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } else { + Err(anyhow!("no buffer provided to retrieve file context from")) + } + } +} diff --git a/crates/ai2/src/prompts/generate.rs b/crates/ai2/src/prompts/generate.rs new file mode 100644 index 0000000000000000000000000000000000000000..c7be620107ee4d6daca06a8cb38019aceedc40a4 --- /dev/null +++ b/crates/ai2/src/prompts/generate.rs @@ -0,0 +1,99 @@ +use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate}; +use anyhow::anyhow; +use std::fmt::Write; + +pub fn capitalize(s: &str) -> String { + let mut c = s.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + c.as_str(), + } +} + +pub struct GenerateInlineContent {} + +impl PromptTemplate for GenerateInlineContent { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let Some(user_prompt) = &args.user_prompt else { + return Err(anyhow!("user prompt not provided")); + }; + + let file_type = args.get_file_type(); + let content_type = match &file_type { + PromptFileType::Code => "code", + PromptFileType::Text => "text", + }; + + let mut prompt = String::new(); + + if let Some(selected_range) = &args.selected_range { + if selected_range.start == selected_range.end { + writeln!( + prompt, + "Assume the cursor is located where the `<|START|>` span is." + ) + .unwrap(); + writeln!( + prompt, + "{} can't be replaced, so assume your answer will be inserted at the cursor.", + capitalize(content_type) + ) + .unwrap(); + writeln!( + prompt, + "Generate {content_type} based on the users prompt: {user_prompt}", + ) + .unwrap(); + } else { + writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap(); + writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap(); + writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap(); + } + } else { + writeln!( + prompt, + "Generate {content_type} based on the users prompt: {user_prompt}" + ) + .unwrap(); + } + + if let Some(language_name) = &args.language_name { + writeln!( + prompt, + "Your answer MUST always and only be valid {}.", + language_name + ) + .unwrap(); + } + writeln!(prompt, "Never make remarks about the output.").unwrap(); + writeln!( + prompt, + "Do not return anything else, except the generated {content_type}." + ) + .unwrap(); + + match file_type { + PromptFileType::Code => { + // writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap(); + } + _ => {} + } + + // Really dumb truncation strategy + if let Some(max_tokens) = max_token_length { + prompt = args.model.truncate( + &prompt, + max_tokens, + crate::models::TruncationDirection::End, + )?; + } + + let token_count = args.model.count_tokens(&prompt)?; + + anyhow::Ok((prompt, token_count)) + } +} diff --git a/crates/ai2/src/prompts/mod.rs b/crates/ai2/src/prompts/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..0025269a440d1e6ead6a81615a64a3c28da62bb8 --- /dev/null +++ b/crates/ai2/src/prompts/mod.rs @@ -0,0 +1,5 @@ +pub mod base; +pub mod file_context; +pub mod generate; +pub mod preamble; +pub mod repository_context; diff --git a/crates/ai2/src/prompts/preamble.rs b/crates/ai2/src/prompts/preamble.rs new file mode 100644 index 0000000000000000000000000000000000000000..92e0edeb78b48169379aae2e88e81f62463a1057 --- /dev/null +++ b/crates/ai2/src/prompts/preamble.rs @@ -0,0 +1,52 @@ +use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate}; +use std::fmt::Write; + +pub struct EngineerPreamble {} + +impl PromptTemplate for EngineerPreamble { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + let mut prompts = Vec::new(); + + match args.get_file_type() { + PromptFileType::Code => { + prompts.push(format!( + "You are an expert {}engineer.", + args.language_name.clone().unwrap_or("".to_string()) + " " + )); + } + PromptFileType::Text => { + prompts.push("You are an expert engineer.".to_string()); + } + } + + if let Some(project_name) = args.project_name.clone() { + prompts.push(format!( + "You are currently working inside the '{project_name}' project in code editor Zed." + )); + } + + if let Some(mut remaining_tokens) = max_token_length { + let mut prompt = String::new(); + let mut total_count = 0; + for prompt_piece in prompts { + let prompt_token_count = + args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?; + if remaining_tokens > prompt_token_count { + writeln!(prompt, "{prompt_piece}").unwrap(); + remaining_tokens -= prompt_token_count; + total_count += prompt_token_count; + } + } + + anyhow::Ok((prompt, total_count)) + } else { + let prompt = prompts.join("\n"); + let token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, token_count)) + } + } +} diff --git a/crates/ai2/src/prompts/repository_context.rs b/crates/ai2/src/prompts/repository_context.rs new file mode 100644 index 0000000000000000000000000000000000000000..1bb75de7d242d34315238f2c43baeb0c016dbdfb --- /dev/null +++ b/crates/ai2/src/prompts/repository_context.rs @@ -0,0 +1,98 @@ +use crate::prompts::base::{PromptArguments, PromptTemplate}; +use std::fmt::Write; +use std::{ops::Range, path::PathBuf}; + +use gpui2::{AsyncAppContext, Model}; +use language2::{Anchor, Buffer}; + +#[derive(Clone)] +pub struct PromptCodeSnippet { + path: Option, + language_name: Option, + content: String, +} + +impl PromptCodeSnippet { + pub fn new( + buffer: Model, + range: Range, + cx: &mut AsyncAppContext, + ) -> anyhow::Result { + let (content, language_name, file_path) = buffer.update(cx, |buffer, _| { + let snapshot = buffer.snapshot(); + let content = snapshot.text_for_range(range.clone()).collect::(); + + let language_name = buffer + .language() + .and_then(|language| Some(language.name().to_string().to_lowercase())); + + let file_path = buffer + .file() + .and_then(|file| Some(file.path().to_path_buf())); + + (content, language_name, file_path) + })?; + + anyhow::Ok(PromptCodeSnippet { + path: file_path, + language_name, + content, + }) + } +} + +impl ToString for PromptCodeSnippet { + fn to_string(&self) -> String { + let path = self + .path + .as_ref() + .and_then(|path| Some(path.to_string_lossy().to_string())) + .unwrap_or("".to_string()); + let language_name = self.language_name.clone().unwrap_or("".to_string()); + let content = self.content.clone(); + + format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```") + } +} + +pub struct RepositoryContext {} + +impl PromptTemplate for RepositoryContext { + fn generate( + &self, + args: &PromptArguments, + max_token_length: Option, + ) -> anyhow::Result<(String, usize)> { + const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; + let template = "You are working inside a large repository, here are a few code snippets that may be useful."; + let mut prompt = String::new(); + + let mut remaining_tokens = max_token_length.clone(); + let seperator_token_length = args.model.count_tokens("\n")?; + for snippet in &args.snippets { + let mut snippet_prompt = template.to_string(); + let content = snippet.to_string(); + writeln!(snippet_prompt, "{content}").unwrap(); + + let token_count = args.model.count_tokens(&snippet_prompt)?; + if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT { + if let Some(tokens_left) = remaining_tokens { + if tokens_left >= token_count { + writeln!(prompt, "{snippet_prompt}").unwrap(); + remaining_tokens = if tokens_left >= (token_count + seperator_token_length) + { + Some(tokens_left - token_count - seperator_token_length) + } else { + Some(0) + }; + } + } else { + writeln!(prompt, "{snippet_prompt}").unwrap(); + } + } + } + + let total_token_count = args.model.count_tokens(&prompt)?; + anyhow::Ok((prompt, total_token_count)) + } +} diff --git a/crates/ai2/src/providers/mod.rs b/crates/ai2/src/providers/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..acd0f9d91053869e3999ef0c1a23326480a7cbdd --- /dev/null +++ b/crates/ai2/src/providers/mod.rs @@ -0,0 +1 @@ +pub mod open_ai; diff --git a/crates/ai2/src/providers/open_ai/completion.rs b/crates/ai2/src/providers/open_ai/completion.rs new file mode 100644 index 0000000000000000000000000000000000000000..eca56110271a3be407a1c8b9f82cbb63c41bef23 --- /dev/null +++ b/crates/ai2/src/providers/open_ai/completion.rs @@ -0,0 +1,306 @@ +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use futures::{ + future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, + Stream, StreamExt, +}; +use gpui2::{AppContext, Executor}; +use isahc::{http::StatusCode, Request, RequestExt}; +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; +use std::{ + env, + fmt::{self, Display}, + io, + sync::Arc, +}; +use util::ResultExt; + +use crate::{ + auth::{CredentialProvider, ProviderCredential}, + completion::{CompletionProvider, CompletionRequest}, + models::LanguageModel, +}; + +use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL}; + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, +} + +impl Role { + pub fn cycle(&mut self) { + *self = match self { + Role::User => Role::Assistant, + Role::Assistant => Role::System, + Role::System => Role::User, + } + } +} + +impl Display for Role { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Role::User => write!(f, "User"), + Role::Assistant => write!(f, "Assistant"), + Role::System => write!(f, "System"), + } + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct RequestMessage { + pub role: Role, + pub content: String, +} + +#[derive(Debug, Default, Serialize)] +pub struct OpenAIRequest { + pub model: String, + pub messages: Vec, + pub stream: bool, + pub stop: Vec, + pub temperature: f32, +} + +impl CompletionRequest for OpenAIRequest { + fn data(&self) -> serde_json::Result { + serde_json::to_string(self) + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ResponseMessage { + pub role: Option, + pub content: Option, +} + +#[derive(Deserialize, Debug)] +pub struct OpenAIUsage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Deserialize, Debug)] +pub struct ChatChoiceDelta { + pub index: u32, + pub delta: ResponseMessage, + pub finish_reason: Option, +} + +#[derive(Deserialize, Debug)] +pub struct OpenAIResponseStreamEvent { + pub id: Option, + pub object: String, + pub created: u32, + pub model: String, + pub choices: Vec, + pub usage: Option, +} + +pub async fn stream_completion( + credential: ProviderCredential, + executor: Arc, + request: Box, +) -> Result>> { + let api_key = match credential { + ProviderCredential::Credentials { api_key } => api_key, + _ => { + return Err(anyhow!("no credentials provider for completion")); + } + }; + + let (tx, rx) = futures::channel::mpsc::unbounded::>(); + + let json_data = request.data()?; + let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body(json_data)? + .send_async() + .await?; + + let status = response.status(); + if status == StatusCode::OK { + executor + .spawn(async move { + let mut lines = BufReader::new(response.body_mut()).lines(); + + fn parse_line( + line: Result, + ) -> Result> { + if let Some(data) = line?.strip_prefix("data: ") { + let event = serde_json::from_str(&data)?; + Ok(Some(event)) + } else { + Ok(None) + } + } + + while let Some(line) = lines.next().await { + if let Some(event) = parse_line(line).transpose() { + let done = event.as_ref().map_or(false, |event| { + event + .choices + .last() + .map_or(false, |choice| choice.finish_reason.is_some()) + }); + if tx.unbounded_send(event).is_err() { + break; + } + + if done { + break; + } + } + } + + anyhow::Ok(()) + }) + .detach(); + + Ok(rx) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + #[derive(Deserialize)] + struct OpenAIResponse { + error: OpenAIError, + } + + #[derive(Deserialize)] + struct OpenAIError { + message: String, + } + + match serde_json::from_str::(&body) { + Ok(response) if !response.error.message.is_empty() => Err(anyhow!( + "Failed to connect to OpenAI API: {}", + response.error.message, + )), + + _ => Err(anyhow!( + "Failed to connect to OpenAI API: {} {}", + response.status(), + body, + )), + } + } +} + +#[derive(Clone)] +pub struct OpenAICompletionProvider { + model: OpenAILanguageModel, + credential: Arc>, + executor: Arc, +} + +impl OpenAICompletionProvider { + pub fn new(model_name: &str, executor: Arc) -> Self { + let model = OpenAILanguageModel::load(model_name); + let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); + Self { + model, + credential, + executor, + } + } +} + +#[async_trait] +impl CredentialProvider for OpenAICompletionProvider { + fn has_credentials(&self) -> bool { + match *self.credential.read() { + ProviderCredential::Credentials { .. } => true, + _ => false, + } + } + async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential { + let existing_credential = self.credential.read().clone(); + + let retrieved_credential = cx + .run_on_main(move |cx| match existing_credential { + ProviderCredential::Credentials { .. } => { + return existing_credential.clone(); + } + _ => { + if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() { + return ProviderCredential::Credentials { api_key }; + } + + if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err() + { + if let Some(api_key) = String::from_utf8(api_key).log_err() { + return ProviderCredential::Credentials { api_key }; + } else { + return ProviderCredential::NoCredentials; + } + } else { + return ProviderCredential::NoCredentials; + } + } + }) + .await; + + *self.credential.write() = retrieved_credential.clone(); + retrieved_credential + } + + async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) { + *self.credential.write() = credential.clone(); + let credential = credential.clone(); + cx.run_on_main(move |cx| match credential { + ProviderCredential::Credentials { api_key } => { + cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) + .log_err(); + } + _ => {} + }) + .await; + } + async fn delete_credentials(&self, cx: &mut AppContext) { + cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err()) + .await; + *self.credential.write() = ProviderCredential::NoCredentials; + } +} + +impl CompletionProvider for OpenAICompletionProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(self.model.clone()); + model + } + fn complete( + &self, + prompt: Box, + ) -> BoxFuture<'static, Result>>> { + // Currently the CompletionRequest for OpenAI, includes a 'model' parameter + // This means that the model is determined by the CompletionRequest and not the CompletionProvider, + // which is currently model based, due to the langauge model. + // At some point in the future we should rectify this. + let credential = self.credential.read().clone(); + let request = stream_completion(credential, self.executor.clone(), prompt); + async move { + let response = request.await?; + let stream = response + .filter_map(|response| async move { + match response { + Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), + Err(error) => Some(Err(error)), + } + }) + .boxed(); + Ok(stream) + } + .boxed() + } + fn box_clone(&self) -> Box { + Box::new((*self).clone()) + } +} diff --git a/crates/ai2/src/providers/open_ai/embedding.rs b/crates/ai2/src/providers/open_ai/embedding.rs new file mode 100644 index 0000000000000000000000000000000000000000..fc49c15134d0aba787968acbd412daff60ce6106 --- /dev/null +++ b/crates/ai2/src/providers/open_ai/embedding.rs @@ -0,0 +1,313 @@ +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use futures::AsyncReadExt; +use gpui2::Executor; +use gpui2::{serde_json, AppContext}; +use isahc::http::StatusCode; +use isahc::prelude::Configurable; +use isahc::{AsyncBody, Response}; +use lazy_static::lazy_static; +use parking_lot::{Mutex, RwLock}; +use parse_duration::parse; +use postage::watch; +use serde::{Deserialize, Serialize}; +use std::env; +use std::ops::Add; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tiktoken_rs::{cl100k_base, CoreBPE}; +use util::http::{HttpClient, Request}; +use util::ResultExt; + +use crate::auth::{CredentialProvider, ProviderCredential}; +use crate::embedding::{Embedding, EmbeddingProvider}; +use crate::models::LanguageModel; +use crate::providers::open_ai::OpenAILanguageModel; + +use crate::providers::open_ai::OPENAI_API_URL; + +lazy_static! { + static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap(); +} + +#[derive(Clone)] +pub struct OpenAIEmbeddingProvider { + model: OpenAILanguageModel, + credential: Arc>, + pub client: Arc, + pub executor: Arc, + rate_limit_count_rx: watch::Receiver>, + rate_limit_count_tx: Arc>>>, +} + +#[derive(Serialize)] +struct OpenAIEmbeddingRequest<'a> { + model: &'static str, + input: Vec<&'a str>, +} + +#[derive(Deserialize)] +struct OpenAIEmbeddingResponse { + data: Vec, + usage: OpenAIEmbeddingUsage, +} + +#[derive(Debug, Deserialize)] +struct OpenAIEmbedding { + embedding: Vec, + index: usize, + object: String, +} + +#[derive(Deserialize)] +struct OpenAIEmbeddingUsage { + prompt_tokens: usize, + total_tokens: usize, +} + +impl OpenAIEmbeddingProvider { + pub fn new(client: Arc, executor: Arc) -> Self { + let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); + let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); + + let model = OpenAILanguageModel::load("text-embedding-ada-002"); + let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); + + OpenAIEmbeddingProvider { + model, + credential, + client, + executor, + rate_limit_count_rx, + rate_limit_count_tx, + } + } + + fn get_api_key(&self) -> Result { + match self.credential.read().clone() { + ProviderCredential::Credentials { api_key } => Ok(api_key), + _ => Err(anyhow!("api credentials not provided")), + } + } + + fn resolve_rate_limit(&self) { + let reset_time = *self.rate_limit_count_tx.lock().borrow(); + + if let Some(reset_time) = reset_time { + if Instant::now() >= reset_time { + *self.rate_limit_count_tx.lock().borrow_mut() = None + } + } + + log::trace!( + "resolving reset time: {:?}", + *self.rate_limit_count_tx.lock().borrow() + ); + } + + fn update_reset_time(&self, reset_time: Instant) { + let original_time = *self.rate_limit_count_tx.lock().borrow(); + + let updated_time = if let Some(original_time) = original_time { + if reset_time < original_time { + Some(reset_time) + } else { + Some(original_time) + } + } else { + Some(reset_time) + }; + + log::trace!("updating rate limit time: {:?}", updated_time); + + *self.rate_limit_count_tx.lock().borrow_mut() = updated_time; + } + async fn send_request( + &self, + api_key: &str, + spans: Vec<&str>, + request_timeout: u64, + ) -> Result> { + let request = Request::post("https://api.openai.com/v1/embeddings") + .redirect_policy(isahc::config::RedirectPolicy::Follow) + .timeout(Duration::from_secs(request_timeout)) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body( + serde_json::to_string(&OpenAIEmbeddingRequest { + input: spans.clone(), + model: "text-embedding-ada-002", + }) + .unwrap() + .into(), + )?; + + Ok(self.client.send(request).await?) + } +} + +#[async_trait] +impl CredentialProvider for OpenAIEmbeddingProvider { + fn has_credentials(&self) -> bool { + match *self.credential.read() { + ProviderCredential::Credentials { .. } => true, + _ => false, + } + } + async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential { + let existing_credential = self.credential.read().clone(); + + let retrieved_credential = cx + .run_on_main(move |cx| match existing_credential { + ProviderCredential::Credentials { .. } => { + return existing_credential.clone(); + } + _ => { + if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() { + return ProviderCredential::Credentials { api_key }; + } + + if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err() + { + if let Some(api_key) = String::from_utf8(api_key).log_err() { + return ProviderCredential::Credentials { api_key }; + } else { + return ProviderCredential::NoCredentials; + } + } else { + return ProviderCredential::NoCredentials; + } + } + }) + .await; + + *self.credential.write() = retrieved_credential.clone(); + retrieved_credential + } + + async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) { + *self.credential.write() = credential.clone(); + let credential = credential.clone(); + cx.run_on_main(move |cx| match credential { + ProviderCredential::Credentials { api_key } => { + cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) + .log_err(); + } + _ => {} + }) + .await; + } + async fn delete_credentials(&self, cx: &mut AppContext) { + cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err()) + .await; + *self.credential.write() = ProviderCredential::NoCredentials; + } +} + +#[async_trait] +impl EmbeddingProvider for OpenAIEmbeddingProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(self.model.clone()); + model + } + + fn max_tokens_per_batch(&self) -> usize { + 50000 + } + + fn rate_limit_expiration(&self) -> Option { + *self.rate_limit_count_rx.borrow() + } + + async fn embed_batch(&self, spans: Vec) -> Result> { + const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; + const MAX_RETRIES: usize = 4; + + let api_key = self.get_api_key()?; + + let mut request_number = 0; + let mut rate_limiting = false; + let mut request_timeout: u64 = 15; + let mut response: Response; + while request_number < MAX_RETRIES { + response = self + .send_request( + &api_key, + spans.iter().map(|x| &**x).collect(), + request_timeout, + ) + .await?; + + request_number += 1; + + match response.status() { + StatusCode::REQUEST_TIMEOUT => { + request_timeout += 5; + } + StatusCode::OK => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?; + + log::trace!( + "openai embedding completed. tokens: {:?}", + response.usage.total_tokens + ); + + // If we complete a request successfully that was previously rate_limited + // resolve the rate limit + if rate_limiting { + self.resolve_rate_limit() + } + + return Ok(response + .data + .into_iter() + .map(|embedding| Embedding::from(embedding.embedding)) + .collect()); + } + StatusCode::TOO_MANY_REQUESTS => { + rate_limiting = true; + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + let delay_duration = { + let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); + if let Some(time_to_reset) = + response.headers().get("x-ratelimit-reset-tokens") + { + if let Ok(time_str) = time_to_reset.to_str() { + parse(time_str).unwrap_or(delay) + } else { + delay + } + } else { + delay + } + }; + + // If we've previously rate limited, increment the duration but not the count + let reset_time = Instant::now().add(delay_duration); + self.update_reset_time(reset_time); + + log::trace!( + "openai rate limiting: waiting {:?} until lifted", + &delay_duration + ); + + self.executor.timer(delay_duration).await; + } + _ => { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + return Err(anyhow!( + "open ai bad request: {:?} {:?}", + &response.status(), + body + )); + } + } + } + Err(anyhow!("openai max retries")) + } +} diff --git a/crates/ai2/src/providers/open_ai/mod.rs b/crates/ai2/src/providers/open_ai/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..7d2f86045d9e1049c55111aef175ac9b56dc7e16 --- /dev/null +++ b/crates/ai2/src/providers/open_ai/mod.rs @@ -0,0 +1,9 @@ +pub mod completion; +pub mod embedding; +pub mod model; + +pub use completion::*; +pub use embedding::*; +pub use model::OpenAILanguageModel; + +pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; diff --git a/crates/ai2/src/providers/open_ai/model.rs b/crates/ai2/src/providers/open_ai/model.rs new file mode 100644 index 0000000000000000000000000000000000000000..6e306c80b905865c011c9064934827085ca126d6 --- /dev/null +++ b/crates/ai2/src/providers/open_ai/model.rs @@ -0,0 +1,57 @@ +use anyhow::anyhow; +use tiktoken_rs::CoreBPE; +use util::ResultExt; + +use crate::models::{LanguageModel, TruncationDirection}; + +#[derive(Clone)] +pub struct OpenAILanguageModel { + name: String, + bpe: Option, +} + +impl OpenAILanguageModel { + pub fn load(model_name: &str) -> Self { + let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err(); + OpenAILanguageModel { + name: model_name.to_string(), + bpe, + } + } +} + +impl LanguageModel for OpenAILanguageModel { + fn name(&self) -> String { + self.name.clone() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + anyhow::Ok(bpe.encode_with_special_tokens(content).len()) + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { + if let Some(bpe) = &self.bpe { + let tokens = bpe.encode_with_special_tokens(content); + if tokens.len() > length { + match direction { + TruncationDirection::End => bpe.decode(tokens[..length].to_vec()), + TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()), + } + } else { + bpe.decode(tokens) + } + } else { + Err(anyhow!("bpe for open ai model was not retrieved")) + } + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name)) + } +} diff --git a/crates/ai2/src/providers/open_ai/new.rs b/crates/ai2/src/providers/open_ai/new.rs new file mode 100644 index 0000000000000000000000000000000000000000..c7d67f2ba1d252a6865124d8ffdfb79130a8c3a0 --- /dev/null +++ b/crates/ai2/src/providers/open_ai/new.rs @@ -0,0 +1,11 @@ +pub trait LanguageModel { + fn name(&self) -> String; + fn count_tokens(&self, content: &str) -> anyhow::Result; + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result; + fn capacity(&self) -> anyhow::Result; +} diff --git a/crates/ai2/src/test.rs b/crates/ai2/src/test.rs new file mode 100644 index 0000000000000000000000000000000000000000..ee88529aecb004ce3b725fb61abd679359673404 --- /dev/null +++ b/crates/ai2/src/test.rs @@ -0,0 +1,193 @@ +use std::{ + sync::atomic::{self, AtomicUsize, Ordering}, + time::Instant, +}; + +use async_trait::async_trait; +use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use gpui2::AppContext; +use parking_lot::Mutex; + +use crate::{ + auth::{CredentialProvider, ProviderCredential}, + completion::{CompletionProvider, CompletionRequest}, + embedding::{Embedding, EmbeddingProvider}, + models::{LanguageModel, TruncationDirection}, +}; + +#[derive(Clone)] +pub struct FakeLanguageModel { + pub capacity: usize, +} + +impl LanguageModel for FakeLanguageModel { + fn name(&self) -> String { + "dummy".to_string() + } + fn count_tokens(&self, content: &str) -> anyhow::Result { + anyhow::Ok(content.chars().collect::>().len()) + } + fn truncate( + &self, + content: &str, + length: usize, + direction: TruncationDirection, + ) -> anyhow::Result { + println!("TRYING TO TRUNCATE: {:?}", length.clone()); + + if length > self.count_tokens(content)? { + println!("NOT TRUNCATING"); + return anyhow::Ok(content.to_string()); + } + + anyhow::Ok(match direction { + TruncationDirection::End => content.chars().collect::>()[..length] + .into_iter() + .collect::(), + TruncationDirection::Start => content.chars().collect::>()[length..] + .into_iter() + .collect::(), + }) + } + fn capacity(&self) -> anyhow::Result { + anyhow::Ok(self.capacity) + } +} + +pub struct FakeEmbeddingProvider { + pub embedding_count: AtomicUsize, +} + +impl Clone for FakeEmbeddingProvider { + fn clone(&self) -> Self { + FakeEmbeddingProvider { + embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)), + } + } +} + +impl Default for FakeEmbeddingProvider { + fn default() -> Self { + FakeEmbeddingProvider { + embedding_count: AtomicUsize::default(), + } + } +} + +impl FakeEmbeddingProvider { + pub fn embedding_count(&self) -> usize { + self.embedding_count.load(atomic::Ordering::SeqCst) + } + + pub fn embed_sync(&self, span: &str) -> Embedding { + let mut result = vec![1.0; 26]; + for letter in span.chars() { + let letter = letter.to_ascii_lowercase(); + if letter as u32 >= 'a' as u32 { + let ix = (letter as u32) - ('a' as u32); + if ix < 26 { + result[ix as usize] += 1.0; + } + } + } + + let norm = result.iter().map(|x| x * x).sum::().sqrt(); + for x in &mut result { + *x /= norm; + } + + result.into() + } +} + +#[async_trait] +impl CredentialProvider for FakeEmbeddingProvider { + fn has_credentials(&self) -> bool { + true + } + async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential { + ProviderCredential::NotNeeded + } + async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {} + async fn delete_credentials(&self, _cx: &mut AppContext) {} +} + +#[async_trait] +impl EmbeddingProvider for FakeEmbeddingProvider { + fn base_model(&self) -> Box { + Box::new(FakeLanguageModel { capacity: 1000 }) + } + fn max_tokens_per_batch(&self) -> usize { + 1000 + } + + fn rate_limit_expiration(&self) -> Option { + None + } + + async fn embed_batch(&self, spans: Vec) -> anyhow::Result> { + self.embedding_count + .fetch_add(spans.len(), atomic::Ordering::SeqCst); + + anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) + } +} + +pub struct FakeCompletionProvider { + last_completion_tx: Mutex>>, +} + +impl Clone for FakeCompletionProvider { + fn clone(&self) -> Self { + Self { + last_completion_tx: Mutex::new(None), + } + } +} + +impl FakeCompletionProvider { + pub fn new() -> Self { + Self { + last_completion_tx: Mutex::new(None), + } + } + + pub fn send_completion(&self, completion: impl Into) { + let mut tx = self.last_completion_tx.lock(); + tx.as_mut().unwrap().try_send(completion.into()).unwrap(); + } + + pub fn finish_completion(&self) { + self.last_completion_tx.lock().take().unwrap(); + } +} + +#[async_trait] +impl CredentialProvider for FakeCompletionProvider { + fn has_credentials(&self) -> bool { + true + } + async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential { + ProviderCredential::NotNeeded + } + async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {} + async fn delete_credentials(&self, _cx: &mut AppContext) {} +} + +impl CompletionProvider for FakeCompletionProvider { + fn base_model(&self) -> Box { + let model: Box = Box::new(FakeLanguageModel { capacity: 8190 }); + model + } + fn complete( + &self, + _prompt: Box, + ) -> BoxFuture<'static, anyhow::Result>>> { + let (tx, rx) = mpsc::channel(1); + *self.last_completion_tx.lock() = Some(tx); + async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() + } + fn box_clone(&self) -> Box { + Box::new((*self).clone()) + } +} diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index 256f4d841603f14d97b2f663e5a3e39d7ea6f0d3..fc885f6b36248d67bd2ce288c6f5d2931d8540da 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -45,6 +45,7 @@ tiktoken-rs = "0.5" [dev-dependencies] editor = { path = "../editor", features = ["test-support"] } project = { path = "../project", features = ["test-support"] } +ai = { path = "../ai", features = ["test-support"]} ctor.workspace = true env_logger.workspace = true diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 6c9b14333e34cbf5fd49d8299ba7bd891b607526..91d61a19f98b4cf4a61257a68d8f7212c6a33586 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -4,7 +4,7 @@ mod codegen; mod prompts; mod streaming_diff; -use ai::completion::Role; +use ai::providers::open_ai::Role; use anyhow::Result; pub use assistant_panel::AssistantPanel; use assistant_settings::OpenAIModel; diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 0dee8be510cb329025f8aeefa5b4185efa250b04..03eb3c238f98d55e7bd40793dafa931244bcc073 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -5,12 +5,14 @@ use crate::{ MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata, SavedMessage, }; + use ai::{ - completion::{ - stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, - }, - templates::repository_context::PromptCodeSnippet, + auth::ProviderCredential, + completion::{CompletionProvider, CompletionRequest}, + providers::open_ai::{OpenAICompletionProvider, OpenAIRequest, RequestMessage}, }; + +use ai::prompts::repository_context::PromptCodeSnippet; use anyhow::{anyhow, Result}; use chrono::{DateTime, Local}; use client::{telemetry::AssistantKind, ClickhouseEvent, TelemetrySettings}; @@ -43,8 +45,8 @@ use search::BufferSearchBar; use semantic_index::{SemanticIndex, SemanticIndexStatus}; use settings::SettingsStore; use std::{ - cell::{Cell, RefCell}, - cmp, env, + cell::Cell, + cmp, fmt::Write, iter, ops::Range, @@ -97,8 +99,8 @@ pub fn init(cx: &mut AppContext) { cx.capture_action(ConversationEditor::copy); cx.add_action(ConversationEditor::split); cx.capture_action(ConversationEditor::cycle_message_role); - cx.add_action(AssistantPanel::save_api_key); - cx.add_action(AssistantPanel::reset_api_key); + cx.add_action(AssistantPanel::save_credentials); + cx.add_action(AssistantPanel::reset_credentials); cx.add_action(AssistantPanel::toggle_zoom); cx.add_action(AssistantPanel::deploy); cx.add_action(AssistantPanel::select_next_match); @@ -140,9 +142,8 @@ pub struct AssistantPanel { zoomed: bool, has_focus: bool, toolbar: ViewHandle, - api_key: Rc>>, + completion_provider: Box, api_key_editor: Option>, - has_read_credentials: bool, languages: Arc, fs: Arc, subscriptions: Vec, @@ -202,6 +203,11 @@ impl AssistantPanel { }); let semantic_index = SemanticIndex::global(cx); + // Defaulting currently to GPT4, allow for this to be set via config. + let completion_provider = Box::new(OpenAICompletionProvider::new( + "gpt-4", + cx.background().clone(), + )); let mut this = Self { workspace: workspace_handle, @@ -213,9 +219,8 @@ impl AssistantPanel { zoomed: false, has_focus: false, toolbar, - api_key: Rc::new(RefCell::new(None)), + completion_provider, api_key_editor: None, - has_read_credentials: false, languages: workspace.app_state().languages.clone(), fs: workspace.app_state().fs.clone(), width: None, @@ -254,10 +259,7 @@ impl AssistantPanel { cx: &mut ViewContext, ) { let this = if let Some(this) = workspace.panel::(cx) { - if this - .update(cx, |assistant, cx| assistant.load_api_key(cx)) - .is_some() - { + if this.update(cx, |assistant, _| assistant.has_credentials()) { this } else { workspace.focus_panel::(cx); @@ -289,12 +291,6 @@ impl AssistantPanel { cx: &mut ViewContext, project: &ModelHandle, ) { - let api_key = if let Some(api_key) = self.api_key.borrow().clone() { - api_key - } else { - return; - }; - let selection = editor.read(cx).selections.newest_anchor().clone(); if selection.start.excerpt_id != selection.end.excerpt_id { return; @@ -325,10 +321,13 @@ impl AssistantPanel { let inline_assist_id = post_inc(&mut self.next_inline_assist_id); let provider = Arc::new(OpenAICompletionProvider::new( - api_key, + "gpt-4", cx.background().clone(), )); + // Retrieve Credentials Authenticates the Provider + // provider.retrieve_credentials(cx); + let codegen = cx.add_model(|cx| { Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx) }); @@ -745,13 +744,14 @@ impl AssistantPanel { content: prompt, }); - let request = OpenAIRequest { + let request = Box::new(OpenAIRequest { model: model.full_name().into(), messages, stream: true, stop: vec!["|END|>".to_string()], temperature, - }; + }); + codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx)); anyhow::Ok(()) }) @@ -811,7 +811,7 @@ impl AssistantPanel { fn new_conversation(&mut self, cx: &mut ViewContext) -> ViewHandle { let editor = cx.add_view(|cx| { ConversationEditor::new( - self.api_key.clone(), + self.completion_provider.clone(), self.languages.clone(), self.fs.clone(), self.workspace.clone(), @@ -870,17 +870,19 @@ impl AssistantPanel { } } - fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { + fn save_credentials(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { if let Some(api_key) = self .api_key_editor .as_ref() .map(|editor| editor.read(cx).text(cx)) { if !api_key.is_empty() { - cx.platform() - .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) - .log_err(); - *self.api_key.borrow_mut() = Some(api_key); + let credential = ProviderCredential::Credentials { + api_key: api_key.clone(), + }; + + self.completion_provider.save_credentials(cx, credential); + self.api_key_editor.take(); cx.focus_self(); cx.notify(); @@ -890,9 +892,8 @@ impl AssistantPanel { } } - fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext) { - cx.platform().delete_credentials(OPENAI_API_URL).log_err(); - self.api_key.take(); + fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext) { + self.completion_provider.delete_credentials(cx); self.api_key_editor = Some(build_api_key_editor(cx)); cx.focus_self(); cx.notify(); @@ -1151,13 +1152,12 @@ impl AssistantPanel { let fs = self.fs.clone(); let workspace = self.workspace.clone(); - let api_key = self.api_key.clone(); let languages = self.languages.clone(); cx.spawn(|this, mut cx| async move { let saved_conversation = fs.load(&path).await?; let saved_conversation = serde_json::from_str(&saved_conversation)?; let conversation = cx.add_model(|cx| { - Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx) + Conversation::deserialize(saved_conversation, path.clone(), languages, cx) }); this.update(&mut cx, |this, cx| { // If, by the time we've loaded the conversation, the user has already opened @@ -1181,30 +1181,12 @@ impl AssistantPanel { .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path)) } - fn load_api_key(&mut self, cx: &mut ViewContext) -> Option { - if self.api_key.borrow().is_none() && !self.has_read_credentials { - self.has_read_credentials = true; - let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { - Some(api_key) - } else if let Some((_, api_key)) = cx - .platform() - .read_credentials(OPENAI_API_URL) - .log_err() - .flatten() - { - String::from_utf8(api_key).log_err() - } else { - None - }; - if let Some(api_key) = api_key { - *self.api_key.borrow_mut() = Some(api_key); - } else if self.api_key_editor.is_none() { - self.api_key_editor = Some(build_api_key_editor(cx)); - cx.notify(); - } - } + fn has_credentials(&mut self) -> bool { + self.completion_provider.has_credentials() + } - self.api_key.borrow().clone() + fn load_credentials(&mut self, cx: &mut ViewContext) { + self.completion_provider.retrieve_credentials(cx); } } @@ -1389,7 +1371,7 @@ impl Panel for AssistantPanel { fn set_active(&mut self, active: bool, cx: &mut ViewContext) { if active { - self.load_api_key(cx); + self.load_credentials(cx); if self.editors.is_empty() { self.new_conversation(cx); @@ -1454,10 +1436,10 @@ struct Conversation { token_count: Option, max_token_count: usize, pending_token_count: Task>, - api_key: Rc>>, pending_save: Task>, path: Option, _subscriptions: Vec, + completion_provider: Box, } impl Entity for Conversation { @@ -1466,9 +1448,9 @@ impl Entity for Conversation { impl Conversation { fn new( - api_key: Rc>>, language_registry: Arc, cx: &mut ModelContext, + completion_provider: Box, ) -> Self { let markdown = language_registry.language_for_name("Markdown"); let buffer = cx.add_model(|cx| { @@ -1507,8 +1489,8 @@ impl Conversation { _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), path: None, - api_key, buffer, + completion_provider, }; let message = MessageAnchor { id: MessageId(post_inc(&mut this.next_message_id.0)), @@ -1554,7 +1536,6 @@ impl Conversation { fn deserialize( saved_conversation: SavedConversation, path: PathBuf, - api_key: Rc>>, language_registry: Arc, cx: &mut ModelContext, ) -> Self { @@ -1563,6 +1544,10 @@ impl Conversation { None => Some(Uuid::new_v4().to_string()), }; let model = saved_conversation.model; + let completion_provider: Box = Box::new( + OpenAICompletionProvider::new(model.full_name(), cx.background().clone()), + ); + completion_provider.retrieve_credentials(cx); let markdown = language_registry.language_for_name("Markdown"); let mut message_anchors = Vec::new(); let mut next_message_id = MessageId(0); @@ -1609,8 +1594,8 @@ impl Conversation { _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), path: Some(path), - api_key, buffer, + completion_provider, }; this.count_remaining_tokens(cx); this @@ -1731,11 +1716,11 @@ impl Conversation { } if should_assist { - let Some(api_key) = self.api_key.borrow().clone() else { + if !self.completion_provider.has_credentials() { return Default::default(); - }; + } - let request = OpenAIRequest { + let request: Box = Box::new(OpenAIRequest { model: self.model.full_name().to_string(), messages: self .messages(cx) @@ -1745,9 +1730,9 @@ impl Conversation { stream: true, stop: vec![], temperature: 1.0, - }; + }); - let stream = stream_completion(api_key, cx.background().clone(), request); + let stream = self.completion_provider.complete(request); let assistant_message = self .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx) .unwrap(); @@ -1765,33 +1750,28 @@ impl Conversation { let mut messages = stream.await?; while let Some(message) = messages.next().await { - let mut message = message?; - if let Some(choice) = message.choices.pop() { - this.upgrade(&cx) - .ok_or_else(|| anyhow!("conversation was dropped"))? - .update(&mut cx, |this, cx| { - let text: Arc = choice.delta.content?.into(); - let message_ix = - this.message_anchors.iter().position(|message| { - message.id == assistant_message_id - })?; - this.buffer.update(cx, |buffer, cx| { - let offset = this.message_anchors[message_ix + 1..] - .iter() - .find(|message| message.start.is_valid(buffer)) - .map_or(buffer.len(), |message| { - message - .start - .to_offset(buffer) - .saturating_sub(1) - }); - buffer.edit([(offset..offset, text)], None, cx); - }); - cx.emit(ConversationEvent::StreamedCompletion); - - Some(()) + let text = message?; + + this.upgrade(&cx) + .ok_or_else(|| anyhow!("conversation was dropped"))? + .update(&mut cx, |this, cx| { + let message_ix = this + .message_anchors + .iter() + .position(|message| message.id == assistant_message_id)?; + this.buffer.update(cx, |buffer, cx| { + let offset = this.message_anchors[message_ix + 1..] + .iter() + .find(|message| message.start.is_valid(buffer)) + .map_or(buffer.len(), |message| { + message.start.to_offset(buffer).saturating_sub(1) + }); + buffer.edit([(offset..offset, text)], None, cx); }); - } + cx.emit(ConversationEvent::StreamedCompletion); + + Some(()) + }); smol::future::yield_now().await; } @@ -2013,57 +1993,54 @@ impl Conversation { fn summarize(&mut self, cx: &mut ModelContext) { if self.message_anchors.len() >= 2 && self.summary.is_none() { - let api_key = self.api_key.borrow().clone(); - if let Some(api_key) = api_key { - let messages = self - .messages(cx) - .take(2) - .map(|message| message.to_open_ai_message(self.buffer.read(cx))) - .chain(Some(RequestMessage { - role: Role::User, - content: - "Summarize the conversation into a short title without punctuation" - .into(), - })); - let request = OpenAIRequest { - model: self.model.full_name().to_string(), - messages: messages.collect(), - stream: true, - stop: vec![], - temperature: 1.0, - }; + if !self.completion_provider.has_credentials() { + return; + } - let stream = stream_completion(api_key, cx.background().clone(), request); - self.pending_summary = cx.spawn(|this, mut cx| { - async move { - let mut messages = stream.await?; + let messages = self + .messages(cx) + .take(2) + .map(|message| message.to_open_ai_message(self.buffer.read(cx))) + .chain(Some(RequestMessage { + role: Role::User, + content: "Summarize the conversation into a short title without punctuation" + .into(), + })); + let request: Box = Box::new(OpenAIRequest { + model: self.model.full_name().to_string(), + messages: messages.collect(), + stream: true, + stop: vec![], + temperature: 1.0, + }); - while let Some(message) = messages.next().await { - let mut message = message?; - if let Some(choice) = message.choices.pop() { - let text = choice.delta.content.unwrap_or_default(); - this.update(&mut cx, |this, cx| { - this.summary - .get_or_insert(Default::default()) - .text - .push_str(&text); - cx.emit(ConversationEvent::SummaryChanged); - }); - } - } + let stream = self.completion_provider.complete(request); + self.pending_summary = cx.spawn(|this, mut cx| { + async move { + let mut messages = stream.await?; + while let Some(message) = messages.next().await { + let text = message?; this.update(&mut cx, |this, cx| { - if let Some(summary) = this.summary.as_mut() { - summary.done = true; - cx.emit(ConversationEvent::SummaryChanged); - } + this.summary + .get_or_insert(Default::default()) + .text + .push_str(&text); + cx.emit(ConversationEvent::SummaryChanged); }); - - anyhow::Ok(()) } - .log_err() - }); - } + + this.update(&mut cx, |this, cx| { + if let Some(summary) = this.summary.as_mut() { + summary.done = true; + cx.emit(ConversationEvent::SummaryChanged); + } + }); + + anyhow::Ok(()) + } + .log_err() + }); } } @@ -2224,13 +2201,14 @@ struct ConversationEditor { impl ConversationEditor { fn new( - api_key: Rc>>, + completion_provider: Box, language_registry: Arc, fs: Arc, workspace: WeakViewHandle, cx: &mut ViewContext, ) -> Self { - let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx)); + let conversation = + cx.add_model(|cx| Conversation::new(language_registry, cx, completion_provider)); Self::for_conversation(conversation, fs, workspace, cx) } @@ -3419,6 +3397,7 @@ fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { mod tests { use super::*; use crate::MessageId; + use ai::test::FakeCompletionProvider; use gpui::AppContext; #[gpui::test] @@ -3426,7 +3405,9 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx)); + + let completion_provider = Box::new(FakeCompletionProvider::new()); + let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3554,7 +3535,9 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx)); + let completion_provider = Box::new(FakeCompletionProvider::new()); + + let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3650,7 +3633,8 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx)); + let completion_provider = Box::new(FakeCompletionProvider::new()); + let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3732,8 +3716,9 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); + let completion_provider = Box::new(FakeCompletionProvider::new()); let conversation = - cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx)); + cx.add_model(|cx| Conversation::new(registry.clone(), cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); let message_0 = conversation.read(cx).message_anchors[0].id; let message_1 = conversation.update(cx, |conversation, cx| { @@ -3770,7 +3755,6 @@ mod tests { Conversation::deserialize( conversation.read(cx).serialize(cx), Default::default(), - Default::default(), registry.clone(), cx, ) diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index 6b79daba420b121c5650e4b8f07e4cb800eeaaa0..0466259b24d9cb05c165ce8c3f6192a558922454 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -1,5 +1,5 @@ use crate::streaming_diff::{Hunk, StreamingDiff}; -use ai::completion::{CompletionProvider, OpenAIRequest}; +use ai::completion::{CompletionProvider, CompletionRequest}; use anyhow::Result; use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint}; use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; @@ -96,7 +96,7 @@ impl Codegen { self.error.as_ref() } - pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext) { + pub fn start(&mut self, prompt: Box, cx: &mut ModelContext) { let range = self.range(); let snapshot = self.snapshot.clone(); let selected_text = snapshot @@ -336,17 +336,25 @@ fn strip_markdown_codeblock( #[cfg(test)] mod tests { use super::*; - use futures::{ - future::BoxFuture, - stream::{self, BoxStream}, - }; + use ai::test::FakeCompletionProvider; + use futures::stream::{self}; use gpui::{executor::Deterministic, TestAppContext}; use indoc::indoc; use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point}; - use parking_lot::Mutex; use rand::prelude::*; + use serde::Serialize; use settings::SettingsStore; - use smol::future::FutureExt; + + #[derive(Serialize)] + pub struct DummyCompletionRequest { + pub name: String, + } + + impl CompletionRequest for DummyCompletionRequest { + fn data(&self) -> serde_json::Result { + serde_json::to_string(self) + } + } #[gpui::test(iterations = 10)] async fn test_transform_autoindent( @@ -372,7 +380,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) }); - let provider = Arc::new(TestCompletionProvider::new()); + let provider = Arc::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(), @@ -381,7 +389,11 @@ mod tests { cx, ) }); - codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); + + let request = Box::new(DummyCompletionRequest { + name: "test".to_string(), + }); + codegen.update(cx, |codegen, cx| codegen.start(request, cx)); let mut new_text = concat!( " let mut x = 0;\n", @@ -434,7 +446,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 6)) }); - let provider = Arc::new(TestCompletionProvider::new()); + let provider = Arc::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(), @@ -443,7 +455,11 @@ mod tests { cx, ) }); - codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); + + let request = Box::new(DummyCompletionRequest { + name: "test".to_string(), + }); + codegen.update(cx, |codegen, cx| codegen.start(request, cx)); let mut new_text = concat!( "t mut x = 0;\n", @@ -496,7 +512,7 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 2)) }); - let provider = Arc::new(TestCompletionProvider::new()); + let provider = Arc::new(FakeCompletionProvider::new()); let codegen = cx.add_model(|cx| { Codegen::new( buffer.clone(), @@ -505,7 +521,11 @@ mod tests { cx, ) }); - codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); + + let request = Box::new(DummyCompletionRequest { + name: "test".to_string(), + }); + codegen.update(cx, |codegen, cx| codegen.start(request, cx)); let mut new_text = concat!( "let mut x = 0;\n", @@ -593,38 +613,6 @@ mod tests { } } - struct TestCompletionProvider { - last_completion_tx: Mutex>>, - } - - impl TestCompletionProvider { - fn new() -> Self { - Self { - last_completion_tx: Mutex::new(None), - } - } - - fn send_completion(&self, completion: impl Into) { - let mut tx = self.last_completion_tx.lock(); - tx.as_mut().unwrap().try_send(completion.into()).unwrap(); - } - - fn finish_completion(&self) { - self.last_completion_tx.lock().take().unwrap(); - } - } - - impl CompletionProvider for TestCompletionProvider { - fn complete( - &self, - _prompt: OpenAIRequest, - ) -> BoxFuture<'static, Result>>> { - let (tx, rx) = mpsc::channel(1); - *self.last_completion_tx.lock() = Some(tx); - async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() - } - } - fn rust_lang() -> Language { Language::new( LanguageConfig { diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index dffcbc29234d3f24174d1d9a6610045105eae890..25af023c40072cbc56ba2865658544999e0fa7c7 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -1,9 +1,10 @@ -use ai::models::{LanguageModel, OpenAILanguageModel}; -use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate}; -use ai::templates::file_context::FileContext; -use ai::templates::generate::GenerateInlineContent; -use ai::templates::preamble::EngineerPreamble; -use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext}; +use ai::models::LanguageModel; +use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate}; +use ai::prompts::file_context::FileContext; +use ai::prompts::generate::GenerateInlineContent; +use ai::prompts::preamble::EngineerPreamble; +use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext}; +use ai::providers::open_ai::OpenAILanguageModel; use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; use std::cmp::{self, Reverse}; use std::ops::Range; diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index 1febb2af786e205e19764e90ca1c78d954d8a1bb..875440ef3fa866734c734830ff5c7b95550fd33c 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -42,6 +42,7 @@ sha1 = "0.10.5" ndarray = { version = "0.15.0" } [dev-dependencies] +ai = { path = "../ai", features = ["test-support"] } collections = { path = "../collections", features = ["test-support"] } gpui = { path = "../gpui", features = ["test-support"] } language = { path = "../language", features = ["test-support"] } diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs index d57d5c7bbe43130d8df135116adeaa76d0a4e357..6ae8faa4cdb0998acad8ab361c83183cff909578 100644 --- a/crates/semantic_index/src/embedding_queue.rs +++ b/crates/semantic_index/src/embedding_queue.rs @@ -41,7 +41,6 @@ pub struct EmbeddingQueue { pending_batch_token_count: usize, finished_files_tx: channel::Sender, finished_files_rx: channel::Receiver, - api_key: Option, } #[derive(Clone)] @@ -51,11 +50,7 @@ pub struct FileFragmentToEmbed { } impl EmbeddingQueue { - pub fn new( - embedding_provider: Arc, - executor: Arc, - api_key: Option, - ) -> Self { + pub fn new(embedding_provider: Arc, executor: Arc) -> Self { let (finished_files_tx, finished_files_rx) = channel::unbounded(); Self { embedding_provider, @@ -64,14 +59,9 @@ impl EmbeddingQueue { pending_batch_token_count: 0, finished_files_tx, finished_files_rx, - api_key, } } - pub fn set_api_key(&mut self, api_key: Option) { - self.api_key = api_key - } - pub fn push(&mut self, file: FileToEmbed) { if file.spans.is_empty() { self.finished_files_tx.try_send(file).unwrap(); @@ -118,7 +108,6 @@ impl EmbeddingQueue { let finished_files_tx = self.finished_files_tx.clone(); let embedding_provider = self.embedding_provider.clone(); - let api_key = self.api_key.clone(); self.executor .spawn(async move { @@ -143,7 +132,7 @@ impl EmbeddingQueue { return; }; - match embedding_provider.embed_batch(spans, api_key).await { + match embedding_provider.embed_batch(spans).await { Ok(embeddings) => { let mut embeddings = embeddings.into_iter(); for fragment in batch { diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs index f9b8bac9a484bfae48c62683ee096e2e49420622..cb15ca453b2c0640739bd44a95482ca527b8d91b 100644 --- a/crates/semantic_index/src/parsing.rs +++ b/crates/semantic_index/src/parsing.rs @@ -1,4 +1,7 @@ -use ai::embedding::{Embedding, EmbeddingProvider}; +use ai::{ + embedding::{Embedding, EmbeddingProvider}, + models::TruncationDirection, +}; use anyhow::{anyhow, Result}; use language::{Grammar, Language}; use rusqlite::{ @@ -108,7 +111,14 @@ impl CodeContextRetriever { .replace("", language_name.as_ref()) .replace("", &content); let digest = SpanDigest::from(document_span.as_str()); - let (document_span, token_count) = self.embedding_provider.truncate(&document_span); + let model = self.embedding_provider.base_model(); + let document_span = model.truncate( + &document_span, + model.capacity()?, + ai::models::TruncationDirection::End, + )?; + let token_count = model.count_tokens(&document_span)?; + Ok(vec![Span { range: 0..content.len(), content: document_span, @@ -131,7 +141,15 @@ impl CodeContextRetriever { ) .replace("", &content); let digest = SpanDigest::from(document_span.as_str()); - let (document_span, token_count) = self.embedding_provider.truncate(&document_span); + + let model = self.embedding_provider.base_model(); + let document_span = model.truncate( + &document_span, + model.capacity()?, + ai::models::TruncationDirection::End, + )?; + let token_count = model.count_tokens(&document_span)?; + Ok(vec![Span { range: 0..content.len(), content: document_span, @@ -222,8 +240,13 @@ impl CodeContextRetriever { .replace("", language_name.as_ref()) .replace("item", &span.content); - let (document_content, token_count) = - self.embedding_provider.truncate(&document_content); + let model = self.embedding_provider.base_model(); + let document_content = model.truncate( + &document_content, + model.capacity()?, + TruncationDirection::End, + )?; + let token_count = model.count_tokens(&document_content)?; span.content = document_content; span.token_count = token_count; diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 8839d25a841efe57136e92f0d107694410e4ec81..818faa044475d8ed76a90dafcb6a2f6e3b7f7ee6 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -7,7 +7,8 @@ pub mod semantic_index_settings; mod semantic_index_tests; use crate::semantic_index_settings::SemanticIndexSettings; -use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings}; +use ai::embedding::{Embedding, EmbeddingProvider}; +use ai::providers::open_ai::OpenAIEmbeddingProvider; use anyhow::{anyhow, Result}; use collections::{BTreeMap, HashMap, HashSet}; use db::VectorDatabase; @@ -88,7 +89,7 @@ pub fn init( let semantic_index = SemanticIndex::new( fs, db_file_path, - Arc::new(OpenAIEmbeddings::new(http_client, cx.background())), + Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())), language_registry, cx.clone(), ) @@ -123,8 +124,6 @@ pub struct SemanticIndex { _embedding_task: Task<()>, _parsing_files_tasks: Vec>, projects: HashMap, ProjectState>, - api_key: Option, - embedding_queue: Arc>, } struct ProjectState { @@ -278,18 +277,18 @@ impl SemanticIndex { } } - pub fn authenticate(&mut self, cx: &AppContext) { - if self.api_key.is_none() { - self.api_key = self.embedding_provider.retrieve_credentials(cx); - - self.embedding_queue - .lock() - .set_api_key(self.api_key.clone()); + pub fn authenticate(&mut self, cx: &AppContext) -> bool { + if !self.embedding_provider.has_credentials() { + self.embedding_provider.retrieve_credentials(cx); + } else { + return true; } + + self.embedding_provider.has_credentials() } pub fn is_authenticated(&self) -> bool { - self.api_key.is_some() + self.embedding_provider.has_credentials() } pub fn enabled(cx: &AppContext) -> bool { @@ -339,7 +338,7 @@ impl SemanticIndex { Ok(cx.add_model(|cx| { let t0 = Instant::now(); let embedding_queue = - EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None); + EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone()); let _embedding_task = cx.background().spawn({ let embedded_files = embedding_queue.finished_files(); let db = db.clone(); @@ -404,8 +403,6 @@ impl SemanticIndex { _embedding_task, _parsing_files_tasks, projects: Default::default(), - api_key: None, - embedding_queue } })) } @@ -720,13 +717,13 @@ impl SemanticIndex { let index = self.index_project(project.clone(), cx); let embedding_provider = self.embedding_provider.clone(); - let api_key = self.api_key.clone(); cx.spawn(|this, mut cx| async move { index.await?; let t0 = Instant::now(); + let query = embedding_provider - .embed_batch(vec![query], api_key) + .embed_batch(vec![query]) .await? .pop() .ok_or_else(|| anyhow!("could not embed query"))?; @@ -944,7 +941,6 @@ impl SemanticIndex { let fs = self.fs.clone(); let db_path = self.db.path().clone(); let background = cx.background().clone(); - let api_key = self.api_key.clone(); cx.background().spawn(async move { let db = VectorDatabase::new(fs, db_path.clone(), background).await?; let mut results = Vec::::new(); @@ -959,15 +955,10 @@ impl SemanticIndex { .parse_file_with_template(None, &snapshot.text(), language) .log_err() .unwrap_or_default(); - if Self::embed_spans( - &mut spans, - embedding_provider.as_ref(), - &db, - api_key.clone(), - ) - .await - .log_err() - .is_some() + if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db) + .await + .log_err() + .is_some() { for span in spans { let similarity = span.embedding.unwrap().similarity(&query); @@ -1007,9 +998,8 @@ impl SemanticIndex { project: ModelHandle, cx: &mut ModelContext, ) -> Task> { - if self.api_key.is_none() { - self.authenticate(cx); - if self.api_key.is_none() { + if !self.is_authenticated() { + if !self.authenticate(cx) { return Task::ready(Err(anyhow!("user is not authenticated"))); } } @@ -1192,7 +1182,6 @@ impl SemanticIndex { spans: &mut [Span], embedding_provider: &dyn EmbeddingProvider, db: &VectorDatabase, - api_key: Option, ) -> Result<()> { let mut batch = Vec::new(); let mut batch_tokens = 0; @@ -1215,7 +1204,7 @@ impl SemanticIndex { if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() { let batch_embeddings = embedding_provider - .embed_batch(mem::take(&mut batch), api_key.clone()) + .embed_batch(mem::take(&mut batch)) .await?; embeddings.extend(batch_embeddings); batch_tokens = 0; @@ -1227,7 +1216,7 @@ impl SemanticIndex { if !batch.is_empty() { let batch_embeddings = embedding_provider - .embed_batch(mem::take(&mut batch), api_key) + .embed_batch(mem::take(&mut batch)) .await?; embeddings.extend(batch_embeddings); diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs index a1ee3e5ada45a93da405153d3da75f1fb6779560..7a91d1e100ca96d21e937fe7945c3cfd78fc68fd 100644 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ b/crates/semantic_index/src/semantic_index_tests.rs @@ -4,10 +4,9 @@ use crate::{ semantic_index_settings::SemanticIndexSettings, FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT, }; -use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider}; -use anyhow::Result; -use async_trait::async_trait; -use gpui::{executor::Deterministic, AppContext, Task, TestAppContext}; +use ai::test::FakeEmbeddingProvider; + +use gpui::{executor::Deterministic, Task, TestAppContext}; use language::{Language, LanguageConfig, LanguageRegistry, ToOffset}; use parking_lot::Mutex; use pretty_assertions::assert_eq; @@ -15,14 +14,7 @@ use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs use rand::{rngs::StdRng, Rng}; use serde_json::json; use settings::SettingsStore; -use std::{ - path::Path, - sync::{ - atomic::{self, AtomicUsize}, - Arc, - }, - time::{Instant, SystemTime}, -}; +use std::{path::Path, sync::Arc, time::SystemTime}; use unindent::Unindent; use util::RandomCharIter; @@ -228,7 +220,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None); + let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background()); for file in &files { queue.push(file.clone()); } @@ -280,7 +272,7 @@ fn assert_search_results( #[gpui::test] async fn test_code_context_retrieval_rust() { let language = rust_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " @@ -382,7 +374,7 @@ async fn test_code_context_retrieval_rust() { #[gpui::test] async fn test_code_context_retrieval_json() { let language = json_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -466,7 +458,7 @@ fn assert_documents_eq( #[gpui::test] async fn test_code_context_retrieval_javascript() { let language = js_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " @@ -565,7 +557,7 @@ async fn test_code_context_retrieval_javascript() { #[gpui::test] async fn test_code_context_retrieval_lua() { let language = lua_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -639,7 +631,7 @@ async fn test_code_context_retrieval_lua() { #[gpui::test] async fn test_code_context_retrieval_elixir() { let language = elixir_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -756,7 +748,7 @@ async fn test_code_context_retrieval_elixir() { #[gpui::test] async fn test_code_context_retrieval_cpp() { let language = cpp_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = " @@ -909,7 +901,7 @@ async fn test_code_context_retrieval_cpp() { #[gpui::test] async fn test_code_context_retrieval_ruby() { let language = ruby_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -1100,7 +1092,7 @@ async fn test_code_context_retrieval_ruby() { #[gpui::test] async fn test_code_context_retrieval_php() { let language = php_lang(); - let embedding_provider = Arc::new(DummyEmbeddings {}); + let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); let mut retriever = CodeContextRetriever::new(embedding_provider); let text = r#" @@ -1248,65 +1240,6 @@ async fn test_code_context_retrieval_php() { ); } -#[derive(Default)] -struct FakeEmbeddingProvider { - embedding_count: AtomicUsize, -} - -impl FakeEmbeddingProvider { - fn embedding_count(&self) -> usize { - self.embedding_count.load(atomic::Ordering::SeqCst) - } - - fn embed_sync(&self, span: &str) -> Embedding { - let mut result = vec![1.0; 26]; - for letter in span.chars() { - let letter = letter.to_ascii_lowercase(); - if letter as u32 >= 'a' as u32 { - let ix = (letter as u32) - ('a' as u32); - if ix < 26 { - result[ix as usize] += 1.0; - } - } - } - - let norm = result.iter().map(|x| x * x).sum::().sqrt(); - for x in &mut result { - *x /= norm; - } - - result.into() - } -} - -#[async_trait] -impl EmbeddingProvider for FakeEmbeddingProvider { - fn retrieve_credentials(&self, _cx: &AppContext) -> Option { - Some("Fake Credentials".to_string()) - } - fn truncate(&self, span: &str) -> (String, usize) { - (span.to_string(), 1) - } - - fn max_tokens_per_batch(&self) -> usize { - 200 - } - - fn rate_limit_expiration(&self) -> Option { - None - } - - async fn embed_batch( - &self, - spans: Vec, - _api_key: Option, - ) -> Result> { - self.embedding_count - .fetch_add(spans.len(), atomic::Ordering::SeqCst); - Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) - } -} - fn js_lang() -> Arc { Arc::new( Language::new( diff --git a/crates/ui2/src/elements/icon.rs b/crates/ui2/src/elements/icon.rs index ef3644229685dbfaa2ad466c4d1b242272377479..2273ec24f2369e50bb9152eeedddfd5f9aad8cbf 100644 --- a/crates/ui2/src/elements/icon.rs +++ b/crates/ui2/src/elements/icon.rs @@ -36,7 +36,7 @@ impl IconColor { IconColor::Error => gpui2::red(), IconColor::Warning => gpui2::red(), IconColor::Success => gpui2::red(), - IconColor::Info => gpui2::red() + IconColor::Info => gpui2::red(), } } } diff --git a/crates/zed/examples/semantic_index_eval.rs b/crates/zed/examples/semantic_index_eval.rs index e7503078002a56780894498c1872320e6ed37694..caf8e5f5c73fe29cd701e540b7dc8e8847b6808c 100644 --- a/crates/zed/examples/semantic_index_eval.rs +++ b/crates/zed/examples/semantic_index_eval.rs @@ -1,4 +1,4 @@ -use ai::embedding::OpenAIEmbeddings; +use ai::providers::open_ai::OpenAIEmbeddingProvider; use anyhow::{anyhow, Result}; use client::{self, UserStore}; use gpui::{AsyncAppContext, ModelHandle, Task}; @@ -475,7 +475,7 @@ fn main() { let semantic_index = SemanticIndex::new( fs.clone(), db_file_path, - Arc::new(OpenAIEmbeddings::new(http_client, cx.background())), + Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())), languages.clone(), cx.clone(), ) diff --git a/crates/zed2/Cargo.toml b/crates/zed2/Cargo.toml index a6b31871dd6d8290f60a0862aa96ef6433df364f..9f681a49e95565b58d8cba5ba2e775ba1d0e37bb 100644 --- a/crates/zed2/Cargo.toml +++ b/crates/zed2/Cargo.toml @@ -15,6 +15,7 @@ name = "Zed" path = "src/main.rs" [dependencies] +ai2 = { path = "../ai2"} # audio = { path = "../audio" } # activity_indicator = { path = "../activity_indicator" } # auto_update = { path = "../auto_update" }