Avoid inserting redundant newlines

Nathan Sobo , Piotr Osiewicz , and Antonio Scandurra created

Co-Authored-By: Piotr Osiewicz <piotr@zed.dev>
Co-Authored-By: Antonio Scandurra <antonio@zed.dev>

Change summary

crates/ai/src/assistant.rs | 173 +++++++++++++++++++++++++++++++++++----
1 file changed, 154 insertions(+), 19 deletions(-)

Detailed changes

crates/ai/src/assistant.rs 🔗

@@ -736,18 +736,33 @@ impl Assistant {
             }
 
             let role = metadata.role;
-            let is_newline = self.buffer.update(cx, |buffer, cx| {
-                if buffer.chars_at(range.end).next() != Some('\n') {
+            let mut edited_buffer = false;
+
+            let mut suffix_start = None;
+            if range.start > message_range.start && range.end < message_range.end - 1 {
+                if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
+                    suffix_start = Some(range.end + 1);
+                } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
+                    suffix_start = Some(range.end);
+                }
+            }
+
+            let suffix = if let Some(suffix_start) = suffix_start {
+                Message {
+                    id: MessageId(post_inc(&mut self.next_message_id.0)),
+                    start: self.buffer.read(cx).anchor_before(suffix_start),
+                }
+            } else {
+                self.buffer.update(cx, |buffer, cx| {
                     buffer.edit([(range.end..range.end, "\n")], None, cx);
-                    false
-                } else {
-                    true
+                });
+                edited_buffer = true;
+                Message {
+                    id: MessageId(post_inc(&mut self.next_message_id.0)),
+                    start: self.buffer.read(cx).anchor_before(range.end + 1),
                 }
-            });
-            let suffix = Message {
-                id: MessageId(post_inc(&mut self.next_message_id.0)),
-                start: self.buffer.read(cx).anchor_before(range.end + 1),
             };
+
             self.messages.insert(start_message_ix + 1, suffix.clone());
             self.messages_metadata.insert(
                 suffix.id,
@@ -757,19 +772,38 @@ impl Assistant {
                     error: None,
                 },
             );
-            if is_newline {
-                cx.emit(AssistantEvent::MessagesEdited);
-            }
-            if range.start == range.end || range.start == message_range.start {
+
+            let new_messages = if range.start == range.end || range.start == message_range.start {
                 (None, Some(suffix))
             } else {
-                self.buffer.update(cx, |buffer, cx| {
-                    buffer.edit([(range.start..range.start, "\n")], None, cx)
-                });
-                let selection = Message {
-                    id: MessageId(post_inc(&mut self.next_message_id.0)),
-                    start: self.buffer.read(cx).anchor_before(range.start + 1),
+                let mut prefix_end = None;
+                if range.start > message_range.start && range.end < message_range.end - 1 {
+                    if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
+                        prefix_end = Some(range.start + 1);
+                    } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
+                        == Some('\n')
+                    {
+                        prefix_end = Some(range.start);
+                    }
+                }
+
+                let selection = if let Some(prefix_end) = prefix_end {
+                    cx.emit(AssistantEvent::MessagesEdited);
+                    Message {
+                        id: MessageId(post_inc(&mut self.next_message_id.0)),
+                        start: self.buffer.read(cx).anchor_before(prefix_end),
+                    }
+                } else {
+                    self.buffer.update(cx, |buffer, cx| {
+                        buffer.edit([(range.start..range.start, "\n")], None, cx)
+                    });
+                    edited_buffer = true;
+                    Message {
+                        id: MessageId(post_inc(&mut self.next_message_id.0)),
+                        start: self.buffer.read(cx).anchor_before(range.end + 1),
+                    }
                 };
+
                 self.messages
                     .insert(start_message_ix + 1, selection.clone());
                 self.messages_metadata.insert(
@@ -781,7 +815,12 @@ impl Assistant {
                     },
                 );
                 (Some(selection), Some(suffix))
+            };
+
+            if !edited_buffer {
+                cx.emit(AssistantEvent::MessagesEdited);
             }
+            new_messages
         } else {
             (None, None)
         }
