Cargo.lock 🔗
@@ -164,6 +164,7 @@ dependencies = [
"libc",
"log",
"nix 0.29.0",
+ "parking_lot",
"paths",
"project",
"schemars",
Agus Zubiaga , Ben Brandt , and Richard Feldman created
Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
Co-authored-by: Richard Feldman <oss@rtfeldman.com>
Cargo.lock | 1
crates/acp_thread/src/acp_thread.rs | 17 ++
crates/agent_servers/Cargo.toml | 3
crates/agent_servers/src/agent_servers.rs | 2
crates/agent_servers/src/claude.rs | 9 +
crates/agent_servers/src/gemini.rs | 94 +++++++++++++++
crates/agent_servers/src/stdio_agent_server.rs | 120 --------------------
crates/agent_ui/src/acp/thread_view.rs | 4
crates/zed/src/zed.rs | 2
9 files changed, 119 insertions(+), 133 deletions(-)
@@ -164,6 +164,7 @@ dependencies = [
"libc",
"log",
"nix 0.29.0",
+ "parking_lot",
"paths",
"project",
"schemars",
@@ -49,7 +49,7 @@ impl UserMessage {
}
fn to_markdown(&self, cx: &App) -> String {
- format!("## User\n{}\n", self.content.to_markdown(cx))
+ format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
}
}
@@ -1711,7 +1711,7 @@ mod tests {
.unwrap();
drop(end_turn_tx);
- request.await.unwrap();
+ assert!(request.await.unwrap_err().to_string().contains("canceled"));
thread.read_with(cx, |thread, _| {
assert!(matches!(
@@ -1816,6 +1816,7 @@ mod tests {
struct FakeAgent {
server: Entity<FakeAcpServer>,
cx: AsyncApp,
+ cancel_tx: Rc<RefCell<Option<oneshot::Sender<()>>>>,
}
impl acp_old::Agent for FakeAgent {
@@ -1834,6 +1835,9 @@ mod tests {
}
async fn cancel_send_message(&self) -> Result<(), acp_old::Error> {
+ if let Some(cancel_tx) = self.cancel_tx.take() {
+ cancel_tx.send(()).log_err();
+ }
Ok(())
}
@@ -1841,6 +1845,9 @@ mod tests {
&self,
request: acp_old::SendUserMessageParams,
) -> Result<(), acp_old::Error> {
+ let (cancel_tx, cancel_rx) = oneshot::channel();
+ self.cancel_tx.replace(Some(cancel_tx));
+
let mut cx = self.cx.clone();
let handler = self
.server
@@ -1848,7 +1855,10 @@ mod tests {
.ok()
.flatten();
if let Some(handler) = handler {
- handler(request, self.server.clone(), self.cx.clone()).await
+ select! {
+ _ = cancel_rx.fuse() => Err(anyhow::anyhow!("Message sending canceled").into()),
+ _ = handler(request, self.server.clone(), self.cx.clone()).fuse() => Ok(()),
+ }
} else {
Err(anyhow::anyhow!("No handler for on_user_message").into())
}
@@ -1860,6 +1870,7 @@ mod tests {
let agent = FakeAgent {
server: cx.entity(),
cx: cx.to_async(),
+ cancel_tx: Default::default(),
};
let foreground_executor = cx.foreground_executor().clone();
@@ -27,12 +27,14 @@ futures.workspace = true
gpui.workspace = true
itertools.workspace = true
log.workspace = true
+parking_lot.workspace = true
paths.workspace = true
project.workspace = true
schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
+shlex.workspace = true
smol.workspace = true
strum.workspace = true
tempfile.workspace = true
@@ -41,7 +43,6 @@ util.workspace = true
uuid.workspace = true
watch.workspace = true
which.workspace = true
-shlex.workspace = true
workspace-hack.workspace = true
[target.'cfg(unix)'.dependencies]
@@ -2,7 +2,6 @@ mod claude;
mod codex;
mod gemini;
mod settings;
-mod stdio_agent_server;
#[cfg(test)]
mod e2e_tests;
@@ -11,7 +10,6 @@ pub use claude::*;
pub use codex::*;
pub use gemini::*;
pub use settings::*;
-pub use stdio_agent_server::*;
use acp_thread::AgentConnection;
use anyhow::Result;
@@ -291,8 +291,15 @@ impl AgentConnection for ClaudeAgentConnection {
.log_err()
.is_some()
{
+ let end_turn_tx = session.end_turn_tx.clone();
cx.foreground_executor()
- .spawn(async move { done_rx.await? })
+ .spawn(async move {
+ done_rx.await??;
+ if let Some(end_turn_tx) = end_turn_tx.take() {
+ end_turn_tx.send(Ok(())).ok();
+ }
+ anyhow::Ok(())
+ })
.detach_and_log_err(cx);
}
}
@@ -1,9 +1,17 @@
-use crate::stdio_agent_server::StdioAgentServer;
-use crate::{AgentServerCommand, AgentServerVersion};
+use anyhow::anyhow;
+use std::cell::RefCell;
+use std::path::Path;
+use std::rc::Rc;
+use util::ResultExt as _;
+
+use crate::{AgentServer, AgentServerCommand, AgentServerVersion};
+use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate};
+use agentic_coding_protocol as acp_old;
use anyhow::{Context as _, Result};
-use gpui::{AsyncApp, Entity};
+use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use project::Project;
use settings::SettingsStore;
+use ui::App;
use crate::AllAgentServersSettings;
@@ -12,7 +20,7 @@ pub struct Gemini;
const ACP_ARG: &str = "--experimental-acp";
-impl StdioAgentServer for Gemini {
+impl AgentServer for Gemini {
fn name(&self) -> &'static str {
"Gemini"
}
@@ -29,6 +37,84 @@ impl StdioAgentServer for Gemini {
ui::IconName::AiGemini
}
+ fn connect(
+ &self,
+ root_dir: &Path,
+ project: &Entity<Project>,
+ cx: &mut App,
+ ) -> Task<Result<Rc<dyn AgentConnection>>> {
+ let root_dir = root_dir.to_path_buf();
+ let project = project.clone();
+ let this = self.clone();
+ let name = self.name();
+
+ cx.spawn(async move |cx| {
+ let command = this.command(&project, cx).await?;
+
+ let mut child = util::command::new_smol_command(&command.path)
+ .args(command.args.iter())
+ .current_dir(root_dir)
+ .stdin(std::process::Stdio::piped())
+ .stdout(std::process::Stdio::piped())
+ .stderr(std::process::Stdio::inherit())
+ .kill_on_drop(true)
+ .spawn()?;
+
+ let stdin = child.stdin.take().unwrap();
+ let stdout = child.stdout.take().unwrap();
+
+ let foreground_executor = cx.foreground_executor().clone();
+
+ let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid()));
+
+ let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
+ OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()),
+ stdin,
+ stdout,
+ move |fut| foreground_executor.spawn(fut).detach(),
+ );
+
+ let io_task = cx.background_spawn(async move {
+ io_fut.await.log_err();
+ });
+
+ let child_status = cx.background_spawn(async move {
+ let result = match child.status().await {
+ Err(e) => Err(anyhow!(e)),
+ Ok(result) if result.success() => Ok(()),
+ Ok(result) => {
+ if let Some(AgentServerVersion::Unsupported {
+ error_message,
+ upgrade_message,
+ upgrade_command,
+ }) = this.version(&command).await.log_err()
+ {
+ Err(anyhow!(LoadError::Unsupported {
+ error_message,
+ upgrade_message,
+ upgrade_command
+ }))
+ } else {
+ Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
+ }
+ }
+ };
+ drop(io_task);
+ result
+ });
+
+ let connection: Rc<dyn AgentConnection> = Rc::new(OldAcpAgentConnection {
+ name,
+ connection,
+ child_status,
+ });
+
+ Ok(connection)
+ })
+ }
+}
+
+impl Gemini {
async fn command(
&self,
project: &Entity<Project>,
@@ -1,120 +0,0 @@
-use crate::{AgentServer, AgentServerCommand, AgentServerVersion};
-use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate};
-use agentic_coding_protocol as acp_old;
-use anyhow::{Result, anyhow};
-use gpui::{App, AsyncApp, Entity, Task, WeakEntity, prelude::*};
-use project::Project;
-use std::{cell::RefCell, path::Path, rc::Rc};
-use util::ResultExt;
-
-pub trait StdioAgentServer: Send + Clone {
- fn logo(&self) -> ui::IconName;
- fn name(&self) -> &'static str;
- fn empty_state_headline(&self) -> &'static str;
- fn empty_state_message(&self) -> &'static str;
-
- fn command(
- &self,
- project: &Entity<Project>,
- cx: &mut AsyncApp,
- ) -> impl Future<Output = Result<AgentServerCommand>>;
-
- fn version(
- &self,
- command: &AgentServerCommand,
- ) -> impl Future<Output = Result<AgentServerVersion>> + Send;
-}
-
-impl<T: StdioAgentServer + 'static> AgentServer for T {
- fn name(&self) -> &'static str {
- self.name()
- }
-
- fn empty_state_headline(&self) -> &'static str {
- self.empty_state_headline()
- }
-
- fn empty_state_message(&self) -> &'static str {
- self.empty_state_message()
- }
-
- fn logo(&self) -> ui::IconName {
- self.logo()
- }
-
- fn connect(
- &self,
- root_dir: &Path,
- project: &Entity<Project>,
- cx: &mut App,
- ) -> Task<Result<Rc<dyn AgentConnection>>> {
- let root_dir = root_dir.to_path_buf();
- let project = project.clone();
- let this = self.clone();
- let name = self.name();
-
- cx.spawn(async move |cx| {
- let command = this.command(&project, cx).await?;
-
- let mut child = util::command::new_smol_command(&command.path)
- .args(command.args.iter())
- .current_dir(root_dir)
- .stdin(std::process::Stdio::piped())
- .stdout(std::process::Stdio::piped())
- .stderr(std::process::Stdio::inherit())
- .kill_on_drop(true)
- .spawn()?;
-
- let stdin = child.stdin.take().unwrap();
- let stdout = child.stdout.take().unwrap();
-
- let foreground_executor = cx.foreground_executor().clone();
-
- let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid()));
-
- let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
- OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()),
- stdin,
- stdout,
- move |fut| foreground_executor.spawn(fut).detach(),
- );
-
- let io_task = cx.background_spawn(async move {
- io_fut.await.log_err();
- });
-
- let child_status = cx.background_spawn(async move {
- let result = match child.status().await {
- Err(e) => Err(anyhow!(e)),
- Ok(result) if result.success() => Ok(()),
- Ok(result) => {
- if let Some(AgentServerVersion::Unsupported {
- error_message,
- upgrade_message,
- upgrade_command,
- }) = this.version(&command).await.log_err()
- {
- Err(anyhow!(LoadError::Unsupported {
- error_message,
- upgrade_message,
- upgrade_command
- }))
- } else {
- Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
- }
- }
- };
- drop(io_task);
- result
- });
-
- let connection: Rc<dyn AgentConnection> = Rc::new(OldAcpAgentConnection {
- name,
- connection,
- child_status,
- });
-
- Ok(connection)
- })
- }
-}
@@ -68,6 +68,7 @@ pub struct AcpThreadView {
plan_expanded: bool,
editor_expanded: bool,
message_history: Rc<RefCell<MessageHistory<Vec<acp::ContentBlock>>>>,
+ _cancel_task: Option<Task<()>>,
}
enum ThreadState {
@@ -183,6 +184,7 @@ impl AcpThreadView {
plan_expanded: false,
editor_expanded: false,
message_history,
+ _cancel_task: None,
}
}
@@ -299,7 +301,7 @@ impl AcpThreadView {
self.last_error.take();
if let Some(thread) = self.thread() {
- thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
+ self._cancel_task = Some(thread.update(cx, |thread, cx| thread.cancel(cx)));
}
}
@@ -19,7 +19,7 @@ use collections::VecDeque;
use debugger_ui::debugger_panel::DebugPanel;
use editor::ProposedChangesEditorToolbar;
use editor::{Editor, MultiBuffer};
-use feature_flags::{FeatureFlagAppExt, PanicFeatureFlag};
+use feature_flags::FeatureFlagAppExt;
use futures::future::Either;
use futures::{StreamExt, channel::mpsc, select_biased};
use git_ui::git_panel::GitPanel;