assistant2: Persist scripting tool uses in saved threads (#26404)

Marshall Bowers created

This PR makes it so the scripting tool uses are persisted to and
restored from saved threads.

Release Notes:

- N/A

Change summary

crates/assistant2/src/thread.rs       |  6 ++-
crates/assistant2/src/thread_store.rs | 23 +++++++++-----
crates/assistant2/src/tool_use.rs     | 44 +++++++++++++++++++++-------
3 files changed, 52 insertions(+), 21 deletions(-)

Detailed changes

crates/assistant2/src/thread.rs 🔗

@@ -117,8 +117,10 @@ impl Thread {
                 .map(|message| message.id.0 + 1)
                 .unwrap_or(0),
         );
-        let tool_use = ToolUseState::from_saved_messages(&saved.messages);
-        let scripting_tool_use = ToolUseState::new();
+        let tool_use =
+            ToolUseState::from_saved_messages(&saved.messages, |name| name != ScriptingTool::NAME);
+        let scripting_tool_use =
+            ToolUseState::from_saved_messages(&saved.messages, |name| name == ScriptingTool::NAME);
 
         Self {
             id,

crates/assistant2/src/thread_store.rs 🔗

@@ -116,28 +116,35 @@ impl ThreadStore {
                 updated_at: thread.updated_at(),
                 messages: thread
                     .messages()
-                    .map(|message| SavedMessage {
-                        id: message.id,
-                        role: message.role,
-                        text: message.text.clone(),
-                        tool_uses: thread
+                    .map(|message| {
+                        let all_tool_uses = thread
                             .tool_uses_for_message(message.id)
                             .into_iter()
+                            .chain(thread.scripting_tool_uses_for_message(message.id))
                             .map(|tool_use| SavedToolUse {
                                 id: tool_use.id,
                                 name: tool_use.name,
                                 input: tool_use.input,
                             })
-                            .collect(),
-                        tool_results: thread
+                            .collect();
+                        let all_tool_results = thread
                             .tool_results_for_message(message.id)
                             .into_iter()
+                            .chain(thread.scripting_tool_results_for_message(message.id))
                             .map(|tool_result| SavedToolResult {
                                 tool_use_id: tool_result.tool_use_id.clone(),
                                 is_error: tool_result.is_error,
                                 content: tool_result.content.clone(),
                             })
-                            .collect(),
+                            .collect();
+
+                        SavedMessage {
+                            id: message.id,
+                            role: message.role,
+                            text: message.text.clone(),
+                            tool_uses: all_tool_uses,
+                            tool_results: all_tool_results,
+                        }
                     })
                     .collect(),
             };

crates/assistant2/src/tool_use.rs 🔗

@@ -46,25 +46,39 @@ impl ToolUseState {
         }
     }
 
-    pub fn from_saved_messages(messages: &[SavedMessage]) -> Self {
+    /// Constructs a [`ToolUseState`] from the given list of [`SavedMessage`]s.
+    ///
+    /// Accepts a function to filter the tools that should be used to populate the state.
+    pub fn from_saved_messages(
+        messages: &[SavedMessage],
+        mut filter_by_tool_name: impl FnMut(&str) -> bool,
+    ) -> Self {
         let mut this = Self::new();
+        let mut tool_names_by_id = HashMap::default();
 
         for message in messages {
             match message.role {
                 Role::Assistant => {
                     if !message.tool_uses.is_empty() {
-                        this.tool_uses_by_assistant_message.insert(
-                            message.id,
-                            message
-                                .tool_uses
+                        let tool_uses = message
+                            .tool_uses
+                            .iter()
+                            .filter(|tool_use| (filter_by_tool_name)(tool_use.name.as_ref()))
+                            .map(|tool_use| LanguageModelToolUse {
+                                id: tool_use.id.clone(),
+                                name: tool_use.name.clone().into(),
+                                input: tool_use.input.clone(),
+                            })
+                            .collect::<Vec<_>>();
+
+                        tool_names_by_id.extend(
+                            tool_uses
                                 .iter()
-                                .map(|tool_use| LanguageModelToolUse {
-                                    id: tool_use.id.clone(),
-                                    name: tool_use.name.clone().into(),
-                                    input: tool_use.input.clone(),
-                                })
-                                .collect(),
+                                .map(|tool_use| (tool_use.id.clone(), tool_use.name.clone())),
                         );
+
+                        this.tool_uses_by_assistant_message
+                            .insert(message.id, tool_uses);
                     }
                 }
                 Role::User => {
@@ -76,6 +90,14 @@ impl ToolUseState {
 
                         for tool_result in &message.tool_results {
                             let tool_use_id = tool_result.tool_use_id.clone();
+                            let Some(tool_use) = tool_names_by_id.get(&tool_use_id) else {
+                                log::warn!("no tool name found for tool use: {tool_use_id:?}");
+                                continue;
+                            };
+
+                            if !(filter_by_tool_name)(tool_use.as_ref()) {
+                                continue;
+                            }
 
                             tool_uses_by_user_message.push(tool_use_id.clone());
                             this.tool_results.insert(