Add support for streaming tool input to more providers (#50682)

Bennet Bo Fenner created

To test:
- [x] Bedrock
- [x] Copilot Chat
- [x] Deepseek
- [x] Open AI
- [x] Open Router
- [x] Vercel
- [x] Vercel AI Gateway
- [x] xAI
- [x] Mistral

Release Notes:

- N/A

Change summary

crates/agent/src/thread.rs                                |  12 
crates/language_models/src/provider/bedrock.rs            |  23 +
crates/language_models/src/provider/copilot_chat.rs       |  21 +
crates/language_models/src/provider/deepseek.rs           |  21 +
crates/language_models/src/provider/mistral.rs            |  21 +
crates/language_models/src/provider/open_ai.rs            | 142 ++++++++
crates/language_models/src/provider/open_ai_compatible.rs |   4 
crates/language_models/src/provider/open_router.rs        |  21 +
crates/language_models/src/provider/vercel.rs             |   4 
crates/language_models/src/provider/vercel_ai_gateway.rs  |   4 
crates/language_models/src/provider/x_ai.rs               |   7 
crates/x_ai/src/x_ai.rs                                   |  12 
12 files changed, 279 insertions(+), 13 deletions(-)

Detailed changes

crates/agent/src/thread.rs 🔗

@@ -2335,20 +2335,18 @@ impl Thread {
     ) {
         // Ensure the last message ends in the current tool use
         let last_message = self.pending_message();
-        let push_new_tool_use = last_message.content.last_mut().is_none_or(|content| {
+
+        let has_tool_use = last_message.content.iter_mut().rev().any(|content| {
             if let AgentMessageContent::ToolUse(last_tool_use) = content {
                 if last_tool_use.id == tool_use.id {
                     *last_tool_use = tool_use.clone();
-                    false
-                } else {
-                    true
+                    return true;
                 }
-            } else {
-                true
             }
+            false
         });
 
-        if push_new_tool_use {
+        if !has_tool_use {
             event_stream.send_tool_call(
                 &tool_use.id,
                 &tool_use.name,

crates/language_models/src/provider/bedrock.rs 🔗

@@ -658,6 +658,10 @@ impl LanguageModel for BedrockModel {
         }
     }
 
+    fn supports_streaming_tools(&self) -> bool {
+        true
+    }
+
     fn telemetry_id(&self) -> String {
         format!("bedrock/{}", self.model.id())
     }
@@ -1200,8 +1204,25 @@ pub fn map_to_language_model_completion_events(
                                     .get_mut(&cb_delta.content_block_index)
                                 {
                                     tool_use.input_json.push_str(tool_output.input());
+                                    if let Ok(input) = serde_json::from_str::<serde_json::Value>(
+                                        &partial_json_fixer::fix_json(&tool_use.input_json),
+                                    ) {
+                                        Some(Ok(LanguageModelCompletionEvent::ToolUse(
+                                            LanguageModelToolUse {
+                                                id: tool_use.id.clone().into(),
+                                                name: tool_use.name.clone().into(),
+                                                is_input_complete: false,
+                                                raw_input: tool_use.input_json.clone(),
+                                                input,
+                                                thought_signature: None,
+                                            },
+                                        )))
+                                    } else {
+                                        None
+                                    }
+                                } else {
+                                    None
                                 }
-                                None
                             }
                             Some(ContentBlockDelta::ReasoningContent(thinking)) => match thinking {
                                 ReasoningContentBlockDelta::Text(thoughts) => {

crates/language_models/src/provider/copilot_chat.rs 🔗

@@ -246,6 +246,10 @@ impl LanguageModel for CopilotChatLanguageModel {
         self.model.supports_tools()
     }
 
+    fn supports_streaming_tools(&self) -> bool {
+        true
+    }
+
     fn supports_images(&self) -> bool {
         self.model.supports_vision()
     }
@@ -455,6 +459,23 @@ pub fn map_to_language_model_completion_events(
                                     entry.thought_signature = Some(thought_signature);
                                 }
                             }
+
+                            if !entry.id.is_empty() && !entry.name.is_empty() {
+                                if let Ok(input) = serde_json::from_str::<serde_json::Value>(
+                                    &partial_json_fixer::fix_json(&entry.arguments),
+                                ) {
+                                    events.push(Ok(LanguageModelCompletionEvent::ToolUse(
+                                        LanguageModelToolUse {
+                                            id: entry.id.clone().into(),
+                                            name: entry.name.as_str().into(),
+                                            is_input_complete: false,
+                                            input,
+                                            raw_input: entry.arguments.clone(),
+                                            thought_signature: entry.thought_signature.clone(),
+                                        },
+                                    )));
+                                }
+                            }
                         }
 
                         if let Some(usage) = event.usage {

crates/language_models/src/provider/deepseek.rs 🔗

@@ -246,6 +246,10 @@ impl LanguageModel for DeepSeekLanguageModel {
         true
     }
 
+    fn supports_streaming_tools(&self) -> bool {
+        true
+    }
+
     fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
         true
     }
@@ -469,6 +473,23 @@ impl DeepSeekEventMapper {
                         entry.arguments.push_str(&arguments);
                     }
                 }
+
+                if !entry.id.is_empty() && !entry.name.is_empty() {
+                    if let Ok(input) = serde_json::from_str::<serde_json::Value>(
+                        &partial_json_fixer::fix_json(&entry.arguments),
+                    ) {
+                        events.push(Ok(LanguageModelCompletionEvent::ToolUse(
+                            LanguageModelToolUse {
+                                id: entry.id.clone().into(),
+                                name: entry.name.as_str().into(),
+                                is_input_complete: false,
+                                input,
+                                raw_input: entry.arguments.clone(),
+                                thought_signature: None,
+                            },
+                        )));
+                    }
+                }
             }
         }
 

crates/language_models/src/provider/mistral.rs 🔗

@@ -280,6 +280,10 @@ impl LanguageModel for MistralLanguageModel {
         self.model.supports_tools()
     }
 
+    fn supports_streaming_tools(&self) -> bool {
+        true
+    }
+
     fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
         self.model.supports_tools()
     }
@@ -629,6 +633,23 @@ impl MistralEventMapper {
                         entry.arguments.push_str(&arguments);
                     }
                 }
+
+                if !entry.id.is_empty() && !entry.name.is_empty() {
+                    if let Ok(input) = serde_json::from_str::<serde_json::Value>(
+                        &partial_json_fixer::fix_json(&entry.arguments),
+                    ) {
+                        events.push(Ok(LanguageModelCompletionEvent::ToolUse(
+                            LanguageModelToolUse {
+                                id: entry.id.clone().into(),
+                                name: entry.name.as_str().into(),
+                                is_input_complete: false,
+                                input,
+                                raw_input: entry.arguments.clone(),
+                                thought_signature: None,
+                            },
+                        )));
+                    }
+                }
             }
         }
 

