Allow loading a previously-saved conversation

Antonio Scandurra created

Change summary

crates/ai/Cargo.toml              |   2 
crates/ai/src/ai.rs               |  38 ++++
crates/ai/src/assistant.rs        | 271 +++++++++++++++++++++++++-------
crates/theme/src/theme.rs         |   1 
styles/src/styleTree/assistant.ts |  31 +++
5 files changed, 279 insertions(+), 64 deletions(-)

Detailed changes

crates/ai/Cargo.toml 🔗

@@ -22,7 +22,7 @@ util = { path = "../util" }
 workspace = { path = "../workspace" }
 
 anyhow.workspace = true
-chrono = "0.4"
+chrono = { version = "0.4", features = ["serde"] }
 futures.workspace = true
 isahc.workspace = true
 regex.workspace = true

crates/ai/src/ai.rs 🔗

@@ -3,6 +3,8 @@ mod assistant_settings;
 
 use anyhow::Result;
 pub use assistant::AssistantPanel;
+use chrono::{DateTime, Local};
+use collections::HashMap;
 use fs::Fs;
 use futures::StreamExt;
 use gpui::AppContext;
@@ -12,7 +14,6 @@ use std::{
     fmt::{self, Display},
     path::PathBuf,
     sync::Arc,
-    time::SystemTime,
 };
 use util::paths::CONVERSATIONS_DIR;
 
@@ -24,11 +25,44 @@ struct OpenAIRequest {
     stream: bool,
 }
 
