replace OpenAIRequest with more generalized Box<dyn CompletionRequest>

KCaverly created

Change summary

crates/ai/src/completion.rs                   |  6 +++-
crates/ai/src/providers/dummy.rs              | 13 ++++++++++
crates/ai/src/providers/mod.rs                |  1 
crates/ai/src/providers/open_ai/completion.rs | 16 ++++++++-----
crates/assistant/src/assistant_panel.rs       | 20 ++++++++++------
crates/assistant/src/codegen.rs               | 25 +++++++++++++++-----
6 files changed, 58 insertions(+), 23 deletions(-)

Detailed changes

crates/ai/src/completion.rs 🔗

@@ -1,11 +1,13 @@
 use anyhow::Result;
 use futures::{future::BoxFuture, stream::BoxStream};
 
-use crate::providers::open_ai::completion::OpenAIRequest;
+pub trait CompletionRequest: Send + Sync {
+    fn data(&self) -> serde_json::Result<String>;
+}
 
 pub trait CompletionProvider {
     fn complete(
         &self,
-        prompt: OpenAIRequest,
+        prompt: Box<dyn CompletionRequest>,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
 }

crates/ai/src/providers/dummy.rs 🔗

@@ -0,0 +1,13 @@
+use crate::completion::CompletionRequest;
+use serde::Serialize;
+
+#[derive(Serialize)]
+pub struct DummyCompletionRequest {
+    pub name: String,
+}
+
+impl CompletionRequest for DummyCompletionRequest {
+    fn data(&self) -> serde_json::Result<String> {
+        serde_json::to_string(self)
+    }
+}

crates/ai/src/providers/open_ai/completion.rs 🔗

@@ -12,7 +12,7 @@ use std::{
     sync::Arc,
 };
 
-use crate::completion::CompletionProvider;
+use crate::completion::{CompletionProvider, CompletionRequest};
 
 pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
 
@@ -59,6 +59,12 @@ pub struct OpenAIRequest {
     pub temperature: f32,
 }
 
+impl CompletionRequest for OpenAIRequest {
+    fn data(&self) -> serde_json::Result<String> {
+        serde_json::to_string(self)
+    }
+}
+
 #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 pub struct ResponseMessage {
     pub role: Option<Role>,
@@ -92,13 +98,11 @@ pub struct OpenAIResponseStreamEvent {
 pub async fn stream_completion(
     api_key: String,
     executor: Arc<Background>,
-    mut request: OpenAIRequest,
+    request: Box<dyn CompletionRequest>,
 ) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
-    request.stream = true;
-
     let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
 
-    let json_data = serde_json::to_string(&request)?;
+    let json_data = request.data()?;
     let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
         .header("Content-Type", "application/json")
         .header("Authorization", format!("Bearer {}", api_key))
@@ -189,7 +193,7 @@ impl OpenAICompletionProvider {
 impl CompletionProvider for OpenAICompletionProvider {
     fn complete(
         &self,
-        prompt: OpenAIRequest,
+        prompt: Box<dyn CompletionRequest>,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
         let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
         async move {

crates/assistant/src/assistant_panel.rs 🔗

@@ -6,8 +6,11 @@ use crate::{
     SavedMessage,
 };
 
-use ai::providers::open_ai::{
-    stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
+use ai::{
+    completion::CompletionRequest,
+    providers::open_ai::{
+        stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
+    },
 };
 
 use ai::prompts::repository_context::PromptCodeSnippet;
@@ -745,13 +748,14 @@ impl AssistantPanel {
                 content: prompt,
             });
 
-            let request = OpenAIRequest {
+            let request = Box::new(OpenAIRequest {
                 model: model.full_name().into(),
                 messages,
                 stream: true,
                 stop: vec!["|END|>".to_string()],
                 temperature,
-            };
+            });
+
             codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx));
             anyhow::Ok(())
         })
@@ -1735,7 +1739,7 @@ impl Conversation {
                 return Default::default();
             };
 
-            let request = OpenAIRequest {
+            let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
                 model: self.model.full_name().to_string(),
                 messages: self
                     .messages(cx)
@@ -1745,7 +1749,7 @@ impl Conversation {
                 stream: true,
                 stop: vec![],
                 temperature: 1.0,
-            };
+            });
 
             let stream = stream_completion(api_key, cx.background().clone(), request);
             let assistant_message = self
@@ -2025,13 +2029,13 @@ impl Conversation {
                             "Summarize the conversation into a short title without punctuation"
                                 .into(),
                     }));
-                let request = OpenAIRequest {
+                let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
                     model: self.model.full_name().to_string(),
                     messages: messages.collect(),
                     stream: true,
                     stop: vec![],
                     temperature: 1.0,
-                };
+                });
 
                 let stream = stream_completion(api_key, cx.background().clone(), request);
                 self.pending_summary = cx.spawn(|this, mut cx| {

crates/assistant/src/codegen.rs 🔗

@@ -1,6 +1,5 @@
 use crate::streaming_diff::{Hunk, StreamingDiff};
-use ai::completion::CompletionProvider;
-use ai::providers::open_ai::OpenAIRequest;
+use ai::completion::{CompletionProvider, CompletionRequest};
 use anyhow::Result;
 use editor::{multi_buffer, Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
 use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
@@ -96,7 +95,7 @@ impl Codegen {
         self.error.as_ref()
     }
 
-    pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
+    pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
         let range = self.range();
         let snapshot = self.snapshot.clone();
         let selected_text = snapshot
@@ -336,6 +335,7 @@ fn strip_markdown_codeblock(
 #[cfg(test)]
 mod tests {
     use super::*;
+    use ai::providers::dummy::DummyCompletionRequest;
     use futures::{
         future::BoxFuture,
         stream::{self, BoxStream},
@@ -381,7 +381,10 @@ mod tests {
                 cx,
             )
         });
-        codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
+        let request = Box::new(DummyCompletionRequest {
+            name: "test".to_string(),
+        });
+        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
 
         let mut new_text = concat!(
             "       let mut x = 0;\n",
@@ -443,7 +446,11 @@ mod tests {
                 cx,
             )
         });
-        codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
+
+        let request = Box::new(DummyCompletionRequest {
+            name: "test".to_string(),
+        });
+        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
 
         let mut new_text = concat!(
             "t mut x = 0;\n",
@@ -505,7 +512,11 @@ mod tests {
                 cx,
             )
         });
-        codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
+
+        let request = Box::new(DummyCompletionRequest {
+            name: "test".to_string(),
+        });
+        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
 
         let mut new_text = concat!(
             "let mut x = 0;\n",
@@ -617,7 +628,7 @@ mod tests {
     impl CompletionProvider for TestCompletionProvider {
         fn complete(
             &self,
-            _prompt: OpenAIRequest,
+            _prompt: Box<dyn CompletionRequest>,
         ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
             let (tx, rx) = mpsc::channel(1);
             *self.last_completion_tx.lock() = Some(tx);