@@ -117,8 +117,10 @@ impl Thread {
.map(|message| message.id.0 + 1)
.unwrap_or(0),
);
- let tool_use = ToolUseState::from_saved_messages(&saved.messages);
- let scripting_tool_use = ToolUseState::new();
+ let tool_use =
+ ToolUseState::from_saved_messages(&saved.messages, |name| name != ScriptingTool::NAME);
+ let scripting_tool_use =
+ ToolUseState::from_saved_messages(&saved.messages, |name| name == ScriptingTool::NAME);
Self {
id,
@@ -116,28 +116,35 @@ impl ThreadStore {
updated_at: thread.updated_at(),
messages: thread
.messages()
- .map(|message| SavedMessage {
- id: message.id,
- role: message.role,
- text: message.text.clone(),
- tool_uses: thread
+ .map(|message| {
+ let all_tool_uses = thread
.tool_uses_for_message(message.id)
.into_iter()
+ .chain(thread.scripting_tool_uses_for_message(message.id))
.map(|tool_use| SavedToolUse {
id: tool_use.id,
name: tool_use.name,
input: tool_use.input,
})
- .collect(),
- tool_results: thread
+ .collect();
+ let all_tool_results = thread
.tool_results_for_message(message.id)
.into_iter()
+ .chain(thread.scripting_tool_results_for_message(message.id))
.map(|tool_result| SavedToolResult {
tool_use_id: tool_result.tool_use_id.clone(),
is_error: tool_result.is_error,
content: tool_result.content.clone(),
})
- .collect(),
+ .collect();
+
+ SavedMessage {
+ id: message.id,
+ role: message.role,
+ text: message.text.clone(),
+ tool_uses: all_tool_uses,
+ tool_results: all_tool_results,
+ }
})
.collect(),
};
@@ -46,25 +46,39 @@ impl ToolUseState {
}
}
- pub fn from_saved_messages(messages: &[SavedMessage]) -> Self {
+ /// Constructs a [`ToolUseState`] from the given list of [`SavedMessage`]s.
+ ///
+ /// Accepts a function to filter the tools that should be used to populate the state.
+ pub fn from_saved_messages(
+ messages: &[SavedMessage],
+ mut filter_by_tool_name: impl FnMut(&str) -> bool,
+ ) -> Self {
let mut this = Self::new();
+ let mut tool_names_by_id = HashMap::default();
for message in messages {
match message.role {
Role::Assistant => {
if !message.tool_uses.is_empty() {
- this.tool_uses_by_assistant_message.insert(
- message.id,
- message
- .tool_uses
+ let tool_uses = message
+ .tool_uses
+ .iter()
+ .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
+ .map(|tool_use| LanguageModelToolUse {
+ id: tool_use.id.clone(),
+ name: tool_use.name.clone().into(),
+ input: tool_use.input.clone(),
+ })
+ .collect::<Vec<_>>();
+
+ tool_names_by_id.extend(
+ tool_uses
.iter()
- .map(|tool_use| LanguageModelToolUse {
- id: tool_use.id.clone(),
- name: tool_use.name.clone().into(),
- input: tool_use.input.clone(),
- })
- .collect(),
+ .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
);
+
+ this.tool_uses_by_assistant_message
+ .insert(message.id, tool_uses);
}
}
Role::User => {
@@ -76,6 +90,14 @@ impl ToolUseState {
for tool_result in &message.tool_results {
let tool_use_id = tool_result.tool_use_id.clone();
+ let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
+ log::warn!("no tool name found for tool use: {tool_use_id:?}");
+ continue;
+ };
+
+ if !(filter_by_tool_name)(tool_use.as_ref()) {
+ continue;
+ }
tool_uses_by_user_message.push(tool_use_id.clone());
this.tool_results.insert(