diff --git a/Cargo.lock b/Cargo.lock index 4537d440ccd97b4934164eee3f0bdf0230edd4ec..ad6c40bcf20ccc8cf770313d19deb831d864be3f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -150,7 +150,9 @@ dependencies = [ "indoc", "itertools 0.14.0", "language", + "libc", "log", + "nix 0.29.0", "paths", "project", "schemars", @@ -162,6 +164,7 @@ dependencies = [ "tempfile", "ui", "util", + "uuid", "watch", "which 6.0.3", "workspace-hack", diff --git a/crates/agent_servers/Cargo.toml b/crates/agent_servers/Cargo.toml index f3df25f70914e95c63265e4c711ca76fd66050b1..4714245b94fd9b519cfe1987817d51ff6ecbc7fd 100644 --- a/crates/agent_servers/Cargo.toml +++ b/crates/agent_servers/Cargo.toml @@ -37,10 +37,15 @@ strum.workspace = true tempfile.workspace = true ui.workspace = true util.workspace = true +uuid.workspace = true watch.workspace = true which.workspace = true workspace-hack.workspace = true +[target.'cfg(unix)'.dependencies] +libc.workspace = true +nix.workspace = true + [dev-dependencies] env_logger.workspace = true language.workspace = true diff --git a/crates/agent_servers/src/claude.rs b/crates/agent_servers/src/claude.rs index 8b3d93a122d07448ddbfb9daf2dfc9226fb11545..835efbd6552423e7e5bcd1d321ca193581e5ab0a 100644 --- a/crates/agent_servers/src/claude.rs +++ b/crates/agent_servers/src/claude.rs @@ -4,10 +4,13 @@ mod tools; use collections::HashMap; use project::Project; use settings::SettingsStore; +use smol::process::Child; use std::cell::RefCell; use std::fmt::Display; use std::path::Path; +use std::pin::pin; use std::rc::Rc; +use uuid::Uuid; use agentic_coding_protocol::{ self as acp, AnyAgentRequest, AnyAgentResult, Client, ProtocolVersion, @@ -16,7 +19,7 @@ use agentic_coding_protocol::{ use anyhow::{Result, anyhow}; use futures::channel::oneshot; use futures::future::LocalBoxFuture; -use futures::{AsyncBufReadExt, AsyncWriteExt}; +use futures::{AsyncBufReadExt, AsyncWriteExt, SinkExt}; use futures::{ AsyncRead, AsyncWrite, FutureExt, StreamExt, channel::mpsc::{self, UnboundedReceiver, UnboundedSender}, @@ -69,13 +72,12 @@ impl AgentServer for ClaudeCode { let (mut delegate_tx, delegate_rx) = watch::channel(None); let tool_id_map = Rc::new(RefCell::new(HashMap::default())); - let permission_mcp_server = - ClaudeMcpServer::new(delegate_rx, tool_id_map.clone(), cx).await?; + let mcp_server = ClaudeMcpServer::new(delegate_rx, tool_id_map.clone(), cx).await?; let mut mcp_servers = HashMap::default(); mcp_servers.insert( mcp_server::SERVER_NAME.to_string(), - permission_mcp_server.server_config()?, + mcp_server.server_config()?, ); let mcp_config = McpConfig { mcp_servers }; @@ -98,50 +100,58 @@ impl AgentServer for ClaudeCode { anyhow::bail!("Failed to find claude binary"); }; - let mut child = util::command::new_smol_command(&command.path) - .args( - [ - "--input-format", - "stream-json", - "--output-format", - "stream-json", - "--print", - "--verbose", - "--mcp-config", - mcp_config_path.to_string_lossy().as_ref(), - "--permission-prompt-tool", - &format!( - "mcp__{}__{}", - mcp_server::SERVER_NAME, - mcp_server::PERMISSION_TOOL - ), - "--allowedTools", - "mcp__zed__Read,mcp__zed__Edit", - "--disallowedTools", - "Read,Edit", - ] - .into_iter() - .chain(command.args.iter().map(|arg| arg.as_str())), - ) - .current_dir(root_dir) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::inherit()) - .kill_on_drop(true) - .spawn()?; - - let stdin = child.stdin.take().unwrap(); - let stdout = child.stdout.take().unwrap(); - let (incoming_message_tx, mut incoming_message_rx) = mpsc::unbounded(); let (outgoing_tx, outgoing_rx) = mpsc::unbounded(); + let (cancel_tx, mut cancel_rx) = mpsc::unbounded::>>(); + + let session_id = Uuid::new_v4(); + + log::trace!("Starting session with id: {}", session_id); - let io_task = - ClaudeAgentConnection::handle_io(outgoing_rx, incoming_message_tx, stdin, stdout); cx.background_spawn(async move { - io_task.await.log_err(); + let mut outgoing_rx = Some(outgoing_rx); + let mut mode = ClaudeSessionMode::Start; + + loop { + let mut child = + spawn_claude(&command, mode, session_id, &mcp_config_path, &root_dir) + .await?; + mode = ClaudeSessionMode::Resume; + + let pid = child.id(); + log::trace!("Spawned (pid: {})", pid); + + let mut io_fut = pin!( + ClaudeAgentConnection::handle_io( + outgoing_rx.take().unwrap(), + incoming_message_tx.clone(), + child.stdin.take().unwrap(), + child.stdout.take().unwrap(), + ) + .fuse() + ); + + select_biased! { + done_tx = cancel_rx.next() => { + if let Some(done_tx) = done_tx { + log::trace!("Interrupted (pid: {})", pid); + let result = send_interrupt(pid as i32); + outgoing_rx.replace(io_fut.await?); + done_tx.send(result).log_err(); + continue; + } + } + result = io_fut => { + result?; + } + } + + log::trace!("Stopped (pid: {})", pid); + break; + } + drop(mcp_config_path); - drop(child); + anyhow::Ok(()) }) .detach(); @@ -171,17 +181,32 @@ impl AgentServer for ClaudeCode { delegate, outgoing_tx, end_turn_tx, + cancel_tx, + session_id, _handler_task: handler_task, _mcp_server: None, }; - connection._mcp_server = Some(permission_mcp_server); + connection._mcp_server = Some(mcp_server); acp_thread::AcpThread::new(connection, title, None, project.clone(), cx) }) }) } } +#[cfg(unix)] +fn send_interrupt(pid: libc::pid_t) -> anyhow::Result<()> { + let pid = nix::unistd::Pid::from_raw(pid); + + nix::sys::signal::kill(pid, nix::sys::signal::SIGINT) + .map_err(|e| anyhow!("Failed to interrupt process: {}", e)) +} + +#[cfg(windows)] +fn send_interrupt(_pid: i32) -> anyhow::Result<()> { + panic!("Cancel not implemented on Windows") +} + impl AgentConnection for ClaudeAgentConnection { /// Send a request to the agent and wait for a response. fn request_any( @@ -191,6 +216,8 @@ impl AgentConnection for ClaudeAgentConnection { let delegate = self.delegate.clone(); let end_turn_tx = self.end_turn_tx.clone(); let outgoing_tx = self.outgoing_tx.clone(); + let mut cancel_tx = self.cancel_tx.clone(); + let session_id = self.session_id; async move { match params { // todo: consider sending an empty request so we get the init response? @@ -229,26 +256,83 @@ impl AgentConnection for ClaudeAgentConnection { stop_sequence: None, usage: None, }, - session_id: None, + session_id: Some(session_id), })?; rx.await??; Ok(AnyAgentResult::SendUserMessageResponse( acp::SendUserMessageResponse, )) } - AnyAgentRequest::CancelSendMessageParams(_) => Ok( - AnyAgentResult::CancelSendMessageResponse(acp::CancelSendMessageResponse), - ), + AnyAgentRequest::CancelSendMessageParams(_) => { + let (done_tx, done_rx) = oneshot::channel(); + cancel_tx.send(done_tx).await?; + done_rx.await??; + + Ok(AnyAgentResult::CancelSendMessageResponse( + acp::CancelSendMessageResponse, + )) + } } } .boxed_local() } } +#[derive(Clone, Copy)] +enum ClaudeSessionMode { + Start, + Resume, +} + +async fn spawn_claude( + command: &AgentServerCommand, + mode: ClaudeSessionMode, + session_id: Uuid, + mcp_config_path: &Path, + root_dir: &Path, +) -> Result { + let child = util::command::new_smol_command(&command.path) + .args([ + "--input-format", + "stream-json", + "--output-format", + "stream-json", + "--print", + "--verbose", + "--mcp-config", + mcp_config_path.to_string_lossy().as_ref(), + "--permission-prompt-tool", + &format!( + "mcp__{}__{}", + mcp_server::SERVER_NAME, + mcp_server::PERMISSION_TOOL + ), + "--allowedTools", + "mcp__zed__Read,mcp__zed__Edit", + "--disallowedTools", + "Read,Edit", + ]) + .args(match mode { + ClaudeSessionMode::Start => ["--session-id".to_string(), session_id.to_string()], + ClaudeSessionMode::Resume => ["--resume".to_string(), session_id.to_string()], + }) + .args(command.args.iter().map(|arg| arg.as_str())) + .current_dir(root_dir) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::inherit()) + .kill_on_drop(true) + .spawn()?; + + Ok(child) +} + struct ClaudeAgentConnection { delegate: AcpClientDelegate, + session_id: Uuid, outgoing_tx: UnboundedSender, end_turn_tx: Rc>>>>, + cancel_tx: UnboundedSender>>, _mcp_server: Option, _handler_task: Task<()>, } @@ -350,7 +434,7 @@ impl ClaudeAgentConnection { incoming_tx: UnboundedSender, mut outgoing_bytes: impl Unpin + AsyncWrite, incoming_bytes: impl Unpin + AsyncRead, - ) -> Result<()> { + ) -> Result> { let mut output_reader = BufReader::new(incoming_bytes); let mut outgoing_line = Vec::new(); let mut incoming_line = String::new(); @@ -384,7 +468,8 @@ impl ClaudeAgentConnection { } } } - Ok(()) + + Ok(outgoing_rx) } } @@ -507,14 +592,14 @@ enum SdkMessage { Assistant { message: Message, // from Anthropic SDK #[serde(skip_serializing_if = "Option::is_none")] - session_id: Option, + session_id: Option, }, // A user message User { message: Message, // from Anthropic SDK #[serde(skip_serializing_if = "Option::is_none")] - session_id: Option, + session_id: Option, }, // Emitted as the last message in a conversation