Insert new message right before the next valid one

Antonio Scandurra created

Change summary

crates/ai/src/assistant.rs | 141 ++++++++++++++++++++++-----------------
1 file changed, 80 insertions(+), 61 deletions(-)

Detailed changes

crates/ai/src/assistant.rs 🔗

@@ -870,7 +870,7 @@ impl Conversation {
                 .messages(cx)
                 .map(|message| SavedMessage {
                     id: message.id,
-                    start: message.range.start,
+                    start: message.offset_range.start,
                 })
                 .collect(),
             summary: self
@@ -968,7 +968,11 @@ impl Conversation {
                         Role::Assistant => "assistant".into(),
                         Role::System => "system".into(),
                     },
-                    content: self.buffer.read(cx).text_for_range(message.range).collect(),
+                    content: self
+                        .buffer
+                        .read(cx)
+                        .text_for_range(message.offset_range)
+                        .collect(),
                     name: None,
                 })
             })
@@ -1183,10 +1187,19 @@ impl Conversation {
             .iter()
             .position(|message| message.id == message_id)
         {
+            // Find the next valid message after the one we were given.
+            let mut next_message_ix = prev_message_ix + 1;
+            while let Some(next_message) = self.message_anchors.get(next_message_ix) {
+                if next_message.start.is_valid(self.buffer.read(cx)) {
+                    break;
+                }
+                next_message_ix += 1;
+            }
+
             let start = self.buffer.update(cx, |buffer, cx| {
-                let offset = self.message_anchors[prev_message_ix + 1..]
-                    .iter()
-                    .find(|message| message.start.is_valid(buffer))
+                let offset = self
+                    .message_anchors
+                    .get(next_message_ix)
                     .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
                 buffer.edit([(offset..offset, "\n")], None, cx);
                 buffer.anchor_before(offset + 1)
@@ -1196,7 +1209,7 @@ impl Conversation {
                 start,
             };
             self.message_anchors
-                .insert(prev_message_ix + 1, message.clone());
+                .insert(next_message_ix, message.clone());
             self.messages_metadata.insert(
                 message.id,
                 MessageMetadata {
@@ -1221,7 +1234,7 @@ impl Conversation {
         let end_message = self.message_for_offset(range.end, cx);
         if let Some((start_message, end_message)) = start_message.zip(end_message) {
             // Prevent splitting when range spans multiple messages.
-            if start_message.index != end_message.index {
+            if start_message.id != end_message.id {
                 return (None, None);
             }
 
@@ -1230,7 +1243,8 @@ impl Conversation {
             let mut edited_buffer = false;
 
             let mut suffix_start = None;
-            if range.start > message.range.start && range.end < message.range.end - 1 {
+            if range.start > message.offset_range.start && range.end < message.offset_range.end - 1
+            {
                 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
                     suffix_start = Some(range.end + 1);
                 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
@@ -1255,7 +1269,7 @@ impl Conversation {
             };
 
             self.message_anchors
-                .insert(message.index + 1, suffix.clone());
+                .insert(message.index_range.end + 1, suffix.clone());
             self.messages_metadata.insert(
                 suffix.id,
                 MessageMetadata {
@@ -1265,49 +1279,52 @@ impl Conversation {
                 },
             );
 
-            let new_messages = if range.start == range.end || range.start == message.range.start {
-                (None, Some(suffix))
-            } else {
-                let mut prefix_end = None;
-                if range.start > message.range.start && range.end < message.range.end - 1 {
-                    if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
-                        prefix_end = Some(range.start + 1);
-                    } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
-                        == Some('\n')
+            let new_messages =
+                if range.start == range.end || range.start == message.offset_range.start {
+                    (None, Some(suffix))
+                } else {
+                    let mut prefix_end = None;
+                    if range.start > message.offset_range.start
+                        && range.end < message.offset_range.end - 1
                     {
-                        prefix_end = Some(range.start);
+                        if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
+                            prefix_end = Some(range.start + 1);
+                        } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
+                            == Some('\n')
+                        {
+                            prefix_end = Some(range.start);
+                        }
                     }
-                }
 
-                let selection = if let Some(prefix_end) = prefix_end {
-                    cx.emit(ConversationEvent::MessagesEdited);
-                    MessageAnchor {
-                        id: MessageId(post_inc(&mut self.next_message_id.0)),
-                        start: self.buffer.read(cx).anchor_before(prefix_end),
-                    }
-                } else {
-                    self.buffer.update(cx, |buffer, cx| {
-                        buffer.edit([(range.start..range.start, "\n")], None, cx)
-                    });
-                    edited_buffer = true;
-                    MessageAnchor {
-                        id: MessageId(post_inc(&mut self.next_message_id.0)),
-                        start: self.buffer.read(cx).anchor_before(range.end + 1),
-                    }
-                };
+                    let selection = if let Some(prefix_end) = prefix_end {
+                        cx.emit(ConversationEvent::MessagesEdited);
+                        MessageAnchor {
+                            id: MessageId(post_inc(&mut self.next_message_id.0)),
+                            start: self.buffer.read(cx).anchor_before(prefix_end),
+                        }
+                    } else {
+                        self.buffer.update(cx, |buffer, cx| {
+                            buffer.edit([(range.start..range.start, "\n")], None, cx)
+                        });
+                        edited_buffer = true;
+                        MessageAnchor {
+                            id: MessageId(post_inc(&mut self.next_message_id.0)),
+                            start: self.buffer.read(cx).anchor_before(range.end + 1),
+                        }
+                    };
 
-                self.message_anchors
-                    .insert(message.index + 1, selection.clone());
-                self.messages_metadata.insert(
-                    selection.id,
-                    MessageMetadata {
-                        role,
-                        sent_at: Local::now(),
-                        status: MessageStatus::Done,
-                    },
-                );
-                (Some(selection), Some(suffix))
-            };
+                    self.message_anchors
+                        .insert(message.index_range.end + 1, selection.clone());
+                    self.messages_metadata.insert(
+                        selection.id,
+                        MessageMetadata {
+                            role,
+                            sent_at: Local::now(),
+                            status: MessageStatus::Done,
+                        },
+                    );
+                    (Some(selection), Some(suffix))
+                };
 
             if !edited_buffer {
                 cx.emit(ConversationEvent::MessagesEdited);
@@ -1389,7 +1406,7 @@ impl Conversation {
         while let Some(offset) = offsets.next() {
             // Locate the message that contains the offset.
             while current_message.as_ref().map_or(false, |message| {
-                !message.range.contains(&offset) && messages.peek().is_some()
+                !message.offset_range.contains(&offset) && messages.peek().is_some()
             }) {
                 current_message = messages.next();
             }
@@ -1397,7 +1414,7 @@ impl Conversation {
 
             // Skip offsets that are in the same message.
             while offsets.peek().map_or(false, |offset| {
-                message.range.contains(offset) || messages.peek().is_none()
+                message.offset_range.contains(offset) || messages.peek().is_none()
             }) {
                 offsets.next();
             }
@@ -1411,15 +1428,17 @@ impl Conversation {
         let buffer = self.buffer.read(cx);
         let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
         iter::from_fn(move || {
-            while let Some((ix, message_anchor)) = message_anchors.next() {
+            while let Some((start_ix, message_anchor)) = message_anchors.next() {
                 let metadata = self.messages_metadata.get(&message_anchor.id)?;
                 let message_start = message_anchor.start.to_offset(buffer);
                 let mut message_end = None;
+                let mut end_ix = start_ix;
                 while let Some((_, next_message)) = message_anchors.peek() {
                     if next_message.start.is_valid(buffer) {
                         message_end = Some(next_message.start);
                         break;
                     } else {
+                        end_ix += 1;
                         message_anchors.next();
                     }
                 }
@@ -1427,8 +1446,8 @@ impl Conversation {
                     .unwrap_or(language::Anchor::MAX)
                     .to_offset(buffer);
                 return Some(Message {
-                    index: ix,
-                    range: message_start..message_end,
+                    index_range: start_ix..end_ix,
+                    offset_range: message_start..message_end,
                     id: message_anchor.id,
                     anchor: message_anchor.start,
                     role: metadata.role,
@@ -1885,11 +1904,11 @@ impl ConversationEditor {
             let mut copied_text = String::new();
             let mut spanned_messages = 0;
             for message in conversation.messages(cx) {
-                if message.range.start >= selection.range().end {
+                if message.offset_range.start >= selection.range().end {
                     break;
-                } else if message.range.end >= selection.range().start {
-                    let range = cmp::max(message.range.start, selection.range().start)
-                        ..cmp::min(message.range.end, selection.range().end);
+                } else if message.offset_range.end >= selection.range().start {
+                    let range = cmp::max(message.offset_range.start, selection.range().start)
+                        ..cmp::min(message.offset_range.end, selection.range().end);
                     if !range.is_empty() {
                         spanned_messages += 1;
                         write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
@@ -2005,8 +2024,8 @@ struct MessageAnchor {
 
 #[derive(Clone, Debug)]
 pub struct Message {
-    range: Range<usize>,
-    index: usize,
+    offset_range: Range<usize>,
+    index_range: Range<usize>,
     id: MessageId,
     anchor: language::Anchor,
     role: Role,
@@ -2017,7 +2036,7 @@ pub struct Message {
 impl Message {
     fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
         let mut content = format!("[Message {}]\n", self.id.0).to_string();
-        content.extend(buffer.text_for_range(self.range.clone()));
+        content.extend(buffer.text_for_range(self.offset_range.clone()));
         RequestMessage {
             role: self.role,
             content: content.trim_end().into(),
@@ -2525,7 +2544,7 @@ mod tests {
         conversation
             .read(cx)
             .messages(cx)
-            .map(|message| (message.id, message.role, message.range))
+            .map(|message| (message.id, message.role, message.offset_range))
             .collect()
     }
 }