From 21e8e8763e5a28264a15a4f5ef3d19447247169e Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 15 Jun 2023 13:59:01 +0200 Subject: [PATCH 1/5] Allow splitting of messages using `shift-enter` --- assets/keymaps/default.json | 3 +- crates/ai/src/assistant.rs | 158 ++++++++++++++++++++++++++++-------- 2 files changed, 126 insertions(+), 35 deletions(-) diff --git a/assets/keymaps/default.json b/assets/keymaps/default.json index 45e85fd04ff616054ac2a7d259c453d3ac92d76a..f6682a9f0b7cc51f0e173e8e37152bd8b5b1a2cf 100644 --- a/assets/keymaps/default.json +++ b/assets/keymaps/default.json @@ -200,7 +200,8 @@ "context": "AssistantEditor > Editor", "bindings": { "cmd-enter": "assistant::Assist", - "cmd->": "assistant::QuoteSelection" + "cmd->": "assistant::QuoteSelection", + "shift-enter": "assistant::Split" } }, { diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index e5702cb677b62398c1bb75ae2054980d10c028f9..cd334d77b141d253367a933896de55e5ce09016f 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -8,7 +8,7 @@ use collections::{HashMap, HashSet}; use editor::{ display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint}, scroll::autoscroll::{Autoscroll, AutoscrollStrategy}, - Anchor, Editor, ToOffset as _, + Anchor, Editor, }; use fs::Fs; use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; @@ -40,7 +40,14 @@ const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; actions!( assistant, - [NewContext, Assist, QuoteSelection, ToggleFocus, ResetKey] + [ + NewContext, + Assist, + Split, + QuoteSelection, + ToggleFocus, + ResetKey + ] ); pub fn init(cx: &mut AppContext) { @@ -64,6 +71,7 @@ pub fn init(cx: &mut AppContext) { cx.capture_action(AssistantEditor::cancel_last_assist); cx.add_action(AssistantEditor::quote_selection); cx.capture_action(AssistantEditor::copy); + cx.capture_action(AssistantEditor::split); cx.add_action(AssistantPanel::save_api_key); cx.add_action(AssistantPanel::reset_api_key); cx.add_action( @@ -711,6 +719,67 @@ impl Assistant { } } + fn split_message( + &mut self, + range: Range, + cx: &mut ModelContext, + ) -> (Option, Option) { + let start_message = self.message_for_offset(range.start, cx); + let end_message = self.message_for_offset(range.end, cx); + if let Some((start_message, end_message)) = start_message.zip(end_message) { + let (start_message_ix, _, start_message_metadata) = start_message; + let (end_message_ix, _, _) = end_message; + + // Prevent splitting when range spans multiple messages. + if start_message_ix != end_message_ix { + return (None, None); + } + + let role = start_message_metadata.role; + self.buffer.update(cx, |buffer, cx| { + buffer.edit([(range.end..range.end, "\n")], None, cx) + }); + let suffix = Message { + id: MessageId(post_inc(&mut self.next_message_id.0)), + start: self.buffer.read(cx).anchor_before(range.end + 1), + }; + self.messages.insert(start_message_ix + 1, suffix.clone()); + self.messages_metadata.insert( + suffix.id, + MessageMetadata { + role, + sent_at: Local::now(), + error: None, + }, + ); + + if range.start == range.end { + (None, Some(suffix)) + } else { + self.buffer.update(cx, |buffer, cx| { + buffer.edit([(range.start..range.start, "\n")], None, cx) + }); + let selection = Message { + id: MessageId(post_inc(&mut self.next_message_id.0)), + start: self.buffer.read(cx).anchor_before(range.start + 1), + }; + self.messages + .insert(start_message_ix + 1, selection.clone()); + self.messages_metadata.insert( + selection.id, + MessageMetadata { + role, + sent_at: Local::now(), + error: None, + }, + ); + (Some(selection), Some(suffix)) + } + } else { + (None, None) + } + } + fn summarize(&mut self, cx: &mut ModelContext) { if self.messages.len() >= 2 && self.summary.is_none() { let api_key = self.api_key.borrow().clone(); @@ -755,35 +824,39 @@ impl Assistant { fn open_ai_request_messages(&self, cx: &AppContext) -> Vec { let buffer = self.buffer.read(cx); self.messages(cx) - .map(|(_message, metadata, range)| RequestMessage { + .map(|(_ix, _message, metadata, range)| RequestMessage { role: metadata.role, content: buffer.text_for_range(range).collect(), }) .collect() } - fn message_id_for_offset(&self, offset: usize, cx: &AppContext) -> Option { - Some( - self.messages(cx) - .find(|(_, _, range)| range.contains(&offset)) - .map(|(message, _, _)| message) - .or(self.messages.last())? - .id, - ) + fn message_for_offset<'a>( + &'a self, + offset: usize, + cx: &'a AppContext, + ) -> Option<(usize, &Message, &MessageMetadata)> { + let mut messages = self.messages(cx).peekable(); + while let Some((ix, message, metadata, range)) = messages.next() { + if range.contains(&offset) || messages.peek().is_none() { + return Some((ix, message, metadata)); + } + } + None } fn messages<'a>( &'a self, cx: &'a AppContext, - ) -> impl 'a + Iterator)> { + ) -> impl 'a + Iterator)> { let buffer = self.buffer.read(cx); - let mut messages = self.messages.iter().peekable(); + let mut messages = self.messages.iter().enumerate().peekable(); iter::from_fn(move || { - while let Some(message) = messages.next() { + while let Some((ix, message)) = messages.next() { let metadata = self.messages_metadata.get(&message.id)?; let message_start = message.start.to_offset(buffer); let mut message_end = None; - while let Some(next_message) = messages.peek() { + while let Some((_, next_message)) = messages.peek() { if next_message.start.is_valid(buffer) { message_end = Some(next_message.start); break; @@ -794,7 +867,7 @@ impl Assistant { let message_end = message_end .unwrap_or(language::Anchor::MAX) .to_offset(buffer); - return Some((message, metadata, message_start..message_end)); + return Some((ix, message, metadata, message_start..message_end)); } None }) @@ -857,21 +930,7 @@ impl AssistantEditor { fn assist(&mut self, _: &Assist, cx: &mut ViewContext) { let user_message = self.assistant.update(cx, |assistant, cx| { - let editor = self.editor.read(cx); - let newest_selection = editor - .selections - .newest_anchor() - .head() - .to_offset(&editor.buffer().read(cx).snapshot(cx)); - let message_id = assistant.message_id_for_offset(newest_selection, cx)?; - let metadata = assistant.messages_metadata.get(&message_id)?; - let user_message = if metadata.role == Role::User { - let (_, user_message) = assistant.assist(cx)?; - user_message - } else { - let user_message = assistant.insert_message_after(message_id, Role::User, cx)?; - user_message - }; + let (_, user_message) = assistant.assist(cx)?; Some(user_message) }); @@ -982,7 +1041,7 @@ impl AssistantEditor { .assistant .read(cx) .messages(cx) - .map(|(message, metadata, _)| BlockProperties { + .map(|(_, message, metadata, _)| BlockProperties { position: buffer.anchor_in_excerpt(excerpt_id, message.start), height: 2, style: BlockStyle::Sticky, @@ -1147,7 +1206,7 @@ impl AssistantEditor { let selection = editor.selections.newest::(cx); let mut copied_text = String::new(); let mut spanned_messages = 0; - for (_message, metadata, message_range) in assistant.messages(cx) { + for (_ix, _message, metadata, message_range) in assistant.messages(cx) { if message_range.start >= selection.range().end { break; } else if message_range.end >= selection.range().start { @@ -1174,6 +1233,13 @@ impl AssistantEditor { cx.propagate_action(); } + fn split(&mut self, _: &Split, cx: &mut ViewContext) { + self.assistant.update(cx, |assistant, cx| { + let range = self.editor.read(cx).selections.newest::(cx).range(); + assistant.split_message(range, cx); + }); + } + fn cycle_model(&mut self, cx: &mut ViewContext) { self.assistant.update(cx, |assistant, cx| { let new_model = match assistant.model.as_str() { @@ -1510,6 +1576,30 @@ mod tests { (message_3.id, Role::User, 4..5) ] ); + + // Split a message into prefix, selection and suffix. + buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "3")], None, cx)); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_5.id, Role::System, 4..5), + (message_3.id, Role::User, 5..6) + ] + ); + let (message_6, message_7) = + assistant.update(cx, |assistant, cx| assistant.split_message(2..3, cx)); + let (message_6, message_7) = (message_6.unwrap(), message_7.unwrap()); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..3), + (message_6.id, Role::User, 3..5), + (message_7.id, Role::User, 5..6), + (message_5.id, Role::System, 6..7), + (message_3.id, Role::User, 7..8) + ] + ); } fn messages( @@ -1519,7 +1609,7 @@ mod tests { assistant .read(cx) .messages(cx) - .map(|(message, metadata, range)| (message.id, metadata.role, range)) + .map(|(_, message, metadata, range)| (message.id, metadata.role, range)) .collect() } } From 8c6ba13fef17bd7579a82bb4d1535057f4af4685 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Thu, 15 Jun 2023 09:02:15 -0600 Subject: [PATCH 2/5] Never insert an empty prefix when splitting a message with a non-empty range Co-Authored-By: Antonio Scandurra Co-Authored-By: Piotr Osiewicz --- crates/ai/src/assistant.rs | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index cd334d77b141d253367a933896de55e5ce09016f..0d748e52b2082faaed7867605b0ff59d31b76741 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -727,15 +727,15 @@ impl Assistant { let start_message = self.message_for_offset(range.start, cx); let end_message = self.message_for_offset(range.end, cx); if let Some((start_message, end_message)) = start_message.zip(end_message) { - let (start_message_ix, _, start_message_metadata) = start_message; - let (end_message_ix, _, _) = end_message; + let (start_message_ix, _, metadata, message_range) = start_message; + let (end_message_ix, _, _, _) = end_message; // Prevent splitting when range spans multiple messages. if start_message_ix != end_message_ix { return (None, None); } - let role = start_message_metadata.role; + let role = metadata.role; self.buffer.update(cx, |buffer, cx| { buffer.edit([(range.end..range.end, "\n")], None, cx) }); @@ -753,7 +753,7 @@ impl Assistant { }, ); - if range.start == range.end { + if range.start == range.end || range.start == message_range.start { (None, Some(suffix)) } else { self.buffer.update(cx, |buffer, cx| { @@ -835,11 +835,11 @@ impl Assistant { &'a self, offset: usize, cx: &'a AppContext, - ) -> Option<(usize, &Message, &MessageMetadata)> { + ) -> Option<(usize, &Message, &MessageMetadata, Range)> { let mut messages = self.messages(cx).peekable(); while let Some((ix, message, metadata, range)) = messages.next() { if range.contains(&offset) || messages.peek().is_none() { - return Some((ix, message, metadata)); + return Some((ix, message, metadata, range)); } } None @@ -1600,6 +1600,23 @@ mod tests { (message_3.id, Role::User, 7..8) ] ); + + // Don't include an empty prefix when splitting with a non-empty range + let (no_message, message_8) = + assistant.update(cx, |assistant, cx| assistant.split_message(3..4, cx)); + assert!(no_message.is_none()); + let message_8 = message_8.unwrap(); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..3), + (message_6.id, Role::User, 3..5), + (message_8.id, Role::User, 5..6), + (message_7.id, Role::User, 6..7), + (message_5.id, Role::System, 7..8), + (message_3.id, Role::User, 8..9) + ] + ); } fn messages( From ef6cb11d5c553517d7e5dc400b06aac500e135e5 Mon Sep 17 00:00:00 2001 From: Piotr Osiewicz <24362066+osiewicz@users.noreply.github.com> Date: Fri, 16 Jun 2023 13:29:12 +0200 Subject: [PATCH 3/5] Emit editor event whether we insert a newline or not. --- crates/ai/src/assistant.rs | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index 0d748e52b2082faaed7867605b0ff59d31b76741..d6b8f52dd191e4d604d6584b4a883f77c706f3a8 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -736,8 +736,13 @@ impl Assistant { } let role = metadata.role; - self.buffer.update(cx, |buffer, cx| { - buffer.edit([(range.end..range.end, "\n")], None, cx) + let is_newline = self.buffer.update(cx, |buffer, cx| { + if buffer.chars_at(range.end).next() != Some('\n') { + buffer.edit([(range.end..range.end, "\n")], None, cx); + false + } else { + true + } }); let suffix = Message { id: MessageId(post_inc(&mut self.next_message_id.0)), @@ -752,7 +757,9 @@ impl Assistant { error: None, }, ); - + if is_newline { + cx.emit(AssistantEvent::MessagesEdited); + } if range.start == range.end || range.start == message_range.start { (None, Some(suffix)) } else { From 6c0f65cfe0823545ce0eee843dc818e7f80130d1 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Fri, 16 Jun 2023 10:36:42 -0600 Subject: [PATCH 4/5] Avoid inserting redundant newlines Co-Authored-By: Piotr Osiewicz Co-Authored-By: Antonio Scandurra --- crates/ai/src/assistant.rs | 173 +++++++++++++++++++++++++++++++++---- 1 file changed, 154 insertions(+), 19 deletions(-) diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index d6b8f52dd191e4d604d6584b4a883f77c706f3a8..0e361a6fa30c0b4ed2ff121d90390c45250ad523 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -736,18 +736,33 @@ impl Assistant { } let role = metadata.role; - let is_newline = self.buffer.update(cx, |buffer, cx| { - if buffer.chars_at(range.end).next() != Some('\n') { + let mut edited_buffer = false; + + let mut suffix_start = None; + if range.start > message_range.start && range.end < message_range.end - 1 { + if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') { + suffix_start = Some(range.end + 1); + } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') { + suffix_start = Some(range.end); + } + } + + let suffix = if let Some(suffix_start) = suffix_start { + Message { + id: MessageId(post_inc(&mut self.next_message_id.0)), + start: self.buffer.read(cx).anchor_before(suffix_start), + } + } else { + self.buffer.update(cx, |buffer, cx| { buffer.edit([(range.end..range.end, "\n")], None, cx); - false - } else { - true + }); + edited_buffer = true; + Message { + id: MessageId(post_inc(&mut self.next_message_id.0)), + start: self.buffer.read(cx).anchor_before(range.end + 1), } - }); - let suffix = Message { - id: MessageId(post_inc(&mut self.next_message_id.0)), - start: self.buffer.read(cx).anchor_before(range.end + 1), }; + self.messages.insert(start_message_ix + 1, suffix.clone()); self.messages_metadata.insert( suffix.id, @@ -757,19 +772,38 @@ impl Assistant { error: None, }, ); - if is_newline { - cx.emit(AssistantEvent::MessagesEdited); - } - if range.start == range.end || range.start == message_range.start { + + let new_messages = if range.start == range.end || range.start == message_range.start { (None, Some(suffix)) } else { - self.buffer.update(cx, |buffer, cx| { - buffer.edit([(range.start..range.start, "\n")], None, cx) - }); - let selection = Message { - id: MessageId(post_inc(&mut self.next_message_id.0)), - start: self.buffer.read(cx).anchor_before(range.start + 1), + let mut prefix_end = None; + if range.start > message_range.start && range.end < message_range.end - 1 { + if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') { + prefix_end = Some(range.start + 1); + } else if self.buffer.read(cx).reversed_chars_at(range.start).next() + == Some('\n') + { + prefix_end = Some(range.start); + } + } + + let selection = if let Some(prefix_end) = prefix_end { + cx.emit(AssistantEvent::MessagesEdited); + Message { + id: MessageId(post_inc(&mut self.next_message_id.0)), + start: self.buffer.read(cx).anchor_before(prefix_end), + } + } else { + self.buffer.update(cx, |buffer, cx| { + buffer.edit([(range.start..range.start, "\n")], None, cx) + }); + edited_buffer = true; + Message { + id: MessageId(post_inc(&mut self.next_message_id.0)), + start: self.buffer.read(cx).anchor_before(range.end + 1), + } }; + self.messages .insert(start_message_ix + 1, selection.clone()); self.messages_metadata.insert( @@ -781,7 +815,12 @@ impl Assistant { }, ); (Some(selection), Some(suffix)) + }; + + if !edited_buffer { + cx.emit(AssistantEvent::MessagesEdited); } + new_messages } else { (None, None) } @@ -1594,8 +1633,11 @@ mod tests { (message_3.id, Role::User, 5..6) ] ); + let (message_6, message_7) = assistant.update(cx, |assistant, cx| assistant.split_message(2..3, cx)); + + assert_eq!(buffer.read(cx).text(), "1C\n3\n\n\nD"); // We insert a newline for the new empty message let (message_6, message_7) = (message_6.unwrap(), message_7.unwrap()); assert_eq!( messages(&assistant, cx), @@ -1626,6 +1668,99 @@ mod tests { ); } + #[gpui::test] + fn test_message_splitting(cx: &mut AppContext) { + let registry = Arc::new(LanguageRegistry::test()); + let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx)); + let buffer = assistant.read(cx).buffer.clone(); + + let message_1 = assistant.read(cx).messages[0].clone(); + assert_eq!( + messages(&assistant, cx), + vec![(message_1.id, Role::User, 0..0)] + ); + + buffer.update(cx, |buffer, cx| { + buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx) + }); + + let (_, message_2) = + assistant.update(cx, |assistant, cx| assistant.split_message(3..3, cx)); + let message_2 = message_2.unwrap(); + + // We recycle newlines in the middle of a split message + assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n"); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_2.id, Role::User, 4..16), + ] + ); + + let (_, message_3) = + assistant.update(cx, |assistant, cx| assistant.split_message(3..3, cx)); + let message_3 = message_3.unwrap(); + + // We don't recycle newlines at the end of a split message + assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n"); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_3.id, Role::User, 4..5), + (message_2.id, Role::User, 5..17), + ] + ); + + let (_, message_4) = + assistant.update(cx, |assistant, cx| assistant.split_message(9..9, cx)); + let message_4 = message_4.unwrap(); + assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n"); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_3.id, Role::User, 4..5), + (message_2.id, Role::User, 5..9), + (message_4.id, Role::User, 9..17), + ] + ); + + let (_, message_5) = + assistant.update(cx, |assistant, cx| assistant.split_message(9..9, cx)); + let message_5 = message_5.unwrap(); + assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n"); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_3.id, Role::User, 4..5), + (message_2.id, Role::User, 5..9), + (message_4.id, Role::User, 9..10), + (message_5.id, Role::User, 10..18), + ] + ); + + let (message_6, message_7) = + assistant.update(cx, |assistant, cx| assistant.split_message(14..16, cx)); + let message_6 = message_6.unwrap(); + let message_7 = message_7.unwrap(); + assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n"); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_3.id, Role::User, 4..5), + (message_2.id, Role::User, 5..9), + (message_4.id, Role::User, 9..10), + (message_5.id, Role::User, 10..14), + (message_6.id, Role::User, 14..17), + (message_7.id, Role::User, 17..19), + ] + ); + } + fn messages( assistant: &ModelHandle, cx: &AppContext, From c179dd999035ad6c23c258371bb7ee104c3ed3a0 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Fri, 16 Jun 2023 11:43:16 -0600 Subject: [PATCH 5/5] Remove redundant tests --- crates/ai/src/assistant.rs | 44 -------------------------------------- 1 file changed, 44 deletions(-) diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index 0e361a6fa30c0b4ed2ff121d90390c45250ad523..eff3dc4d20562b9d005b1d2b95f8ee4f7934c694 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -1622,50 +1622,6 @@ mod tests { (message_3.id, Role::User, 4..5) ] ); - - // Split a message into prefix, selection and suffix. - buffer.update(cx, |buffer, cx| buffer.edit([(2..2, "3")], None, cx)); - assert_eq!( - messages(&assistant, cx), - vec![ - (message_1.id, Role::User, 0..4), - (message_5.id, Role::System, 4..5), - (message_3.id, Role::User, 5..6) - ] - ); - - let (message_6, message_7) = - assistant.update(cx, |assistant, cx| assistant.split_message(2..3, cx)); - - assert_eq!(buffer.read(cx).text(), "1C\n3\n\n\nD"); // We insert a newline for the new empty message - let (message_6, message_7) = (message_6.unwrap(), message_7.unwrap()); - assert_eq!( - messages(&assistant, cx), - vec![ - (message_1.id, Role::User, 0..3), - (message_6.id, Role::User, 3..5), - (message_7.id, Role::User, 5..6), - (message_5.id, Role::System, 6..7), - (message_3.id, Role::User, 7..8) - ] - ); - - // Don't include an empty prefix when splitting with a non-empty range - let (no_message, message_8) = - assistant.update(cx, |assistant, cx| assistant.split_message(3..4, cx)); - assert!(no_message.is_none()); - let message_8 = message_8.unwrap(); - assert_eq!( - messages(&assistant, cx), - vec![ - (message_1.id, Role::User, 0..3), - (message_6.id, Role::User, 3..5), - (message_8.id, Role::User, 5..6), - (message_7.id, Role::User, 6..7), - (message_5.id, Role::System, 7..8), - (message_3.id, Role::User, 8..9) - ] - ); } #[gpui::test]