Detailed changes
@@ -4,9 +4,11 @@ use action_log::ActionLog;
use agent_client_protocol::{self as acp};
use anyhow::Result;
use client::{Client, UserStore};
-use fs::FakeFs;
+use fs::{FakeFs, Fs};
use futures::channel::mpsc::UnboundedReceiver;
-use gpui::{AppContext, Entity, Task, TestAppContext, http_client::FakeHttpClient};
+use gpui::{
+ App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
+};
use indoc::indoc;
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
@@ -19,6 +21,7 @@ use reqwest_client::ReqwestClient;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
+use settings::SettingsStore;
use smol::stream::StreamExt;
use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
use util::path;
@@ -282,6 +285,63 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
})
]
);
+
+ // Simulate yet another tool call.
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: "tool_id_3".into(),
+ name: ToolRequiringPermission.name().into(),
+ raw_input: "{}".into(),
+ input: json!({}),
+ is_input_complete: true,
+ },
+ ));
+ fake_model.end_last_completion_stream();
+
+ // Respond by always allowing tools.
+ let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
+ tool_call_auth_3
+ .response
+ .send(tool_call_auth_3.options[0].id.clone())
+ .unwrap();
+ cx.run_until_parked();
+ let completion = fake_model.pending_completions().pop().unwrap();
+ let message = completion.messages.last().unwrap();
+ assert_eq!(
+ message.content,
+ vec![MessageContent::ToolResult(LanguageModelToolResult {
+ tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
+ tool_name: ToolRequiringPermission.name().into(),
+ is_error: false,
+ content: "Allowed".into(),
+ output: Some("Allowed".into())
+ })]
+ );
+
+ // Simulate a final tool call, ensuring we don't trigger authorization.
+ fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: "tool_id_4".into(),
+ name: ToolRequiringPermission.name().into(),
+ raw_input: "{}".into(),
+ input: json!({}),
+ is_input_complete: true,
+ },
+ ));
+ fake_model.end_last_completion_stream();
+ cx.run_until_parked();
+ let completion = fake_model.pending_completions().pop().unwrap();
+ let message = completion.messages.last().unwrap();
+ assert_eq!(
+ message.content,
+ vec![MessageContent::ToolResult(LanguageModelToolResult {
+ tool_use_id: "tool_id_4".into(),
+ tool_name: ToolRequiringPermission.name().into(),
+ is_error: false,
+ content: "Allowed".into(),
+ output: Some("Allowed".into())
+ })]
+ );
}
#[gpui::test]
@@ -773,13 +833,17 @@ impl TestModel {
async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
cx.executor().allow_parking();
+
+ let fs = FakeFs::new(cx.background_executor.clone());
+
cx.update(|cx| {
settings::init(cx);
+ watch_settings(fs.clone(), cx);
Project::init_settings(cx);
+ agent_settings::init(cx);
});
let templates = Templates::new();
- let fs = FakeFs::new(cx.background_executor.clone());
fs.insert_tree(path!("/test"), json!({})).await;
let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
@@ -841,3 +905,26 @@ fn init_logger() {
env_logger::init();
}
}
+
+fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
+ let fs = fs.clone();
+ cx.spawn({
+ async move |cx| {
+ let mut new_settings_content_rx = settings::watch_config_file(
+ cx.background_executor(),
+ fs,
+ paths::settings_file().clone(),
+ );
+
+ while let Some(new_settings_content) = new_settings_content_rx.next().await {
+ cx.update(|cx| {
+ SettingsStore::update_global(cx, |settings, cx| {
+ settings.set_user_settings(&new_settings_content, cx)
+ })
+ })
+ .ok();
+ }
+ }
+ })
+ .detach();
+}
@@ -110,9 +110,9 @@ impl AgentTool for ToolRequiringPermission {
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Task<Result<String>> {
- let auth_check = event_stream.authorize("Authorize?".into());
+ let authorize = event_stream.authorize("Authorize?", cx);
cx.foreground_executor().spawn(async move {
- auth_check.await?;
+ authorize.await?;
Ok("Allowed".to_string())
})
}
@@ -1,10 +1,12 @@
use crate::{SystemPromptTemplate, Template, Templates};
use action_log::ActionLog;
use agent_client_protocol as acp;
+use agent_settings::AgentSettings;
use anyhow::{Context as _, Result, anyhow};
use assistant_tool::adapt_schema_to_format;
use cloud_llm_client::{CompletionIntent, CompletionMode};
use collections::HashMap;
+use fs::Fs;
use futures::{
channel::{mpsc, oneshot},
stream::FuturesUnordered,
@@ -21,8 +23,9 @@ use project::Project;
use prompt_store::ProjectContext;
use schemars::{JsonSchema, Schema};
use serde::{Deserialize, Serialize};
+use settings::{Settings, update_settings_file};
use smol::stream::StreamExt;
-use std::{cell::RefCell, collections::BTreeMap, fmt::Write, future::Future, rc::Rc, sync::Arc};
+use std::{cell::RefCell, collections::BTreeMap, fmt::Write, rc::Rc, sync::Arc};
use util::{ResultExt, markdown::MarkdownCodeBlock};
#[derive(Debug, Clone)]
@@ -506,8 +509,9 @@ impl Thread {
}));
};
+ let fs = self.project.read(cx).fs().clone();
let tool_event_stream =
- ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone());
+ ToolCallEventStream::new(&tool_use, tool.kind(), event_stream.clone(), Some(fs));
tool_event_stream.update_fields(acp::ToolCallUpdateFields {
status: Some(acp::ToolCallStatus::InProgress),
..Default::default()
@@ -884,6 +888,7 @@ pub struct ToolCallEventStream {
kind: acp::ToolKind,
input: serde_json::Value,
stream: AgentResponseEventStream,
+ fs: Option<Arc<dyn Fs>>,
}
impl ToolCallEventStream {
@@ -902,6 +907,7 @@ impl ToolCallEventStream {
},
acp::ToolKind::Other,
AgentResponseEventStream(events_tx),
+ None,
);
(stream, ToolCallEventStreamReceiver(events_rx))
@@ -911,12 +917,14 @@ impl ToolCallEventStream {
tool_use: &LanguageModelToolUse,
kind: acp::ToolKind,
stream: AgentResponseEventStream,
+ fs: Option<Arc<dyn Fs>>,
) -> Self {
Self {
tool_use_id: tool_use.id.clone(),
kind,
input: tool_use.input.clone(),
stream,
+ fs,
}
}
@@ -951,7 +959,11 @@ impl ToolCallEventStream {
.ok();
}
- pub fn authorize(&self, title: String) -> impl use<> + Future<Output = Result<()>> {
+ pub fn authorize(&self, title: impl Into<String>, cx: &mut App) -> Task<Result<()>> {
+ if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
+ return Task::ready(Ok(()));
+ }
+
let (response_tx, response_rx) = oneshot::channel();
self.stream
.0
@@ -959,7 +971,7 @@ impl ToolCallEventStream {
ToolCallAuthorization {
tool_call: AgentResponseEventStream::initial_tool_call(
&self.tool_use_id,
- title,
+ title.into(),
self.kind.clone(),
self.input.clone(),
),
@@ -984,12 +996,22 @@ impl ToolCallEventStream {
},
)))
.ok();
- async move {
- match response_rx.await?.0.as_ref() {
- "allow" | "always_allow" => Ok(()),
- _ => Err(anyhow!("Permission to run tool denied by user")),
+ let fs = self.fs.clone();
+ cx.spawn(async move |cx| match response_rx.await?.0.as_ref() {
+ "always_allow" => {
+ if let Some(fs) = fs.clone() {
+ cx.update(|cx| {
+ update_settings_file::<AgentSettings>(fs, cx, |settings, _| {
+ settings.set_always_allow_tool_actions(true);
+ });
+ })?;
+ }
+
+ Ok(())
}
- }
+ "allow" => Ok(()),
+ _ => Err(anyhow!("Permission to run tool denied by user")),
+ })
}
}
@@ -133,7 +133,7 @@ impl EditFileTool {
&self,
input: &EditFileToolInput,
event_stream: &ToolCallEventStream,
- cx: &App,
+ cx: &mut App,
) -> Task<Result<()>> {
if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
return Task::ready(Ok(()));
@@ -147,8 +147,9 @@ impl EditFileTool {
.components()
.any(|component| component.as_os_str() == local_settings_folder.as_os_str())
{
- return cx.foreground_executor().spawn(
- event_stream.authorize(format!("{} (local settings)", input.display_description)),
+ return event_stream.authorize(
+ format!("{} (local settings)", input.display_description),
+ cx,
);
}
@@ -156,9 +157,9 @@ impl EditFileTool {
// so check for that edge case too.
if let Ok(canonical_path) = std::fs::canonicalize(&input.path) {
if canonical_path.starts_with(paths::config_dir()) {
- return cx.foreground_executor().spawn(
- event_stream
- .authorize(format!("{} (global settings)", input.display_description)),
+ return event_stream.authorize(
+ format!("{} (global settings)", input.display_description),
+ cx,
);
}
}
@@ -173,8 +174,7 @@ impl EditFileTool {
if project_path.is_some() {
Task::ready(Ok(()))
} else {
- cx.foreground_executor()
- .spawn(event_stream.authorize(input.display_description.clone()))
+ event_stream.authorize(&input.display_description, cx)
}
}
}
@@ -65,7 +65,7 @@ impl AgentTool for OpenTool {
) -> Task<Result<Self::Output>> {
// If path_or_url turns out to be a path in the project, make it absolute.
let abs_path = to_absolute_path(&input.path_or_url, self.project.clone(), cx);
- let authorize = event_stream.authorize(self.initial_title(Ok(input.clone())).to_string());
+ let authorize = event_stream.authorize(self.initial_title(Ok(input.clone())), cx);
cx.background_spawn(async move {
authorize.await?;
@@ -5,7 +5,6 @@ use gpui::{App, AppContext, Entity, SharedString, Task};
use project::{Project, terminals::TerminalKind};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
-use settings::Settings;
use std::{
path::{Path, PathBuf},
sync::Arc,
@@ -61,21 +60,6 @@ impl TerminalTool {
determine_shell: determine_shell.shared(),
}
}
-
- fn authorize(
- &self,
- input: &TerminalToolInput,
- event_stream: &ToolCallEventStream,
- cx: &App,
- ) -> Task<Result<()>> {
- if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
- return Task::ready(Ok(()));
- }
-
- // TODO: do we want to have a special title here?
- cx.foreground_executor()
- .spawn(event_stream.authorize(self.initial_title(Ok(input.clone())).to_string()))
- }
}
impl AgentTool for TerminalTool {
@@ -152,7 +136,7 @@ impl AgentTool for TerminalTool {
env
});
- let authorize = self.authorize(&input, &event_stream, cx);
+ let authorize = event_stream.authorize(self.initial_title(Ok(input.clone())), cx);
cx.spawn({
async move |cx| {
@@ -2172,6 +2172,9 @@ impl Fs for FakeFs {
async fn atomic_write(&self, path: PathBuf, data: String) -> Result<()> {
self.simulate_random_delay().await;
let path = normalize_path(path.as_path());
+ if let Some(path) = path.parent() {
+ self.create_dir(path).await?;
+ }
self.write_file_internal(path, data.into_bytes(), true)?;
Ok(())
}