assistant2: Add support for using tools (#21190)

Marshall Bowers created

This PR adds rudimentary support for using tools to `assistant2`. There
are currently no visual affordances for tool use.

This is gated behind the `assistant-tool-use` feature flag.

<img width="1079" alt="Screenshot 2024-11-25 at 7 21 31 PM"
src="https://github.com/user-attachments/assets/64d6ca29-c592-4474-8e9d-c344f855bc63">

Release Notes:

- N/A

Change summary

Cargo.lock                                |   3 
crates/assistant/src/context.rs           |  12 -
crates/assistant2/Cargo.toml              |   3 
crates/assistant2/src/assistant_panel.rs  |  61 +++++++
crates/assistant2/src/message_editor.rs   |  19 ++
crates/assistant2/src/thread.rs           | 190 ++++++++++++++++++++++--
crates/assistant_tools/src/now_tool.rs    |   2 
crates/feature_flags/src/feature_flags.rs |  10 +
8 files changed, 263 insertions(+), 37 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -455,6 +455,8 @@ name = "assistant2"
 version = "0.1.0"
 dependencies = [
  "anyhow",
+ "assistant_tool",
+ "collections",
  "command_palette_hooks",
  "editor",
  "feature_flags",
@@ -463,6 +465,7 @@ dependencies = [
  "language_model",
  "language_model_selector",
  "proto",
+ "serde_json",
  "settings",
  "smol",
  "theme",

crates/assistant/src/context.rs 🔗

@@ -15,7 +15,7 @@ use assistant_tool::ToolWorkingSet;
 use client::{self, proto, telemetry::Telemetry};
 use clock::ReplicaId;
 use collections::{HashMap, HashSet};
-use feature_flags::{FeatureFlag, FeatureFlagAppExt};
+use feature_flags::{FeatureFlagAppExt, ToolUseFeatureFlag};
 use fs::{Fs, RemoveOptions};
 use futures::{future::Shared, FutureExt, StreamExt};
 use gpui::{
@@ -3201,16 +3201,6 @@ pub enum PendingSlashCommandStatus {
     Error(String),
 }
 
-pub(crate) struct ToolUseFeatureFlag;
-
-impl FeatureFlag for ToolUseFeatureFlag {
-    const NAME: &'static str = "assistant-tool-use";
-
-    fn enabled_for_staff() -> bool {
-        false
-    }
-}
-
 #[derive(Debug, Clone)]
 pub struct PendingToolUse {
     pub id: Arc<str>,

crates/assistant2/Cargo.toml 🔗

@@ -14,6 +14,8 @@ doctest = false
 
 [dependencies]
 anyhow.workspace = true
+assistant_tool.workspace = true
+collections.workspace = true
 command_palette_hooks.workspace = true
 editor.workspace = true
 feature_flags.workspace = true
@@ -23,6 +25,7 @@ language_model.workspace = true
 language_model_selector.workspace = true
 proto.workspace = true
 settings.workspace = true
+serde_json.workspace = true
 smol.workspace = true
 theme.workspace = true
 ui.workspace = true

crates/assistant2/src/assistant_panel.rs 🔗

@@ -1,4 +1,7 @@
+use std::sync::Arc;
+
 use anyhow::Result;
+use assistant_tool::ToolWorkingSet;
 use gpui::{
     prelude::*, px, Action, AppContext, AsyncWindowContext, EventEmitter, FocusHandle,
     FocusableView, Model, Pixels, Subscription, Task, View, ViewContext, WeakView, WindowContext,
@@ -10,7 +13,7 @@ use workspace::dock::{DockPosition, Panel, PanelEvent};
 use workspace::Workspace;
 
 use crate::message_editor::MessageEditor;
-use crate::thread::Thread;
+use crate::thread::{Thread, ThreadEvent};
 use crate::{NewThread, ToggleFocus, ToggleModelSelector};
 
 pub fn init(cx: &mut AppContext) {
@@ -25,8 +28,10 @@ pub fn init(cx: &mut AppContext) {
 }
 
 pub struct AssistantPanel {
+    workspace: WeakView<Workspace>,
     thread: Model<Thread>,
     message_editor: View<MessageEditor>,
+    tools: Arc<ToolWorkingSet>,
     _subscriptions: Vec<Subscription>,
 }
 
@@ -36,26 +41,36 @@ impl AssistantPanel {
         cx: AsyncWindowContext,
     ) -> Task<Result<View<Self>>> {
         cx.spawn(|mut cx| async move {
+            let tools = Arc::new(ToolWorkingSet::default());
             workspace.update(&mut cx, |workspace, cx| {
-                cx.new_view(|cx| Self::new(workspace, cx))
+                cx.new_view(|cx| Self::new(workspace, tools, cx))
             })
         })
     }
 
-    fn new(_workspace: &Workspace, cx: &mut ViewContext<Self>) -> Self {
-        let thread = cx.new_model(Thread::new);
-        let subscriptions = vec![cx.observe(&thread, |_, _, cx| cx.notify())];
+    fn new(workspace: &Workspace, tools: Arc<ToolWorkingSet>, cx: &mut ViewContext<Self>) -> Self {
+        let thread = cx.new_model(|cx| Thread::new(tools.clone(), cx));
+        let subscriptions = vec![
+            cx.observe(&thread, |_, _, cx| cx.notify()),
+            cx.subscribe(&thread, Self::handle_thread_event),
+        ];
 
         Self {
+            workspace: workspace.weak_handle(),
             thread: thread.clone(),
             message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)),
+            tools,
             _subscriptions: subscriptions,
         }
     }
 
     fn new_thread(&mut self, cx: &mut ViewContext<Self>) {
-        let thread = cx.new_model(Thread::new);
-        let subscriptions = vec![cx.observe(&thread, |_, _, cx| cx.notify())];
+        let tools = self.thread.read(cx).tools().clone();
+        let thread = cx.new_model(|cx| Thread::new(tools, cx));
+        let subscriptions = vec![
+            cx.observe(&thread, |_, _, cx| cx.notify()),
+            cx.subscribe(&thread, Self::handle_thread_event),
+        ];
 
         self.message_editor = cx.new_view(|cx| MessageEditor::new(thread.clone(), cx));
         self.thread = thread;
@@ -63,6 +78,38 @@ impl AssistantPanel {
 
         self.message_editor.focus_handle(cx).focus(cx);
     }
+
+    fn handle_thread_event(
+        &mut self,
+        _: Model<Thread>,
+        event: &ThreadEvent,
+        cx: &mut ViewContext<Self>,
+    ) {
+        match event {
+            ThreadEvent::StreamedCompletion => {}
+            ThreadEvent::UsePendingTools => {
+                let pending_tool_uses = self
+                    .thread
+                    .read(cx)
+                    .pending_tool_uses()
+                    .into_iter()
+                    .filter(|tool_use| tool_use.status.is_idle())
+                    .cloned()
+                    .collect::<Vec<_>>();
+
+                for tool_use in pending_tool_uses {
+                    if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
+                        let task = tool.run(tool_use.input, self.workspace.clone(), cx);
+
+                        self.thread.update(cx, |thread, cx| {
+                            thread.insert_tool_output(tool_use.id.clone(), task, cx);
+                        });
+                    }
+                }
+            }
+            ThreadEvent::ToolFinished { .. } => {}
+        }
+    }
 }
 
 impl FocusableView for AssistantPanel {

crates/assistant2/src/message_editor.rs 🔗

@@ -1,6 +1,7 @@
 use editor::{Editor, EditorElement, EditorStyle};
+use feature_flags::{FeatureFlagAppExt, ToolUseFeatureFlag};
 use gpui::{AppContext, FocusableView, Model, TextStyle, View};
-use language_model::LanguageModelRegistry;
+use language_model::{LanguageModelRegistry, LanguageModelRequestTool};
 use settings::Settings;
 use theme::ThemeSettings;
 use ui::{prelude::*, ButtonLike, ElevationIndex, KeyBinding};
@@ -55,7 +56,21 @@ impl MessageEditor {
 
         self.thread.update(cx, |thread, cx| {
             thread.insert_user_message(user_message);
-            let request = thread.to_completion_request(request_kind, cx);
+            let mut request = thread.to_completion_request(request_kind, cx);
+
+            if cx.has_flag::<ToolUseFeatureFlag>() {
+                request.tools = thread
+                    .tools()
+                    .tools(cx)
+                    .into_iter()
+                    .map(|tool| LanguageModelRequestTool {
+                        name: tool.name(),
+                        description: tool.description(),
+                        input_schema: tool.input_schema(),
+                    })
+                    .collect();
+            }
+
             thread.stream_completion(request, model, cx)
         });
 

crates/assistant2/src/thread.rs 🔗

@@ -1,12 +1,16 @@
 use std::sync::Arc;
 
-use futures::StreamExt as _;
+use anyhow::Result;
+use assistant_tool::ToolWorkingSet;
+use collections::HashMap;
+use futures::future::Shared;
+use futures::{FutureExt as _, StreamExt as _};
 use gpui::{AppContext, EventEmitter, ModelContext, Task};
 use language_model::{
     LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
-    MessageContent, Role, StopReason,
+    LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
 };
-use util::{post_inc, ResultExt as _};
+use util::post_inc;
 
 #[derive(Debug, Clone, Copy)]
 pub enum RequestKind {
@@ -14,14 +18,12 @@ pub enum RequestKind {
 }
 
 /// A message in a [`Thread`].
+#[derive(Debug)]
 pub struct Message {
     pub role: Role,
     pub text: String,
-}
-
-struct PendingCompletion {
-    id: usize,
-    _task: Task<()>,
+    pub tool_uses: Vec<LanguageModelToolUse>,
+    pub tool_results: Vec<LanguageModelToolResult>,
 }
 
 /// A thread of conversation with the LLM.
@@ -29,14 +31,20 @@ pub struct Thread {
     messages: Vec<Message>,
     completion_count: usize,
     pending_completions: Vec<PendingCompletion>,
+    tools: Arc<ToolWorkingSet>,
+    pending_tool_uses_by_id: HashMap<Arc<str>, PendingToolUse>,
+    completed_tool_uses_by_id: HashMap<Arc<str>, String>,
 }
 
 impl Thread {
-    pub fn new(_cx: &mut ModelContext<Self>) -> Self {
+    pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
         Self {
+            tools,
             messages: Vec::new(),
             completion_count: 0,
             pending_completions: Vec::new(),
+            pending_tool_uses_by_id: HashMap::default(),
+            completed_tool_uses_by_id: HashMap::default(),
         }
     }
 
@@ -44,11 +52,31 @@ impl Thread {
         self.messages.iter()
     }
 
+    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
+        &self.tools
+    }
+
+    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
+        self.pending_tool_uses_by_id.values().collect()
+    }
+
     pub fn insert_user_message(&mut self, text: impl Into<String>) {
-        self.messages.push(Message {
+        let mut message = Message {
             role: Role::User,
             text: text.into(),
-        });
+            tool_uses: Vec::new(),
+            tool_results: Vec::new(),
+        };
+
+        for (tool_use_id, tool_output) in self.completed_tool_uses_by_id.drain() {
+            message.tool_results.push(LanguageModelToolResult {
+                tool_use_id: tool_use_id.to_string(),
+                content: tool_output,
+                is_error: false,
+            });
+        }
+
+        self.messages.push(message);
     }
 
     pub fn to_completion_request(
@@ -70,9 +98,23 @@ impl Thread {
                 cache: false,
             };
 
-            request_message
-                .content
-                .push(MessageContent::Text(message.text.clone()));
+            for tool_result in &message.tool_results {
+                request_message
+                    .content
+                    .push(MessageContent::ToolResult(tool_result.clone()));
+            }
+
+            if !message.text.is_empty() {
+                request_message
+                    .content
+                    .push(MessageContent::Text(message.text.clone()));
+            }
+
+            for tool_use in &message.tool_uses {
+                request_message
+                    .content
+                    .push(MessageContent::ToolUse(tool_use.clone()));
+            }
 
             request.messages.push(request_message);
         }
@@ -103,6 +145,8 @@ impl Thread {
                                 thread.messages.push(Message {
                                     role: Role::Assistant,
                                     text: String::new(),
+                                    tool_uses: Vec::new(),
+                                    tool_results: Vec::new(),
                                 });
                             }
                             LanguageModelCompletionEvent::Stop(reason) => {
@@ -115,7 +159,24 @@ impl Thread {
                                     }
                                 }
                             }
-                            LanguageModelCompletionEvent::ToolUse(_tool_use) => {}
+                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
+                                if let Some(last_message) = thread.messages.last_mut() {
+                                    if last_message.role == Role::Assistant {
+                                        last_message.tool_uses.push(tool_use.clone());
+                                    }
+                                }
+
+                                let tool_use_id: Arc<str> = tool_use.id.into();
+                                thread.pending_tool_uses_by_id.insert(
+                                    tool_use_id.clone(),
+                                    PendingToolUse {
+                                        id: tool_use_id,
+                                        name: tool_use.name,
+                                        input: tool_use.input,
+                                        status: PendingToolUseStatus::Idle,
+                                    },
+                                );
+                            }
                         }
 
                         cx.emit(ThreadEvent::StreamedCompletion);
@@ -135,7 +196,35 @@ impl Thread {
             };
 
             let result = stream_completion.await;
-            let _ = result.log_err();
+
+            thread
+                .update(&mut cx, |_thread, cx| {
+                    let error_message = if let Some(error) = result.as_ref().err() {
+                        let error_message = error
+                            .chain()
+                            .map(|err| err.to_string())
+                            .collect::<Vec<_>>()
+                            .join("\n");
+                        Some(error_message)
+                    } else {
+                        None
+                    };
+
+                    if let Some(error_message) = error_message {
+                        eprintln!("Completion failed: {error_message:?}");
+                    }
+
+                    if let Ok(stop_reason) = result {
+                        match stop_reason {
+                            StopReason::ToolUse => {
+                                cx.emit(ThreadEvent::UsePendingTools);
+                            }
+                            StopReason::EndTurn => {}
+                            StopReason::MaxTokens => {}
+                        }
+                    }
+                })
+                .ok();
         });
 
         self.pending_completions.push(PendingCompletion {
@@ -143,11 +232,80 @@ impl Thread {
             _task: task,
         });
     }
+
+    pub fn insert_tool_output(
+        &mut self,
+        tool_use_id: Arc<str>,
+        output: Task<Result<String>>,
+        cx: &mut ModelContext<Self>,
+    ) {
+        let insert_output_task = cx.spawn(|thread, mut cx| {
+            let tool_use_id = tool_use_id.clone();
+            async move {
+                let output = output.await;
+                thread
+                    .update(&mut cx, |thread, cx| match output {
+                        Ok(output) => {
+                            thread
+                                .completed_tool_uses_by_id
+                                .insert(tool_use_id.clone(), output);
+
+                            cx.emit(ThreadEvent::ToolFinished { tool_use_id });
+                        }
+                        Err(err) => {
+                            if let Some(tool_use) =
+                                thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
+                            {
+                                tool_use.status = PendingToolUseStatus::Error(err.to_string());
+                            }
+                        }
+                    })
+                    .ok();
+            }
+        });
+
+        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
+            tool_use.status = PendingToolUseStatus::Running {
+                _task: insert_output_task.shared(),
+            };
+        }
+    }
 }
 
 #[derive(Debug, Clone)]
 pub enum ThreadEvent {
     StreamedCompletion,
+    UsePendingTools,
+    ToolFinished {
+        #[allow(unused)]
+        tool_use_id: Arc<str>,
+    },
 }
 
 impl EventEmitter<ThreadEvent> for Thread {}
