diff --git a/assets/keymaps/default.json b/assets/keymaps/default.json index 45e85fd04ff616054ac2a7d259c453d3ac92d76a..f6682a9f0b7cc51f0e173e8e37152bd8b5b1a2cf 100644 --- a/assets/keymaps/default.json +++ b/assets/keymaps/default.json @@ -200,7 +200,8 @@ "context": "AssistantEditor > Editor", "bindings": { "cmd-enter": "assistant::Assist", - "cmd->": "assistant::QuoteSelection" + "cmd->": "assistant::QuoteSelection", + "shift-enter": "assistant::Split" } }, { diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index e5702cb677b62398c1bb75ae2054980d10c028f9..eff3dc4d20562b9d005b1d2b95f8ee4f7934c694 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -8,7 +8,7 @@ use collections::{HashMap, HashSet}; use editor::{ display_map::{BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint}, scroll::autoscroll::{Autoscroll, AutoscrollStrategy}, - Anchor, Editor, ToOffset as _, + Anchor, Editor, }; use fs::Fs; use futures::{io::BufReader, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; @@ -40,7 +40,14 @@ const OPENAI_API_URL: &'static str = "https://api.openai.com/v1"; actions!( assistant, - [NewContext, Assist, QuoteSelection, ToggleFocus, ResetKey] + [ + NewContext, + Assist, + Split, + QuoteSelection, + ToggleFocus, + ResetKey + ] ); pub fn init(cx: &mut AppContext) { @@ -64,6 +71,7 @@ pub fn init(cx: &mut AppContext) { cx.capture_action(AssistantEditor::cancel_last_assist); cx.add_action(AssistantEditor::quote_selection); cx.capture_action(AssistantEditor::copy); + cx.capture_action(AssistantEditor::split); cx.add_action(AssistantPanel::save_api_key); cx.add_action(AssistantPanel::reset_api_key); cx.add_action( @@ -711,6 +719,113 @@ impl Assistant { } } + fn split_message( + &mut self, + range: Range, + cx: &mut ModelContext, + ) -> (Option, Option) { + let start_message = self.message_for_offset(range.start, cx); + let end_message = self.message_for_offset(range.end, cx); + if let Some((start_message, end_message)) = start_message.zip(end_message) { + let (start_message_ix, _, metadata, message_range) = start_message; + let (end_message_ix, _, _, _) = end_message; + + // Prevent splitting when range spans multiple messages. + if start_message_ix != end_message_ix { + return (None, None); + } + + let role = metadata.role; + let mut edited_buffer = false; + + let mut suffix_start = None; + if range.start > message_range.start && range.end < message_range.end - 1 { + if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') { + suffix_start = Some(range.end + 1); + } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') { + suffix_start = Some(range.end); + } + } + + let suffix = if let Some(suffix_start) = suffix_start { + Message { + id: MessageId(post_inc(&mut self.next_message_id.0)), + start: self.buffer.read(cx).anchor_before(suffix_start), + } + } else { + self.buffer.update(cx, |buffer, cx| { + buffer.edit([(range.end..range.end, "\n")], None, cx); + }); + edited_buffer = true; + Message { + id: MessageId(post_inc(&mut self.next_message_id.0)), + start: self.buffer.read(cx).anchor_before(range.end + 1), + } + }; + + self.messages.insert(start_message_ix + 1, suffix.clone()); + self.messages_metadata.insert( + suffix.id, + MessageMetadata { + role, + sent_at: Local::now(), + error: None, + }, + ); + + let new_messages = if range.start == range.end || range.start == message_range.start { + (None, Some(suffix)) + } else { + let mut prefix_end = None; + if range.start > message_range.start && range.end < message_range.end - 1 { + if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') { + prefix_end = Some(range.start + 1); + } else if self.buffer.read(cx).reversed_chars_at(range.start).next() + == Some('\n') + { + prefix_end = Some(range.start); + } + } + + let selection = if let Some(prefix_end) = prefix_end { + cx.emit(AssistantEvent::MessagesEdited); + Message { + id: MessageId(post_inc(&mut self.next_message_id.0)), + start: self.buffer.read(cx).anchor_before(prefix_end), + } + } else { + self.buffer.update(cx, |buffer, cx| { + buffer.edit([(range.start..range.start, "\n")], None, cx) + }); + edited_buffer = true; + Message { + id: MessageId(post_inc(&mut self.next_message_id.0)), + start: self.buffer.read(cx).anchor_before(range.end + 1), + } + }; + + self.messages + .insert(start_message_ix + 1, selection.clone()); + self.messages_metadata.insert( + selection.id, + MessageMetadata { + role, + sent_at: Local::now(), + error: None, + }, + ); + (Some(selection), Some(suffix)) + }; + + if !edited_buffer { + cx.emit(AssistantEvent::MessagesEdited); + } + new_messages + } else { + (None, None) + } + } + fn summarize(&mut self, cx: &mut ModelContext) { if self.messages.len() >= 2 && self.summary.is_none() { let api_key = self.api_key.borrow().clone(); @@ -755,35 +870,39 @@ impl Assistant { fn open_ai_request_messages(&self, cx: &AppContext) -> Vec { let buffer = self.buffer.read(cx); self.messages(cx) - .map(|(_message, metadata, range)| RequestMessage { + .map(|(_ix, _message, metadata, range)| RequestMessage { role: metadata.role, content: buffer.text_for_range(range).collect(), }) .collect() } - fn message_id_for_offset(&self, offset: usize, cx: &AppContext) -> Option { - Some( - self.messages(cx) - .find(|(_, _, range)| range.contains(&offset)) - .map(|(message, _, _)| message) - .or(self.messages.last())? - .id, - ) + fn message_for_offset<'a>( + &'a self, + offset: usize, + cx: &'a AppContext, + ) -> Option<(usize, &Message, &MessageMetadata, Range)> { + let mut messages = self.messages(cx).peekable(); + while let Some((ix, message, metadata, range)) = messages.next() { + if range.contains(&offset) || messages.peek().is_none() { + return Some((ix, message, metadata, range)); + } + } + None } fn messages<'a>( &'a self, cx: &'a AppContext, - ) -> impl 'a + Iterator)> { + ) -> impl 'a + Iterator)> { let buffer = self.buffer.read(cx); - let mut messages = self.messages.iter().peekable(); + let mut messages = self.messages.iter().enumerate().peekable(); iter::from_fn(move || { - while let Some(message) = messages.next() { + while let Some((ix, message)) = messages.next() { let metadata = self.messages_metadata.get(&message.id)?; let message_start = message.start.to_offset(buffer); let mut message_end = None; - while let Some(next_message) = messages.peek() { + while let Some((_, next_message)) = messages.peek() { if next_message.start.is_valid(buffer) { message_end = Some(next_message.start); break; @@ -794,7 +913,7 @@ impl Assistant { let message_end = message_end .unwrap_or(language::Anchor::MAX) .to_offset(buffer); - return Some((message, metadata, message_start..message_end)); + return Some((ix, message, metadata, message_start..message_end)); } None }) @@ -857,21 +976,7 @@ impl AssistantEditor { fn assist(&mut self, _: &Assist, cx: &mut ViewContext) { let user_message = self.assistant.update(cx, |assistant, cx| { - let editor = self.editor.read(cx); - let newest_selection = editor - .selections - .newest_anchor() - .head() - .to_offset(&editor.buffer().read(cx).snapshot(cx)); - let message_id = assistant.message_id_for_offset(newest_selection, cx)?; - let metadata = assistant.messages_metadata.get(&message_id)?; - let user_message = if metadata.role == Role::User { - let (_, user_message) = assistant.assist(cx)?; - user_message - } else { - let user_message = assistant.insert_message_after(message_id, Role::User, cx)?; - user_message - }; + let (_, user_message) = assistant.assist(cx)?; Some(user_message) }); @@ -982,7 +1087,7 @@ impl AssistantEditor { .assistant .read(cx) .messages(cx) - .map(|(message, metadata, _)| BlockProperties { + .map(|(_, message, metadata, _)| BlockProperties { position: buffer.anchor_in_excerpt(excerpt_id, message.start), height: 2, style: BlockStyle::Sticky, @@ -1147,7 +1252,7 @@ impl AssistantEditor { let selection = editor.selections.newest::(cx); let mut copied_text = String::new(); let mut spanned_messages = 0; - for (_message, metadata, message_range) in assistant.messages(cx) { + for (_ix, _message, metadata, message_range) in assistant.messages(cx) { if message_range.start >= selection.range().end { break; } else if message_range.end >= selection.range().start { @@ -1174,6 +1279,13 @@ impl AssistantEditor { cx.propagate_action(); } + fn split(&mut self, _: &Split, cx: &mut ViewContext) { + self.assistant.update(cx, |assistant, cx| { + let range = self.editor.read(cx).selections.newest::(cx).range(); + assistant.split_message(range, cx); + }); + } + fn cycle_model(&mut self, cx: &mut ViewContext) { self.assistant.update(cx, |assistant, cx| { let new_model = match assistant.model.as_str() { @@ -1512,6 +1624,99 @@ mod tests { ); } + #[gpui::test] + fn test_message_splitting(cx: &mut AppContext) { + let registry = Arc::new(LanguageRegistry::test()); + let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx)); + let buffer = assistant.read(cx).buffer.clone(); + + let message_1 = assistant.read(cx).messages[0].clone(); + assert_eq!( + messages(&assistant, cx), + vec![(message_1.id, Role::User, 0..0)] + ); + + buffer.update(cx, |buffer, cx| { + buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx) + }); + + let (_, message_2) = + assistant.update(cx, |assistant, cx| assistant.split_message(3..3, cx)); + let message_2 = message_2.unwrap(); + + // We recycle newlines in the middle of a split message + assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n"); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_2.id, Role::User, 4..16), + ] + ); + + let (_, message_3) = + assistant.update(cx, |assistant, cx| assistant.split_message(3..3, cx)); + let message_3 = message_3.unwrap(); + + // We don't recycle newlines at the end of a split message + assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n"); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_3.id, Role::User, 4..5), + (message_2.id, Role::User, 5..17), + ] + ); + + let (_, message_4) = + assistant.update(cx, |assistant, cx| assistant.split_message(9..9, cx)); + let message_4 = message_4.unwrap(); + assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n"); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_3.id, Role::User, 4..5), + (message_2.id, Role::User, 5..9), + (message_4.id, Role::User, 9..17), + ] + ); + + let (_, message_5) = + assistant.update(cx, |assistant, cx| assistant.split_message(9..9, cx)); + let message_5 = message_5.unwrap(); + assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n"); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_3.id, Role::User, 4..5), + (message_2.id, Role::User, 5..9), + (message_4.id, Role::User, 9..10), + (message_5.id, Role::User, 10..18), + ] + ); + + let (message_6, message_7) = + assistant.update(cx, |assistant, cx| assistant.split_message(14..16, cx)); + let message_6 = message_6.unwrap(); + let message_7 = message_7.unwrap(); + assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n"); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_3.id, Role::User, 4..5), + (message_2.id, Role::User, 5..9), + (message_4.id, Role::User, 9..10), + (message_5.id, Role::User, 10..14), + (message_6.id, Role::User, 14..17), + (message_7.id, Role::User, 17..19), + ] + ); + } + fn messages( assistant: &ModelHandle, cx: &AppContext, @@ -1519,7 +1724,7 @@ mod tests { assistant .read(cx) .messages(cx) - .map(|(message, metadata, range)| (message.id, metadata.role, range)) + .map(|(_, message, metadata, range)| (message.id, metadata.role, range)) .collect() } }