Detailed changes
@@ -455,6 +455,8 @@ name = "assistant2"
version = "0.1.0"
dependencies = [
"anyhow",
+ "assistant_tool",
+ "collections",
"command_palette_hooks",
"editor",
"feature_flags",
@@ -463,6 +465,7 @@ dependencies = [
"language_model",
"language_model_selector",
"proto",
+ "serde_json",
"settings",
"smol",
"theme",
@@ -15,7 +15,7 @@ use assistant_tool::ToolWorkingSet;
use client::{self, proto, telemetry::Telemetry};
use clock::ReplicaId;
use collections::{HashMap, HashSet};
-use feature_flags::{FeatureFlag, FeatureFlagAppExt};
+use feature_flags::{FeatureFlagAppExt, ToolUseFeatureFlag};
use fs::{Fs, RemoveOptions};
use futures::{future::Shared, FutureExt, StreamExt};
use gpui::{
@@ -3201,16 +3201,6 @@ pub enum PendingSlashCommandStatus {
Error(String),
}
-pub(crate) struct ToolUseFeatureFlag;
-
-impl FeatureFlag for ToolUseFeatureFlag {
- const NAME: &'static str = "assistant-tool-use";
-
- fn enabled_for_staff() -> bool {
- false
- }
-}
-
#[derive(Debug, Clone)]
pub struct PendingToolUse {
pub id: Arc<str>,
@@ -14,6 +14,8 @@ doctest = false
[dependencies]
anyhow.workspace = true
+assistant_tool.workspace = true
+collections.workspace = true
command_palette_hooks.workspace = true
editor.workspace = true
feature_flags.workspace = true
@@ -23,6 +25,7 @@ language_model.workspace = true
language_model_selector.workspace = true
proto.workspace = true
settings.workspace = true
+serde_json.workspace = true
smol.workspace = true
theme.workspace = true
ui.workspace = true
@@ -1,4 +1,7 @@
+use std::sync::Arc;
+
use anyhow::Result;
+use assistant_tool::ToolWorkingSet;
use gpui::{
prelude::*, px, Action, AppContext, AsyncWindowContext, EventEmitter, FocusHandle,
FocusableView, Model, Pixels, Subscription, Task, View, ViewContext, WeakView, WindowContext,
@@ -10,7 +13,7 @@ use workspace::dock::{DockPosition, Panel, PanelEvent};
use workspace::Workspace;
use crate::message_editor::MessageEditor;
-use crate::thread::Thread;
+use crate::thread::{Thread, ThreadEvent};
use crate::{NewThread, ToggleFocus, ToggleModelSelector};
pub fn init(cx: &mut AppContext) {
@@ -25,8 +28,10 @@ pub fn init(cx: &mut AppContext) {
}
pub struct AssistantPanel {
+ workspace: WeakView<Workspace>,
thread: Model<Thread>,
message_editor: View<MessageEditor>,
+ tools: Arc<ToolWorkingSet>,
_subscriptions: Vec<Subscription>,
}
@@ -36,26 +41,36 @@ impl AssistantPanel {
cx: AsyncWindowContext,
) -> Task<Result<View<Self>>> {
cx.spawn(|mut cx| async move {
+ let tools = Arc::new(ToolWorkingSet::default());
workspace.update(&mut cx, |workspace, cx| {
- cx.new_view(|cx| Self::new(workspace, cx))
+ cx.new_view(|cx| Self::new(workspace, tools, cx))
})
})
}
- fn new(_workspace: &Workspace, cx: &mut ViewContext<Self>) -> Self {
- let thread = cx.new_model(Thread::new);
- let subscriptions = vec![cx.observe(&thread, |_, _, cx| cx.notify())];
+ fn new(workspace: &Workspace, tools: Arc<ToolWorkingSet>, cx: &mut ViewContext<Self>) -> Self {
+ let thread = cx.new_model(|cx| Thread::new(tools.clone(), cx));
+ let subscriptions = vec![
+ cx.observe(&thread, |_, _, cx| cx.notify()),
+ cx.subscribe(&thread, Self::handle_thread_event),
+ ];
Self {
+ workspace: workspace.weak_handle(),
thread: thread.clone(),
message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)),
+ tools,
_subscriptions: subscriptions,
}
}
fn new_thread(&mut self, cx: &mut ViewContext<Self>) {
- let thread = cx.new_model(Thread::new);
- let subscriptions = vec![cx.observe(&thread, |_, _, cx| cx.notify())];
+ let tools = self.thread.read(cx).tools().clone();
+ let thread = cx.new_model(|cx| Thread::new(tools, cx));
+ let subscriptions = vec![
+ cx.observe(&thread, |_, _, cx| cx.notify()),
+ cx.subscribe(&thread, Self::handle_thread_event),
+ ];
self.message_editor = cx.new_view(|cx| MessageEditor::new(thread.clone(), cx));
self.thread = thread;
@@ -63,6 +78,38 @@ impl AssistantPanel {
self.message_editor.focus_handle(cx).focus(cx);
}
+
+ fn handle_thread_event(
+ &mut self,
+ _: Model<Thread>,
+ event: &ThreadEvent,
+ cx: &mut ViewContext<Self>,
+ ) {
+ match event {
+ ThreadEvent::StreamedCompletion => {}
+ ThreadEvent::UsePendingTools => {
+ let pending_tool_uses = self
+ .thread
+ .read(cx)
+ .pending_tool_uses()
+ .into_iter()
+ .filter(|tool_use| tool_use.status.is_idle())
+ .cloned()
+ .collect::<Vec<_>>();
+
+ for tool_use in pending_tool_uses {
+ if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
+ let task = tool.run(tool_use.input, self.workspace.clone(), cx);
+
+ self.thread.update(cx, |thread, cx| {
+ thread.insert_tool_output(tool_use.id.clone(), task, cx);
+ });
+ }
+ }
+ }
+ ThreadEvent::ToolFinished { .. } => {}
+ }
+ }
}
impl FocusableView for AssistantPanel {
@@ -1,6 +1,7 @@
use editor::{Editor, EditorElement, EditorStyle};
+use feature_flags::{FeatureFlagAppExt, ToolUseFeatureFlag};
use gpui::{AppContext, FocusableView, Model, TextStyle, View};
-use language_model::LanguageModelRegistry;
+use language_model::{LanguageModelRegistry, LanguageModelRequestTool};
use settings::Settings;
use theme::ThemeSettings;
use ui::{prelude::*, ButtonLike, ElevationIndex, KeyBinding};
@@ -55,7 +56,21 @@ impl MessageEditor {
self.thread.update(cx, |thread, cx| {
thread.insert_user_message(user_message);
- let request = thread.to_completion_request(request_kind, cx);
+ let mut request = thread.to_completion_request(request_kind, cx);
+
+ if cx.has_flag::<ToolUseFeatureFlag>() {
+ request.tools = thread
+ .tools()
+ .tools(cx)
+ .into_iter()
+ .map(|tool| LanguageModelRequestTool {
+ name: tool.name(),
+ description: tool.description(),
+ input_schema: tool.input_schema(),
+ })
+ .collect();
+ }
+
thread.stream_completion(request, model, cx)
});
@@ -1,12 +1,16 @@
use std::sync::Arc;
-use futures::StreamExt as _;
+use anyhow::Result;
+use assistant_tool::ToolWorkingSet;
+use collections::HashMap;
+use futures::future::Shared;
+use futures::{FutureExt as _, StreamExt as _};
use gpui::{AppContext, EventEmitter, ModelContext, Task};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
- MessageContent, Role, StopReason,
+ LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
};
-use util::{post_inc, ResultExt as _};
+use util::post_inc;
#[derive(Debug, Clone, Copy)]
pub enum RequestKind {
@@ -14,14 +18,12 @@ pub enum RequestKind {
}
/// A message in a [`Thread`].
+#[derive(Debug)]
pub struct Message {
pub role: Role,
pub text: String,
-}
-
-struct PendingCompletion {
- id: usize,
- _task: Task<()>,
+ pub tool_uses: Vec<LanguageModelToolUse>,
+ pub tool_results: Vec<LanguageModelToolResult>,
}
/// A thread of conversation with the LLM.
@@ -29,14 +31,20 @@ pub struct Thread {
messages: Vec<Message>,
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
+ tools: Arc<ToolWorkingSet>,
+ pending_tool_uses_by_id: HashMap<Arc<str>, PendingToolUse>,
+ completed_tool_uses_by_id: HashMap<Arc<str>, String>,
}
impl Thread {
- pub fn new(_cx: &mut ModelContext<Self>) -> Self {
+ pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
Self {
+ tools,
messages: Vec::new(),
completion_count: 0,
pending_completions: Vec::new(),
+ pending_tool_uses_by_id: HashMap::default(),
+ completed_tool_uses_by_id: HashMap::default(),
}
}
@@ -44,11 +52,31 @@ impl Thread {
self.messages.iter()
}
+ pub fn tools(&self) -> &Arc<ToolWorkingSet> {
+ &self.tools
+ }
+
+ pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
+ self.pending_tool_uses_by_id.values().collect()
+ }
+
pub fn insert_user_message(&mut self, text: impl Into<String>) {
- self.messages.push(Message {
+ let mut message = Message {
role: Role::User,
text: text.into(),
- });
+ tool_uses: Vec::new(),
+ tool_results: Vec::new(),
+ };
+
+ for (tool_use_id, tool_output) in self.completed_tool_uses_by_id.drain() {
+ message.tool_results.push(LanguageModelToolResult {
+ tool_use_id: tool_use_id.to_string(),
+ content: tool_output,
+ is_error: false,
+ });
+ }
+
+ self.messages.push(message);
}
pub fn to_completion_request(
@@ -70,9 +98,23 @@ impl Thread {
cache: false,
};
- request_message
- .content
- .push(MessageContent::Text(message.text.clone()));
+ for tool_result in &message.tool_results {
+ request_message
+ .content
+ .push(MessageContent::ToolResult(tool_result.clone()));
+ }
+
+ if !message.text.is_empty() {
+ request_message
+ .content
+ .push(MessageContent::Text(message.text.clone()));
+ }
+
+ for tool_use in &message.tool_uses {
+ request_message
+ .content
+ .push(MessageContent::ToolUse(tool_use.clone()));
+ }
request.messages.push(request_message);
}
@@ -103,6 +145,8 @@ impl Thread {
thread.messages.push(Message {
role: Role::Assistant,
text: String::new(),
+ tool_uses: Vec::new(),
+ tool_results: Vec::new(),
});
}
LanguageModelCompletionEvent::Stop(reason) => {
@@ -115,7 +159,24 @@ impl Thread {
}
}
}
- LanguageModelCompletionEvent::ToolUse(_tool_use) => {}
+ LanguageModelCompletionEvent::ToolUse(tool_use) => {
+ if let Some(last_message) = thread.messages.last_mut() {
+ if last_message.role == Role::Assistant {
+ last_message.tool_uses.push(tool_use.clone());
+ }
+ }
+
+ let tool_use_id: Arc<str> = tool_use.id.into();
+ thread.pending_tool_uses_by_id.insert(
+ tool_use_id.clone(),
+ PendingToolUse {
+ id: tool_use_id,
+ name: tool_use.name,
+ input: tool_use.input,
+ status: PendingToolUseStatus::Idle,
+ },
+ );
+ }
}
cx.emit(ThreadEvent::StreamedCompletion);
@@ -135,7 +196,35 @@ impl Thread {
};
let result = stream_completion.await;
- let _ = result.log_err();
+
+ thread
+ .update(&mut cx, |_thread, cx| {
+ let error_message = if let Some(error) = result.as_ref().err() {
+ let error_message = error
+ .chain()
+ .map(|err| err.to_string())
+ .collect::<Vec<_>>()
+ .join("\n");
+ Some(error_message)
+ } else {
+ None
+ };
+
+ if let Some(error_message) = error_message {
+ eprintln!("Completion failed: {error_message:?}");
+ }
+
+ if let Ok(stop_reason) = result {
+ match stop_reason {
+ StopReason::ToolUse => {
+ cx.emit(ThreadEvent::UsePendingTools);
+ }
+ StopReason::EndTurn => {}
+ StopReason::MaxTokens => {}
+ }
+ }
+ })
+ .ok();
});
self.pending_completions.push(PendingCompletion {
@@ -143,11 +232,80 @@ impl Thread {
_task: task,
});
}
+
+ pub fn insert_tool_output(
+ &mut self,
+ tool_use_id: Arc<str>,
+ output: Task<Result<String>>,
+ cx: &mut ModelContext<Self>,
+ ) {
+ let insert_output_task = cx.spawn(|thread, mut cx| {
+ let tool_use_id = tool_use_id.clone();
+ async move {
+ let output = output.await;
+ thread
+ .update(&mut cx, |thread, cx| match output {
+ Ok(output) => {
+ thread
+ .completed_tool_uses_by_id
+ .insert(tool_use_id.clone(), output);
+
+ cx.emit(ThreadEvent::ToolFinished { tool_use_id });
+ }
+ Err(err) => {
+ if let Some(tool_use) =
+ thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
+ {
+ tool_use.status = PendingToolUseStatus::Error(err.to_string());
+ }
+ }
+ })
+ .ok();
+ }
+ });
+
+ if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
+ tool_use.status = PendingToolUseStatus::Running {
+ _task: insert_output_task.shared(),
+ };
+ }
+ }
}
#[derive(Debug, Clone)]
pub enum ThreadEvent {
StreamedCompletion,
+ UsePendingTools,
+ ToolFinished {
+ #[allow(unused)]
+ tool_use_id: Arc<str>,
+ },
}
impl EventEmitter<ThreadEvent> for Thread {}
+
+struct PendingCompletion {
+ id: usize,
+ _task: Task<()>,
+}
+
+#[derive(Debug, Clone)]
+pub struct PendingToolUse {
+ pub id: Arc<str>,
+ pub name: String,
+ pub input: serde_json::Value,
+ pub status: PendingToolUseStatus,
+}
+
+#[derive(Debug, Clone)]
+pub enum PendingToolUseStatus {
+ Idle,
+ Running { _task: Shared<Task<()>> },
+ Error(#[allow(unused)] String),
+}
+
+impl PendingToolUseStatus {
+ pub fn is_idle(&self) -> bool {
+ matches!(self, PendingToolUseStatus::Idle)
+ }
+}
@@ -30,7 +30,7 @@ impl Tool for NowTool {
}
fn description(&self) -> String {
- "Returns the current datetime in RFC 3339 format.".into()
+ "Returns the current datetime in RFC 3339 format. Only use this tool when the user specifically asks for it or the current task would benefit from knowing the current datetime.".into()
}
fn input_schema(&self) -> serde_json::Value {
@@ -49,6 +49,16 @@ impl FeatureFlag for Assistant2FeatureFlag {
}
}
+pub struct ToolUseFeatureFlag;
+
+impl FeatureFlag for ToolUseFeatureFlag {
+ const NAME: &'static str = "assistant-tool-use";
+
+ fn enabled_for_staff() -> bool {
+ false
+ }
+}
+
pub struct Remoting {}
impl FeatureFlag for Remoting {
const NAME: &'static str = "remoting";