Streaming tool calls (#29179)

Richard Feldman and Marshall Bowers created

https://github.com/user-attachments/assets/7854a737-ef83-414c-b397-45122e4f32e8



Release Notes:

- Create file and edit file tools now stream their tool descriptions, so
you can see what they're doing sooner.

---------

Co-authored-by: Marshall Bowers <git@maxdeviant.com>

Change summary

Cargo.lock                                          |  7 +
Cargo.toml                                          |  1 
crates/agent/src/active_thread.rs                   | 72 +++++++----
crates/agent/src/thread.rs                          | 22 +++
crates/agent/src/tool_use.rs                        | 81 +++++++++--
crates/assistant_tool/src/assistant_tool.rs         |  8 +
crates/assistant_tools/src/create_file_tool.rs      | 76 +++++++++++
crates/assistant_tools/src/edit_file_tool.rs        | 96 +++++++++++++++
crates/eval/src/example.rs                          |  1 
crates/language_model/src/language_model.rs         |  1 
crates/language_models/Cargo.toml                   |  1 
crates/language_models/src/provider/anthropic.rs    | 30 ++++
crates/language_models/src/provider/bedrock.rs      |  1 
crates/language_models/src/provider/copilot_chat.rs |  1 
crates/language_models/src/provider/google.rs       |  1 
crates/language_models/src/provider/open_ai.rs      |  1 
crates/markdown/src/markdown.rs                     |  5 
17 files changed, 358 insertions(+), 47 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -7713,6 +7713,7 @@ dependencies = [
  "mistral",
  "ollama",
  "open_ai",
+ "partial-json-fixer",
  "project",
  "proto",
  "schemars",
@@ -9828,6 +9829,12 @@ dependencies = [
  "windows-targets 0.52.6",
 ]
 
+[[package]]
+name = "partial-json-fixer"
+version = "0.5.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "35ffd90b3f3b6477db7478016b9efb1b7e9d38eafd095f0542fe0ec2ea884a13"
+
 [[package]]
 name = "password-hash"
 version = "0.4.2"

Cargo.toml 🔗

@@ -480,6 +480,7 @@ num-format = "0.4.4"
 ordered-float = "2.1.1"
 palette = { version = "0.7.5", default-features = false, features = ["std"] }
 parking_lot = "0.12.1"
+partial-json-fixer = "0.5.3"
 pathdiff = "0.2"
 pet = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }
 pet-fs = { git = "https://github.com/microsoft/python-environment-tools.git", rev = "845945b830297a50de0e24020b980a65e4820559" }

crates/agent/src/active_thread.rs 🔗

@@ -266,14 +266,6 @@ fn default_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
     }
 }
 
-fn render_tool_use_markdown(
-    text: SharedString,
-    language_registry: Arc<LanguageRegistry>,
-    cx: &mut App,
-) -> Entity<Markdown> {
-    cx.new(|cx| Markdown::new(text, Some(language_registry), None, cx))
-}
-
 fn tool_use_markdown_style(window: &Window, cx: &mut App) -> MarkdownStyle {
     let theme_settings = ThemeSettings::get_global(cx);
     let colors = cx.theme().colors();
@@ -867,21 +859,34 @@ impl ActiveThread {
         tool_output: SharedString,
         cx: &mut Context<Self>,
     ) {
-        let rendered = RenderedToolUse {
-            label: render_tool_use_markdown(tool_label.into(), self.language_registry.clone(), cx),
-            input: render_tool_use_markdown(
-                format!(
-                    "```json\n{}\n```",
-                    serde_json::to_string_pretty(tool_input).unwrap_or_default()
-                )
-                .into(),
-                self.language_registry.clone(),
-                cx,
-            ),
-            output: render_tool_use_markdown(tool_output, self.language_registry.clone(), cx),
-        };
-        self.rendered_tool_uses
-            .insert(tool_use_id.clone(), rendered);
+        let rendered = self
+            .rendered_tool_uses
+            .entry(tool_use_id.clone())
+            .or_insert_with(|| RenderedToolUse {
+                label: cx.new(|cx| {
+                    Markdown::new("".into(), Some(self.language_registry.clone()), None, cx)
+                }),
+                input: cx.new(|cx| {
+                    Markdown::new("".into(), Some(self.language_registry.clone()), None, cx)
+                }),
+                output: cx.new(|cx| {
+                    Markdown::new("".into(), Some(self.language_registry.clone()), None, cx)
+                }),
+            });
+
+        rendered.label.update(cx, |this, cx| {
+            this.replace(tool_label, cx);
+        });
+        rendered.input.update(cx, |this, cx| {
+            let input = format!(
+                "```json\n{}\n```",
+                serde_json::to_string_pretty(tool_input).unwrap_or_default()
+            );
+            this.replace(input, cx);
+        });
+        rendered.output.update(cx, |this, cx| {
+            this.replace(tool_output, cx);
+        });
     }
 
     fn handle_thread_event(
@@ -974,6 +979,19 @@ impl ActiveThread {
                     );
                 }
             }