crates/language_models/src/provider/open_ai.rs 🔗

@@ -328,6 +328,10 @@ impl LanguageModel for OpenAiLanguageModel {
         }
     }
 
+    fn supports_streaming_tools(&self) -> bool {
+        true
+    }
+
     fn supports_thinking(&self) -> bool {
         self.model.reasoning_effort().is_some()
     }
@@ -824,6 +828,23 @@ impl OpenAiEventMapper {
                             entry.arguments.push_str(&arguments);
                         }
                     }
+
+                    if !entry.id.is_empty() && !entry.name.is_empty() {
+                        if let Ok(input) = serde_json::from_str::<serde_json::Value>(
+                            &partial_json_fixer::fix_json(&entry.arguments),
+                        ) {
+                            events.push(Ok(LanguageModelCompletionEvent::ToolUse(
+                                LanguageModelToolUse {
+                                    id: entry.id.clone().into(),
+                                    name: entry.name.as_str().into(),
+                                    is_input_complete: false,
+                                    input,
+                                    raw_input: entry.arguments.clone(),
+                                    thought_signature: None,
+                                },
+                            )));
+                        }
+                    }
                 }
             }
         }
@@ -954,6 +975,20 @@ impl OpenAiResponseEventMapper {
             ResponsesStreamEvent::FunctionCallArgumentsDelta { item_id, delta, .. } => {
                 if let Some(entry) = self.function_calls_by_item.get_mut(&item_id) {
                     entry.arguments.push_str(&delta);
+                    if let Ok(input) = serde_json::from_str::<serde_json::Value>(
+                        &partial_json_fixer::fix_json(&entry.arguments),
+                    ) {
+                        return vec![Ok(LanguageModelCompletionEvent::ToolUse(
+                            LanguageModelToolUse {
+                                id: LanguageModelToolUseId::from(entry.call_id.clone()),
+                                name: entry.name.clone(),
+                                is_input_complete: false,
+                                input,
+                                raw_input: entry.arguments.clone(),
+                                thought_signature: None,
+                            },
+                        ))];
+                    }
                 }
                 Vec::new()
             }
@@ -1670,19 +1705,30 @@ mod tests {
         ];
 
         let mapped = map_response_events(events);
+        assert_eq!(mapped.len(), 3);
+        // First event is the partial tool use (from FunctionCallArgumentsDelta)
         assert!(matches!(
             mapped[0],
+            LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
+                is_input_complete: false,
+                ..
+            })
+        ));
+        // Second event is the complete tool use (from FunctionCallArgumentsDone)
+        assert!(matches!(
+            mapped[1],
             LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
                 ref id,
                 ref name,
                 ref raw_input,
+                is_input_complete: true,
                 ..
             }) if id.to_string() == "call_123"
                 && name.as_ref() == "get_weather"
                 && raw_input == "{\"city\":\"Boston\"}"
         ));
         assert!(matches!(
-            mapped[1],
+            mapped[2],
             LanguageModelCompletionEvent::Stop(StopReason::ToolUse)
         ));
     }
