util.rs

  1use std::str::FromStr;
  2
  3/// Parses tool call arguments JSON, treating empty strings as empty objects.
  4///
  5/// Many LLM providers return empty strings for tool calls with no arguments.
  6/// This helper normalizes that behavior by converting empty strings to `{}`.
  7pub fn parse_tool_arguments(arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
  8    if arguments.is_empty() {
  9        Ok(serde_json::Value::Object(Default::default()))
 10    } else {
 11        serde_json::Value::from_str(arguments)
 12    }
 13}
 14
 15/// `partial_json_fixer::fix_json` converts a trailing `\` inside a string into `\\`
 16/// (a literal backslash). When used for incremental parsing (comparing successive
 17/// parses to extract deltas), this produces a spurious backslash character that
 18/// doesn't exist in the final text, corrupting the output.
 19///
 20/// This function strips any trailing incomplete escape sequence before fixing,
 21/// so each intermediate parse produces a true prefix of the final string value.
 22pub fn fix_streamed_json(partial_json: &str) -> String {
 23    let json = strip_trailing_incomplete_escape(partial_json);
 24    partial_json_fixer::fix_json(json)
 25}
 26
 27fn strip_trailing_incomplete_escape(json: &str) -> &str {
 28    let trailing_backslashes = json
 29        .as_bytes()
 30        .iter()
 31        .rev()
 32        .take_while(|&&b| b == b'\\')
 33        .count();
 34    if trailing_backslashes % 2 == 1 {
 35        &json[..json.len() - 1]
 36    } else {
 37        json
 38    }
 39}
 40
 41#[cfg(test)]
 42mod tests {
 43    use super::*;
 44
 45    #[test]
 46    fn test_fix_streamed_json_strips_incomplete_escape() {
 47        // Trailing `\` inside a string — incomplete escape sequence
 48        let fixed = fix_streamed_json(r#"{"text": "hello\"#);
 49        let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
 50        assert_eq!(parsed["text"], "hello");
 51    }
 52
 53    #[test]
 54    fn test_fix_streamed_json_preserves_complete_escape() {
 55        // `\\` is a complete escape (literal backslash)
 56        let fixed = fix_streamed_json(r#"{"text": "hello\\"#);
 57        let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
 58        assert_eq!(parsed["text"], "hello\\");
 59    }
 60
 61    #[test]
 62    fn test_fix_streamed_json_strips_escape_after_complete_escape() {
 63        // `\\\` = complete `\\` (literal backslash) + incomplete `\`
 64        let fixed = fix_streamed_json(r#"{"text": "hello\\\"#);
 65        let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
 66        assert_eq!(parsed["text"], "hello\\");
 67    }
 68
 69    #[test]
 70    fn test_fix_streamed_json_no_escape_at_end() {
 71        let fixed = fix_streamed_json(r#"{"text": "hello"#);
 72        let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
 73        assert_eq!(parsed["text"], "hello");
 74    }
 75
 76    #[test]
 77    fn test_fix_streamed_json_newline_escape_boundary() {
 78        // Simulates a stream boundary landing between `\` and `n`
 79        let fixed = fix_streamed_json(r#"{"text": "line1\"#);
 80        let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
 81        assert_eq!(parsed["text"], "line1");
 82
 83        // Next chunk completes the escape
 84        let fixed = fix_streamed_json(r#"{"text": "line1\nline2"#);
 85        let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
 86        assert_eq!(parsed["text"], "line1\nline2");
 87    }
 88
 89    #[test]
 90    fn test_fix_streamed_json_incremental_delta_correctness() {
 91        // This is the actual scenario that causes the bug:
 92        // chunk 1 ends mid-escape, chunk 2 completes it.
 93        let chunk1 = r#"{"replacement_text": "fn foo() {\"#;
 94        let fixed1 = fix_streamed_json(chunk1);
 95        let parsed1: serde_json::Value = serde_json::from_str(&fixed1).expect("valid json");
 96        let text1 = parsed1["replacement_text"].as_str().expect("string");
 97        assert_eq!(text1, "fn foo() {");
 98
 99        let chunk2 = r#"{"replacement_text": "fn foo() {\n    return bar;\n}"}"#;
100        let fixed2 = fix_streamed_json(chunk2);
101        let parsed2: serde_json::Value = serde_json::from_str(&fixed2).expect("valid json");
102        let text2 = parsed2["replacement_text"].as_str().expect("string");
103        assert_eq!(text2, "fn foo() {\n    return bar;\n}");
104
105        // The delta should be the newline + rest, with no spurious backslash
106        let delta = &text2[text1.len()..];
107        assert_eq!(delta, "\n    return bar;\n}");
108    }
109}