Introduce LanguageModelToolUse::raw_input (#29322)

Nathan Sobo created

This is to enable alternative streaming solutions at the application
layer. I'm not sure we really should have performed parsing of the input
at this layer. Either way I want to experiment with streaming approaches
in a separate crate on a branch, and this will help.

/cc @maxdeviant @bennetbo @rtfeldman

Closes #ISSUE

Release Notes:

- N/A

Change summary

crates/agent/src/tool_use.rs                        | 1 +
crates/google_ai/src/google_ai.rs                   | 1 +
crates/language_model/src/language_model.rs         | 1 +
crates/language_models/src/provider/anthropic.rs    | 2 ++
crates/language_models/src/provider/bedrock.rs      | 1 +
crates/language_models/src/provider/copilot_chat.rs | 1 +
crates/language_models/src/provider/google.rs       | 5 +++++
crates/language_models/src/provider/open_ai.rs      | 1 +
8 files changed, 13 insertions(+)

Detailed changes

crates/agent/src/tool_use.rs 🔗

@@ -68,6 +68,7 @@ impl ToolUseState {
                             .map(|tool_use| LanguageModelToolUse {
                                 id: tool_use.id.clone(),
                                 name: tool_use.name.clone().into(),
+                                raw_input: tool_use.input.to_string(),
                                 input: tool_use.input.clone(),
                                 is_input_complete: true,
                             })

crates/google_ai/src/google_ai.rs 🔗

@@ -338,6 +338,7 @@ pub struct CountTokensResponse {
 #[derive(Debug, Serialize, Deserialize)]
 pub struct FunctionCall {
     pub name: String,
+    pub raw_args: String,
     pub args: serde_json::Value,
 }
 

crates/language_model/src/language_model.rs 🔗

@@ -186,6 +186,7 @@ where
 pub struct LanguageModelToolUse {
     pub id: LanguageModelToolUseId,
     pub name: Arc<str>,
+    pub raw_input: String,
     pub input: serde_json::Value,
     pub is_input_complete: bool,
 }

crates/language_models/src/provider/anthropic.rs 🔗

@@ -727,6 +727,7 @@ pub fn map_to_language_model_completion_events(
                                                     id: tool_use.id.clone().into(),
                                                     name: tool_use.name.clone().into(),
                                                     is_input_complete: false,
+                                                    raw_input: tool_use.input_json.clone(),
                                                     input,
                                                 },
                                             ))],
@@ -757,6 +758,7 @@ pub fn map_to_language_model_completion_events(
                                                     )
                                                     .map_err(|err| anyhow!("Error parsing tool call input JSON: {err:?} - JSON string was: {input_json:?}"))?
                                                 },
+                                                raw_input: tool_use.input_json.clone(),
                                             },
                                         ))
                                     })],

crates/language_models/src/provider/bedrock.rs 🔗

@@ -894,6 +894,7 @@ pub fn map_to_language_model_completion_events(
                                                 id: tool_use.id.into(),
                                                 name: tool_use.name.into(),
                                                 is_input_complete: true,
+                                                raw_input: tool_use.input_json.clone(),
                                                 input: if tool_use.input_json.is_empty() {
                                                     Value::Null
                                                 } else {

crates/language_models/src/provider/copilot_chat.rs 🔗

@@ -368,6 +368,7 @@ pub fn map_to_language_model_completion_events(
                                                     id: tool_call.id.into(),
                                                     name: tool_call.name.as_str().into(),
                                                     is_input_complete: true,
+                                                    raw_input: tool_call.arguments.clone(),
                                                     input: serde_json::Value::from_str(
                                                         &tool_call.arguments,
                                                     )?,

crates/language_models/src/provider/google.rs 🔗

@@ -396,6 +396,7 @@ pub fn into_google(
                     Some(Part::FunctionCallPart(google_ai::FunctionCallPart {
                         function_call: google_ai::FunctionCall {
                             name: tool_use.name.to_string(),
+                            raw_args: tool_use.raw_input,
                             args: tool_use.input,
                         },
                     }))
@@ -537,6 +538,10 @@ pub fn map_to_language_model_completion_events(
                                                     id,
                                                     name,
                                                     is_input_complete: true,
+                                                    raw_input: function_call_part
+                                                        .function_call
+                                                        .raw_args
+                                                        .clone(),
                                                     input: function_call_part.function_call.args,
                                                 },
                                             )));

crates/language_models/src/provider/open_ai.rs 🔗

@@ -491,6 +491,7 @@ pub fn map_to_language_model_completion_events(
                                                     id: tool_call.id.into(),
                                                     name: tool_call.name.as_str().into(),
                                                     is_input_complete: true,
+                                                    raw_input: tool_call.arguments.clone(),
                                                     input: serde_json::Value::from_str(
                                                         &tool_call.arguments,
                                                     )?,