From d813ae88458ed5e14899c5ebdd4437daa033ae6e Mon Sep 17 00:00:00 2001 From: KCaverly Date: Sun, 22 Oct 2023 14:33:19 +0200 Subject: [PATCH] replace OpenAIRequest with more generalized Box --- crates/ai/src/completion.rs | 6 +++-- crates/ai/src/providers/dummy.rs | 13 ++++++++++ crates/ai/src/providers/mod.rs | 1 + crates/ai/src/providers/open_ai/completion.rs | 16 +++++++----- crates/assistant/src/assistant_panel.rs | 20 +++++++++------ crates/assistant/src/codegen.rs | 25 +++++++++++++------ 6 files changed, 58 insertions(+), 23 deletions(-) create mode 100644 crates/ai/src/providers/dummy.rs diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index f45893898fccaf54e794d151ce4c6c64eff34bc5..ba89c869d214c8772430e30b3177af502a41031e 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -1,11 +1,13 @@ use anyhow::Result; use futures::{future::BoxFuture, stream::BoxStream}; -use crate::providers::open_ai::completion::OpenAIRequest; +pub trait CompletionRequest: Send + Sync { + fn data(&self) -> serde_json::Result; +} pub trait CompletionProvider { fn complete( &self, - prompt: OpenAIRequest, + prompt: Box, ) -> BoxFuture<'static, Result>>>; } diff --git a/crates/ai/src/providers/dummy.rs b/crates/ai/src/providers/dummy.rs new file mode 100644 index 0000000000000000000000000000000000000000..be42b13f2fe53dd1568ef5f59fc24e394bef2fef --- /dev/null +++ b/crates/ai/src/providers/dummy.rs @@ -0,0 +1,13 @@ +use crate::completion::CompletionRequest; +use serde::Serialize; + +#[derive(Serialize)] +pub struct DummyCompletionRequest { + pub name: String, +} + +impl CompletionRequest for DummyCompletionRequest { + fn data(&self) -> serde_json::Result { + serde_json::to_string(self) + } +} diff --git a/crates/ai/src/providers/mod.rs b/crates/ai/src/providers/mod.rs index acd0f9d91053869e3999ef0c1a23326480a7cbdd..7a7092baf39d40da8b09ffb8085463a4a8b2efdf 100644 --- a/crates/ai/src/providers/mod.rs +++ b/crates/ai/src/providers/mod.rs @@ -1 +1,2 @@ +pub mod dummy; pub mod open_ai; diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs index bb6138eee3b74d1471ae03d2f8352d3906391087..95ed13c0dd7f1d519c8279ea75fd3980dd4f5d50 100644 --- a/crates/ai/src/providers/open_ai/completion.rs +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -12,7 +12,7 @@ use std::{ sync::Arc, }; -use crate::completion::CompletionProvider; +use crate::completion::{CompletionProvider, CompletionRequest}; pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; @@ -59,6 +59,12 @@ pub struct OpenAIRequest { pub temperature: f32, } +impl CompletionRequest for OpenAIRequest { + fn data(&self) -> serde_json::Result { + serde_json::to_string(self) + } +} + #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] pub struct ResponseMessage { pub role: Option, @@ -92,13 +98,11 @@ pub struct OpenAIResponseStreamEvent { pub async fn stream_completion( api_key: String, executor: Arc, - mut request: OpenAIRequest, + request: Box, ) -> Result>> { - request.stream = true; - let (tx, rx) = futures::channel::mpsc::unbounded::>(); - let json_data = serde_json::to_string(&request)?; + let json_data = request.data()?; let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions")) .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", api_key)) @@ -189,7 +193,7 @@ impl OpenAICompletionProvider { impl CompletionProvider for OpenAICompletionProvider { fn complete( &self, - prompt: OpenAIRequest, + prompt: Box, ) -> BoxFuture<'static, Result>>> { let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt); async move { diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 9b749e5091fd394eab07739e509b2a574bc43640..ec16c8fd046b48b8ef8dc0fc52afbdc6a0da1b3e 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -6,8 +6,11 @@ use crate::{ SavedMessage, }; -use ai::providers::open_ai::{ - stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, +use ai::{ + completion::CompletionRequest, + providers::open_ai::{ + stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, + }, }; use ai::prompts::repository_context::PromptCodeSnippet; @@ -745,13 +748,14 @@ impl AssistantPanel { content: prompt, }); - let request = OpenAIRequest { + let request = Box::new(OpenAIRequest { model: model.full_name().into(), messages, stream: true, stop: vec!["|END|>".to_string()], temperature, - }; + }); + codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx)); anyhow::Ok(()) }) @@ -1735,7 +1739,7 @@ impl Conversation { return Default::default(); }; - let request = OpenAIRequest { + let request: Box = Box::new(OpenAIRequest { model: self.model.full_name().to_string(), messages: self .messages(cx) @@ -1745,7 +1749,7 @@ impl Conversation { stream: true, stop: vec![], temperature: 1.0, - }; + }); let stream = stream_completion(api_key, cx.background().clone(), request); let assistant_message = self @@ -2025,13 +2029,13 @@ impl Conversation { "Summarize the conversation into a short title without punctuation" .into(), })); - let request = OpenAIRequest { + let request: Box = Box::new(OpenAIRequest { model: self.model.full_name().to_string(), messages: messages.collect(), stream: true, stop: vec![], temperature: 1.0, - }; + }); let stream = stream_completion(api_key, cx.background().clone(), request); self.pending_summary = cx.spawn(|this, mut cx| { diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index 66d2f60690d23628bfdc7d62f84f869c24347119..e535eca1441b4c8137d043c304a68cb2866e38fc 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -1,6 +1,5 @@ use crate::streaming_diff::{Hunk, StreamingDiff}; -use ai::completion::CompletionProvider; -use ai::providers::open_ai::OpenAIRequest; +use ai::completion::{CompletionProvider, CompletionRequest}; use anyhow::Result; use editor::{multi_buffer, Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint}; use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; @@ -96,7 +95,7 @@ impl Codegen { self.error.as_ref() } - pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext) { + pub fn start(&mut self, prompt: Box, cx: &mut ModelContext) { let range = self.range(); let snapshot = self.snapshot.clone(); let selected_text = snapshot @@ -336,6 +335,7 @@ fn strip_markdown_codeblock( #[cfg(test)] mod tests { use super::*; + use ai::providers::dummy::DummyCompletionRequest; use futures::{ future::BoxFuture, stream::{self, BoxStream}, @@ -381,7 +381,10 @@ mod tests { cx, ) }); - codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); + let request = Box::new(DummyCompletionRequest { + name: "test".to_string(), + }); + codegen.update(cx, |codegen, cx| codegen.start(request, cx)); let mut new_text = concat!( " let mut x = 0;\n", @@ -443,7 +446,11 @@ mod tests { cx, ) }); - codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); + + let request = Box::new(DummyCompletionRequest { + name: "test".to_string(), + }); + codegen.update(cx, |codegen, cx| codegen.start(request, cx)); let mut new_text = concat!( "t mut x = 0;\n", @@ -505,7 +512,11 @@ mod tests { cx, ) }); - codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx)); + + let request = Box::new(DummyCompletionRequest { + name: "test".to_string(), + }); + codegen.update(cx, |codegen, cx| codegen.start(request, cx)); let mut new_text = concat!( "let mut x = 0;\n", @@ -617,7 +628,7 @@ mod tests { impl CompletionProvider for TestCompletionProvider { fn complete( &self, - _prompt: OpenAIRequest, + _prompt: Box, ) -> BoxFuture<'static, Result>>> { let (tx, rx) = mpsc::channel(1); *self.last_completion_tx.lock() = Some(tx);