Parse markdown into an OpenAI chat completion request

Nathan Sobo created

Change summary

Cargo.lock                        | 220 +++++++++++++++++++++++++++++++++
Cargo.toml                        |   1 
crates/ai/Cargo.toml              |  21 +++
crates/ai/src/ai.rs               | 112 ++++++++++++++++
crates/live_kit_client/Cargo.toml |   1 
crates/zed/Cargo.toml             |   1 
crates/zed/src/main.rs            |   1 
7 files changed, 357 insertions(+)

Detailed changes

Cargo.lock 🔗

@@ -95,6 +95,18 @@ dependencies = [
  "memchr",
 ]
 
+[[package]]
+name = "ai"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "async-openai",
+ "editor",
+ "gpui",
+ "pulldown-cmark",
+ "unindent",
+]
+
 [[package]]
 name = "alacritty_config"
 version = "0.1.1-dev"
@@ -342,6 +354,28 @@ dependencies = [
  "futures-lite",
 ]
 
+[[package]]
+name = "async-openai"
+version = "0.10.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e5d5e93aca1b2f0ca772c76cadd43e965809df87ef98e25e47244c7f006c85d2"
+dependencies = [
+ "backoff",
+ "base64 0.21.0",
+ "derive_builder",
+ "futures 0.3.28",
+ "rand 0.8.5",
+ "reqwest",
+ "reqwest-eventsource",
+ "serde",
+ "serde_json",
+ "thiserror",
+ "tokio",
+ "tokio-stream",
+ "tokio-util 0.7.8",
+ "tracing",
+]
+
 [[package]]
 name = "async-pipe"
 version = "0.1.3"
@@ -642,6 +676,20 @@ dependencies = [
  "tower-service",
 ]
 
+[[package]]
+name = "backoff"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1"
+dependencies = [
+ "futures-core",
+ "getrandom 0.2.9",
+ "instant",
+ "pin-project-lite 0.2.9",
+ "rand 0.8.5",
+ "tokio",
+]
+
 [[package]]
 name = "backtrace"
 version = "0.3.67"
@@ -1801,6 +1849,41 @@ dependencies = [
  "syn 2.0.15",
 ]
 
+[[package]]
+name = "darling"
+version = "0.14.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850"
+dependencies = [
+ "darling_core",
+ "darling_macro",
+]
+
+[[package]]
+name = "darling_core"
+version = "0.14.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0"
+dependencies = [
+ "fnv",
+ "ident_case",
+ "proc-macro2",
+ "quote",
+ "strsim 0.10.0",
+ "syn 1.0.109",
+]
+
+[[package]]
+name = "darling_macro"
+version = "0.14.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e"
+dependencies = [
+ "darling_core",
+ "quote",
+ "syn 1.0.109",
+]
+
 [[package]]
 name = "dashmap"
 version = "5.4.0"
@@ -1855,6 +1938,37 @@ dependencies = [
  "byteorder",
 ]
 
+[[package]]
+name = "derive_builder"
+version = "0.12.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8"
+dependencies = [
+ "derive_builder_macro",
+]
+
+[[package]]
+name = "derive_builder_core"
+version = "0.12.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f"
+dependencies = [
+ "darling",
+ "proc-macro2",
+ "quote",
+ "syn 1.0.109",
+]
+
+[[package]]
+name = "derive_builder_macro"
+version = "0.12.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e"
+dependencies = [
+ "derive_builder_core",
+ "syn 1.0.109",
+]
+
 [[package]]
 name = "dhat"
 version = "0.3.2"
@@ -2190,6 +2304,17 @@ version = "2.5.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0"
 
+[[package]]
+name = "eventsource-stream"
+version = "0.2.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "74fef4569247a5f429d9156b9d0a2599914385dd189c539334c625d8099d90ab"
+dependencies = [
+ "futures-core",
+ "nom",
+ "pin-project-lite 0.2.9",
+]
+
 [[package]]
 name = "fallible-iterator"
 version = "0.2.0"
@@ -2586,6 +2711,12 @@ version = "0.3.28"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65"
 
+[[package]]
+name = "futures-timer"
+version = "3.0.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c"
+
 [[package]]
 name = "futures-util"
 version = "0.3.28"
@@ -2633,6 +2764,15 @@ dependencies = [
  "version_check",
 ]
 
+[[package]]
+name = "getopts"
+version = "0.2.21"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "14dbbfd5c71d70241ecf9e6f13737f7b5ce823821063188d7e46c41d371eebd5"
+dependencies = [
+ "unicode-width",
+]
+
 [[package]]
 name = "getrandom"
 version = "0.1.16"
