diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 1820aebae547afa1a01968bb5d160b34503e9e1e..0a14a19e739abc4be5ac40f5da2ee663c19fbece 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -1965,6 +1965,7 @@ impl Thread { tool_name, raw_input, json_parse_error, + event_stream, ), ))); } @@ -2050,42 +2051,7 @@ impl Thread { kind = tool.kind(); } - // Ensure the last message ends in the current tool use - let last_message = self.pending_message(); - let push_new_tool_use = last_message.content.last_mut().is_none_or(|content| { - if let AgentMessageContent::ToolUse(last_tool_use) = content { - if last_tool_use.id == tool_use.id { - *last_tool_use = tool_use.clone(); - false - } else { - true - } - } else { - true - } - }); - - if push_new_tool_use { - event_stream.send_tool_call( - &tool_use.id, - &tool_use.name, - title, - kind, - tool_use.input.clone(), - ); - last_message - .content - .push(AgentMessageContent::ToolUse(tool_use.clone())); - } else { - event_stream.update_tool_call_fields( - &tool_use.id, - acp::ToolCallUpdateFields::new() - .title(title.as_str()) - .kind(kind) - .raw_input(tool_use.input.clone()), - None, - ); - } + self.send_or_update_tool_use(&tool_use, title, kind, event_stream); if !tool_use.is_input_complete { return None; @@ -2152,7 +2118,23 @@ impl Thread { tool_name: Arc, raw_input: Arc, json_parse_error: String, + event_stream: &ThreadEventStream, ) -> LanguageModelToolResult { + let tool_use = LanguageModelToolUse { + id: tool_use_id.clone(), + name: tool_name.clone(), + raw_input: raw_input.to_string(), + input: serde_json::json!({}), + is_input_complete: true, + thought_signature: None, + }; + self.send_or_update_tool_use( + &tool_use, + SharedString::from(&tool_use.name), + acp::ToolKind::Other, + event_stream, + ); + let tool_output = format!("Error parsing input JSON: {json_parse_error}"); LanguageModelToolResult { tool_use_id, @@ -2163,6 +2145,51 @@ impl Thread { } } + fn send_or_update_tool_use( + &mut self, + tool_use: &LanguageModelToolUse, + title: SharedString, + kind: acp::ToolKind, + event_stream: &ThreadEventStream, + ) { + // Ensure the last message ends in the current tool use + let last_message = self.pending_message(); + let push_new_tool_use = last_message.content.last_mut().is_none_or(|content| { + if let AgentMessageContent::ToolUse(last_tool_use) = content { + if last_tool_use.id == tool_use.id { + *last_tool_use = tool_use.clone(); + false + } else { + true + } + } else { + true + } + }); + + if push_new_tool_use { + event_stream.send_tool_call( + &tool_use.id, + &tool_use.name, + title, + kind, + tool_use.input.clone(), + ); + last_message + .content + .push(AgentMessageContent::ToolUse(tool_use.clone())); + } else { + event_stream.update_tool_call_fields( + &tool_use.id, + acp::ToolCallUpdateFields::new() + .title(title.as_str()) + .kind(kind) + .raw_input(tool_use.input.clone()), + None, + ); + } + } + pub fn title(&self) -> SharedString { self.title.clone().unwrap_or("New Thread".into()) } @@ -3511,3 +3538,117 @@ fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage { size: None, } } + +#[cfg(test)] +mod tests { + use super::*; + use gpui::TestAppContext; + use language_model::LanguageModelToolUseId; + use serde_json::json; + use std::sync::Arc; + + async fn setup_thread_for_test(cx: &mut TestAppContext) -> (Entity, ThreadEventStream) { + cx.update(|cx| { + let settings_store = settings::SettingsStore::test(cx); + cx.set_global(settings_store); + }); + + let fs = fs::FakeFs::new(cx.background_executor.clone()); + let templates = Templates::new(); + let project = Project::test(fs.clone(), [], cx).await; + + cx.update(|cx| { + let project_context = cx.new(|_cx| prompt_store::ProjectContext::default()); + let context_server_store = project.read(cx).context_server_store(); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(context_server_store, cx)); + + let thread = cx.new(|cx| { + Thread::new( + project, + project_context, + context_server_registry, + templates, + None, + cx, + ) + }); + + let (event_tx, _event_rx) = mpsc::unbounded(); + let event_stream = ThreadEventStream(event_tx); + + (thread, event_stream) + }) + } + + #[gpui::test] + async fn test_handle_tool_use_json_parse_error_adds_tool_use_to_content( + cx: &mut TestAppContext, + ) { + let (thread, event_stream) = setup_thread_for_test(cx).await; + + cx.update(|cx| { + thread.update(cx, |thread, _cx| { + let tool_use_id = LanguageModelToolUseId::from("test_tool_id"); + let tool_name: Arc = Arc::from("test_tool"); + let raw_input: Arc = Arc::from("{invalid json"); + let json_parse_error = "expected value at line 1 column 1".to_string(); + + // Call the function under test + let result = thread.handle_tool_use_json_parse_error_event( + tool_use_id.clone(), + tool_name.clone(), + raw_input.clone(), + json_parse_error, + &event_stream, + ); + + // Verify the result is an error + assert!(result.is_error); + assert_eq!(result.tool_use_id, tool_use_id); + assert_eq!(result.tool_name, tool_name); + assert!(matches!( + result.content, + LanguageModelToolResultContent::Text(_) + )); + + // Verify the tool use was added to the message content + { + let last_message = thread.pending_message(); + assert_eq!( + last_message.content.len(), + 1, + "Should have one tool_use in content" + ); + + match &last_message.content[0] { + AgentMessageContent::ToolUse(tool_use) => { + assert_eq!(tool_use.id, tool_use_id); + assert_eq!(tool_use.name, tool_name); + assert_eq!(tool_use.raw_input, raw_input.to_string()); + assert!(tool_use.is_input_complete); + // Should fall back to empty object for invalid JSON + assert_eq!(tool_use.input, json!({})); + } + _ => panic!("Expected ToolUse content"), + } + } + + // Insert the tool result (simulating what the caller does) + thread + .pending_message() + .tool_results + .insert(result.tool_use_id.clone(), result); + + // Verify the tool result was added + let last_message = thread.pending_message(); + assert_eq!( + last_message.tool_results.len(), + 1, + "Should have one tool_result" + ); + assert!(last_message.tool_results.contains_key(&tool_use_id)); + }); + }); + } +}