Detailed changes
@@ -3,7 +3,7 @@ use crate::thread::{
ThreadEvent, ThreadFeedback,
};
use crate::thread_store::ThreadStore;
-use crate::tool_use::{ToolUse, ToolUseStatus};
+use crate::tool_use::{PendingToolUseStatus, ToolType, ToolUse, ToolUseStatus};
use crate::ui::ContextPill;
use collections::HashMap;
use editor::{Editor, MultiBuffer};
@@ -471,11 +471,18 @@ impl ActiveThread {
for tool_use in tool_uses {
self.render_tool_use_label_markdown(
- tool_use.id,
+ tool_use.id.clone(),
tool_use.ui_text.clone(),
window,
cx,
);
+ self.render_scripting_tool_use_markdown(
+ tool_use.id,
+ tool_use.name.as_ref(),
+ tool_use.input.clone(),
+ window,
+ cx,
+ );
}
}
ThreadEvent::ToolFinished {
@@ -491,13 +498,6 @@ impl ActiveThread {
window,
cx,
);
- self.render_scripting_tool_use_markdown(
- tool_use.id.clone(),
- tool_use.name.as_ref(),
- tool_use.input.clone(),
- window,
- cx,
- );
}
if self.thread.read(cx).all_tools_finished() {
@@ -996,29 +996,31 @@ impl ActiveThread {
)
.child(div().p_2().child(message_content)),
),
- Role::Assistant => v_flex()
- .id(("message-container", ix))
- .ml_2()
- .pl_2()
- .border_l_1()
- .border_color(cx.theme().colors().border_variant)
- .child(message_content)
- .when(
- !tool_uses.is_empty() || !scripting_tool_uses.is_empty(),
- |parent| {
- parent.child(
- v_flex()
- .children(
- tool_uses
- .into_iter()
- .map(|tool_use| self.render_tool_use(tool_use, cx)),
- )
- .children(scripting_tool_uses.into_iter().map(|tool_use| {
- self.render_scripting_tool_use(tool_use, window, cx)
- })),
- )
- },
- ),
+ Role::Assistant => {
+ v_flex()
+ .id(("message-container", ix))
+ .ml_2()
+ .pl_2()
+ .border_l_1()
+ .border_color(cx.theme().colors().border_variant)
+ .child(message_content)
+ .when(
+ !tool_uses.is_empty() || !scripting_tool_uses.is_empty(),
+ |parent| {
+ parent.child(
+ v_flex()
+ .children(
+ tool_uses
+ .into_iter()
+ .map(|tool_use| self.render_tool_use(tool_use, cx)),
+ )
+ .children(scripting_tool_uses.into_iter().map(|tool_use| {
+ self.render_scripting_tool_use(tool_use, cx)
+ })),
+ )
+ },
+ )
+ }
Role::System => div().id(("message-container", ix)).py_1().px_2().child(
v_flex()
.bg(colors.editor_background)
@@ -1379,7 +1381,8 @@ impl ActiveThread {
)
.child({
let (icon_name, color, animated) = match &tool_use.status {
- ToolUseStatus::Pending => {
+ ToolUseStatus::Pending
+ | ToolUseStatus::NeedsConfirmation => {
(IconName::Warning, Color::Warning, false)
}
ToolUseStatus::Running => {
@@ -1500,6 +1503,14 @@ impl ActiveThread {
),
),
ToolUseStatus::Pending => container,
+ ToolUseStatus::NeedsConfirmation => container.child(
+ content_container().child(
+ Label::new("Asking Permission")
+ .size(LabelSize::Small)
+ .color(Color::Muted)
+ .buffer_font(cx),
+ ),
+ ),
}),
)
}),
@@ -1509,7 +1520,6 @@ impl ActiveThread {
fn render_scripting_tool_use(
&self,
tool_use: ToolUse,
- window: &Window,
cx: &mut Context<Self>,
) -> impl IntoElement {
let is_open = self
@@ -1555,13 +1565,25 @@ impl ActiveThread {
}
}),
))
- .child(div().text_ui_sm(cx).child(render_markdown(
- tool_use.ui_text.clone(),
- self.language_registry.clone(),
- window,
- cx,
- )))
- .truncate(),
+ .child(
+ h_flex()
+ .gap_1p5()
+ .child(
+ Icon::new(IconName::Terminal)
+ .size(IconSize::XSmall)
+ .color(Color::Muted),
+ )
+ .child(
+ div()
+ .text_ui_sm(cx)
+ .children(
+ self.rendered_tool_use_labels
+ .get(&tool_use.id)
+ .cloned(),
+ )
+ .truncate(),
+ ),
+ ),
)
.child(
Label::new(match tool_use.status {
@@ -1569,6 +1591,7 @@ impl ActiveThread {
ToolUseStatus::Running => "Running",
ToolUseStatus::Finished(_) => "Finished",
ToolUseStatus::Error(_) => "Error",
+ ToolUseStatus::NeedsConfirmation => "Asking Permission",
})
.size(LabelSize::XSmall)
.buffer_font(cx),
@@ -1620,6 +1643,13 @@ impl ActiveThread {
.child(Label::new(err)),
),
ToolUseStatus::Pending | ToolUseStatus::Running => parent,
+ ToolUseStatus::NeedsConfirmation => parent.child(
+ v_flex()
+ .gap_0p5()
+ .py_1()
+ .px_2p5()
+ .child(Label::new("Asking Permission")),
+ ),
}),
)
}),
@@ -1682,6 +1712,45 @@ impl ActiveThread {
.into_any()
}
+ fn handle_allow_tool(
+ &mut self,
+ tool_use_id: LanguageModelToolUseId,
+ _: &ClickEvent,
+ _window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ if let Some(PendingToolUseStatus::NeedsConfirmation(c)) = self
+ .thread
+ .read(cx)
+ .pending_tool(&tool_use_id)
+ .map(|tool_use| tool_use.status.clone())
+ {
+ self.thread.update(cx, |thread, cx| {
+ thread.run_tool(
+ c.tool_use_id.clone(),
+ c.ui_text.clone(),
+ c.input.clone(),
+ &c.messages,
+ c.tool_type.clone(),
+ cx,
+ );
+ });
+ }
+ }
+
+ fn handle_deny_tool(
+ &mut self,
+ tool_use_id: LanguageModelToolUseId,
+ tool_type: ToolType,
+ _: &ClickEvent,
+ _window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ self.thread.update(cx, |thread, cx| {
+ thread.deny_tool_use(tool_use_id, tool_type, cx);
+ });
+ }
+
fn handle_open_rules(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
let Some(system_prompt_context) = self.thread.read(cx).system_prompt_context().as_ref()
else {
@@ -1704,12 +1773,82 @@ impl ActiveThread {
task.detach();
}
}
+
+ fn render_confirmations<'a>(
+ &'a mut self,
+ cx: &'a mut Context<Self>,
+ ) -> impl Iterator<Item = AnyElement> + 'a {
+ let thread = self.thread.read(cx);
+
+ thread
+ .tools_needing_confirmation()
+ .map(|(tool_type, tool)| {
+ div()
+ .m_3()
+ .p_2()
+ .bg(cx.theme().colors().editor_background)
+ .border_1()
+ .border_color(cx.theme().colors().border)
+ .rounded_lg()
+ .child(
+ v_flex()
+ .gap_1()
+ .child(
+ v_flex()
+ .gap_0p5()
+ .child(
+ Label::new("The agent wants to run this action:")
+ .color(Color::Muted),
+ )
+ .child(div().p_3().child(Label::new(&tool.ui_text))),
+ )
+ .child(
+ h_flex()
+ .gap_1()
+ .child({
+ let tool_id = tool.id.clone();
+ Button::new("allow-tool-action", "Allow").on_click(
+ cx.listener(move |this, event, window, cx| {
+ this.handle_allow_tool(
+ tool_id.clone(),
+ event,
+ window,
+ cx,
+ )
+ }),
+ )
+ })
+ .child({
+ let tool_id = tool.id.clone();
+ Button::new("deny-tool", "Deny").on_click(cx.listener(
+ move |this, event, window, cx| {
+ this.handle_deny_tool(
+ tool_id.clone(),
+ tool_type.clone(),
+ event,
+ window,
+ cx,
+ )
+ },
+ ))
+ }),
+ )
+ .child(
+ Label::new("Note: A future release will introduce a way to remember your answers to these. In the meantime, you can avoid these prompts by adding \"assistant\": { \"always_allow_tool_actions\": true } to your settings.json.")
+ .color(Color::Muted)
+ .size(LabelSize::Small),
+ ),
+ )
+ .into_any()
+ })
+ }
}
impl Render for ActiveThread {
- fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
+ fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
v_flex()
.size_full()
.child(list(self.list_state.clone()).flex_grow())
+ .children(self.render_confirmations(cx))
}
}
@@ -3,14 +3,15 @@ use std::io::Write;
use std::sync::Arc;
use anyhow::{Context as _, Result};
-use assistant_tool::{ActionLog, ToolWorkingSet};
+use assistant_settings::AssistantSettings;
+use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
use chrono::{DateTime, Utc};
use collections::{BTreeMap, HashMap, HashSet};
use fs::Fs;
use futures::future::Shared;
use futures::{FutureExt, StreamExt as _};
use git;
-use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task};
+use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
use language_model::{
LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
@@ -24,6 +25,7 @@ use prompt_store::{
};
use scripting_tool::{ScriptingSession, ScriptingTool};
use serde::{Deserialize, Serialize};
+use settings::Settings;
use util::{maybe, post_inc, ResultExt as _, TryFutureExt as _};
use uuid::Uuid;
@@ -32,7 +34,7 @@ use crate::thread_store::{
SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
SerializedToolUse,
};
-use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
+use crate::tool_use::{PendingToolUse, PendingToolUseStatus, ToolType, ToolUse, ToolUseState};
#[derive(Debug, Clone, Copy)]
pub enum RequestKind {
@@ -350,6 +352,44 @@ impl Thread {
&self.tools
}
+ pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
+ self.tool_use
+ .pending_tool_uses()
+ .into_iter()
+ .find(|tool_use| &tool_use.id == id)
+ .or_else(|| {
+ self.scripting_tool_use
+ .pending_tool_uses()
+ .into_iter()
+ .find(|tool_use| &tool_use.id == id)
+ })
+ }
+
+ pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = (ToolType, &PendingToolUse)> {
+ self.tool_use
+ .pending_tool_uses()
+ .into_iter()
+ .filter_map(|tool_use| {
+ if let PendingToolUseStatus::NeedsConfirmation(confirmation) = &tool_use.status {
+ Some((confirmation.tool_type.clone(), tool_use))
+ } else {
+ None
+ }
+ })
+ .chain(
+ self.scripting_tool_use
+ .pending_tool_uses()
+ .into_iter()
+ .filter_map(|tool_use| {
+ if tool_use.status.needs_confirmation() {
+ Some((ToolType::ScriptingTool, tool_use))
+ } else {
+ None
+ }
+ }),
+ )
+ }
+
pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
self.checkpoints_by_message.get(&id).cloned()
}
@@ -1178,6 +1218,7 @@ impl Thread {
cx: &mut Context<Self>,
) -> impl IntoIterator<Item = PendingToolUse> {
let request = self.to_completion_request(RequestKind::Chat, cx);
+ let messages = Arc::new(request.messages);
let pending_tool_uses = self
.tool_use
.pending_tool_uses()
@@ -1188,18 +1229,33 @@ impl Thread {
for tool_use in pending_tool_uses.iter() {
if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
- let task = tool.run(
- tool_use.input.clone(),
- &request.messages,
- self.project.clone(),
- self.action_log.clone(),
- cx,
- );
-
- self.insert_tool_output(
+ if tool.needs_confirmation()
+ && !AssistantSettings::get_global(cx).always_allow_tool_actions
+ {
+ self.tool_use.confirm_tool_use(
+ tool_use.id.clone(),
+ tool_use.ui_text.clone(),
+ tool_use.input.clone(),
+ messages.clone(),
+ ToolType::NonScriptingTool(tool),
+ );
+ } else {
+ self.run_tool(
+ tool_use.id.clone(),
+ tool_use.ui_text.clone(),
+ tool_use.input.clone(),
+ &messages,
+ ToolType::NonScriptingTool(tool),
+ cx,
+ );
+ }
+ } else if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
+ self.run_tool(
tool_use.id.clone(),
- tool_use.ui_text.clone().into(),
- task,
+ tool_use.ui_text.clone(),
+ tool_use.input.clone(),
+ &messages,
+ ToolType::NonScriptingTool(tool),
cx,
);
}
@@ -1214,36 +1270,13 @@ impl Thread {
.collect::<Vec<_>>();
for scripting_tool_use in pending_scripting_tool_uses.iter() {
- let task = match ScriptingTool::deserialize_input(scripting_tool_use.input.clone()) {
- Err(err) => Task::ready(Err(err.into())),
- Ok(input) => {
- let (script_id, script_task) =
- self.scripting_session.update(cx, move |session, cx| {
- session.run_script(input.lua_script, cx)
- });
-
- let session = self.scripting_session.clone();
- cx.spawn(async move |_, cx| {
- script_task.await;
-
- let message = session.read_with(cx, |session, _cx| {
- // Using a id to get the script output seems impractical.
- // Why not just include it in the Task result?
- // This is because we'll later report the script state as it runs,
- session
- .get(script_id)
- .output_message_for_llm()
- .expect("Script shouldn't still be running")
- })?;
-
- Ok(message)
- })
- }
- };
-
- let ui_text: SharedString = scripting_tool_use.name.clone().into();
-
- self.insert_scripting_tool_output(scripting_tool_use.id.clone(), ui_text, task, cx);
+ self.scripting_tool_use.confirm_tool_use(
+ scripting_tool_use.id.clone(),
+ scripting_tool_use.ui_text.clone(),
+ scripting_tool_use.input.clone(),
+ messages.clone(),
+ ToolType::ScriptingTool,
+ );
}
pending_tool_uses
@@ -1251,17 +1284,49 @@ impl Thread {
.chain(pending_scripting_tool_uses)
}
- pub fn insert_tool_output(
+ pub fn run_tool(
&mut self,
tool_use_id: LanguageModelToolUseId,
- ui_text: SharedString,
- output: Task<Result<String>>,
- cx: &mut Context<Self>,
+ ui_text: impl Into<SharedString>,
+ input: serde_json::Value,
+ messages: &[LanguageModelRequestMessage],
+ tool_type: ToolType,
+ cx: &mut Context<'_, Thread>,
) {
- let insert_output_task = cx.spawn({
- let tool_use_id = tool_use_id.clone();
- async move |thread, cx| {
- let output = output.await;
+ match tool_type {
+ ToolType::ScriptingTool => {
+ let task = self.spawn_scripting_tool_use(tool_use_id.clone(), input, cx);
+ self.scripting_tool_use
+ .run_pending_tool(tool_use_id, ui_text.into(), task);
+ }
+ ToolType::NonScriptingTool(tool) => {
+ let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
+ self.tool_use
+ .run_pending_tool(tool_use_id, ui_text.into(), task);
+ }
+ }
+ }
+
+ fn spawn_tool_use(
+ &mut self,
+ tool_use_id: LanguageModelToolUseId,
+ messages: &[LanguageModelRequestMessage],
+ input: serde_json::Value,
+ tool: Arc<dyn Tool>,
+ cx: &mut Context<Thread>,
+ ) -> Task<()> {
+ let run_tool = tool.run(
+ input,
+ messages,
+ self.project.clone(),
+ self.action_log.clone(),
+ cx,
+ );
+
+ cx.spawn({
+ async move |thread: WeakEntity<Thread>, cx| {
+ let output = run_tool.await;
+
thread
.update(cx, |thread, cx| {
let pending_tool_use = thread
@@ -1276,23 +1341,46 @@ impl Thread {
})
.ok();
}
- });
-
- self.tool_use
- .run_pending_tool(tool_use_id, ui_text, insert_output_task);
+ })
}
- pub fn insert_scripting_tool_output(
+ fn spawn_scripting_tool_use(
&mut self,
tool_use_id: LanguageModelToolUseId,
- ui_text: SharedString,
- output: Task<Result<String>>,
- cx: &mut Context<Self>,
- ) {
- let insert_output_task = cx.spawn({
+ input: serde_json::Value,
+ cx: &mut Context<Thread>,
+ ) -> Task<()> {
+ let task = match ScriptingTool::deserialize_input(input) {
+ Err(err) => Task::ready(Err(err.into())),
+ Ok(input) => {
+ let (script_id, script_task) =
+ self.scripting_session.update(cx, move |session, cx| {
+ session.run_script(input.lua_script, cx)
+ });
+
+ let session = self.scripting_session.clone();
+ cx.spawn(async move |_, cx| {
+ script_task.await;
+
+ let message = session.read_with(cx, |session, _cx| {
+ // Using a id to get the script output seems impractical.
+ // Why not just include it in the Task result?
+ // This is because we'll later report the script state as it runs,
+ session
+ .get(script_id)
+ .output_message_for_llm()
+ .expect("Script shouldn't still be running")
+ })?;
+
+ Ok(message)
+ })
+ }
+ };
+
+ cx.spawn({
let tool_use_id = tool_use_id.clone();
async move |thread, cx| {
- let output = output.await;
+ let output = task.await;
thread
.update(cx, |thread, cx| {
let pending_tool_use = thread
@@ -1307,10 +1395,7 @@ impl Thread {
})
.ok();
}
- });
-
- self.scripting_tool_use
- .run_pending_tool(tool_use_id, ui_text, insert_output_task);
+ })
}
pub fn attach_tool_results(
@@ -1568,6 +1653,30 @@ impl Thread {
pub fn cumulative_token_usage(&self) -> TokenUsage {
self.cumulative_token_usage.clone()
}
+
+ pub fn deny_tool_use(
+ &mut self,
+ tool_use_id: LanguageModelToolUseId,
+ tool_type: ToolType,
+ cx: &mut Context<Self>,
+ ) {
+ let err = Err(anyhow::anyhow!(
+ "Permission to run tool action denied by user"
+ ));
+
+ if let ToolType::ScriptingTool = tool_type {
+ self.scripting_tool_use
+ .insert_tool_output(tool_use_id.clone(), err);
+ } else {
+ self.tool_use.insert_tool_output(tool_use_id.clone(), err);
+ }
+
+ cx.emit(ThreadEvent::ToolFinished {
+ tool_use_id,
+ pending_tool_use: None,
+ canceled: true,
+ });
+ }
}
#[derive(Debug, Clone)]
@@ -1,7 +1,7 @@
use std::sync::Arc;
use anyhow::Result;
-use assistant_tool::ToolWorkingSet;
+use assistant_tool::{Tool, ToolWorkingSet};
use collections::HashMap;
use futures::future::Shared;
use futures::FutureExt as _;
@@ -10,6 +10,7 @@ use language_model::{
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse,
LanguageModelToolUseId, MessageContent, Role,
};
+use scripting_tool::ScriptingTool;
use crate::thread::MessageId;
use crate::thread_store::SerializedMessage;
@@ -25,6 +26,7 @@ pub struct ToolUse {
#[derive(Debug, Clone)]
pub enum ToolUseStatus {
+ NeedsConfirmation,
Pending,
Running,
Finished(SharedString),
@@ -163,16 +165,19 @@ impl ToolUseState {
}
if let Some(pending_tool_use) = self.pending_tool_uses_by_id.get(&tool_use.id) {
- return match pending_tool_use.status {
+ match pending_tool_use.status {
PendingToolUseStatus::Idle => ToolUseStatus::Pending,
+ PendingToolUseStatus::NeedsConfirmation { .. } => {
+ ToolUseStatus::NeedsConfirmation
+ }
PendingToolUseStatus::Running { .. } => ToolUseStatus::Running,
PendingToolUseStatus::Error(ref err) => {
ToolUseStatus::Error(err.clone().into())
}
- };
+ }
+ } else {
+ ToolUseStatus::Pending
}
-
- ToolUseStatus::Pending
})();
tool_uses.push(ToolUse {
@@ -195,6 +200,8 @@ impl ToolUseState {
) -> SharedString {
if let Some(tool) = self.tools.tool(tool_name, cx) {
tool.ui_text(input).into()
+ } else if tool_name == ScriptingTool::NAME {
+ "Run Lua Script".into()
} else {
"Unknown tool".into()
}
@@ -272,6 +279,28 @@ impl ToolUseState {
}
}
+ pub fn confirm_tool_use(
+ &mut self,
+ tool_use_id: LanguageModelToolUseId,
+ ui_text: impl Into<Arc<str>>,
+ input: serde_json::Value,
+ messages: Arc<Vec<LanguageModelRequestMessage>>,
+ tool_type: ToolType,
+ ) {
+ if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
+ let ui_text = ui_text.into();
+ tool_use.ui_text = ui_text.clone();
+ let confirmation = Confirmation {
+ tool_use_id,
+ input,
+ messages,
+ tool_type,
+ ui_text,
+ };
+ tool_use.status = PendingToolUseStatus::NeedsConfirmation(Arc::new(confirmation));
+ }
+ }
+
pub fn insert_tool_output(
&mut self,
tool_use_id: LanguageModelToolUseId,
@@ -369,9 +398,25 @@ pub struct PendingToolUse {
pub status: PendingToolUseStatus,
}
+#[derive(Debug, Clone)]
+pub enum ToolType {
+ ScriptingTool,
+ NonScriptingTool(Arc<dyn Tool>),
+}
+
+#[derive(Debug, Clone)]
+pub struct Confirmation {
+ pub tool_use_id: LanguageModelToolUseId,
+ pub input: serde_json::Value,
+ pub ui_text: Arc<str>,
+ pub messages: Arc<Vec<LanguageModelRequestMessage>>,
+ pub tool_type: ToolType,
+}
+
#[derive(Debug, Clone)]
pub enum PendingToolUseStatus {
Idle,
+ NeedsConfirmation(Arc<Confirmation>),
Running { _task: Shared<Task<()>> },
Error(#[allow(unused)] Arc<str>),
}
@@ -384,4 +429,8 @@ impl PendingToolUseStatus {
pub fn is_error(&self) -> bool {
matches!(self, PendingToolUseStatus::Error(_))
}
+
+ pub fn needs_confirmation(&self) -> bool {
+ matches!(self, PendingToolUseStatus::NeedsConfirmation { .. })
+ }
}
@@ -72,6 +72,7 @@ pub struct AssistantSettings {
pub using_outdated_settings_version: bool,
pub enable_experimental_live_diffs: bool,
pub profiles: IndexMap<Arc<str>, AgentProfile>,
+ pub always_allow_tool_actions: bool,
}
impl AssistantSettings {
@@ -173,6 +174,7 @@ impl AssistantSettingsContent {
inline_alternatives: None,
enable_experimental_live_diffs: None,
profiles: None,
+ always_allow_tool_actions: None,
},
VersionedAssistantSettingsContent::V2(settings) => settings.clone(),
},
@@ -195,6 +197,7 @@ impl AssistantSettingsContent {
inline_alternatives: None,
enable_experimental_live_diffs: None,
profiles: None,
+ always_allow_tool_actions: None,
},
}
}
@@ -325,6 +328,7 @@ impl Default for VersionedAssistantSettingsContent {
inline_alternatives: None,
enable_experimental_live_diffs: None,
profiles: None,
+ always_allow_tool_actions: None,
})
}
}
@@ -363,6 +367,11 @@ pub struct AssistantSettingsContentV2 {
enable_experimental_live_diffs: Option<bool>,
#[schemars(skip)]
profiles: Option<IndexMap<Arc<str>, AgentProfileContent>>,
+ /// Whenever a tool action would normally wait for your confirmation
+ /// that you allow it, always choose to allow it.
+ ///
+ /// Default: false
+ always_allow_tool_actions: Option<bool>,
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
@@ -499,6 +508,10 @@ impl Settings for AssistantSettings {
&mut settings.enable_experimental_live_diffs,
value.enable_experimental_live_diffs,
);
+ merge(
+ &mut settings.always_allow_tool_actions,
+ value.always_allow_tool_actions,
+ );
if let Some(profiles) = value.profiles {
settings
@@ -579,6 +592,7 @@ mod tests {
default_height: None,
enable_experimental_live_diffs: None,
profiles: None,
+ always_allow_tool_actions: None,
}),
)
},
@@ -1,14 +1,14 @@
mod tool_registry;
mod tool_working_set;
-use std::sync::Arc;
-
use anyhow::Result;
use collections::{HashMap, HashSet};
use gpui::{App, Context, Entity, SharedString, Task};
use language::Buffer;
use language_model::LanguageModelRequestMessage;
use project::Project;
+use std::fmt::{self, Debug, Formatter};
+use std::sync::Arc;
pub use crate::tool_registry::*;
pub use crate::tool_working_set::*;
@@ -38,6 +38,10 @@ pub trait Tool: 'static + Send + Sync {
ToolSource::Native
}
+ /// Returns true iff the tool needs the users's confirmation
+ /// before having permission to run.
+ fn needs_confirmation(&self) -> bool;
+
/// Returns the JSON schema that describes the tool's input.
fn input_schema(&self) -> serde_json::Value {
serde_json::Value::Object(serde_json::Map::default())
@@ -57,6 +61,12 @@ pub trait Tool: 'static + Send + Sync {
) -> Task<Result<String>>;
}
+impl Debug for dyn Tool {
+ fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+ f.debug_struct("Tool").field("name", &self.name()).finish()
+ }
+}
+
/// Tracks actions performed by tools in a thread
#[derive(Debug)]
pub struct ActionLog {
@@ -23,6 +23,10 @@ impl Tool for BashTool {
"bash".to_string()
}
+ fn needs_confirmation(&self) -> bool {
+ true
+ }
+
fn description(&self) -> String {
include_str!("./bash_tool/description.md").to_string()
}
@@ -30,6 +30,10 @@ impl Tool for DeletePathTool {
"delete-path".into()
}
+ fn needs_confirmation(&self) -> bool {
+ true
+ }
+
fn description(&self) -> String {
include_str!("./delete_path_tool/description.md").into()
}
@@ -37,6 +37,10 @@ impl Tool for DiagnosticsTool {
"diagnostics".into()
}
+ fn needs_confirmation(&self) -> bool {
+ false
+ }
+
fn description(&self) -> String {
include_str!("./diagnostics_tool/description.md").into()
}
@@ -79,6 +79,10 @@ impl Tool for EditFilesTool {
"edit-files".into()
}
+ fn needs_confirmation(&self) -> bool {
+ true
+ }
+
fn description(&self) -> String {
include_str!("./edit_files_tool/description.md").into()
}
@@ -145,30 +149,22 @@ impl Tool for EditFilesTool {
struct EditToolRequest {
parser: EditActionParser,
- editor_response: EditorResponse,
+ output: String,
+ changed_buffers: HashSet<Entity<language::Buffer>>,
+ bad_searches: Vec<BadSearch>,
project: Entity<Project>,
action_log: Entity<ActionLog>,
tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
}
-enum EditorResponse {
- /// The editor model hasn't produced any actions yet.
- /// If we don't have any by the end, we'll return its message to the architect model.
- Message(String),
- /// The editor model produced at least one action.
- Actions {
- applied: Vec<AppliedAction>,
- search_errors: Vec<SearchError>,
- },
-}
-
-struct AppliedAction {
- source: String,
- buffer: Entity<language::Buffer>,
+#[derive(Debug)]
+enum DiffResult {
+ BadSearch(BadSearch),
+ Diff(language::Diff),
}
#[derive(Debug)]
-enum SearchError {
+enum BadSearch {
NoMatch {
file_path: String,
search: String,
@@ -234,7 +230,10 @@ impl EditToolRequest {
let mut request = Self {
parser: EditActionParser::new(),
- editor_response: EditorResponse::Message(String::with_capacity(256)),
+ // we start with the success header so we don't need to shift the output in the common case
+ output: Self::SUCCESS_OUTPUT_HEADER.to_string(),
+ changed_buffers: HashSet::default(),
+ bad_searches: Vec::new(),
action_log,
project,
tool_log,
@@ -251,12 +250,6 @@ impl EditToolRequest {
async fn process_response_chunk(&mut self, chunk: &str, cx: &mut AsyncApp) -> Result<()> {
let new_actions = self.parser.parse_chunk(chunk);
- if let EditorResponse::Message(ref mut message) = self.editor_response {
- if new_actions.is_empty() {
- message.push_str(chunk);
- }
- }
-
if let Some((ref log, req_id)) = self.tool_log {
log.update(cx, |log, cx| {
log.push_editor_response_chunk(req_id, chunk, &new_actions, cx)
@@ -287,11 +280,6 @@ impl EditToolRequest {
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?;
- enum DiffResult {
- Diff(language::Diff),
- SearchError(SearchError),
- }
-
let result = match action {
EditAction::Replace {
old,
@@ -301,39 +289,7 @@ impl EditToolRequest {
let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
cx.background_executor()
- .spawn(async move {
- if snapshot.is_empty() {
- let exists = snapshot
- .file()
- .map_or(false, |file| file.disk_state().exists());
-
- let error = SearchError::EmptyBuffer {
- file_path: file_path.display().to_string(),
- exists,
- search: old,
- };
-
- return anyhow::Ok(DiffResult::SearchError(error));
- }
-
- let replace_result =
- // Try to match exactly
- replace_exact(&old, &new, &snapshot)
- .await
- // If that fails, try being flexible about indentation
- .or_else(|| replace_with_flexible_indent(&old, &new, &snapshot));
-
- let Some(diff) = replace_result else {
- let error = SearchError::NoMatch {
- search: old,
- file_path: file_path.display().to_string(),
- };
-
- return Ok(DiffResult::SearchError(error));
- };
-
- Ok(DiffResult::Diff(diff))
- })
+ .spawn(Self::replace_diff(old, new, file_path, snapshot))
.await
}
EditAction::Write { content, .. } => Ok(DiffResult::Diff(
@@ -344,179 +300,177 @@ impl EditToolRequest {
}?;
match result {
- DiffResult::SearchError(error) => {
- self.push_search_error(error);
+ DiffResult::BadSearch(invalid_replace) => {
+ self.bad_searches.push(invalid_replace);
}
DiffResult::Diff(diff) => {
let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
- self.push_applied_action(AppliedAction { source, buffer });
+ write!(&mut self.output, "\n\n{}", source)?;
+ self.changed_buffers.insert(buffer);
}
}
- anyhow::Ok(())
+ Ok(())
}
- fn push_search_error(&mut self, error: SearchError) {
- match &mut self.editor_response {
- EditorResponse::Message(_) => {
- self.editor_response = EditorResponse::Actions {
- applied: Vec::new(),
- search_errors: vec![error],
- };
- }
- EditorResponse::Actions { search_errors, .. } => {
- search_errors.push(error);
- }
+ async fn replace_diff(
+ old: String,
+ new: String,
+ file_path: std::path::PathBuf,
+ snapshot: language::BufferSnapshot,
+ ) -> Result<DiffResult> {
+ if snapshot.is_empty() {
+ let exists = snapshot
+ .file()
+ .map_or(false, |file| file.disk_state().exists());
+
+ return Ok(DiffResult::BadSearch(BadSearch::EmptyBuffer {
+ file_path: file_path.display().to_string(),
+ exists,
+ search: old,
+ }));
}
- }
- fn push_applied_action(&mut self, action: AppliedAction) {
- match &mut self.editor_response {
- EditorResponse::Message(_) => {
- self.editor_response = EditorResponse::Actions {
- applied: vec![action],
- search_errors: Vec::new(),
- };
- }
- EditorResponse::Actions { applied, .. } => {
- applied.push(action);
- }
- }
+ let result =
+ // Try to match exactly
+ replace_exact(&old, &new, &snapshot)
+ .await
+ // If that fails, try being flexible about indentation
+ .or_else(|| replace_with_flexible_indent(&old, &new, &snapshot));
+
+ let Some(diff) = result else {
+ return anyhow::Ok(DiffResult::BadSearch(BadSearch::NoMatch {
+ search: old,
+ file_path: file_path.display().to_string(),
+ }));
+ };
+
+ anyhow::Ok(DiffResult::Diff(diff))
}
+ const SUCCESS_OUTPUT_HEADER: &str = "Successfully applied. Here's a list of changes:";
+ const ERROR_OUTPUT_HEADER_NO_EDITS: &str = "I couldn't apply any edits!";
+ const ERROR_OUTPUT_HEADER_WITH_EDITS: &str =
+ "Errors occurred. First, here's a list of the edits we managed to apply:";
+
async fn finalize(self, cx: &mut AsyncApp) -> Result<String> {
- match self.editor_response {
- EditorResponse::Message(message) => Err(anyhow!(
- "No edits were applied! You might need to provide more context.\n\n{}",
- message
- )),
- EditorResponse::Actions {
- applied,
- search_errors,
- } => {
- let mut output = String::with_capacity(1024);
-
- let parse_errors = self.parser.errors();
- let has_errors = !search_errors.is_empty() || !parse_errors.is_empty();
-
- if has_errors {
- let error_count = search_errors.len() + parse_errors.len();
-
- if applied.is_empty() {
- writeln!(
- &mut output,
- "{} errors occurred! No edits were applied.",
- error_count,
- )?;
- } else {
- writeln!(
- &mut output,
- "{} errors occurred, but {} edits were correctly applied.",
- error_count,
- applied.len(),
- )?;
-
- writeln!(
- &mut output,
- "# {} SEARCH/REPLACE block(s) applied:\n\nDo not re-send these since they are already applied!\n",
- applied.len()
- )?;
- }
- } else {
- write!(
- &mut output,
- "Successfully applied! Here's a list of applied edits:"
- )?;
- }
+ let changed_buffer_count = self.changed_buffers.len();
- let mut changed_buffers = HashSet::default();
+ // Save each buffer once at the end
+ for buffer in &self.changed_buffers {
+ self.project
+ .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
+ .await?;
+ }
- for action in applied {
- changed_buffers.insert(action.buffer);
- write!(&mut output, "\n\n{}", action.source)?;
- }
+ self.action_log
+ .update(cx, |log, cx| log.buffer_edited(self.changed_buffers, cx))
+ .log_err();
- for buffer in &changed_buffers {
- self.project
- .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
- .await?;
- }
+ let errors = self.parser.errors();
+
+ if errors.is_empty() && self.bad_searches.is_empty() {
+ if changed_buffer_count == 0 {
+ return Err(anyhow!(
+ "The instructions didn't lead to any changes. You might need to consult the file contents first."
+ ));
+ }
+
+ Ok(self.output)
+ } else {
+ let mut output = self.output;
+
+ if output.is_empty() {
+ output.replace_range(
+ 0..Self::SUCCESS_OUTPUT_HEADER.len(),
+ Self::ERROR_OUTPUT_HEADER_NO_EDITS,
+ );
+ } else {
+ output.replace_range(
+ 0..Self::SUCCESS_OUTPUT_HEADER.len(),
+ Self::ERROR_OUTPUT_HEADER_WITH_EDITS,
+ );
+ }
- self.action_log
- .update(cx, |log, cx| log.buffer_edited(changed_buffers.clone(), cx))
- .log_err();
-
- if !search_errors.is_empty() {
- writeln!(
- &mut output,
- "\n\n## {} SEARCH/REPLACE block(s) failed to match:\n",
- search_errors.len()
- )?;
-
- for error in search_errors {
- match error {
- SearchError::NoMatch { file_path, search } => {
- writeln!(
- &mut output,
- "### No exact match in: `{}`\n```\n{}\n```\n",
- file_path, search,
- )?;
- }
- SearchError::EmptyBuffer {
- file_path,
- exists: true,
- search,
- } => {
- writeln!(
- &mut output,
- "### No match because `{}` is empty:\n```\n{}\n```\n",
- file_path, search,
- )?;
- }
- SearchError::EmptyBuffer {
- file_path,
- exists: false,
- search,
- } => {
- writeln!(
- &mut output,
- "### No match because `{}` does not exist:\n```\n{}\n```\n",
- file_path, search,
- )?;
- }
+ if !self.bad_searches.is_empty() {
+ writeln!(
+ &mut output,
+ "\n\n# {} SEARCH/REPLACE block(s) failed to match:\n",
+ self.bad_searches.len()
+ )?;
+
+ for bad_search in self.bad_searches {
+ match bad_search {
+ BadSearch::NoMatch { file_path, search } => {
+ writeln!(
+ &mut output,
+ "## No exact match in: `{}`\n```\n{}\n```\n",
+ file_path, search,
+ )?;
+ }
+ BadSearch::EmptyBuffer {
+ file_path,
+ exists: true,
+ search,
+ } => {
+ writeln!(
+ &mut output,
+ "## No match because `{}` is empty:\n```\n{}\n```\n",
+ file_path, search,
+ )?;
+ }
+ BadSearch::EmptyBuffer {
+ file_path,
+ exists: false,
+ search,
+ } => {
+ writeln!(
+ &mut output,
+ "## No match because `{}` does not exist:\n```\n{}\n```\n",
+ file_path, search,
+ )?;
}
}
-
- write!(&mut output,
- "The SEARCH section must exactly match an existing block of lines including all white \
- space, comments, indentation, docstrings, etc."
- )?;
}
- if !parse_errors.is_empty() {
- writeln!(
- &mut output,
- "\n\n## {} SEARCH/REPLACE blocks failed to parse:",
- parse_errors.len()
- )?;
+ write!(&mut output,
+ "The SEARCH section must exactly match an existing block of lines including all white \
+ space, comments, indentation, docstrings, etc."
+ )?;
+ }
- for error in parse_errors {
- writeln!(&mut output, "- {}", error)?;
- }
+ if !errors.is_empty() {
+ writeln!(
+ &mut output,
+ "\n\n# {} SEARCH/REPLACE blocks failed to parse:",
+ errors.len()
+ )?;
+
+ for error in errors {
+ writeln!(&mut output, "- {}", error)?;
}
+ }
- if has_errors {
- writeln!(&mut output,
- "\n\nYou can fix errors by running the tool again. You can include instructions, \
- but errors are part of the conversation so you don't need to repeat them.",
- )?;
+ if changed_buffer_count > 0 {
+ writeln!(
+ &mut output,
+ "\n\nThe other SEARCH/REPLACE blocks were applied successfully. Do not re-send them!",
+ )?;
+ }
- Err(anyhow!(output))
+ writeln!(
+ &mut output,
+ "{}You can fix errors by running the tool again. You can include instructions, \
+ but errors are part of the conversation so you don't need to repeat them.",
+ if changed_buffer_count == 0 {
+ "\n\n"
} else {
- Ok(output)
+ ""
}
- }
+ )?;
+
+ Err(anyhow!(output))
}
}
}
@@ -113,6 +113,10 @@ impl Tool for FetchTool {
"fetch".to_string()
}
+ fn needs_confirmation(&self) -> bool {
+ true
+ }
+
fn description(&self) -> String {
include_str!("./fetch_tool/description.md").to_string()
}
@@ -31,7 +31,7 @@ pub struct ListDirectoryToolInput {
///
/// If you wanna list contents in the directory `foo/baz`, you should use the path `foo/baz`.
/// </example>
- pub path: Arc<Path>,
+ pub path: String,
}
pub struct ListDirectoryTool;
@@ -41,6 +41,10 @@ impl Tool for ListDirectoryTool {
"list-directory".into()
}
+ fn needs_confirmation(&self) -> bool {
+ false
+ }
+
fn description(&self) -> String {
include_str!("./list_directory_tool/description.md").into()
}
@@ -52,7 +56,7 @@ impl Tool for ListDirectoryTool {
fn ui_text(&self, input: &serde_json::Value) -> String {
match serde_json::from_value::<ListDirectoryToolInput>(input.clone()) {
- Ok(input) => format!("List the `{}` directory's contents", input.path.display()),
+ Ok(input) => format!("List the `{}` directory's contents", input.path),
Err(_) => "List directory".to_string(),
}
}
@@ -70,11 +74,29 @@ impl Tool for ListDirectoryTool {
Err(err) => return Task::ready(Err(anyhow!(err))),
};
+ // Sometimes models will return these even though we tell it to give a path and not a glob.
+ // When this happens, just list the root worktree directories.
+ if matches!(input.path.as_str(), "." | "" | "./" | "*") {
+ let output = project
+ .read(cx)
+ .worktrees(cx)
+ .filter_map(|worktree| {
+ worktree.read(cx).root_entry().and_then(|entry| {
+ if entry.is_dir() {
+ entry.path.to_str()
+ } else {
+ None
+ }
+ })
+ })
+ .collect::<Vec<_>>()
+ .join("\n");
+
+ return Task::ready(Ok(output));
+ }
+
let Some(project_path) = project.read(cx).find_project_path(&input.path, cx) else {
- return Task::ready(Err(anyhow!(
- "Path {} not found in project",
- input.path.display()
- )));
+ return Task::ready(Err(anyhow!("Path {} not found in project", input.path)));
};
let Some(worktree) = project
.read(cx)
@@ -85,11 +107,11 @@ impl Tool for ListDirectoryTool {
let worktree = worktree.read(cx);
let Some(entry) = worktree.entry_for_path(&project_path.path) else {
- return Task::ready(Err(anyhow!("Path not found: {}", input.path.display())));
+ return Task::ready(Err(anyhow!("Path not found: {}", input.path)));
};
if !entry.is_dir() {
- return Task::ready(Err(anyhow!("{} is not a directory.", input.path.display())));
+ return Task::ready(Err(anyhow!("{} is not a directory.", input.path)));
}
let mut output = String::new();
@@ -102,7 +124,7 @@ impl Tool for ListDirectoryTool {
.unwrap();
}
if output.is_empty() {
- return Task::ready(Ok(format!("{} is empty.", input.path.display())));
+ return Task::ready(Ok(format!("{} is empty.", input.path)));
}
Task::ready(Ok(output))
}
@@ -31,6 +31,10 @@ impl Tool for NowTool {
"now".into()
}
+ fn needs_confirmation(&self) -> bool {
+ false
+ }
+
fn description(&self) -> String {
"Returns the current datetime in RFC 3339 format. Only use this tool when the user specifically asks for it or the current task would benefit from knowing the current datetime.".into()
}
@@ -39,6 +39,10 @@ impl Tool for PathSearchTool {
"path-search".into()
}
+ fn needs_confirmation(&self) -> bool {
+ false
+ }
+
fn description(&self) -> String {
include_str!("./path_search_tool/description.md").into()
}
@@ -44,6 +44,10 @@ impl Tool for ReadFileTool {
"read-file".into()
}
+ fn needs_confirmation(&self) -> bool {
+ false
+ }
+
fn description(&self) -> String {
include_str!("./read_file_tool/description.md").into()
}
@@ -41,6 +41,10 @@ impl Tool for RegexSearchTool {
"regex-search".into()
}
+ fn needs_confirmation(&self) -> bool {
+ false
+ }
+
fn description(&self) -> String {
include_str!("./regex_search_tool/description.md").into()
}
@@ -22,6 +22,10 @@ impl Tool for ThinkingTool {
"thinking".to_string()
}
+ fn needs_confirmation(&self) -> bool {
+ false
+ }
+
fn description(&self) -> String {
include_str!("./thinking_tool/description.md").to_string()
}
@@ -44,6 +44,10 @@ impl Tool for ContextServerTool {
}
}
+ fn needs_confirmation(&self) -> bool {
+ true
+ }
+
fn input_schema(&self) -> serde_json::Value {
match &self.tool.input_schema {
serde_json::Value::Null => {