diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index d6b8f52dd191e4d604d6584b4a883f77c706f3a8..0e361a6fa30c0b4ed2ff121d90390c45250ad523 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -736,18 +736,33 @@ impl Assistant { } let role = metadata.role; - let is_newline = self.buffer.update(cx, |buffer, cx| { - if buffer.chars_at(range.end).next() != Some('\n') { + let mut edited_buffer = false; + + let mut suffix_start = None; + if range.start > message_range.start && range.end < message_range.end - 1 { + if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') { + suffix_start = Some(range.end + 1); + } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') { + suffix_start = Some(range.end); + } + } + + let suffix = if let Some(suffix_start) = suffix_start { + Message { + id: MessageId(post_inc(&mut self.next_message_id.0)), + start: self.buffer.read(cx).anchor_before(suffix_start), + } + } else { + self.buffer.update(cx, |buffer, cx| { buffer.edit([(range.end..range.end, "\n")], None, cx); - false - } else { - true + }); + edited_buffer = true; + Message { + id: MessageId(post_inc(&mut self.next_message_id.0)), + start: self.buffer.read(cx).anchor_before(range.end + 1), } - }); - let suffix = Message { - id: MessageId(post_inc(&mut self.next_message_id.0)), - start: self.buffer.read(cx).anchor_before(range.end + 1), }; + self.messages.insert(start_message_ix + 1, suffix.clone()); self.messages_metadata.insert( suffix.id, @@ -757,19 +772,38 @@ impl Assistant { error: None, }, ); - if is_newline { - cx.emit(AssistantEvent::MessagesEdited); - } - if range.start == range.end || range.start == message_range.start { + + let new_messages = if range.start == range.end || range.start == message_range.start { (None, Some(suffix)) } else { - self.buffer.update(cx, |buffer, cx| { - buffer.edit([(range.start..range.start, "\n")], None, cx) - }); - let selection = Message { - id: MessageId(post_inc(&mut self.next_message_id.0)), - start: self.buffer.read(cx).anchor_before(range.start + 1), + let mut prefix_end = None; + if range.start > message_range.start && range.end < message_range.end - 1 { + if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') { + prefix_end = Some(range.start + 1); + } else if self.buffer.read(cx).reversed_chars_at(range.start).next() + == Some('\n') + { + prefix_end = Some(range.start); + } + } + + let selection = if let Some(prefix_end) = prefix_end { + cx.emit(AssistantEvent::MessagesEdited); + Message { + id: MessageId(post_inc(&mut self.next_message_id.0)), + start: self.buffer.read(cx).anchor_before(prefix_end), + } + } else { + self.buffer.update(cx, |buffer, cx| { + buffer.edit([(range.start..range.start, "\n")], None, cx) + }); + edited_buffer = true; + Message { + id: MessageId(post_inc(&mut self.next_message_id.0)), + start: self.buffer.read(cx).anchor_before(range.end + 1), + } }; + self.messages .insert(start_message_ix + 1, selection.clone()); self.messages_metadata.insert( @@ -781,7 +815,12 @@ impl Assistant { }, ); (Some(selection), Some(suffix)) + }; + + if !edited_buffer { + cx.emit(AssistantEvent::MessagesEdited); } + new_messages } else { (None, None) } @@ -1594,8 +1633,11 @@ mod tests { (message_3.id, Role::User, 5..6) ] ); + let (message_6, message_7) = assistant.update(cx, |assistant, cx| assistant.split_message(2..3, cx)); + + assert_eq!(buffer.read(cx).text(), "1C\n3\n\n\nD"); // We insert a newline for the new empty message let (message_6, message_7) = (message_6.unwrap(), message_7.unwrap()); assert_eq!( messages(&assistant, cx), @@ -1626,6 +1668,99 @@ mod tests { ); } + #[gpui::test] + fn test_message_splitting(cx: &mut AppContext) { + let registry = Arc::new(LanguageRegistry::test()); + let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx)); + let buffer = assistant.read(cx).buffer.clone(); + + let message_1 = assistant.read(cx).messages[0].clone(); + assert_eq!( + messages(&assistant, cx), + vec![(message_1.id, Role::User, 0..0)] + ); + + buffer.update(cx, |buffer, cx| { + buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx) + }); + + let (_, message_2) = + assistant.update(cx, |assistant, cx| assistant.split_message(3..3, cx)); + let message_2 = message_2.unwrap(); + + // We recycle newlines in the middle of a split message + assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n"); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_2.id, Role::User, 4..16), + ] + ); + + let (_, message_3) = + assistant.update(cx, |assistant, cx| assistant.split_message(3..3, cx)); + let message_3 = message_3.unwrap(); + + // We don't recycle newlines at the end of a split message + assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n"); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_3.id, Role::User, 4..5), + (message_2.id, Role::User, 5..17), + ] + ); + + let (_, message_4) = + assistant.update(cx, |assistant, cx| assistant.split_message(9..9, cx)); + let message_4 = message_4.unwrap(); + assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n"); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_3.id, Role::User, 4..5), + (message_2.id, Role::User, 5..9), + (message_4.id, Role::User, 9..17), + ] + ); + + let (_, message_5) = + assistant.update(cx, |assistant, cx| assistant.split_message(9..9, cx)); + let message_5 = message_5.unwrap(); + assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n"); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_3.id, Role::User, 4..5), + (message_2.id, Role::User, 5..9), + (message_4.id, Role::User, 9..10), + (message_5.id, Role::User, 10..18), + ] + ); + + let (message_6, message_7) = + assistant.update(cx, |assistant, cx| assistant.split_message(14..16, cx)); + let message_6 = message_6.unwrap(); + let message_7 = message_7.unwrap(); + assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n"); + assert_eq!( + messages(&assistant, cx), + vec![ + (message_1.id, Role::User, 0..4), + (message_3.id, Role::User, 4..5), + (message_2.id, Role::User, 5..9), + (message_4.id, Role::User, 9..10), + (message_5.id, Role::User, 10..14), + (message_6.id, Role::User, 14..17), + (message_7.id, Role::User, 17..19), + ] + ); + } + fn messages( assistant: &ModelHandle, cx: &AppContext,