introduce ai crate with completion providers

KCaverly created

Change summary

Cargo.lock                              |  15 +
Cargo.toml                              |   1 
crates/ai/Cargo.toml                    |  21 ++
crates/ai/src/ai.rs                     |   1 
crates/ai/src/completion.rs             | 212 +++++++++++++++++++++++++++
crates/assistant/Cargo.toml             |   1 
crates/assistant/src/assistant.rs       | 192 -----------------------
crates/assistant/src/assistant_panel.rs |   9 
crates/assistant/src/codegen.rs         |  59 +------
crates/zed/src/zed.rs                   |   4 
10 files changed, 273 insertions(+), 242 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -86,6 +86,20 @@ dependencies = [
  "memchr",
 ]
 
+[[package]]
+name = "ai"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "ctor",
+ "futures 0.3.28",
+ "gpui",
+ "isahc",
+ "regex",
+ "serde",
+ "serde_json",
+]
+
 [[package]]
 name = "alacritty_config"
 version = "0.1.2-dev"
@@ -272,6 +286,7 @@ checksum = "d92bec98840b8f03a5ff5413de5293bfcd8bf96467cf5452609f939ec6f5de16"
 name = "assistant"
 version = "0.1.0"
 dependencies = [
+ "ai",
  "anyhow",
  "chrono",
  "client",

Cargo.toml 🔗

@@ -1,6 +1,7 @@
 [workspace]
 members = [
     "crates/activity_indicator",
+    "crates/ai",
     "crates/assistant",
     "crates/audio",
     "crates/auto_update",

crates/ai/Cargo.toml 🔗

@@ -0,0 +1,21 @@
+[package]
+name = "ai"
+version = "0.1.0"
+edition = "2021"
+publish = false
+
+[lib]
+path = "src/ai.rs"
+doctest = false
+
+[dependencies]
+gpui = { path = "../gpui" }
+anyhow.workspace = true
+futures.workspace = true
+isahc.workspace = true
+regex.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+
+[dev-dependencies]
+ctor.workspace = true

crates/ai/src/completion.rs 🔗

@@ -0,0 +1,212 @@
+use anyhow::{anyhow, Result};
+use futures::{
+    future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
+    Stream, StreamExt,
+};
+use gpui::executor::Background;
+use isahc::{http::StatusCode, Request, RequestExt};
+use serde::{Deserialize, Serialize};
+use std::{
+    fmt::{self, Display},
+    io,
+    sync::Arc,
+};
+
+pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
+
+#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(rename_all = "lowercase")]
+pub enum Role {
+    User,
+    Assistant,
+    System,
+}
+
+impl Role {
+    pub fn cycle(&mut self) {
+        *self = match self {
+            Role::User => Role::Assistant,
+            Role::Assistant => Role::System,
+            Role::System => Role::User,
+        }
+    }
+}
+
+impl Display for Role {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
+        match self {
+            Role::User => write!(f, "User"),
+            Role::Assistant => write!(f, "Assistant"),
+            Role::System => write!(f, "System"),
+        }
+    }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct RequestMessage {
+    pub role: Role,
+    pub content: String,
+}
+
+#[derive(Debug, Default, Serialize)]
+pub struct OpenAIRequest {
+    pub model: String,
+    pub messages: Vec<RequestMessage>,
+    pub stream: bool,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct ResponseMessage {
+    pub role: Option<Role>,
+    pub content: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct OpenAIUsage {
+    pub prompt_tokens: u32,
+    pub completion_tokens: u32,
+    pub total_tokens: u32,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct ChatChoiceDelta {
+    pub index: u32,
+    pub delta: ResponseMessage,
+    pub finish_reason: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct OpenAIResponseStreamEvent {
+    pub id: Option<String>,
+    pub object: String,
+    pub created: u32,
+    pub model: String,
+    pub choices: Vec<ChatChoiceDelta>,
+    pub usage: Option<OpenAIUsage>,
+}
+
+pub async fn stream_completion(
+    api_key: String,
+    executor: Arc<Background>,
+    mut request: OpenAIRequest,
+) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
+    request.stream = true;
+
+    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
+
+    let json_data = serde_json::to_string(&request)?;
+    let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
+        .header("Content-Type", "application/json")
+        .header("Authorization", format!("Bearer {}", api_key))
+        .body(json_data)?
+        .send_async()
+        .await?;
+
+    let status = response.status();
+    if status == StatusCode::OK {
+        executor
+            .spawn(async move {
+                let mut lines = BufReader::new(response.body_mut()).lines();
+
+                fn parse_line(
+                    line: Result<String, io::Error>,
+                ) -> Result<Option<OpenAIResponseStreamEvent>> {
+                    if let Some(data) = line?.strip_prefix("data: ") {
+                        let event = serde_json::from_str(&data)?;
+                        Ok(Some(event))
+                    } else {
+                        Ok(None)
+                    }
+                }
+
+                while let Some(line) = lines.next().await {
+                    if let Some(event) = parse_line(line).transpose() {
+                        let done = event.as_ref().map_or(false, |event| {
+                            event
+                                .choices
+                                .last()
+                                .map_or(false, |choice| choice.finish_reason.is_some())
+                        });
+                        if tx.unbounded_send(event).is_err() {
+                            break;
+                        }
+
+                        if done {
+                            break;
+                        }
+                    }
+                }
+
+                anyhow::Ok(())
+            })
+            .detach();
+
+        Ok(rx)
+    } else {
+        let mut body = String::new();
+        response.body_mut().read_to_string(&mut body).await?;
+
+        #[derive(Deserialize)]
+        struct OpenAIResponse {
+            error: OpenAIError,
+        }
+
+        #[derive(Deserialize)]
+        struct OpenAIError {
+            message: String,
+        }
+
+        match serde_json::from_str::<OpenAIResponse>(&body) {
+            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
+                "Failed to connect to OpenAI API: {}",
+                response.error.message,
+            )),
+
+            _ => Err(anyhow!(
+                "Failed to connect to OpenAI API: {} {}",
+                response.status(),
+                body,
+            )),
+        }
+    }
+}
+
+pub trait CompletionProvider {
+    fn complete(
+        &self,
+        prompt: OpenAIRequest,
+    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
+}
+
+pub struct OpenAICompletionProvider {
+    api_key: String,
+    executor: Arc<Background>,
+}
+
+impl OpenAICompletionProvider {
+    pub fn new(api_key: String, executor: Arc<Background>) -> Self {
+        Self { api_key, executor }
+    }
+}
+
+impl CompletionProvider for OpenAICompletionProvider {
+    fn complete(
+        &self,
+        prompt: OpenAIRequest,
+    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+        let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
+        async move {
+            let response = request.await?;
+            let stream = response
+                .filter_map(|response| async move {
+                    match response {
+                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
+                        Err(error) => Some(Err(error)),
+                    }
+                })
+                .boxed();
+            Ok(stream)
+        }
+        .boxed()
+    }
+}

crates/assistant/Cargo.toml 🔗

@@ -9,6 +9,7 @@ path = "src/assistant.rs"
 doctest = false
 
 [dependencies]
+ai = { path = "../ai" }
 client = { path = "../client" }
 collections = { path = "../collections"}
 editor = { path = "../editor" }

crates/assistant/src/assistant.rs 🔗

@@ -3,37 +3,20 @@ mod assistant_settings;
 mod codegen;
 mod streaming_diff;
 
-use anyhow::{anyhow, Result};
+use ai::completion::Role;
+use anyhow::Result;
 pub use assistant_panel::AssistantPanel;
 use assistant_settings::OpenAIModel;
 use chrono::{DateTime, Local};
 use collections::HashMap;
 use fs::Fs;
-use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
-use gpui::{executor::Background, AppContext};
-use isahc::{http::StatusCode, Request, RequestExt};
+use futures::StreamExt;
+use gpui::AppContext;
 use regex::Regex;
 use serde::{Deserialize, Serialize};
-use std::{
-    cmp::Reverse,
-    ffi::OsStr,
-    fmt::{self, Display},
-    io,
-    path::PathBuf,
-    sync::Arc,
-};
+use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc};
 use util::paths::CONVERSATIONS_DIR;
 
-const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
-
-// Data types for chat completion requests
-#[derive(Debug, Default, Serialize)]
-pub struct OpenAIRequest {
-    model: String,
-    messages: Vec<RequestMessage>,
-    stream: bool,
-}
-
 #[derive(
     Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
 )]
@@ -116,175 +99,10 @@ impl SavedConversationMetadata {
     }
 }
 
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-struct RequestMessage {
-    role: Role,
-    content: String,
-}
-
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct ResponseMessage {
-    role: Option<Role>,
-    content: Option<String>,
-}
-
-#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
-#[serde(rename_all = "lowercase")]
-enum Role {
-    User,
-    Assistant,
-    System,
-}
-
-impl Role {
-    pub fn cycle(&mut self) {
-        *self = match self {
-            Role::User => Role::Assistant,
-            Role::Assistant => Role::System,
-            Role::System => Role::User,
-        }
-    }
-}
-
-impl Display for Role {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
-        match self {
-            Role::User => write!(f, "User"),
-            Role::Assistant => write!(f, "Assistant"),
-            Role::System => write!(f, "System"),
-        }
-    }
-}
-
-#[derive(Deserialize, Debug)]
-pub struct OpenAIResponseStreamEvent {
-    pub id: Option<String>,
-    pub object: String,
-    pub created: u32,
-    pub model: String,
-    pub choices: Vec<ChatChoiceDelta>,
-    pub usage: Option<Usage>,
-}
-
-#[derive(Deserialize, Debug)]
-pub struct Usage {
-    pub prompt_tokens: u32,
-    pub completion_tokens: u32,
-    pub total_tokens: u32,
-}
-
-#[derive(Deserialize, Debug)]
-pub struct ChatChoiceDelta {
-    pub index: u32,
-    pub delta: ResponseMessage,
-    pub finish_reason: Option<String>,
-}
-
-#[derive(Deserialize, Debug)]
-struct OpenAIUsage {
-    prompt_tokens: u64,
-    completion_tokens: u64,
-    total_tokens: u64,
-}
-
-#[derive(Deserialize, Debug)]
-struct OpenAIChoice {
-    text: String,
-    index: u32,
-    logprobs: Option<serde_json::Value>,
-    finish_reason: Option<String>,
-}
-
 pub fn init(cx: &mut AppContext) {
     assistant_panel::init(cx);
 }
 
