Test serialization roundtrip

Antonio Scandurra created

Change summary

crates/ai/src/assistant.rs | 238 ++++++++++++++++++++++++---------------
1 file changed, 146 insertions(+), 92 deletions(-)

Detailed changes

crates/ai/src/assistant.rs 🔗

@@ -492,15 +492,14 @@ impl AssistantPanel {
         }
 
         let fs = self.fs.clone();
-        let conversation = Conversation::load(
-            path.clone(),
-            self.api_key.clone(),
-            self.languages.clone(),
-            self.fs.clone(),
-            cx,
-        );
+        let api_key = self.api_key.clone();
+        let languages = self.languages.clone();
         cx.spawn(|this, mut cx| async move {
-            let conversation = conversation.await?;
+            let saved_conversation = fs.load(&path).await?;
+            let saved_conversation = serde_json::from_str(&saved_conversation)?;
+            let conversation = cx.add_model(|cx| {
+                Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
+            });
             this.update(&mut cx, |this, cx| {
                 // If, by the time we've loaded the conversation, the user has already opened
                 // the same conversation, we don't want to open it again.
@@ -508,7 +507,7 @@ impl AssistantPanel {
                     this.set_active_editor_index(Some(ix), cx);
                 } else {
                     let editor = cx
-                        .add_view(|cx| ConversationEditor::from_conversation(conversation, fs, cx));
+                        .add_view(|cx| ConversationEditor::for_conversation(conversation, fs, cx));
                     this.add_conversation(editor, cx);
                 }
             })?;
@@ -861,72 +860,86 @@ impl Conversation {
         this
     }
 
-    fn load(
+    fn serialize(&self, cx: &AppContext) -> SavedConversation {
+        SavedConversation {
+            zed: "conversation".into(),
+            version: SavedConversation::VERSION.into(),
+            text: self.buffer.read(cx).text(),
+            message_metadata: self.messages_metadata.clone(),
+            messages: self
+                .messages(cx)
+                .map(|message| SavedMessage {
+                    id: message.id,
+                    start: message.range.start,
+                })
+                .collect(),
+            summary: self
+                .summary
+                .as_ref()
+                .map(|summary| summary.text.clone())
+                .unwrap_or_default(),
+            model: self.model.clone(),
+        }
+    }
+
+    fn deserialize(
+        saved_conversation: SavedConversation,
         path: PathBuf,
         api_key: Rc<RefCell<Option<String>>>,
         language_registry: Arc<LanguageRegistry>,
-        fs: Arc<dyn Fs>,
-        cx: &mut AppContext,
-    ) -> Task<Result<ModelHandle<Self>>> {
-        cx.spawn(|mut cx| async move {
-            let saved_conversation = fs.load(&path).await?;
-            let saved_conversation: SavedConversation = serde_json::from_str(&saved_conversation)?;
-
-            let model = saved_conversation.model;
-            let markdown = language_registry.language_for_name("Markdown");
-            let mut message_anchors = Vec::new();
-            let mut next_message_id = MessageId(0);
-            let buffer = cx.add_model(|cx| {
-                let mut buffer = Buffer::new(0, saved_conversation.text, cx);
-                for message in saved_conversation.messages {
-                    message_anchors.push(MessageAnchor {
-                        id: message.id,
-                        start: buffer.anchor_before(message.start),
-                    });
-                    next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
-                }
-                buffer.set_language_registry(language_registry);
-                cx.spawn_weak(|buffer, mut cx| async move {
-                    let markdown = markdown.await?;
-                    let buffer = buffer
-                        .upgrade(&cx)
-                        .ok_or_else(|| anyhow!("buffer was dropped"))?;
-                    buffer.update(&mut cx, |buffer, cx| {
-                        buffer.set_language(Some(markdown), cx)
-                    });
-                    anyhow::Ok(())
-                })
-                .detach_and_log_err(cx);
-                buffer
-            });
-            let conversation = cx.add_model(|cx| {
-                let mut this = Self {
-                    message_anchors,
-                    messages_metadata: saved_conversation.message_metadata,
-                    next_message_id,
-                    summary: Some(Summary {
-                        text: saved_conversation.summary,
-                        done: true,
-                    }),
-                    pending_summary: Task::ready(None),
-                    completion_count: Default::default(),
-                    pending_completions: Default::default(),
-                    token_count: None,
-                    max_token_count: tiktoken_rs::model::get_context_size(&model),
-                    pending_token_count: Task::ready(None),
-                    model,
-                    _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
-                    pending_save: Task::ready(Ok(())),
-                    path: Some(path),
-                    api_key,
-                    buffer,
-                };
+        cx: &mut ModelContext<Self>,
+    ) -> Self {
+        let model = saved_conversation.model;
+        let markdown = language_registry.language_for_name("Markdown");
+        let mut message_anchors = Vec::new();
+        let mut next_message_id = MessageId(0);
+        let buffer = cx.add_model(|cx| {
+            let mut buffer = Buffer::new(0, saved_conversation.text, cx);
+            for message in saved_conversation.messages {
+                message_anchors.push(MessageAnchor {
+                    id: message.id,
+                    start: buffer.anchor_before(message.start),
+                });
+                next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
+            }
+            buffer.set_language_registry(language_registry);
+            cx.spawn_weak(|buffer, mut cx| async move {
+                let markdown = markdown.await?;
+                let buffer = buffer
+                    .upgrade(&cx)
+                    .ok_or_else(|| anyhow!("buffer was dropped"))?;
+                buffer.update(&mut cx, |buffer, cx| {
+                    buffer.set_language(Some(markdown), cx)
+                });
+                anyhow::Ok(())
+            })
+            .detach_and_log_err(cx);
+            buffer
+        });
 
-                this.count_remaining_tokens(cx);
-                this
-            });
-            Ok(conversation)
-        })
+        let mut this = Self {
+            message_anchors,
+            messages_metadata: saved_conversation.message_metadata,
+            next_message_id,
+            summary: Some(Summary {
+                text: saved_conversation.summary,
+                done: true,
+            }),
+            pending_summary: Task::ready(None),
+            completion_count: Default::default(),
+            pending_completions: Default::default(),
+            token_count: None,
+            max_token_count: tiktoken_rs::model::get_context_size(&model),
+            pending_token_count: Task::ready(None),
+            model,
+            _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
+            pending_save: Task::ready(Ok(())),
+            path: Some(path),
+            api_key,
+            buffer,
+        };
+        this.count_remaining_tokens(cx);
+        this
     }
 
     fn handle_buffer_event(
@@ -1453,23 +1466,7 @@ impl Conversation {
             });
 
             if let Some(summary) = summary {
-                let conversation = this.read_with(&cx, |this, cx| SavedConversation {
-                    zed: "conversation".into(),
-                    version: SavedConversation::VERSION.into(),
-                    text: this.buffer.read(cx).text(),
-                    message_metadata: this.messages_metadata.clone(),
-                    messages: this
-                        .message_anchors
-                        .iter()
-                        .map(|message| SavedMessage {
-                            id: message.id,
-                            start: message.start.to_offset(this.buffer.read(cx)),
-                        })
-                        .collect(),
-                    summary: summary.clone(),
-                    model: this.model.clone(),
-                });
-
+                let conversation = this.read_with(&cx, |this, cx| this.serialize(cx));
                 let path = if let Some(old_path) = old_path {
                     old_path
                 } else {
@@ -1533,10 +1530,10 @@ impl ConversationEditor {
         cx: &mut ViewContext<Self>,
     ) -> Self {
         let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
-        Self::from_conversation(conversation, fs, cx)
+        Self::for_conversation(conversation, fs, cx)
     }
 
-    fn from_conversation(
+    fn for_conversation(
         conversation: ModelHandle<Conversation>,
         fs: Arc<dyn Fs>,
         cx: &mut ViewContext<Self>,
@@ -2116,9 +2113,8 @@ async fn stream_completion(
 
 #[cfg(test)]
 mod tests {
-    use crate::MessageId;
-
     use super::*;
+    use crate::MessageId;
     use fs::FakeFs;
     use gpui::{AppContext, TestAppContext};
     use project::Project;
@@ -2464,6 +2460,64 @@ mod tests {
         }
     }
 
+    #[gpui::test]
+    fn test_serialization(cx: &mut AppContext) {
+        let registry = Arc::new(LanguageRegistry::test());
+        let conversation =
+            cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
+        let buffer = conversation.read(cx).buffer.clone();
+        let message_0 = conversation.read(cx).message_anchors[0].id;
+        let message_1 = conversation.update(cx, |conversation, cx| {
+            conversation
+                .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
+                .unwrap()
+        });
+        let message_2 = conversation.update(cx, |conversation, cx| {
+            conversation
+                .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
+                .unwrap()
+        });
+        buffer.update(cx, |buffer, cx| {
+            buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
+            buffer.finalize_last_transaction();
+        });
+        let _message_3 = conversation.update(cx, |conversation, cx| {
+            conversation
+                .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
+                .unwrap()
+        });
+        buffer.update(cx, |buffer, cx| buffer.undo(cx));
+        assert_eq!(buffer.read(cx).text(), "a\nb\nc\n");
+        assert_eq!(
+            messages(&conversation, cx),
+            [
+                (message_0, Role::User, 0..2),
+                (message_1.id, Role::Assistant, 2..6),
+                (message_2.id, Role::System, 6..6),
+            ]
+        );
+
+        let deserialized_conversation = cx.add_model(|cx| {
+            Conversation::deserialize(
+                conversation.read(cx).serialize(cx),
+                Default::default(),
+                Default::default(),
+                registry.clone(),
+                cx,
+            )
+        });
+        let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone();
+        assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n");
+        assert_eq!(
+            messages(&deserialized_conversation, cx),
+            [
+                (message_0, Role::User, 0..2),
+                (message_1.id, Role::Assistant, 2..6),
+                (message_2.id, Role::System, 6..6),
+            ]
+        );
+    }
+
     fn messages(
         conversation: &ModelHandle<Conversation>,
         cx: &AppContext,