From 5362e7d3468095eb57f9b863c1ded32f62e76049 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Tue, 23 May 2023 01:52:37 -0600 Subject: [PATCH] WIP --- Cargo.lock | 1 + crates/ai/Cargo.toml | 1 + crates/ai/README.zmd | 10 +- crates/ai/src/ai.rs | 235 +++++++++++++++++++++++++++---------------- 4 files changed, 158 insertions(+), 89 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e97865fbbef12f3d48894a3671fd1c4ad4ffc970..5728a270024f5b239d45c358ceb85166d4a926c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -101,6 +101,7 @@ version = "0.1.0" dependencies = [ "anyhow", "assets", + "collections", "editor", "futures 0.3.28", "gpui", diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index 3c87583c9634c6d7d1a1bc030973f3cc0347b5f6..b367a4d43cac845950dc123e66ed0c7be15da1f2 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -10,6 +10,7 @@ doctest = false [dependencies] assets = { path = "../assets"} +collections = { path = "../collections"} editor = { path = "../editor" } gpui = { path = "../gpui" } util = { path = "../util" } diff --git a/crates/ai/README.zmd b/crates/ai/README.zmd index 83930ebd3e62ff2bd4a8c9ce8c90677110b1ad5c..146e770a26d75b0f39a76b8f9948265d306a5a19 100644 --- a/crates/ai/README.zmd +++ b/crates/ai/README.zmd @@ -2,8 +2,12 @@ This is Zed Markdown. Mention a language model with / at the start of any line, like this: -/ +/ Please help me articulate Zed's approach to integrating with LLMs. -> To mention a language model, simply include a forward slash (/) at the start of a line, followed by the mention of the model. For example: +> Zed's approach to integrating with large language models (LLMs) involves seamless communication between the user and the AI model. By incorporating a mention with a / at the beginning of a line, users can directly ask questions or request assistance from the AI model. This provides an interactive and efficient way to collaborate within the editor, enhancing productivity and supporting user needs. < -/gpt-4 +This is a document, but it's also more than that. It's a conversation with the model. The document represents the *context* that feeds into a model invocation. The conversation between one or more users and the model is actually a branching and merging conversation of continuously evolving contexts, and the connection of all edits is modeled as a conversation graph. + +/ Confirm you understand the above. + +> Yes, I understand. The document serves as both the context for AI model invocation and as a representation of an ongoing conversation between the users and the model. The conversation includes branching and merging contexts, and all edits contribute to the conversation graph. < diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index 8329b568bf45f6515511ac87c976c9c75d24621d..c2c00914fd8c9e59faa3af7b107bd9283e3ff1c5 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -1,5 +1,6 @@ use anyhow::{anyhow, Result}; use assets::Assets; +use collections::HashMap; use editor::Editor; use futures::AsyncBufReadExt; use futures::{io::BufReader, AsyncReadExt, Stream, StreamExt}; @@ -8,9 +9,11 @@ use gpui::{actions, AppContext, Task, ViewContext}; use isahc::prelude::*; use isahc::{http::StatusCode, Request}; use serde::{Deserialize, Serialize}; +use std::cell::RefCell; use std::fs; +use std::rc::Rc; use std::{io, sync::Arc}; -use util::ResultExt; +use util::{ResultExt, TryFutureExt}; actions!(ai, [Assist]); @@ -82,101 +85,161 @@ struct OpenAIChoice { } pub fn init(cx: &mut AppContext) { - cx.add_async_action(assist) -} - -fn assist( - editor: &mut Editor, - _: &Assist, - cx: &mut ViewContext, -) -> Option>> { - let api_key = std::env::var("OPENAI_API_KEY").log_err()?; - - let selections = editor.selections.all(cx); - let (user_message, insertion_site) = editor.buffer().update(cx, |buffer, cx| { - // Insert ->-> <-<- around selected text as described in the system prompt above. - let snapshot = buffer.snapshot(cx); - let mut user_message = String::new(); - let mut buffer_offset = 0; - for selection in selections { - user_message.extend(snapshot.text_for_range(buffer_offset..selection.start)); - user_message.push_str("->->"); - user_message.extend(snapshot.text_for_range(selection.start..selection.end)); - buffer_offset = selection.end; - user_message.push_str("<-<-"); + let assistant = Rc::new(Assistant::default()); + cx.add_action({ + let assistant = assistant.clone(); + move |editor: &mut Editor, _: &Assist, cx: &mut ViewContext| { + assistant.assist(editor, cx).log_err(); } - if buffer_offset < snapshot.len() { - user_message.extend(snapshot.text_for_range(buffer_offset..snapshot.len())); + }); + cx.capture_action({ + let assistant = assistant.clone(); + move |_: &mut Editor, _: &editor::Cancel, cx: &mut ViewContext| { + dbg!("CANCEL LAST ASSIST"); + + if !assistant.cancel_last_assist(cx.view_id()) { + cx.propagate_action(); + } } + }); +} - // Ensure the document ends with 4 trailing newlines. - let trailing_newline_count = snapshot - .reversed_chars_at(snapshot.len()) - .take_while(|c| *c == '\n') - .take(4); - let suffix = "\n".repeat(4 - trailing_newline_count.count()); - buffer.edit([(snapshot.len()..snapshot.len(), suffix)], None, cx); +type CompletionId = usize; - let snapshot = buffer.snapshot(cx); // Take a new snapshot after editing. - let insertion_site = snapshot.anchor_after(snapshot.len() - 2); +#[derive(Default)] +struct Assistant(RefCell); - (user_message, insertion_site) - }); +#[derive(Default)] +struct AssistantState { + assist_stacks: HashMap>)>>, + next_completion_id: CompletionId, +} - let buffer = editor.buffer().clone(); - let executor = cx.background_executor().clone(); - Some(cx.spawn(|_, mut cx| async move { - // TODO: We should have a get_string method on assets. This is repateated elsewhere. - let content = Assets::get("contexts/system.zmd").unwrap(); - let mut system_message = std::str::from_utf8(content.data.as_ref()) - .unwrap() - .to_string(); - - if let Ok(custom_system_message_path) = std::env::var("ZED_ASSISTANT_SYSTEM_PROMPT_PATH") { - system_message - .push_str("\n\nAlso consider the following user-defined system prompt:\n\n"); - // TODO: Replace this with our file system trait object. - // What you could bind dependencies on an action when you bind it?. - dbg!("reading from {:?}", &custom_system_message_path); - system_message.push_str( - &cx.background() - .spawn(async move { fs::read_to_string(custom_system_message_path) }) - .await?, - ); - } +impl Assistant { + fn assist(self: &Rc, editor: &mut Editor, cx: &mut ViewContext) -> Result<()> { + let api_key = std::env::var("OPENAI_API_KEY")?; - let stream = stream_completion( - api_key, - executor, - OpenAIRequest { - model: "gpt-4".to_string(), - messages: vec![ - RequestMessage { - role: Role::System, - content: system_message.to_string(), - }, - RequestMessage { - role: Role::User, - content: user_message, + let selections = editor.selections.all(cx); + let (user_message, insertion_site) = editor.buffer().update(cx, |buffer, cx| { + // Insert ->-> <-<- around selected text as described in the system prompt above. + let snapshot = buffer.snapshot(cx); + let mut user_message = String::new(); + let mut buffer_offset = 0; + for selection in selections { + user_message.extend(snapshot.text_for_range(buffer_offset..selection.start)); + user_message.push_str("->->"); + user_message.extend(snapshot.text_for_range(selection.start..selection.end)); + buffer_offset = selection.end; + user_message.push_str("<-<-"); + } + if buffer_offset < snapshot.len() { + user_message.extend(snapshot.text_for_range(buffer_offset..snapshot.len())); + } + + // Ensure the document ends with 4 trailing newlines. + let trailing_newline_count = snapshot + .reversed_chars_at(snapshot.len()) + .take_while(|c| *c == '\n') + .take(4); + let suffix = "\n".repeat(4 - trailing_newline_count.count()); + buffer.edit([(snapshot.len()..snapshot.len(), suffix)], None, cx); + + let snapshot = buffer.snapshot(cx); // Take a new snapshot after editing. + let insertion_site = snapshot.anchor_after(snapshot.len() - 2); + + (user_message, insertion_site) + }); + + let this = self.clone(); + let buffer = editor.buffer().clone(); + let executor = cx.background_executor().clone(); + let editor_id = cx.view_id(); + let assist_id = util::post_inc(&mut self.0.borrow_mut().next_completion_id); + let assist_task = cx.spawn(|_, mut cx| { + async move { + // TODO: We should have a get_string method on assets. This is repateated elsewhere. + let content = Assets::get("contexts/system.zmd").unwrap(); + let mut system_message = std::str::from_utf8(content.data.as_ref()) + .unwrap() + .to_string(); + + if let Ok(custom_system_message_path) = + std::env::var("ZED_ASSISTANT_SYSTEM_PROMPT_PATH") + { + system_message.push_str( + "\n\nAlso consider the following user-defined system prompt:\n\n", + ); + // TODO: Replace this with our file system trait object. + // What you could bind dependencies on an action when you bind it?. + dbg!("reading from {:?}", &custom_system_message_path); + system_message.push_str( + &cx.background() + .spawn(async move { fs::read_to_string(custom_system_message_path) }) + .await?, + ); + } + + let stream = stream_completion( + api_key, + executor, + OpenAIRequest { + model: "gpt-4".to_string(), + messages: vec![ + RequestMessage { + role: Role::System, + content: system_message.to_string(), + }, + RequestMessage { + role: Role::User, + content: user_message, + }, + ], + stream: false, }, - ], - stream: false, - }, - ); - - let mut messages = stream.await?; - while let Some(message) = messages.next().await { - let mut message = message?; - if let Some(choice) = message.choices.pop() { - buffer.update(&mut cx, |buffer, cx| { - let text: Arc = choice.delta.content?.into(); - buffer.edit([(insertion_site.clone()..insertion_site, text)], None, cx); - Some(()) - }); + ); + + let mut messages = stream.await?; + while let Some(message) = messages.next().await { + let mut message = message?; + if let Some(choice) = message.choices.pop() { + buffer.update(&mut cx, |buffer, cx| { + let text: Arc = choice.delta.content?.into(); + buffer.edit([(insertion_site.clone()..insertion_site, text)], None, cx); + Some(()) + }); + } + } + + this.0 + .borrow_mut() + .assist_stacks + .get_mut(&editor_id) + .unwrap() + .retain(|(id, _)| *id != assist_id); + + anyhow::Ok(()) } - } + .log_err() + }); + + self.0 + .borrow_mut() + .assist_stacks + .entry(cx.view_id()) + .or_default() + .push((dbg!(assist_id), assist_task)); + Ok(()) - })) + } + + fn cancel_last_assist(self: &Rc, editor_id: usize) -> bool { + self.0 + .borrow_mut() + .assist_stacks + .get_mut(&editor_id) + .and_then(|assists| assists.pop()) + .is_some() + } } async fn stream_completion(