WIP: Stream in completions

Nathan Sobo created

Drop dependency on tokio introduced by async-openai and do it ourselves.

The approach I'm taking of replacing instead of appending is causing issues. Need to just append.

Change summary

Cargo.lock                    | 204 --------------------------------
Cargo.toml                    |   1 
crates/ai/Cargo.toml          |   7 
crates/ai/src/ai.rs           | 223 +++++++++++++++++++++++++++++++-----
crates/auto_update/Cargo.toml |   2 
crates/feedback/Cargo.toml    |   2 
crates/gpui/src/executor.rs   |   2 
crates/util/Cargo.toml        |   2 
crates/zed/Cargo.toml         |   2 
9 files changed, 209 insertions(+), 236 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -100,11 +100,16 @@ name = "ai"
 version = "0.1.0"
 dependencies = [
  "anyhow",
- "async-openai",
+ "async-stream",
  "editor",
+ "futures 0.3.28",
  "gpui",
+ "isahc",
  "pulldown-cmark",
+ "serde",
+ "serde_json",
  "unindent",
+ "util",
 ]
 
 [[package]]
@@ -354,28 +359,6 @@ dependencies = [
  "futures-lite",
 ]
 
-[[package]]
-name = "async-openai"
-version = "0.10.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e5d5e93aca1b2f0ca772c76cadd43e965809df87ef98e25e47244c7f006c85d2"
-dependencies = [
- "backoff",
- "base64 0.21.0",
- "derive_builder",
- "futures 0.3.28",
- "rand 0.8.5",
- "reqwest",
- "reqwest-eventsource",
- "serde",
- "serde_json",
- "thiserror",
- "tokio",
- "tokio-stream",
- "tokio-util 0.7.8",
- "tracing",
-]
-
 [[package]]
 name = "async-pipe"
 version = "0.1.3"
@@ -676,20 +659,6 @@ dependencies = [
  "tower-service",
 ]
 
-[[package]]
-name = "backoff"
-version = "0.4.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
-dependencies = [
- "futures-core",
- "getrandom 0.2.9",
- "instant",
- "pin-project-lite 0.2.9",
- "rand 0.8.5",
- "tokio",
-]
-
 [[package]]
 name = "backtrace"
 version = "0.3.67"
@@ -1849,41 +1818,6 @@ dependencies = [
  "syn 2.0.15",
 ]
 
-[[package]]
-name = "darling"
-version = "0.14.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850"
-dependencies = [
- "darling_core",
- "darling_macro",
-]
-
-[[package]]
-name = "darling_core"
-version = "0.14.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0"
-dependencies = [
- "fnv",
- "ident_case",
- "proc-macro2",
- "quote",
- "strsim 0.10.0",
- "syn 1.0.109",
-]
-
-[[package]]
-name = "darling_macro"
-version = "0.14.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e"
-dependencies = [
- "darling_core",
- "quote",
- "syn 1.0.109",
-]
-
 [[package]]
 name = "dashmap"
 version = "5.4.0"
@@ -1938,37 +1872,6 @@ dependencies = [
  "byteorder",
 ]
 
-[[package]]
-name = "derive_builder"
-version = "0.12.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8"
-dependencies = [
- "derive_builder_macro",
-]
-
-[[package]]
-name = "derive_builder_core"
-version = "0.12.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f"
-dependencies = [
- "darling",
- "proc-macro2",
- "quote",
- "syn 1.0.109",
-]
-
-[[package]]
-name = "derive_builder_macro"
-version = "0.12.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e"
-dependencies = [
- "derive_builder_core",
- "syn 1.0.109",
-]
-
 [[package]]
 name = "dhat"
 version = "0.3.2"
@@ -2304,17 +2207,6 @@ version = "2.5.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0"
 
-[[package]]
-name = "eventsource-stream"
-version = "0.2.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
-dependencies = [
- "futures-core",
- "nom",
- "pin-project-lite 0.2.9",
-]
-
 [[package]]
 name = "fallible-iterator"
 version = "0.2.0"
@@ -2711,12 +2603,6 @@ version = "0.3.28"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65"
 
-[[package]]
-name = "futures-timer"
-version = "3.0.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c"
-
 [[package]]
 name = "futures-util"
 version = "0.3.28"
@@ -3200,19 +3086,6 @@ dependencies = [
  "want",
 ]
 
-[[package]]
-name = "hyper-rustls"
-version = "0.23.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1788965e61b367cd03a62950836d5cd41560c3577d90e40e0819373194d1661c"
-dependencies = [
- "http",
- "hyper",
- "rustls 0.20.8",
- "tokio",
- "tokio-rustls",
-]
-
 [[package]]
 name = "hyper-timeout"
 version = "0.4.1"
@@ -3262,12 +3135,6 @@ dependencies = [
  "cxx-build",
 ]
 
