@@ -492,15 +492,14 @@ impl AssistantPanel {
}
let fs = self.fs.clone();
- let conversation = Conversation::load(
- path.clone(),
- self.api_key.clone(),
- self.languages.clone(),
- self.fs.clone(),
- cx,
- );
+ let api_key = self.api_key.clone();
+ let languages = self.languages.clone();
cx.spawn(|this, mut cx| async move {
- let conversation = conversation.await?;
+ let saved_conversation = fs.load(&path).await?;
+ let saved_conversation = serde_json::from_str(&saved_conversation)?;
+ let conversation = cx.add_model(|cx| {
+ Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
+ });
this.update(&mut cx, |this, cx| {
// If, by the time we've loaded the conversation, the user has already opened
// the same conversation, we don't want to open it again.
@@ -508,7 +507,7 @@ impl AssistantPanel {
this.set_active_editor_index(Some(ix), cx);
} else {
let editor = cx
- .add_view(|cx| ConversationEditor::from_conversation(conversation, fs, cx));
+ .add_view(|cx| ConversationEditor::for_conversation(conversation, fs, cx));
this.add_conversation(editor, cx);
}
})?;
@@ -861,72 +860,86 @@ impl Conversation {
this
}
- fn load(
+ fn serialize(&self, cx: &AppContext) -> SavedConversation {
+ SavedConversation {
+ zed: "conversation".into(),
+ version: SavedConversation::VERSION.into(),
+ text: self.buffer.read(cx).text(),
+ message_metadata: self.messages_metadata.clone(),
+ messages: self
+ .messages(cx)
+ .map(|message| SavedMessage {
+ id: message.id,
+ start: message.range.start,
+ })
+ .collect(),
+ summary: self
+ .summary
+ .as_ref()
+ .map(|summary| summary.text.clone())
+ .unwrap_or_default(),
+ model: self.model.clone(),
+ }
+ }
+
+ fn deserialize(
+ saved_conversation: SavedConversation,
path: PathBuf,
api_key: Rc<RefCell<Option<String>>>,
language_registry: Arc<LanguageRegistry>,
- fs: Arc<dyn Fs>,
- cx: &mut AppContext,
- ) -> Task<Result<ModelHandle<Self>>> {
- cx.spawn(|mut cx| async move {
- let saved_conversation = fs.load(&path).await?;
- let saved_conversation: SavedConversation = serde_json::from_str(&saved_conversation)?;
-
- let model = saved_conversation.model;
- let markdown = language_registry.language_for_name("Markdown");
- let mut message_anchors = Vec::new();
- let mut next_message_id = MessageId(0);
- let buffer = cx.add_model(|cx| {
- let mut buffer = Buffer::new(0, saved_conversation.text, cx);
- for message in saved_conversation.messages {
- message_anchors.push(MessageAnchor {
- id: message.id,
- start: buffer.anchor_before(message.start),
- });
- next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
- }
- buffer.set_language_registry(language_registry);
- cx.spawn_weak(|buffer, mut cx| async move {
- let markdown = markdown.await?;
- let buffer = buffer
- .upgrade(&cx)
- .ok_or_else(|| anyhow!("buffer was dropped"))?;
- buffer.update(&mut cx, |buffer, cx| {
- buffer.set_language(Some(markdown), cx)
- });
- anyhow::Ok(())
- })
- .detach_and_log_err(cx);
- buffer
- });
- let conversation = cx.add_model(|cx| {
- let mut this = Self {
- message_anchors,
- messages_metadata: saved_conversation.message_metadata,
- next_message_id,
- summary: Some(Summary {
- text: saved_conversation.summary,
- done: true,
- }),
- pending_summary: Task::ready(None),
- completion_count: Default::default(),
- pending_completions: Default::default(),
- token_count: None,
- max_token_count: tiktoken_rs::model::get_context_size(&model),
- pending_token_count: Task::ready(None),
- model,
- _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
- pending_save: Task::ready(Ok(())),
- path: Some(path),
- api_key,
- buffer,
- };
+ cx: &mut ModelContext<Self>,
+ ) -> Self {
+ let model = saved_conversation.model;
+ let markdown = language_registry.language_for_name("Markdown");
+ let mut message_anchors = Vec::new();
+ let mut next_message_id = MessageId(0);
+ let buffer = cx.add_model(|cx| {
+ let mut buffer = Buffer::new(0, saved_conversation.text, cx);
+ for message in saved_conversation.messages {
+ message_anchors.push(MessageAnchor {
+ id: message.id,
+ start: buffer.anchor_before(message.start),
+ });
+ next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
+ }
+ buffer.set_language_registry(language_registry);
+ cx.spawn_weak(|buffer, mut cx| async move {
+ let markdown = markdown.await?;
+ let buffer = buffer
+ .upgrade(&cx)
+ .ok_or_else(|| anyhow!("buffer was dropped"))?;
+ buffer.update(&mut cx, |buffer, cx| {
+ buffer.set_language(Some(markdown), cx)
+ });
+ anyhow::Ok(())
+ })
+ .detach_and_log_err(cx);
+ buffer
+ });
- this.count_remaining_tokens(cx);
- this
- });
- Ok(conversation)
- })
+ let mut this = Self {
+ message_anchors,
+ messages_metadata: saved_conversation.message_metadata,
+ next_message_id,
+ summary: Some(Summary {
+ text: saved_conversation.summary,
+ done: true,
+ }),
+ pending_summary: Task::ready(None),
+ completion_count: Default::default(),
+ pending_completions: Default::default(),
+ token_count: None,
+ max_token_count: tiktoken_rs::model::get_context_size(&model),
+ pending_token_count: Task::ready(None),
+ model,
+ _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
+ pending_save: Task::ready(Ok(())),
+ path: Some(path),
+ api_key,
+ buffer,
+ };
+ this.count_remaining_tokens(cx);
+ this
}
fn handle_buffer_event(
@@ -1453,23 +1466,7 @@ impl Conversation {
});
if let Some(summary) = summary {
- let conversation = this.read_with(&cx, |this, cx| SavedConversation {
- zed: "conversation".into(),
- version: SavedConversation::VERSION.into(),
- text: this.buffer.read(cx).text(),
- message_metadata: this.messages_metadata.clone(),
- messages: this
- .message_anchors
- .iter()
- .map(|message| SavedMessage {
- id: message.id,
- start: message.start.to_offset(this.buffer.read(cx)),
- })
- .collect(),
- summary: summary.clone(),
- model: this.model.clone(),
- });
-
+ let conversation = this.read_with(&cx, |this, cx| this.serialize(cx));
let path = if let Some(old_path) = old_path {
old_path
} else {
@@ -1533,10 +1530,10 @@ impl ConversationEditor {
cx: &mut ViewContext<Self>,
) -> Self {
let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
- Self::from_conversation(conversation, fs, cx)
+ Self::for_conversation(conversation, fs, cx)
}
- fn from_conversation(
+ fn for_conversation(
conversation: ModelHandle<Conversation>,
fs: Arc<dyn Fs>,
cx: &mut ViewContext<Self>,
@@ -2116,9 +2113,8 @@ async fn stream_completion(
#[cfg(test)]
mod tests {
- use crate::MessageId;
-
use super::*;
+ use crate::MessageId;
use fs::FakeFs;
use gpui::{AppContext, TestAppContext};
use project::Project;
@@ -2464,6 +2460,64 @@ mod tests {
}
}
+ #[gpui::test]
+ fn test_serialization(cx: &mut AppContext) {
+ let registry = Arc::new(LanguageRegistry::test());
+ let conversation =
+ cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
+ let buffer = conversation.read(cx).buffer.clone();
+ let message_0 = conversation.read(cx).message_anchors[0].id;
+ let message_1 = conversation.update(cx, |conversation, cx| {
+ conversation
+ .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
+ .unwrap()
+ });
+ let message_2 = conversation.update(cx, |conversation, cx| {
+ conversation
+ .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
+ .unwrap()
+ });
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
+ buffer.finalize_last_transaction();
+ });
+ let _message_3 = conversation.update(cx, |conversation, cx| {
+ conversation
+ .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
+ .unwrap()
+ });
+ buffer.update(cx, |buffer, cx| buffer.undo(cx));
+ assert_eq!(buffer.read(cx).text(), "a\nb\nc\n");
+ assert_eq!(
+ messages(&conversation, cx),
+ [
+ (message_0, Role::User, 0..2),
+ (message_1.id, Role::Assistant, 2..6),
+ (message_2.id, Role::System, 6..6),
+ ]
+ );
+
+ let deserialized_conversation = cx.add_model(|cx| {
+ Conversation::deserialize(
+ conversation.read(cx).serialize(cx),
+ Default::default(),
+ Default::default(),
+ registry.clone(),
+ cx,
+ )
+ });
+ let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone();
+ assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n");
+ assert_eq!(
+ messages(&deserialized_conversation, cx),
+ [
+ (message_0, Role::User, 0..2),
+ (message_1.id, Role::Assistant, 2..6),
+ (message_2.id, Role::System, 6..6),
+ ]
+ );
+ }
+
fn messages(
conversation: &ModelHandle<Conversation>,
cx: &AppContext,