assistant2: Rework how tool results are stored and referred to (#25817)

Marshall Bowers created

This PR reworks how we store tool results and refer to them later.

We now maintain a mapping of the tool uses to their corresponding
results, with separate mappings for the messages and the tool uses they
correspond to.

Release Notes:

- N/A

Change summary

crates/assistant2/src/active_thread.rs |   7 -
crates/assistant2/src/thread.rs        | 115 +++++++++++++--------------
2 files changed, 57 insertions(+), 65 deletions(-)

Detailed changes

crates/assistant2/src/active_thread.rs 🔗

@@ -255,12 +255,7 @@ impl ActiveThread {
                         let task = tool.run(tool_use.input, self.workspace.clone(), window, cx);
 
                         self.thread.update(cx, |thread, cx| {
-                            thread.insert_tool_output(
-                                tool_use.assistant_message_id,
-                                tool_use.id.clone(),
-                                task,
-                                cx,
-                            );
+                            thread.insert_tool_output(tool_use.id.clone(), task, cx);
                         });
                     }
                 }

crates/assistant2/src/thread.rs 🔗

@@ -88,8 +88,9 @@ pub struct Thread {
     completion_count: usize,
     pending_completions: Vec<PendingCompletion>,
     tools: Arc<ToolWorkingSet>,
-    tool_uses_by_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
-    tool_results_by_message: HashMap<MessageId, Vec<LanguageModelToolResult>>,
+    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
+    tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
+    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
     pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 }
 
@@ -107,8 +108,9 @@ impl Thread {
             completion_count: 0,
             pending_completions: Vec::new(),
             tools,
-            tool_uses_by_message: HashMap::default(),
-            tool_results_by_message: HashMap::default(),
+            tool_uses_by_assistant_message: HashMap::default(),
+            tool_uses_by_user_message: HashMap::default(),
+            tool_results: HashMap::default(),
             pending_tool_uses_by_id: HashMap::default(),
         }
     }
@@ -141,8 +143,9 @@ impl Thread {
             completion_count: 0,
             pending_completions: Vec::new(),
             tools,
-            tool_uses_by_message: HashMap::default(),
-            tool_results_by_message: HashMap::default(),
+            tool_uses_by_assistant_message: HashMap::default(),
+            tool_uses_by_user_message: HashMap::default(),
+            tool_results: HashMap::default(),
             pending_tool_uses_by_id: HashMap::default(),
         }
     }
@@ -209,26 +212,14 @@ impl Thread {
     }
 
     pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
-        let Some(tool_uses_for_message) = &self.tool_uses_by_message.get(&id) else {
+        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
             return Vec::new();
         };
 