-[[package]]
-name = "ident_case"
-version = "1.0.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
-
 [[package]]
 name = "idna"
 version = "0.3.0"
@@ -4062,16 +3929,6 @@ version = "0.3.17"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
 
-[[package]]
-name = "mime_guess"
-version = "2.0.4"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef"
-dependencies = [
- "mime",
- "unicase",
-]
-
 [[package]]
 name = "minimal-lexical"
 version = "0.2.1"
@@ -5537,52 +5394,28 @@ dependencies = [
  "http",
  "http-body",
  "hyper",
- "hyper-rustls",
  "hyper-tls",
  "ipnet",
  "js-sys",
  "log",
  "mime",
- "mime_guess",
  "native-tls",
  "once_cell",
  "percent-encoding",
  "pin-project-lite 0.2.9",
- "rustls 0.20.8",
- "rustls-native-certs",
- "rustls-pemfile",
  "serde",
  "serde_json",
  "serde_urlencoded",
  "tokio",
  "tokio-native-tls",
- "tokio-rustls",
- "tokio-util 0.7.8",
  "tower-service",
  "url",
  "wasm-bindgen",
  "wasm-bindgen-futures",
- "wasm-streams",
  "web-sys",
  "winreg",
 ]
 
-[[package]]
-name = "reqwest-eventsource"
-version = "0.4.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8f03f570355882dd8d15acc3a313841e6e90eddbc76a93c748fd82cc13ba9f51"
-dependencies = [
- "eventsource-stream",
- "futures-core",
- "futures-timer",
- "mime",
- "nom",
- "pin-project-lite 0.2.9",
- "reqwest",
- "thiserror",
-]
-
 [[package]]
 name = "resvg"
 version = "0.14.1"
@@ -5870,18 +5703,6 @@ dependencies = [
  "webpki 0.22.0",
 ]
 
-[[package]]
-name = "rustls-native-certs"
-version = "0.6.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0167bac7a9f490495f3c33013e7722b53cb087ecbe082fb0c6387c96f634ea50"
-dependencies = [
- "openssl-probe",
- "rustls-pemfile",
- "schannel",
- "security-framework",
-]
-
 [[package]]
 name = "rustls-pemfile"
 version = "1.0.2"
@@ -8245,19 +8066,6 @@ dependencies = [
  "leb128",
 ]
 
-[[package]]
-name = "wasm-streams"
-version = "0.2.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6bbae3363c08332cadccd13b67db371814cd214c2524020932f0804b8cf7c078"
-dependencies = [
- "futures-util",
- "js-sys",
- "wasm-bindgen",
- "wasm-bindgen-futures",
- "web-sys",
-]
-
 [[package]]
 name = "wasmparser"
 version = "0.85.0"

Cargo.toml 🔗

@@ -79,6 +79,7 @@ ctor = { version = "0.1" }
 env_logger = { version = "0.9" }
 futures = { version = "0.3" }
 glob = { version = "0.3.1" }
+isahc = "1.7.2"
 lazy_static = { version = "1.4.0" }
 log = { version = "0.4.16", features = ["kv_unstable_serde"] }
 ordered-float = { version = "2.1.1" }

crates/ai/Cargo.toml 🔗

@@ -11,11 +11,16 @@ doctest = false
 [dependencies]
 editor = { path = "../editor" }
 gpui = { path = "../gpui" }
+util = { path = "../util" }
 
+serde.workspace = true
+serde_json.workspace = true
 anyhow.workspace = true
-async-openai = "0.10.3"
 pulldown-cmark = "0.9.2"
+futures.workspace = true
+isahc.workspace = true
 unindent.workspace = true
+async-stream = "0.3.5"
 
 [dev-dependencies]
 editor = { path = "../editor", features = ["test-support"] }

crates/ai/src/ai.rs 🔗

@@ -1,11 +1,87 @@
-use anyhow::Result;
-use async_openai::types::{ChatCompletionRequestMessage, CreateChatCompletionRequest, Role};
+use std::io;
+use std::rc::Rc;
+
+use anyhow::{anyhow, Result};
 use editor::Editor;
+use futures::AsyncBufReadExt;
+use futures::{io::BufReader, AsyncReadExt, Stream, StreamExt};
+use gpui::executor::Foreground;
 use gpui::{actions, AppContext, Task, ViewContext};
+use isahc::prelude::*;
+use isahc::{http::StatusCode, Request};
 use pulldown_cmark::{Event, HeadingLevel, Parser, Tag};
+use serde::{Deserialize, Serialize};
+use util::ResultExt;
 
 actions!(ai, [Assist]);
 
