diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index b56695d08aefc237393948acf758583f6a64c852..6e23c1e7a07b665fec1e95f6730e1c2471bac07b 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -492,15 +492,14 @@ impl AssistantPanel { } let fs = self.fs.clone(); - let conversation = Conversation::load( - path.clone(), - self.api_key.clone(), - self.languages.clone(), - self.fs.clone(), - cx, - ); + let api_key = self.api_key.clone(); + let languages = self.languages.clone(); cx.spawn(|this, mut cx| async move { - let conversation = conversation.await?; + let saved_conversation = fs.load(&path).await?; + let saved_conversation = serde_json::from_str(&saved_conversation)?; + let conversation = cx.add_model(|cx| { + Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx) + }); this.update(&mut cx, |this, cx| { // If, by the time we've loaded the conversation, the user has already opened // the same conversation, we don't want to open it again. @@ -508,7 +507,7 @@ impl AssistantPanel { this.set_active_editor_index(Some(ix), cx); } else { let editor = cx - .add_view(|cx| ConversationEditor::from_conversation(conversation, fs, cx)); + .add_view(|cx| ConversationEditor::for_conversation(conversation, fs, cx)); this.add_conversation(editor, cx); } })?; @@ -861,72 +860,86 @@ impl Conversation { this } - fn load( + fn serialize(&self, cx: &AppContext) -> SavedConversation { + SavedConversation { + zed: "conversation".into(), + version: SavedConversation::VERSION.into(), + text: self.buffer.read(cx).text(), + message_metadata: self.messages_metadata.clone(), + messages: self + .messages(cx) + .map(|message| SavedMessage { + id: message.id, + start: message.range.start, + }) + .collect(), + summary: self + .summary + .as_ref() + .map(|summary| summary.text.clone()) + .unwrap_or_default(), + model: self.model.clone(), + } + } + + fn deserialize( + saved_conversation: SavedConversation, path: PathBuf, api_key: Rc>>, language_registry: Arc, - fs: Arc, - cx: &mut AppContext, - ) -> Task>> { - cx.spawn(|mut cx| async move { - let saved_conversation = fs.load(&path).await?; - let saved_conversation: SavedConversation = serde_json::from_str(&saved_conversation)?; - - let model = saved_conversation.model; - let markdown = language_registry.language_for_name("Markdown"); - let mut message_anchors = Vec::new(); - let mut next_message_id = MessageId(0); - let buffer = cx.add_model(|cx| { - let mut buffer = Buffer::new(0, saved_conversation.text, cx); - for message in saved_conversation.messages { - message_anchors.push(MessageAnchor { - id: message.id, - start: buffer.anchor_before(message.start), - }); - next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1)); - } - buffer.set_language_registry(language_registry); - cx.spawn_weak(|buffer, mut cx| async move { - let markdown = markdown.await?; - let buffer = buffer - .upgrade(&cx) - .ok_or_else(|| anyhow!("buffer was dropped"))?; - buffer.update(&mut cx, |buffer, cx| { - buffer.set_language(Some(markdown), cx) - }); - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - buffer - }); - let conversation = cx.add_model(|cx| { - let mut this = Self { - message_anchors, - messages_metadata: saved_conversation.message_metadata, - next_message_id, - summary: Some(Summary { - text: saved_conversation.summary, - done: true, - }), - pending_summary: Task::ready(None), - completion_count: Default::default(), - pending_completions: Default::default(), - token_count: None, - max_token_count: tiktoken_rs::model::get_context_size(&model), - pending_token_count: Task::ready(None), - model, - _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], - pending_save: Task::ready(Ok(())), - path: Some(path), - api_key, - buffer, - }; + cx: &mut ModelContext, + ) -> Self { + let model = saved_conversation.model; + let markdown = language_registry.language_for_name("Markdown"); + let mut message_anchors = Vec::new(); + let mut next_message_id = MessageId(0); + let buffer = cx.add_model(|cx| { + let mut buffer = Buffer::new(0, saved_conversation.text, cx); + for message in saved_conversation.messages { + message_anchors.push(MessageAnchor { + id: message.id, + start: buffer.anchor_before(message.start), + }); + next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1)); + } + buffer.set_language_registry(language_registry); + cx.spawn_weak(|buffer, mut cx| async move { + let markdown = markdown.await?; + let buffer = buffer + .upgrade(&cx) + .ok_or_else(|| anyhow!("buffer was dropped"))?; + buffer.update(&mut cx, |buffer, cx| { + buffer.set_language(Some(markdown), cx) + }); + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + buffer + }); - this.count_remaining_tokens(cx); - this - }); - Ok(conversation) - }) + let mut this = Self { + message_anchors, + messages_metadata: saved_conversation.message_metadata, + next_message_id, + summary: Some(Summary { + text: saved_conversation.summary, + done: true, + }), + pending_summary: Task::ready(None), + completion_count: Default::default(), + pending_completions: Default::default(), + token_count: None, + max_token_count: tiktoken_rs::model::get_context_size(&model), + pending_token_count: Task::ready(None), + model, + _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], + pending_save: Task::ready(Ok(())), + path: Some(path), + api_key, + buffer, + }; + this.count_remaining_tokens(cx); + this } fn handle_buffer_event( @@ -1453,23 +1466,7 @@ impl Conversation { }); if let Some(summary) = summary { - let conversation = this.read_with(&cx, |this, cx| SavedConversation { - zed: "conversation".into(), - version: SavedConversation::VERSION.into(), - text: this.buffer.read(cx).text(), - message_metadata: this.messages_metadata.clone(), - messages: this - .message_anchors - .iter() - .map(|message| SavedMessage { - id: message.id, - start: message.start.to_offset(this.buffer.read(cx)), - }) - .collect(), - summary: summary.clone(), - model: this.model.clone(), - }); - + let conversation = this.read_with(&cx, |this, cx| this.serialize(cx)); let path = if let Some(old_path) = old_path { old_path } else { @@ -1533,10 +1530,10 @@ impl ConversationEditor { cx: &mut ViewContext, ) -> Self { let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx)); - Self::from_conversation(conversation, fs, cx) + Self::for_conversation(conversation, fs, cx) } - fn from_conversation( + fn for_conversation( conversation: ModelHandle, fs: Arc, cx: &mut ViewContext, @@ -2116,9 +2113,8 @@ async fn stream_completion( #[cfg(test)] mod tests { - use crate::MessageId; - use super::*; + use crate::MessageId; use fs::FakeFs; use gpui::{AppContext, TestAppContext}; use project::Project; @@ -2464,6 +2460,64 @@ mod tests { } } + #[gpui::test] + fn test_serialization(cx: &mut AppContext) { + let registry = Arc::new(LanguageRegistry::test()); + let conversation = + cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx)); + let buffer = conversation.read(cx).buffer.clone(); + let message_0 = conversation.read(cx).message_anchors[0].id; + let message_1 = conversation.update(cx, |conversation, cx| { + conversation + .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx) + .unwrap() + }); + let message_2 = conversation.update(cx, |conversation, cx| { + conversation + .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx) + .unwrap() + }); + buffer.update(cx, |buffer, cx| { + buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx); + buffer.finalize_last_transaction(); + }); + let _message_3 = conversation.update(cx, |conversation, cx| { + conversation + .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx) + .unwrap() + }); + buffer.update(cx, |buffer, cx| buffer.undo(cx)); + assert_eq!(buffer.read(cx).text(), "a\nb\nc\n"); + assert_eq!( + messages(&conversation, cx), + [ + (message_0, Role::User, 0..2), + (message_1.id, Role::Assistant, 2..6), + (message_2.id, Role::System, 6..6), + ] + ); + + let deserialized_conversation = cx.add_model(|cx| { + Conversation::deserialize( + conversation.read(cx).serialize(cx), + Default::default(), + Default::default(), + registry.clone(), + cx, + ) + }); + let deserialized_buffer = deserialized_conversation.read(cx).buffer.clone(); + assert_eq!(deserialized_buffer.read(cx).text(), "a\nb\nc\n"); + assert_eq!( + messages(&deserialized_conversation, cx), + [ + (message_0, Role::User, 0..2), + (message_1.id, Role::Assistant, 2..6), + (message_2.id, Role::System, 6..6), + ] + ); + } + fn messages( conversation: &ModelHandle, cx: &AppContext,