@@ -10,6 +10,7 @@ use gpui::{
use language::{Buffer, LanguageRegistry};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
use markdown::{Markdown, MarkdownStyle};
+use scripting_tool::{ScriptingTool, ScriptingToolInput};
use settings::Settings as _;
use theme::ThemeSettings;
use ui::{prelude::*, Disclosure, KeyBinding};
@@ -28,6 +29,7 @@ pub struct ActiveThread {
messages: Vec<MessageId>,
list_state: ListState,
rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
+ rendered_scripting_tool_uses: HashMap<LanguageModelToolUseId, Entity<Markdown>>,
editing_message: Option<(MessageId, EditMessageState)>,
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
last_error: Option<ThreadError>,
@@ -58,6 +60,7 @@ impl ActiveThread {
save_thread_task: None,
messages: Vec::new(),
rendered_messages_by_id: HashMap::default(),
+ rendered_scripting_tool_uses: HashMap::default(),
expanded_tool_uses: HashMap::default(),
list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
let this = cx.entity().downgrade();
@@ -296,7 +299,27 @@ impl ActiveThread {
thread.use_pending_tools(cx);
});
}
- ThreadEvent::ToolFinished { .. } => {
+ ThreadEvent::ToolFinished {
+ pending_tool_use, ..
+ } => {
+ if let Some(tool_use) = pending_tool_use {
+ if tool_use.name.as_ref() == ScriptingTool::NAME {
+ let lua_script =
+ serde_json::from_value::<ScriptingToolInput>(tool_use.input.clone())
+ .map(|input| input.lua_script)
+ .unwrap_or_default();
+
+ let lua_script = self.render_markdown(
+ format!("```lua\n{lua_script}\n```").into(),
+ window,
+ cx,
+ );
+
+ self.rendered_scripting_tool_uses
+ .insert(tool_use.id.clone(), lua_script.clone());
+ }
+ }
+
if self.thread.read(cx).all_tools_finished() {
let model_registry = LanguageModelRegistry::read_global(cx);
if let Some(model) = model_registry.active_model() {
@@ -306,14 +329,6 @@ impl ActiveThread {
}
}
}
- ThreadEvent::ScriptFinished => {
- let model_registry = LanguageModelRegistry::read_global(cx);
- if let Some(model) = model_registry.active_model() {
- self.thread.update(cx, |thread, cx| {
- thread.send_to_model(model, RequestKind::Chat, false, cx);
- });
- }
- }
}
}
@@ -662,8 +677,13 @@ impl ActiveThread {
.pl_1()
.pr_2()
.bg(cx.theme().colors().editor_foreground.opacity(0.02))
- .when(is_open, |element| element.border_b_1().rounded_t(px(6.)))
- .when(!is_open, |element| element.rounded_md())
+ .map(|element| {
+ if is_open {
+ element.border_b_1().rounded_t(px(6.))
+ } else {
+ element.rounded_md()
+ }
+ })
.border_color(cx.theme().colors().border)
.child(
h_flex()
@@ -743,8 +763,106 @@ impl ActiveThread {
tool_use: ToolUse,
cx: &mut Context<Self>,
) -> impl IntoElement {
- // TODO: Add custom rendering for scripting tool uses.
- self.render_tool_use(tool_use, cx)
+ let is_open = self
+ .expanded_tool_uses
+ .get(&tool_use.id)
+ .copied()
+ .unwrap_or_default();
+
+ div().px_2p5().child(
+ v_flex()
+ .gap_1()
+ .rounded_lg()
+ .border_1()
+ .border_color(cx.theme().colors().border)
+ .child(
+ h_flex()
+ .justify_between()
+ .py_0p5()
+ .pl_1()
+ .pr_2()
+ .bg(cx.theme().colors().editor_foreground.opacity(0.02))
+ .map(|element| {
+ if is_open {
+ element.border_b_1().rounded_t(px(6.))
+ } else {
+ element.rounded_md()
+ }
+ })
+ .border_color(cx.theme().colors().border)
+ .child(
+ h_flex()
+ .gap_1()
+ .child(Disclosure::new("tool-use-disclosure", is_open).on_click(
+ cx.listener({
+ let tool_use_id = tool_use.id.clone();
+ move |this, _event, _window, _cx| {
+ let is_open = this
+ .expanded_tool_uses
+ .entry(tool_use_id.clone())
+ .or_insert(false);
+
+ *is_open = !*is_open;
+ }
+ }),
+ ))
+ .child(Label::new(tool_use.name)),
+ )
+ .child(
+ Label::new(match tool_use.status {
+ ToolUseStatus::Pending => "Pending",
+ ToolUseStatus::Running => "Running",
+ ToolUseStatus::Finished(_) => "Finished",
+ ToolUseStatus::Error(_) => "Error",
+ })
+ .size(LabelSize::XSmall)
+ .buffer_font(cx),
+ ),
+ )
+ .map(|parent| {
+ if !is_open {
+ return parent;
+ }
+
+ let lua_script_markdown =
+ self.rendered_scripting_tool_uses.get(&tool_use.id).cloned();
+
+ parent.child(
+ v_flex()
+ .child(
+ v_flex()
+ .gap_0p5()
+ .py_1()
+ .px_2p5()
+ .border_b_1()
+ .border_color(cx.theme().colors().border)
+ .child(Label::new("Input:"))
+ .children(lua_script_markdown.map(|lua_script| {
+ div().p_2p5().text_ui(cx).child(lua_script)
+ })),
+ )
+ .map(|parent| match tool_use.status {
+ ToolUseStatus::Finished(output) => parent.child(
+ v_flex()
+ .gap_0p5()
+ .py_1()
+ .px_2p5()
+ .child(Label::new("Result:"))
+ .child(Label::new(output)),
+ ),
+ ToolUseStatus::Error(err) => parent.child(
+ v_flex()
+ .gap_0p5()
+ .py_1()
+ .px_2p5()
+ .child(Label::new("Error:"))
+ .child(Label::new(err)),
+ ),
+ ToolUseStatus::Pending | ToolUseStatus::Running => parent,
+ }),
+ )
+ }),
+ )
}
}
@@ -20,7 +20,7 @@ use uuid::Uuid;
use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
use crate::thread_store::SavedThread;
-use crate::tool_use::{ToolUse, ToolUseState};
+use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
#[derive(Debug, Clone, Copy)]
pub enum RequestKind {
@@ -653,11 +653,14 @@ impl Thread {
let output = output.await;
thread
.update(&mut cx, |thread, cx| {
- thread
+ let pending_tool_use = thread
.tool_use
.insert_tool_output(tool_use_id.clone(), output);
- cx.emit(ThreadEvent::ToolFinished { tool_use_id });
+ cx.emit(ThreadEvent::ToolFinished {
+ tool_use_id,
+ pending_tool_use,
+ });
})
.ok();
}
@@ -679,11 +682,14 @@ impl Thread {
let output = output.await;
thread
.update(&mut cx, |thread, cx| {
- thread
+ let pending_tool_use = thread
.scripting_tool_use
.insert_tool_output(tool_use_id.clone(), output);
- cx.emit(ThreadEvent::ToolFinished { tool_use_id });
+ cx.emit(ThreadEvent::ToolFinished {
+ tool_use_id,
+ pending_tool_use,
+ });
})
.ok();
}
@@ -742,8 +748,9 @@ pub enum ThreadEvent {
ToolFinished {
#[allow(unused)]
tool_use_id: LanguageModelToolUseId,
+ /// The pending tool use that corresponds to this tool.
+ pending_tool_use: Option<PendingToolUse>,
},
- ScriptFinished,
}
impl EventEmitter<ThreadEvent> for Thread {}
@@ -202,7 +202,7 @@ impl ToolUseState {
&mut self,
tool_use_id: LanguageModelToolUseId,
output: Result<String>,
- ) {
+ ) -> Option<PendingToolUse> {
match output {
Ok(output) => {
self.tool_results.insert(
@@ -213,7 +213,7 @@ impl ToolUseState {
is_error: false,
},
);
- self.pending_tool_uses_by_id.remove(&tool_use_id);
+ self.pending_tool_uses_by_id.remove(&tool_use_id)
}
Err(err) => {
self.tool_results.insert(
@@ -228,6 +228,8 @@ impl ToolUseState {
if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
}
+
+ self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
}
}
}