util.rs

  1use std::str::FromStr;
  2
  3/// Parses tool call arguments JSON, treating empty strings as empty objects.
  4///
  5/// Many LLM providers return empty strings for tool calls with no arguments.
  6/// This helper normalizes that behavior by converting empty strings to `{}`.
  7pub fn parse_tool_arguments(arguments: &str) -> Result<serde_json::Value, serde_json::Error> {
  8    if arguments.is_empty() {
  9        Ok(serde_json::Value::Object(Default::default()))
 10    } else {
 11        serde_json::Value::from_str(arguments)
 12    }
 13}
 14
 15/// `partial_json_fixer::fix_json` converts a trailing `\` inside a string into `\\`
 16/// (a literal backslash). When used for incremental parsing (comparing successive
 17/// parses to extract deltas), this produces a spurious backslash character that
 18/// doesn't exist in the final text, corrupting the output.
 19///
 20/// This function strips any trailing incomplete escape sequence before fixing,
 21/// so each intermediate parse produces a true prefix of the final string value.
 22pub fn fix_streamed_json(partial_json: &str) -> String {
 23    let json = strip_trailing_incomplete_escape(partial_json);
 24    partial_json_fixer::fix_json(json)
 25}
 26
 27fn strip_trailing_incomplete_escape(json: &str) -> &str {
 28    let trailing_backslashes = json
 29        .as_bytes()
 30        .iter()
 31        .rev()
 32        .take_while(|&&b| b == b'\\')
 33        .count();
 34    if trailing_backslashes % 2 == 1 {
 35        &json[..json.len() - 1]
 36    } else {
 37        json
 38    }
 39}
 40
 41/// Parses a "prompt is too long: N tokens ..." message and extracts the token count.
 42pub fn parse_prompt_too_long(message: &str) -> Option<u64> {
 43    message
 44        .strip_prefix("prompt is too long: ")?
 45        .split_once(" tokens")?
 46        .0
 47        .parse()
 48        .ok()
 49}
 50
 51#[cfg(test)]
 52mod tests {
 53    use super::*;
 54
 55    #[test]
 56    fn test_fix_streamed_json_strips_incomplete_escape() {
 57        let fixed = fix_streamed_json(r#"{"text": "hello\"#);
 58        let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
 59        assert_eq!(parsed["text"], "hello");
 60    }
 61
 62    #[test]
 63    fn test_fix_streamed_json_preserves_complete_escape() {
 64        let fixed = fix_streamed_json(r#"{"text": "hello\\"#);
 65        let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
 66        assert_eq!(parsed["text"], "hello\\");
 67    }
 68
 69    #[test]
 70    fn test_fix_streamed_json_strips_escape_after_complete_escape() {
 71        let fixed = fix_streamed_json(r#"{"text": "hello\\\"#);
 72        let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
 73        assert_eq!(parsed["text"], "hello\\");
 74    }
 75
 76    #[test]
 77    fn test_fix_streamed_json_no_escape_at_end() {
 78        let fixed = fix_streamed_json(r#"{"text": "hello"#);
 79        let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
 80        assert_eq!(parsed["text"], "hello");
 81    }
 82
 83    #[test]
 84    fn test_fix_streamed_json_newline_escape_boundary() {
 85        let fixed = fix_streamed_json(r#"{"text": "line1\"#);
 86        let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
 87        assert_eq!(parsed["text"], "line1");
 88
 89        let fixed = fix_streamed_json(r#"{"text": "line1\nline2"#);
 90        let parsed: serde_json::Value = serde_json::from_str(&fixed).expect("valid json");
 91        assert_eq!(parsed["text"], "line1\nline2");
 92    }
 93
 94    #[test]
 95    fn test_fix_streamed_json_incremental_delta_correctness() {
 96        let chunk1 = r#"{"replacement_text": "fn foo() {\"#;
 97        let fixed1 = fix_streamed_json(chunk1);
 98        let parsed1: serde_json::Value = serde_json::from_str(&fixed1).expect("valid json");
 99        let text1 = parsed1["replacement_text"].as_str().expect("string");
100        assert_eq!(text1, "fn foo() {");
101
102        let chunk2 = r#"{"replacement_text": "fn foo() {\n    return bar;\n}"}"#;
103        let fixed2 = fix_streamed_json(chunk2);
104        let parsed2: serde_json::Value = serde_json::from_str(&fixed2).expect("valid json");
105        let text2 = parsed2["replacement_text"].as_str().expect("string");
106        assert_eq!(text2, "fn foo() {\n    return bar;\n}");
107
108        let delta = &text2[text1.len()..];
109        assert_eq!(delta, "\n    return bar;\n}");
110    }
111}