+            ThreadEvent::StreamedToolUse {
+                tool_use_id,
+                ui_text,
+                input,
+            } => {
+                self.render_tool_use_markdown(
+                    tool_use_id.clone(),
+                    ui_text.clone(),
+                    input,
+                    "".into(),
+                    cx,
+                );
+            }
             ThreadEvent::ToolFinished {
                 pending_tool_use, ..
             } => {
@@ -2478,13 +2496,15 @@ impl ActiveThread {
         let edit_tools = tool_use.needs_confirmation;
 
         let status_icons = div().child(match &tool_use.status {
-            ToolUseStatus::Pending | ToolUseStatus::NeedsConfirmation => {
+            ToolUseStatus::NeedsConfirmation => {
                 let icon = Icon::new(IconName::Warning)
                     .color(Color::Warning)
                     .size(IconSize::Small);
                 icon.into_any_element()
             }
-            ToolUseStatus::Running => {
+            ToolUseStatus::Pending
+            | ToolUseStatus::InputStillStreaming
+            | ToolUseStatus::Running => {
                 let icon = Icon::new(IconName::ArrowCircle)
                     .color(Color::Accent)
                     .size(IconSize::Small);
@@ -2570,7 +2590,7 @@ impl ActiveThread {
                             }),
                         )),
                 ),
-                ToolUseStatus::Running => container.child(
+                ToolUseStatus::InputStillStreaming | ToolUseStatus::Running => container.child(
                     results_content_container().child(
                         h_flex()
                             .gap_1()

crates/agent/src/thread.rs 🔗

@@ -1293,12 +1293,27 @@ impl Thread {
                                         thread.insert_message(Role::Assistant, vec![], cx)
                                     });
 
-                                thread.tool_use.request_tool_use(
+                                let tool_use_id = tool_use.id.clone();
+                                let streamed_input = if tool_use.is_input_complete {
+                                    None
+                                } else {
+                                    Some((&tool_use.input).clone())
+                                };
+
+                                let ui_text = thread.tool_use.request_tool_use(
                                     last_assistant_message_id,
                                     tool_use,
                                     tool_use_metadata.clone(),
                                     cx,
                                 );
+
+                                if let Some(input) = streamed_input {
+                                    cx.emit(ThreadEvent::StreamedToolUse {
+                                        tool_use_id,
+                                        ui_text,
+                                        input,
+                                    });
+                                }
                             }
                         }
 
@@ -2189,6 +2204,11 @@ pub enum ThreadEvent {
     StreamedCompletion,
     StreamedAssistantText(MessageId, String),
     StreamedAssistantThinking(MessageId, String),
+    StreamedToolUse {
+        tool_use_id: LanguageModelToolUseId,
+        ui_text: Arc<str>,
+        input: serde_json::Value,
+    },
     Stopped(Result<StopReason, Arc<anyhow::Error>>),
     MessageAdded(MessageId),
     MessageEdited(MessageId),

crates/agent/src/tool_use.rs 🔗

@@ -75,6 +75,7 @@ impl ToolUseState {
                                 id: tool_use.id.clone(),
                                 name: tool_use.name.clone().into(),
                                 input: tool_use.input.clone(),
+                                is_input_complete: true,
                             })
                             .collect::<Vec<_>>();
 
@@ -176,6 +177,9 @@ impl ToolUseState {
                         PendingToolUseStatus::Error(ref err) => {
                             ToolUseStatus::Error(err.clone().into())
                         }
+                        PendingToolUseStatus::InputStillStreaming => {
+                            ToolUseStatus::InputStillStreaming
+                        }
                     }
                 } else {
                     ToolUseStatus::Pending
@@ -192,7 +196,12 @@ impl ToolUseState {
             tool_uses.push(ToolUse {
                 id: tool_use.id.clone(),
                 name: tool_use.name.clone().into(),
-                ui_text: self.tool_ui_label(&tool_use.name, &tool_use.input, cx),
+                ui_text: self.tool_ui_label(
+                    &tool_use.name,
+                    &tool_use.input,
+                    tool_use.is_input_complete,
+                    cx,
+                ),
                 input: tool_use.input.clone(),
                 status,
                 icon,
@@ -207,10 +216,15 @@ impl ToolUseState {
         &self,
         tool_name: &str,
         input: &serde_json::Value,
+        is_input_complete: bool,
         cx: &App,
     ) -> SharedString {
         if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
-            tool.ui_text(input).into()
+            if is_input_complete {
+                tool.ui_text(input).into()
+            } else {
+                tool.still_streaming_ui_text(input).into()
+            }
         } else {
             format!("Unknown tool {tool_name:?}").into()
         }
@@ -258,22 +272,50 @@ impl ToolUseState {
         tool_use: LanguageModelToolUse,
         metadata: ToolUseMetadata,
         cx: &App,
-    ) {
-        self.tool_uses_by_assistant_message
+    ) -> Arc<str> {
+        let tool_uses = self
+            .tool_uses_by_assistant_message
             .entry(assistant_message_id)
-            .or_default()
-            .push(tool_use.clone());
+            .or_default();
 
-        self.tool_use_metadata_by_id
-            .insert(tool_use.id.clone(), metadata);
+        let mut existing_tool_use_found = false;
 
-        // The tool use is being requested by the Assistant, so we want to
-        // attach the tool results to the next user message.
-        let next_user_message_id = MessageId(assistant_message_id.0 + 1);
-        self.tool_uses_by_user_message
-            .entry(next_user_message_id)
-            .or_default()
-            .push(tool_use.id.clone());
+        for existing_tool_use in tool_uses.iter_mut() {
+            if existing_tool_use.id == tool_use.id {
+                *existing_tool_use = tool_use.clone();
+                existing_tool_use_found = true;
+            }
+        }
+
+        if !existing_tool_use_found {
+            tool_uses.push(tool_use.clone());
+        }
+
+        let status = if tool_use.is_input_complete {
+            self.tool_use_metadata_by_id
+                .insert(tool_use.id.clone(), metadata);
+
+            // The tool use is being requested by the Assistant, so we want to
+            // attach the tool results to the next user message.
+            let next_user_message_id = MessageId(assistant_message_id.0 + 1);
+            self.tool_uses_by_user_message
+                .entry(next_user_message_id)
+                .or_default()
+                .push(tool_use.id.clone());
+
+            PendingToolUseStatus::Idle
+        } else {
+            PendingToolUseStatus::InputStillStreaming
+        };
+
+        let ui_text: Arc<str> = self
+            .tool_ui_label(
+                &tool_use.name,
+                &tool_use.input,
+                tool_use.is_input_complete,
+                cx,
+            )
+            .into();
 
         self.pending_tool_uses_by_id.insert(
             tool_use.id.clone(),
@@ -281,13 +323,13 @@ impl ToolUseState {
                 assistant_message_id,
                 id: tool_use.id,
                 name: tool_use.name.clone(),
-                ui_text: self
-                    .tool_ui_label(&tool_use.name, &tool_use.input, cx)
-                    .into(),
+                ui_text: ui_text.clone(),
                 input: tool_use.input,
-                status: PendingToolUseStatus::Idle,
+                status,
             },
         );
+
+        ui_text
     }
 
     pub fn run_pending_tool(
@@ -497,6 +539,7 @@ pub struct Confirmation {
 
 #[derive(Debug, Clone)]
 pub enum PendingToolUseStatus {
+    InputStillStreaming,
     Idle,
     NeedsConfirmation(Arc<Confirmation>),
     Running { _task: Shared<Task<()>> },

crates/assistant_tool/src/assistant_tool.rs 🔗

@@ -30,6 +30,7 @@ pub fn init(cx: &mut App) {
 
 #[derive(Debug, Clone)]
 pub enum ToolUseStatus {
+    InputStillStreaming,
     NeedsConfirmation,
     Pending,
     Running,
@@ -41,6 +42,7 @@ impl ToolUseStatus {
     pub fn text(&self) -> SharedString {
         match self {
             ToolUseStatus::NeedsConfirmation => "".into(),
+            ToolUseStatus::InputStillStreaming => "".into(),
             ToolUseStatus::Pending => "".into(),
             ToolUseStatus::Running => "".into(),
             ToolUseStatus::Finished(out) => out.clone(),
@@ -148,6 +150,12 @@ pub trait Tool: 'static + Send + Sync {
     /// Returns markdown to be displayed in the UI for this tool.
     fn ui_text(&self, input: &serde_json::Value) -> String;
 
+    /// Returns markdown to be displayed in the UI for this tool, while the input JSON is still streaming
+    /// (so information may be missing).
+    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
+        self.ui_text(input)
+    }
+
     /// Runs the tool with the provided input.
     fn run(
         self: Arc<Self>,

crates/assistant_tools/src/create_file_tool.rs 🔗

@@ -33,8 +33,18 @@ pub struct CreateFileToolInput {
     pub contents: String,
 }
 
+#[derive(Debug, Serialize, Deserialize, JsonSchema)]
+struct PartialInput {
+    #[serde(default)]
+    path: String,
+    #[serde(default)]
+    contents: String,
+}
+
 pub struct CreateFileTool;
 
+const DEFAULT_UI_TEXT: &str = "Create file";
+
 impl Tool for CreateFileTool {
     fn name(&self) -> String {
         "create_file".into()
@@ -62,7 +72,14 @@ impl Tool for CreateFileTool {
                 let path = MarkdownString::inline_code(&input.path);
                 format!("Create file {path}")
             }
-            Err(_) => "Create file".to_string(),
+            Err(_) => DEFAULT_UI_TEXT.to_string(),
+        }
+    }
+
+    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
+        match serde_json::from_value::<PartialInput>(input.clone()).ok() {
+            Some(input) if !input.path.is_empty() => input.path,
+            _ => DEFAULT_UI_TEXT.to_string(),
         }
     }
 
@@ -111,3 +128,60 @@ impl Tool for CreateFileTool {
         .into()
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use serde_json::json;
+
+    #[test]
+    fn still_streaming_ui_text_with_path() {
+        let tool = CreateFileTool;
+        let input = json!({
+            "path": "src/main.rs",
+            "contents": "fn main() {\n    println!(\"Hello, world!\");\n}"
+        });
+
+        assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs");
+    }
+
+    #[test]
+    fn still_streaming_ui_text_without_path() {
+        let tool = CreateFileTool;
+        let input = json!({
+            "path": "",
+            "contents": "fn main() {\n    println!(\"Hello, world!\");\n}"
+        });
+
+        assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
+    }
+
+    #[test]
+    fn still_streaming_ui_text_with_null() {
+        let tool = CreateFileTool;
+        let input = serde_json::Value::Null;
+
+        assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
+    }
+
+    #[test]
+    fn ui_text_with_valid_input() {
+        let tool = CreateFileTool;
+        let input = json!({
+            "path": "src/main.rs",
+            "contents": "fn main() {\n    println!(\"Hello, world!\");\n}"
+        });
+
+        assert_eq!(tool.ui_text(&input), "Create file `src/main.rs`");
+    }
+
+    #[test]
+    fn ui_text_with_invalid_input() {
+        let tool = CreateFileTool;
+        let input = json!({
+            "invalid": "field"
+        });
+
+        assert_eq!(tool.ui_text(&input), DEFAULT_UI_TEXT);
+    }
+}

crates/assistant_tools/src/edit_file_tool.rs 🔗

@@ -47,8 +47,22 @@ pub struct EditFileToolInput {
     pub new_string: String,
 }
 
+#[derive(Debug, Serialize, Deserialize, JsonSchema)]
+struct PartialInput {
+    #[serde(default)]
+    path: String,
+    #[serde(default)]
+    display_description: String,
+    #[serde(default)]
+    old_string: String,
+    #[serde(default)]
+    new_string: String,
+}
+
 pub struct EditFileTool;
 
+const DEFAULT_UI_TEXT: &str = "Edit file";
+
 impl Tool for EditFileTool {
     fn name(&self) -> String {
         "edit_file".into()
@@ -77,6 +91,22 @@ impl Tool for EditFileTool {
         }
     }
 
+    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
+        if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
+            let description = input.display_description.trim();
+            if !description.is_empty() {
+                return description.to_string();
+            }
+
+            let path = input.path.trim();
+            if !path.is_empty() {
+                return path.to_string();
+            }
+        }
+
+        DEFAULT_UI_TEXT.to_string()
+    }
+
     fn run(
         self: Arc<Self>,
         input: serde_json::Value,
@@ -181,3 +211,69 @@ impl Tool for EditFileTool {
         }).into()
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use serde_json::json;
+
+    #[test]
+    fn still_streaming_ui_text_with_path() {
+        let tool = EditFileTool;
+        let input = json!({
+            "path": "src/main.rs",
+            "display_description": "",
+            "old_string": "old code",
+            "new_string": "new code"
+        });
+
+        assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs");
+    }
+
+    #[test]
+    fn still_streaming_ui_text_with_description() {
+        let tool = EditFileTool;
+        let input = json!({
+            "path": "",
+            "display_description": "Fix error handling",
+            "old_string": "old code",
+            "new_string": "new code"
+        });
+
+        assert_eq!(tool.still_streaming_ui_text(&input), "Fix error handling");
+    }
+
+    #[test]
+    fn still_streaming_ui_text_with_path_and_description() {
+        let tool = EditFileTool;
+        let input = json!({
+            "path": "src/main.rs",
+            "display_description": "Fix error handling",
+            "old_string": "old code",
+            "new_string": "new code"
+        });
+
+        assert_eq!(tool.still_streaming_ui_text(&input), "Fix error handling");
+    }
+
+    #[test]
+    fn still_streaming_ui_text_no_path_or_description() {
+        let tool = EditFileTool;
+        let input = json!({
+            "path": "",
+            "display_description": "",
+            "old_string": "old code",
+            "new_string": "new code"
+        });
+
+        assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
+    }
+
+    #[test]
+    fn still_streaming_ui_text_with_null() {
+        let tool = EditFileTool;
+        let input = serde_json::Value::Null;
+
+        assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
+    }
+}

crates/eval/src/example.rs 🔗

@@ -426,6 +426,7 @@ impl Example {
                             ThreadEvent::ToolConfirmationNeeded => {
                                 panic!("{}Bug: Tool confirmation should not be required in eval", log_prefix);
                             },
+                            ThreadEvent::StreamedToolUse { .. } |
                             ThreadEvent::StreamedCompletion |
                             ThreadEvent::MessageAdded(_) |
                             ThreadEvent::MessageEdited(_) |

crates/language_model/src/language_model.rs 🔗

@@ -187,6 +187,7 @@ pub struct LanguageModelToolUse {
     pub id: LanguageModelToolUseId,
     pub name: Arc<str>,
     pub input: serde_json::Value,
+    pub is_input_complete: bool,
 }
 
 pub struct LanguageModelTextStream {

crates/language_models/Cargo.toml 🔗

@@ -38,6 +38,7 @@ menu.workspace = true
 mistral = { workspace = true, features = ["schemars"] }
 ollama = { workspace = true, features = ["schemars"] }
 open_ai = { workspace = true, features = ["schemars"] }
+partial-json-fixer.workspace = true
 project.workspace = true
 proto.workspace = true
 schemars.workspace = true

crates/language_models/src/provider/anthropic.rs 🔗

@@ -713,6 +713,35 @@ pub fn map_to_language_model_completion_events(
                             ContentDelta::InputJsonDelta { partial_json } => {
                                 if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
                                     tool_use.input_json.push_str(&partial_json);
+
+                                    return Some((
+                                        vec![maybe!({
+                                            Ok(LanguageModelCompletionEvent::ToolUse(
+                                                LanguageModelToolUse {
+                                                    id: tool_use.id.clone().into(),
+                                                    name: tool_use.name.clone().into(),
+                                                    is_input_complete: false,
+                                                    input: if tool_use.input_json.is_empty() {
+                                                        serde_json::Value::Object(
+                                                            serde_json::Map::default(),
+                                                        )
+                                                    } else {
+                                                        serde_json::Value::from_str(
+                                                            // Convert invalid (incomplete) JSON into
+                                                            // JSON that serde will accept, e.g. by closing
+                                                            // unclosed delimiters. This way, we can update
+                                                            // the UI with whatever has been streamed back so far.
+                                                            &partial_json_fixer::fix_json(
+                                                                &tool_use.input_json,
+                                                            ),
+                                                        )
+                                                        .map_err(|err| anyhow!(err))?
+                                                    },
+                                                },
+                                            ))
+                                        })],
+                                        state,
+                                    ));
                                 }
                             }
                         },
@@ -724,6 +753,7 @@ pub fn map_to_language_model_completion_events(
                                             LanguageModelToolUse {
                                                 id: tool_use.id.into(),
                                                 name: tool_use.name.into(),
+                                                is_input_complete: true,
                                                 input: if tool_use.input_json.is_empty() {
                                                     serde_json::Value::Object(
                                                         serde_json::Map::default(),

crates/language_models/src/provider/bedrock.rs 🔗

@@ -893,6 +893,7 @@ pub fn map_to_language_model_completion_events(
                                             let tool_use_event = LanguageModelToolUse {
                                                 id: tool_use.id.into(),
                                                 name: tool_use.name.into(),
+                                                is_input_complete: true,
                                                 input: if tool_use.input_json.is_empty() {
                                                     Value::Null
                                                 } else {

crates/language_models/src/provider/copilot_chat.rs 🔗

@@ -367,6 +367,7 @@ pub fn map_to_language_model_completion_events(
                                                 LanguageModelToolUse {
                                                     id: tool_call.id.into(),
                                                     name: tool_call.name.as_str().into(),
+                                                    is_input_complete: true,
                                                     input: serde_json::Value::from_str(
                                                         &tool_call.arguments,
                                                     )?,

crates/language_models/src/provider/open_ai.rs 🔗

@@ -490,6 +490,7 @@ pub fn map_to_language_model_completion_events(
                                                 LanguageModelToolUse {
                                                     id: tool_call.id.into(),
                                                     name: tool_call.name.as_str().into(),
+                                                    is_input_complete: true,
                                                     input: serde_json::Value::from_str(
                                                         &tool_call.arguments,
                                                     )?,

crates/markdown/src/markdown.rs 🔗

@@ -192,6 +192,11 @@ impl Markdown {
         self.parse(cx);
     }
 
+    pub fn replace(&mut self, source: impl Into<SharedString>, cx: &mut Context<Self>) {
+        self.source = source.into();
+        self.parse(cx);
+    }
+
     pub fn reset(&mut self, source: SharedString, cx: &mut Context<Self>) {
         if source == self.source() {
             return;