diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index 0e8471a4a2c15d0f0db4cb120265e3d5c8a72e4e..15fc434eebf325e8dd9edd74fb85d4df3340fa42 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -10,6 +10,7 @@ use gpui::{ use language::{Buffer, LanguageRegistry}; use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role}; use markdown::{Markdown, MarkdownStyle}; +use scripting_tool::{ScriptingTool, ScriptingToolInput}; use settings::Settings as _; use theme::ThemeSettings; use ui::{prelude::*, Disclosure, KeyBinding}; @@ -28,6 +29,7 @@ pub struct ActiveThread { messages: Vec, list_state: ListState, rendered_messages_by_id: HashMap>, + rendered_scripting_tool_uses: HashMap>, editing_message: Option<(MessageId, EditMessageState)>, expanded_tool_uses: HashMap, last_error: Option, @@ -58,6 +60,7 @@ impl ActiveThread { save_thread_task: None, messages: Vec::new(), rendered_messages_by_id: HashMap::default(), + rendered_scripting_tool_uses: HashMap::default(), expanded_tool_uses: HashMap::default(), list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), { let this = cx.entity().downgrade(); @@ -296,7 +299,27 @@ impl ActiveThread { thread.use_pending_tools(cx); }); } - ThreadEvent::ToolFinished { .. } => { + ThreadEvent::ToolFinished { + pending_tool_use, .. + } => { + if let Some(tool_use) = pending_tool_use { + if tool_use.name.as_ref() == ScriptingTool::NAME { + let lua_script = + serde_json::from_value::(tool_use.input.clone()) + .map(|input| input.lua_script) + .unwrap_or_default(); + + let lua_script = self.render_markdown( + format!("```lua\n{lua_script}\n```").into(), + window, + cx, + ); + + self.rendered_scripting_tool_uses + .insert(tool_use.id.clone(), lua_script.clone()); + } + } + if self.thread.read(cx).all_tools_finished() { let model_registry = LanguageModelRegistry::read_global(cx); if let Some(model) = model_registry.active_model() { @@ -306,14 +329,6 @@ impl ActiveThread { } } } - ThreadEvent::ScriptFinished => { - let model_registry = LanguageModelRegistry::read_global(cx); - if let Some(model) = model_registry.active_model() { - self.thread.update(cx, |thread, cx| { - thread.send_to_model(model, RequestKind::Chat, false, cx); - }); - } - } } } @@ -662,8 +677,13 @@ impl ActiveThread { .pl_1() .pr_2() .bg(cx.theme().colors().editor_foreground.opacity(0.02)) - .when(is_open, |element| element.border_b_1().rounded_t(px(6.))) - .when(!is_open, |element| element.rounded_md()) + .map(|element| { + if is_open { + element.border_b_1().rounded_t(px(6.)) + } else { + element.rounded_md() + } + }) .border_color(cx.theme().colors().border) .child( h_flex() @@ -743,8 +763,106 @@ impl ActiveThread { tool_use: ToolUse, cx: &mut Context, ) -> impl IntoElement { - // TODO: Add custom rendering for scripting tool uses. - self.render_tool_use(tool_use, cx) + let is_open = self + .expanded_tool_uses + .get(&tool_use.id) + .copied() + .unwrap_or_default(); + + div().px_2p5().child( + v_flex() + .gap_1() + .rounded_lg() + .border_1() + .border_color(cx.theme().colors().border) + .child( + h_flex() + .justify_between() + .py_0p5() + .pl_1() + .pr_2() + .bg(cx.theme().colors().editor_foreground.opacity(0.02)) + .map(|element| { + if is_open { + element.border_b_1().rounded_t(px(6.)) + } else { + element.rounded_md() + } + }) + .border_color(cx.theme().colors().border) + .child( + h_flex() + .gap_1() + .child(Disclosure::new("tool-use-disclosure", is_open).on_click( + cx.listener({ + let tool_use_id = tool_use.id.clone(); + move |this, _event, _window, _cx| { + let is_open = this + .expanded_tool_uses + .entry(tool_use_id.clone()) + .or_insert(false); + + *is_open = !*is_open; + } + }), + )) + .child(Label::new(tool_use.name)), + ) + .child( + Label::new(match tool_use.status { + ToolUseStatus::Pending => "Pending", + ToolUseStatus::Running => "Running", + ToolUseStatus::Finished(_) => "Finished", + ToolUseStatus::Error(_) => "Error", + }) + .size(LabelSize::XSmall) + .buffer_font(cx), + ), + ) + .map(|parent| { + if !is_open { + return parent; + } + + let lua_script_markdown = + self.rendered_scripting_tool_uses.get(&tool_use.id).cloned(); + + parent.child( + v_flex() + .child( + v_flex() + .gap_0p5() + .py_1() + .px_2p5() + .border_b_1() + .border_color(cx.theme().colors().border) + .child(Label::new("Input:")) + .children(lua_script_markdown.map(|lua_script| { + div().p_2p5().text_ui(cx).child(lua_script) + })), + ) + .map(|parent| match tool_use.status { + ToolUseStatus::Finished(output) => parent.child( + v_flex() + .gap_0p5() + .py_1() + .px_2p5() + .child(Label::new("Result:")) + .child(Label::new(output)), + ), + ToolUseStatus::Error(err) => parent.child( + v_flex() + .gap_0p5() + .py_1() + .px_2p5() + .child(Label::new("Error:")) + .child(Label::new(err)), + ), + ToolUseStatus::Pending | ToolUseStatus::Running => parent, + }), + ) + }), + ) } } diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index d4b01b987841decd369fb068916f840735428f65..2f1425fccc9053eb58cd8bc5a15e9b9e71a9b468 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -20,7 +20,7 @@ use uuid::Uuid; use crate::context::{attach_context_to_message, ContextId, ContextSnapshot}; use crate::thread_store::SavedThread; -use crate::tool_use::{ToolUse, ToolUseState}; +use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState}; #[derive(Debug, Clone, Copy)] pub enum RequestKind { @@ -653,11 +653,14 @@ impl Thread { let output = output.await; thread .update(&mut cx, |thread, cx| { - thread + let pending_tool_use = thread .tool_use .insert_tool_output(tool_use_id.clone(), output); - cx.emit(ThreadEvent::ToolFinished { tool_use_id }); + cx.emit(ThreadEvent::ToolFinished { + tool_use_id, + pending_tool_use, + }); }) .ok(); } @@ -679,11 +682,14 @@ impl Thread { let output = output.await; thread .update(&mut cx, |thread, cx| { - thread + let pending_tool_use = thread .scripting_tool_use .insert_tool_output(tool_use_id.clone(), output); - cx.emit(ThreadEvent::ToolFinished { tool_use_id }); + cx.emit(ThreadEvent::ToolFinished { + tool_use_id, + pending_tool_use, + }); }) .ok(); } @@ -742,8 +748,9 @@ pub enum ThreadEvent { ToolFinished { #[allow(unused)] tool_use_id: LanguageModelToolUseId, + /// The pending tool use that corresponds to this tool. + pending_tool_use: Option, }, - ScriptFinished, } impl EventEmitter for Thread {} diff --git a/crates/assistant2/src/tool_use.rs b/crates/assistant2/src/tool_use.rs index 4161797dc2d7aab174caac37eeb61b0129d00477..565728ef3fbdfff20bb18bed415cb3e7cf794e7e 100644 --- a/crates/assistant2/src/tool_use.rs +++ b/crates/assistant2/src/tool_use.rs @@ -202,7 +202,7 @@ impl ToolUseState { &mut self, tool_use_id: LanguageModelToolUseId, output: Result, - ) { + ) -> Option { match output { Ok(output) => { self.tool_results.insert( @@ -213,7 +213,7 @@ impl ToolUseState { is_error: false, }, ); - self.pending_tool_uses_by_id.remove(&tool_use_id); + self.pending_tool_uses_by_id.remove(&tool_use_id) } Err(err) => { self.tool_results.insert( @@ -228,6 +228,8 @@ impl ToolUseState { if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) { tool_use.status = PendingToolUseStatus::Error(err.to_string().into()); } + + self.pending_tool_uses_by_id.get(&tool_use_id).cloned() } } } diff --git a/crates/scripting_tool/src/scripting_tool.rs b/crates/scripting_tool/src/scripting_tool.rs index 885240c9a6b8547264303d057baec4b1c41af3d1..efb5c8855518351a118130d07b48a77b2db9de6f 100644 --- a/crates/scripting_tool/src/scripting_tool.rs +++ b/crates/scripting_tool/src/scripting_tool.rs @@ -8,8 +8,8 @@ use schemars::JsonSchema; use serde::Deserialize; #[derive(Debug, Deserialize, JsonSchema)] -struct ScriptingToolInput { - lua_script: String, +pub struct ScriptingToolInput { + pub lua_script: String, } pub struct ScriptingTool;