diff --git a/crates/agent/src/db.rs b/crates/agent/src/db.rs index c72e20571e2761788157a5fd10df147c2b414e4a..84d080ff48107e7173226df81a419b90603d82fd 100644 --- a/crates/agent/src/db.rs +++ b/crates/agent/src/db.rs @@ -150,6 +150,7 @@ impl DbThread { .unwrap_or_default(), input: tool_use.input, is_input_complete: true, + thought_signature: None, }, )); } diff --git a/crates/agent/src/edit_agent/evals.rs b/crates/agent/src/edit_agent/evals.rs index 54aa6ae5c95022ee1ef022aed78d46533de356be..ddb9052b84b986229720efa89b9e912452411d86 100644 --- a/crates/agent/src/edit_agent/evals.rs +++ b/crates/agent/src/edit_agent/evals.rs @@ -1108,6 +1108,7 @@ fn tool_use( raw_input: serde_json::to_string_pretty(&input).unwrap(), input: serde_json::to_value(input).unwrap(), is_input_complete: true, + thought_signature: None, }) } diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index 5d4bdce27cc05d1cf46a4b73821f0a97878fd6f4..ffc5dbc6d30e58b5d819c3778b063951b0ed0861 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -274,6 +274,7 @@ async fn test_prompt_caching(cx: &mut TestAppContext) { raw_input: json!({"text": "test"}).to_string(), input: json!({"text": "test"}), is_input_complete: true, + thought_signature: None, }; fake_model .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone())); @@ -461,6 +462,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { raw_input: "{}".into(), input: json!({}), is_input_complete: true, + thought_signature: None, }, )); fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( @@ -470,6 +472,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { raw_input: "{}".into(), input: json!({}), is_input_complete: true, + thought_signature: None, }, )); fake_model.end_last_completion_stream(); @@ -520,6 +523,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { raw_input: "{}".into(), input: json!({}), is_input_complete: true, + thought_signature: None, }, )); fake_model.end_last_completion_stream(); @@ -554,6 +558,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) { raw_input: "{}".into(), input: json!({}), is_input_complete: true, + thought_signature: None, }, )); fake_model.end_last_completion_stream(); @@ -592,6 +597,7 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) { raw_input: "{}".into(), input: json!({}), is_input_complete: true, + thought_signature: None, }, )); fake_model.end_last_completion_stream(); @@ -621,6 +627,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) { raw_input: "{}".into(), input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(), is_input_complete: true, + thought_signature: None, }; fake_model .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone())); @@ -731,6 +738,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) { raw_input: "{}".into(), input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(), is_input_complete: true, + thought_signature: None, }; let tool_result = LanguageModelToolResult { tool_use_id: "tool_id_1".into(), @@ -1037,6 +1045,7 @@ async fn test_mcp_tools(cx: &mut TestAppContext) { raw_input: json!({"text": "test"}).to_string(), input: json!({"text": "test"}), is_input_complete: true, + thought_signature: None, }, )); fake_model.end_last_completion_stream(); @@ -1080,6 +1089,7 @@ async fn test_mcp_tools(cx: &mut TestAppContext) { raw_input: json!({"text": "mcp"}).to_string(), input: json!({"text": "mcp"}), is_input_complete: true, + thought_signature: None, }, )); fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( @@ -1089,6 +1099,7 @@ async fn test_mcp_tools(cx: &mut TestAppContext) { raw_input: json!({"text": "native"}).to_string(), input: json!({"text": "native"}), is_input_complete: true, + thought_signature: None, }, )); fake_model.end_last_completion_stream(); @@ -1788,6 +1799,7 @@ async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) { raw_input: "{}".into(), input: json!({}), is_input_complete: true, + thought_signature: None, }; let echo_tool_use = LanguageModelToolUse { id: "tool_id_2".into(), @@ -1795,6 +1807,7 @@ async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) { raw_input: json!({"text": "test"}).to_string(), input: json!({"text": "test"}), is_input_complete: true, + thought_signature: None, }; fake_model.send_last_completion_stream_text_chunk("Hi!"); fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( @@ -2000,6 +2013,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { raw_input: input.to_string(), input, is_input_complete: false, + thought_signature: None, }, )); @@ -2012,6 +2026,7 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) { raw_input: input.to_string(), input, is_input_complete: true, + thought_signature: None, }, )); fake_model.end_last_completion_stream(); @@ -2214,6 +2229,7 @@ async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) { raw_input: json!({"text": "test"}).to_string(), input: json!({"text": "test"}), is_input_complete: true, + thought_signature: None, }; fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse( tool_use_1.clone(), diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index 9b7e5ec8d1c42fc846d131cfd063de5bba8287ae..84f8e8ef8dbaac1d55f73515f625b670a4a52709 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -229,6 +229,10 @@ pub struct GenerativeContentBlob { #[serde(rename_all = "camelCase")] pub struct FunctionCallPart { pub function_call: FunctionCall, + /// Thought signature returned by the model for function calls. + /// Only present on the first function call in parallel call scenarios. + #[serde(skip_serializing_if = "Option::is_none")] + pub thought_signature: Option, } #[derive(Debug, Serialize, Deserialize)] @@ -636,3 +640,109 @@ impl std::fmt::Display for Model { write!(f, "{}", self.id()) } } + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_function_call_part_with_signature_serializes_correctly() { + let part = FunctionCallPart { + function_call: FunctionCall { + name: "test_function".to_string(), + args: json!({"arg": "value"}), + }, + thought_signature: Some("test_signature".to_string()), + }; + + let serialized = serde_json::to_value(&part).unwrap(); + + assert_eq!(serialized["functionCall"]["name"], "test_function"); + assert_eq!(serialized["functionCall"]["args"]["arg"], "value"); + assert_eq!(serialized["thoughtSignature"], "test_signature"); + } + + #[test] + fn test_function_call_part_without_signature_omits_field() { + let part = FunctionCallPart { + function_call: FunctionCall { + name: "test_function".to_string(), + args: json!({"arg": "value"}), + }, + thought_signature: None, + }; + + let serialized = serde_json::to_value(&part).unwrap(); + + assert_eq!(serialized["functionCall"]["name"], "test_function"); + assert_eq!(serialized["functionCall"]["args"]["arg"], "value"); + // thoughtSignature field should not be present when None + assert!(serialized.get("thoughtSignature").is_none()); + } + + #[test] + fn test_function_call_part_deserializes_with_signature() { + let json = json!({ + "functionCall": { + "name": "test_function", + "args": {"arg": "value"} + }, + "thoughtSignature": "test_signature" + }); + + let part: FunctionCallPart = serde_json::from_value(json).unwrap(); + + assert_eq!(part.function_call.name, "test_function"); + assert_eq!(part.thought_signature, Some("test_signature".to_string())); + } + + #[test] + fn test_function_call_part_deserializes_without_signature() { + let json = json!({ + "functionCall": { + "name": "test_function", + "args": {"arg": "value"} + } + }); + + let part: FunctionCallPart = serde_json::from_value(json).unwrap(); + + assert_eq!(part.function_call.name, "test_function"); + assert_eq!(part.thought_signature, None); + } + + #[test] + fn test_function_call_part_round_trip() { + let original = FunctionCallPart { + function_call: FunctionCall { + name: "test_function".to_string(), + args: json!({"arg": "value", "nested": {"key": "val"}}), + }, + thought_signature: Some("round_trip_signature".to_string()), + }; + + let serialized = serde_json::to_value(&original).unwrap(); + let deserialized: FunctionCallPart = serde_json::from_value(serialized).unwrap(); + + assert_eq!(deserialized.function_call.name, original.function_call.name); + assert_eq!(deserialized.function_call.args, original.function_call.args); + assert_eq!(deserialized.thought_signature, original.thought_signature); + } + + #[test] + fn test_function_call_part_with_empty_signature_serializes() { + let part = FunctionCallPart { + function_call: FunctionCall { + name: "test_function".to_string(), + args: json!({"arg": "value"}), + }, + thought_signature: Some("".to_string()), + }; + + let serialized = serde_json::to_value(&part).unwrap(); + + // Empty string should still be serialized (normalization happens at a higher level) + assert_eq!(serialized["thoughtSignature"], ""); + } +} diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 606b0921b29f056ddea22947f08b2686af37d639..785bb0dbdc7b6bb82d052cce16eb1c4b2fd66a48 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -515,6 +515,9 @@ pub struct LanguageModelToolUse { pub raw_input: String, pub input: serde_json::Value, pub is_input_complete: bool, + /// Thought signature the model sent us. Some models require that this + /// signature be preserved and sent back in conversation history for validation. + pub thought_signature: Option, } pub struct LanguageModelTextStream { @@ -921,4 +924,85 @@ mod tests { ), } } + + #[test] + fn test_language_model_tool_use_serializes_with_signature() { + use serde_json::json; + + let tool_use = LanguageModelToolUse { + id: LanguageModelToolUseId::from("test_id"), + name: "test_tool".into(), + raw_input: json!({"arg": "value"}).to_string(), + input: json!({"arg": "value"}), + is_input_complete: true, + thought_signature: Some("test_signature".to_string()), + }; + + let serialized = serde_json::to_value(&tool_use).unwrap(); + + assert_eq!(serialized["id"], "test_id"); + assert_eq!(serialized["name"], "test_tool"); + assert_eq!(serialized["thought_signature"], "test_signature"); + } + + #[test] + fn test_language_model_tool_use_deserializes_with_missing_signature() { + use serde_json::json; + + let json = json!({ + "id": "test_id", + "name": "test_tool", + "raw_input": "{\"arg\":\"value\"}", + "input": {"arg": "value"}, + "is_input_complete": true + }); + + let tool_use: LanguageModelToolUse = serde_json::from_value(json).unwrap(); + + assert_eq!(tool_use.id, LanguageModelToolUseId::from("test_id")); + assert_eq!(tool_use.name.as_ref(), "test_tool"); + assert_eq!(tool_use.thought_signature, None); + } + + #[test] + fn test_language_model_tool_use_round_trip_with_signature() { + use serde_json::json; + + let original = LanguageModelToolUse { + id: LanguageModelToolUseId::from("round_trip_id"), + name: "round_trip_tool".into(), + raw_input: json!({"key": "value"}).to_string(), + input: json!({"key": "value"}), + is_input_complete: true, + thought_signature: Some("round_trip_sig".to_string()), + }; + + let serialized = serde_json::to_value(&original).unwrap(); + let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap(); + + assert_eq!(deserialized.id, original.id); + assert_eq!(deserialized.name, original.name); + assert_eq!(deserialized.thought_signature, original.thought_signature); + } + + #[test] + fn test_language_model_tool_use_round_trip_without_signature() { + use serde_json::json; + + let original = LanguageModelToolUse { + id: LanguageModelToolUseId::from("no_sig_id"), + name: "no_sig_tool".into(), + raw_input: json!({"key": "value"}).to_string(), + input: json!({"key": "value"}), + is_input_complete: true, + thought_signature: None, + }; + + let serialized = serde_json::to_value(&original).unwrap(); + let deserialized: LanguageModelToolUse = serde_json::from_value(serialized).unwrap(); + + assert_eq!(deserialized.id, original.id); + assert_eq!(deserialized.name, original.name); + assert_eq!(deserialized.thought_signature, None); + } } diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index 287c76fc6dfea530ce53b48178024ef185b98134..2491e8277a8b2632f6835af13736c23e94966c4c 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -711,6 +711,7 @@ impl AnthropicEventMapper { is_input_complete: false, raw_input: tool_use.input_json.clone(), input, + thought_signature: None, }, ))]; } @@ -734,6 +735,7 @@ impl AnthropicEventMapper { is_input_complete: true, input, raw_input: tool_use.input_json.clone(), + thought_signature: None, }, )), Err(json_parse_err) => { diff --git a/crates/language_models/src/provider/bedrock.rs b/crates/language_models/src/provider/bedrock.rs index 61f36428d2e69af013103c8ca06b38d8d4a96e8d..9672d61f90512be62ea58e77682d63cc8553710f 100644 --- a/crates/language_models/src/provider/bedrock.rs +++ b/crates/language_models/src/provider/bedrock.rs @@ -970,6 +970,7 @@ pub fn map_to_language_model_completion_events( is_input_complete: true, raw_input: tool_use.input_json, input, + thought_signature: None, }, )) }), diff --git a/crates/language_models/src/provider/copilot_chat.rs b/crates/language_models/src/provider/copilot_chat.rs index 2f2469fa770821c208e037665c02d9ea8c20408f..f62b899318ae56452509f8d9e7cca05f8859cf27 100644 --- a/crates/language_models/src/provider/copilot_chat.rs +++ b/crates/language_models/src/provider/copilot_chat.rs @@ -458,6 +458,7 @@ pub fn map_to_language_model_completion_events( is_input_complete: true, input, raw_input: tool_call.arguments, + thought_signature: None, }, )), Err(error) => Ok( @@ -560,6 +561,7 @@ impl CopilotResponsesEventMapper { is_input_complete: true, input, raw_input: arguments.clone(), + thought_signature: None, }, ))), Err(error) => { diff --git a/crates/language_models/src/provider/deepseek.rs b/crates/language_models/src/provider/deepseek.rs index 1d573fd964d0f183393bb766c492566f622a4901..4bc7164f421bfbaa075c72faff7f731c0defcdba 100644 --- a/crates/language_models/src/provider/deepseek.rs +++ b/crates/language_models/src/provider/deepseek.rs @@ -501,6 +501,7 @@ impl DeepSeekEventMapper { is_input_complete: true, input, raw_input: tool_call.arguments.clone(), + thought_signature: None, }, )), Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index e33b118e30fca60e147bd2f311e844626da9b368..68b6f976418b2125027e5800527f73cc49e5a1bb 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -439,11 +439,15 @@ pub fn into_google( })] } language_model::MessageContent::ToolUse(tool_use) => { + // Normalize empty string signatures to None + let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty()); + vec![Part::FunctionCallPart(google_ai::FunctionCallPart { function_call: google_ai::FunctionCall { name: tool_use.name.to_string(), args: tool_use.input, }, + thought_signature, })] } language_model::MessageContent::ToolResult(tool_result) => { @@ -655,6 +659,11 @@ impl GoogleEventMapper { let id: LanguageModelToolUseId = format!("{}-{}", name, next_tool_id).into(); + // Normalize empty string signatures to None + let thought_signature = function_call_part + .thought_signature + .filter(|s| !s.is_empty()); + events.push(Ok(LanguageModelCompletionEvent::ToolUse( LanguageModelToolUse { id, @@ -662,6 +671,7 @@ impl GoogleEventMapper { is_input_complete: true, raw_input: function_call_part.function_call.args.to_string(), input: function_call_part.function_call.args, + thought_signature, }, ))); } @@ -891,3 +901,424 @@ impl Render for ConfigurationView { } } } + +#[cfg(test)] +mod tests { + use super::*; + use google_ai::{ + Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse, + Part, Role as GoogleRole, TextPart, + }; + use language_model::{LanguageModelToolUseId, MessageContent, Role}; + use serde_json::json; + + #[test] + fn test_function_call_with_signature_creates_tool_use_with_signature() { + let mut mapper = GoogleEventMapper::new(); + + let response = GenerateContentResponse { + candidates: Some(vec![GenerateContentCandidate { + index: Some(0), + content: Content { + parts: vec![Part::FunctionCallPart(FunctionCallPart { + function_call: FunctionCall { + name: "test_function".to_string(), + args: json!({"arg": "value"}), + }, + thought_signature: Some("test_signature_123".to_string()), + })], + role: GoogleRole::Model, + }, + finish_reason: None, + finish_message: None, + safety_ratings: None, + citation_metadata: None, + }]), + prompt_feedback: None, + usage_metadata: None, + }; + + let events = mapper.map_event(response); + + assert_eq!(events.len(), 2); // ToolUse event + Stop event + + if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { + assert_eq!(tool_use.name.as_ref(), "test_function"); + assert_eq!( + tool_use.thought_signature.as_deref(), + Some("test_signature_123") + ); + } else { + panic!("Expected ToolUse event"); + } + } + + #[test] + fn test_function_call_without_signature_has_none() { + let mut mapper = GoogleEventMapper::new(); + + let response = GenerateContentResponse { + candidates: Some(vec![GenerateContentCandidate { + index: Some(0), + content: Content { + parts: vec![Part::FunctionCallPart(FunctionCallPart { + function_call: FunctionCall { + name: "test_function".to_string(), + args: json!({"arg": "value"}), + }, + thought_signature: None, + })], + role: GoogleRole::Model, + }, + finish_reason: None, + finish_message: None, + safety_ratings: None, + citation_metadata: None, + }]), + prompt_feedback: None, + usage_metadata: None, + }; + + let events = mapper.map_event(response); + + if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { + assert_eq!(tool_use.thought_signature, None); + } else { + panic!("Expected ToolUse event"); + } + } + + #[test] + fn test_empty_string_signature_normalized_to_none() { + let mut mapper = GoogleEventMapper::new(); + + let response = GenerateContentResponse { + candidates: Some(vec![GenerateContentCandidate { + index: Some(0), + content: Content { + parts: vec![Part::FunctionCallPart(FunctionCallPart { + function_call: FunctionCall { + name: "test_function".to_string(), + args: json!({"arg": "value"}), + }, + thought_signature: Some("".to_string()), + })], + role: GoogleRole::Model, + }, + finish_reason: None, + finish_message: None, + safety_ratings: None, + citation_metadata: None, + }]), + prompt_feedback: None, + usage_metadata: None, + }; + + let events = mapper.map_event(response); + + if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { + assert_eq!(tool_use.thought_signature, None); + } else { + panic!("Expected ToolUse event"); + } + } + + #[test] + fn test_parallel_function_calls_preserve_signatures() { + let mut mapper = GoogleEventMapper::new(); + + let response = GenerateContentResponse { + candidates: Some(vec![GenerateContentCandidate { + index: Some(0), + content: Content { + parts: vec![ + Part::FunctionCallPart(FunctionCallPart { + function_call: FunctionCall { + name: "function_1".to_string(), + args: json!({"arg": "value1"}), + }, + thought_signature: Some("signature_1".to_string()), + }), + Part::FunctionCallPart(FunctionCallPart { + function_call: FunctionCall { + name: "function_2".to_string(), + args: json!({"arg": "value2"}), + }, + thought_signature: None, + }), + ], + role: GoogleRole::Model, + }, + finish_reason: None, + finish_message: None, + safety_ratings: None, + citation_metadata: None, + }]), + prompt_feedback: None, + usage_metadata: None, + }; + + let events = mapper.map_event(response); + + assert_eq!(events.len(), 3); // 2 ToolUse events + Stop event + + if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { + assert_eq!(tool_use.name.as_ref(), "function_1"); + assert_eq!(tool_use.thought_signature.as_deref(), Some("signature_1")); + } else { + panic!("Expected ToolUse event for function_1"); + } + + if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] { + assert_eq!(tool_use.name.as_ref(), "function_2"); + assert_eq!(tool_use.thought_signature, None); + } else { + panic!("Expected ToolUse event for function_2"); + } + } + + #[test] + fn test_tool_use_with_signature_converts_to_function_call_part() { + let tool_use = language_model::LanguageModelToolUse { + id: LanguageModelToolUseId::from("test_id"), + name: "test_function".into(), + raw_input: json!({"arg": "value"}).to_string(), + input: json!({"arg": "value"}), + is_input_complete: true, + thought_signature: Some("test_signature_456".to_string()), + }; + + let request = super::into_google( + LanguageModelRequest { + messages: vec![language_model::LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::ToolUse(tool_use)], + cache: false, + }], + ..Default::default() + }, + "gemini-2.5-flash".to_string(), + GoogleModelMode::Default, + ); + + assert_eq!(request.contents[0].parts.len(), 1); + if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { + assert_eq!(fc_part.function_call.name, "test_function"); + assert_eq!( + fc_part.thought_signature.as_deref(), + Some("test_signature_456") + ); + } else { + panic!("Expected FunctionCallPart"); + } + } + + #[test] + fn test_tool_use_without_signature_omits_field() { + let tool_use = language_model::LanguageModelToolUse { + id: LanguageModelToolUseId::from("test_id"), + name: "test_function".into(), + raw_input: json!({"arg": "value"}).to_string(), + input: json!({"arg": "value"}), + is_input_complete: true, + thought_signature: None, + }; + + let request = super::into_google( + LanguageModelRequest { + messages: vec![language_model::LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::ToolUse(tool_use)], + cache: false, + }], + ..Default::default() + }, + "gemini-2.5-flash".to_string(), + GoogleModelMode::Default, + ); + + assert_eq!(request.contents[0].parts.len(), 1); + if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { + assert_eq!(fc_part.thought_signature, None); + } else { + panic!("Expected FunctionCallPart"); + } + } + + #[test] + fn test_empty_signature_in_tool_use_normalized_to_none() { + let tool_use = language_model::LanguageModelToolUse { + id: LanguageModelToolUseId::from("test_id"), + name: "test_function".into(), + raw_input: json!({"arg": "value"}).to_string(), + input: json!({"arg": "value"}), + is_input_complete: true, + thought_signature: Some("".to_string()), + }; + + let request = super::into_google( + LanguageModelRequest { + messages: vec![language_model::LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::ToolUse(tool_use)], + cache: false, + }], + ..Default::default() + }, + "gemini-2.5-flash".to_string(), + GoogleModelMode::Default, + ); + + if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { + assert_eq!(fc_part.thought_signature, None); + } else { + panic!("Expected FunctionCallPart"); + } + } + + #[test] + fn test_round_trip_preserves_signature() { + let mut mapper = GoogleEventMapper::new(); + + // Simulate receiving a response from Google with a signature + let response = GenerateContentResponse { + candidates: Some(vec![GenerateContentCandidate { + index: Some(0), + content: Content { + parts: vec![Part::FunctionCallPart(FunctionCallPart { + function_call: FunctionCall { + name: "test_function".to_string(), + args: json!({"arg": "value"}), + }, + thought_signature: Some("round_trip_sig".to_string()), + })], + role: GoogleRole::Model, + }, + finish_reason: None, + finish_message: None, + safety_ratings: None, + citation_metadata: None, + }]), + prompt_feedback: None, + usage_metadata: None, + }; + + let events = mapper.map_event(response); + + let tool_use = if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { + tool_use.clone() + } else { + panic!("Expected ToolUse event"); + }; + + // Convert back to Google format + let request = super::into_google( + LanguageModelRequest { + messages: vec![language_model::LanguageModelRequestMessage { + role: Role::Assistant, + content: vec![MessageContent::ToolUse(tool_use)], + cache: false, + }], + ..Default::default() + }, + "gemini-2.5-flash".to_string(), + GoogleModelMode::Default, + ); + + // Verify signature is preserved + if let Part::FunctionCallPart(fc_part) = &request.contents[0].parts[0] { + assert_eq!(fc_part.thought_signature.as_deref(), Some("round_trip_sig")); + } else { + panic!("Expected FunctionCallPart"); + } + } + + #[test] + fn test_mixed_text_and_function_call_with_signature() { + let mut mapper = GoogleEventMapper::new(); + + let response = GenerateContentResponse { + candidates: Some(vec![GenerateContentCandidate { + index: Some(0), + content: Content { + parts: vec![ + Part::TextPart(TextPart { + text: "I'll help with that.".to_string(), + }), + Part::FunctionCallPart(FunctionCallPart { + function_call: FunctionCall { + name: "helper_function".to_string(), + args: json!({"query": "help"}), + }, + thought_signature: Some("mixed_sig".to_string()), + }), + ], + role: GoogleRole::Model, + }, + finish_reason: None, + finish_message: None, + safety_ratings: None, + citation_metadata: None, + }]), + prompt_feedback: None, + usage_metadata: None, + }; + + let events = mapper.map_event(response); + + assert_eq!(events.len(), 3); // Text event + ToolUse event + Stop event + + if let Ok(LanguageModelCompletionEvent::Text(text)) = &events[0] { + assert_eq!(text, "I'll help with that."); + } else { + panic!("Expected Text event"); + } + + if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[1] { + assert_eq!(tool_use.name.as_ref(), "helper_function"); + assert_eq!(tool_use.thought_signature.as_deref(), Some("mixed_sig")); + } else { + panic!("Expected ToolUse event"); + } + } + + #[test] + fn test_special_characters_in_signature_preserved() { + let mut mapper = GoogleEventMapper::new(); + + let signature_with_special_chars = "sig<>\"'&%$#@!{}[]".to_string(); + + let response = GenerateContentResponse { + candidates: Some(vec![GenerateContentCandidate { + index: Some(0), + content: Content { + parts: vec![Part::FunctionCallPart(FunctionCallPart { + function_call: FunctionCall { + name: "test_function".to_string(), + args: json!({"arg": "value"}), + }, + thought_signature: Some(signature_with_special_chars.clone()), + })], + role: GoogleRole::Model, + }, + finish_reason: None, + finish_message: None, + safety_ratings: None, + citation_metadata: None, + }]), + prompt_feedback: None, + usage_metadata: None, + }; + + let events = mapper.map_event(response); + + if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] { + assert_eq!( + tool_use.thought_signature.as_deref(), + Some(signature_with_special_chars.as_str()) + ); + } else { + panic!("Expected ToolUse event"); + } + } +} diff --git a/crates/language_models/src/provider/lmstudio.rs b/crates/language_models/src/provider/lmstudio.rs index c0b3509c0e2c9636ca48cdb0de0cc6ed32a2b792..a16bd351a9d779bcba5b2a4111fc62e0dc9dc639 100644 --- a/crates/language_models/src/provider/lmstudio.rs +++ b/crates/language_models/src/provider/lmstudio.rs @@ -569,6 +569,7 @@ impl LmStudioEventMapper { is_input_complete: true, input, raw_input: tool_call.arguments, + thought_signature: None, }, )), Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { diff --git a/crates/language_models/src/provider/mistral.rs b/crates/language_models/src/provider/mistral.rs index 2d30dfca21d8cbc4fd1be3575801919148f705b3..0c45913bea83e32c508daa6c6579ecd0382b3dc0 100644 --- a/crates/language_models/src/provider/mistral.rs +++ b/crates/language_models/src/provider/mistral.rs @@ -720,6 +720,7 @@ impl MistralEventMapper { is_input_complete: true, input, raw_input: tool_call.arguments, + thought_signature: None, }, ))), Err(error) => { diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index b6870f5f72b08d2ca4decc101deae59b6a56c224..8345db3cce9fc51c487ec039c4257bfb39b162c3 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -592,6 +592,7 @@ fn map_to_language_model_completion_events( raw_input: function.arguments.to_string(), input: function.arguments, is_input_complete: true, + thought_signature: None, }); events.push(Ok(event)); state.used_tools = true; diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 792d280950ceafa24cdf5e4104b80dd49bd45f3f..ee62522882c214dfa1384f75ced6eba46c9ec35f 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -586,6 +586,7 @@ impl OpenAiEventMapper { is_input_complete: true, input, raw_input: tool_call.arguments.clone(), + thought_signature: None, }, )), Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { diff --git a/crates/language_models/src/provider/open_router.rs b/crates/language_models/src/provider/open_router.rs index 6326968a916a7d6a21811ee22c328564e1ec4682..c98ee02efd7b7af32ea6c649f29eef685753ba7d 100644 --- a/crates/language_models/src/provider/open_router.rs +++ b/crates/language_models/src/provider/open_router.rs @@ -635,6 +635,7 @@ impl OpenRouterEventMapper { is_input_complete: true, input, raw_input: tool_call.arguments.clone(), + thought_signature: None, }, )), Err(error) => Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {