From b445e4ce24410642996ce5d297d28eb8f37061ae Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Fri, 28 Feb 2025 11:33:08 -0500 Subject: [PATCH] assistant2: Rework how tool results are stored and referred to (#25817) This PR reworks how we store tool results and refer to them later. We now maintain a mapping of the tool uses to their corresponding results, with separate mappings for the messages and the tool uses they correspond to. Release Notes: - N/A --- crates/assistant2/src/active_thread.rs | 7 +- crates/assistant2/src/thread.rs | 115 ++++++++++++------------- 2 files changed, 57 insertions(+), 65 deletions(-) diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index fd6c7705c8385d7529f6fef85fc88d7cd3ec71bb..ffa650d23057b5ed34de57b533f2329b176e7822 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -255,12 +255,7 @@ impl ActiveThread { let task = tool.run(tool_use.input, self.workspace.clone(), window, cx); self.thread.update(cx, |thread, cx| { - thread.insert_tool_output( - tool_use.assistant_message_id, - tool_use.id.clone(), - task, - cx, - ); + thread.insert_tool_output(tool_use.id.clone(), task, cx); }); } } diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index e9469e57dc9af8393dfe34d1d9574521408391eb..50ea13e670b7a38bf79fcc140e6cc7dbbb19df97 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -88,8 +88,9 @@ pub struct Thread { completion_count: usize, pending_completions: Vec, tools: Arc, - tool_uses_by_message: HashMap>, - tool_results_by_message: HashMap>, + tool_uses_by_assistant_message: HashMap>, + tool_uses_by_user_message: HashMap>, + tool_results: HashMap, pending_tool_uses_by_id: HashMap, } @@ -107,8 +108,9 @@ impl Thread { completion_count: 0, pending_completions: Vec::new(), tools, - tool_uses_by_message: HashMap::default(), - tool_results_by_message: HashMap::default(), + tool_uses_by_assistant_message: HashMap::default(), + tool_uses_by_user_message: HashMap::default(), + tool_results: HashMap::default(), pending_tool_uses_by_id: HashMap::default(), } } @@ -141,8 +143,9 @@ impl Thread { completion_count: 0, pending_completions: Vec::new(), tools, - tool_uses_by_message: HashMap::default(), - tool_results_by_message: HashMap::default(), + tool_uses_by_assistant_message: HashMap::default(), + tool_uses_by_user_message: HashMap::default(), + tool_results: HashMap::default(), pending_tool_uses_by_id: HashMap::default(), } } @@ -209,26 +212,14 @@ impl Thread { } pub fn tool_uses_for_message(&self, id: MessageId) -> Vec { - let Some(tool_uses_for_message) = &self.tool_uses_by_message.get(&id) else { + let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else { return Vec::new(); }; - // The tool use was requested by an Assistant message, so we need to - // look for the tool results on the next user message. - let next_user_message = MessageId(id.0 + 1); - - let empty = Vec::new(); - let tool_results_for_message = self - .tool_results_by_message - .get(&next_user_message) - .unwrap_or_else(|| &empty); - let mut tool_uses = Vec::new(); for tool_use in tool_uses_for_message.iter() { - let tool_result = tool_results_for_message - .iter() - .find(|tool_result| tool_result.tool_use_id == tool_use.id); + let tool_result = self.tool_results.get(&tool_use.id); let status = (|| { if let Some(tool_result) = tool_result { @@ -264,7 +255,7 @@ impl Thread { } pub fn message_has_tool_results(&self, message_id: MessageId) -> bool { - self.tool_results_by_message + self.tool_uses_by_user_message .get(&message_id) .map_or(false, |results| !results.is_empty()) } @@ -369,13 +360,15 @@ impl Thread { content: Vec::new(), cache: false, }; - if let Some(tool_results) = self.tool_results_by_message.get(&message.id) { + if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message.id) { match request_kind { RequestKind::Chat => { - for tool_result in tool_results { - request_message - .content - .push(MessageContent::ToolResult(tool_result.clone())); + for tool_use_id in tool_uses { + if let Some(tool_result) = self.tool_results.get(tool_use_id) { + request_message + .content + .push(MessageContent::ToolResult(tool_result.clone())); + } } } RequestKind::Summarize => { @@ -390,7 +383,7 @@ impl Thread { .push(MessageContent::Text(message.text.clone())); } - if let Some(tool_uses) = self.tool_uses_by_message.get(&message.id) { + if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message.id) { match request_kind { RequestKind::Chat => { for tool_use in tool_uses { @@ -477,11 +470,22 @@ impl Thread { .rfind(|message| message.role == Role::Assistant) { thread - .tool_uses_by_message + .tool_uses_by_assistant_message .entry(last_assistant_message.id) .or_default() .push(tool_use.clone()); + // The tool use is being requested by the + // Assistant, so we want to attach the tool + // results to the next user message. + let next_user_message_id = + MessageId(last_assistant_message.id.0 + 1); + thread + .tool_uses_by_user_message + .entry(next_user_message_id) + .or_default() + .push(tool_use.id.clone()); + thread.pending_tool_uses_by_id.insert( tool_use.id.clone(), PendingToolUse { @@ -611,7 +615,6 @@ impl Thread { pub fn insert_tool_output( &mut self, - assistant_message_id: MessageId, tool_use_id: LanguageModelToolUseId, output: Task>, cx: &mut Context, @@ -621,44 +624,38 @@ impl Thread { async move { let output = output.await; thread - .update(&mut cx, |thread, cx| { - // The tool use was requested by an Assistant message, - // so we want to attach the tool results to the next - // user message. - let next_user_message = MessageId(assistant_message_id.0 + 1); - - let tool_results = thread - .tool_results_by_message - .entry(next_user_message) - .or_default(); - - match output { - Ok(output) => { - tool_results.push(LanguageModelToolResult { + .update(&mut cx, |thread, cx| match output { + Ok(output) => { + thread.tool_results.insert( + tool_use_id.clone(), + LanguageModelToolResult { tool_use_id: tool_use_id.clone(), content: output.into(), is_error: false, - }); - thread.pending_tool_uses_by_id.remove(&tool_use_id); + }, + ); + thread.pending_tool_uses_by_id.remove(&tool_use_id); - cx.emit(ThreadEvent::ToolFinished { tool_use_id }); - } - Err(err) => { - tool_results.push(LanguageModelToolResult { + cx.emit(ThreadEvent::ToolFinished { tool_use_id }); + } + Err(err) => { + thread.tool_results.insert( + tool_use_id.clone(), + LanguageModelToolResult { tool_use_id: tool_use_id.clone(), content: err.to_string().into(), is_error: true, - }); - - if let Some(tool_use) = - thread.pending_tool_uses_by_id.get_mut(&tool_use_id) - { - tool_use.status = - PendingToolUseStatus::Error(err.to_string().into()); - } - - cx.emit(ThreadEvent::ToolFinished { tool_use_id }); + }, + ); + + if let Some(tool_use) = + thread.pending_tool_uses_by_id.get_mut(&tool_use_id) + { + tool_use.status = + PendingToolUseStatus::Error(err.to_string().into()); } + + cx.emit(ThreadEvent::ToolFinished { tool_use_id }); } }) .ok();