Centralize `always_allow` logic when authorizing agent2 tools (#35988)

Antonio Scandurra , Cole Miller , Bennet Bo Fenner , Agus Zubiaga , and Ben Brandt created

Release Notes:

- N/A

---------

Co-authored-by: Cole Miller <cole@zed.dev>
Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>

Change summary

crates/agent2/src/tests/mod.rs            | 93 ++++++++++++++++++++++++
crates/agent2/src/tests/test_tools.rs     |  4 
crates/agent2/src/thread.rs               | 40 ++++++++--
crates/agent2/src/tools/edit_file_tool.rs | 16 ++--
crates/agent2/src/tools/open_tool.rs      |  2 
crates/agent2/src/tools/terminal_tool.rs  | 18 ----
crates/fs/src/fs.rs                       |  3 
7 files changed, 136 insertions(+), 40 deletions(-)

Detailed changes

crates/agent2/src/tests/mod.rs 🔗

@@ -4,9 +4,11 @@ use action_log::ActionLog;
 use agent_client_protocol::{self as acp};
 use anyhow::Result;
 use client::{Client, UserStore};
-use fs::FakeFs;
+use fs::{FakeFs, Fs};
 use futures::channel::mpsc::UnboundedReceiver;
-use gpui::{AppContext, Entity, Task, TestAppContext, http_client::FakeHttpClient};
+use gpui::{
+    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
+};
 use indoc::indoc;
 use language_model::{
     LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
@@ -19,6 +21,7 @@ use reqwest_client::ReqwestClient;
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use serde_json::json;
+use settings::SettingsStore;
 use smol::stream::StreamExt;
 use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
 use util::path;
@@ -282,6 +285,63 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
             })
         ]
     );
+
+    // Simulate yet another tool call.
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+        LanguageModelToolUse {
+            id: "tool_id_3".into(),
+            name: ToolRequiringPermission.name().into(),
+            raw_input: "{}".into(),
+            input: json!({}),
+            is_input_complete: true,
+        },
+    ));
+    fake_model.end_last_completion_stream();
+
+    // Respond by always allowing tools.
+    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
+    tool_call_auth_3
+        .response
+        .send(tool_call_auth_3.options[0].id.clone())
+        .unwrap();
+    cx.run_until_parked();
+    let completion = fake_model.pending_completions().pop().unwrap();
+    let message = completion.messages.last().unwrap();
+    assert_eq!(
+        message.content,
+        vec![MessageContent::ToolResult(LanguageModelToolResult {
+            tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
+            tool_name: ToolRequiringPermission.name().into(),
+            is_error: false,
+            content: "Allowed".into(),
+            output: Some("Allowed".into())
+        })]
+    );
+
+    // Simulate a final tool call, ensuring we don't trigger authorization.
+    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+        LanguageModelToolUse {
+            id: "tool_id_4".into(),
+            name: ToolRequiringPermission.name().into(),
+            raw_input: "{}".into(),
+            input: json!({}),
+            is_input_complete: true,
+        },
+    ));
+    fake_model.end_last_completion_stream();
+    cx.run_until_parked();
+    let completion = fake_model.pending_completions().pop().unwrap();
+    let message = completion.messages.last().unwrap();
+    assert_eq!(
+        message.content,
+        vec![MessageContent::ToolResult(LanguageModelToolResult {
+            tool_use_id: "tool_id_4".into(),
+            tool_name: ToolRequiringPermission.name().into(),
+            is_error: false,
+            content: "Allowed".into(),
+            output: Some("Allowed".into())
+        })]
+    );
 }
 
 #[gpui::test]
@@ -773,13 +833,17 @@ impl TestModel {
 
 async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
     cx.executor().allow_parking();
+
+    let fs = FakeFs::new(cx.background_executor.clone());
+
     cx.update(|cx| {
         settings::init(cx);
+        watch_settings(fs.clone(), cx);
         Project::init_settings(cx);
+        agent_settings::init(cx);
     });
     let templates = Templates::new();
 
-    let fs = FakeFs::new(cx.background_executor.clone());
     fs.insert_tree(path!("/test"), json!({})).await;
     let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
 
@@ -841,3 +905,26 @@ fn init_logger() {
         env_logger::init();
     }
 }
