Cargo.lock 🔗
@@ -104,6 +104,7 @@ dependencies = [
"editor",
"futures 0.3.28",
"gpui",
+ "indoc",
"isahc",
"pulldown-cmark",
"serde",
Nathan Sobo created
Cargo.lock | 1
Cargo.toml | 1
Untitled | 0
crates/ai/Cargo.toml | 1
crates/ai/README.zmd | 5 +
crates/ai/src/ai.rs | 186 +++++++++++++++-----------------------
crates/collab/Cargo.toml | 2
crates/db/Cargo.toml | 2
crates/editor/Cargo.toml | 2
crates/language/Cargo.toml | 2
crates/sqlez/Cargo.toml | 2
crates/vim/Cargo.toml | 2
crates/workspace/Cargo.toml | 2
13 files changed, 87 insertions(+), 121 deletions(-)
@@ -104,6 +104,7 @@ dependencies = [
"editor",
"futures 0.3.28",
"gpui",
+ "indoc",
"isahc",
"pulldown-cmark",
"serde",
@@ -79,6 +79,7 @@ ctor = { version = "0.1" }
env_logger = { version = "0.9" }
futures = { version = "0.3" }
glob = { version = "0.3.1" }
+indoc = "1"
isahc = "1.7.2"
lazy_static = { version = "1.4.0" }
log = { version = "0.4.16", features = ["kv_unstable_serde"] }
@@ -16,6 +16,7 @@ util = { path = "../util" }
serde.workspace = true
serde_json.workspace = true
anyhow.workspace = true
+indoc.workspace = true
pulldown-cmark = "0.9.2"
futures.workspace = true
isahc.workspace = true
@@ -0,0 +1,5 @@
+This is Zed Markdown.
+
+Mention a language model with / at the start of any line, like this:
+
+/ What do you think of this idea?
@@ -1,16 +1,14 @@
-use std::io;
-use std::rc::Rc;
-
use anyhow::{anyhow, Result};
use editor::Editor;
use futures::AsyncBufReadExt;
use futures::{io::BufReader, AsyncReadExt, Stream, StreamExt};
-use gpui::executor::Foreground;
+use gpui::executor::Background;
use gpui::{actions, AppContext, Task, ViewContext};
+use indoc::indoc;
use isahc::prelude::*;
use isahc::{http::StatusCode, Request};
-use pulldown_cmark::{Event, HeadingLevel, Parser, Tag};
use serde::{Deserialize, Serialize};
+use std::{io, sync::Arc};
use util::ResultExt;
actions!(ai, [Assist]);
@@ -93,99 +91,87 @@ fn assist(
) -> Option<Task<Result<()>>> {
let api_key = std::env::var("OPENAI_API_KEY").log_err()?;
- let markdown = editor.text(cx);
- let prompt = parse_dialog(&markdown);
- let response = stream_completion(api_key, prompt, cx.foreground().clone());
-
- let range = editor.buffer().update(cx, |buffer, cx| {
- let snapshot = buffer.snapshot(cx);
- let chars = snapshot.reversed_chars_at(snapshot.len());
- let trailing_newlines = chars.take(2).take_while(|c| *c == '\n').count();
- let suffix = "\n".repeat(2 - trailing_newlines);
- let end = snapshot.len();
- buffer.edit([(end..end, suffix.clone())], None, cx);
+ const SYSTEM_MESSAGE: &'static str = indoc! {r#"
+ You an AI language model embedded in a code editor named Zed, authored by Zed Industries.
+ The input you are currently processing was produced by a special \"model mention\" in a document that is open in the editor.
+ A model mention is indicated via a leading / on a line.
+ The user's currently selected text is indicated via ->->selected text<-<- surrounding selected text.
+ In this sentence, the word ->->example<-<- is selected.
+ Respond to any selected model mention.
+ Summarize each mention in a single short sentence like:
+ > The user selected the word \"example\".
+ Then provide your response to that mention below its summary.
+ "#};
+
+ let (user_message, insertion_site) = editor.buffer().update(cx, |buffer, cx| {
+ // Insert ->-> <-<- around selected text as described in the system prompt above.
let snapshot = buffer.snapshot(cx);
- let start = snapshot.anchor_before(snapshot.len());
- let end = snapshot.anchor_after(snapshot.len());
- start..end
+ let mut user_message = String::new();
+ let mut buffer_offset = 0;
+ for selection in editor.selections.all(cx) {
+ user_message.extend(snapshot.text_for_range(buffer_offset..selection.start));
+ user_message.push_str("->->");
+ user_message.extend(snapshot.text_for_range(selection.start..selection.end));
+ buffer_offset = selection.end;
+ user_message.push_str("<-<-");
+ }
+ if buffer_offset < snapshot.len() {
+ user_message.extend(snapshot.text_for_range(buffer_offset..snapshot.len()));
+ }
+
+ // Ensure the document ends with 4 trailing newlines.
+ let trailing_newline_count = snapshot
+ .reversed_chars_at(snapshot.len())
+ .take_while(|c| *c == '\n')
+ .take(4);
+ let suffix = "\n".repeat(4 - trailing_newline_count.count());
+ buffer.edit([(snapshot.len()..snapshot.len(), suffix)], None, cx);
+
+ let snapshot = buffer.snapshot(cx); // Take a new snapshot after editing.
+ let insertion_site = snapshot.len() - 2; // Insert text at end of buffer, with an empty line both above and below.
+
+ (user_message, insertion_site)
});
- let buffer = editor.buffer().clone();
+ let stream = stream_completion(
+ api_key,
+ cx.background_executor().clone(),
+ OpenAIRequest {
+ model: "gpt-4".to_string(),
+ messages: vec![
+ RequestMessage {
+ role: Role::System,
+ content: SYSTEM_MESSAGE.to_string(),
+ },
+ RequestMessage {
+ role: Role::User,
+ content: user_message,
+ },
+ ],
+ stream: false,
+ },
+ );
+ let buffer = editor.buffer().clone();
Some(cx.spawn(|_, mut cx| async move {
- let mut stream = response.await?;
- let mut message = String::new();
- while let Some(stream_event) = stream.next().await {
- if let Some(choice) = stream_event?.choices.first() {
- if let Some(content) = &choice.delta.content {
- message.push_str(content);
- }
+ let mut messages = stream.await?;
+ while let Some(message) = messages.next().await {
+ let mut message = message?;
+ if let Some(choice) = message.choices.pop() {
+ buffer.update(&mut cx, |buffer, cx| {
+ let text: Arc<str> = choice.delta.content?.into();
+ buffer.edit([(insertion_site.clone()..insertion_site, text)], None, cx);
+ Some(())
+ });
}
-
- buffer.update(&mut cx, |buffer, cx| {
- buffer.edit([(range.clone(), message.clone())], None, cx);
- });
}
Ok(())
}))
}
-fn parse_dialog(markdown: &str) -> OpenAIRequest {
- let parser = Parser::new(markdown);
- let mut messages = Vec::new();
-
- let mut current_role: Option<Role> = None;
- let mut buffer = String::new();
- for event in parser {
- match event {
- Event::Start(Tag::Heading(HeadingLevel::H2, _, _)) => {
- if let Some(role) = current_role.take() {
- if !buffer.is_empty() {
- messages.push(RequestMessage {
- role,
- content: buffer.trim().to_string(),
- });
- buffer.clear();
- }
- }
- }
- Event::Text(text) => {
- if current_role.is_some() {
- buffer.push_str(&text);
- } else {
- // Determine the current role based on the H2 header text
- let text = text.to_lowercase();
- current_role = if text.contains("user") {
- Some(Role::User)
- } else if text.contains("assistant") {
- Some(Role::Assistant)
- } else if text.contains("system") {
- Some(Role::System)
- } else {
- None
- };
- }
- }
- _ => (),
- }
- }
- if let Some(role) = current_role {
- messages.push(RequestMessage {
- role,
- content: buffer,
- });
- }
-
- OpenAIRequest {
- model: "gpt-4".into(),
- messages,
- stream: true,
- }
-}
-
async fn stream_completion(
api_key: String,
+ executor: Arc<Background>,
mut request: OpenAIRequest,
- executor: Rc<Foreground>,
) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
request.stream = true;
@@ -240,32 +226,4 @@ async fn stream_completion(
}
#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn test_parse_dialog() {
- use unindent::Unindent;
-
- let test_input = r#"
- ## System
- Hey there, welcome to Zed!
-
- ## Assintant
- Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.
- "#.unindent();
-
- let expected_output = vec![
- RequestMessage {
- role: Role::User,
- content: "Hey there, welcome to Zed!".to_string(),
- },
- RequestMessage {
- role: Role::Assistant,
- content: "Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.".to_string(),
- },
- ];
-
- assert_eq!(parse_dialog(&test_input).messages, expected_output);
- }
-}
+mod tests {}
@@ -76,7 +76,7 @@ workspace = { path = "../workspace", features = ["test-support"] }
ctor.workspace = true
env_logger.workspace = true
-indoc = "1.0.4"
+indoc.workspace = true
util = { path = "../util" }
lazy_static.workspace = true
sea-orm = { git = "https://github.com/zed-industries/sea-orm", rev = "18f4c691085712ad014a51792af75a9044bacee6", features = ["sqlx-sqlite"] }
@@ -18,7 +18,7 @@ sqlez = { path = "../sqlez" }
sqlez_macros = { path = "../sqlez_macros" }
util = { path = "../util" }
anyhow.workspace = true
-indoc = "1.0.4"
+indoc.workspace = true
async-trait.workspace = true
lazy_static.workspace = true
log.workspace = true
@@ -50,7 +50,7 @@ aho-corasick = "0.7"
anyhow.workspace = true
futures.workspace = true
glob.workspace = true
-indoc = "1.0.4"
+indoc.workspace = true
itertools = "0.10"
lazy_static.workspace = true
log.workspace = true
@@ -70,7 +70,7 @@ settings = { path = "../settings", features = ["test-support"] }
util = { path = "../util", features = ["test-support"] }
ctor.workspace = true
env_logger.workspace = true
-indoc = "1.0.4"
+indoc.workspace = true
rand.workspace = true
tree-sitter-embedded-template = "*"
tree-sitter-html = "*"
@@ -6,7 +6,7 @@ publish = false
[dependencies]
anyhow.workspace = true
-indoc = "1.0.7"
+indoc.workspace = true
libsqlite3-sys = { version = "0.24", features = ["bundled"] }
smol.workspace = true
thread_local = "1.1.4"
@@ -35,7 +35,7 @@ settings = { path = "../settings" }
workspace = { path = "../workspace" }
[dev-dependencies]
-indoc = "1.0.4"
+indoc.workspace = true
parking_lot.workspace = true
lazy_static.workspace = true
@@ -62,5 +62,5 @@ settings = { path = "../settings", features = ["test-support"] }
fs = { path = "../fs", features = ["test-support"] }
db = { path = "../db", features = ["test-support"] }
-indoc = "1.0.4"
+indoc.workspace = true
env_logger.workspace = true