assistant2: Suggest current thread in inline assistant (#22586)

Agus Zubiaga and Marshall created

Release Notes:

- N/A

---------

Co-authored-by: Marshall <marshall@zed.com>

Change summary

crates/assistant2/src/active_thread.rs                        |   2 
crates/assistant2/src/assistant_panel.rs                      |   6 
crates/assistant2/src/context.rs                              |   8 
crates/assistant2/src/context_picker/thread_context_picker.rs |  24 
crates/assistant2/src/context_store.rs                        |  16 
crates/assistant2/src/context_strip.rs                        | 196 ++--
crates/assistant2/src/thread.rs                               |  21 
crates/assistant2/src/ui/context_pill.rs                      |   2 
8 files changed, 161 insertions(+), 114 deletions(-)

Detailed changes

crates/assistant2/src/active_thread.rs 🔗

@@ -22,7 +22,7 @@ pub struct ActiveThread {
     workspace: WeakView<Workspace>,
     language_registry: Arc<LanguageRegistry>,
     tools: Arc<ToolWorkingSet>,
-    thread: Model<Thread>,
+    pub(crate) thread: Model<Thread>,
     messages: Vec<MessageId>,
     list_state: ListState,
     rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,

crates/assistant2/src/assistant_panel.rs 🔗

@@ -19,7 +19,7 @@ use workspace::Workspace;
 use crate::active_thread::ActiveThread;
 use crate::assistant_settings::{AssistantDockPosition, AssistantSettings};
 use crate::message_editor::MessageEditor;
-use crate::thread::{ThreadError, ThreadId};
+use crate::thread::{Thread, ThreadError, ThreadId};
 use crate::thread_history::{PastThread, ThreadHistory};
 use crate::thread_store::ThreadStore;
 use crate::{NewThread, OpenHistory, ToggleFocus};
@@ -206,6 +206,10 @@ impl AssistantPanel {
         self.message_editor.focus_handle(cx).focus(cx);
     }
 
+    pub(crate) fn active_thread(&self, cx: &AppContext) -> Model<Thread> {
+        self.thread.read(cx).thread.clone()
+    }
+
     pub(crate) fn delete_thread(&mut self, thread_id: &ThreadId, cx: &mut ViewContext<Self>) {
         self.thread_store
             .update(cx, |this, cx| this.delete_thread(thread_id, cx));

crates/assistant2/src/context.rs 🔗

@@ -4,6 +4,8 @@ use project::ProjectEntryId;
 use serde::{Deserialize, Serialize};
 use util::post_inc;
 
+use crate::thread::ThreadId;
+
 #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 pub struct ContextId(pub(crate) usize);
 
@@ -22,12 +24,12 @@ pub struct Context {
     pub text: SharedString,
 }
 
-#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 pub enum ContextKind {
     File(ProjectEntryId),
     Directory,
     FetchedUrl,
-    Thread,
+    Thread(ThreadId),
 }
 
 pub fn attach_context_to_message(
@@ -55,7 +57,7 @@ pub fn attach_context_to_message(
                 fetch_context.push_str(&context.text);
                 fetch_context.push('\n');
             }
-            ContextKind::Thread => {
+            ContextKind::Thread(_) => {
                 thread_context.push_str(&context.name);
                 thread_context.push('\n');
                 thread_context.push_str(&context.text);

crates/assistant2/src/context_picker/thread_context_picker.rs 🔗

@@ -169,25 +169,11 @@ impl PickerDelegate for ThreadContextPickerDelegate {
 
         self.context_store
             .update(cx, |context_store, cx| {
-                let text = thread.update(cx, |thread, _cx| {
-                    let mut text = String::new();
-
-                    for message in thread.messages() {
-                        text.push_str(match message.role {
-                            language_model::Role::User => "User:",
-                            language_model::Role::Assistant => "Assistant:",
-                            language_model::Role::System => "System:",
-                        });
-                        text.push('\n');
-
-                        text.push_str(&message.text);
-                        text.push('\n');
-                    }
-
-                    text
-                });
-
-                context_store.insert_context(ContextKind::Thread, entry.summary.clone(), text);
+                context_store.insert_context(
+                    ContextKind::Thread(thread.read(cx).id().clone()),
+                    entry.summary.clone(),
+                    thread.read(cx).text(),
+                );
             })
             .ok();
 

crates/assistant2/src/context_store.rs 🔗

@@ -1,7 +1,10 @@
 use gpui::SharedString;
 use project::ProjectEntryId;
 
-use crate::context::{Context, ContextId, ContextKind};
+use crate::{
+    context::{Context, ContextId, ContextKind},
+    thread::ThreadId,
+};
 
 pub struct ContextStore {
     context: Vec<Context>,
@@ -49,9 +52,14 @@ impl ContextStore {
     pub fn contains_project_entry(&self, entry_id: ProjectEntryId) -> bool {
         self.context.iter().any(|probe| match probe.kind {
             ContextKind::File(probe_entry_id) => probe_entry_id == entry_id,
-            ContextKind::Directory => false,
-            ContextKind::FetchedUrl => false,
-            ContextKind::Thread => false,
+            ContextKind::Directory | ContextKind::FetchedUrl | ContextKind::Thread(_) => false,
+        })
+    }
+
+    pub fn contains_thread(&self, thread_id: &ThreadId) -> bool {
+        self.context.iter().any(|probe| match probe.kind {
+            ContextKind::Thread(ref probe_thread_id) => probe_thread_id == thread_id,
+            ContextKind::File(_) | ContextKind::Directory | ContextKind::FetchedUrl => false,
         })
     }
 }

crates/assistant2/src/context_strip.rs 🔗

@@ -1,18 +1,19 @@
 use std::rc::Rc;
 
 use editor::Editor;
-use gpui::{EntityId, FocusHandle, Model, Subscription, View, WeakModel, WeakView};
+use gpui::{AppContext, FocusHandle, Model, View, WeakModel, WeakView};
 use language::Buffer;
 use project::ProjectEntryId;
 use ui::{prelude::*, PopoverMenu, PopoverMenuHandle, Tooltip};
-use workspace::{ItemHandle, Workspace};
+use workspace::Workspace;
 
 use crate::context::ContextKind;
 use crate::context_picker::{ConfirmBehavior, ContextPicker};
 use crate::context_store::ContextStore;
+use crate::thread::{Thread, ThreadId};
 use crate::thread_store::ThreadStore;
 use crate::ui::ContextPill;
-use crate::ToggleContextPicker;
+use crate::{AssistantPanel, ToggleContextPicker};
 use settings::Settings;
 
 pub struct ContextStrip {
@@ -20,21 +21,8 @@ pub struct ContextStrip {
     context_picker: View<ContextPicker>,
     context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
     focus_handle: FocusHandle,
-    workspace_active_pane_id: Option<EntityId>,
-    suggested_context: Option<SuggestedContext>,
-    _subscription: Option<Subscription>,
-}
-
-pub enum SuggestContextKind {
-    File,
-    Thread,
-}
-
-#[derive(Clone)]
-pub struct SuggestedContext {
-    entry_id: ProjectEntryId,
-    title: SharedString,
-    buffer: WeakModel<Buffer>,
+    suggest_context_kind: SuggestContextKind,
+    workspace: WeakView<Workspace>,
 }
 
 impl ContextStrip {
@@ -47,20 +35,6 @@ impl ContextStrip {
         suggest_context_kind: SuggestContextKind,
         cx: &mut ViewContext<Self>,
     ) -> Self {
-        let subscription = match suggest_context_kind {
-            SuggestContextKind::File => {
-                if let Some(workspace) = workspace.upgrade() {
-                    Some(cx.subscribe(&workspace, Self::handle_workspace_event))
-                } else {
-                    None
-                }
-            }
-            SuggestContextKind::Thread => {
-                // TODO: Suggest current thread
-                None
-            }
-        };
-
         Self {
             context_store: context_store.clone(),
             context_picker: cx.new_view(|cx| {
@@ -74,56 +48,65 @@ impl ContextStrip {
             }),
             context_picker_menu_handle,
             focus_handle,
-            workspace_active_pane_id: None,
-            suggested_context: None,
-            _subscription: subscription,
+            suggest_context_kind,
+            workspace,
         }
     }
 
-    fn handle_workspace_event(
-        &mut self,
-        workspace: View<Workspace>,
-        event: &workspace::Event,
-        cx: &mut ViewContext<Self>,
-    ) {
-        match event {
-            workspace::Event::WorkspaceCreated(_) | workspace::Event::ActiveItemChanged => {
-                let workspace = workspace.read(cx);
-
-                if let Some(active_item) = workspace.active_item(cx) {
-                    let new_active_item_id = Some(active_item.item_id());
-
-                    if self.workspace_active_pane_id != new_active_item_id {
-                        self.suggested_context = Self::suggested_file(active_item, cx);
-                        self.workspace_active_pane_id = new_active_item_id;
-                    }
-                } else {
-                    self.suggested_context = None;
-                    self.workspace_active_pane_id = None;
-                }
-            }
-            _ => {}
+    fn suggested_context(&self, cx: &ViewContext<Self>) -> Option<SuggestedContext> {
+        match self.suggest_context_kind {
+            SuggestContextKind::File => self.suggested_file(cx),
+            SuggestContextKind::Thread => self.suggested_thread(cx),
         }
     }
 
-    fn suggested_file(
-        active_item: Box<dyn ItemHandle>,
-        cx: &WindowContext,
-    ) -> Option<SuggestedContext> {
+    fn suggested_file(&self, cx: &ViewContext<Self>) -> Option<SuggestedContext> {
+        let workspace = self.workspace.upgrade()?;
+        let active_item = workspace.read(cx).active_item(cx)?;
         let entry_id = *active_item.project_entry_ids(cx).first()?;
 
+        if self.context_store.read(cx).contains_project_entry(entry_id) {
+            return None;
+        }
+
         let editor = active_item.to_any().downcast::<Editor>().ok()?.read(cx);
         let active_buffer = editor.buffer().read(cx).as_singleton()?;
 
         let file = active_buffer.read(cx).file()?;
         let title = file.path().to_string_lossy().into_owned().into();
 
-        Some(SuggestedContext {
+        Some(SuggestedContext::File {
             entry_id,
             title,
             buffer: active_buffer.downgrade(),
         })
     }
+
+    fn suggested_thread(&self, cx: &ViewContext<Self>) -> Option<SuggestedContext> {
+        let workspace = self.workspace.upgrade()?;
+        let active_thread = workspace
+            .read(cx)
+            .panel::<AssistantPanel>(cx)?
+            .read(cx)
+            .active_thread(cx);
+        let weak_active_thread = active_thread.downgrade();
+
+        let active_thread = active_thread.read(cx);
+
+        if self
+            .context_store
+            .read(cx)
+            .contains_thread(active_thread.id())
+        {
+            return None;
+        }
+
+        Some(SuggestedContext::Thread {
+            id: active_thread.id().clone(),
+            title: active_thread.summary().unwrap_or("Active Thread".into()),
+            thread: weak_active_thread,
+        })
+    }
 }
 
 impl Render for ContextStrip {
@@ -133,13 +116,7 @@ impl Render for ContextStrip {
         let context_picker = self.context_picker.clone();
         let focus_handle = self.focus_handle.clone();
 
-        let suggested_context = self.suggested_context.as_ref().and_then(|suggested| {
-            if context_store.contains_project_entry(suggested.entry_id) {
-                None
-            } else {
-                Some(suggested.clone())
-            }
-        });
+        let suggested_context = self.suggested_context(cx);
 
         h_flex()
             .flex_wrap()
@@ -172,7 +149,7 @@ impl Render for ContextStrip {
                     })
                     .with_handle(self.context_picker_menu_handle.clone()),
             )
-            .when(context.is_empty() && self.suggested_context.is_none(), {
+            .when(context.is_empty() && suggested_context.is_none(), {
                 |parent| {
                     parent.child(
                         h_flex()
@@ -209,24 +186,13 @@ impl Render for ContextStrip {
             }))
             .when_some(suggested_context, |el, suggested| {
                 el.child(
-                    Button::new("add-suggested-context", suggested.title.clone())
+                    Button::new("add-suggested-context", suggested.title().clone())
                         .on_click({
                             let context_store = self.context_store.clone();
 
                             cx.listener(move |_this, _event, cx| {
-                                let Some(buffer) = suggested.buffer.upgrade() else {
-                                    return;
-                                };
-
-                                let title = suggested.title.clone();
-                                let text = buffer.read(cx).text();
-
-                                context_store.update(cx, move |context_store, _cx| {
-                                    context_store.insert_context(
-                                        ContextKind::File(suggested.entry_id),
-                                        title,
-                                        text,
-                                    );
+                                context_store.update(cx, |context_store, cx| {
+                                    suggested.accept(context_store, cx);
                                 });
                                 cx.notify();
                             })
@@ -260,3 +226,63 @@ impl Render for ContextStrip {
             })
     }
 }
+
+pub enum SuggestContextKind {
+    File,
+    Thread,
+}
+
+#[derive(Clone)]
+pub enum SuggestedContext {
+    File {
+        entry_id: ProjectEntryId,
+        title: SharedString,
+        buffer: WeakModel<Buffer>,
+    },
+    Thread {
+        id: ThreadId,
+        title: SharedString,
+        thread: WeakModel<Thread>,
+    },
+}
+
+impl SuggestedContext {
+    pub fn title(&self) -> &SharedString {
+        match self {
+            Self::File { title, .. } => title,
+            Self::Thread { title, .. } => title,
+        }
+    }
+
+    pub fn accept(&self, context_store: &mut ContextStore, cx: &mut AppContext) {
+        match self {
+            Self::File {
+                entry_id,
+                title,
+                buffer,
+            } => {
+                let Some(buffer) = buffer.upgrade() else {
+                    return;
+                };
+                let text = buffer.read(cx).text();
+
+                context_store.insert_context(
+                    ContextKind::File(*entry_id),
+                    title.clone(),
+                    text.clone(),
+                );
+            }
+            Self::Thread { id, title, thread } => {
+                let Some(thread) = thread.upgrade() else {
+                    return;
+                };
+
+                context_store.insert_context(
+                    ContextKind::Thread(id.clone()),
+                    title.clone(),
+                    thread.read(cx).text(),
+                );
+            }
+        }
+    }
+}

crates/assistant2/src/thread.rs 🔗

@@ -164,6 +164,27 @@ impl Thread {
         id
     }
 
+    /// Returns the representation of this [`Thread`] in a textual form.
+    ///
+    /// This is the representation we use when attaching a thread as context to another thread.
+    pub fn text(&self) -> String {
+        let mut text = String::new();
+
+        for message in &self.messages {
+            text.push_str(match message.role {
+                language_model::Role::User => "User:",
+                language_model::Role::Assistant => "Assistant:",
+                language_model::Role::System => "System:",
+            });
+            text.push('\n');
+
+            text.push_str(&message.text);
+            text.push('\n');
+        }
+
+        text
+    }
+
     pub fn to_completion_request(
         &self,
         _request_kind: RequestKind,

crates/assistant2/src/ui/context_pill.rs 🔗

@@ -36,7 +36,7 @@ impl RenderOnce for ContextPill {
             ContextKind::File(_) => IconName::File,
             ContextKind::Directory => IconName::Folder,
             ContextKind::FetchedUrl => IconName::Globe,
-            ContextKind::Thread => IconName::MessageCircle,
+            ContextKind::Thread(_) => IconName::MessageCircle,
         };
 
         h_flex()