assistant2: Add helper methods to `Thread` for dealing with tool use (#26310)

Marshall Bowers created

This PR adds two new helper methods to the `Thread` for dealing with
tool use:

- `use_pending_tools` - This uses all of the tools that are pending
- The reason we aren't calling this directly in `stream_completion` is
that we still might need to have a way for users to confirm that they
want tools to be run, which would need to happen at the UI layer in the
`ActiveThread`.
- `send_tool_results_to_model` - This encapsulates inserting a new user
message that contains the tool results and sending them up to the model.

Release Notes:

- N/A

Change summary

crates/assistant2/src/active_thread.rs   | 42 ++--------------------
crates/assistant2/src/assistant_panel.rs | 11 -----
crates/assistant2/src/thread.rs          | 47 ++++++++++++++++++++++++-
crates/assistant2/src/thread_store.rs    | 13 +++++-
4 files changed, 61 insertions(+), 52 deletions(-)

Detailed changes

crates/assistant2/src/active_thread.rs 🔗

@@ -1,17 +1,15 @@
 use std::sync::Arc;
 
-use assistant_tool::ToolWorkingSet;
 use collections::HashMap;
 use editor::{Editor, MultiBuffer};
 use gpui::{
     list, AbsoluteLength, AnyElement, App, ClickEvent, DefiniteLength, EdgesRefinement, Empty,
     Entity, Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
-    Task, TextStyleRefinement, UnderlineStyle, WeakEntity,
+    Task, TextStyleRefinement, UnderlineStyle,
 };
 use language::{Buffer, LanguageRegistry};
 use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
 use markdown::{Markdown, MarkdownStyle};
-use project::Project;
 use settings::Settings as _;
 use theme::ThemeSettings;
 use ui::{prelude::*, Disclosure, KeyBinding};
@@ -23,9 +21,7 @@ use crate::tool_use::{ToolUse, ToolUseStatus};
 use crate::ui::ContextPill;
 
 pub struct ActiveThread {
-    project: WeakEntity<Project>,
     language_registry: Arc<LanguageRegistry>,
-    tools: Arc<ToolWorkingSet>,
     thread_store: Entity<ThreadStore>,
     thread: Entity<Thread>,
     save_thread_task: Option<Task<()>>,
@@ -46,9 +42,7 @@ impl ActiveThread {
     pub fn new(
         thread: Entity<Thread>,
         thread_store: Entity<ThreadStore>,
-        project: WeakEntity<Project>,
         language_registry: Arc<LanguageRegistry>,
-        tools: Arc<ToolWorkingSet>,
         window: &mut Window,
         cx: &mut Context<Self>,
     ) -> Self {
@@ -58,9 +52,7 @@ impl ActiveThread {
         ];
 
         let mut this = Self {
-            project,
             language_registry,
-            tools,
             thread_store,
             thread: thread.clone(),
             save_thread_task: None,
@@ -300,24 +292,9 @@ impl ActiveThread {
                 cx.notify();
             }
             ThreadEvent::UsePendingTools => {
-                let pending_tool_uses = self
-                    .thread
-                    .read(cx)
-                    .pending_tool_uses()
-                    .into_iter()
-                    .filter(|tool_use| tool_use.status.is_idle())
-                    .cloned()
-                    .collect::<Vec<_>>();
-
-                for tool_use in pending_tool_uses {
-                    if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
-                        let task = tool.run(tool_use.input, self.project.clone(), cx);
-
-                        self.thread.update(cx, |thread, cx| {
-                            thread.insert_tool_output(tool_use.id.clone(), task, cx);
-                        });
-                    }
-                }
+                self.thread.update(cx, |thread, cx| {
+                    thread.use_pending_tools(cx);
+                });
             }
             ThreadEvent::ToolFinished { .. } => {
                 let all_tools_finished = self
@@ -330,16 +307,7 @@ impl ActiveThread {
                     let model_registry = LanguageModelRegistry::read_global(cx);
                     if let Some(model) = model_registry.active_model() {
                         self.thread.update(cx, |thread, cx| {
-                            // Insert a user message to contain the tool results.
-                            thread.insert_user_message(
-                                // TODO: Sending up a user message without any content results in the model sending back
-                                // responses that also don't have any content. We currently don't handle this case well,
-                                // so for now we provide some text to keep the model on track.
-                                "Here are the tool results.",
-                                Vec::new(),
-                                cx,
-                            );
-                            thread.send_to_model(model, RequestKind::Chat, true, cx);
+                            thread.send_tool_results_to_model(model, cx);
                         });
                     }
                 }

crates/assistant2/src/assistant_panel.rs 🔗

@@ -92,7 +92,6 @@ pub struct AssistantPanel {
     context_editor: Option<Entity<ContextEditor>>,
     configuration: Option<Entity<AssistantConfiguration>>,
     configuration_subscription: Option<Subscription>,
-    tools: Arc<ToolWorkingSet>,
     local_timezone: UtcOffset,
     active_view: ActiveView,
     history_store: Entity<HistoryStore>,
@@ -133,7 +132,7 @@ impl AssistantPanel {
             log::info!("[assistant2-debug] finished initializing ContextStore");
 
             workspace.update_in(&mut cx, |workspace, window, cx| {
-                cx.new(|cx| Self::new(workspace, thread_store, context_store, tools, window, cx))
+                cx.new(|cx| Self::new(workspace, thread_store, context_store, window, cx))
             })
         })
     }
@@ -142,7 +141,6 @@ impl AssistantPanel {
         workspace: &Workspace,
         thread_store: Entity<ThreadStore>,
         context_store: Entity<assistant_context_editor::ContextStore>,
-        tools: Arc<ToolWorkingSet>,
         window: &mut Window,
         cx: &mut Context<Self>,
     ) -> Self {
@@ -179,9 +177,7 @@ impl AssistantPanel {
                 ActiveThread::new(
                     thread.clone(),
                     thread_store.clone(),
-                    project.downgrade(),
                     language_registry,
-                    tools.clone(),
                     window,
                     cx,
                 )
@@ -191,7 +187,6 @@ impl AssistantPanel {
             context_editor: None,
             configuration: None,
             configuration_subscription: None,
-            tools,
             local_timezone: UtcOffset::from_whole_seconds(
                 chrono::Local::now().offset().local_minus_utc(),
             )
@@ -246,9 +241,7 @@ impl AssistantPanel {
             ActiveThread::new(
                 thread.clone(),
                 self.thread_store.clone(),
-                self.project.downgrade(),
                 self.language_registry.clone(),
-                self.tools.clone(),
                 window,
                 cx,
             )
@@ -381,9 +374,7 @@ impl AssistantPanel {
                     ActiveThread::new(
                         thread.clone(),
                         this.thread_store.clone(),
-                        this.project.downgrade(),
                         this.language_registry.clone(),
-                        this.tools.clone(),
                         window,
                         cx,
                     )

crates/assistant2/src/thread.rs 🔗

@@ -5,13 +5,14 @@ use assistant_tool::ToolWorkingSet;
 use chrono::{DateTime, Utc};
 use collections::{BTreeMap, HashMap, HashSet};
 use futures::StreamExt as _;
-use gpui::{App, Context, EventEmitter, SharedString, Task};
+use gpui::{App, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
 use language_model::{
     LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
     LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
     LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
     Role, StopReason,
 };
+use project::Project;
 use serde::{Deserialize, Serialize};
 use util::{post_inc, TryFutureExt as _};
 use uuid::Uuid;
@@ -71,12 +72,17 @@ pub struct Thread {
     context_by_message: HashMap<MessageId, Vec<ContextId>>,
     completion_count: usize,
     pending_completions: Vec<PendingCompletion>,
+    project: WeakEntity<Project>,
     tools: Arc<ToolWorkingSet>,
     tool_use: ToolUseState,
 }
 
 impl Thread {
-    pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut Context<Self>) -> Self {
+    pub fn new(
+        project: Entity<Project>,
+        tools: Arc<ToolWorkingSet>,
+        _cx: &mut Context<Self>,
+    ) -> Self {
         Self {
             id: ThreadId::new(),
             updated_at: Utc::now(),
@@ -88,6 +94,7 @@ impl Thread {
             context_by_message: HashMap::default(),
             completion_count: 0,
             pending_completions: Vec::new(),
+            project: project.downgrade(),
             tools,
             tool_use: ToolUseState::new(),
         }
@@ -96,6 +103,7 @@ impl Thread {
     pub fn from_saved(
         id: ThreadId,
         saved: SavedThread,
+        project: Entity<Project>,
         tools: Arc<ToolWorkingSet>,
         _cx: &mut Context<Self>,
     ) -> Self {
@@ -127,6 +135,7 @@ impl Thread {
             context_by_message: HashMap::default(),
             completion_count: 0,
             pending_completions: Vec::new(),
+            project: project.downgrade(),
             tools,
             tool_use,
         }
@@ -550,6 +559,23 @@ impl Thread {
         });
     }
 
+    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
+        let pending_tool_uses = self
+            .pending_tool_uses()
+            .into_iter()
+            .filter(|tool_use| tool_use.status.is_idle())
+            .cloned()
+            .collect::<Vec<_>>();
+
+        for tool_use in pending_tool_uses {
+            if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
+                let task = tool.run(tool_use.input, self.project.clone(), cx);
+
+                self.insert_tool_output(tool_use.id.clone(), task, cx);
+            }
+        }
+    }
+
     pub fn insert_tool_output(
         &mut self,
         tool_use_id: LanguageModelToolUseId,
@@ -576,6 +602,23 @@ impl Thread {
             .run_pending_tool(tool_use_id, insert_output_task);
     }
 
+    pub fn send_tool_results_to_model(
+        &mut self,
+        model: Arc<dyn LanguageModel>,
+        cx: &mut Context<Self>,
+    ) {
+        // Insert a user message to contain the tool results.
+        self.insert_user_message(
+            // TODO: Sending up a user message without any content results in the model sending back
+            // responses that also don't have any content. We currently don't handle this case well,
+            // so for now we provide some text to keep the model on track.
+            "Here are the tool results.",
+            Vec::new(),
+            cx,
+        );
+        self.send_to_model(model, RequestKind::Chat, true, cx);
+    }
+
     /// Cancels the last pending completion, if there are any pending.
     ///
     /// Returns whether a completion was canceled.

crates/assistant2/src/thread_store.rs 🔗

@@ -26,7 +26,6 @@ pub fn init(cx: &mut App) {
 }
 
 pub struct ThreadStore {
-    #[allow(unused)]
     project: Entity<Project>,
     tools: Arc<ToolWorkingSet>,
     context_server_manager: Entity<ContextServerManager>,
@@ -78,7 +77,7 @@ impl ThreadStore {
     }
 
     pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
-        cx.new(|cx| Thread::new(self.tools.clone(), cx))
+        cx.new(|cx| Thread::new(self.project.clone(), self.tools.clone(), cx))
     }
 
     pub fn open_thread(
@@ -96,7 +95,15 @@ impl ThreadStore {
                 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
 
             this.update(&mut cx, |this, cx| {
-                cx.new(|cx| Thread::from_saved(id.clone(), thread, this.tools.clone(), cx))
+                cx.new(|cx| {
+                    Thread::from_saved(
+                        id.clone(),
+                        thread,
+                        this.project.clone(),
+                        this.tools.clone(),
+                        cx,
+                    )
+                })
             })
         })
     }