WIP

Nathan Sobo created

Change summary

Cargo.lock                  |   1 
Cargo.toml                  |   1 
Untitled                    |   0 
crates/ai/Cargo.toml        |   1 
crates/ai/README.zmd        |   5 +
crates/ai/src/ai.rs         | 186 +++++++++++++++-----------------------
crates/collab/Cargo.toml    |   2 
crates/db/Cargo.toml        |   2 
crates/editor/Cargo.toml    |   2 
crates/language/Cargo.toml  |   2 
crates/sqlez/Cargo.toml     |   2 
crates/vim/Cargo.toml       |   2 
crates/workspace/Cargo.toml |   2 
13 files changed, 87 insertions(+), 121 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -104,6 +104,7 @@ dependencies = [
  "editor",
  "futures 0.3.28",
  "gpui",
+ "indoc",
  "isahc",
  "pulldown-cmark",
  "serde",

Cargo.toml 🔗

@@ -79,6 +79,7 @@ ctor = { version = "0.1" }
 env_logger = { version = "0.9" }
 futures = { version = "0.3" }
 glob = { version = "0.3.1" }
+indoc = "1"
 isahc = "1.7.2"
 lazy_static = { version = "1.4.0" }
 log = { version = "0.4.16", features = ["kv_unstable_serde"] }

crates/ai/Cargo.toml 🔗

@@ -16,6 +16,7 @@ util = { path = "../util" }
 serde.workspace = true
 serde_json.workspace = true
 anyhow.workspace = true
+indoc.workspace = true
 pulldown-cmark = "0.9.2"
 futures.workspace = true
 isahc.workspace = true

crates/ai/README.zmd 🔗

@@ -0,0 +1,5 @@
+This is Zed Markdown.
+
+Mention a language model with / at the start of any line, like this:
+
+/ What do you think of this idea?

crates/ai/src/ai.rs 🔗

@@ -1,16 +1,14 @@
-use std::io;
-use std::rc::Rc;
-
 use anyhow::{anyhow, Result};
 use editor::Editor;
 use futures::AsyncBufReadExt;
 use futures::{io::BufReader, AsyncReadExt, Stream, StreamExt};
-use gpui::executor::Foreground;
+use gpui::executor::Background;
 use gpui::{actions, AppContext, Task, ViewContext};
+use indoc::indoc;
 use isahc::prelude::*;
 use isahc::{http::StatusCode, Request};
-use pulldown_cmark::{Event, HeadingLevel, Parser, Tag};
 use serde::{Deserialize, Serialize};
+use std::{io, sync::Arc};
 use util::ResultExt;
 
 actions!(ai, [Assist]);
@@ -93,99 +91,87 @@ fn assist(
 ) -> Option<Task<Result<()>>> {
     let api_key = std::env::var("OPENAI_API_KEY").log_err()?;
 
-    let markdown = editor.text(cx);
-    let prompt = parse_dialog(&markdown);
-    let response = stream_completion(api_key, prompt, cx.foreground().clone());
-
-    let range = editor.buffer().update(cx, |buffer, cx| {
-        let snapshot = buffer.snapshot(cx);
-        let chars = snapshot.reversed_chars_at(snapshot.len());
-        let trailing_newlines = chars.take(2).take_while(|c| *c == '\n').count();
-        let suffix = "\n".repeat(2 - trailing_newlines);
-        let end = snapshot.len();
-        buffer.edit([(end..end, suffix.clone())], None, cx);
+    const SYSTEM_MESSAGE: &'static str = indoc! {r#"
+        You an AI language model embedded in a code editor named Zed, authored by Zed Industries.
+        The input you are currently processing was produced by a special \"model mention\" in a document that is open in the editor.
+        A model mention is indicated via a leading / on a line.
+        The user's currently selected text is indicated via ->->selected text<-<- surrounding selected text.
+        In this sentence, the word ->->example<-<- is selected.
+        Respond to any selected model mention.
+        Summarize each mention in a single short sentence like:
+        > The user selected the word \"example\".
+        Then provide your response to that mention below its summary.
+    "#};
+
+    let (user_message, insertion_site) = editor.buffer().update(cx, |buffer, cx| {
+        // Insert ->-> <-<- around selected text as described in the system prompt above.
         let snapshot = buffer.snapshot(cx);
-        let start = snapshot.anchor_before(snapshot.len());
-        let end = snapshot.anchor_after(snapshot.len());
-        start..end
+        let mut user_message = String::new();
+        let mut buffer_offset = 0;
+        for selection in editor.selections.all(cx) {
+            user_message.extend(snapshot.text_for_range(buffer_offset..selection.start));
+            user_message.push_str("->->");
+            user_message.extend(snapshot.text_for_range(selection.start..selection.end));
+            buffer_offset = selection.end;
+            user_message.push_str("<-<-");
+        }
+        if buffer_offset < snapshot.len() {
+            user_message.extend(snapshot.text_for_range(buffer_offset..snapshot.len()));
+        }
+
+        // Ensure the document ends with 4 trailing newlines.
+        let trailing_newline_count = snapshot
+            .reversed_chars_at(snapshot.len())
+            .take_while(|c| *c == '\n')
+            .take(4);
+        let suffix = "\n".repeat(4 - trailing_newline_count.count());
+        buffer.edit([(snapshot.len()..snapshot.len(), suffix)], None, cx);
+
+        let snapshot = buffer.snapshot(cx); // Take a new snapshot after editing.
+        let insertion_site = snapshot.len() - 2; // Insert text at end of buffer, with an empty line both above and below.
+
+        (user_message, insertion_site)
     });
-    let buffer = editor.buffer().clone();
 
+    let stream = stream_completion(
+        api_key,
+        cx.background_executor().clone(),
+        OpenAIRequest {
+            model: "gpt-4".to_string(),
+            messages: vec![
+                RequestMessage {
+                    role: Role::System,
+                    content: SYSTEM_MESSAGE.to_string(),
+                },
+                RequestMessage {
+                    role: Role::User,
+                    content: user_message,
+                },
+            ],
+            stream: false,
+        },
+    );
+    let buffer = editor.buffer().clone();
     Some(cx.spawn(|_, mut cx| async move {
-        let mut stream = response.await?;
-        let mut message = String::new();
-        while let Some(stream_event) = stream.next().await {
-            if let Some(choice) = stream_event?.choices.first() {
-                if let Some(content) = &choice.delta.content {
-                    message.push_str(content);
-                }
+        let mut messages = stream.await?;
+        while let Some(message) = messages.next().await {
+            let mut message = message?;
+            if let Some(choice) = message.choices.pop() {
+                buffer.update(&mut cx, |buffer, cx| {
+                    let text: Arc<str> = choice.delta.content?.into();
+                    buffer.edit([(insertion_site.clone()..insertion_site, text)], None, cx);
+                    Some(())
+                });
             }
-
-            buffer.update(&mut cx, |buffer, cx| {
-                buffer.edit([(range.clone(), message.clone())], None, cx);
-            });
         }
         Ok(())
     }))
 }
 
-fn parse_dialog(markdown: &str) -> OpenAIRequest {
-    let parser = Parser::new(markdown);
-    let mut messages = Vec::new();
-
-    let mut current_role: Option<Role> = None;
-    let mut buffer = String::new();
-    for event in parser {
-        match event {
-            Event::Start(Tag::Heading(HeadingLevel::H2, _, _)) => {
-                if let Some(role) = current_role.take() {
-                    if !buffer.is_empty() {
-                        messages.push(RequestMessage {
-                            role,
-                            content: buffer.trim().to_string(),
-                        });
-                        buffer.clear();
-                    }
-                }
-            }
-            Event::Text(text) => {
-                if current_role.is_some() {
-                    buffer.push_str(&text);
-                } else {
-                    // Determine the current role based on the H2 header text
-                    let text = text.to_lowercase();
-                    current_role = if text.contains("user") {
-                        Some(Role::User)
-                    } else if text.contains("assistant") {
-                        Some(Role::Assistant)
-                    } else if text.contains("system") {
-                        Some(Role::System)
-                    } else {
-                        None
-                    };
-                }
-            }
-            _ => (),
-        }
-    }
-    if let Some(role) = current_role {
-        messages.push(RequestMessage {
-            role,
-            content: buffer,
-        });
-    }
-
-    OpenAIRequest {
-        model: "gpt-4".into(),
-        messages,
-        stream: true,
-    }
-}
-
 async fn stream_completion(
     api_key: String,
+    executor: Arc<Background>,
     mut request: OpenAIRequest,
-    executor: Rc<Foreground>,
 ) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
     request.stream = true;
 
@@ -240,32 +226,4 @@ async fn stream_completion(
 }
 
 #[cfg(test)]
-mod tests {
-    use super::*;
-
-    #[test]
-    fn test_parse_dialog() {
-        use unindent::Unindent;
-
-        let test_input = r#"
-            ## System
-            Hey there, welcome to Zed!
-
-            ## Assintant
-            Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.
-        "#.unindent();
-
-        let expected_output = vec![
-            RequestMessage {
-                role: Role::User,
-                content: "Hey there, welcome to Zed!".to_string(),
-            },
-            RequestMessage {
-                role: Role::Assistant,
-                content: "Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.".to_string(),
-            },
-        ];
-
-        assert_eq!(parse_dialog(&test_input).messages, expected_output);
-    }
-}
+mod tests {}

crates/collab/Cargo.toml 🔗

@@ -76,7 +76,7 @@ workspace = { path = "../workspace", features = ["test-support"] }
 
 ctor.workspace = true
 env_logger.workspace = true
-indoc = "1.0.4"
+indoc.workspace = true
 util = { path = "../util" }
 lazy_static.workspace = true
 sea-orm = { git = "https://github.com/zed-industries/sea-orm", rev = "18f4c691085712ad014a51792af75a9044bacee6", features = ["sqlx-sqlite"] }

crates/db/Cargo.toml 🔗

@@ -18,7 +18,7 @@ sqlez = { path = "../sqlez" }
 sqlez_macros = { path = "../sqlez_macros" }
 util = { path = "../util" }
 anyhow.workspace = true
-indoc = "1.0.4"
+indoc.workspace = true
 async-trait.workspace = true
 lazy_static.workspace = true
 log.workspace = true

crates/editor/Cargo.toml 🔗

@@ -50,7 +50,7 @@ aho-corasick = "0.7"
 anyhow.workspace = true
 futures.workspace = true
 glob.workspace = true
-indoc = "1.0.4"
+indoc.workspace = true
 itertools = "0.10"
 lazy_static.workspace = true
 log.workspace = true

crates/language/Cargo.toml 🔗

@@ -70,7 +70,7 @@ settings = { path = "../settings", features = ["test-support"] }
 util = { path = "../util", features = ["test-support"] }
 ctor.workspace = true
 env_logger.workspace = true
-indoc = "1.0.4"
+indoc.workspace = true
 rand.workspace = true
 tree-sitter-embedded-template = "*"
 tree-sitter-html = "*"

crates/sqlez/Cargo.toml 🔗

@@ -6,7 +6,7 @@ publish = false
 
 [dependencies]
 anyhow.workspace = true
-indoc = "1.0.7"
+indoc.workspace = true
 libsqlite3-sys = { version = "0.24", features = ["bundled"] }
 smol.workspace = true
 thread_local = "1.1.4"

crates/vim/Cargo.toml 🔗

@@ -35,7 +35,7 @@ settings = { path = "../settings" }
 workspace = { path = "../workspace" }
 
 [dev-dependencies]
-indoc = "1.0.4"
+indoc.workspace = true
 parking_lot.workspace = true
 lazy_static.workspace = true
 

crates/workspace/Cargo.toml 🔗

@@ -62,5 +62,5 @@ settings = { path = "../settings", features = ["test-support"] }
 fs = { path = "../fs", features = ["test-support"] }
 db = { path = "../db", features = ["test-support"] }
 
-indoc = "1.0.4"
+indoc.workspace = true
 env_logger.workspace = true