assistant2: Visualize tool use (#25692)

Marshall Bowers created

This PR adds visuals for tool use in Assistant 2:

<img width="1309" alt="Screenshot 2025-02-26 at 5 57 14 PM"
src="https://github.com/user-attachments/assets/4083ff65-a2f1-4a43-8815-0bade2c00af2"
/>

Release Notes:

- N/A

Change summary

crates/assistant2/src/active_thread.rs | 113 +++++++++++++++++++++++++++
crates/assistant2/src/thread.rs        |  72 +++++++++++++++++
2 files changed, 181 insertions(+), 4 deletions(-)

Detailed changes

crates/assistant2/src/active_thread.rs 🔗

@@ -8,14 +8,14 @@ use gpui::{
     UnderlineStyle, WeakEntity,
 };
 use language::LanguageRegistry;
-use language_model::Role;
+use language_model::{LanguageModelToolUseId, Role};
 use markdown::{Markdown, MarkdownStyle};
 use settings::Settings as _;
 use theme::ThemeSettings;
-use ui::prelude::*;
+use ui::{prelude::*, Disclosure};
 use workspace::Workspace;
 
-use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
+use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent, ToolUse, ToolUseStatus};
 use crate::thread_store::ThreadStore;
 use crate::ui::ContextPill;
 
@@ -28,6 +28,7 @@ pub struct ActiveThread {
     messages: Vec<MessageId>,
     list_state: ListState,
     rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
+    expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
     last_error: Option<ThreadError>,
     _subscriptions: Vec<Subscription>,
 }
