Render messages as early as possible to show progress (#11569)

Kyle Kelley created

This shows "Researching..." as placeholder text as early as possible so
that the user can see the model is working on reading/researching/etc.

This also adds on an `Option<Value>` to the `render_running` function so
that tools can hopefully render based on partially completed JSON (still
to come).

Release Notes:

- N/A

Change summary

Cargo.lock                                        |  1 
crates/assistant2/src/assistant2.rs               |  7 ++
crates/assistant2/src/tools/project_index.rs      | 13 ++++-
crates/assistant_tooling/Cargo.toml               |  1 
crates/assistant_tooling/src/assistant_tooling.rs |  3 
crates/assistant_tooling/src/tool_registry.rs     | 36 ++++++++++++----
6 files changed, 46 insertions(+), 15 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -423,6 +423,7 @@ dependencies = [
  "serde_json",
  "settings",
  "sum_tree",
+ "ui",
  "unindent",
  "util",
 ]

crates/assistant2/src/assistant2.rs 🔗

@@ -16,7 +16,8 @@ use crate::{
 use ::ui::{div, prelude::*, Color, Tooltip, ViewContext};
 use anyhow::{Context, Result};
 use assistant_tooling::{
-    AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry, UserAttachment,
+    tool_running_placeholder, AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry,
+    UserAttachment,
 };
 use client::{proto, Client, UserStore};
 use collections::HashMap;
@@ -864,6 +865,10 @@ impl AssistantChat {
                     }
                 }
 
+                if message_elements.is_empty() {
+                    message_elements.push(tool_running_placeholder());
+                }
+
                 div()
                     .when(is_first, |this| this.pt(padding))
                     .child(

crates/assistant2/src/tools/project_index.rs 🔗

@@ -6,6 +6,7 @@ use project::ProjectPath;
 use schemars::JsonSchema;
 use semantic_index::{ProjectIndex, Status};
 use serde::Deserialize;
+use serde_json::Value;
 use std::{fmt::Write as _, ops::Range};
 use ui::{div, prelude::*, CollapsibleContainer, Color, Icon, IconName, Label, WindowContext};
 
@@ -202,8 +203,14 @@ impl LanguageModelTool for ProjectIndexTool {
         cx.new_view(|_cx| ProjectIndexView::new(input, output))
     }
 
-    fn render_running(_: &mut WindowContext) -> impl IntoElement {
-        CollapsibleContainer::new(ElementId::Name(nanoid::nanoid!().into()), false)
-            .start_slot("Searching code base")
+    fn render_running(arguments: &Option<Value>, _: &mut WindowContext) -> impl IntoElement {
+        let text: String = arguments
+            .as_ref()
+            .and_then(|arguments| arguments.get("query"))
+            .and_then(|query| query.as_str())
+            .map(|query| format!("Searching for: {}", query))
+            .unwrap_or_else(|| "Preparing search...".to_string());
+
+        CollapsibleContainer::new(ElementId::Name(nanoid::nanoid!().into()), false).start_slot(text)
     }
 }

crates/assistant_tooling/Cargo.toml 🔗

@@ -21,6 +21,7 @@ schemars.workspace = true
 serde.workspace = true
 serde_json.workspace = true
 sum_tree.workspace = true
+ui.workspace = true
 util.workspace = true
 
 [dev-dependencies]

crates/assistant_tooling/src/assistant_tooling.rs 🔗

@@ -5,5 +5,6 @@ mod tool_registry;
 pub use attachment_registry::{AttachmentRegistry, LanguageModelAttachment, UserAttachment};
 pub use project_context::ProjectContext;
 pub use tool_registry::{
-    LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition, ToolOutput, ToolRegistry,
+    tool_running_placeholder, LanguageModelTool, ToolFunctionCall, ToolFunctionDefinition,
+    ToolOutput, ToolRegistry,
 };

crates/assistant_tooling/src/tool_registry.rs 🔗

@@ -4,6 +4,7 @@ use gpui::{
 };
 use schemars::{schema::RootSchema, schema_for, JsonSchema};
 use serde::Deserialize;
+use serde_json::Value;
 use std::{
     any::TypeId,
     collections::HashMap,
@@ -78,17 +79,22 @@ pub trait LanguageModelTool {
     /// Executes the tool with the given input.
     fn execute(&self, input: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>>;
 
+    /// A view of the output of running the tool, for displaying to the user.
     fn output_view(
         input: Self::Input,
         output: Result<Self::Output>,
         cx: &mut WindowContext,
     ) -> View<Self::View>;
 
-    fn render_running(_cx: &mut WindowContext) -> impl IntoElement {
-        div()
+    fn render_running(_arguments: &Option<Value>, _cx: &mut WindowContext) -> impl IntoElement {
+        tool_running_placeholder()
     }
 }
 
+pub fn tool_running_placeholder() -> AnyElement {
+    ui::Label::new("Researching...").into_any_element()
+}
+
 pub trait ToolOutput: Sized {
     fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
 }
@@ -97,7 +103,7 @@ struct RegisteredTool {
     enabled: AtomicBool,
     type_id: TypeId,
     call: Box<dyn Fn(&ToolFunctionCall, &mut WindowContext) -> Task<Result<ToolFunctionCall>>>,
-    render_running: fn(&mut WindowContext) -> gpui::AnyElement,
+    render_running: fn(&ToolFunctionCall, &mut WindowContext) -> gpui::AnyElement,
     definition: ToolFunctionDefinition,
 }
 
@@ -144,11 +150,15 @@ impl ToolRegistry {
                 .p_2()
                 .child(result.into_any_element(&tool_call.name))
                 .into_any_element(),
-            None => self
-                .registered_tools
-                .get(&tool_call.name)
-                .map(|tool| (tool.render_running)(cx))
-                .unwrap_or_else(|| div().into_any_element()),
+            None => {
+                let tool = self.registered_tools.get(&tool_call.name);
+
+                if let Some(tool) = tool {
+                    (tool.render_running)(&tool_call, cx)
+                } else {
+                    tool_running_placeholder()
+                }
+            }
         }
     }
 
@@ -205,8 +215,14 @@ impl ToolRegistry {
 
         return Ok(());
 
-        fn render_running<T: LanguageModelTool>(cx: &mut WindowContext) -> AnyElement {
-            T::render_running(cx).into_any_element()
+        fn render_running<T: LanguageModelTool>(
+            tool_call: &ToolFunctionCall,
+            cx: &mut WindowContext,
+        ) -> AnyElement {
+            // Attempt to parse the string arguments that are JSON as a JSON value
+            let maybe_arguments = serde_json::to_value(tool_call.arguments.clone()).ok();
+
+            T::render_running(&maybe_arguments, cx).into_any_element()
         }
 
         fn generate<T: LanguageModelTool>(