+
+fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
+    let fs = fs.clone();
+    cx.spawn({
+        async move |cx| {
+            let mut new_settings_content_rx = settings::watch_config_file(
+                cx.background_executor(),
+                fs,
+                paths::settings_file().clone(),
+            );
+
+            while let Some(new_settings_content) = new_settings_content_rx.next().await {
+                cx.update(|cx| {
+                    SettingsStore::update_global(cx, |settings, cx| {
+                        settings.set_user_settings(&new_settings_content, cx)
+                    })
+                })
+                .ok();
+            }
+        }
+    })
+    .detach();
+}

crates/agent2/src/tests/test_tools.rs 🔗

@@ -110,9 +110,9 @@ impl AgentTool for ToolRequiringPermission {
         event_stream: ToolCallEventStream,
         cx: &mut App,
     ) -> Task<Result<String>> {
-        let auth_check = event_stream.authorize("Authorize?".into());
+        let authorize = event_stream.authorize("Authorize?", cx);
         cx.foreground_executor().spawn(async move {
-            auth_check.await?;
+            authorize.await?;
             Ok("Allowed".to_string())
         })
     }

crates/agent2/src/thread.rs 🔗

@@ -1,10 +1,12 @@
 use crate::{SystemPromptTemplate, Template, Templates};
 use action_log::ActionLog;
 use agent_client_protocol as acp;
+use agent_settings::AgentSettings;
 use anyhow::{Context as _, Result, anyhow};
 use assistant_tool::adapt_schema_to_format;
 use cloud_llm_client::{CompletionIntent, CompletionMode};
 use collections::HashMap;
+use fs::Fs;
 use futures::{
     channel::{mpsc, oneshot},
     stream::FuturesUnordered,
@@ -21,8 +23,9 @@ use project::Project;
 use prompt_store::ProjectContext;
 use schemars::{JsonSchema, Schema};
 use serde::{Deserialize, Serialize};
+use settings::{Settings, update_settings_file};
 use smol::stream::StreamExt;
-use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc};
+use std::{cell::RefCell, collections::BTreeMap, fmt::Write, rc::Rc, sync::Arc};
 use util::{ResultExt, markdown::MarkdownCodeBlock};
 
 #[derive(Debug, Clone)]
@@ -506,8 +509,9 @@ impl Thread {
             }));
         };
 
+        let fs = self.project.read(cx).fs().clone();
         let tool_event_stream =
-            ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone());
+            ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs));
         tool_event_stream.update_fields(acp::ToolCallUpdateFields {
             status: Some(acp::ToolCallStatus::InProgress),
             ..Default::default()
@@ -884,6 +888,7 @@ pub struct ToolCallEventStream {
     kind: acp::ToolKind,
     input: serde_json::Value,
     stream: AgentResponseEventStream,
+    fs: Option<Arc<dyn Fs>>,
 }
 
 impl ToolCallEventStream {
@@ -902,6 +907,7 @@ impl ToolCallEventStream {
             },
             acp::ToolKind::Other,
             AgentResponseEventStream(events_tx),
+            None,
         );
 
         (stream, ToolCallEventStreamReceiver(events_rx))
@@ -911,12 +917,14 @@ impl ToolCallEventStream {
         tool_use: &LanguageModelToolUse,
         kind: acp::ToolKind,
         stream: AgentResponseEventStream,
+        fs: Option<Arc<dyn Fs>>,
     ) -> Self {
         Self {
             tool_use_id: tool_use.id.clone(),
             kind,
             input: tool_use.input.clone(),
             stream,
+            fs,
         }
     }
 
@@ -951,7 +959,11 @@ impl ToolCallEventStream {
             .ok();
     }
 
-    pub fn authorize(&self, title: String) -> impl use<> + Future<Output = Result<()>> {
+    pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
+        if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
+            return Task::ready(Ok(()));
+        }
+
         let (response_tx, response_rx) = oneshot::channel();
         self.stream
             .0
@@ -959,7 +971,7 @@ impl ToolCallEventStream {
                 ToolCallAuthorization {
                     tool_call: AgentResponseEventStream::initial_tool_call(
                         &self.tool_use_id,
-                        title,
+                        title.into(),
                         self.kind.clone(),
                         self.input.clone(),
                     ),
@@ -984,12 +996,22 @@ impl ToolCallEventStream {
                 },
             )))
             .ok();
-        async move {
-            match response_rx.await?.0.as_ref() {
-                "allow" | "always_allow" => Ok(()),
-                _ => Err(anyhow!("Permission to run tool denied by user")),
+        let fs = self.fs.clone();
+        cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
+            "always_allow" => {
+                if let Some(fs) = fs.clone() {
+                    cx.update(|cx| {
+                        update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
+                            settings.set_always_allow_tool_actions(true);
+                        });
+                    })?;
+                }
+
+                Ok(())
             }
-        }
+            "allow" => Ok(()),
+            _ => Err(anyhow!("Permission to run tool denied by user")),
+        })
     }
 }
 

