From fdf144fb7292e23d4360afbb2b1729a2549a700b Mon Sep 17 00:00:00 2001 From: Finn Eitreim <48069764+feitreim@users.noreply.github.com> Date: Fri, 20 Mar 2026 03:09:19 -0400 Subject: [PATCH] language_models: Fix the partial json streaming to not blast `\` everywhere (#51976) ## Context This PR fixes one of the issues in #51905, where model outputs are full of errant `\` characters. heres the problem: As the response is streamed back to zed, we accumulate the message chunks and and need to convert those chunks to valid json, to do that we use `partial_json_fixer::fix_json`, when the last character of a chunk is `\`, the `fix_json` has to escape that backslash, because its inside of a string (if it isn't, its invalid json and the tool call will crash) and other wise you would end up escaping the end `"` and everything would be messed up. why is this a problem for zed: T_0 is the output at some step. T_1 is the output at the next step. the `fix_json` system is meant to be used by replacing T_0 with T_1, however in the editor, replacing the entirety of T_0 with T_1 would be slow/cause flickering/etc.. so we calculate the difference between T_0 and T_1 and just add it to the current buffer state. So when a chunk ends on `\`, we end up with something like `... end of line\\"}` at the end of T_0, in T_1, this becomes `... end of line\n ...`. then when we add the new chunk from T_1, it just picks up after the \n because its tracking the length to manage the deltas. ## How to Review utils.rs: fix_streamed_json => remove trailing backslashes from incoming json streams so that `partial_json_fixer::fix_json` doesn't try to escape them. other files: call fix_streamed_json before passing to `serde_json` I had claude write a bunch of tests while I was working on the fix, which I have kept in for now, but the end functionality of fix_streamed_json is really simple now, so maybe these arent really needed. ## Videos Behavior Before: https://github.com/user-attachments/assets/f23f5579-b2e1-4d71-9e24-f15ea831de52 Behavior After: https://github.com/user-attachments/assets/40acdc23-4522-4621-be28-895965f4f262 ## Self-Review Checklist - [x] I've reviewed my own diff for quality, security, and reliability - [x] Unsafe blocks (if any) have justifying comments - [x] The content is consistent with the [UI/UX checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) - [x] Tests cover the new/changed behavior - [x] Performance impact has been considered and is acceptable Release Notes: - language_models: fixed partial json streaming --- crates/language_models/Cargo.toml | 1 - .../language_models/src/provider/anthropic.rs | 8 +- .../language_models/src/provider/bedrock.rs | 4 +- .../src/provider/copilot_chat.rs | 4 +- .../language_models/src/provider/deepseek.rs | 4 +- .../language_models/src/provider/mistral.rs | 4 +- .../language_models/src/provider/open_ai.rs | 6 +- .../src/provider/open_router.rs | 4 +- crates/language_models/src/provider/util.rs | 96 +++++++++++++++++++ 9 files changed, 113 insertions(+), 18 deletions(-) diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index f9dc4266d69ae9164f6b187162ed32069de5c10c..911dfb813ac54d89e764b3d62c50b4411cf8ba9c 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -67,7 +67,6 @@ vercel = { workspace = true, features = ["schemars"] } x_ai = { workspace = true, features = ["schemars"] } [dev-dependencies] - language_model = { workspace = true, features = ["test-support"] } pretty_assertions.workspace = true diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index cd60935b59c6f3f1c15ebd9a91b5683639408618..1fd79fb3a93d978d0912abbc4f0688e0bbe846e6 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -24,7 +24,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::parse_tool_arguments; +use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; pub use settings::AnthropicAvailableModel as AvailableModel; @@ -873,9 +873,9 @@ impl AnthropicEventMapper { // valid JSON that serde can accept, e.g. by closing // unclosed delimiters. This way, we can update the // UI with whatever has been streamed back so far. - if let Ok(input) = serde_json::Value::from_str( - &partial_json_fixer::fix_json(&tool_use.input_json), - ) { + if let Ok(input) = + serde_json::Value::from_str(&fix_streamed_json(&tool_use.input_json)) + { return vec![Ok(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { id: tool_use.id.clone().into(), diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index b05e9159cee443662c82153f205c4600afe0de34..734e97ee335c4106fced9d334d31b5ed5b86d407 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -48,7 +48,7 @@ use ui_input::InputField; use util::ResultExt; use crate::AllLanguageModelSettings; -use crate::provider::util::parse_tool_arguments; +use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; actions!(bedrock, [Tab, TabPrev]); @@ -1244,7 +1244,7 @@ pub fn map_to_language_model_completion_events( { tool_use.input_json.push_str(tool_output.input()); if let Ok(input) = serde_json::from_str::( - &partial_json_fixer::fix_json(&tool_use.input_json), + &fix_streamed_json(&tool_use.input_json), ) { Some(Ok(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 286eb872795642be47dfd46f16e561dcd53f93dc..7063db83bf65b82a4f314ad97e9463b106400c0b 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -33,7 +33,7 @@ use ui::prelude::*; use util::debug_panic; use crate::provider::anthropic::{AnthropicEventMapper, into_anthropic}; -use crate::provider::util::parse_tool_arguments; +use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("copilot_chat"); const PROVIDER_NAME: LanguageModelProviderName = @@ -579,7 +579,7 @@ pub fn map_to_language_model_completion_events( if !entry.id.is_empty() && !entry.name.is_empty() { if let Ok(input) = serde_json::from_str::( - &partial_json_fixer::fix_json(&entry.arguments), + &fix_streamed_json(&entry.arguments), ) { events.push(Ok(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index 0bf86ef15c91b16dbc496ff732b087fedd0da0a9..e27bd510dbb0b0f518e615e31fc194675a5c3cfe 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -22,7 +22,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::parse_tool_arguments; +use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("deepseek"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("DeepSeek"); @@ -476,7 +476,7 @@ impl DeepSeekEventMapper { if !entry.id.is_empty() && !entry.name.is_empty() { if let Ok(input) = serde_json::from_str::( - &partial_json_fixer::fix_json(&entry.arguments), + &fix_streamed_json(&entry.arguments), ) { events.push(Ok(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 338931cf7ca902225e10a7d09c9e7528128f1491..72f0cae2993da4efb3e19cb19ec42b186290920d 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -22,7 +22,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::parse_tool_arguments; +use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("mistral"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Mistral"); @@ -647,7 +647,7 @@ impl MistralEventMapper { if !entry.id.is_empty() && !entry.name.is_empty() { if let Ok(input) = serde_json::from_str::( - &partial_json_fixer::fix_json(&entry.arguments), + &fix_streamed_json(&entry.arguments), ) { events.push(Ok(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 4ea7c8f49ce7f745e1aa108062cd0bb4def08097..2548a6b26f39dbb67add7262fc4b2796c1d8306f 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -33,7 +33,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::parse_tool_arguments; +use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; const PROVIDER_ID: LanguageModelProviderId = language_model::OPEN_AI_PROVIDER_ID; const PROVIDER_NAME: LanguageModelProviderName = language_model::OPEN_AI_PROVIDER_NAME; @@ -836,7 +836,7 @@ impl OpenAiEventMapper { if !entry.id.is_empty() && !entry.name.is_empty() { if let Ok(input) = serde_json::from_str::( - &partial_json_fixer::fix_json(&entry.arguments), + &fix_streamed_json(&entry.arguments), ) { events.push(Ok(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { @@ -991,7 +991,7 @@ impl OpenAiResponseEventMapper { if let Some(entry) = self.function_calls_by_item.get_mut(&item_id) { entry.arguments.push_str(&delta); if let Ok(input) = serde_json::from_str::( - &partial_json_fixer::fix_json(&entry.arguments), + &fix_streamed_json(&entry.arguments), ) { return vec![Ok(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index e0e56bc1beadd8309a4c1b3c7626efa99c1c6473..a4a679be73c0276351a6524ad7e8fc40e2c26860 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -21,7 +21,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*}; use ui_input::InputField; use util::ResultExt; -use crate::provider::util::parse_tool_arguments; +use crate::provider::util::{fix_streamed_json, parse_tool_arguments}; const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("openrouter"); const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("OpenRouter"); @@ -657,7 +657,7 @@ impl OpenRouterEventMapper { if !entry.id.is_empty() && !entry.name.is_empty() { if let Ok(input) = serde_json::from_str::( - &partial_json_fixer::fix_json(&entry.arguments), + &fix_streamed_json(&entry.arguments), ) { events.push(Ok(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { diff --git a/crates/language_models/src/provider/util.rs b/crates/language_models/src/provider/util.rs index 6b1cf7afbb7e3a068dabbc6787c322649d50393d..76a02b6de40a3e36c7c506f11a6f6d34d2aaca3e 100644 --- a/crates/language_models/src/provider/util.rs +++ b/crates/language_models/src/provider/util.rs @@ -11,3 +11,99 @@ pub fn parse_tool_arguments(arguments: &str) -> Result String { + let json = strip_trailing_incomplete_escape(partial_json); + partial_json_fixer::fix_json(json) +} + +fn strip_trailing_incomplete_escape(json: &str) -> &str { + let trailing_backslashes = json + .as_bytes() + .iter() + .rev() + .take_while(|&&b| b == b'\\') + .count(); + if trailing_backslashes % 2 == 1 { + &json[..json.len() - 1] + } else { + json + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_fix_streamed_json_strips_incomplete_escape() { + // Trailing `\` inside a string — incomplete escape sequence + let fixed = fix_streamed_json(r#"{"text": "hello\"#); + let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); + assert_eq!(parsed["text"], "hello"); + } + + #[test] + fn test_fix_streamed_json_preserves_complete_escape() { + // `\\` is a complete escape (literal backslash) + let fixed = fix_streamed_json(r#"{"text": "hello\\"#); + let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); + assert_eq!(parsed["text"], "hello\\"); + } + + #[test] + fn test_fix_streamed_json_strips_escape_after_complete_escape() { + // `\\\` = complete `\\` (literal backslash) + incomplete `\` + let fixed = fix_streamed_json(r#"{"text": "hello\\\"#); + let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); + assert_eq!(parsed["text"], "hello\\"); + } + + #[test] + fn test_fix_streamed_json_no_escape_at_end() { + let fixed = fix_streamed_json(r#"{"text": "hello"#); + let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); + assert_eq!(parsed["text"], "hello"); + } + + #[test] + fn test_fix_streamed_json_newline_escape_boundary() { + // Simulates a stream boundary landing between `\` and `n` + let fixed = fix_streamed_json(r#"{"text": "line1\"#); + let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); + assert_eq!(parsed["text"], "line1"); + + // Next chunk completes the escape + let fixed = fix_streamed_json(r#"{"text": "line1\nline2"#); + let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json"); + assert_eq!(parsed["text"], "line1\nline2"); + } + + #[test] + fn test_fix_streamed_json_incremental_delta_correctness() { + // This is the actual scenario that causes the bug: + // chunk 1 ends mid-escape, chunk 2 completes it. + let chunk1 = r#"{"replacement_text": "fn foo() {\"#; + let fixed1 = fix_streamed_json(chunk1); + let parsed1: serde_json::Value = serde_json::from_str(&fixed1).expect("valid json"); + let text1 = parsed1["replacement_text"].as_str().expect("string"); + assert_eq!(text1, "fn foo() {"); + + let chunk2 = r#"{"replacement_text": "fn foo() {\n return bar;\n}"}"#; + let fixed2 = fix_streamed_json(chunk2); + let parsed2: serde_json::Value = serde_json::from_str(&fixed2).expect("valid json"); + let text2 = parsed2["replacement_text"].as_str().expect("string"); + assert_eq!(text2, "fn foo() {\n return bar;\n}"); + + // The delta should be the newline + rest, with no spurious backslash + let delta = &text2[text1.len()..]; + assert_eq!(delta, "\n return bar;\n}"); + } +}