Cargo.lock 🔗
@@ -20,7 +20,9 @@ dependencies = [
"itertools 0.14.0",
"language",
"markdown",
+ "parking_lot",
"project",
+ "rand 0.8.5",
"serde",
"serde_json",
"settings",
Ben Brandt created
Release Notes:
- N/A
Cargo.lock | 2
crates/acp_thread/Cargo.toml | 2
crates/acp_thread/src/acp_thread.rs | 148 ++++++++++++++++++++++++++++++
3 files changed, 149 insertions(+), 3 deletions(-)
@@ -20,7 +20,9 @@ dependencies = [
"itertools 0.14.0",
"language",
"markdown",
+ "parking_lot",
"project",
+ "rand 0.8.5",
"serde",
"serde_json",
"settings",
@@ -41,7 +41,9 @@ async-pipe.workspace = true
env_logger.workspace = true
gpui = { workspace = true, "features" = ["test-support"] }
indoc.workspace = true
+parking_lot.workspace = true
project = { workspace = true, "features" = ["test-support"] }
+rand.workspace = true
tempfile.workspace = true
util.workspace = true
settings.workspace = true
@@ -671,7 +671,18 @@ impl AcpThread {
for entry in self.entries.iter().rev() {
match entry {
AgentThreadEntry::UserMessage(_) => return false,
- AgentThreadEntry::ToolCall(call) if call.diffs().next().is_some() => return true,
+ AgentThreadEntry::ToolCall(
+ call @ ToolCall {
+ status:
+ ToolCallStatus::Allowed {
+ status:
+ acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
+ },
+ ..
+ },
+ ) if call.diffs().next().is_some() => {
+ return true;
+ }
AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
}
}
@@ -1231,10 +1242,15 @@ mod tests {
use agentic_coding_protocol as acp_old;
use anyhow::anyhow;
use async_pipe::{PipeReader, PipeWriter};
- use futures::{channel::mpsc, future::LocalBoxFuture, select};
- use gpui::{AsyncApp, TestAppContext};
+ use futures::{
+ channel::mpsc,
+ future::{LocalBoxFuture, try_join_all},
+ select,
+ };
+ use gpui::{AsyncApp, TestAppContext, WeakEntity};
use indoc::indoc;
use project::FakeFs;
+ use rand::Rng as _;
use serde_json::json;
use settings::SettingsStore;
use smol::{future::BoxedLocal, stream::StreamExt as _};
@@ -1562,6 +1578,42 @@ mod tests {
});
}
+ #[gpui::test]
+ async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.background_executor.clone());
+ fs.insert_tree(path!("/test"), json!({})).await;
+ let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+
+ let connection = Rc::new(StubAgentConnection::new(vec![
+ acp::SessionUpdate::ToolCall(acp::ToolCall {
+ id: acp::ToolCallId("test".into()),
+ label: "Label".into(),
+ kind: acp::ToolKind::Edit,
+ status: acp::ToolCallStatus::Completed,
+ content: vec![acp::ToolCallContent::Diff {
+ diff: acp::Diff {
+ path: "/test/test.txt".into(),
+ old_text: None,
+ new_text: "foo".into(),
+ },
+ }],
+ locations: vec![],
+ raw_input: None,
+ }),
+ ]));
+
+ let thread = connection
+ .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
+ .await
+ .unwrap();
+ cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
+ .await
+ .unwrap();
+
+ assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
+ }
+
async fn run_until_first_tool_call(
thread: &Entity<AcpThread>,
cx: &mut TestAppContext,
@@ -1589,6 +1641,96 @@ mod tests {
}
}
+ #[derive(Clone, Default)]
+ struct StubAgentConnection {
+ sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
+ permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
+ updates: Vec<acp::SessionUpdate>,
+ }
+
+ impl StubAgentConnection {
+ fn new(updates: Vec<acp::SessionUpdate>) -> Self {
+ Self {
+ updates,
+ permission_requests: HashMap::default(),
+ sessions: Arc::default(),
+ }
+ }
+ }
+
+ impl AgentConnection for StubAgentConnection {
+ fn name(&self) -> &'static str {
+ "StubAgentConnection"
+ }
+
+ fn new_thread(
+ self: Rc<Self>,
+ project: Entity<Project>,
+ _cwd: &Path,
+ cx: &mut gpui::AsyncApp,
+ ) -> Task<gpui::Result<Entity<AcpThread>>> {
+ let session_id = acp::SessionId(
+ rand::thread_rng()
+ .sample_iter(&rand::distributions::Alphanumeric)
+ .take(7)
+ .map(char::from)
+ .collect::<String>()
+ .into(),
+ );
+ let thread = cx
+ .new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))
+ .unwrap();
+ self.sessions.lock().insert(session_id, thread.downgrade());
+ Task::ready(Ok(thread))
+ }
+
+ fn authenticate(&self, _cx: &mut App) -> Task<gpui::Result<()>> {
+ unimplemented!()
+ }
+
+ fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<gpui::Result<()>> {
+ let sessions = self.sessions.lock();
+ let thread = sessions.get(¶ms.session_id).unwrap();
+ let mut tasks = vec![];
+ for update in &self.updates {
+ let thread = thread.clone();
+ let update = update.clone();
+ let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
+ && let Some(options) = self.permission_requests.get(&tool_call.id)
+ {
+ Some((tool_call.clone(), options.clone()))
+ } else {
+ None
+ };
+ let task = cx.spawn(async move |cx| {
+ if let Some((tool_call, options)) = permission_request {
+ let permission = thread.update(cx, |thread, cx| {
+ thread.request_tool_call_permission(
+ tool_call.clone(),
+ options.clone(),
+ cx,
+ )
+ })?;
+ permission.await?;
+ }
+ thread.update(cx, |thread, cx| {
+ thread.handle_session_update(update.clone(), cx).unwrap();
+ })?;
+ anyhow::Ok(())
+ });
+ tasks.push(task);
+ }
+ cx.spawn(async move |_| {
+ try_join_all(tasks).await?;
+ Ok(())
+ })
+ }
+
+ fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
+ unimplemented!()
+ }
+ }
+
pub fn fake_acp_thread(
project: Entity<Project>,
cx: &mut TestAppContext,