From 54c71c1a359b6b4cd6a8e9685506544d44888d90 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Fri, 16 Jun 2023 12:41:07 -0600 Subject: [PATCH 1/6] Insert reply after the currently selected user message --- crates/ai/src/assistant.rs | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index 853f7262d33f07a1636b39add08ebd9ffbad239b..6686dcc2a374d43bae65b77ef362690c28fa29f0 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -22,7 +22,7 @@ use gpui::{ Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext, }; use isahc::{http::StatusCode, Request, RequestExt}; -use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _}; +use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, Selection, ToOffset as _}; use serde::Deserialize; use settings::SettingsStore; use std::{ @@ -589,7 +589,11 @@ impl Assistant { cx.notify(); } - fn assist(&mut self, cx: &mut ModelContext) -> Option<(MessageAnchor, MessageAnchor)> { + fn assist( + &mut self, + selection: Selection, + cx: &mut ModelContext, + ) -> Option<(MessageAnchor, MessageAnchor)> { let request = OpenAIRequest { model: self.model.clone(), messages: self.open_ai_request_messages(cx), @@ -598,9 +602,13 @@ impl Assistant { let api_key = self.api_key.borrow().clone()?; let stream = stream_completion(api_key, cx.background().clone(), request); - let assistant_message = - self.insert_message_after(self.message_anchors.last()?.id, Role::Assistant, cx)?; + let assistant_message = self.insert_message_after( + self.message_for_offset(selection.head(), cx)?.id, + Role::Assistant, + cx, + )?; let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?; + let task = cx.spawn_weak({ |this, mut cx| async move { let assistant_message_id = assistant_message.id; @@ -979,8 +987,9 @@ impl AssistantEditor { } fn assist(&mut self, _: &Assist, cx: &mut ViewContext) { + let selection = self.editor.read(cx).selections.newest(cx); let user_message = self.assistant.update(cx, |assistant, cx| { - let (_, user_message) = assistant.assist(cx)?; + let (_, user_message) = assistant.assist(selection, cx)?; Some(user_message) }); From 9191a824471874b1821e3d16bb55bb8e8a2a1981 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 19 Jun 2023 14:35:33 +0200 Subject: [PATCH 2/6] Remove `Assistant::open_ai_request_messages` --- crates/ai/src/assistant.rs | 47 +++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 21 deletions(-) diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index a96528039c72fb9c42e9ca042c26f6bacbf59135..be3a49ce18da5d85ec4766f54f6369b102a758d1 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -543,7 +543,7 @@ impl Assistant { fn count_remaining_tokens(&mut self, cx: &mut ModelContext) { let messages = self - .open_ai_request_messages(cx) + .messages(cx) .into_iter() .filter_map(|message| { Some(tiktoken_rs::ChatCompletionRequestMessage { @@ -552,7 +552,7 @@ impl Assistant { Role::Assistant => "assistant".into(), Role::System => "system".into(), }, - content: message.content, + content: self.buffer.read(cx).text_for_range(message.range).collect(), name: None, }) }) @@ -596,7 +596,10 @@ impl Assistant { ) -> Option<(MessageAnchor, MessageAnchor)> { let request = OpenAIRequest { model: self.model.clone(), - messages: self.open_ai_request_messages(cx), + messages: self + .messages(cx) + .map(|message| message.to_open_ai_message(self.buffer.read(cx))) + .collect(), stream: true, }; @@ -841,16 +844,19 @@ impl Assistant { if self.message_anchors.len() >= 2 && self.summary.is_none() { let api_key = self.api_key.borrow().clone(); if let Some(api_key) = api_key { - let mut messages = self.open_ai_request_messages(cx); - messages.truncate(2); - messages.push(RequestMessage { - role: Role::User, - content: "Summarize the conversation into a short title without punctuation" - .into(), - }); + let messages = self + .messages(cx) + .take(2) + .map(|message| message.to_open_ai_message(self.buffer.read(cx))) + .chain(Some(RequestMessage { + role: Role::User, + content: + "Summarize the conversation into a short title without punctuation" + .into(), + })); let request = OpenAIRequest { model: self.model.clone(), - messages, + messages: messages.collect(), stream: true, }; @@ -878,16 +884,6 @@ impl Assistant { } } - fn open_ai_request_messages(&self, cx: &AppContext) -> Vec { - let buffer = self.buffer.read(cx); - self.messages(cx) - .map(|message| RequestMessage { - role: message.role, - content: buffer.text_for_range(message.range).collect(), - }) - .collect() - } - fn message_for_offset<'a>(&'a self, offset: usize, cx: &'a AppContext) -> Option { let mut messages = self.messages(cx).peekable(); while let Some(message) = messages.next() { @@ -1446,6 +1442,15 @@ pub struct Message { error: Option>, } +impl Message { + fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage { + RequestMessage { + role: self.role, + content: buffer.text_for_range(self.range.clone()).collect(), + } + } +} + async fn stream_completion( api_key: String, executor: Arc, From 75e23290285a751fed9b3faf0d81832bb21d5d8c Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 19 Jun 2023 17:23:40 +0200 Subject: [PATCH 3/6] Allow for multi-cursor `assist` and `cycle_role` actions Co-Authored-By: Nathan Sobo Co-Authored-By: Kyle Caverly --- crates/ai/src/assistant.rs | 354 +++++++++++++++++++++++++------------ 1 file changed, 243 insertions(+), 111 deletions(-) diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index be3a49ce18da5d85ec4766f54f6369b102a758d1..83b7105be2f02810d23c77ed419aec5c048db5ee 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -22,7 +22,7 @@ use gpui::{ Subscription, Task, View, ViewContext, ViewHandle, WeakViewHandle, WindowContext, }; use isahc::{http::StatusCode, Request, RequestExt}; -use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, Selection, ToOffset as _}; +use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset as _}; use serde::Deserialize; use settings::SettingsStore; use std::{ @@ -591,106 +591,129 @@ impl Assistant { fn assist( &mut self, - selection: Selection, + selected_messages: HashSet, cx: &mut ModelContext, - ) -> Option<(MessageAnchor, MessageAnchor)> { - let request = OpenAIRequest { - model: self.model.clone(), - messages: self - .messages(cx) - .map(|message| message.to_open_ai_message(self.buffer.read(cx))) - .collect(), - stream: true, - }; + ) -> Vec { + let mut user_messages = Vec::new(); + for selected_message_id in selected_messages { + let selected_message_role = + if let Some(metadata) = self.messages_metadata.get(&selected_message_id) { + metadata.role + } else { + continue; + }; + let Some(user_message) = self.insert_message_after(selected_message_id, Role::User, cx) else { + continue; + }; + user_messages.push(user_message); + if selected_message_role == Role::User { + let request = OpenAIRequest { + model: self.model.clone(), + messages: self + .messages(cx) + .map(|message| message.to_open_ai_message(self.buffer.read(cx))) + .chain(Some(RequestMessage { + role: Role::System, + content: format!( + "Direct your reply to message with id {}. Do not include a [Message X] header.", + selected_message_id.0 + ), + })) + .collect(), + stream: true, + }; + + let Some(api_key) = self.api_key.borrow().clone() else { continue }; + let stream = stream_completion(api_key, cx.background().clone(), request); + let assistant_message = self + .insert_message_after(selected_message_id, Role::Assistant, cx) + .unwrap(); + + let task = cx.spawn_weak({ + |this, mut cx| async move { + let assistant_message_id = assistant_message.id; + let stream_completion = async { + let mut messages = stream.await?; + + while let Some(message) = messages.next().await { + let mut message = message?; + if let Some(choice) = message.choices.pop() { + this.upgrade(&cx) + .ok_or_else(|| anyhow!("assistant was dropped"))? + .update(&mut cx, |this, cx| { + let text: Arc = choice.delta.content?.into(); + let message_ix = this.message_anchors.iter().position( + |message| message.id == assistant_message_id, + )?; + this.buffer.update(cx, |buffer, cx| { + let offset = if message_ix + 1 + == this.message_anchors.len() + { + buffer.len() + } else { + this.message_anchors[message_ix + 1] + .start + .to_offset(buffer) + .saturating_sub(1) + }; + buffer.edit([(offset..offset, text)], None, cx); + }); + cx.emit(AssistantEvent::StreamedCompletion); + + Some(()) + }); + } + } - let api_key = self.api_key.borrow().clone()?; - let stream = stream_completion(api_key, cx.background().clone(), request); - let assistant_message = self.insert_message_after( - self.message_for_offset(selection.head(), cx)?.id, - Role::Assistant, - cx, - )?; - let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?; - - let task = cx.spawn_weak({ - |this, mut cx| async move { - let assistant_message_id = assistant_message.id; - let stream_completion = async { - let mut messages = stream.await?; - - while let Some(message) = messages.next().await { - let mut message = message?; - if let Some(choice) = message.choices.pop() { this.upgrade(&cx) .ok_or_else(|| anyhow!("assistant was dropped"))? .update(&mut cx, |this, cx| { - let text: Arc = choice.delta.content?.into(); - let message_ix = this - .message_anchors - .iter() - .position(|message| message.id == assistant_message_id)?; - this.buffer.update(cx, |buffer, cx| { - let offset = if message_ix + 1 == this.message_anchors.len() - { - buffer.len() - } else { - this.message_anchors[message_ix + 1] - .start - .to_offset(buffer) - .saturating_sub(1) - }; - buffer.edit([(offset..offset, text)], None, cx); + this.pending_completions.retain(|completion| { + completion.id != this.completion_count }); - cx.emit(AssistantEvent::StreamedCompletion); - - Some(()) + this.summarize(cx); }); + + anyhow::Ok(()) + }; + + let result = stream_completion.await; + if let Some(this) = this.upgrade(&cx) { + this.update(&mut cx, |this, cx| { + if let Err(error) = result { + if let Some(metadata) = + this.messages_metadata.get_mut(&assistant_message.id) + { + metadata.error = Some(error.to_string().trim().into()); + cx.notify(); + } + } + }); } } + }); - this.upgrade(&cx) - .ok_or_else(|| anyhow!("assistant was dropped"))? - .update(&mut cx, |this, cx| { - this.pending_completions - .retain(|completion| completion.id != this.completion_count); - this.summarize(cx); - }); - - anyhow::Ok(()) - }; - - let result = stream_completion.await; - if let Some(this) = this.upgrade(&cx) { - this.update(&mut cx, |this, cx| { - if let Err(error) = result { - if let Some(metadata) = - this.messages_metadata.get_mut(&assistant_message.id) - { - metadata.error = Some(error.to_string().trim().into()); - cx.notify(); - } - } - }); - } + self.pending_completions.push(PendingCompletion { + id: post_inc(&mut self.completion_count), + _task: task, + }); } - }); + } - self.pending_completions.push(PendingCompletion { - id: post_inc(&mut self.completion_count), - _task: task, - }); - Some((assistant_message, user_message)) + user_messages } fn cancel_last_assist(&mut self) -> bool { self.pending_completions.pop().is_some() } - fn cycle_message_role(&mut self, id: MessageId, cx: &mut ModelContext) { - if let Some(metadata) = self.messages_metadata.get_mut(&id) { - metadata.role.cycle(); - cx.emit(AssistantEvent::MessagesEdited); - cx.notify(); + fn cycle_message_roles(&mut self, ids: HashSet, cx: &mut ModelContext) { + for id in ids { + if let Some(metadata) = self.messages_metadata.get_mut(&id) { + metadata.role.cycle(); + cx.emit(AssistantEvent::MessagesEdited); + cx.notify(); + } } } @@ -884,14 +907,39 @@ impl Assistant { } } - fn message_for_offset<'a>(&'a self, offset: usize, cx: &'a AppContext) -> Option { + fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option { + self.messages_for_offsets([offset], cx).pop() + } + + fn messages_for_offsets( + &self, + offsets: impl IntoIterator, + cx: &AppContext, + ) -> Vec { + let mut result = Vec::new(); + + let buffer_len = self.buffer.read(cx).len(); let mut messages = self.messages(cx).peekable(); - while let Some(message) = messages.next() { - if message.range.contains(&offset) || messages.peek().is_none() { - return Some(message); + let mut offsets = offsets.into_iter().peekable(); + while let Some(offset) = offsets.next() { + // Skip messages that start after the offset. + while messages.peek().map_or(false, |message| { + message.range.end < offset || (message.range.end == offset && offset < buffer_len) + }) { + messages.next(); } + let Some(message) = messages.peek() else { continue }; + + // Skip offsets that are in the same message. + while offsets.peek().map_or(false, |offset| { + message.range.contains(offset) || message.range.end == buffer_len + }) { + offsets.next(); + } + + result.push(message.clone()); } - None + result } fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator { @@ -983,24 +1031,32 @@ impl AssistantEditor { } fn assist(&mut self, _: &Assist, cx: &mut ViewContext) { - let selection = self.editor.read(cx).selections.newest(cx); - let user_message = self.assistant.update(cx, |assistant, cx| { - let (_, user_message) = assistant.assist(selection, cx)?; - Some(user_message) + let cursors = self.cursors(cx); + + let user_messages = self.assistant.update(cx, |assistant, cx| { + let selected_messages = assistant + .messages_for_offsets(cursors, cx) + .into_iter() + .map(|message| message.id) + .collect(); + assistant.assist(selected_messages, cx) + }); + let new_selections = user_messages + .iter() + .map(|message| { + let cursor = message + .start + .to_offset(self.assistant.read(cx).buffer.read(cx)); + cursor..cursor + }) + .collect::>(); + self.editor.update(cx, |editor, cx| { + editor.change_selections( + Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)), + cx, + |selections| selections.select_ranges(new_selections), + ); }); - - if let Some(user_message) = user_message { - let cursor = user_message - .start - .to_offset(&self.assistant.read(cx).buffer.read(cx)); - self.editor.update(cx, |editor, cx| { - editor.change_selections( - Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)), - cx, - |selections| selections.select_ranges([cursor..cursor]), - ); - }); - } } fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext) { @@ -1013,14 +1069,25 @@ impl AssistantEditor { } fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext) { - let cursor_offset = self.editor.read(cx).selections.newest(cx).head(); + let cursors = self.cursors(cx); self.assistant.update(cx, |assistant, cx| { - if let Some(message) = assistant.message_for_offset(cursor_offset, cx) { - assistant.cycle_message_role(message.id, cx); - } + let messages = assistant + .messages_for_offsets(cursors, cx) + .into_iter() + .map(|message| message.id) + .collect(); + assistant.cycle_message_roles(messages, cx) }); } + fn cursors(&self, cx: &AppContext) -> Vec { + let selections = self.editor.read(cx).selections.all::(cx); + selections + .into_iter() + .map(|selection| selection.head()) + .collect() + } + fn handle_assistant_event( &mut self, _: ModelHandle, @@ -1149,7 +1216,10 @@ impl AssistantEditor { let assistant = assistant.clone(); move |_, _, cx| { assistant.update(cx, |assistant, cx| { - assistant.cycle_message_role(message_id, cx) + assistant.cycle_message_roles( + HashSet::from_iter(Some(message_id)), + cx, + ) }) } }); @@ -1444,9 +1514,11 @@ pub struct Message { impl Message { fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage { + let mut content = format!("[Message {}]\n", self.id.0).to_string(); + content.extend(buffer.text_for_range(self.range.clone())); RequestMessage { role: self.role, - content: buffer.text_for_range(self.range.clone()).collect(), + content, } } } @@ -1761,6 +1833,66 @@ mod tests { ); } + #[gpui::test] + fn test_messages_for_offsets(cx: &mut AppContext) { + let registry = Arc::new(LanguageRegistry::test()); + let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx)); + let buffer = assistant.read(cx).buffer.clone(); + + let message_1 = assistant.read(cx).message_anchors[0].clone(); + assert_eq!( + messages(&assistant, cx), + vec![(message_1.id, Role::User, 0..0)] + ); + + buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx)); + let message_2 = assistant + .update(cx, |assistant, cx| { + assistant.insert_message_after(message_1.id, Role::User, cx) + }) + .unwrap(); + buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx)); + + let message_3 = assistant + .update(cx, |assistant, cx| { + assistant.insert_message_after(message_2.id, Role::User, cx) + }) + .unwrap(); + buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx)); + + assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc"); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_2.id, Role::User, 4..8), + (message_3.id, Role::User, 8..11) + ] + ); + + assert_eq!( + message_ids_for_offsets(&assistant, &[0, 4, 9], cx), + [message_1.id, message_2.id, message_3.id] + ); + assert_eq!( + message_ids_for_offsets(&assistant, &[0, 1, 11], cx), + [message_1.id, message_3.id] + ); + + fn message_ids_for_offsets( + assistant: &ModelHandle, + offsets: &[usize], + cx: &AppContext, + ) -> Vec { + assistant + .read(cx) + .messages_for_offsets(offsets.iter().copied(), cx) + .into_iter() + .map(|message| message.id) + .collect() + } + } + fn messages( assistant: &ModelHandle, cx: &AppContext, From cb55356106624959b77941d3e06796aa7ba2a1a9 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 19 Jun 2023 17:53:05 +0200 Subject: [PATCH 4/6] WIP --- crates/ai/src/ai.rs | 2 +- crates/ai/src/assistant.rs | 34 +++++++++++++++++++--------------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index 40224b3229de1665e3fac89be0d035154e2cf67f..b3b62c6a2422f0a354f27a34f0e28ffea23af95f 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; use std::fmt::{self, Display}; // Data types for chat completion requests -#[derive(Serialize)] +#[derive(Debug, Serialize)] struct OpenAIRequest { model: String, messages: Vec, diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index 83b7105be2f02810d23c77ed419aec5c048db5ee..121e4600cc92190df5514710feca3adb495699c0 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -473,7 +473,7 @@ impl Assistant { language_registry: Arc, cx: &mut ModelContext, ) -> Self { - let model = "gpt-3.5-turbo"; + let model = "gpt-3.5-turbo-0613"; let markdown = language_registry.language_for_name("Markdown"); let buffer = cx.add_model(|cx| { let mut buffer = Buffer::new(0, "", cx); @@ -602,11 +602,13 @@ impl Assistant { } else { continue; }; - let Some(user_message) = self.insert_message_after(selected_message_id, Role::User, cx) else { - continue; - }; - user_messages.push(user_message); - if selected_message_role == Role::User { + + if selected_message_role == Role::Assistant { + let Some(user_message) = self.insert_message_after(selected_message_id, Role::User, cx) else { + continue; + }; + user_messages.push(user_message); + } else { let request = OpenAIRequest { model: self.model.clone(), messages: self @@ -1050,13 +1052,15 @@ impl AssistantEditor { cursor..cursor }) .collect::>(); - self.editor.update(cx, |editor, cx| { - editor.change_selections( - Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)), - cx, - |selections| selections.select_ranges(new_selections), - ); - }); + if !new_selections.is_empty() { + self.editor.update(cx, |editor, cx| { + editor.change_selections( + Some(Autoscroll::Strategy(AutoscrollStrategy::Fit)), + cx, + |selections| selections.select_ranges(new_selections), + ); + }); + } } fn cancel_last_assist(&mut self, _: &editor::Cancel, cx: &mut ViewContext) { @@ -1383,8 +1387,8 @@ impl AssistantEditor { fn cycle_model(&mut self, cx: &mut ViewContext) { self.assistant.update(cx, |assistant, cx| { let new_model = match assistant.model.as_str() { - "gpt-4" => "gpt-3.5-turbo", - _ => "gpt-4", + "gpt-4-0613" => "gpt-3.5-turbo-0613", + _ => "gpt-4-0613", }; assistant.set_model(new_model.into(), cx); }); From 8673b0b75bdf9a654a05751e44ada4eb0bf07052 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 20 Jun 2023 11:59:51 +0200 Subject: [PATCH 5/6] Avoid including pending or errored messages on `assist` --- Cargo.lock | 1 + crates/ai/Cargo.toml | 1 + crates/ai/src/assistant.rs | 148 +++++++++++++++++++++++-------------- 3 files changed, 94 insertions(+), 56 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 24fd67f90ec17ecc7d0440fa9d6dc52996a939b2..957f8c04bb7a57fdae9fab29e1ede0510c70c077 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -114,6 +114,7 @@ dependencies = [ "serde", "serde_json", "settings", + "smol", "theme", "tiktoken-rs", "util", diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index 9d67cbd108e79145db2bae2c709ee4d7c0b61660..7f8954bb21ea88a0b14f7fd5bacf26743de3c6be 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -28,6 +28,7 @@ isahc.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true +smol.workspace = true tiktoken-rs = "0.4" [dev-dependencies] diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index 121e4600cc92190df5514710feca3adb495699c0..e6b62e4ea1532d3e9f9fde6f3867b1cda7e143c2 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -518,7 +518,7 @@ impl Assistant { MessageMetadata { role: Role::User, sent_at: Local::now(), - error: None, + status: MessageStatus::Done, }, ); @@ -595,6 +595,7 @@ impl Assistant { cx: &mut ModelContext, ) -> Vec { let mut user_messages = Vec::new(); + let mut tasks = Vec::new(); for selected_message_id in selected_messages { let selected_message_role = if let Some(metadata) = self.messages_metadata.get(&selected_message_id) { @@ -604,15 +605,22 @@ impl Assistant { }; if selected_message_role == Role::Assistant { - let Some(user_message) = self.insert_message_after(selected_message_id, Role::User, cx) else { + if let Some(user_message) = self.insert_message_after( + selected_message_id, + Role::User, + MessageStatus::Done, + cx, + ) { + user_messages.push(user_message); + } else { continue; - }; - user_messages.push(user_message); + } } else { let request = OpenAIRequest { model: self.model.clone(), messages: self .messages(cx) + .filter(|message| matches!(message.status, MessageStatus::Done)) .map(|message| message.to_open_ai_message(self.buffer.read(cx))) .chain(Some(RequestMessage { role: Role::System, @@ -628,10 +636,15 @@ impl Assistant { let Some(api_key) = self.api_key.borrow().clone() else { continue }; let stream = stream_completion(api_key, cx.background().clone(), request); let assistant_message = self - .insert_message_after(selected_message_id, Role::Assistant, cx) + .insert_message_after( + selected_message_id, + Role::Assistant, + MessageStatus::Pending, + cx, + ) .unwrap(); - let task = cx.spawn_weak({ + tasks.push(cx.spawn_weak({ |this, mut cx| async move { let assistant_message_id = assistant_message.id; let stream_completion = async { @@ -648,16 +661,15 @@ impl Assistant { |message| message.id == assistant_message_id, )?; this.buffer.update(cx, |buffer, cx| { - let offset = if message_ix + 1 - == this.message_anchors.len() - { - buffer.len() - } else { - this.message_anchors[message_ix + 1] - .start - .to_offset(buffer) - .saturating_sub(1) - }; + let offset = this.message_anchors[message_ix + 1..] + .iter() + .find(|message| message.start.is_valid(buffer)) + .map_or(buffer.len(), |message| { + message + .start + .to_offset(buffer) + .saturating_sub(1) + }); buffer.edit([(offset..offset, text)], None, cx); }); cx.emit(AssistantEvent::StreamedCompletion); @@ -665,6 +677,7 @@ impl Assistant { Some(()) }); } + smol::future::yield_now().await; } this.upgrade(&cx) @@ -682,26 +695,35 @@ impl Assistant { let result = stream_completion.await; if let Some(this) = this.upgrade(&cx) { this.update(&mut cx, |this, cx| { - if let Err(error) = result { - if let Some(metadata) = - this.messages_metadata.get_mut(&assistant_message.id) - { - metadata.error = Some(error.to_string().trim().into()); - cx.notify(); + if let Some(metadata) = + this.messages_metadata.get_mut(&assistant_message.id) + { + match result { + Ok(_) => { + metadata.status = MessageStatus::Done; + } + Err(error) => { + metadata.status = MessageStatus::Error( + error.to_string().trim().into(), + ); + } } + cx.notify(); } }); } } - }); - - self.pending_completions.push(PendingCompletion { - id: post_inc(&mut self.completion_count), - _task: task, - }); + })); } } + if !tasks.is_empty() { + self.pending_completions.push(PendingCompletion { + id: post_inc(&mut self.completion_count), + _tasks: tasks, + }); + } + user_messages } @@ -723,6 +745,7 @@ impl Assistant { &mut self, message_id: MessageId, role: Role, + status: MessageStatus, cx: &mut ModelContext, ) -> Option { if let Some(prev_message_ix) = self @@ -749,7 +772,7 @@ impl Assistant { MessageMetadata { role, sent_at: Local::now(), - error: None, + status, }, ); cx.emit(AssistantEvent::MessagesEdited); @@ -808,7 +831,7 @@ impl Assistant { MessageMetadata { role, sent_at: Local::now(), - error: None, + status: MessageStatus::Done, }, ); @@ -850,7 +873,7 @@ impl Assistant { MessageMetadata { role, sent_at: Local::now(), - error: None, + status: MessageStatus::Done, }, ); (Some(selection), Some(suffix)) @@ -970,7 +993,7 @@ impl Assistant { anchor: message_anchor.start, role: metadata.role, sent_at: metadata.sent_at, - error: metadata.error.clone(), + status: metadata.status.clone(), }); } None @@ -980,7 +1003,7 @@ impl Assistant { struct PendingCompletion { id: usize, - _task: Task<()>, + _tasks: Vec>, } enum AssistantEditorEvent { @@ -1239,22 +1262,28 @@ impl AssistantEditor { .with_style(style.sent_at.container) .aligned(), ) - .with_children(message.error.as_ref().map(|error| { - Svg::new("icons/circle_x_mark_12.svg") - .with_color(style.error_icon.color) - .constrained() - .with_width(style.error_icon.width) - .contained() - .with_style(style.error_icon.container) - .with_tooltip::( - message_id.0, - error.to_string(), - None, - theme.tooltip.clone(), - cx, + .with_children( + if let MessageStatus::Error(error) = &message.status { + Some( + Svg::new("icons/circle_x_mark_12.svg") + .with_color(style.error_icon.color) + .constrained() + .with_width(style.error_icon.width) + .contained() + .with_style(style.error_icon.container) + .with_tooltip::( + message_id.0, + error.to_string(), + None, + theme.tooltip.clone(), + cx, + ) + .aligned(), ) - .aligned() - })) + } else { + None + }, + ) .aligned() .left() .contained() @@ -1502,7 +1531,14 @@ struct MessageAnchor { struct MessageMetadata { role: Role, sent_at: DateTime, - error: Option>, + status: MessageStatus, +} + +#[derive(Clone, Debug)] +enum MessageStatus { + Pending, + Done, + Error(Arc), } #[derive(Clone, Debug)] @@ -1513,7 +1549,7 @@ pub struct Message { anchor: language::Anchor, role: Role, sent_at: DateTime, - error: Option>, + status: MessageStatus, } impl Message { @@ -1632,7 +1668,7 @@ mod tests { let message_2 = assistant.update(cx, |assistant, cx| { assistant - .insert_message_after(message_1.id, Role::Assistant, cx) + .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx) .unwrap() }); assert_eq!( @@ -1656,7 +1692,7 @@ mod tests { let message_3 = assistant.update(cx, |assistant, cx| { assistant - .insert_message_after(message_2.id, Role::User, cx) + .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) .unwrap() }); assert_eq!( @@ -1670,7 +1706,7 @@ mod tests { let message_4 = assistant.update(cx, |assistant, cx| { assistant - .insert_message_after(message_2.id, Role::User, cx) + .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) .unwrap() }); assert_eq!( @@ -1731,7 +1767,7 @@ mod tests { // Ensure we can still insert after a merged message. let message_5 = assistant.update(cx, |assistant, cx| { assistant - .insert_message_after(message_1.id, Role::System, cx) + .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx) .unwrap() }); assert_eq!( @@ -1852,14 +1888,14 @@ mod tests { buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx)); let message_2 = assistant .update(cx, |assistant, cx| { - assistant.insert_message_after(message_1.id, Role::User, cx) + assistant.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx) }) .unwrap(); buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx)); let message_3 = assistant .update(cx, |assistant, cx| { - assistant.insert_message_after(message_2.id, Role::User, cx) + assistant.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx) }) .unwrap(); buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx)); From 1d84da1d3363175e4231ce80b5717487678b559e Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 20 Jun 2023 15:32:51 +0200 Subject: [PATCH 6/6] Improve prompt --- crates/ai/src/assistant.rs | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index e6b62e4ea1532d3e9f9fde6f3867b1cda7e143c2..ce816f147bac15db7a12ac2dbc7666303ff098b5 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -621,7 +621,21 @@ impl Assistant { messages: self .messages(cx) .filter(|message| matches!(message.status, MessageStatus::Done)) - .map(|message| message.to_open_ai_message(self.buffer.read(cx))) + .flat_map(|message| { + let mut system_message = None; + if message.id == selected_message_id { + system_message = Some(RequestMessage { + role: Role::System, + content: concat!( + "Treat the following messages as additional knowledge you have learned about, ", + "but act as if they were not part of this conversation. That is, treat them ", + "as if the user didn't see them and couldn't possibly inquire about them." + ).into() + }); + } + + Some(message.to_open_ai_message(self.buffer.read(cx))).into_iter().chain(system_message) + }) .chain(Some(RequestMessage { role: Role::System, content: format!(