+
+struct PendingCompletion {
+    id: usize,
+    _task: Task<()>,
+}
+
+#[derive(Debug, Clone)]
+pub struct PendingToolUse {
+    pub id: Arc<str>,
+    pub name: String,
+    pub input: serde_json::Value,
+    pub status: PendingToolUseStatus,
+}
+
+#[derive(Debug, Clone)]
+pub enum PendingToolUseStatus {
+    Idle,
+    Running { _task: Shared<Task<()>> },
+    Error(#[allow(unused)] String),
+}
+
+impl PendingToolUseStatus {
+    pub fn is_idle(&self) -> bool {
+        matches!(self, PendingToolUseStatus::Idle)
+    }
+}

crates/assistant_tools/src/now_tool.rs 🔗

@@ -30,7 +30,7 @@ impl Tool for NowTool {
     }
 
     fn description(&self) -> String {
-        "Returns the current datetime in RFC 3339 format.".into()
+        "Returns the current datetime in RFC 3339 format. Only use this tool when the user specifically asks for it or the current task would benefit from knowing the current datetime.".into()
     }
 
     fn input_schema(&self) -> serde_json::Value {

crates/feature_flags/src/feature_flags.rs 🔗

@@ -49,6 +49,16 @@ impl FeatureFlag for Assistant2FeatureFlag {
     }
 }
 
+pub struct ToolUseFeatureFlag;
+
+impl FeatureFlag for ToolUseFeatureFlag {
+    const NAME: &'static str = "assistant-tool-use";
+
+    fn enabled_for_staff() -> bool {
+        false
+    }
+}
+
 pub struct Remoting {}
 impl FeatureFlag for Remoting {
     const NAME: &'static str = "remoting";