assistant2: Improve Lua script rendering (#26389)

Marshall Bowers created

This PR improves the rendering of Lua scripts provided to the scripting
tool.

We now render them in code blocks with syntax highlighting:

<img width="1297" alt="Screenshot 2025-03-10 at 2 40 51 PM"
src="https://github.com/user-attachments/assets/def65b5c-86a8-490f-aaa5-5cc1687fe01e"
/>

Release Notes:

- N/A

Change summary

crates/assistant2/src/active_thread.rs      | 144 ++++++++++++++++++++--
crates/assistant2/src/thread.rs             |  19 ++
crates/assistant2/src/tool_use.rs           |   6 
crates/scripting_tool/src/scripting_tool.rs |   4 
4 files changed, 150 insertions(+), 23 deletions(-)

Detailed changes

crates/assistant2/src/active_thread.rs 🔗

@@ -10,6 +10,7 @@ use gpui::{
 use language::{Buffer, LanguageRegistry};
 use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
 use markdown::{Markdown, MarkdownStyle};
+use scripting_tool::{ScriptingTool, ScriptingToolInput};
 use settings::Settings as _;
 use theme::ThemeSettings;
 use ui::{prelude::*, Disclosure, KeyBinding};
@@ -28,6 +29,7 @@ pub struct ActiveThread {
     messages: Vec<MessageId>,
     list_state: ListState,
     rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
+    rendered_scripting_tool_uses: HashMap<LanguageModelToolUseId, Entity<Markdown>>,
     editing_message: Option<(MessageId, EditMessageState)>,
     expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
     last_error: Option<ThreadError>,
@@ -58,6 +60,7 @@ impl ActiveThread {
             save_thread_task: None,
             messages: Vec::new(),
             rendered_messages_by_id: HashMap::default(),
+            rendered_scripting_tool_uses: HashMap::default(),
             expanded_tool_uses: HashMap::default(),
             list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
                 let this = cx.entity().downgrade();
@@ -296,7 +299,27 @@ impl ActiveThread {
                     thread.use_pending_tools(cx);
                 });
             }
-            ThreadEvent::ToolFinished { .. } => {
+            ThreadEvent::ToolFinished {
+                pending_tool_use, ..
+            } => {
+                if let Some(tool_use) = pending_tool_use {
+                    if tool_use.name.as_ref() == ScriptingTool::NAME {
+                        let lua_script =
+                            serde_json::from_value::<ScriptingToolInput>(tool_use.input.clone())
+                                .map(|input| input.lua_script)
+                                .unwrap_or_default();
+
+                        let lua_script = self.render_markdown(
+                            format!("```lua\n{lua_script}\n```").into(),
+                            window,
+                            cx,
+                        );
+
+                        self.rendered_scripting_tool_uses
+                            .insert(tool_use.id.clone(), lua_script.clone());
+                    }
+                }
+
                 if self.thread.read(cx).all_tools_finished() {
                     let model_registry = LanguageModelRegistry::read_global(cx);
                     if let Some(model) = model_registry.active_model() {
@@ -306,14 +329,6 @@ impl ActiveThread {
                     }
                 }
             }
-            ThreadEvent::ScriptFinished => {
-                let model_registry = LanguageModelRegistry::read_global(cx);
-                if let Some(model) = model_registry.active_model() {
-                    self.thread.update(cx, |thread, cx| {
-                        thread.send_to_model(model, RequestKind::Chat, false, cx);
-                    });
-                }
-            }
         }
     }
 
@@ -662,8 +677,13 @@ impl ActiveThread {
                         .pl_1()
                         .pr_2()
                         .bg(cx.theme().colors().editor_foreground.opacity(0.02))
-                        .when(is_open, |element| element.border_b_1().rounded_t(px(6.)))
-                        .when(!is_open, |element| element.rounded_md())
+                        .map(|element| {
+                            if is_open {
+                                element.border_b_1().rounded_t(px(6.))
+                            } else {
+                                element.rounded_md()
+                            }
+                        })
                         .border_color(cx.theme().colors().border)
                         .child(
                             h_flex()
@@ -743,8 +763,106 @@ impl ActiveThread {
         tool_use: ToolUse,
         cx: &mut Context<Self>,
     ) -> impl IntoElement {
-        // TODO: Add custom rendering for scripting tool uses.
-        self.render_tool_use(tool_use, cx)
+        let is_open = self
+            .expanded_tool_uses
+            .get(&tool_use.id)
+            .copied()
+            .unwrap_or_default();
+
+        div().px_2p5().child(
+            v_flex()
+                .gap_1()
+                .rounded_lg()
+                .border_1()
+                .border_color(cx.theme().colors().border)
+                .child(
+                    h_flex()
+                        .justify_between()
+                        .py_0p5()
+                        .pl_1()
+                        .pr_2()
+                        .bg(cx.theme().colors().editor_foreground.opacity(0.02))
+                        .map(|element| {
+                            if is_open {
+                                element.border_b_1().rounded_t(px(6.))
+                            } else {
+                                element.rounded_md()
+                            }
+                        })
+                        .border_color(cx.theme().colors().border)
+                        .child(
+                            h_flex()
+                                .gap_1()
+                                .child(Disclosure::new("tool-use-disclosure", is_open).on_click(
+                                    cx.listener({
+                                        let tool_use_id = tool_use.id.clone();
+                                        move |this, _event, _window, _cx| {
+                                            let is_open = this
+                                                .expanded_tool_uses
+                                                .entry(tool_use_id.clone())
+                                                .or_insert(false);
+
+                                            *is_open = !*is_open;
+                                        }
+                                    }),
+                                ))
+                                .child(Label::new(tool_use.name)),
+                        )
+                        .child(
+                            Label::new(match tool_use.status {
+                                ToolUseStatus::Pending => "Pending",
+                                ToolUseStatus::Running => "Running",
+                                ToolUseStatus::Finished(_) => "Finished",
+                                ToolUseStatus::Error(_) => "Error",
+                            })
+                            .size(LabelSize::XSmall)
+                            .buffer_font(cx),
+                        ),
+                )
+                .map(|parent| {
+                    if !is_open {
+                        return parent;
+                    }
+
+                    let lua_script_markdown =
+                        self.rendered_scripting_tool_uses.get(&tool_use.id).cloned();
+
+                    parent.child(
+                        v_flex()
+                            .child(
+                                v_flex()
+                                    .gap_0p5()
+                                    .py_1()
+                                    .px_2p5()
+                                    .border_b_1()
+                                    .border_color(cx.theme().colors().border)
+                                    .child(Label::new("Input:"))
+                                    .children(lua_script_markdown.map(|lua_script| {
+                                        div().p_2p5().text_ui(cx).child(lua_script)
+                                    })),
+                            )
+                            .map(|parent| match tool_use.status {
+                                ToolUseStatus::Finished(output) => parent.child(
+                                    v_flex()
+                                        .gap_0p5()
+                                        .py_1()
+                                        .px_2p5()
+                                        .child(Label::new("Result:"))
+                                        .child(Label::new(output)),
+                                ),
+                                ToolUseStatus::Error(err) => parent.child(
+                                    v_flex()
+                                        .gap_0p5()
+                                        .py_1()
+                                        .px_2p5()
+                                        .child(Label::new("Error:"))
+                                        .child(Label::new(err)),
+                                ),
+                                ToolUseStatus::Pending | ToolUseStatus::Running => parent,
+                            }),
+                    )
+                }),
+        )
     }
 }
 

crates/assistant2/src/thread.rs 🔗

@@ -20,7 +20,7 @@ use uuid::Uuid;
 
 use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
 use crate::thread_store::SavedThread;
-use crate::tool_use::{ToolUse, ToolUseState};
+use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
 
 #[derive(Debug, Clone, Copy)]
 pub enum RequestKind {
@@ -653,11 +653,14 @@ impl Thread {
                 let output = output.await;
                 thread
                     .update(&mut cx, |thread, cx| {
-                        thread
+                        let pending_tool_use = thread
                             .tool_use
                             .insert_tool_output(tool_use_id.clone(), output);
 
-                        cx.emit(ThreadEvent::ToolFinished { tool_use_id });
+                        cx.emit(ThreadEvent::ToolFinished {
+                            tool_use_id,
+                            pending_tool_use,
+                        });
                     })
                     .ok();
             }
@@ -679,11 +682,14 @@ impl Thread {
                 let output = output.await;
                 thread
                     .update(&mut cx, |thread, cx| {
-                        thread
+                        let pending_tool_use = thread
                             .scripting_tool_use
                             .insert_tool_output(tool_use_id.clone(), output);
 
-                        cx.emit(ThreadEvent::ToolFinished { tool_use_id });
+                        cx.emit(ThreadEvent::ToolFinished {
+                            tool_use_id,
+                            pending_tool_use,
+                        });
                     })
                     .ok();
             }
@@ -742,8 +748,9 @@ pub enum ThreadEvent {
     ToolFinished {
         #[allow(unused)]
         tool_use_id: LanguageModelToolUseId,
+        /// The pending tool use that corresponds to this tool.
+        pending_tool_use: Option<PendingToolUse>,
     },
-    ScriptFinished,
 }
 
 impl EventEmitter<ThreadEvent> for Thread {}

crates/assistant2/src/tool_use.rs 🔗

@@ -202,7 +202,7 @@ impl ToolUseState {
         &mut self,
         tool_use_id: LanguageModelToolUseId,
         output: Result<String>,
-    ) {
+    ) -> Option<PendingToolUse> {
         match output {
             Ok(output) => {
                 self.tool_results.insert(
@@ -213,7 +213,7 @@ impl ToolUseState {
                         is_error: false,
                     },
                 );
-                self.pending_tool_uses_by_id.remove(&tool_use_id);
+                self.pending_tool_uses_by_id.remove(&tool_use_id)
             }
             Err(err) => {
                 self.tool_results.insert(
@@ -228,6 +228,8 @@ impl ToolUseState {
                 if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
                     tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
                 }
+
+                self.pending_tool_uses_by_id.get(&tool_use_id).cloned()
             }
         }
     }

crates/scripting_tool/src/scripting_tool.rs 🔗

@@ -8,8 +8,8 @@ use schemars::JsonSchema;
 use serde::Deserialize;
 
 #[derive(Debug, Deserialize, JsonSchema)]
-struct ScriptingToolInput {
-    lua_script: String,
+pub struct ScriptingToolInput {
+    pub lua_script: String,
 }
 
 pub struct ScriptingTool;