@@ -1594,8 +1633,11 @@ mod tests {
                 (message_3.id, Role::User, 5..6)
             ]
         );
+
         let (message_6, message_7) =
             assistant.update(cx, |assistant, cx| assistant.split_message(2..3, cx));
+
+        assert_eq!(buffer.read(cx).text(), "1C\n3\n\n\nD"); // We insert a newline for the new empty message
         let (message_6, message_7) = (message_6.unwrap(), message_7.unwrap());
         assert_eq!(
             messages(&assistant, cx),
@@ -1626,6 +1668,99 @@ mod tests {
         );
     }
 
+    #[gpui::test]
+    fn test_message_splitting(cx: &mut AppContext) {
+        let registry = Arc::new(LanguageRegistry::test());
+        let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
+        let buffer = assistant.read(cx).buffer.clone();
+
+        let message_1 = assistant.read(cx).messages[0].clone();
+        assert_eq!(
+            messages(&assistant, cx),
+            vec![(message_1.id, Role::User, 0..0)]
+        );
+
+        buffer.update(cx, |buffer, cx| {
+            buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
+        });
+
+        let (_, message_2) =
+            assistant.update(cx, |assistant, cx| assistant.split_message(3..3, cx));
+        let message_2 = message_2.unwrap();
+
+        // We recycle newlines in the middle of a split message
+        assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
+        assert_eq!(
+            messages(&assistant, cx),
+            vec![
+                (message_1.id, Role::User, 0..4),
+                (message_2.id, Role::User, 4..16),
+            ]
+        );
+
+        let (_, message_3) =
+            assistant.update(cx, |assistant, cx| assistant.split_message(3..3, cx));
+        let message_3 = message_3.unwrap();
+
+        // We don't recycle newlines at the end of a split message
+        assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
+        assert_eq!(
+            messages(&assistant, cx),
+            vec![
+                (message_1.id, Role::User, 0..4),
+                (message_3.id, Role::User, 4..5),
+                (message_2.id, Role::User, 5..17),
+            ]
+        );
+
+        let (_, message_4) =
+            assistant.update(cx, |assistant, cx| assistant.split_message(9..9, cx));
+        let message_4 = message_4.unwrap();
+        assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
+        assert_eq!(
+            messages(&assistant, cx),
+            vec![
+                (message_1.id, Role::User, 0..4),
+                (message_3.id, Role::User, 4..5),
+                (message_2.id, Role::User, 5..9),
+                (message_4.id, Role::User, 9..17),
+            ]
+        );
+
+        let (_, message_5) =
+            assistant.update(cx, |assistant, cx| assistant.split_message(9..9, cx));
+        let message_5 = message_5.unwrap();
+        assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
+        assert_eq!(
+            messages(&assistant, cx),
+            vec![
+                (message_1.id, Role::User, 0..4),
+                (message_3.id, Role::User, 4..5),
+                (message_2.id, Role::User, 5..9),
+                (message_4.id, Role::User, 9..10),
+                (message_5.id, Role::User, 10..18),
+            ]
+        );
+
+        let (message_6, message_7) =
+            assistant.update(cx, |assistant, cx| assistant.split_message(14..16, cx));
+        let message_6 = message_6.unwrap();
+        let message_7 = message_7.unwrap();
+        assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
+        assert_eq!(
+            messages(&assistant, cx),
+            vec![
+                (message_1.id, Role::User, 0..4),
+                (message_3.id, Role::User, 4..5),
+                (message_2.id, Role::User, 5..9),
+                (message_4.id, Role::User, 9..10),
+                (message_5.id, Role::User, 10..14),
+                (message_6.id, Role::User, 14..17),
+                (message_7.id, Role::User, 17..19),
+            ]
+        );
+    }
+
     fn messages(
         assistant: &ModelHandle<Assistant>,
         cx: &AppContext,