WIP

Nathan Sobo created

Change summary

Cargo.lock           |   1 
crates/ai/Cargo.toml |   1 
crates/ai/README.zmd |  10 +
crates/ai/src/ai.rs  | 235 +++++++++++++++++++++++++++++----------------
4 files changed, 158 insertions(+), 89 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -101,6 +101,7 @@ version = "0.1.0"
 dependencies = [
  "anyhow",
  "assets",
+ "collections",
  "editor",
  "futures 0.3.28",
  "gpui",

crates/ai/Cargo.toml 🔗

@@ -10,6 +10,7 @@ doctest = false
 
 [dependencies]
 assets = { path = "../assets"}
+collections = { path = "../collections"}
 editor = { path = "../editor" }
 gpui = { path = "../gpui" }
 util = { path = "../util" }

crates/ai/README.zmd 🔗

@@ -2,8 +2,12 @@ This is Zed Markdown.
 
 Mention a language model with / at the start of any line, like this:
 
-/
+/ Please help me articulate Zed's approach to integrating with LLMs.
 
-> To mention a language model, simply include a forward slash (/) at the start of a line, followed by the mention of the model. For example:
+> Zed's approach to integrating with large language models (LLMs) involves seamless communication between the user and the AI model. By incorporating a mention with a / at the beginning of a line, users can directly ask questions or request assistance from the AI model. This provides an interactive and efficient way to collaborate within the editor, enhancing productivity and supporting user needs. <
 
-/gpt-4
+This is a document, but it's also more than that. It's a conversation with the model. The document represents the *context* that feeds into a model invocation. The conversation between one or more users and the model is actually a branching and merging conversation of continuously evolving contexts, and the connection of all edits is modeled as a conversation graph.
+
+/ Confirm you understand the above.
+
+> Yes, I understand. The document serves as both the context for AI model invocation and as a representation of an ongoing conversation between the users and the model. The conversation includes branching and merging contexts, and all edits contribute to the conversation graph. <

crates/ai/src/ai.rs 🔗

@@ -1,5 +1,6 @@
 use anyhow::{anyhow, Result};
 use assets::Assets;
+use collections::HashMap;
 use editor::Editor;
 use futures::AsyncBufReadExt;
 use futures::{io::BufReader, AsyncReadExt, Stream, StreamExt};
@@ -8,9 +9,11 @@ use gpui::{actions, AppContext, Task, ViewContext};
 use isahc::prelude::*;
 use isahc::{http::StatusCode, Request};
 use serde::{Deserialize, Serialize};
+use std::cell::RefCell;
 use std::fs;
+use std::rc::Rc;
 use std::{io, sync::Arc};
-use util::ResultExt;
+use util::{ResultExt, TryFutureExt};
 
 actions!(ai, [Assist]);
 
@@ -82,101 +85,161 @@ struct OpenAIChoice {
 }
 
 pub fn init(cx: &mut AppContext) {
-    cx.add_async_action(assist)
-}
-
-fn assist(
-    editor: &mut Editor,
-    _: &Assist,
-    cx: &mut ViewContext<Editor>,
-) -> Option<Task<Result<()>>> {
-    let api_key = std::env::var("OPENAI_API_KEY").log_err()?;
-
-    let selections = editor.selections.all(cx);
-    let (user_message, insertion_site) = editor.buffer().update(cx, |buffer, cx| {
-        // Insert ->-> <-<- around selected text as described in the system prompt above.
-        let snapshot = buffer.snapshot(cx);
-        let mut user_message = String::new();
-        let mut buffer_offset = 0;
-        for selection in selections {
-            user_message.extend(snapshot.text_for_range(buffer_offset..selection.start));
-            user_message.push_str("->->");
-            user_message.extend(snapshot.text_for_range(selection.start..selection.end));
-            buffer_offset = selection.end;
-            user_message.push_str("<-<-");
+    let assistant = Rc::new(Assistant::default());
+    cx.add_action({
+        let assistant = assistant.clone();
+        move |editor: &mut Editor, _: &Assist, cx: &mut ViewContext<Editor>| {
+            assistant.assist(editor, cx).log_err();
         }
-        if buffer_offset < snapshot.len() {
-            user_message.extend(snapshot.text_for_range(buffer_offset..snapshot.len()));
+    });
+    cx.capture_action({
+        let assistant = assistant.clone();
+        move |_: &mut Editor, _: &editor::Cancel, cx: &mut ViewContext<Editor>| {
+            dbg!("CANCEL LAST ASSIST");
+
+            if !assistant.cancel_last_assist(cx.view_id()) {
+                cx.propagate_action();
+            }
         }
+    });
+}
 
-        // Ensure the document ends with 4 trailing newlines.
-        let trailing_newline_count = snapshot
-            .reversed_chars_at(snapshot.len())
-            .take_while(|c| *c == '\n')
-            .take(4);
-        let suffix = "\n".repeat(4 - trailing_newline_count.count());
-        buffer.edit([(snapshot.len()..snapshot.len(), suffix)], None, cx);
+type CompletionId = usize;
 
-        let snapshot = buffer.snapshot(cx); // Take a new snapshot after editing.
-        let insertion_site = snapshot.anchor_after(snapshot.len() - 2);
+#[derive(Default)]
+struct Assistant(RefCell<AssistantState>);
 
-        (user_message, insertion_site)
-    });
+#[derive(Default)]
+struct AssistantState {
+    assist_stacks: HashMap<usize, Vec<(CompletionId, Task<Option<()>>)>>,
+    next_completion_id: CompletionId,
+}
 
-    let buffer = editor.buffer().clone();
-    let executor = cx.background_executor().clone();
-    Some(cx.spawn(|_, mut cx| async move {
-        // TODO: We should have a get_string method on assets. This is repateated elsewhere.
-        let content = Assets::get("contexts/system.zmd").unwrap();
-        let mut system_message = std::str::from_utf8(content.data.as_ref())
-            .unwrap()
-            .to_string();
-
-        if let Ok(custom_system_message_path) = std::env::var("ZED_ASSISTANT_SYSTEM_PROMPT_PATH") {
-            system_message
-                .push_str("\n\nAlso consider the following user-defined system prompt:\n\n");
-            // TODO: Replace this with our file system trait object.
-            // What you could bind dependencies on an action when you bind it?.
-            dbg!("reading from {:?}", &custom_system_message_path);
-            system_message.push_str(
-                &cx.background()
-                    .spawn(async move { fs::read_to_string(custom_system_message_path) })
-                    .await?,
-            );
-        }
+impl Assistant {
+    fn assist(self: &Rc<Self>, editor: &mut Editor, cx: &mut ViewContext<Editor>) -> Result<()> {
+        let api_key = std::env::var("OPENAI_API_KEY")?;
 
-        let stream = stream_completion(
-            api_key,
-            executor,
-            OpenAIRequest {
-                model: "gpt-4".to_string(),
-                messages: vec![
-                    RequestMessage {
-                        role: Role::System,
-                        content: system_message.to_string(),
-                    },
-                    RequestMessage {
-                        role: Role::User,
-                        content: user_message,
+        let selections = editor.selections.all(cx);
+        let (user_message, insertion_site) = editor.buffer().update(cx, |buffer, cx| {
+            // Insert ->-> <-<- around selected text as described in the system prompt above.
+            let snapshot = buffer.snapshot(cx);
+            let mut user_message = String::new();
+            let mut buffer_offset = 0;
+            for selection in selections {
+                user_message.extend(snapshot.text_for_range(buffer_offset..selection.start));
+                user_message.push_str("->->");
+                user_message.extend(snapshot.text_for_range(selection.start..selection.end));
+                buffer_offset = selection.end;
+                user_message.push_str("<-<-");
+            }
+            if buffer_offset < snapshot.len() {
+                user_message.extend(snapshot.text_for_range(buffer_offset..snapshot.len()));
+            }
+
+            // Ensure the document ends with 4 trailing newlines.
+            let trailing_newline_count = snapshot
+                .reversed_chars_at(snapshot.len())
+                .take_while(|c| *c == '\n')
+                .take(4);
+            let suffix = "\n".repeat(4 - trailing_newline_count.count());
+            buffer.edit([(snapshot.len()..snapshot.len(), suffix)], None, cx);
+
+            let snapshot = buffer.snapshot(cx); // Take a new snapshot after editing.
+            let insertion_site = snapshot.anchor_after(snapshot.len() - 2);
+
+            (user_message, insertion_site)
+        });
+
+        let this = self.clone();
+        let buffer = editor.buffer().clone();
+        let executor = cx.background_executor().clone();
+        let editor_id = cx.view_id();
+        let assist_id = util::post_inc(&mut self.0.borrow_mut().next_completion_id);
+        let assist_task = cx.spawn(|_, mut cx| {
+            async move {
+                // TODO: We should have a get_string method on assets. This is repateated elsewhere.
+                let content = Assets::get("contexts/system.zmd").unwrap();
+                let mut system_message = std::str::from_utf8(content.data.as_ref())
+                    .unwrap()
+                    .to_string();
+
+                if let Ok(custom_system_message_path) =
+                    std::env::var("ZED_ASSISTANT_SYSTEM_PROMPT_PATH")
+                {
+                    system_message.push_str(
+                        "\n\nAlso consider the following user-defined system prompt:\n\n",
+                    );
+                    // TODO: Replace this with our file system trait object.
+                    // What you could bind dependencies on an action when you bind it?.
+                    dbg!("reading from {:?}", &custom_system_message_path);
+                    system_message.push_str(
+                        &cx.background()
+                            .spawn(async move { fs::read_to_string(custom_system_message_path) })
+                            .await?,
+                    );
+                }
+
+                let stream = stream_completion(
+                    api_key,
+                    executor,
+                    OpenAIRequest {
+                        model: "gpt-4".to_string(),
+                        messages: vec![
+                            RequestMessage {
+                                role: Role::System,
+                                content: system_message.to_string(),
+                            },
+                            RequestMessage {
+                                role: Role::User,
+                                content: user_message,
+                            },
+                        ],
+                        stream: false,
                     },
-                ],
-                stream: false,
-            },
-        );
-
-        let mut messages = stream.await?;
-        while let Some(message) = messages.next().await {
-            let mut message = message?;
-            if let Some(choice) = message.choices.pop() {
-                buffer.update(&mut cx, |buffer, cx| {
-                    let text: Arc<str> = choice.delta.content?.into();
-                    buffer.edit([(insertion_site.clone()..insertion_site, text)], None, cx);
-                    Some(())
-                });
+                );
+
+                let mut messages = stream.await?;
+                while let Some(message) = messages.next().await {
+                    let mut message = message?;
+                    if let Some(choice) = message.choices.pop() {
+                        buffer.update(&mut cx, |buffer, cx| {
+                            let text: Arc<str> = choice.delta.content?.into();
+                            buffer.edit([(insertion_site.clone()..insertion_site, text)], None, cx);
+                            Some(())
+                        });
+                    }
+                }
+
+                this.0
+                    .borrow_mut()
+                    .assist_stacks
+                    .get_mut(&editor_id)
+                    .unwrap()
+                    .retain(|(id, _)| *id != assist_id);
+
+                anyhow::Ok(())
             }
-        }
+            .log_err()
+        });
+
+        self.0
+            .borrow_mut()
+            .assist_stacks
+            .entry(cx.view_id())
+            .or_default()
+            .push((dbg!(assist_id), assist_task));
+
         Ok(())
-    }))
+    }
+
+    fn cancel_last_assist(self: &Rc<Self>, editor_id: usize) -> bool {
+        self.0
+            .borrow_mut()
+            .assist_stacks
+            .get_mut(&editor_id)
+            .and_then(|assists| assists.pop())
+            .is_some()
+    }
 }
 
 async fn stream_completion(