From 30de64845f3b23cebf3d367424ce79202c4708c5 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Mon, 22 May 2023 23:11:22 -0600 Subject: [PATCH] WIP --- Cargo.lock | 1 + Cargo.toml | 1 + Untitled | 0 crates/ai/Cargo.toml | 1 + crates/ai/README.zmd | 5 + crates/ai/src/ai.rs | 186 ++++++++++++++---------------------- crates/collab/Cargo.toml | 2 +- crates/db/Cargo.toml | 2 +- crates/editor/Cargo.toml | 2 +- crates/language/Cargo.toml | 2 +- crates/sqlez/Cargo.toml | 2 +- crates/vim/Cargo.toml | 2 +- crates/workspace/Cargo.toml | 2 +- 13 files changed, 87 insertions(+), 121 deletions(-) create mode 100644 Untitled create mode 100644 crates/ai/README.zmd diff --git a/Cargo.lock b/Cargo.lock index 5c0570f912f48762c8c114a049c79ca668b79a64..0ea65f93acb25a9883a559c2226dfc67655673ce 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -104,6 +104,7 @@ dependencies = [ "editor", "futures 0.3.28", "gpui", + "indoc", "isahc", "pulldown-cmark", "serde", diff --git a/Cargo.toml b/Cargo.toml index d8bf005b774395b2eed77fa319eb0f94d3e1c223..7411dd53ad7b1087340b4c6931ff33c7bc74809a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,6 +79,7 @@ ctor = { version = "0.1" } env_logger = { version = "0.9" } futures = { version = "0.3" } glob = { version = "0.3.1" } +indoc = "1" isahc = "1.7.2" lazy_static = { version = "1.4.0" } log = { version = "0.4.16", features = ["kv_unstable_serde"] } diff --git a/Untitled b/Untitled new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index 0953330a69e0ff675f18579df0c77a51b1666143..dacdbbbf630115a3ad1192a21fcf4e916c2cf6e5 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -16,6 +16,7 @@ util = { path = "../util" } serde.workspace = true serde_json.workspace = true anyhow.workspace = true +indoc.workspace = true pulldown-cmark = "0.9.2" futures.workspace = true isahc.workspace = true diff --git a/crates/ai/README.zmd b/crates/ai/README.zmd new file mode 100644 index 0000000000000000000000000000000000000000..44cda74cd5cebba504dd7b97f9ee3f43134c870b --- /dev/null +++ b/crates/ai/README.zmd @@ -0,0 +1,5 @@ +This is Zed Markdown. + +Mention a language model with / at the start of any line, like this: + +/ What do you think of this idea? diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index b0bbd15d5939315ddd25bb1fe22e0e3e83bc56af..101378e747fe9d859ea3ea345510bbf2b44f65b8 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -1,16 +1,14 @@ -use std::io; -use std::rc::Rc; - use anyhow::{anyhow, Result}; use editor::Editor; use futures::AsyncBufReadExt; use futures::{io::BufReader, AsyncReadExt, Stream, StreamExt}; -use gpui::executor::Foreground; +use gpui::executor::Background; use gpui::{actions, AppContext, Task, ViewContext}; +use indoc::indoc; use isahc::prelude::*; use isahc::{http::StatusCode, Request}; -use pulldown_cmark::{Event, HeadingLevel, Parser, Tag}; use serde::{Deserialize, Serialize}; +use std::{io, sync::Arc}; use util::ResultExt; actions!(ai, [Assist]); @@ -93,99 +91,87 @@ fn assist( ) -> Option>> { let api_key = std::env::var("OPENAI_API_KEY").log_err()?; - let markdown = editor.text(cx); - let prompt = parse_dialog(&markdown); - let response = stream_completion(api_key, prompt, cx.foreground().clone()); - - let range = editor.buffer().update(cx, |buffer, cx| { - let snapshot = buffer.snapshot(cx); - let chars = snapshot.reversed_chars_at(snapshot.len()); - let trailing_newlines = chars.take(2).take_while(|c| *c == '\n').count(); - let suffix = "\n".repeat(2 - trailing_newlines); - let end = snapshot.len(); - buffer.edit([(end..end, suffix.clone())], None, cx); + const SYSTEM_MESSAGE: &'static str = indoc! {r#" + You an AI language model embedded in a code editor named Zed, authored by Zed Industries. + The input you are currently processing was produced by a special \"model mention\" in a document that is open in the editor. + A model mention is indicated via a leading / on a line. + The user's currently selected text is indicated via ->->selected text<-<- surrounding selected text. + In this sentence, the word ->->example<-<- is selected. + Respond to any selected model mention. + Summarize each mention in a single short sentence like: + > The user selected the word \"example\". + Then provide your response to that mention below its summary. + "#}; + + let (user_message, insertion_site) = editor.buffer().update(cx, |buffer, cx| { + // Insert ->-> <-<- around selected text as described in the system prompt above. let snapshot = buffer.snapshot(cx); - let start = snapshot.anchor_before(snapshot.len()); - let end = snapshot.anchor_after(snapshot.len()); - start..end + let mut user_message = String::new(); + let mut buffer_offset = 0; + for selection in editor.selections.all(cx) { + user_message.extend(snapshot.text_for_range(buffer_offset..selection.start)); + user_message.push_str("->->"); + user_message.extend(snapshot.text_for_range(selection.start..selection.end)); + buffer_offset = selection.end; + user_message.push_str("<-<-"); + } + if buffer_offset < snapshot.len() { + user_message.extend(snapshot.text_for_range(buffer_offset..snapshot.len())); + } + + // Ensure the document ends with 4 trailing newlines. + let trailing_newline_count = snapshot + .reversed_chars_at(snapshot.len()) + .take_while(|c| *c == '\n') + .take(4); + let suffix = "\n".repeat(4 - trailing_newline_count.count()); + buffer.edit([(snapshot.len()..snapshot.len(), suffix)], None, cx); + + let snapshot = buffer.snapshot(cx); // Take a new snapshot after editing. + let insertion_site = snapshot.len() - 2; // Insert text at end of buffer, with an empty line both above and below. + + (user_message, insertion_site) }); - let buffer = editor.buffer().clone(); + let stream = stream_completion( + api_key, + cx.background_executor().clone(), + OpenAIRequest { + model: "gpt-4".to_string(), + messages: vec![ + RequestMessage { + role: Role::System, + content: SYSTEM_MESSAGE.to_string(), + }, + RequestMessage { + role: Role::User, + content: user_message, + }, + ], + stream: false, + }, + ); + let buffer = editor.buffer().clone(); Some(cx.spawn(|_, mut cx| async move { - let mut stream = response.await?; - let mut message = String::new(); - while let Some(stream_event) = stream.next().await { - if let Some(choice) = stream_event?.choices.first() { - if let Some(content) = &choice.delta.content { - message.push_str(content); - } + let mut messages = stream.await?; + while let Some(message) = messages.next().await { + let mut message = message?; + if let Some(choice) = message.choices.pop() { + buffer.update(&mut cx, |buffer, cx| { + let text: Arc = choice.delta.content?.into(); + buffer.edit([(insertion_site.clone()..insertion_site, text)], None, cx); + Some(()) + }); } - - buffer.update(&mut cx, |buffer, cx| { - buffer.edit([(range.clone(), message.clone())], None, cx); - }); } Ok(()) })) } -fn parse_dialog(markdown: &str) -> OpenAIRequest { - let parser = Parser::new(markdown); - let mut messages = Vec::new(); - - let mut current_role: Option = None; - let mut buffer = String::new(); - for event in parser { - match event { - Event::Start(Tag::Heading(HeadingLevel::H2, _, _)) => { - if let Some(role) = current_role.take() { - if !buffer.is_empty() { - messages.push(RequestMessage { - role, - content: buffer.trim().to_string(), - }); - buffer.clear(); - } - } - } - Event::Text(text) => { - if current_role.is_some() { - buffer.push_str(&text); - } else { - // Determine the current role based on the H2 header text - let text = text.to_lowercase(); - current_role = if text.contains("user") { - Some(Role::User) - } else if text.contains("assistant") { - Some(Role::Assistant) - } else if text.contains("system") { - Some(Role::System) - } else { - None - }; - } - } - _ => (), - } - } - if let Some(role) = current_role { - messages.push(RequestMessage { - role, - content: buffer, - }); - } - - OpenAIRequest { - model: "gpt-4".into(), - messages, - stream: true, - } -} - async fn stream_completion( api_key: String, + executor: Arc, mut request: OpenAIRequest, - executor: Rc, ) -> Result>> { request.stream = true; @@ -240,32 +226,4 @@ async fn stream_completion( } #[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_parse_dialog() { - use unindent::Unindent; - - let test_input = r#" - ## System - Hey there, welcome to Zed! - - ## Assintant - Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast. - "#.unindent(); - - let expected_output = vec![ - RequestMessage { - role: Role::User, - content: "Hey there, welcome to Zed!".to_string(), - }, - RequestMessage { - role: Role::Assistant, - content: "Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.".to_string(), - }, - ]; - - assert_eq!(parse_dialog(&test_input).messages, expected_output); - } -} +mod tests {} diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index f2202618f49f1c127701a11e6e7a77c71c2ced7f..cd06b9a70a253eb9668d2704880638c1eeabaaba 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -76,7 +76,7 @@ workspace = { path = "../workspace", features = ["test-support"] } ctor.workspace = true env_logger.workspace = true -indoc = "1.0.4" +indoc.workspace = true util = { path = "../util" } lazy_static.workspace = true sea-orm = { git = "https://github.com/zed-industries/sea-orm", rev = "18f4c691085712ad014a51792af75a9044bacee6", features = ["sqlx-sqlite"] } diff --git a/crates/db/Cargo.toml b/crates/db/Cargo.toml index 8cb7170ef6a518ccd005d5b8d3d5fa5a691c80af..b49078e860ff0d502c7ff1fbe5cdfa26df5fac38 100644 --- a/crates/db/Cargo.toml +++ b/crates/db/Cargo.toml @@ -18,7 +18,7 @@ sqlez = { path = "../sqlez" } sqlez_macros = { path = "../sqlez_macros" } util = { path = "../util" } anyhow.workspace = true -indoc = "1.0.4" +indoc.workspace = true async-trait.workspace = true lazy_static.workspace = true log.workspace = true diff --git a/crates/editor/Cargo.toml b/crates/editor/Cargo.toml index fc7bf4b8abad6732ab338e439db6f30ba2f49e83..482923fee72e03a820701f1fbc33ca642f9713c9 100644 --- a/crates/editor/Cargo.toml +++ b/crates/editor/Cargo.toml @@ -50,7 +50,7 @@ aho-corasick = "0.7" anyhow.workspace = true futures.workspace = true glob.workspace = true -indoc = "1.0.4" +indoc.workspace = true itertools = "0.10" lazy_static.workspace = true log.workspace = true diff --git a/crates/language/Cargo.toml b/crates/language/Cargo.toml index 5a7644d98e6220d77bda66bcd6d6f35f895b67a6..79121b3799abb933fdef4c6f78ce2c096f0d6010 100644 --- a/crates/language/Cargo.toml +++ b/crates/language/Cargo.toml @@ -70,7 +70,7 @@ settings = { path = "../settings", features = ["test-support"] } util = { path = "../util", features = ["test-support"] } ctor.workspace = true env_logger.workspace = true -indoc = "1.0.4" +indoc.workspace = true rand.workspace = true tree-sitter-embedded-template = "*" tree-sitter-html = "*" diff --git a/crates/sqlez/Cargo.toml b/crates/sqlez/Cargo.toml index 7371a7863a30701ae1995501ab7741f47ba4d196..01d17d48123f181b5913029102aa3215f75fafbb 100644 --- a/crates/sqlez/Cargo.toml +++ b/crates/sqlez/Cargo.toml @@ -6,7 +6,7 @@ publish = false [dependencies] anyhow.workspace = true -indoc = "1.0.7" +indoc.workspace = true libsqlite3-sys = { version = "0.24", features = ["bundled"] } smol.workspace = true thread_local = "1.1.4" diff --git a/crates/vim/Cargo.toml b/crates/vim/Cargo.toml index c34a5b469b40e73cb13bbf84803576c6ba48b643..ee3144fd566ba4fae33a4333f159c64b6140595a 100644 --- a/crates/vim/Cargo.toml +++ b/crates/vim/Cargo.toml @@ -35,7 +35,7 @@ settings = { path = "../settings" } workspace = { path = "../workspace" } [dev-dependencies] -indoc = "1.0.4" +indoc.workspace = true parking_lot.workspace = true lazy_static.workspace = true diff --git a/crates/workspace/Cargo.toml b/crates/workspace/Cargo.toml index 33e5e7aefe0cc451efb4ad14b9639a59bd471fbf..b22607e20dec0ac9f285c9a66f5df638c5a66809 100644 --- a/crates/workspace/Cargo.toml +++ b/crates/workspace/Cargo.toml @@ -62,5 +62,5 @@ settings = { path = "../settings", features = ["test-support"] } fs = { path = "../fs", features = ["test-support"] } db = { path = "../db", features = ["test-support"] } -indoc = "1.0.4" +indoc.workspace = true env_logger.workspace = true