From 87bc2aac5cc99e8425e1c29c1af6bb7bc15e280f Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Wed, 4 Mar 2026 17:36:25 +0100 Subject: [PATCH] Add support for streaming tool input to more providers (#50682) To test: - [x] Bedrock - [x] Copilot Chat - [x] Deepseek - [x] Open AI - [x] Open Router - [x] Vercel - [x] Vercel AI Gateway - [x] xAI - [x] Mistral Release Notes: - N/A --- crates/agent/src/thread.rs | 12 +- .../language_models/src/provider/bedrock.rs | 23 ++- .../src/provider/copilot_chat.rs | 21 +++ .../language_models/src/provider/deepseek.rs | 21 +++ .../language_models/src/provider/mistral.rs | 21 +++ .../language_models/src/provider/open_ai.rs | 142 +++++++++++++++++- .../src/provider/open_ai_compatible.rs | 4 + .../src/provider/open_router.rs | 21 +++ crates/language_models/src/provider/vercel.rs | 4 + .../src/provider/vercel_ai_gateway.rs | 4 + crates/language_models/src/provider/x_ai.rs | 7 +- crates/x_ai/src/x_ai.rs | 12 ++ 12 files changed, 279 insertions(+), 13 deletions(-) diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 616ae414d4d51a384a18460e8339fd07770fa6b9..be87a6a1e1e5ddba8a5d4b3b5bca82168a141840 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -2335,20 +2335,18 @@ impl Thread { ) { // Ensure the last message ends in the current tool use let last_message = self.pending_message(); - let push_new_tool_use = last_message.content.last_mut().is_none_or(|content| { + + let has_tool_use = last_message.content.iter_mut().rev().any(|content| { if let AgentMessageContent::ToolUse(last_tool_use) = content { if last_tool_use.id == tool_use.id { *last_tool_use = tool_use.clone(); - false - } else { - true + return true; } - } else { - true } + false }); - if push_new_tool_use { + if !has_tool_use { event_stream.send_tool_call( &tool_use.id, &tool_use.name, diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index bcf8401c1c14ae1a74bb7136141d0b35509cdd40..5b493fdf1087911372d8796cc88f4ad14eef8df0 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -658,6 +658,10 @@ impl LanguageModel for BedrockModel { } } + fn supports_streaming_tools(&self) -> bool { + true + } + fn telemetry_id(&self) -> String { format!("bedrock/{}", self.model.id()) } @@ -1200,8 +1204,25 @@ pub fn map_to_language_model_completion_events( .get_mut(&cb_delta.content_block_index) { tool_use.input_json.push_str(tool_output.input()); + if let Ok(input) = serde_json::from_str::( + &partial_json_fixer::fix_json(&tool_use.input_json), + ) { + Some(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: tool_use.id.clone().into(), + name: tool_use.name.clone().into(), + is_input_complete: false, + raw_input: tool_use.input_json.clone(), + input, + thought_signature: None, + }, + ))) + } else { + None + } + } else { + None } - None } Some(ContentBlockDelta::ReasoningContent(thinking)) => match thinking { ReasoningContentBlockDelta::Text(thoughts) => { diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 4363430f865de63ed5fec0d6b40b085d9413fc2a..7d714cd93a2a93dbb9fd02ec4d2b95149bb43330 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -246,6 +246,10 @@ impl LanguageModel for CopilotChatLanguageModel { self.model.supports_tools() } + fn supports_streaming_tools(&self) -> bool { + true + } + fn supports_images(&self) -> bool { self.model.supports_vision() } @@ -455,6 +459,23 @@ pub fn map_to_language_model_completion_events( entry.thought_signature = Some(thought_signature); } } + + if !entry.id.is_empty() && !entry.name.is_empty() { + if let Ok(input) = serde_json::from_str::( + &partial_json_fixer::fix_json(&entry.arguments), + ) { + events.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: entry.id.clone().into(), + name: entry.name.as_str().into(), + is_input_complete: false, + input, + raw_input: entry.arguments.clone(), + thought_signature: entry.thought_signature.clone(), + }, + ))); + } + } } if let Some(usage) = event.usage { diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index 2a9f7322b1fb5d3d1e6713c5a084b83dc2b01ce2..0bf86ef15c91b16dbc496ff732b087fedd0da0a9 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -246,6 +246,10 @@ impl LanguageModel for DeepSeekLanguageModel { true } + fn supports_streaming_tools(&self) -> bool { + true + } + fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool { true } @@ -469,6 +473,23 @@ impl DeepSeekEventMapper { entry.arguments.push_str(&arguments); } } + + if !entry.id.is_empty() && !entry.name.is_empty() { + if let Ok(input) = serde_json::from_str::( + &partial_json_fixer::fix_json(&entry.arguments), + ) { + events.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: entry.id.clone().into(), + name: entry.name.as_str().into(), + is_input_complete: false, + input, + raw_input: entry.arguments.clone(), + thought_signature: None, + }, + ))); + } + } } } diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 02d46dcaa7ce7acc76d85c93cad610a7d2489bf0..6af66f4e9a9d257b385c84a6c0c6d989f04c013f 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -280,6 +280,10 @@ impl LanguageModel for MistralLanguageModel { self.model.supports_tools() } + fn supports_streaming_tools(&self) -> bool { + true + } + fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool { self.model.supports_tools() } @@ -629,6 +633,23 @@ impl MistralEventMapper { entry.arguments.push_str(&arguments); } } + + if !entry.id.is_empty() && !entry.name.is_empty() { + if let Ok(input) = serde_json::from_str::( + &partial_json_fixer::fix_json(&entry.arguments), + ) { + events.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: entry.id.clone().into(), + name: entry.name.as_str().into(), + is_input_complete: false, + input, + raw_input: entry.arguments.clone(), + thought_signature: None, + }, + ))); + } + } } } diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 7fb65df0a534c7600f7315fd85d7adda0d66314a..57b3a6b20a9712e7c4d99b3ccfc48719e632da9d 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -328,6 +328,10 @@ impl LanguageModel for OpenAiLanguageModel { } } + fn supports_streaming_tools(&self) -> bool { + true + } + fn supports_thinking(&self) -> bool { self.model.reasoning_effort().is_some() } @@ -824,6 +828,23 @@ impl OpenAiEventMapper { entry.arguments.push_str(&arguments); } } + + if !entry.id.is_empty() && !entry.name.is_empty() { + if let Ok(input) = serde_json::from_str::( + &partial_json_fixer::fix_json(&entry.arguments), + ) { + events.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: entry.id.clone().into(), + name: entry.name.as_str().into(), + is_input_complete: false, + input, + raw_input: entry.arguments.clone(), + thought_signature: None, + }, + ))); + } + } } } } @@ -954,6 +975,20 @@ impl OpenAiResponseEventMapper { ResponsesStreamEvent::FunctionCallArgumentsDelta { item_id, delta, .. } => { if let Some(entry) = self.function_calls_by_item.get_mut(&item_id) { entry.arguments.push_str(&delta); + if let Ok(input) = serde_json::from_str::( + &partial_json_fixer::fix_json(&entry.arguments), + ) { + return vec![Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: LanguageModelToolUseId::from(entry.call_id.clone()), + name: entry.name.clone(), + is_input_complete: false, + input, + raw_input: entry.arguments.clone(), + thought_signature: None, + }, + ))]; + } } Vec::new() } @@ -1670,19 +1705,30 @@ mod tests { ]; let mapped = map_response_events(events); + assert_eq!(mapped.len(), 3); + // First event is the partial tool use (from FunctionCallArgumentsDelta) assert!(matches!( mapped[0], + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + is_input_complete: false, + .. + }) + )); + // Second event is the complete tool use (from FunctionCallArgumentsDone) + assert!(matches!( + mapped[1], LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref id, ref name, ref raw_input, + is_input_complete: true, .. }) if id.to_string() == "call_123" && name.as_ref() == "get_weather" && raw_input == "{\"city\":\"Boston\"}" )); assert!(matches!( - mapped[1], + mapped[2], LanguageModelCompletionEvent::Stop(StopReason::ToolUse) )); } @@ -1878,13 +1924,27 @@ mod tests { ]; let mapped = map_response_events(events); + assert_eq!(mapped.len(), 3); + // First event is the partial tool use (from FunctionCallArgumentsDelta) assert!(matches!( mapped[0], - LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { ref raw_input, .. }) - if raw_input == "{\"city\":\"Boston\"}" + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + is_input_complete: false, + .. + }) )); + // Second event is the complete tool use (from the Incomplete response output) assert!(matches!( mapped[1], + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + ref raw_input, + is_input_complete: true, + .. + }) + if raw_input == "{\"city\":\"Boston\"}" + )); + assert!(matches!( + mapped[2], LanguageModelCompletionEvent::Stop(StopReason::MaxTokens) )); } @@ -1976,4 +2036,80 @@ mod tests { LanguageModelCompletionEvent::Stop(StopReason::ToolUse) )); } + + #[test] + fn responses_stream_emits_partial_tool_use_events() { + let events = vec![ + ResponsesStreamEvent::OutputItemAdded { + output_index: 0, + sequence_number: None, + item: ResponseOutputItem::FunctionCall(ResponseFunctionToolCall { + id: Some("item_fn".to_string()), + status: Some("in_progress".to_string()), + name: Some("get_weather".to_string()), + call_id: Some("call_abc".to_string()), + arguments: String::new(), + }), + }, + ResponsesStreamEvent::FunctionCallArgumentsDelta { + item_id: "item_fn".into(), + output_index: 0, + delta: "{\"city\":\"Bos".into(), + sequence_number: None, + }, + ResponsesStreamEvent::FunctionCallArgumentsDelta { + item_id: "item_fn".into(), + output_index: 0, + delta: "ton\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::FunctionCallArgumentsDone { + item_id: "item_fn".into(), + output_index: 0, + arguments: "{\"city\":\"Boston\"}".into(), + sequence_number: None, + }, + ResponsesStreamEvent::Completed { + response: ResponseSummary::default(), + }, + ]; + + let mapped = map_response_events(events); + // Two partial events + one complete event + Stop + assert!(mapped.len() >= 3); + + // The last complete ToolUse event should have is_input_complete: true + let complete_tool_use = mapped.iter().find(|e| { + matches!( + e, + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + is_input_complete: true, + .. + }) + ) + }); + assert!( + complete_tool_use.is_some(), + "should have a complete tool use event" + ); + + // All ToolUse events before the final one should have is_input_complete: false + let tool_uses: Vec<_> = mapped + .iter() + .filter(|e| matches!(e, LanguageModelCompletionEvent::ToolUse(_))) + .collect(); + assert!( + tool_uses.len() >= 2, + "should have at least one partial and one complete event" + ); + + let last = tool_uses.last().unwrap(); + assert!(matches!( + last, + LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { + is_input_complete: true, + .. + }) + )); + } } diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index d47ea26c594ab0abb5c859ed549d43e0ed3f859b..b478bc843c05e01d428561d9c255ef0d2ca97148 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -319,6 +319,10 @@ impl LanguageModel for OpenAiCompatibleLanguageModel { } } + fn supports_streaming_tools(&self) -> bool { + true + } + fn supports_split_token_display(&self) -> bool { true } diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index 7a74125d606ddc4be56d113fbbf3fa66866fb595..e0e56bc1beadd8309a4c1b3c7626efa99c1c6473 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -314,6 +314,10 @@ impl LanguageModel for OpenRouterLanguageModel { self.model.supports_tool_calls() } + fn supports_streaming_tools(&self) -> bool { + true + } + fn supports_thinking(&self) -> bool { matches!(self.model.mode, OpenRouterModelMode::Thinking { .. }) } @@ -650,6 +654,23 @@ impl OpenRouterEventMapper { entry.thought_signature = Some(signature); } } + + if !entry.id.is_empty() && !entry.name.is_empty() { + if let Ok(input) = serde_json::from_str::( + &partial_json_fixer::fix_json(&entry.arguments), + ) { + events.push(Ok(LanguageModelCompletionEvent::ToolUse( + LanguageModelToolUse { + id: entry.id.clone().into(), + name: entry.name.as_str().into(), + is_input_complete: false, + input, + raw_input: entry.arguments.clone(), + thought_signature: entry.thought_signature.clone(), + }, + ))); + } + } } } diff --git a/crates/language_models/src/provider/vercel.rs b/crates/language_models/src/provider/vercel.rs index 3b324e46927f5864d83a5e4b74c46f5e39e8ab3a..b71da5b7db05710ee30115ab54379c9ee4e4c750 100644 --- a/crates/language_models/src/provider/vercel.rs +++ b/crates/language_models/src/provider/vercel.rs @@ -248,6 +248,10 @@ impl LanguageModel for VercelLanguageModel { true } + fn supports_streaming_tools(&self) -> bool { + true + } + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { match choice { LanguageModelToolChoice::Auto diff --git a/crates/language_models/src/provider/vercel_ai_gateway.rs b/crates/language_models/src/provider/vercel_ai_gateway.rs index 69c54e624b9e7289abaefbe7ab654d73df385b62..78f900de0c94fd3bbbff3962e92d1a8cb9f3e118 100644 --- a/crates/language_models/src/provider/vercel_ai_gateway.rs +++ b/crates/language_models/src/provider/vercel_ai_gateway.rs @@ -385,6 +385,10 @@ impl LanguageModel for VercelAiGatewayLanguageModel { } } + fn supports_streaming_tools(&self) -> bool { + true + } + fn supports_split_token_display(&self) -> bool { true } diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs index 06564224dea9621d594e5cf3f4a84093f1620446..f1f8bb658f04a91341951d1602af04f858af7bd3 100644 --- a/crates/language_models/src/provider/x_ai.rs +++ b/crates/language_models/src/provider/x_ai.rs @@ -257,6 +257,10 @@ impl LanguageModel for XAiLanguageModel { self.model.supports_images() } + fn supports_streaming_tools(&self) -> bool { + true + } + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { match choice { LanguageModelToolChoice::Auto @@ -265,8 +269,7 @@ impl LanguageModel for XAiLanguageModel { } } fn tool_input_format(&self) -> LanguageModelToolSchemaFormat { - let model_id = self.model.id().trim().to_lowercase(); - if model_id.eq(x_ai::Model::Grok4.id()) || model_id.eq(x_ai::Model::GrokCodeFast1.id()) { + if self.model.requires_json_schema_subset() { LanguageModelToolSchemaFormat::JsonSchemaSubset } else { LanguageModelToolSchemaFormat::JsonSchema diff --git a/crates/x_ai/src/x_ai.rs b/crates/x_ai/src/x_ai.rs index 072a893a6a8f4fc7fbc8a6f4f5ed43316915b974..1abb2b53771fa1e29e2979560e9f394744b26158 100644 --- a/crates/x_ai/src/x_ai.rs +++ b/crates/x_ai/src/x_ai.rs @@ -165,6 +165,18 @@ impl Model { } } + pub fn requires_json_schema_subset(&self) -> bool { + match self { + Self::Grok4 + | Self::Grok4FastReasoning + | Self::Grok4FastNonReasoning + | Self::Grok41FastNonReasoning + | Self::Grok41FastReasoning + | Self::GrokCodeFast1 => true, + _ => false, + } + } + pub fn supports_prompt_cache_key(&self) -> bool { false }