crates/agent2/src/tools/edit_file_tool.rs 🔗

@@ -133,7 +133,7 @@ impl EditFileTool {
         &self,
         input: &EditFileToolInput,
         event_stream: &ToolCallEventStream,
-        cx: &App,
+        cx: &mut App,
     ) -> Task<Result<()>> {
         if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
             return Task::ready(Ok(()));
@@ -147,8 +147,9 @@ impl EditFileTool {
             .components()
             .any(|component| component.as_os_str() == local_settings_folder.as_os_str())
         {
-            return cx.foreground_executor().spawn(
-                event_stream.authorize(format!("{} (local settings)", input.display_description)),
+            return event_stream.authorize(
+                format!("{} (local settings)", input.display_description),
+                cx,
             );
         }
 
@@ -156,9 +157,9 @@ impl EditFileTool {
         // so check for that edge case too.
         if let Ok(canonical_path) = std::fs::canonicalize(&input.path) {
             if canonical_path.starts_with(paths::config_dir()) {
-                return cx.foreground_executor().spawn(
-                    event_stream
-                        .authorize(format!("{} (global settings)", input.display_description)),
+                return event_stream.authorize(
+                    format!("{} (global settings)", input.display_description),
+                    cx,
                 );
             }
         }
@@ -173,8 +174,7 @@ impl EditFileTool {
         if project_path.is_some() {
             Task::ready(Ok(()))
         } else {
-            cx.foreground_executor()
-                .spawn(event_stream.authorize(input.display_description.clone()))
+            event_stream.authorize(&input.display_description, cx)
         }
     }
 }

crates/agent2/src/tools/open_tool.rs 🔗

@@ -65,7 +65,7 @@ impl AgentTool for OpenTool {
     ) -> Task<Result<Self::Output>> {
         // If path_or_url turns out to be a path in the project, make it absolute.
         let abs_path = to_absolute_path(&input.path_or_url, self.project.clone(), cx);
-        let authorize = event_stream.authorize(self.initial_title(Ok(input.clone())).to_string());
+        let authorize = event_stream.authorize(self.initial_title(Ok(input.clone())), cx);
         cx.background_spawn(async move {
             authorize.await?;
 

crates/agent2/src/tools/terminal_tool.rs 🔗

@@ -5,7 +5,6 @@ use gpui::{App, AppContext, Entity, SharedString, Task};
 use project::{Project, terminals::TerminalKind};
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
-use settings::Settings;
 use std::{
     path::{Path, PathBuf},
     sync::Arc,
@@ -61,21 +60,6 @@ impl TerminalTool {
             determine_shell: determine_shell.shared(),
         }
     }
-
-    fn authorize(
-        &self,
-        input: &TerminalToolInput,
-        event_stream: &ToolCallEventStream,
-        cx: &App,
-    ) -> Task<Result<()>> {
-        if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
-            return Task::ready(Ok(()));
-        }
-
-        // TODO: do we want to have a special title here?
-        cx.foreground_executor()
-            .spawn(event_stream.authorize(self.initial_title(Ok(input.clone())).to_string()))
-    }
 }
 
 impl AgentTool for TerminalTool {
@@ -152,7 +136,7 @@ impl AgentTool for TerminalTool {
             env
         });
 
-        let authorize = self.authorize(&input, &event_stream, cx);
+        let authorize = event_stream.authorize(self.initial_title(Ok(input.clone())), cx);
 
         cx.spawn({
             async move |cx| {

crates/fs/src/fs.rs 🔗

@@ -2172,6 +2172,9 @@ impl Fs for FakeFs {
     async fn atomic_write(&self, path: PathBuf, data: String) -> Result<()> {
         self.simulate_random_delay().await;
         let path = normalize_path(path.as_path());
+        if let Some(path) = path.parent() {
+            self.create_dir(path).await?;
+        }
         self.write_file_internal(path, data.into_bytes(), true)?;
         Ok(())
     }