@@ -3060,6 +3200,19 @@ dependencies = [
  "want",
 ]
 
+[[package]]
+name = "hyper-rustls"
+version = "0.23.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1788965e61b367cd03a62950836d5cd41560c3577d90e40e0819373194d1661c"
+dependencies = [
+ "http",
+ "hyper",
+ "rustls 0.20.8",
+ "tokio",
+ "tokio-rustls",
+]
+
 [[package]]
 name = "hyper-timeout"
 version = "0.4.1"
@@ -3109,6 +3262,12 @@ dependencies = [
  "cxx-build",
 ]
 
+[[package]]
+name = "ident_case"
+version = "1.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
+
 [[package]]
 name = "idna"
 version = "0.3.0"
@@ -3903,6 +4062,16 @@ version = "0.3.17"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
 
+[[package]]
+name = "mime_guess"
+version = "2.0.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef"
+dependencies = [
+ "mime",
+ "unicase",
+]
+
 [[package]]
 name = "minimal-lexical"
 version = "0.2.1"
@@ -5071,6 +5240,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "2d9cc634bc78768157b5cbfe988ffcd1dcba95cd2b2f03a88316c08c6d00ed63"
 dependencies = [
  "bitflags",
+ "getopts",
  "memchr",
  "unicase",
 ]
@@ -5367,28 +5537,52 @@ dependencies = [
  "http",
  "http-body",
  "hyper",
+ "hyper-rustls",
  "hyper-tls",
  "ipnet",
  "js-sys",
  "log",
  "mime",
+ "mime_guess",
  "native-tls",
  "once_cell",
  "percent-encoding",
  "pin-project-lite 0.2.9",
+ "rustls 0.20.8",
+ "rustls-native-certs",
+ "rustls-pemfile",
  "serde",
  "serde_json",
  "serde_urlencoded",
  "tokio",
  "tokio-native-tls",
+ "tokio-rustls",
+ "tokio-util 0.7.8",
  "tower-service",
  "url",
  "wasm-bindgen",
  "wasm-bindgen-futures",
+ "wasm-streams",
  "web-sys",
  "winreg",
 ]
 
+[[package]]
+name = "reqwest-eventsource"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8f03f570355882dd8d15acc3a313841e6e90eddbc76a93c748fd82cc13ba9f51"
+dependencies = [
+ "eventsource-stream",
+ "futures-core",
+ "futures-timer",
+ "mime",
+ "nom",
+ "pin-project-lite 0.2.9",
+ "reqwest",
+ "thiserror",
+]
+
 [[package]]
 name = "resvg"
 version = "0.14.1"
@@ -5676,6 +5870,18 @@ dependencies = [
  "webpki 0.22.0",
 ]
 
+[[package]]
+name = "rustls-native-certs"
+version = "0.6.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0167bac7a9f490495f3c33013e7722b53cb087ecbe082fb0c6387c96f634ea50"
+dependencies = [
+ "openssl-probe",
+ "rustls-pemfile",
+ "schannel",
+ "security-framework",
+]
+
 [[package]]
 name = "rustls-pemfile"
 version = "1.0.2"
@@ -8039,6 +8245,19 @@ dependencies = [
  "leb128",
 ]
 
+[[package]]
+name = "wasm-streams"
+version = "0.2.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6bbae3363c08332cadccd13b67db371814cd214c2524020932f0804b8cf7c078"
+dependencies = [
+ "futures-util",
+ "js-sys",
+ "wasm-bindgen",
+ "wasm-bindgen-futures",
+ "web-sys",
+]
+
 [[package]]
 name = "wasmparser"
 version = "0.85.0"
