Reduce segment cloning when rendering messages (#33340)

Richard Feldman created

While working on retries, I discovered some opportunities to reduce
cloning of message segments. These segments have full `String`s (not
`SharedString`s), so cloning them means copying cloning all the bytes of
all the strings in the message, which would be nice to avoid!

Release Notes:

- N/A

Change summary

crates/agent/src/thread.rs           |   7 +
crates/agent_ui/src/active_thread.rs | 177 +++++++++++++++--------------
2 files changed, 98 insertions(+), 86 deletions(-)

Detailed changes

crates/agent/src/thread.rs 🔗

@@ -198,6 +198,13 @@ impl MessageSegment {
             Self::RedactedThinking(_) => false,
         }
     }
+
+    pub fn text(&self) -> Option<&str> {
+        match self {
+            MessageSegment::Text(text) => Some(text),
+            _ => None,
+        }
+    }
 }
 
 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]

crates/agent_ui/src/active_thread.rs 🔗

@@ -809,7 +809,12 @@ impl ActiveThread {
         };
 
         for message in thread.read(cx).messages().cloned().collect::<Vec<_>>() {
-            this.push_message(&message.id, &message.segments, window, cx);
+            let rendered_message = RenderedMessage::from_segments(
+                &message.segments,
+                this.language_registry.clone(),
+                cx,
+            );
+            this.push_rendered_message(message.id, rendered_message);
 
             for tool_use in thread.read(cx).tool_uses_for_message(message.id, cx) {
                 this.render_tool_use_markdown(
@@ -875,36 +880,11 @@ impl ActiveThread {
         &self.text_thread_store
     }
 
-    fn push_message(
-        &mut self,
-        id: &MessageId,
-        segments: &[MessageSegment],
-        _window: &mut Window,
-        cx: &mut Context<Self>,
-    ) {
+    fn push_rendered_message(&mut self, id: MessageId, rendered_message: RenderedMessage) {
         let old_len = self.messages.len();
-        self.messages.push(*id);
+        self.messages.push(id);
         self.list_state.splice(old_len..old_len, 1);
-
-        let rendered_message =
-            RenderedMessage::from_segments(segments, self.language_registry.clone(), cx);
-        self.rendered_messages_by_id.insert(*id, rendered_message);
-    }
-
-    fn edited_message(
-        &mut self,
-        id: &MessageId,
-        segments: &[MessageSegment],
-        _window: &mut Window,
-        cx: &mut Context<Self>,
-    ) {
-        let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
-            return;
-        };
-        self.list_state.splice(index..index + 1, 1);
-        let rendered_message =
-            RenderedMessage::from_segments(segments, self.language_registry.clone(), cx);
-        self.rendered_messages_by_id.insert(*id, rendered_message);
+        self.rendered_messages_by_id.insert(id, rendered_message);
     }
 
     fn deleted_message(&mut self, id: &MessageId) {
@@ -1037,31 +1017,43 @@ impl ActiveThread {
                 }
             }
             ThreadEvent::MessageAdded(message_id) => {
-                if let Some(message_segments) = self
-                    .thread
-                    .read(cx)
-                    .message(*message_id)
-                    .map(|message| message.segments.clone())
-                {
-                    self.push_message(message_id, &message_segments, window, cx);
+                if let Some(rendered_message) = self.thread.update(cx, |thread, cx| {
+                    thread.message(*message_id).map(|message| {
+                        RenderedMessage::from_segments(
+                            &message.segments,
+                            self.language_registry.clone(),
+                            cx,
+                        )
+                    })
+                }) {
+                    self.push_rendered_message(*message_id, rendered_message);
                 }
 
                 self.save_thread(cx);
                 cx.notify();
             }
             ThreadEvent::MessageEdited(message_id) => {
-                if let Some(message_segments) = self
-                    .thread
-                    .read(cx)
-                    .message(*message_id)
-                    .map(|message| message.segments.clone())
-                {
-                    self.edited_message(message_id, &message_segments, window, cx);
+                if let Some(index) = self.messages.iter().position(|id| id == message_id) {
+                    if let Some(rendered_message) = self.thread.update(cx, |thread, cx| {
+                        thread.message(*message_id).map(|message| {
+                            let mut rendered_message = RenderedMessage {
+                                language_registry: self.language_registry.clone(),
+                                segments: Vec::with_capacity(message.segments.len()),
+                            };
+                            for segment in &message.segments {
+                                rendered_message.push_segment(segment, cx);
+                            }
+                            rendered_message
+                        })
+                    }) {
+                        self.list_state.splice(index..index + 1, 1);
+                        self.rendered_messages_by_id
+                            .insert(*message_id, rendered_message);
+                        self.scroll_to_bottom(cx);
+                        self.save_thread(cx);
+                        cx.notify();
+                    }
                 }
-
-                self.scroll_to_bottom(cx);
-                self.save_thread(cx);
-                cx.notify();
             }
             ThreadEvent::MessageDeleted(message_id) => {
                 self.deleted_message(message_id);
@@ -1311,17 +1303,11 @@ impl ActiveThread {
     fn start_editing_message(
         &mut self,
         message_id: MessageId,
-        message_segments: &[MessageSegment],
+        message_text: impl Into<Arc<str>>,
         message_creases: &[MessageCrease],
         window: &mut Window,
         cx: &mut Context<Self>,
     ) {
-        // User message should always consist of a single text segment,
-        // therefore we can skip returning early if it's not a text segment.
-        let Some(MessageSegment::Text(message_text)) = message_segments.first() else {
-            return;
-        };
-
         let editor = crate::message_editor::create_editor(
             self.workspace.clone(),
             self.context_store.downgrade(),
@@ -1333,7 +1319,7 @@ impl ActiveThread {
             cx,
         );
         editor.update(cx, |editor, cx| {
-            editor.set_text(message_text.clone(), window, cx);
+            editor.set_text(message_text, window, cx);
             insert_message_creases(editor, message_creases, &self.context_store, window, cx);
             editor.focus_handle(cx).focus(window);
             editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
@@ -1828,8 +1814,6 @@ impl ActiveThread {
             return div().children(loading_dots).into_any();
         }
 
-        let message_creases = message.creases.clone();
-
         let Some(rendered_message) = self.rendered_messages_by_id.get(&message_id) else {
             return Empty.into_any();
         };
@@ -2144,15 +2128,30 @@ impl ActiveThread {
                                 }),
                         )
                         .on_click(cx.listener({
-                            let message_segments = message.segments.clone();
+                            let message_creases = message.creases.clone();
                             move |this, _, window, cx| {
-                                this.start_editing_message(
-                                    message_id,
-                                    &message_segments,
-                                    &message_creases,
-                                    window,
-                                    cx,
-                                );
+                                if let Some(message_text) =
+                                    this.thread.read(cx).message(message_id).and_then(|message| {
+                                        message.segments.first().and_then(|segment| {
+                                            match segment {
+                                                MessageSegment::Text(message_text) => {
+                                                    Some(Into::<Arc<str>>::into(message_text.as_str()))
+                                                }
+                                                _ => {
+                                                    None
+                                                }
+                                            }
+                                        })
+                                    })
+                                {
+                                    this.start_editing_message(
+                                        message_id,
+                                        message_text,
+                                        &message_creases,
+                                        window,
+                                        cx,
+                                    );
+                                }
                             }
                         })),
                 ),
@@ -3826,13 +3825,15 @@ mod tests {
         });
 
         active_thread.update_in(cx, |active_thread, window, cx| {
-            active_thread.start_editing_message(
-                message.id,
-                message.segments.as_slice(),
-                message.creases.as_slice(),
-                window,
-                cx,
-            );
+            if let Some(message_text) = message.segments.first().and_then(MessageSegment::text) {
+                active_thread.start_editing_message(
+                    message.id,
+                    message_text,
+                    message.creases.as_slice(),
+                    window,
+                    cx,
+                );
+            }
             let editor = active_thread
                 .editing_message
                 .as_ref()
@@ -3847,13 +3848,15 @@ mod tests {
 
         let message = thread.update(cx, |thread, _| thread.message(message.id).cloned().unwrap());
         active_thread.update_in(cx, |active_thread, window, cx| {
-            active_thread.start_editing_message(
-                message.id,
-                message.segments.as_slice(),
-                message.creases.as_slice(),
-                window,
-                cx,
-            );
+            if let Some(message_text) = message.segments.first().and_then(MessageSegment::text) {
+                active_thread.start_editing_message(
+                    message.id,
+                    message_text,
+                    message.creases.as_slice(),
+                    window,
+                    cx,
+                );
+            }
             let editor = active_thread
                 .editing_message
                 .as_ref()
@@ -3935,13 +3938,15 @@ mod tests {
 
         // Edit the message while the completion is still running
         active_thread.update_in(cx, |active_thread, window, cx| {
-            active_thread.start_editing_message(
-                message.id,
-                message.segments.as_slice(),
-                message.creases.as_slice(),
-                window,
-                cx,
-            );
+            if let Some(message_text) = message.segments.first().and_then(MessageSegment::text) {
+                active_thread.start_editing_message(
+                    message.id,
+                    message_text,
+                    message.creases.as_slice(),
+                    window,
+                    cx,
+                );
+            }
             let editor = active_thread
                 .editing_message
                 .as_ref()