-pub async fn stream_completion(
-    api_key: String,
-    executor: Arc<Background>,
-    mut request: OpenAIRequest,
-) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
-    request.stream = true;
-
-    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
-
-    let json_data = serde_json::to_string(&request)?;
-    let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
-        .header("Content-Type", "application/json")
-        .header("Authorization", format!("Bearer {}", api_key))
-        .body(json_data)?
-        .send_async()
-        .await?;
-
-    let status = response.status();
-    if status == StatusCode::OK {
-        executor
-            .spawn(async move {
-                let mut lines = BufReader::new(response.body_mut()).lines();
-
-                fn parse_line(
-                    line: Result<String, io::Error>,
-                ) -> Result<Option<OpenAIResponseStreamEvent>> {
-                    if let Some(data) = line?.strip_prefix("data: ") {
-                        let event = serde_json::from_str(&data)?;
-                        Ok(Some(event))
-                    } else {
-                        Ok(None)
-                    }
-                }
-
-                while let Some(line) = lines.next().await {
-                    if let Some(event) = parse_line(line).transpose() {
-                        let done = event.as_ref().map_or(false, |event| {
-                            event
-                                .choices
-                                .last()
-                                .map_or(false, |choice| choice.finish_reason.is_some())
-                        });
-                        if tx.unbounded_send(event).is_err() {
-                            break;
-                        }
-
-                        if done {
-                            break;
-                        }
-                    }
-                }
-
-                anyhow::Ok(())
-            })
-            .detach();
-
-        Ok(rx)
-    } else {
-        let mut body = String::new();
-        response.body_mut().read_to_string(&mut body).await?;
-
-        #[derive(Deserialize)]
-        struct OpenAIResponse {
-            error: OpenAIError,
-        }
-
-        #[derive(Deserialize)]
-        struct OpenAIError {
-            message: String,
-        }
-
-        match serde_json::from_str::<OpenAIResponse>(&body) {
-            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
-                "Failed to connect to OpenAI API: {}",
-                response.error.message,
-            )),
-
-            _ => Err(anyhow!(
-                "Failed to connect to OpenAI API: {} {}",
-                response.status(),
-                body,
-            )),
-        }
-    }
-}
-
 #[cfg(test)]
 #[ctor::ctor]
 fn init_logger() {

crates/assistant/src/assistant_panel.rs 🔗

@@ -1,8 +1,11 @@
 use crate::{
     assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
-    codegen::{self, Codegen, CodegenKind, OpenAICompletionProvider},
-    stream_completion, MessageId, MessageMetadata, MessageStatus, OpenAIRequest, RequestMessage,
-    Role, SavedConversation, SavedConversationMetadata, SavedMessage, OPENAI_API_URL,
+    codegen::{self, Codegen, CodegenKind},
+    MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
+    SavedMessage,
+};
+use ai::completion::{
+    stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
 };
 use anyhow::{anyhow, Result};
 use chrono::{DateTime, Local};

crates/assistant/src/codegen.rs 🔗

@@ -1,59 +1,14 @@
-use crate::{
-    stream_completion,
-    streaming_diff::{Hunk, StreamingDiff},
-    OpenAIRequest,
-};
+use crate::streaming_diff::{Hunk, StreamingDiff};
+use ai::completion::{CompletionProvider, OpenAIRequest};
 use anyhow::Result;
 use editor::{
     multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
 };
-use futures::{
-    channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt,
-};
-use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task};
+use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
+use gpui::{Entity, ModelContext, ModelHandle, Task};
 use language::{Rope, TransactionId};
 use std::{cmp, future, ops::Range, sync::Arc};
 
