diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index f9dc4266d69ae9164f6b187162ed32069de5c10c..911dfb813ac54d89e764b3d62c50b4411cf8ba9c 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -67,7 +67,6 @@ vercel = { workspace = true, features = ["schemars"] } x_ai = { workspace = true, features = ["schemars"] } [dev-dependencies] - language_model = { workspace = true, features = ["test-support"] } pretty_assertions.workspace = true diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index cd60935b59c6f3f1c15ebd9a91b5683639408618..1fd79fb3a93d978d0912abbc4f0688e0bbe846e6 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -24,7 +24,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::parse_tool_arguments; +use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; pub use settings::AnthropicAvailableModel as AvailableModel; @@ -873,9 +873,9 @@ impl AnthropicEventMapper { // valid JSON that serde can accept, e.g. by closing // unclosed delimiters. This way, we can update the // UI with whatever has been streamed back so far. - if let Ok(input) = serde_json::Value::from_str( - &partial_json_fixer::fix_json(&tool_use.input_json), - ) { + if let Ok(input) = + serde_json::Value::from_str(&fix_streamed_json(&tool_use.input_json)) + { return vec![Ok(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { id: tool_use.id.clone().into(), diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index b05e9159cee443662c82153f205c4600afe0de34..734e97ee335c4106fced9d334d31b5ed5b86d407 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -48,7 +48,7 @@ use ui_input::InputField; use util::ResultExt; use crate::AllLanguageModelSettings; -use crate::provider::util::parse_tool_arguments; +use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; actions!(bedrock, [Tab, TabPrev]); @@ -1244,7 +1244,7 @@ pub fn map_to_language_model_completion_events( { tool_use.input_json.push_str(tool_output.input()); if let Ok(input) = serde_json::from_str::( - &partial_json_fixer::fix_json(&tool_use.input_json), + &fix_streamed_json(&tool_use.input_json), ) { Some(Ok(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 286eb872795642be47dfd46f16e561dcd53f93dc..7063db83bf65b82a4f314ad97e9463b106400c0b 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -33,7 +33,7 @@ use ui::prelude::*; use util::debug_panic; use crate::provider::anthropic::{AnthropicEventMapper, into_anthropic}; -use crate::provider::util::parse_tool_arguments; +use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat"); const PROVIDER_NAME: LanguageModelProviderName = @@ -579,7 +579,7 @@ pub fn map_to_language_model_completion_events( if !entry.id.is_empty() && !entry.name.is_empty() { if let Ok(input) = serde_json::from_str::( - &partial_json_fixer::fix_json(&entry.arguments), + &fix_streamed_json(&entry.arguments), ) { events.push(Ok(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index 0bf86ef15c91b16dbc496ff732b087fedd0da0a9..e27bd510dbb0b0f518e615e31fc194675a5c3cfe 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -22,7 +22,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::parse_tool_arguments; +use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek"); @@ -476,7 +476,7 @@ impl DeepSeekEventMapper { if !entry.id.is_empty() && !entry.name.is_empty() { if let Ok(input) = serde_json::from_str::( - &partial_json_fixer::fix_json(&entry.arguments), + &fix_streamed_json(&entry.arguments), ) { events.push(Ok(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 338931cf7ca902225e10a7d09c9e7528128f1491..72f0cae2993da4efb3e19cb19ec42b186290920d 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -22,7 +22,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::parse_tool_arguments; +use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral"); @@ -647,7 +647,7 @@ impl MistralEventMapper { if !entry.id.is_empty() && !entry.name.is_empty() { if let Ok(input) = serde_json::from_str::( - &partial_json_fixer::fix_json(&entry.arguments), + &fix_streamed_json(&entry.arguments), ) { events.push(Ok(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 4ea7c8f49ce7f745e1aa108062cd0bb4def08097..2548a6b26f39dbb67add7262fc4b2796c1d8306f 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -33,7 +33,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::parse_tool_arguments; +use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID; const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME; @@ -836,7 +836,7 @@ impl OpenAiEventMapper { if !entry.id.is_empty() && !entry.name.is_empty() { if let Ok(input) = serde_json::from_str::( - &partial_json_fixer::fix_json(&entry.arguments), + &fix_streamed_json(&entry.arguments), ) { events.push(Ok(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { @@ -991,7 +991,7 @@ impl OpenAiResponseEventMapper { if let Some(entry) = self.function_calls_by_item.get_mut(&item_id) { entry.arguments.push_str(&delta); if let Ok(input) = serde_json::from_str::( - &partial_json_fixer::fix_json(&entry.arguments), + &fix_streamed_json(&entry.arguments), ) { return vec![Ok(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index e0e56bc1beadd8309a4c1b3c7626efa99c1c6473..a4a679be73c0276351a6524ad7e8fc40e2c26860 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -21,7 +21,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::parse_tool_arguments; +use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter"); @@ -657,7 +657,7 @@ impl OpenRouterEventMapper { if !entry.id.is_empty() && !entry.name.is_empty() { if let Ok(input) = serde_json::from_str::( - &partial_json_fixer::fix_json(&entry.arguments), + &fix_streamed_json(&entry.arguments), ) { events.push(Ok(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { diff --git a/crates/language_models/src/provider/util.rs b/crates/language_models/src/provider/util.rs index 6b1cf7afbb7e3a068dabbc6787c322649d50393d..76a02b6de40a3e36c7c506f11a6f6d34d2aaca3e 100644 --- a/crates/language_models/src/provider/util.rs +++ b/crates/language_models/src/provider/util.rs @@ -11,3 +11,99 @@ pub fn parse_tool_arguments(arguments: &str) -> Result String { + let json = strip_trailing_incomplete_escape(partial_json); + partial_json_fixer::fix_json(json) +} + +fn strip_trailing_incomplete_escape(json: &str) -> &str { + let trailing_backslashes = json + .as_bytes() + .iter() + .rev() + .take_while(|&&b| b == b'\\') + .count(); + if trailing_backslashes % 2 == 1 { + &json[..json.len() - 1] + } else { + json + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_fix_streamed_json_strips_incomplete_escape() { + // Trailing `\` inside a string — incomplete escape sequence + let fixed = fix_streamed_json(r#"{"text": "hello\"#); + let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); + assert_eq!(parsed["text"], "hello"); + } + + #[test] + fn test_fix_streamed_json_preserves_complete_escape() { + // `\\` is a complete escape (literal backslash) + let fixed = fix_streamed_json(r#"{"text": "hello\\"#); + let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); + assert_eq!(parsed["text"], "hello\\"); + } + + #[test] + fn test_fix_streamed_json_strips_escape_after_complete_escape() { + // `\\\` = complete `\\` (literal backslash) + incomplete `\` + let fixed = fix_streamed_json(r#"{"text": "hello\\\"#); + let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); + assert_eq!(parsed["text"], "hello\\"); + } + + #[test] + fn test_fix_streamed_json_no_escape_at_end() { + let fixed = fix_streamed_json(r#"{"text": "hello"#); + let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); + assert_eq!(parsed["text"], "hello"); + } + + #[test] + fn test_fix_streamed_json_newline_escape_boundary() { + // Simulates a stream boundary landing between `\` and `n` + let fixed = fix_streamed_json(r#"{"text": "line1\"#); + let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); + assert_eq!(parsed["text"], "line1"); + + // Next chunk completes the escape + let fixed = fix_streamed_json(r#"{"text": "line1\nline2"#); + let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); + assert_eq!(parsed["text"], "line1\nline2"); + } + + #[test] + fn test_fix_streamed_json_incremental_delta_correctness() { + // This is the actual scenario that causes the bug: + // chunk 1 ends mid-escape, chunk 2 completes it. + let chunk1 = r#"{"replacement_text": "fn foo() {\"#; + let fixed1 = fix_streamed_json(chunk1); + let parsed1: serde_json::Value = serde_json::from_str(&fixed1).expect("valid json"); + let text1 = parsed1["replacement_text"].as_str().expect("string"); + assert_eq!(text1, "fn foo() {"); + + let chunk2 = r#"{"replacement_text": "fn foo() {\n return bar;\n}"}"#; + let fixed2 = fix_streamed_json(chunk2); + let parsed2: serde_json::Value = serde_json::from_str(&fixed2).expect("valid json"); + let text2 = parsed2["replacement_text"].as_str().expect("string"); + assert_eq!(text2, "fn foo() {\n return bar;\n}"); + + // The delta should be the newline + rest, with no spurious backslash + let delta = &text2[text1.len()..]; + assert_eq!(delta, "\n return bar;\n}"); + } +}