assistant2: Factor out tool use into its own module (#25819)

Marshall Bowers created

This PR factors out the concerns related to tool use out of `Thread` and
into their own module.

Release Notes:

- N/A

Change summary

crates/assistant2/src/active_thread.rs |   5 
crates/assistant2/src/assistant.rs     |   1 
crates/assistant2/src/thread.rs        | 228 ++++-----------------------
crates/assistant2/src/tool_use.rs      | 221 +++++++++++++++++++++++++++
4 files changed, 258 insertions(+), 197 deletions(-)

Detailed changes

crates/assistant2/src/active_thread.rs 🔗

@@ -15,10 +15,9 @@ use theme::ThemeSettings;
 use ui::{prelude::*, Disclosure};
 use workspace::Workspace;
 
-use crate::thread::{
-    MessageId, RequestKind, Thread, ThreadError, ThreadEvent, ToolUse, ToolUseStatus,
-};
+use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
 use crate::thread_store::ThreadStore;
+use crate::tool_use::{ToolUse, ToolUseStatus};
 use crate::ui::ContextPill;
 
 pub struct ActiveThread {

crates/assistant2/src/thread.rs 🔗

@@ -4,14 +4,12 @@ use anyhow::Result;
 use assistant_tool::ToolWorkingSet;
 use chrono::{DateTime, Utc};
 use collections::{BTreeMap, HashMap, HashSet};
-use futures::future::Shared;
-use futures::{FutureExt as _, StreamExt as _};
+use futures::StreamExt as _;
 use gpui::{App, Context, EventEmitter, SharedString, Task};
 use language_model::{
     LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
-    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
-    LanguageModelToolUse, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
-    PaymentRequiredError, Role, StopReason,
+    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolUseId,
+    MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError, Role, StopReason,
 };
 use serde::{Deserialize, Serialize};
 use util::{post_inc, TryFutureExt as _};
@@ -19,6 +17,7 @@ use uuid::Uuid;
 
 use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
 use crate::thread_store::SavedThread;
+use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
 
 #[derive(Debug, Clone, Copy)]
 pub enum RequestKind {
@@ -43,7 +42,7 @@ impl std::fmt::Display for ThreadId {
 }
 
 #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
-pub struct MessageId(usize);
+pub struct MessageId(pub(crate) usize);
 
 impl MessageId {
     fn post_inc(&mut self) -> Self {
@@ -59,22 +58,6 @@ pub struct Message {
     pub text: String,
 }
 
-#[derive(Debug)]
-pub struct ToolUse {
-    pub id: LanguageModelToolUseId,
-    pub name: SharedString,
-    pub status: ToolUseStatus,
-    pub input: serde_json::Value,
-}
-
-#[derive(Debug, Clone)]
-pub enum ToolUseStatus {
-    Pending,
-    Running,
-    Finished(SharedString),
-    Error(SharedString),
-}
-
 /// A thread of conversation with the LLM.
 pub struct Thread {
     id: ThreadId,
@@ -88,10 +71,7 @@ pub struct Thread {
     completion_count: usize,
     pending_completions: Vec<PendingCompletion>,
     tools: Arc<ToolWorkingSet>,
-    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
-    tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
-    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
-    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
+    tool_use: ToolUseState,
 }
 
 impl Thread {
@@ -108,10 +88,7 @@ impl Thread {
             completion_count: 0,
             pending_completions: Vec::new(),
             tools,
-            tool_uses_by_assistant_message: HashMap::default(),
-            tool_uses_by_user_message: HashMap::default(),
-            tool_results: HashMap::default(),
-            pending_tool_uses_by_id: HashMap::default(),
+            tool_use: ToolUseState::default(),
         }
     }
 
@@ -143,10 +120,7 @@ impl Thread {
             completion_count: 0,
             pending_completions: Vec::new(),
             tools,
-            tool_uses_by_assistant_message: HashMap::default(),
-            tool_uses_by_user_message: HashMap::default(),
-            tool_results: HashMap::default(),
-            pending_tool_uses_by_id: HashMap::default(),
+            tool_use: ToolUseState::default(),
         }
     }
 
@@ -208,56 +182,15 @@ impl Thread {
     }
 
     pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
-        self.pending_tool_uses_by_id.values().collect()
+        self.tool_use.pending_tool_uses()
     }
 
     pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
-        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
-            return Vec::new();
-        };
-
-        let mut tool_uses = Vec::new();
-
-        for tool_use in tool_uses_for_message.iter() {
-            let tool_result = self.tool_results.get(&tool_use.id);
-
-            let status = (|| {
-                if let Some(tool_result) = tool_result {
-                    return if tool_result.is_error {
-                        ToolUseStatus::Error(tool_result.content.clone().into())
-                    } else {
-                        ToolUseStatus::Finished(tool_result.content.clone().into())
-                    };
-                }
-
-                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
-                    return match pending_tool_use.status {
-                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
-                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
-                        PendingToolUseStatus::Error(ref err) => {
-                            ToolUseStatus::Error(err.clone().into())
-                        }
-                    };
-                }
-
-                ToolUseStatus::Pending
-            })();
-
-            tool_uses.push(ToolUse {
-                id: tool_use.id.clone(),
-                name: tool_use.name.clone().into(),
-                input: tool_use.input.clone(),
-                status,
-            })
-        }
-
-        tool_uses
+        self.tool_use.tool_uses_for_message(id)
     }
 
     pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
-        self.tool_uses_by_user_message
-            .get(&message_id)
-            .map_or(false, |results| !results.is_empty())
+        self.tool_use.message_has_tool_results(message_id)
     }
 
     pub fn insert_user_message(
@@ -360,20 +293,13 @@ impl Thread {
                 content: Vec::new(),
                 cache: false,
             };
-            if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message.id) {
-                match request_kind {
-                    RequestKind::Chat => {
-                        for tool_use_id in tool_uses {
-                            if let Some(tool_result) = self.tool_results.get(tool_use_id) {
-                                request_message
-                                    .content
-                                    .push(MessageContent::ToolResult(tool_result.clone()));
-                            }
-                        }
-                    }
-                    RequestKind::Summarize => {
-                        // We don't care about tool use during summarization.
-                    }
+            match request_kind {
+                RequestKind::Chat => {
+                    self.tool_use
+                        .attach_tool_results(message.id, &mut request_message);
+                }
+                RequestKind::Summarize => {
+                    // We don't care about tool use during summarization.
                 }
             }
 
@@ -383,18 +309,13 @@ impl Thread {
                     .push(MessageContent::Text(message.text.clone()));
             }
 
-            if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message.id) {
-                match request_kind {
-                    RequestKind::Chat => {
-                        for tool_use in tool_uses {
-                            request_message
-                                .content
-                                .push(MessageContent::ToolUse(tool_use.clone()));
-                        }
-                    }
-                    RequestKind::Summarize => {
-                        // We don't care about tool use during summarization.
-                    }
+            match request_kind {
+                RequestKind::Chat => {
+                    self.tool_use
+                        .attach_tool_uses(message.id, &mut request_message);
+                }
+                RequestKind::Summarize => {
+                    // We don't care about tool use during summarization.
                 }
             }
 
@@ -470,32 +391,8 @@ impl Thread {
                                     .rfind(|message| message.role == Role::Assistant)
                                 {
                                     thread
-                                        .tool_uses_by_assistant_message
-                                        .entry(last_assistant_message.id)
-                                        .or_default()
-                                        .push(tool_use.clone());
-
-                                    // The tool use is being requested by the
-                                    // Assistant, so we want to attach the tool
-                                    // results to the next user message.
-                                    let next_user_message_id =
-                                        MessageId(last_assistant_message.id.0 + 1);
-                                    thread
-                                        .tool_uses_by_user_message
-                                        .entry(next_user_message_id)
-                                        .or_default()
-                                        .push(tool_use.id.clone());
-
-                                    thread.pending_tool_uses_by_id.insert(
-                                        tool_use.id.clone(),
-                                        PendingToolUse {
-                                            assistant_message_id: last_assistant_message.id,
-                                            id: tool_use.id,
-                                            name: tool_use.name,
-                                            input: tool_use.input,
-                                            status: PendingToolUseStatus::Idle,
-                                        },
-                                    );
+                                        .tool_use
+                                        .request_tool_use(last_assistant_message.id, tool_use);
                                 }
                             }
                         }
@@ -624,49 +521,19 @@ impl Thread {
             async move {
                 let output = output.await;
                 thread
-                    .update(&mut cx, |thread, cx| match output {
-                        Ok(output) => {
-                            thread.tool_results.insert(
-                                tool_use_id.clone(),
-                                LanguageModelToolResult {
-                                    tool_use_id: tool_use_id.clone(),
-                                    content: output.into(),
-                                    is_error: false,
-                                },
-                            );
-                            thread.pending_tool_uses_by_id.remove(&tool_use_id);
-
-                            cx.emit(ThreadEvent::ToolFinished { tool_use_id });
-                        }
-                        Err(err) => {
-                            thread.tool_results.insert(
-                                tool_use_id.clone(),
-                                LanguageModelToolResult {
-                                    tool_use_id: tool_use_id.clone(),
-                                    content: err.to_string().into(),
-                                    is_error: true,
-                                },
-                            );
-
-                            if let Some(tool_use) =
-                                thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
-                            {
-                                tool_use.status =
-                                    PendingToolUseStatus::Error(err.to_string().into());
-                            }
+                    .update(&mut cx, |thread, cx| {
+                        thread
+                            .tool_use
+                            .insert_tool_output(tool_use_id.clone(), output);
 
-                            cx.emit(ThreadEvent::ToolFinished { tool_use_id });
-                        }
+                        cx.emit(ThreadEvent::ToolFinished { tool_use_id });
                     })
                     .ok();
             }
         });
 
-        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
-            tool_use.status = PendingToolUseStatus::Running {
-                _task: insert_output_task.shared(),
-            };
-        }
+        self.tool_use
+            .run_pending_tool(tool_use_id, insert_output_task);
     }
 
     /// Cancels the last pending completion, if there are any pending.
@@ -708,30 +575,3 @@ struct PendingCompletion {
     id: usize,
     _task: Task<()>,
 }
-
-#[derive(Debug, Clone)]
-pub struct PendingToolUse {
-    pub id: LanguageModelToolUseId,
-    /// The ID of the Assistant message in which the tool use was requested.
-    pub assistant_message_id: MessageId,
-    pub name: Arc<str>,
-    pub input: serde_json::Value,
-    pub status: PendingToolUseStatus,
-}
-
-#[derive(Debug, Clone)]
-pub enum PendingToolUseStatus {
-    Idle,
-    Running { _task: Shared<Task<()>> },
-    Error(#[allow(unused)] Arc<str>),
-}
-
-impl PendingToolUseStatus {
-    pub fn is_idle(&self) -> bool {
-        matches!(self, PendingToolUseStatus::Idle)
-    }
-
-    pub fn is_error(&self) -> bool {
-        matches!(self, PendingToolUseStatus::Error(_))
-    }
-}

crates/assistant2/src/tool_use.rs 🔗

@@ -0,0 +1,221 @@
+use std::sync::Arc;
+
+use anyhow::Result;
+use collections::HashMap;
+use futures::future::Shared;
+use futures::FutureExt as _;
+use gpui::{SharedString, Task};
+use language_model::{
+    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
+    LanguageModelToolUseId, MessageContent,
+};
+
+use crate::thread::MessageId;
+
+#[derive(Debug)]
+pub struct ToolUse {
+    pub id: LanguageModelToolUseId,
+    pub name: SharedString,
+    pub status: ToolUseStatus,
+    pub input: serde_json::Value,
+}
+
+#[derive(Debug, Clone)]
+pub enum ToolUseStatus {
+    Pending,
+    Running,
+    Finished(SharedString),
+    Error(SharedString),
+}
+
+#[derive(Default)]
+pub struct ToolUseState {
+    tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
+    tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
+    tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
+    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
+}
+
+impl ToolUseState {
+    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
+        self.pending_tool_uses_by_id.values().collect()
+    }
+
+    pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
+        let Some(tool_uses_for_message) = &self.tool_uses_by_assistant_message.get(&id) else {
+            return Vec::new();
+        };
+
+        let mut tool_uses = Vec::new();
+
+        for tool_use in tool_uses_for_message.iter() {
+            let tool_result = self.tool_results.get(&tool_use.id);
+
+            let status = (|| {
+                if let Some(tool_result) = tool_result {
+                    return if tool_result.is_error {
+                        ToolUseStatus::Error(tool_result.content.clone().into())
+                    } else {
+                        ToolUseStatus::Finished(tool_result.content.clone().into())
+                    };
+                }
+
+                if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
+                    return match pending_tool_use.status {
+                        PendingToolUseStatus::Idle => ToolUseStatus::Pending,
+                        PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
+                        PendingToolUseStatus::Error(ref err) => {
+                            ToolUseStatus::Error(err.clone().into())
+                        }
+                    };
+                }
+
+                ToolUseStatus::Pending
+            })();
+
+            tool_uses.push(ToolUse {
+                id: tool_use.id.clone(),
+                name: tool_use.name.clone().into(),
+                input: tool_use.input.clone(),
+                status,
+            })
+        }
+
+        tool_uses
+    }
+
+    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
+        self.tool_uses_by_user_message
+            .get(&message_id)
+            .map_or(false, |results| !results.is_empty())
+    }
+
+    pub fn request_tool_use(
+        &mut self,
+        assistant_message_id: MessageId,
+        tool_use: LanguageModelToolUse,
+    ) {
+        self.tool_uses_by_assistant_message
+            .entry(assistant_message_id)
+            .or_default()
+            .push(tool_use.clone());
+
+        // The tool use is being requested by the Assistant, so we want to
+        // attach the tool results to the next user message.
+        let next_user_message_id = MessageId(assistant_message_id.0 + 1);
+        self.tool_uses_by_user_message
+            .entry(next_user_message_id)
+            .or_default()
+            .push(tool_use.id.clone());
+
+        self.pending_tool_uses_by_id.insert(
+            tool_use.id.clone(),
+            PendingToolUse {
+                assistant_message_id,
+                id: tool_use.id,
+                name: tool_use.name,
+                input: tool_use.input,
+                status: PendingToolUseStatus::Idle,
+            },
+        );
+    }
+
+    pub fn run_pending_tool(&mut self, tool_use_id: LanguageModelToolUseId, task: Task<()>) {
+        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
+            tool_use.status = PendingToolUseStatus::Running {
+                _task: task.shared(),
+            };
+        }
+    }
+
+    pub fn insert_tool_output(
+        &mut self,
+        tool_use_id: LanguageModelToolUseId,
+        output: Result<String>,
+    ) {
+        match output {
+            Ok(output) => {
+                self.tool_results.insert(
+                    tool_use_id.clone(),
+                    LanguageModelToolResult {
+                        tool_use_id: tool_use_id.clone(),
+                        content: output.into(),
+                        is_error: false,
+                    },
+                );
+                self.pending_tool_uses_by_id.remove(&tool_use_id);
+            }
+            Err(err) => {
+                self.tool_results.insert(
+                    tool_use_id.clone(),
+                    LanguageModelToolResult {
+                        tool_use_id: tool_use_id.clone(),
+                        content: err.to_string().into(),
+                        is_error: true,
+                    },
+                );
+
+                if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
+                    tool_use.status = PendingToolUseStatus::Error(err.to_string().into());
+                }
+            }
+        }
+    }
+
+    pub fn attach_tool_uses(
+        &self,
+        message_id: MessageId,
+        request_message: &mut LanguageModelRequestMessage,
+    ) {
+        if let Some(tool_uses) = self.tool_uses_by_assistant_message.get(&message_id) {
+            for tool_use in tool_uses {
+                request_message
+                    .content
+                    .push(MessageContent::ToolUse(tool_use.clone()));
+            }
+        }
+    }
+
+    pub fn attach_tool_results(
+        &self,
+        message_id: MessageId,
+        request_message: &mut LanguageModelRequestMessage,
+    ) {
+        if let Some(tool_uses) = self.tool_uses_by_user_message.get(&message_id) {
+            for tool_use_id in tool_uses {
+                if let Some(tool_result) = self.tool_results.get(tool_use_id) {
+                    request_message
+                        .content
+                        .push(MessageContent::ToolResult(tool_result.clone()));
+                }
+            }
+        }
+    }
+}
+
+#[derive(Debug, Clone)]
+pub struct PendingToolUse {
+    pub id: LanguageModelToolUseId,
+    /// The ID of the Assistant message in which the tool use was requested.
+    pub assistant_message_id: MessageId,
+    pub name: Arc<str>,
+    pub input: serde_json::Value,
+    pub status: PendingToolUseStatus,
+}
+
+#[derive(Debug, Clone)]
+pub enum PendingToolUseStatus {
+    Idle,
+    Running { _task: Shared<Task<()>> },
+    Error(#[allow(unused)] Arc<str>),
+}
+
+impl PendingToolUseStatus {
+    pub fn is_idle(&self) -> bool {
+        matches!(self, PendingToolUseStatus::Idle)
+    }
+
+    pub fn is_error(&self) -> bool {
+        matches!(self, PendingToolUseStatus::Error(_))
+    }
+}