+#[derive(
+    Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
+)]
+struct MessageId(usize);
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+struct MessageMetadata {
+    role: Role,
+    sent_at: DateTime<Local>,
+    status: MessageStatus,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+enum MessageStatus {
+    Pending,
+    Done,
+    Error(Arc<str>),
+}
+
+#[derive(Serialize, Deserialize)]
+struct SavedMessage {
+    id: MessageId,
+    start: usize,
+}
+
 #[derive(Serialize, Deserialize)]
 struct SavedConversation {
     zed: String,
     version: String,
-    messages: Vec<RequestMessage>,
+    text: String,
+    messages: Vec<SavedMessage>,
+    message_metadata: HashMap<MessageId, MessageMetadata>,
+    summary: String,
+    model: String,
+}
+
+impl SavedConversation {
+    const VERSION: &'static str = "0.1.0";
 }
 
 struct SavedConversationMetadata {

crates/ai/src/assistant.rs 🔗

@@ -1,7 +1,7 @@
 use crate::{
     assistant_settings::{AssistantDockPosition, AssistantSettings},
-    OpenAIRequest, OpenAIResponseStreamEvent, RequestMessage, Role, SavedConversation,
-    SavedConversationMetadata,
+    MessageId, MessageMetadata, MessageStatus, OpenAIRequest, OpenAIResponseStreamEvent,
+    RequestMessage, Role, SavedConversation, SavedConversationMetadata, SavedMessage,
 };
 use anyhow::{anyhow, Result};
 use chrono::{DateTime, Local};
@@ -27,10 +27,18 @@ use language::{language_settings::SoftWrap, Buffer, LanguageRegistry, ToOffset a
 use serde::Deserialize;
 use settings::SettingsStore;
 use std::{
-    borrow::Cow, cell::RefCell, cmp, env, fmt::Write, io, iter, ops::Range, path::PathBuf, rc::Rc,
-    sync::Arc, time::Duration,
+    borrow::Cow,
+    cell::RefCell,
+    cmp, env,
+    fmt::Write,
+    io, iter,
+    ops::Range,
+    path::{Path, PathBuf},
+    rc::Rc,
+    sync::Arc,
+    time::Duration,
 };
-use theme::{ui::IconStyle, IconButton, Theme};
+use theme::ui::IconStyle;
 use util::{
     channel::ReleaseChannel, paths::CONVERSATIONS_DIR, post_inc, truncate_and_trailoff, ResultExt,
     TryFutureExt,
@@ -68,7 +76,7 @@ pub fn init(cx: &mut AppContext) {
         |workspace: &mut Workspace, _: &NewContext, cx: &mut ViewContext<Workspace>| {
             if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
                 this.update(cx, |this, cx| {
-                    this.add_conversation(cx);
+                    this.new_conversation(cx);
                 })
             }
 
@@ -187,13 +195,8 @@ impl AssistantPanel {
         })
     }
 
-    fn add_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
-        let focus = self.has_focus(cx);
+    fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
         let editor = cx.add_view(|cx| {
-            if focus {
-                cx.focus_self();
-            }
-
             ConversationEditor::new(
                 self.api_key.clone(),
                 self.languages.clone(),
@@ -201,14 +204,24 @@ impl AssistantPanel {
                 cx,
             )
         });
+        self.add_conversation(editor.clone(), cx);
+        editor
+    }
+
+    fn add_conversation(
+        &mut self,
+        editor: ViewHandle<ConversationEditor>,
+        cx: &mut ViewContext<Self>,
+    ) {
         self.subscriptions
             .push(cx.subscribe(&editor, Self::handle_conversation_editor_event));
 
         self.active_conversation_index = Some(self.conversation_editors.len());
         self.conversation_editors.push(editor.clone());
-
+        if self.has_focus(cx) {
+            cx.focus(&editor);
+        }
         cx.notify();
-        editor
     }
 
     fn handle_conversation_editor_event(
@@ -264,9 +277,28 @@ impl AssistantPanel {
     }
 
     fn render_hamburger_button(style: &IconStyle) -> impl Element<Self> {
+        enum ListConversations {}
         Svg::for_style(style.icon.clone())
             .contained()
             .with_style(style.container)
+            .mouse::<ListConversations>(0)
+            .with_cursor_style(CursorStyle::PointingHand)
+            .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
+                this.active_conversation_index = None;
+                cx.notify();
+            })
+    }
+
+    fn render_plus_button(style: &IconStyle) -> impl Element<Self> {
+        enum AddConversation {}
+        Svg::for_style(style.icon.clone())
+            .contained()
+            .with_style(style.container)
+            .mouse::<AddConversation>(0)
+            .with_cursor_style(CursorStyle::PointingHand)
+            .on_click(MouseButton::Left, |_, this: &mut Self, cx| {
+                this.new_conversation(cx);
+            })
     }
 
     fn render_saved_conversation(
@@ -274,20 +306,23 @@ impl AssistantPanel {
         index: usize,
         cx: &mut ViewContext<Self>,
     ) -> impl Element<Self> {
+        let conversation = &self.saved_conversations[index];
+        let path = conversation.path.clone();
         MouseEventHandler::<SavedConversationMetadata, _>::new(index, cx, move |state, cx| {
             let style = &theme::current(cx).assistant.saved_conversation;
-            let conversation = &self.saved_conversations[index];
             Flex::row()
                 .with_child(
                     Label::new(
-                        conversation.mtime.format("%c").to_string(),
+                        conversation.mtime.format("%F %I:%M%p").to_string(),
                         style.saved_at.text.clone(),
                     )
+                    .aligned()
                     .contained()
                     .with_style(style.saved_at.container),
                 )
                 .with_child(
                     Label::new(conversation.title.clone(), style.title.text.clone())
+                        .aligned()
                         .contained()
                         .with_style(style.title.container),
                 )
@@ -295,7 +330,48 @@ impl AssistantPanel {
                 .with_style(*style.container.style_for(state, false))
         })
         .with_cursor_style(CursorStyle::PointingHand)
-        .on_click(MouseButton::Left, |_, this, cx| {})
+        .on_click(MouseButton::Left, move |_, this, cx| {
+            this.open_conversation(path.clone(), cx)
+                .detach_and_log_err(cx)
+        })
+    }
+
+    fn open_conversation(&mut self, path: PathBuf, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
+        if let Some(ix) = self.conversation_editor_index_for_path(&path, cx) {
+            self.active_conversation_index = Some(ix);
+            cx.notify();
+            return Task::ready(Ok(()));
+        }
+
+        let fs = self.fs.clone();
+        let conversation = Conversation::load(
+            path.clone(),
+            self.api_key.clone(),
+            self.languages.clone(),
+            self.fs.clone(),
+            cx,
+        );
+        cx.spawn(|this, mut cx| async move {
+            let conversation = conversation.await?;
+            this.update(&mut cx, |this, cx| {
+                // If, by the time we've loaded the conversation, the user has already opened
+                // the same conversation, we don't want to open it again.
+                if let Some(ix) = this.conversation_editor_index_for_path(&path, cx) {
+                    this.active_conversation_index = Some(ix);
+                } else {
+                    let editor = cx
+                        .add_view(|cx| ConversationEditor::from_conversation(conversation, fs, cx));
+                    this.add_conversation(editor, cx);
+                }
+            })?;
+            Ok(())
+        })
+    }
+
+    fn conversation_editor_index_for_path(&self, path: &Path, cx: &AppContext) -> Option<usize> {
+        self.conversation_editors
+            .iter()
+            .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
     }
 }
 
@@ -341,30 +417,37 @@ impl View for AssistantPanel {
                 .with_style(style.api_key_prompt.container)
                 .aligned()
                 .into_any()
-        } else if let Some(editor) = self.active_conversation_editor() {
+        } else {
             Flex::column()
                 .with_child(
                     Flex::row()
-                        .with_child(Self::render_hamburger_button(&style.hamburger_button))
+                        .with_child(
+                            Self::render_hamburger_button(&style.hamburger_button).aligned(),
+                        )
+                        .with_child(Self::render_plus_button(&style.plus_button).aligned())
                         .contained()
                         .with_style(theme.workspace.tab_bar.container)
+                        .expanded()
                         .constrained()
                         .with_height(theme.workspace.tab_bar.height),
                 )
-                .with_child(ChildView::new(editor, cx).flex(1., true))
+                .with_child(if let Some(editor) = self.active_conversation_editor() {
+                    ChildView::new(editor, cx).flex(1., true).into_any()
+                } else {
+                    UniformList::new(
+                        self.saved_conversations_list_state.clone(),
+                        self.saved_conversations.len(),
+                        cx,
+                        |this, range, items, cx| {
+                            for ix in range {
+                                items.push(this.render_saved_conversation(ix, cx).into_any());
+                            }
+                        },
+                    )
+                    .flex(1., true)
+                    .into_any()
+                })
                 .into_any()
-        } else {
-            UniformList::new(
-                self.saved_conversations_list_state.clone(),
-                self.saved_conversations.len(),
-                cx,
-                |this, range, items, cx| {
-                    for ix in range {
-                        items.push(this.render_saved_conversation(ix, cx).into_any());
-                    }
-                },
-            )
-            .into_any()
         }
     }
 
@@ -468,7 +551,7 @@ impl Panel for AssistantPanel {
             }
 
             if self.conversation_editors.is_empty() {
-                self.add_conversation(cx);
+                self.new_conversation(cx);
             }
         }
     }