@@ -1878,13 +1924,27 @@ mod tests {
         ];
 
         let mapped = map_response_events(events);
+        assert_eq!(mapped.len(), 3);
+        // First event is the partial tool use (from FunctionCallArgumentsDelta)
         assert!(matches!(
             mapped[0],
-            LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. })
-            if raw_input == "{\"city\":\"Boston\"}"
+            LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
+                is_input_complete: false,
+                ..
+            })
         ));
+        // Second event is the complete tool use (from the Incomplete response output)
         assert!(matches!(
             mapped[1],
+            LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
+                ref raw_input,
+                is_input_complete: true,
+                ..
+            })
+            if raw_input == "{\"city\":\"Boston\"}"
+        ));
+        assert!(matches!(
+            mapped[2],
             LanguageModelCompletionEvent::Stop(StopReason::MaxTokens)
         ));
     }
@@ -1976,4 +2036,80 @@ mod tests {
             LanguageModelCompletionEvent::Stop(StopReason::ToolUse)
         ));
     }
+
+    #[test]
+    fn responses_stream_emits_partial_tool_use_events() {
+        let events = vec![
+            ResponsesStreamEvent::OutputItemAdded {
+                output_index: 0,
+                sequence_number: None,
+                item: ResponseOutputItem::FunctionCall(ResponseFunctionToolCall {
+                    id: Some("item_fn".to_string()),
+                    status: Some("in_progress".to_string()),
+                    name: Some("get_weather".to_string()),
+                    call_id: Some("call_abc".to_string()),
+                    arguments: String::new(),
+                }),
+            },
+            ResponsesStreamEvent::FunctionCallArgumentsDelta {
+                item_id: "item_fn".into(),
+                output_index: 0,
+                delta: "{\"city\":\"Bos".into(),
+                sequence_number: None,
+            },
+            ResponsesStreamEvent::FunctionCallArgumentsDelta {
+                item_id: "item_fn".into(),
+                output_index: 0,
+                delta: "ton\"}".into(),
+                sequence_number: None,
+            },
+            ResponsesStreamEvent::FunctionCallArgumentsDone {
+                item_id: "item_fn".into(),
+                output_index: 0,
+                arguments: "{\"city\":\"Boston\"}".into(),
+                sequence_number: None,
+            },
+            ResponsesStreamEvent::Completed {
+                response: ResponseSummary::default(),
+            },
+        ];
+
+        let mapped = map_response_events(events);
+        // Two partial events + one complete event + Stop
+        assert!(mapped.len() >= 3);
+
+        // The last complete ToolUse event should have is_input_complete: true
+        let complete_tool_use = mapped.iter().find(|e| {
+            matches!(
+                e,
+                LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
+                    is_input_complete: true,
+                    ..
+                })
+            )
+        });
+        assert!(
+            complete_tool_use.is_some(),
+            "should have a complete tool use event"
+        );
+
+        // All ToolUse events before the final one should have is_input_complete: false
+        let tool_uses: Vec<_> = mapped
+            .iter()
+            .filter(|e| matches!(e, LanguageModelCompletionEvent::ToolUse(_)))
+            .collect();
+        assert!(
+            tool_uses.len() >= 2,
+            "should have at least one partial and one complete event"
+        );
+
+        let last = tool_uses.last().unwrap();
+        assert!(matches!(
+            last,
+            LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
+                is_input_complete: true,
+                ..
+            })
+        ));
+    }
 }

