diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index aad2aad2651e127b0e846719654121a9055cfbc9..0d4f7c062a41266e491e17e6ca8ca9144c4e6416 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -8,14 +8,14 @@ use gpui::{ UnderlineStyle, WeakEntity, }; use language::LanguageRegistry; -use language_model::Role; +use language_model::{LanguageModelToolUseId, Role}; use markdown::{Markdown, MarkdownStyle}; use settings::Settings as _; use theme::ThemeSettings; -use ui::prelude::*; +use ui::{prelude::*, Disclosure}; use workspace::Workspace; -use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent}; +use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent, ToolUse, ToolUseStatus}; use crate::thread_store::ThreadStore; use crate::ui::ContextPill; @@ -28,6 +28,7 @@ pub struct ActiveThread { messages: Vec, list_state: ListState, rendered_messages_by_id: HashMap>, + expanded_tool_uses: HashMap, last_error: Option, _subscriptions: Vec, } @@ -55,6 +56,7 @@ impl ActiveThread { thread: thread.clone(), messages: Vec::new(), rendered_messages_by_id: HashMap::default(), + expanded_tool_uses: HashMap::default(), list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), { let this = cx.entity().downgrade(); move |ix, _: &mut Window, cx: &mut App| { @@ -276,6 +278,7 @@ impl ActiveThread { }; let context = self.thread.read(cx).context_for_message(message_id); + let tool_uses = self.thread.read(cx).tool_uses_for_message(message_id); let colors = cx.theme().colors(); let message_content = v_flex() @@ -332,7 +335,22 @@ impl ActiveThread { ) .child(message_content), ), - Role::Assistant => div().id(("message-container", ix)).child(message_content), + Role::Assistant => div() + .id(("message-container", ix)) + .child(message_content) + .map(|parent| { + if tool_uses.is_empty() { + return parent; + } + + parent.child( + v_flex().children( + tool_uses + .into_iter() + .map(|tool_use| self.render_tool_use(tool_use, cx)), + ), + ) + }), Role::System => div().id(("message-container", ix)).py_1().px_2().child( v_flex() .bg(colors.editor_background) @@ -343,6 +361,93 @@ impl ActiveThread { styled_message.into_any() } + + fn render_tool_use(&self, tool_use: ToolUse, cx: &mut Context) -> impl IntoElement { + let is_open = self + .expanded_tool_uses + .get(&tool_use.id) + .copied() + .unwrap_or_default(); + + v_flex().px_2p5().child( + v_flex() + .gap_1() + .bg(cx.theme().colors().editor_background) + .rounded_lg() + .border_1() + .border_color(cx.theme().colors().border) + .shadow_sm() + .child( + h_flex() + .justify_between() + .py_1() + .px_2() + .bg(cx.theme().colors().editor_foreground.opacity(0.05)) + .when(is_open, |element| element.border_b_1()) + .border_color(cx.theme().colors().border) + .rounded_t(px(6.)) + .child( + h_flex() + .gap_2() + .child(Disclosure::new("tool-use-disclosure", is_open).on_click( + cx.listener({ + let tool_use_id = tool_use.id.clone(); + move |this, _event, _window, _cx| { + let is_open = this + .expanded_tool_uses + .entry(tool_use_id.clone()) + .or_insert(false); + + *is_open = !*is_open; + } + }), + )) + .child(Label::new(tool_use.name)), + ) + .child(Label::new(match tool_use.status { + ToolUseStatus::Pending => "Pending", + ToolUseStatus::Running => "Running", + ToolUseStatus::Finished(_) => "Finished", + ToolUseStatus::Error(_) => "Error", + })), + ) + .map(|parent| { + if !is_open { + return parent; + } + + parent.child( + v_flex() + .gap_2() + .p_2p5() + .child( + v_flex() + .gap_0p5() + .child(Label::new("Input:")) + .child(Label::new( + serde_json::to_string_pretty(&tool_use.input) + .unwrap_or_default(), + )), + ) + .map(|parent| match tool_use.status { + ToolUseStatus::Finished(output) => parent.child( + v_flex() + .gap_0p5() + .child(Label::new("Result:")) + .child(Label::new(output)), + ), + ToolUseStatus::Error(err) => parent.child( + v_flex() + .gap_0p5() + .child(Label::new("Error:")) + .child(Label::new(err)), + ), + ToolUseStatus::Pending | ToolUseStatus::Running => parent, + }), + ) + }), + ) + } } impl Render for ActiveThread { diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index e742b46fe3578af54387f7fc333b8b3f3b222157..47c1222807033abae2904462de796d6dedcb2ec4 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -59,6 +59,22 @@ pub struct Message { pub text: String, } +#[derive(Debug)] +pub struct ToolUse { + pub id: LanguageModelToolUseId, + pub name: SharedString, + pub status: ToolUseStatus, + pub input: serde_json::Value, +} + +#[derive(Debug, Clone)] +pub enum ToolUseStatus { + Pending, + Running, + Finished(SharedString), + Error(SharedString), +} + /// A thread of conversation with the LLM. pub struct Thread { id: ThreadId, @@ -192,6 +208,61 @@ impl Thread { self.pending_tool_uses_by_id.values().collect() } + pub fn tool_uses_for_message(&self, id: MessageId) -> Vec { + let Some(tool_uses_for_message) = &self.tool_uses_by_message.get(&id) else { + return Vec::new(); + }; + + // The tool use was requested by an Assistant message, so we need to + // look for the tool results on the next user message. + let next_user_message = MessageId(id.0 + 1); + + let empty = Vec::new(); + let tool_results_for_message = self + .tool_results_by_message + .get(&next_user_message) + .unwrap_or_else(|| &empty); + + let mut tool_uses = Vec::new(); + + for tool_use in tool_uses_for_message.iter() { + let tool_result = tool_results_for_message + .iter() + .find(|tool_result| tool_result.tool_use_id == tool_use.id); + + let status = (|| { + if let Some(tool_result) = tool_result { + return if tool_result.is_error { + ToolUseStatus::Error(tool_result.content.clone().into()) + } else { + ToolUseStatus::Finished(tool_result.content.clone().into()) + }; + } + + if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) { + return match pending_tool_use.status { + PendingToolUseStatus::Idle => ToolUseStatus::Pending, + PendingToolUseStatus::Running { .. } => ToolUseStatus::Running, + PendingToolUseStatus::Error(ref err) => { + ToolUseStatus::Error(err.clone().into()) + } + }; + } + + ToolUseStatus::Pending + })(); + + tool_uses.push(ToolUse { + id: tool_use.id.clone(), + name: tool_use.name.clone().into(), + input: tool_use.input.clone(), + status, + }) + } + + tool_uses + } + pub fn insert_user_message( &mut self, text: impl Into, @@ -537,6 +608,7 @@ impl Thread { content: output, is_error: false, }); + thread.pending_tool_uses_by_id.remove(&tool_use_id); cx.emit(ThreadEvent::ToolFinished { tool_use_id }); }