-pub trait CompletionProvider {
-    fn complete(
-        &self,
-        prompt: OpenAIRequest,
-    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
-}
-
-pub struct OpenAICompletionProvider {
-    api_key: String,
-    executor: Arc<Background>,
-}
-
-impl OpenAICompletionProvider {
-    pub fn new(api_key: String, executor: Arc<Background>) -> Self {
-        Self { api_key, executor }
-    }
-}
-
-impl CompletionProvider for OpenAICompletionProvider {
-    fn complete(
-        &self,
-        prompt: OpenAIRequest,
-    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
-        let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
-        async move {
-            let response = request.await?;
-            let stream = response
-                .filter_map(|response| async move {
-                    match response {
-                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
-                        Err(error) => Some(Err(error)),
-                    }
-                })
-                .boxed();
-            Ok(stream)
-        }
-        .boxed()
-    }
-}
-
 pub enum Event {
     Finished,
     Undone,
@@ -397,13 +352,17 @@ fn strip_markdown_codeblock(
 #[cfg(test)]
 mod tests {
     use super::*;
-    use futures::stream;
+    use futures::{
+        future::BoxFuture,
+        stream::{self, BoxStream},
+    };
     use gpui::{executor::Deterministic, TestAppContext};
     use indoc::indoc;
     use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
     use parking_lot::Mutex;
     use rand::prelude::*;
     use settings::SettingsStore;
+    use smol::future::FutureExt;
 
     #[gpui::test(iterations = 10)]
     async fn test_transform_autoindent(

crates/zed/src/zed.rs 🔗

@@ -5,9 +5,9 @@ pub mod only_instance;
 #[cfg(any(test, feature = "test-support"))]
 pub mod test;
 
-use assistant::AssistantPanel;
 use anyhow::Context;
 use assets::Assets;
+use assistant::AssistantPanel;
 use breadcrumbs::Breadcrumbs;
 pub use client;
 use collab_ui::CollabTitlebarItem; // TODO: Add back toggle collab ui shortcut
@@ -2418,7 +2418,7 @@ mod tests {
             pane::init(cx);
             project_panel::init((), cx);
             terminal_view::init(cx);
-            ai::init(cx);
+            assistant::init(cx);
             app_state
         })
     }