@@ -55,6 +56,7 @@ impl ActiveThread {
             thread: thread.clone(),
             messages: Vec::new(),
             rendered_messages_by_id: HashMap::default(),
+            expanded_tool_uses: HashMap::default(),
             list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
                 let this = cx.entity().downgrade();
                 move |ix, _: &mut Window, cx: &mut App| {
@@ -276,6 +278,7 @@ impl ActiveThread {
         };
 
         let context = self.thread.read(cx).context_for_message(message_id);
+        let tool_uses = self.thread.read(cx).tool_uses_for_message(message_id);
         let colors = cx.theme().colors();
 
         let message_content = v_flex()
@@ -332,7 +335,22 @@ impl ActiveThread {
                         )
                         .child(message_content),
                 ),
-            Role::Assistant => div().id(("message-container", ix)).child(message_content),
+            Role::Assistant => div()
+                .id(("message-container", ix))
+                .child(message_content)
+                .map(|parent| {
+                    if tool_uses.is_empty() {
+                        return parent;
+                    }
+
+                    parent.child(
+                        v_flex().children(
+                            tool_uses
+                                .into_iter()
+                                .map(|tool_use| self.render_tool_use(tool_use, cx)),
+                        ),
+                    )
+                }),
             Role::System => div().id(("message-container", ix)).py_1().px_2().child(
                 v_flex()
                     .bg(colors.editor_background)
@@ -343,6 +361,93 @@ impl ActiveThread {
 
         styled_message.into_any()
     }
+
+    fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context<Self>) -> impl IntoElement {
+        let is_open = self
+            .expanded_tool_uses
+            .get(&tool_use.id)
+            .copied()
+            .unwrap_or_default();
+
+        v_flex().px_2p5().child(
+            v_flex()
+                .gap_1()
+                .bg(cx.theme().colors().editor_background)
+                .rounded_lg()
+                .border_1()
+                .border_color(cx.theme().colors().border)
+                .shadow_sm()
+                .child(
+                    h_flex()
+                        .justify_between()
+                        .py_1()
+                        .px_2()
+                        .bg(cx.theme().colors().editor_foreground.opacity(0.05))
+                        .when(is_open, |element| element.border_b_1())
+                        .border_color(cx.theme().colors().border)
+                        .rounded_t(px(6.))
+                        .child(
+                            h_flex()
+                                .gap_2()
+                                .child(Disclosure::new("tool-use-disclosure", is_open).on_click(
+                                    cx.listener({
+                                        let tool_use_id = tool_use.id.clone();
+                                        move |this, _event, _window, _cx| {
+                                            let is_open = this
+                                                .expanded_tool_uses
+                                                .entry(tool_use_id.clone())
+                                                .or_insert(false);
+
+                                            *is_open = !*is_open;
+                                        }
+                                    }),
+                                ))
+                                .child(Label::new(tool_use.name)),
+                        )
+                        .child(Label::new(match tool_use.status {
+                            ToolUseStatus::Pending => "Pending",
+                            ToolUseStatus::Running => "Running",
+                            ToolUseStatus::Finished(_) => "Finished",
+                            ToolUseStatus::Error(_) => "Error",
+                        })),
+                )
+                .map(|parent| {
+                    if !is_open {
+                        return parent;
+                    }
+
+                    parent.child(
+                        v_flex()
+                            .gap_2()
+                            .p_2p5()
+                            .child(
+                                v_flex()
+                                    .gap_0p5()
+                                    .child(Label::new("Input:"))
+                                    .child(Label::new(
+                                        serde_json::to_string_pretty(&tool_use.input)
+                                            .unwrap_or_default(),
+                                    )),
+                            )
+                            .map(|parent| match tool_use.status {
+                                ToolUseStatus::Finished(output) => parent.child(
+                                    v_flex()
+                                        .gap_0p5()
+                                        .child(Label::new("Result:"))
+                                        .child(Label::new(output)),
+                                ),
+                                ToolUseStatus::Error(err) => parent.child(
+                                    v_flex()
+                                        .gap_0p5()
+                                        .child(Label::new("Error:"))
+                                        .child(Label::new(err)),
+                                ),
+                                ToolUseStatus::Pending | ToolUseStatus::Running => parent,
+                            }),
+                    )
+                }),
+        )
+    }
 }
 
 impl Render for ActiveThread {

crates/assistant2/src/thread.rs 🔗

@@ -59,6 +59,22 @@ pub struct Message {
     pub text: String,
 }
 
+#[derive(Debug)]
+pub struct ToolUse {
+    pub id: LanguageModelToolUseId,
+    pub name: SharedString,
+    pub status: ToolUseStatus,
+    pub input: serde_json::Value,
+}
+
+#[derive(Debug, Clone)]
+pub enum ToolUseStatus {
+    Pending,
+    Running,
+    Finished(SharedString),
+    Error(SharedString),
+}
+
 /// A thread of conversation with the LLM.
 pub struct Thread {
     id: ThreadId,
@@ -192,6 +208,61 @@ impl Thread {
         self.pending_tool_uses_by_id.values().collect()
     }
 
+    pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
+        let Some(tool_uses_for_message) = &self.tool_uses_by_message.get(&id) else {
+            return Vec::new();
+        };
+
+        // The tool use was requested by an Assistant message, so we need to
+        // look for the tool results on the next user message.
+        let next_user_message = MessageId(id.0 + 1);
+
+        let empty = Vec::new();
+        let tool_results_for_message = self
+            .tool_results_by_message
+            .get(&next_user_message)
+            .unwrap_or_else(|| &empty);
+
+        let mut tool_uses = Vec::new();
+
+        for tool_use in tool_uses_for_message.iter() {
+            let tool_result = tool_results_for_message
+                .iter()
+                .find(|tool_result| tool_result.tool_use_id == tool_use.id);
+
+            let status = (|| {
+                if let Some(tool_result) = tool_result {
+                    return if tool_result.is_error {
+                        ToolUseStatus::Error(tool_result.content.clone().into())
+                    } else {
+                        ToolUseStatus::Finished(tool_result.content.clone().into())
+                    };
+                }
+
+                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
+                    return match pending_tool_use.status {
+                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
+                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
+                        PendingToolUseStatus::Error(ref err) => {
+                            ToolUseStatus::Error(err.clone().into())
+                        }
+                    };
+                }
+
+                ToolUseStatus::Pending
+            })();
+
+            tool_uses.push(ToolUse {
+                id: tool_use.id.clone(),
+                name: tool_use.name.clone().into(),
+                input: tool_use.input.clone(),
+                status,
+            })
+        }
+
+        tool_uses
+    }
+
     pub fn insert_user_message(
         &mut self,
         text: impl Into<String>,
@@ -537,6 +608,7 @@ impl Thread {
                                     content: output,
                                     is_error: false,
                                 });
+                                thread.pending_tool_uses_by_id.remove(&tool_use_id);
 
                                 cx.emit(ThreadEvent::ToolFinished { tool_use_id });
                             }