+// Data types for chat completion requests
+#[derive(Serialize)]
+struct OpenAIRequest {
+    model: String,
+    messages: Vec<RequestMessage>,
+    stream: bool,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+struct RequestMessage {
+    role: Role,
+    content: String,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+struct ResponseMessage {
+    role: Option<Role>,
+    content: Option<String>,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(rename_all = "lowercase")]
+enum Role {
+    User,
+    Assistant,
+    System,
+}
+
+#[derive(Deserialize, Debug)]
+struct OpenAIResponseStreamEvent {
+    pub id: Option<String>,
+    pub object: String,
+    pub created: u32,
+    pub model: String,
+    pub choices: Vec<ChatChoiceDelta>,
+    pub usage: Option<Usage>,
+}
+
+#[derive(Deserialize, Debug)]
+struct Usage {
+    pub prompt_tokens: u32,
+    pub completion_tokens: u32,
+    pub total_tokens: u32,
+}
+
+#[derive(Deserialize, Debug)]
+struct ChatChoiceDelta {
+    pub index: u32,
+    pub delta: ResponseMessage,
+    pub finish_reason: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+struct OpenAIUsage {
+    prompt_tokens: u64,
+    completion_tokens: u64,
+    total_tokens: u64,
+}
+
+#[derive(Deserialize, Debug)]
+struct OpenAIChoice {
+    text: String,
+    index: u32,
+    logprobs: Option<serde_json::Value>,
+    finish_reason: Option<String>,
+}
+
 pub fn init(cx: &mut AppContext) {
     cx.add_async_action(assist)
 }
@@ -15,26 +91,58 @@ fn assist(
     _: &Assist,
     cx: &mut ViewContext<Editor>,
 ) -> Option<Task<Result<()>>> {
+    let api_key = std::env::var("OPENAI_API_KEY").log_err()?;
+
     let markdown = editor.text(cx);
-    parse_dialog(&markdown);
-    None
+    let prompt = parse_dialog(&markdown);
+    let response = stream_completion(api_key, prompt, cx.foreground().clone());
+
+    let range = editor.buffer().update(cx, |buffer, cx| {
+        let snapshot = buffer.snapshot(cx);
+        let chars = snapshot.reversed_chars_at(snapshot.len());
+        let trailing_newlines = chars.take(2).take_while(|c| *c == '\n').count();
+        let suffix = "\n".repeat(2 - trailing_newlines);
+        let end = snapshot.len();
+        buffer.edit([(end..end, suffix.clone())], None, cx);
+        let snapshot = buffer.snapshot(cx);
+        let start = snapshot.anchor_before(snapshot.len());
+        let end = snapshot.anchor_after(snapshot.len());
+        start..end
+    });
+    let buffer = editor.buffer().clone();
+
+    Some(cx.spawn(|_, mut cx| async move {
+        let mut stream = response.await?;
+        let mut message = String::new();
+        while let Some(stream_event) = stream.next().await {
+            if let Some(choice) = stream_event?.choices.first() {
+                if let Some(content) = &choice.delta.content {
+                    message.push_str(content);
+                }
+            }
+
+            buffer.update(&mut cx, |buffer, cx| {
+                buffer.edit([(range.clone(), message.clone())], None, cx);
+            });
+        }
+        Ok(())
+    }))
 }
 
-fn parse_dialog(markdown: &str) -> CreateChatCompletionRequest {
+fn parse_dialog(markdown: &str) -> OpenAIRequest {
     let parser = Parser::new(markdown);
     let mut messages = Vec::new();
 
-    let mut current_role: Option<(Role, Option<String>)> = None;
+    let mut current_role: Option<Role> = None;
     let mut buffer = String::new();
     for event in parser {
         match event {
             Event::Start(Tag::Heading(HeadingLevel::H2, _, _)) => {
-                if let Some((role, name)) = current_role.take() {
+                if let Some(role) = current_role.take() {
                     if !buffer.is_empty() {
-                        messages.push(ChatCompletionRequestMessage {
+                        messages.push(RequestMessage {
                             role,
                             content: buffer.trim().to_string(),
-                            name,
                         });
                         buffer.clear();
                     }
@@ -45,36 +153,89 @@ fn parse_dialog(markdown: &str) -> CreateChatCompletionRequest {
                     buffer.push_str(&text);
                 } else {
                     // Determine the current role based on the H2 header text
-                    let mut chars = text.chars();
-                    let first_char = chars.by_ref().skip_while(|c| c.is_whitespace()).next();
-                    let name = chars.take_while(|c| *c != '\n').collect::<String>();
-                    let name = if name.is_empty() { None } else { Some(name) };
-
-                    let role = match first_char {
-                        Some('@') => Some(Role::User),
-                        Some('/') => Some(Role::Assistant),
-                        Some('#') => Some(Role::System),
-                        _ => None,
+                    let text = text.to_lowercase();
+                    current_role = if text.contains("user") {
+                        Some(Role::User)
+                    } else if text.contains("assistant") {
+                        Some(Role::Assistant)
+                    } else if text.contains("system") {
+                        Some(Role::System)
+                    } else {
+                        None
                     };
-
-                    current_role = role.map(|role| (role, name));
                 }
             }
             _ => (),
         }
     }
-    if let Some((role, name)) = current_role {
-        messages.push(ChatCompletionRequestMessage {
+    if let Some(role) = current_role {
+        messages.push(RequestMessage {
             role,
             content: buffer,
-            name,
         });
     }
 
-    CreateChatCompletionRequest {
+    OpenAIRequest {
         model: "gpt-4".into(),
         messages,
-        ..Default::default()
+        stream: true,
+    }
+}
+
+async fn stream_completion(
+    api_key: String,
+    mut request: OpenAIRequest,
+    executor: Rc<Foreground>,
+) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
+    request.stream = true;
+
+    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
+
+    let json_data = serde_json::to_string(&request)?;
+    let mut response = Request::post("https://api.openai.com/v1/chat/completions")
+        .header("Content-Type", "application/json")
+        .header("Authorization", format!("Bearer {}", api_key))
+        .body(json_data)?
+        .send_async()
+        .await?;
+
+    let status = response.status();
+    if status == StatusCode::OK {
+        executor
+            .spawn(async move {
+                let mut lines = BufReader::new(response.body_mut()).lines();
+
+                fn parse_line(
+                    line: Result<String, io::Error>,
+                ) -> Result<Option<OpenAIResponseStreamEvent>> {
+                    if let Some(data) = line?.strip_prefix("data: ") {
+                        let event = serde_json::from_str(&data)?;
+                        Ok(Some(event))
+                    } else {
+                        Ok(None)
+                    }
+                }
+
+                while let Some(line) = lines.next().await {
+                    if let Some(event) = parse_line(line).transpose() {
+                        tx.unbounded_send(event).log_err();
+                    }
+                }
+
+                anyhow::Ok(())
+            })
+            .detach();
+
+        Ok(rx)
+    } else {
+        let mut body = String::new();
+        response.body_mut().read_to_string(&mut body).await?;
+
+        Err(anyhow!(
+            "Failed to connect to OpenAI API: {} {}",
+            response.status(),
+            body,
+        ))
     }
 }
 
@@ -87,23 +248,21 @@ mod tests {
         use unindent::Unindent;
 
         let test_input = r#"
-            ## @nathan
+            ## System
             Hey there, welcome to Zed!
 
-            ## /sky
+            ## Assintant
             Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.
         "#.unindent();
 
         let expected_output = vec![
-            ChatCompletionRequestMessage {
+            RequestMessage {
                 role: Role::User,
                 content: "Hey there, welcome to Zed!".to_string(),
-                name: Some("nathan".to_string()),
             },
-            ChatCompletionRequestMessage {
+            RequestMessage {
                 role: Role::Assistant,
                 content: "Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.".to_string(),
-                name: Some("sky".to_string()),
             },
         ];
 

crates/auto_update/Cargo.toml 🔗

@@ -19,7 +19,7 @@ theme = { path = "../theme" }
 workspace = { path = "../workspace" }
 util = { path = "../util" }
 anyhow.workspace = true
-isahc = "1.7"
+isahc.workspace = true
 lazy_static.workspace = true
 log.workspace = true
 serde.workspace = true

crates/feedback/Cargo.toml 🔗

@@ -27,7 +27,7 @@ futures.workspace = true
 anyhow.workspace = true
 smallvec.workspace = true
 human_bytes = "0.4.1"
-isahc = "1.7"
+isahc.workspace = true
 lazy_static.workspace = true
 postage.workspace = true
 serde.workspace = true

crates/gpui/src/executor.rs 🔗

@@ -960,7 +960,7 @@ impl<T: 'static, E: 'static + Display> Task<Result<T, E>> {
     pub fn detach_and_log_err(self, cx: &mut AppContext) {
         cx.spawn(|_| async move {
             if let Err(err) = self.await {
-                log::error!("{}", err);
+                log::error!("{:#}", err);
             }
         })
         .detach();

crates/util/Cargo.toml 🔗

@@ -17,7 +17,7 @@ backtrace = "0.3"
 log.workspace = true
 lazy_static.workspace = true
 futures.workspace = true
-isahc = "1.7"
+isahc.workspace = true
 smol.workspace = true
 url = "2.2"
 rand.workspace = true

crates/zed/Cargo.toml 🔗

@@ -82,7 +82,7 @@ futures.workspace = true
 ignore = "0.4"
 image = "0.23"
 indexmap = "1.6.2"
-isahc = "1.7"
+isahc.workspace = true
 lazy_static.workspace = true
 libc = "0.2"
 log.workspace = true