diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 2f1425fccc9053eb58cd8bc5a15e9b9e71a9b468..bfc7e852c9772b4bf6b3d863c331b5966648cb87 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -117,8 +117,10 @@ impl Thread { .map(|message| message.id.0 + 1) .unwrap_or(0), ); - let tool_use = ToolUseState::from_saved_messages(&saved.messages); - let scripting_tool_use = ToolUseState::new(); + let tool_use = + ToolUseState::from_saved_messages(&saved.messages, |name| name != ScriptingTool::NAME); + let scripting_tool_use = + ToolUseState::from_saved_messages(&saved.messages, |name| name == ScriptingTool::NAME); Self { id, diff --git a/crates/assistant2/src/thread_store.rs b/crates/assistant2/src/thread_store.rs index 4d9dc8244945ac1e11ce0e1a620612d905831bc4..2200d914f3c8e3104a3d8256b214fe7f64ac807e 100644 --- a/crates/assistant2/src/thread_store.rs +++ b/crates/assistant2/src/thread_store.rs @@ -116,28 +116,35 @@ impl ThreadStore { updated_at: thread.updated_at(), messages: thread .messages() - .map(|message| SavedMessage { - id: message.id, - role: message.role, - text: message.text.clone(), - tool_uses: thread + .map(|message| { + let all_tool_uses = thread .tool_uses_for_message(message.id) .into_iter() + .chain(thread.scripting_tool_uses_for_message(message.id)) .map(|tool_use| SavedToolUse { id: tool_use.id, name: tool_use.name, input: tool_use.input, }) - .collect(), - tool_results: thread + .collect(); + let all_tool_results = thread .tool_results_for_message(message.id) .into_iter() + .chain(thread.scripting_tool_results_for_message(message.id)) .map(|tool_result| SavedToolResult { tool_use_id: tool_result.tool_use_id.clone(), is_error: tool_result.is_error, content: tool_result.content.clone(), }) - .collect(), + .collect(); + + SavedMessage { + id: message.id, + role: message.role, + text: message.text.clone(), + tool_uses: all_tool_uses, + tool_results: all_tool_results, + } }) .collect(), }; diff --git a/crates/assistant2/src/tool_use.rs b/crates/assistant2/src/tool_use.rs index 565728ef3fbdfff20bb18bed415cb3e7cf794e7e..d5eb50f99cc467b731cda62f04724eeba4db0c14 100644 --- a/crates/assistant2/src/tool_use.rs +++ b/crates/assistant2/src/tool_use.rs @@ -46,25 +46,39 @@ impl ToolUseState { } } - pub fn from_saved_messages(messages: &[SavedMessage]) -> Self { + /// Constructs a [`ToolUseState`] from the given list of [`SavedMessage`]s. + /// + /// Accepts a function to filter the tools that should be used to populate the state. + pub fn from_saved_messages( + messages: &[SavedMessage], + mut filter_by_tool_name: impl FnMut(&str) -> bool, + ) -> Self { let mut this = Self::new(); + let mut tool_names_by_id = HashMap::default(); for message in messages { match message.role { Role::Assistant => { if !message.tool_uses.is_empty() { - this.tool_uses_by_assistant_message.insert( - message.id, - message - .tool_uses + let tool_uses = message + .tool_uses + .iter() + .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref())) + .map(|tool_use| LanguageModelToolUse { + id: tool_use.id.clone(), + name: tool_use.name.clone().into(), + input: tool_use.input.clone(), + }) + .collect::>(); + + tool_names_by_id.extend( + tool_uses .iter() - .map(|tool_use| LanguageModelToolUse { - id: tool_use.id.clone(), - name: tool_use.name.clone().into(), - input: tool_use.input.clone(), - }) - .collect(), + .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())), ); + + this.tool_uses_by_assistant_message + .insert(message.id, tool_uses); } } Role::User => { @@ -76,6 +90,14 @@ impl ToolUseState { for tool_result in &message.tool_results { let tool_use_id = tool_result.tool_use_id.clone(); + let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else { + log::warn!("no tool name found for tool use: {tool_use_id:?}"); + continue; + }; + + if !(filter_by_tool_name)(tool_use.as_ref()) { + continue; + } tool_uses_by_user_message.push(tool_use_id.clone()); this.tool_results.insert(