-        // The tool use was requested by an Assistant message, so we need to
-        // look for the tool results on the next user message.
-        let next_user_message = MessageId(id.0 + 1);
-
-        let empty = Vec::new();
-        let tool_results_for_message = self
-            .tool_results_by_message
-            .get(&next_user_message)
-            .unwrap_or_else(|| &empty);
-
         let mut tool_uses = Vec::new();
 
         for tool_use in tool_uses_for_message.iter() {
-            let tool_result = tool_results_for_message
-                .iter()
-                .find(|tool_result| tool_result.tool_use_id == tool_use.id);
+            let tool_result = self.tool_results.get(&tool_use.id);
 
             let status = (|| {
                 if let Some(tool_result) = tool_result {
@@ -264,7 +255,7 @@ impl Thread {
     }
 
     pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
-        self.tool_results_by_message
+        self.tool_uses_by_user_message
             .get(&message_id)
             .map_or(false, |results| !results.is_empty())
     }
@@ -369,13 +360,15 @@ impl Thread {
                 content: Vec::new(),
                 cache: false,
             };
-            if let Some(tool_results) = self.tool_results_by_message.get(&message.id) {
+            if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message.id) {
                 match request_kind {
                     RequestKind::Chat => {
-                        for tool_result in tool_results {
-                            request_message
-                                .content
-                                .push(MessageContent::ToolResult(tool_result.clone()));
+                        for tool_use_id in tool_uses {
+                            if let Some(tool_result) = self.tool_results.get(tool_use_id) {
+                                request_message
+                                    .content
+                                    .push(MessageContent::ToolResult(tool_result.clone()));
+                            }
                         }
                     }
                     RequestKind::Summarize => {
@@ -390,7 +383,7 @@ impl Thread {
                     .push(MessageContent::Text(message.text.clone()));
             }
 
-            if let Some(tool_uses) = self.tool_uses_by_message.get(&message.id) {
+            if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message.id) {
                 match request_kind {
                     RequestKind::Chat => {
                         for tool_use in tool_uses {
@@ -477,11 +470,22 @@ impl Thread {
                                     .rfind(|message| message.role == Role::Assistant)
                                 {
                                     thread
-                                        .tool_uses_by_message
+                                        .tool_uses_by_assistant_message
                                         .entry(last_assistant_message.id)
                                         .or_default()
                                         .push(tool_use.clone());
 
+                                    // The tool use is being requested by the
+                                    // Assistant, so we want to attach the tool
+                                    // results to the next user message.
+                                    let next_user_message_id =
+                                        MessageId(last_assistant_message.id.0 + 1);
+                                    thread
+                                        .tool_uses_by_user_message
+                                        .entry(next_user_message_id)
+                                        .or_default()
+                                        .push(tool_use.id.clone());
+
                                     thread.pending_tool_uses_by_id.insert(
                                         tool_use.id.clone(),
                                         PendingToolUse {
@@ -611,7 +615,6 @@ impl Thread {
 
     pub fn insert_tool_output(
         &mut self,
-        assistant_message_id: MessageId,
         tool_use_id: LanguageModelToolUseId,
         output: Task<Result<String>>,
         cx: &mut Context<Self>,
@@ -621,44 +624,38 @@ impl Thread {
             async move {
                 let output = output.await;
                 thread
-                    .update(&mut cx, |thread, cx| {
-                        // The tool use was requested by an Assistant message,
-                        // so we want to attach the tool results to the next
-                        // user message.
-                        let next_user_message = MessageId(assistant_message_id.0 + 1);
-
-                        let tool_results = thread
-                            .tool_results_by_message
-                            .entry(next_user_message)
-                            .or_default();
-
-                        match output {
-                            Ok(output) => {
-                                tool_results.push(LanguageModelToolResult {
+                    .update(&mut cx, |thread, cx| match output {
+                        Ok(output) => {
+                            thread.tool_results.insert(
+                                tool_use_id.clone(),
+                                LanguageModelToolResult {
                                     tool_use_id: tool_use_id.clone(),
                                     content: output.into(),
                                     is_error: false,
-                                });
-                                thread.pending_tool_uses_by_id.remove(&tool_use_id);
+                                },
+                            );
+                            thread.pending_tool_uses_by_id.remove(&tool_use_id);
 
-                                cx.emit(ThreadEvent::ToolFinished { tool_use_id });
-                            }
-                            Err(err) => {
-                                tool_results.push(LanguageModelToolResult {
+                            cx.emit(ThreadEvent::ToolFinished { tool_use_id });
+                        }
+                        Err(err) => {
+                            thread.tool_results.insert(
+                                tool_use_id.clone(),
+                                LanguageModelToolResult {
                                     tool_use_id: tool_use_id.clone(),
                                     content: err.to_string().into(),
                                     is_error: true,
-                                });
-
-                                if let Some(tool_use) =
-                                    thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
-                                {
-                                    tool_use.status =
-                                        PendingToolUseStatus::Error(err.to_string().into());
-                                }
-
-                                cx.emit(ThreadEvent::ToolFinished { tool_use_id });
+                                },
+                            );
+
+                            if let Some(tool_use) =
+                                thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
+                            {
+                                tool_use.status =
+                                    PendingToolUseStatus::Error(err.to_string().into());
                             }
+
+                            cx.emit(ThreadEvent::ToolFinished { tool_use_id });
                         }
                     })
                     .ok();