assistant2: Restore tool uses when loading saved threads (#25942)

Marshall Bowers created

This PR makes it so tool uses are restored when loading saved threads in
Assistant 2.

Release Notes:

- N/A

Change summary

crates/assistant2/src/thread.rs       | 14 +++-
crates/assistant2/src/thread_store.rs | 40 ++++++++++++++
crates/assistant2/src/tool_use.rs     | 74 ++++++++++++++++++++++++++++
3 files changed, 120 insertions(+), 8 deletions(-)

Detailed changes

crates/assistant2/src/thread.rs 🔗

@@ -8,8 +8,9 @@ use futures::StreamExt as _;
 use gpui::{App, Context, EventEmitter, SharedString, Task};
 use language_model::{
     LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
-    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolUseId,
-    MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError, Role, StopReason,
+    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
+    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
+    Role, StopReason,
 };
 use serde::{Deserialize, Serialize};
 use util::{post_inc, TryFutureExt as _};
@@ -88,7 +89,7 @@ impl Thread {
             completion_count: 0,
             pending_completions: Vec::new(),
             tools,
-            tool_use: ToolUseState::default(),
+            tool_use: ToolUseState::new(),
         }
     }
 
@@ -99,6 +100,7 @@ impl Thread {
         _cx: &mut Context<Self>,
     ) -> Self {
         let next_message_id = MessageId(saved.messages.len());
+        let tool_use = ToolUseState::from_saved_messages(&saved.messages);
 
         Self {
             id,
@@ -120,7 +122,7 @@ impl Thread {
             completion_count: 0,
             pending_completions: Vec::new(),
             tools,
-            tool_use: ToolUseState::default(),
+            tool_use,
         }
     }
 
@@ -189,6 +191,10 @@ impl Thread {
         self.tool_use.tool_uses_for_message(id)
     }
 
+    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
+        self.tool_use.tool_results_for_message(id)
+    }
+
     pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
         self.tool_use.message_has_tool_results(message_id)
     }

crates/assistant2/src/thread_store.rs 🔗

@@ -14,7 +14,7 @@ use gpui::{
 };
 use heed::types::{SerdeBincode, SerdeJson};
 use heed::Database;
-use language_model::Role;
+use language_model::{LanguageModelToolUseId, Role};
 use project::Project;
 use serde::{Deserialize, Serialize};
 use util::ResultExt as _;
@@ -113,6 +113,24 @@ impl ThreadStore {
                         id: message.id,
                         role: message.role,
                         text: message.text.clone(),
+                        tool_uses: thread
+                            .tool_uses_for_message(message.id)
+                            .into_iter()
+                            .map(|tool_use| SavedToolUse {
+                                id: tool_use.id,
+                                name: tool_use.name,
+                                input: tool_use.input,
+                            })
+                            .collect(),
+                        tool_results: thread
+                            .tool_results_for_message(message.id)
+                            .into_iter()
+                            .map(|tool_result| SavedToolResult {
+                                tool_use_id: tool_result.tool_use_id.clone(),
+                                is_error: tool_result.is_error,
+                                content: tool_result.content.clone(),
+                            })
+                            .collect(),
                     })
                     .collect(),
             };
@@ -239,11 +257,29 @@ pub struct SavedThread {
     pub messages: Vec<SavedMessage>,
 }
 
-#[derive(Serialize, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
 pub struct SavedMessage {
     pub id: MessageId,
     pub role: Role,
     pub text: String,
+    #[serde(default)]
+    pub tool_uses: Vec<SavedToolUse>,
+    #[serde(default)]
+    pub tool_results: Vec<SavedToolResult>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct SavedToolUse {
+    pub id: LanguageModelToolUseId,
+    pub name: SharedString,
+    pub input: serde_json::Value,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct SavedToolResult {
+    pub tool_use_id: LanguageModelToolUseId,
+    pub is_error: bool,
+    pub content: Arc<str>,
 }
 
 struct GlobalThreadsDatabase(

crates/assistant2/src/tool_use.rs 🔗

@@ -7,10 +7,11 @@ use futures::FutureExt as _;
 use gpui::{SharedString, Task};
 use language_model::{
     LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
-    LanguageModelToolUseId, MessageContent,
+    LanguageModelToolUseId, MessageContent, Role,
 };
 
 use crate::thread::MessageId;
+use crate::thread_store::SavedMessage;
 
 #[derive(Debug)]
 pub struct ToolUse {
@@ -28,7 +29,6 @@ pub enum ToolUseStatus {
     Error(SharedString),
 }
 
-#[derive(Default)]
 pub struct ToolUseState {
     tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
     tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
@@ -37,6 +37,65 @@ pub struct ToolUseState {
 }
 
 impl ToolUseState {
+    pub fn new() -> Self {
+        Self {
+            tool_uses_by_assistant_message: HashMap::default(),
+            tool_uses_by_user_message: HashMap::default(),
+            tool_results: HashMap::default(),
+            pending_tool_uses_by_id: HashMap::default(),
+        }
+    }
+
+    pub fn from_saved_messages(messages: &[SavedMessage]) -> Self {
+        let mut this = Self::new();
+
+        for message in messages {
+            match message.role {
+                Role::Assistant => {
+                    if !message.tool_uses.is_empty() {
+                        this.tool_uses_by_assistant_message.insert(
+                            message.id,
+                            message
+                                .tool_uses
+                                .iter()
+                                .map(|tool_use| LanguageModelToolUse {
+                                    id: tool_use.id.clone(),
+                                    name: tool_use.name.clone().into(),
+                                    input: tool_use.input.clone(),
+                                })
+                                .collect(),
+                        );
+                    }
+                }
+                Role::User => {
+                    if !message.tool_results.is_empty() {
+                        let tool_uses_by_user_message = this
+                            .tool_uses_by_user_message
+                            .entry(message.id)
+                            .or_default();
+
+                        for tool_result in &message.tool_results {
+                            let tool_use_id = tool_result.tool_use_id.clone();
+
+                            tool_uses_by_user_message.push(tool_use_id.clone());
+                            this.tool_results.insert(
+                                tool_use_id.clone(),
+                                LanguageModelToolResult {
+                                    tool_use_id,
+                                    is_error: tool_result.is_error,
+                                    content: tool_result.content.clone(),
+                                },
+                            );
+                        }
+                    }
+                }
+                Role::System => {}
+            }
+        }
+
+        this
+    }
+
     pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
         self.pending_tool_uses_by_id.values().collect()
     }
@@ -84,6 +143,17 @@ impl ToolUseState {
         tool_uses
     }
 
+    pub fn tool_results_for_message(&self, message_id: MessageId) -> Vec<&LanguageModelToolResult> {
+        let empty = Vec::new();
+
+        self.tool_uses_by_user_message
+            .get(&message_id)
+            .unwrap_or(&empty)
+            .iter()
+            .filter_map(|tool_use_id| self.tool_results.get(&tool_use_id))
+            .collect()
+    }
+
     pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
         self.tool_uses_by_user_message
             .get(&message_id)