@@ -8759,6 +8978,7 @@ name = "zed"
 version = "0.88.0"
 dependencies = [
  "activity_indicator",
+ "ai",
  "anyhow",
  "assets",
  "async-compression",

Cargo.toml 🔗

@@ -1,6 +1,7 @@
 [workspace]
 members = [
     "crates/activity_indicator",
+    "crates/ai",
     "crates/assets",
     "crates/auto_update",
     "crates/breadcrumbs",

crates/ai/Cargo.toml 🔗

@@ -0,0 +1,21 @@
+[package]
+name = "ai"
+version = "0.1.0"
+edition = "2021"
+publish = false
+
+[lib]
+path = "src/ai.rs"
+doctest = false
+
+[dependencies]
+editor = { path = "../editor" }
+gpui = { path = "../gpui" }
+
+anyhow.workspace = true
+async-openai = "0.10.3"
+pulldown-cmark = "0.9.2"
+unindent.workspace = true
+
+[dev-dependencies]
+editor = { path = "../editor", features = ["test-support"] }

crates/ai/src/ai.rs 🔗

@@ -0,0 +1,112 @@
+use anyhow::Result;
+use async_openai::types::{ChatCompletionRequestMessage, CreateChatCompletionRequest, Role};
+use editor::Editor;
+use gpui::{actions, AppContext, Task, ViewContext};
+use pulldown_cmark::{Event, HeadingLevel, Parser, Tag};
+
+actions!(ai, [Assist]);
+
+pub fn init(cx: &mut AppContext) {
+    cx.add_async_action(assist)
+}
+
+fn assist(
+    editor: &mut Editor,
+    _: &Assist,
+    cx: &mut ViewContext<Editor>,
+) -> Option<Task<Result<()>>> {
+    let markdown = editor.text(cx);
+    parse_dialog(&markdown);
+    None
+}
+
+fn parse_dialog(markdown: &str) -> CreateChatCompletionRequest {
+    let parser = Parser::new(markdown);
+    let mut messages = Vec::new();
+
+    let mut current_role: Option<(Role, Option<String>)> = None;
+    let mut buffer = String::new();
+    for event in parser {
+        match event {
+            Event::Start(Tag::Heading(HeadingLevel::H2, _, _)) => {
+                if let Some((role, name)) = current_role.take() {
+                    if !buffer.is_empty() {
+                        messages.push(ChatCompletionRequestMessage {
+                            role,
+                            content: buffer.trim().to_string(),
+                            name,
+                        });
+                        buffer.clear();
+                    }
+                }
+            }
+            Event::Text(text) => {
+                if current_role.is_some() {
+                    buffer.push_str(&text);
+                } else {
+                    // Determine the current role based on the H2 header text
+                    let mut chars = text.chars();
+                    let first_char = chars.by_ref().skip_while(|c| c.is_whitespace()).next();
+                    let name = chars.take_while(|c| *c != '\n').collect::<String>();
+                    let name = if name.is_empty() { None } else { Some(name) };
+
+                    let role = match first_char {
+                        Some('@') => Some(Role::User),
+                        Some('/') => Some(Role::Assistant),
+                        Some('#') => Some(Role::System),
+                        _ => None,
+                    };
+
+                    current_role = role.map(|role| (role, name));
+                }
+            }
+            _ => (),
+        }
+    }
+    if let Some((role, name)) = current_role {
+        messages.push(ChatCompletionRequestMessage {
+            role,
+            content: buffer,
+            name,
+        });
+    }
+
+    CreateChatCompletionRequest {
+        model: "gpt-4".into(),
+        messages,
+        ..Default::default()
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_parse_dialog() {
+        use unindent::Unindent;
+
+        let test_input = r#"
+            ## @nathan
+            Hey there, welcome to Zed!
+
+            ## /sky
+            Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.
+        "#.unindent();
+
+        let expected_output = vec![
+            ChatCompletionRequestMessage {
+                role: Role::User,
+                content: "Hey there, welcome to Zed!".to_string(),
+                name: Some("nathan".to_string()),
+            },
+            ChatCompletionRequestMessage {
+                role: Role::Assistant,
+                content: "Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.".to_string(),
+                name: Some("sky".to_string()),
+            },
+        ];
+
+        assert_eq!(parse_dialog(&test_input).messages, expected_output);
+    }
+}

crates/live_kit_client/Cargo.toml 🔗

@@ -46,6 +46,7 @@ collections = { path = "../collections", features = ["test-support"] }
 gpui = { path = "../gpui", features = ["test-support"] }
 live_kit_server = { path = "../live_kit_server" }
 media = { path = "../media" }
+nanoid = "0.4"
 
 anyhow.workspace = true
 async-trait.workspace = true

crates/zed/Cargo.toml 🔗

@@ -48,6 +48,7 @@ language_selector = { path = "../language_selector" }
 lsp = { path = "../lsp" }
 lsp_log = { path = "../lsp_log" }
 node_runtime = { path = "../node_runtime" }
+ai = { path = "../ai" }
 outline = { path = "../outline" }
 plugin_runtime = { path = "../plugin_runtime" }
 project = { path = "../project" }

crates/zed/src/main.rs 🔗

@@ -162,6 +162,7 @@ fn main() {
         terminal_view::init(cx);
         theme_testbench::init(cx);
         copilot::init(http.clone(), node_runtime, cx);
+        ai::init(cx);
 
         cx.spawn(|cx| watch_themes(fs.clone(), cx)).detach();