@@ -598,6 +681,74 @@ impl Conversation {
         this
     }
 
+    fn load(
+        path: PathBuf,
+        api_key: Rc<RefCell<Option<String>>>,
+        language_registry: Arc<LanguageRegistry>,
+        fs: Arc<dyn Fs>,
+        cx: &mut AppContext,
+    ) -> Task<Result<ModelHandle<Self>>> {
+        cx.spawn(|mut cx| async move {
+            let saved_conversation = fs.load(&path).await?;
+            let saved_conversation: SavedConversation = serde_json::from_str(&saved_conversation)?;
+
+            let model = saved_conversation.model;
+            let markdown = language_registry.language_for_name("Markdown");
+            let mut message_anchors = Vec::new();
+            let mut next_message_id = MessageId(0);
+            let buffer = cx.add_model(|cx| {
+                let mut buffer = Buffer::new(0, saved_conversation.text, cx);
+                for message in saved_conversation.messages {
+                    message_anchors.push(MessageAnchor {
+                        id: message.id,
+                        start: buffer.anchor_before(message.start),
+                    });
+                    next_message_id = cmp::max(next_message_id, MessageId(message.id.0 + 1));
+                }
+                buffer.set_language_registry(language_registry);
+                cx.spawn_weak(|buffer, mut cx| async move {
+                    let markdown = markdown.await?;
+                    let buffer = buffer
+                        .upgrade(&cx)
+                        .ok_or_else(|| anyhow!("buffer was dropped"))?;
+                    buffer.update(&mut cx, |buffer, cx| {
+                        buffer.set_language(Some(markdown), cx)
+                    });
+                    anyhow::Ok(())
+                })
+                .detach_and_log_err(cx);
+                buffer
+            });
+            let conversation = cx.add_model(|cx| {
+                let mut this = Self {
+                    message_anchors,
+                    messages_metadata: saved_conversation.message_metadata,
+                    next_message_id,
+                    summary: Some(Summary {
+                        text: saved_conversation.summary,
+                        done: true,
+                    }),
+                    pending_summary: Task::ready(None),
+                    completion_count: Default::default(),
+                    pending_completions: Default::default(),
+                    token_count: None,
+                    max_token_count: tiktoken_rs::model::get_context_size(&model),
+                    pending_token_count: Task::ready(None),
+                    model,
+                    _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
+                    pending_save: Task::ready(Ok(())),
+                    path: Some(path),
+                    api_key,
+                    buffer,
+                };
+
+                this.count_remaining_tokens(cx);
+                this
+            });
+            Ok(conversation)
+        })
+    }
+
     fn handle_buffer_event(
         &mut self,
         _: ModelHandle<Buffer>,
@@ -1122,15 +1273,22 @@ impl Conversation {
             });
 
             if let Some(summary) = summary {
-                let conversation = SavedConversation {
+                let conversation = this.read_with(&cx, |this, cx| SavedConversation {
                     zed: "conversation".into(),
-                    version: "0.1".into(),
-                    messages: this.read_with(&cx, |this, cx| {
-                        this.messages(cx)
-                            .map(|message| message.to_open_ai_message(this.buffer.read(cx)))
-                            .collect()
-                    }),
-                };
+                    version: SavedConversation::VERSION.into(),
+                    text: this.buffer.read(cx).text(),
+                    message_metadata: this.messages_metadata.clone(),
+                    messages: this
+                        .message_anchors
+                        .iter()
+                        .map(|message| SavedMessage {
+                            id: message.id,
+                            start: message.start.to_offset(this.buffer.read(cx)),
+                        })
+                        .collect(),
+                    summary: summary.clone(),
+                    model: this.model.clone(),
+                });
 
                 let path = if let Some(old_path) = old_path {
                     old_path
@@ -1195,6 +1353,14 @@ impl ConversationEditor {
         cx: &mut ViewContext<Self>,
     ) -> Self {
         let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
+        Self::from_conversation(conversation, fs, cx)
+    }
+
+    fn from_conversation(
+        conversation: ModelHandle<Conversation>,
+        fs: Arc<dyn Fs>,
+        cx: &mut ViewContext<Self>,
+    ) -> Self {
         let editor = cx.add_view(|cx| {
             let mut editor = Editor::for_buffer(conversation.read(cx).buffer.clone(), None, cx);
             editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
@@ -1524,7 +1690,7 @@ impl ConversationEditor {
                 let conversation = panel
                     .active_conversation_editor()
                     .cloned()
-                    .unwrap_or_else(|| panel.add_conversation(cx));
+                    .unwrap_or_else(|| panel.new_conversation(cx));
                 conversation.update(cx, |conversation, cx| {
                     conversation
                         .editor
@@ -1693,29 +1859,12 @@ impl Item for ConversationEditor {
     }
 }
 
-#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash)]
-struct MessageId(usize);
-
 #[derive(Clone, Debug)]
 struct MessageAnchor {
     id: MessageId,
     start: language::Anchor,
 }
 
-#[derive(Clone, Debug)]
-struct MessageMetadata {
-    role: Role,
-    sent_at: DateTime<Local>,
-    status: MessageStatus,
-}
-
-#[derive(Clone, Debug)]
-enum MessageStatus {
-    Pending,
-    Done,
-    Error(Arc<str>),
-}
-
 #[derive(Clone, Debug)]
 pub struct Message {
     range: Range<usize>,
@@ -1733,7 +1882,7 @@ impl Message {
         content.extend(buffer.text_for_range(self.range.clone()));
         RequestMessage {
             role: self.role,
-            content,
+            content: content.trim_end().into(),
         }
     }
 }
@@ -1826,6 +1975,8 @@ async fn stream_completion(
 
 #[cfg(test)]
 mod tests {
+    use crate::MessageId;
+
     use super::*;
     use fs::FakeFs;
     use gpui::{AppContext, TestAppContext};

crates/theme/src/theme.rs 🔗

@@ -995,6 +995,7 @@ pub struct TerminalStyle {
 pub struct AssistantStyle {
     pub container: ContainerStyle,
     pub hamburger_button: IconStyle,
+    pub plus_button: IconStyle,
     pub message_header: ContainerStyle,
     pub sent_at: ContainedText,
     pub user_sender: Interactive<ContainedText>,

styles/src/styleTree/assistant.ts 🔗

@@ -23,7 +23,36 @@ export default function assistant(colorScheme: ColorScheme) {
               height: 15,
             },
           },
-          container: {}
+          container: {
+            margin: { left: 8 },
+          }
+        },
+        plusButton: {
+          icon: {
+            color: text(layer, "sans", "default", { size: "sm" }).color,
+            asset: "icons/plus_12.svg",
+            dimensions: {
+              width: 12,
+              height: 12,
+            },
+          },
+          container: {
+            margin: { left: 8 },
+          }
+        },
+        savedConversation: {
+          background: background(layer, "on"),
+          hover: {
+            background: background(layer, "on", "hovered"),
+          },
+          savedAt: {
+            margin: { left: 8 },
+            ...text(layer, "sans", "default", { size: "xs" }),
+          },
+          title: {
+            margin: { left: 8 },
+            ...text(layer, "sans", "default", { size: "sm", weight: "bold" }),
+          }
         },
         userSender: {
             ...text(layer, "sans", "default", { size: "sm", weight: "bold" }),