crates/language_models/src/provider/open_router.rs 🔗

@@ -314,6 +314,10 @@ impl LanguageModel for OpenRouterLanguageModel {
         self.model.supports_tool_calls()
     }
 
+    fn supports_streaming_tools(&self) -> bool {
+        true
+    }
+
     fn supports_thinking(&self) -> bool {
         matches!(self.model.mode, OpenRouterModelMode::Thinking { .. })
     }
@@ -650,6 +654,23 @@ impl OpenRouterEventMapper {
                         entry.thought_signature = Some(signature);
                     }
                 }
+
+                if !entry.id.is_empty() && !entry.name.is_empty() {
+                    if let Ok(input) = serde_json::from_str::<serde_json::Value>(
+                        &partial_json_fixer::fix_json(&entry.arguments),
+                    ) {
+                        events.push(Ok(LanguageModelCompletionEvent::ToolUse(
+                            LanguageModelToolUse {
+                                id: entry.id.clone().into(),
+                                name: entry.name.as_str().into(),
+                                is_input_complete: false,
+                                input,
+                                raw_input: entry.arguments.clone(),
+                                thought_signature: entry.thought_signature.clone(),
+                            },
+                        )));
+                    }
+                }
             }
         }
 

crates/language_models/src/provider/vercel.rs 🔗

@@ -248,6 +248,10 @@ impl LanguageModel for VercelLanguageModel {
         true
     }
 
+    fn supports_streaming_tools(&self) -> bool {
+        true
+    }
+
     fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
         match choice {
             LanguageModelToolChoice::Auto

crates/language_models/src/provider/x_ai.rs 🔗

@@ -257,6 +257,10 @@ impl LanguageModel for XAiLanguageModel {
         self.model.supports_images()
     }
 
+    fn supports_streaming_tools(&self) -> bool {
+        true
+    }
+
     fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
         match choice {
             LanguageModelToolChoice::Auto
@@ -265,8 +269,7 @@ impl LanguageModel for XAiLanguageModel {
         }
     }
     fn tool_input_format(&self) -> LanguageModelToolSchemaFormat {
-        let model_id = self.model.id().trim().to_lowercase();
-        if model_id.eq(x_ai::Model::Grok4.id()) || model_id.eq(x_ai::Model::GrokCodeFast1.id()) {
+        if self.model.requires_json_schema_subset() {
             LanguageModelToolSchemaFormat::JsonSchemaSubset
         } else {
             LanguageModelToolSchemaFormat::JsonSchema

crates/x_ai/src/x_ai.rs 🔗

@@ -165,6 +165,18 @@ impl Model {
         }
     }
 
+    pub fn requires_json_schema_subset(&self) -> bool {
+        match self {
+            Self::Grok4
+            | Self::Grok4FastReasoning
+            | Self::Grok4FastNonReasoning
+            | Self::Grok41FastNonReasoning
+            | Self::Grok41FastReasoning
+            | Self::GrokCodeFast1 => true,
+            _ => false,
+        }
+    }
+
     pub fn supports_prompt_cache_key(&self) -> bool {
         false
     }