Detailed changes
@@ -7,10 +7,8 @@ name = "acp_thread"
version = "0.1.0"
dependencies = [
"agent-client-protocol",
- "agentic-coding-protocol",
"anyhow",
"assistant_tool",
- "async-pipe",
"buffer_diff",
"editor",
"env_logger 0.11.8",
@@ -139,10 +137,14 @@ dependencies = [
[[package]]
name = "agent-client-protocol"
-version = "0.0.11"
+version = "0.0.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "72ec54650c1fc2d63498bab47eeeaa9eddc7d239d53f615b797a0e84f7ccc87b"
+checksum = "22c5180e40d31a9998ffa5f8eb067667f0870908a4aeed65a6a299e2d1d95443"
dependencies = [
+ "anyhow",
+ "futures 0.3.31",
+ "log",
+ "parking_lot",
"schemars",
"serde",
"serde_json",
@@ -177,6 +179,7 @@ dependencies = [
"smol",
"strum 0.27.1",
"tempfile",
+ "thiserror 2.0.12",
"ui",
"util",
"uuid",
@@ -9572,9 +9575,9 @@ dependencies = [
[[package]]
name = "lock_api"
-version = "0.4.12"
+version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17"
+checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765"
dependencies = [
"autocfg",
"scopeguard",
@@ -11288,9 +11291,9 @@ checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba"
[[package]]
name = "parking_lot"
-version = "0.12.3"
+version = "0.12.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27"
+checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13"
dependencies = [
"lock_api",
"parking_lot_core",
@@ -11298,9 +11301,9 @@ dependencies = [
[[package]]
name = "parking_lot_core"
-version = "0.9.10"
+version = "0.9.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8"
+checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5"
dependencies = [
"cfg-if",
"libc",
@@ -421,7 +421,7 @@ zlog_settings = { path = "crates/zlog_settings" }
#
agentic-coding-protocol = "0.0.10"
-agent-client-protocol = "0.0.11"
+agent-client-protocol = "0.0.17"
aho-corasick = "1.1"
alacritty_terminal = { git = "https://github.com/zed-industries/alacritty.git", branch = "add-hush-login-flag" }
any_vec = "0.14"
@@ -17,7 +17,6 @@ test-support = ["gpui/test-support", "project/test-support"]
[dependencies]
agent-client-protocol.workspace = true
-agentic-coding-protocol.workspace = true
anyhow.workspace = true
assistant_tool.workspace = true
buffer_diff.workspace = true
@@ -37,7 +36,6 @@ util.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
-async-pipe.workspace = true
env_logger.workspace = true
gpui = { workspace = true, "features" = ["test-support"] }
indoc.workspace = true
@@ -1,7 +1,5 @@
mod connection;
-mod old_acp_support;
pub use connection::*;
-pub use old_acp_support::*;
use agent_client_protocol as acp;
use anyhow::{Context as _, Result};
@@ -391,7 +389,7 @@ impl ToolCallContent {
cx: &mut App,
) -> Self {
match content {
- acp::ToolCallContent::ContentBlock(content) => Self::ContentBlock {
+ acp::ToolCallContent::Content { content } => Self::ContentBlock {
content: ContentBlock::new(content, &language_registry, cx),
},
acp::ToolCallContent::Diff { diff } => Self::Diff {
@@ -619,6 +617,7 @@ impl Error for LoadError {}
impl AcpThread {
pub fn new(
+ title: impl Into<SharedString>,
connection: Rc<dyn AgentConnection>,
project: Entity<Project>,
session_id: acp::SessionId,
@@ -631,7 +630,7 @@ impl AcpThread {
shared_buffers: Default::default(),
entries: Default::default(),
plan: Default::default(),
- title: connection.name().into(),
+ title: title.into(),
project,
send_task: None,
connection,
@@ -708,14 +707,14 @@ impl AcpThread {
cx: &mut Context<Self>,
) -> Result<()> {
match update {
- acp::SessionUpdate::UserMessage(content_block) => {
- self.push_user_content_block(content_block, cx);
+ acp::SessionUpdate::UserMessageChunk { content } => {
+ self.push_user_content_block(content, cx);
}
- acp::SessionUpdate::AgentMessageChunk(content_block) => {
- self.push_assistant_content_block(content_block, false, cx);
+ acp::SessionUpdate::AgentMessageChunk { content } => {
+ self.push_assistant_content_block(content, false, cx);
}
- acp::SessionUpdate::AgentThoughtChunk(content_block) => {
- self.push_assistant_content_block(content_block, true, cx);
+ acp::SessionUpdate::AgentThoughtChunk { content } => {
+ self.push_assistant_content_block(content, true, cx);
}
acp::SessionUpdate::ToolCall(tool_call) => {
self.upsert_tool_call(tool_call, cx);
@@ -984,10 +983,6 @@ impl AcpThread {
cx.notify();
}
- pub fn authenticate(&self, cx: &mut App) -> impl use<> + Future<Output = Result<()>> {
- self.connection.authenticate(cx)
- }
-
#[cfg(any(test, feature = "test-support"))]
pub fn send_raw(
&mut self,
@@ -1029,7 +1024,7 @@ impl AcpThread {
let result = this
.update(cx, |this, cx| {
this.connection.prompt(
- acp::PromptArguments {
+ acp::PromptRequest {
prompt: message,
session_id: this.session_id.clone(),
},
@@ -1239,21 +1234,15 @@ impl AcpThread {
#[cfg(test)]
mod tests {
use super::*;
- use agentic_coding_protocol as acp_old;
use anyhow::anyhow;
- use async_pipe::{PipeReader, PipeWriter};
- use futures::{
- channel::mpsc,
- future::{LocalBoxFuture, try_join_all},
- select,
- };
+ use futures::{channel::mpsc, future::LocalBoxFuture, select};
use gpui::{AsyncApp, TestAppContext, WeakEntity};
use indoc::indoc;
use project::FakeFs;
use rand::Rng as _;
use serde_json::json;
use settings::SettingsStore;
- use smol::{future::BoxedLocal, stream::StreamExt as _};
+ use smol::stream::StreamExt as _;
use std::{cell::RefCell, rc::Rc, time::Duration};
use util::path;
@@ -1274,7 +1263,15 @@ mod tests {
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, [], cx).await;
- let (thread, _fake_server) = fake_acp_thread(project, cx);
+ let connection = Rc::new(FakeAgentConnection::new());
+ let thread = cx
+ .spawn(async move |mut cx| {
+ connection
+ .new_thread(project, Path::new(path!("/test")), &mut cx)
+ .await
+ })
+ .await
+ .unwrap();
// Test creating a new user message
thread.update(cx, |thread, cx| {
@@ -1354,34 +1351,40 @@ mod tests {
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, [], cx).await;
- let (thread, fake_server) = fake_acp_thread(project, cx);
-
- fake_server.update(cx, |fake_server, _| {
- fake_server.on_user_message(move |_, server, mut cx| async move {
- server
- .update(&mut cx, |server, _| {
- server.send_to_zed(acp_old::StreamAssistantMessageChunkParams {
- chunk: acp_old::AssistantMessageChunk::Thought {
- thought: "Thinking ".into(),
- },
- })
- })?
- .await
- .unwrap();
- server
- .update(&mut cx, |server, _| {
- server.send_to_zed(acp_old::StreamAssistantMessageChunkParams {
- chunk: acp_old::AssistantMessageChunk::Thought {
- thought: "hard!".into(),
- },
- })
- })?
- .await
- .unwrap();
+ let connection = Rc::new(FakeAgentConnection::new().on_user_message(
+ |_, thread, mut cx| {
+ async move {
+ thread.update(&mut cx, |thread, cx| {
+ thread
+ .handle_session_update(
+ acp::SessionUpdate::AgentThoughtChunk {
+ content: "Thinking ".into(),
+ },
+ cx,
+ )
+ .unwrap();
+ thread
+ .handle_session_update(
+ acp::SessionUpdate::AgentThoughtChunk {
+ content: "hard!".into(),
+ },
+ cx,
+ )
+ .unwrap();
+ })
+ }
+ .boxed_local()
+ },
+ ));
- Ok(())
+ let thread = cx
+ .spawn(async move |mut cx| {
+ connection
+ .new_thread(project, Path::new(path!("/test")), &mut cx)
+ .await
})
- });
+ .await
+ .unwrap();
thread
.update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
@@ -1414,7 +1417,38 @@ mod tests {
fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
.await;
let project = Project::test(fs.clone(), [], cx).await;
- let (thread, fake_server) = fake_acp_thread(project.clone(), cx);
+ let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
+ let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
+ let connection = Rc::new(FakeAgentConnection::new().on_user_message(
+ move |_, thread, mut cx| {
+ let read_file_tx = read_file_tx.clone();
+ async move {
+ let content = thread
+ .update(&mut cx, |thread, cx| {
+ thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
+ })
+ .unwrap()
+ .await
+ .unwrap();
+ assert_eq!(content, "one\ntwo\nthree\n");
+ read_file_tx.take().unwrap().send(()).unwrap();
+ thread
+ .update(&mut cx, |thread, cx| {
+ thread.write_text_file(
+ path!("/tmp/foo").into(),
+ "one\ntwo\nthree\nfour\nfive\n".to_string(),
+ cx,
+ )
+ })
+ .unwrap()
+ .await
+ .unwrap();
+ Ok(())
+ }
+ .boxed_local()
+ },
+ ));
+
let (worktree, pathbuf) = project
.update(cx, |project, cx| {
project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
@@ -1428,38 +1462,10 @@ mod tests {
.await
.unwrap();
- let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
- let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
-
- fake_server.update(cx, |fake_server, _| {
- fake_server.on_user_message(move |_, server, mut cx| {
- let read_file_tx = read_file_tx.clone();
- async move {
- let content = server
- .update(&mut cx, |server, _| {
- server.send_to_zed(acp_old::ReadTextFileParams {
- path: path!("/tmp/foo").into(),
- line: None,
- limit: None,
- })
- })?
- .await
- .unwrap();
- assert_eq!(content.content, "one\ntwo\nthree\n");
- read_file_tx.take().unwrap().send(()).unwrap();
- server
- .update(&mut cx, |server, _| {
- server.send_to_zed(acp_old::WriteTextFileParams {
- path: path!("/tmp/foo").into(),
- content: "one\ntwo\nthree\nfour\nfive\n".to_string(),
- })
- })?
- .await
- .unwrap();
- Ok(())
- }
- })
- });
+ let thread = cx
+ .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
+ .await
+ .unwrap();
let request = thread.update(cx, |thread, cx| {
thread.send_raw("Extend the count in /tmp/foo", cx)
@@ -1486,36 +1492,44 @@ mod tests {
let fs = FakeFs::new(cx.executor());
let project = Project::test(fs, [], cx).await;
- let (thread, fake_server) = fake_acp_thread(project, cx);
-
- let (end_turn_tx, end_turn_rx) = oneshot::channel::<()>();
+ let id = acp::ToolCallId("test".into());
- let tool_call_id = Rc::new(RefCell::new(None));
- let end_turn_rx = Rc::new(RefCell::new(Some(end_turn_rx)));
- fake_server.update(cx, |fake_server, _| {
- let tool_call_id = tool_call_id.clone();
- fake_server.on_user_message(move |_, server, mut cx| {
- let end_turn_rx = end_turn_rx.clone();
- let tool_call_id = tool_call_id.clone();
+ let connection = Rc::new(FakeAgentConnection::new().on_user_message({
+ let id = id.clone();
+ move |_, thread, mut cx| {
+ let id = id.clone();
async move {
- let tool_call_result = server
- .update(&mut cx, |server, _| {
- server.send_to_zed(acp_old::PushToolCallParams {
- label: "Fetch".to_string(),
- icon: acp_old::Icon::Globe,
- content: None,
- locations: vec![],
- })
- })?
- .await
+ thread
+ .update(&mut cx, |thread, cx| {
+ thread.handle_session_update(
+ acp::SessionUpdate::ToolCall(acp::ToolCall {
+ id: id.clone(),
+ label: "Label".into(),
+ kind: acp::ToolKind::Fetch,
+ status: acp::ToolCallStatus::InProgress,
+ content: vec![],
+ locations: vec![],
+ raw_input: None,
+ }),
+ cx,
+ )
+ })
+ .unwrap()
.unwrap();
- *tool_call_id.clone().borrow_mut() = Some(tool_call_result.id);
- end_turn_rx.take().unwrap().await.ok();
-
Ok(())
}
+ .boxed_local()
+ }
+ }));
+
+ let thread = cx
+ .spawn(async move |mut cx| {
+ connection
+ .new_thread(project, Path::new(path!("/test")), &mut cx)
+ .await
})
- });
+ .await
+ .unwrap();
let request = thread.update(cx, |thread, cx| {
thread.send_raw("Fetch https://example.com", cx)
@@ -1536,8 +1550,6 @@ mod tests {
));
});
- cx.run_until_parked();
-
thread.update(cx, |thread, cx| thread.cancel(cx)).await;
thread.read_with(cx, |thread, _| {
@@ -1550,19 +1562,22 @@ mod tests {
));
});
- fake_server
- .update(cx, |fake_server, _| {
- fake_server.send_to_zed(acp_old::UpdateToolCallParams {
- tool_call_id: tool_call_id.borrow().unwrap(),
- status: acp_old::ToolCallStatus::Finished,
- content: None,
- })
+ thread
+ .update(cx, |thread, cx| {
+ thread.handle_session_update(
+ acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
+ id,
+ fields: acp::ToolCallUpdateFields {
+ status: Some(acp::ToolCallStatus::Completed),
+ ..Default::default()
+ },
+ }),
+ cx,
+ )
})
- .await
.unwrap();
- drop(end_turn_tx);
- assert!(request.await.unwrap_err().to_string().contains("canceled"));
+ request.await.unwrap();
thread.read_with(cx, |thread, _| {
assert!(matches!(
@@ -1585,23 +1600,37 @@ mod tests {
fs.insert_tree(path!("/test"), json!({})).await;
let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
- let connection = Rc::new(StubAgentConnection::new(vec![
- acp::SessionUpdate::ToolCall(acp::ToolCall {
- id: acp::ToolCallId("test".into()),
- label: "Label".into(),
- kind: acp::ToolKind::Edit,
- status: acp::ToolCallStatus::Completed,
- content: vec![acp::ToolCallContent::Diff {
- diff: acp::Diff {
- path: "/test/test.txt".into(),
- old_text: None,
- new_text: "foo".into(),
- },
- }],
- locations: vec![],
- raw_input: None,
- }),
- ]));
+ let connection = Rc::new(FakeAgentConnection::new().on_user_message({
+ move |_, thread, mut cx| {
+ async move {
+ thread
+ .update(&mut cx, |thread, cx| {
+ thread.handle_session_update(
+ acp::SessionUpdate::ToolCall(acp::ToolCall {
+ id: acp::ToolCallId("test".into()),
+ label: "Label".into(),
+ kind: acp::ToolKind::Edit,
+ status: acp::ToolCallStatus::Completed,
+ content: vec![acp::ToolCallContent::Diff {
+ diff: acp::Diff {
+ path: "/test/test.txt".into(),
+ old_text: None,
+ new_text: "foo".into(),
+ },
+ }],
+ locations: vec![],
+ raw_input: None,
+ }),
+ cx,
+ )
+ })
+ .unwrap()
+ .unwrap();
+ Ok(())
+ }
+ .boxed_local()
+ }
+ }));
let thread = connection
.new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
@@ -1642,25 +1671,53 @@ mod tests {
}
#[derive(Clone, Default)]
- struct StubAgentConnection {
+ struct FakeAgentConnection {
+ auth_methods: Vec<acp::AuthMethod>,
sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
- permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
- updates: Vec<acp::SessionUpdate>,
+ on_user_message: Option<
+ Rc<
+ dyn Fn(
+ acp::PromptRequest,
+ WeakEntity<AcpThread>,
+ AsyncApp,
+ ) -> LocalBoxFuture<'static, Result<()>>
+ + 'static,
+ >,
+ >,
}
- impl StubAgentConnection {
- fn new(updates: Vec<acp::SessionUpdate>) -> Self {
+ impl FakeAgentConnection {
+ fn new() -> Self {
Self {
- updates,
- permission_requests: HashMap::default(),
+ auth_methods: Vec::new(),
+ on_user_message: None,
sessions: Arc::default(),
}
}
+
+ #[expect(unused)]
+ fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
+ self.auth_methods = auth_methods;
+ self
+ }
+
+ fn on_user_message(
+ mut self,
+ handler: impl Fn(
+ acp::PromptRequest,
+ WeakEntity<AcpThread>,
+ AsyncApp,
+ ) -> LocalBoxFuture<'static, Result<()>>
+ + 'static,
+ ) -> Self {
+ self.on_user_message.replace(Rc::new(handler));
+ self
+ }
}
- impl AgentConnection for StubAgentConnection {
- fn name(&self) -> &'static str {
- "StubAgentConnection"
+ impl AgentConnection for FakeAgentConnection {
+ fn auth_methods(&self) -> &[acp::AuthMethod] {
+ &self.auth_methods
}
fn new_thread(
@@ -1678,222 +1735,43 @@ mod tests {
.into(),
);
let thread = cx
- .new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))
+ .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
.unwrap();
self.sessions.lock().insert(session_id, thread.downgrade());
Task::ready(Ok(thread))
}
- fn authenticate(&self, _cx: &mut App) -> Task<gpui::Result<()>> {
- unimplemented!()
+ fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
+ if self.auth_methods().iter().any(|m| m.id == method) {
+ Task::ready(Ok(()))
+ } else {
+ Task::ready(Err(anyhow!("Invalid Auth Method")))
+ }
}
- fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<gpui::Result<()>> {
+ fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<gpui::Result<()>> {
let sessions = self.sessions.lock();
let thread = sessions.get(¶ms.session_id).unwrap();
- let mut tasks = vec![];
- for update in &self.updates {
+ if let Some(handler) = &self.on_user_message {
+ let handler = handler.clone();
let thread = thread.clone();
- let update = update.clone();
- let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
- && let Some(options) = self.permission_requests.get(&tool_call.id)
- {
- Some((tool_call.clone(), options.clone()))
- } else {
- None
- };
- let task = cx.spawn(async move |cx| {
- if let Some((tool_call, options)) = permission_request {
- let permission = thread.update(cx, |thread, cx| {
- thread.request_tool_call_permission(
- tool_call.clone(),
- options.clone(),
- cx,
- )
- })?;
- permission.await?;
- }
- thread.update(cx, |thread, cx| {
- thread.handle_session_update(update.clone(), cx).unwrap();
- })?;
- anyhow::Ok(())
- });
- tasks.push(task);
- }
- cx.spawn(async move |_| {
- try_join_all(tasks).await?;
- Ok(())
- })
- }
-
- fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
- unimplemented!()
- }
- }
-
- pub fn fake_acp_thread(
- project: Entity<Project>,
- cx: &mut TestAppContext,
- ) -> (Entity<AcpThread>, Entity<FakeAcpServer>) {
- let (stdin_tx, stdin_rx) = async_pipe::pipe();
- let (stdout_tx, stdout_rx) = async_pipe::pipe();
-
- let thread = cx.new(|cx| {
- let foreground_executor = cx.foreground_executor().clone();
- let thread_rc = Rc::new(RefCell::new(cx.entity().downgrade()));
-
- let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
- OldAcpClientDelegate::new(thread_rc.clone(), cx.to_async()),
- stdin_tx,
- stdout_rx,
- move |fut| {
- foreground_executor.spawn(fut).detach();
- },
- );
-
- let io_task = cx.background_spawn({
- async move {
- io_fut.await.log_err();
- Ok(())
- }
- });
- let connection = OldAcpAgentConnection {
- name: "test",
- connection,
- child_status: io_task,
- current_thread: thread_rc,
- };
-
- AcpThread::new(
- Rc::new(connection),
- project,
- acp::SessionId("test".into()),
- cx,
- )
- });
- let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
- (thread, agent)
- }
-
- pub struct FakeAcpServer {
- connection: acp_old::ClientConnection,
-
- _io_task: Task<()>,
- on_user_message: Option<
- Rc<
- dyn Fn(
- acp_old::SendUserMessageParams,
- Entity<FakeAcpServer>,
- AsyncApp,
- ) -> LocalBoxFuture<'static, Result<(), acp_old::Error>>,
- >,
- >,
- }
-
- #[derive(Clone)]
- struct FakeAgent {
- server: Entity<FakeAcpServer>,
- cx: AsyncApp,
- cancel_tx: Rc<RefCell<Option<oneshot::Sender<()>>>>,
- }
-
- impl acp_old::Agent for FakeAgent {
- async fn initialize(
- &self,
- params: acp_old::InitializeParams,
- ) -> Result<acp_old::InitializeResponse, acp_old::Error> {
- Ok(acp_old::InitializeResponse {
- protocol_version: params.protocol_version,
- is_authenticated: true,
- })
- }
-
- async fn authenticate(&self) -> Result<(), acp_old::Error> {
- Ok(())
- }
-
- async fn cancel_send_message(&self) -> Result<(), acp_old::Error> {
- if let Some(cancel_tx) = self.cancel_tx.take() {
- cancel_tx.send(()).log_err();
- }
- Ok(())
- }
-
- async fn send_user_message(
- &self,
- request: acp_old::SendUserMessageParams,
- ) -> Result<(), acp_old::Error> {
- let (cancel_tx, cancel_rx) = oneshot::channel();
- self.cancel_tx.replace(Some(cancel_tx));
-
- let mut cx = self.cx.clone();
- let handler = self
- .server
- .update(&mut cx, |server, _| server.on_user_message.clone())
- .ok()
- .flatten();
- if let Some(handler) = handler {
- select! {
- _ = cancel_rx.fuse() => Err(anyhow::anyhow!("Message sending canceled").into()),
- _ = handler(request, self.server.clone(), self.cx.clone()).fuse() => Ok(()),
- }
+ cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
} else {
- Err(anyhow::anyhow!("No handler for on_user_message").into())
- }
- }
- }
-
- impl FakeAcpServer {
- fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
- let agent = FakeAgent {
- server: cx.entity(),
- cx: cx.to_async(),
- cancel_tx: Default::default(),
- };
- let foreground_executor = cx.foreground_executor().clone();
-
- let (connection, io_fut) = acp_old::ClientConnection::connect_to_client(
- agent.clone(),
- stdout,
- stdin,
- move |fut| {
- foreground_executor.spawn(fut).detach();
- },
- );
- FakeAcpServer {
- connection: connection,
- on_user_message: None,
- _io_task: cx.background_spawn(async move {
- io_fut.await.log_err();
- }),
+ Task::ready(Ok(()))
}
}
- fn on_user_message<F>(
- &mut self,
- handler: impl for<'a> Fn(
- acp_old::SendUserMessageParams,
- Entity<FakeAcpServer>,
- AsyncApp,
- ) -> F
- + 'static,
- ) where
- F: Future<Output = Result<(), acp_old::Error>> + 'static,
- {
- self.on_user_message
- .replace(Rc::new(move |request, server, cx| {
- handler(request, server, cx).boxed_local()
- }));
- }
+ fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
+ let sessions = self.sessions.lock();
+ let thread = sessions.get(&session_id).unwrap().clone();
- fn send_to_zed<T: acp_old::ClientRequest + 'static>(
- &self,
- message: T,
- ) -> BoxedLocal<Result<T::Response>> {
- self.connection
- .request(message)
- .map(|f| f.map_err(|err| anyhow!(err)))
- .boxed_local()
+ cx.spawn(async move |cx| {
+ thread
+ .update(cx, |thread, cx| thread.cancel(cx))
+ .unwrap()
+ .await
+ })
+ .detach();
}
}
}
@@ -1,6 +1,6 @@
-use std::{path::Path, rc::Rc};
+use std::{error::Error, fmt, path::Path, rc::Rc};
-use agent_client_protocol as acp;
+use agent_client_protocol::{self as acp};
use anyhow::Result;
use gpui::{AsyncApp, Entity, Task};
use project::Project;
@@ -9,8 +9,6 @@ use ui::App;
use crate::AcpThread;
pub trait AgentConnection {
- fn name(&self) -> &'static str;
-
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
@@ -18,9 +16,21 @@ pub trait AgentConnection {
cx: &mut AsyncApp,
) -> Task<Result<Entity<AcpThread>>>;
- fn authenticate(&self, cx: &mut App) -> Task<Result<()>>;
+ fn auth_methods(&self) -> &[acp::AuthMethod];
+
+ fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
- fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>>;
+ fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>>;
fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
}
+
+#[derive(Debug)]
+pub struct AuthRequired;
+
+impl Error for AuthRequired {}
+impl fmt::Display for AuthRequired {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(f, "AuthRequired")
+ }
+}
@@ -25,6 +25,7 @@ collections.workspace = true
context_server.workspace = true
futures.workspace = true
gpui.workspace = true
+indoc.workspace = true
itertools.workspace = true
log.workspace = true
paths.workspace = true
@@ -37,11 +38,11 @@ settings.workspace = true
smol.workspace = true
strum.workspace = true
tempfile.workspace = true
+thiserror.workspace = true
ui.workspace = true
util.workspace = true
uuid.workspace = true
watch.workspace = true
-indoc.workspace = true
which.workspace = true
workspace-hack.workspace = true
@@ -0,0 +1,34 @@
+use std::{path::Path, rc::Rc};
+
+use crate::AgentServerCommand;
+use acp_thread::AgentConnection;
+use anyhow::Result;
+use gpui::AsyncApp;
+use thiserror::Error;
+
+mod v0;
+mod v1;
+
+#[derive(Debug, Error)]
+#[error("Unsupported version")]
+pub struct UnsupportedVersion;
+
+pub async fn connect(
+ server_name: &'static str,
+ command: AgentServerCommand,
+ root_dir: &Path,
+ cx: &mut AsyncApp,
+) -> Result<Rc<dyn AgentConnection>> {
+ let conn = v1::AcpConnection::stdio(server_name, command.clone(), &root_dir, cx).await;
+
+ match conn {
+ Ok(conn) => Ok(Rc::new(conn) as _),
+ Err(err) if err.is::<UnsupportedVersion>() => {
+ // Consider re-using initialize response and subprocess when adding another version here
+ let conn: Rc<dyn AgentConnection> =
+ Rc::new(v0::AcpConnection::stdio(server_name, command, &root_dir, cx).await?);
+ Ok(conn)
+ }
+ Err(err) => Err(err),
+ }
+}
@@ -1,18 +1,19 @@
// Translates old acp agents into the new schema
use agent_client_protocol as acp;
use agentic_coding_protocol::{self as acp_old, AgentRequest as _};
-use anyhow::{Context as _, Result};
+use anyhow::{Context as _, Result, anyhow};
use futures::channel::oneshot;
use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
use project::Project;
-use std::{cell::RefCell, error::Error, fmt, path::Path, rc::Rc};
+use std::{cell::RefCell, path::Path, rc::Rc};
use ui::App;
use util::ResultExt as _;
-use crate::{AcpThread, AgentConnection};
+use crate::AgentServerCommand;
+use acp_thread::{AcpThread, AgentConnection, AuthRequired};
#[derive(Clone)]
-pub struct OldAcpClientDelegate {
+struct OldAcpClientDelegate {
thread: Rc<RefCell<WeakEntity<AcpThread>>>,
cx: AsyncApp,
next_tool_call_id: Rc<RefCell<u64>>,
@@ -20,7 +21,7 @@ pub struct OldAcpClientDelegate {
}
impl OldAcpClientDelegate {
- pub fn new(thread: Rc<RefCell<WeakEntity<AcpThread>>>, cx: AsyncApp) -> Self {
+ fn new(thread: Rc<RefCell<WeakEntity<AcpThread>>>, cx: AsyncApp) -> Self {
Self {
thread,
cx,
@@ -351,28 +352,71 @@ fn into_new_plan_status(status: acp_old::PlanEntryStatus) -> acp::PlanEntryStatu
}
}
-#[derive(Debug)]
-pub struct Unauthenticated;
-
-impl Error for Unauthenticated {}
-impl fmt::Display for Unauthenticated {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(f, "Unauthenticated")
- }
-}
-
-pub struct OldAcpAgentConnection {
+pub struct AcpConnection {
pub name: &'static str,
pub connection: acp_old::AgentConnection,
- pub child_status: Task<Result<()>>,
+ pub _child_status: Task<Result<()>>,
pub current_thread: Rc<RefCell<WeakEntity<AcpThread>>>,
}
-impl AgentConnection for OldAcpAgentConnection {
- fn name(&self) -> &'static str {
- self.name
+impl AcpConnection {
+ pub fn stdio(
+ name: &'static str,
+ command: AgentServerCommand,
+ root_dir: &Path,
+ cx: &mut AsyncApp,
+ ) -> Task<Result<Self>> {
+ let root_dir = root_dir.to_path_buf();
+
+ cx.spawn(async move |cx| {
+ let mut child = util::command::new_smol_command(&command.path)
+ .args(command.args.iter())
+ .current_dir(root_dir)
+ .stdin(std::process::Stdio::piped())
+ .stdout(std::process::Stdio::piped())
+ .stderr(std::process::Stdio::inherit())
+ .kill_on_drop(true)
+ .spawn()?;
+
+ let stdin = child.stdin.take().unwrap();
+ let stdout = child.stdout.take().unwrap();
+
+ let foreground_executor = cx.foreground_executor().clone();
+
+ let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid()));
+
+ let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
+ OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()),
+ stdin,
+ stdout,
+ move |fut| foreground_executor.spawn(fut).detach(),
+ );
+
+ let io_task = cx.background_spawn(async move {
+ io_fut.await.log_err();
+ });
+
+ let child_status = cx.background_spawn(async move {
+ let result = match child.status().await {
+ Err(e) => Err(anyhow!(e)),
+ Ok(result) if result.success() => Ok(()),
+ Ok(result) => Err(anyhow!(result)),
+ };
+ drop(io_task);
+ result
+ });
+
+ Ok(Self {
+ name,
+ connection,
+ _child_status: child_status,
+ current_thread: thread_rc,
+ })
+ })
}
+}
+impl AgentConnection for AcpConnection {
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
@@ -391,13 +435,13 @@ impl AgentConnection for OldAcpAgentConnection {
let result = acp_old::InitializeParams::response_from_any(result)?;
if !result.is_authenticated {
- anyhow::bail!(Unauthenticated)
+ anyhow::bail!(AuthRequired)
}
cx.update(|cx| {
let thread = cx.new(|cx| {
let session_id = acp::SessionId("acp-old-no-id".into());
- AcpThread::new(self.clone(), project, session_id, cx)
+ AcpThread::new(self.name, self.clone(), project, session_id, cx)
});
current_thread.replace(thread.downgrade());
thread
@@ -405,7 +449,11 @@ impl AgentConnection for OldAcpAgentConnection {
})
}
- fn authenticate(&self, cx: &mut App) -> Task<Result<()>> {
+ fn auth_methods(&self) -> &[acp::AuthMethod] {
+ &[]
+ }
+
+ fn authenticate(&self, _method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
let task = self
.connection
.request_any(acp_old::AuthenticateParams.into_any());
@@ -415,7 +463,7 @@ impl AgentConnection for OldAcpAgentConnection {
})
}
- fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>> {
+ fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
let chunks = params
.prompt
.into_iter()
@@ -0,0 +1,254 @@
+use agent_client_protocol::{self as acp, Agent as _};
+use collections::HashMap;
+use futures::channel::oneshot;
+use project::Project;
+use std::cell::RefCell;
+use std::path::Path;
+use std::rc::Rc;
+
+use anyhow::{Context as _, Result};
+use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
+
+use crate::{AgentServerCommand, acp::UnsupportedVersion};
+use acp_thread::{AcpThread, AgentConnection, AuthRequired};
+
+pub struct AcpConnection {
+ server_name: &'static str,
+ connection: Rc<acp::ClientSideConnection>,
+ sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
+ auth_methods: Vec<acp::AuthMethod>,
+ _io_task: Task<Result<()>>,
+ _child: smol::process::Child,
+}
+
+pub struct AcpSession {
+ thread: WeakEntity<AcpThread>,
+}
+
+const MINIMUM_SUPPORTED_VERSION: acp::ProtocolVersion = acp::V1;
+
+impl AcpConnection {
+ pub async fn stdio(
+ server_name: &'static str,
+ command: AgentServerCommand,
+ root_dir: &Path,
+ cx: &mut AsyncApp,
+ ) -> Result<Self> {
+ let mut child = util::command::new_smol_command(&command.path)
+ .args(command.args.iter().map(|arg| arg.as_str()))
+ .envs(command.env.iter().flatten())
+ .current_dir(root_dir)
+ .stdin(std::process::Stdio::piped())
+ .stdout(std::process::Stdio::piped())
+ .stderr(std::process::Stdio::inherit())
+ .kill_on_drop(true)
+ .spawn()?;
+
+ let stdout = child.stdout.take().expect("Failed to take stdout");
+ let stdin = child.stdin.take().expect("Failed to take stdin");
+
+ let sessions = Rc::new(RefCell::new(HashMap::default()));
+
+ let client = ClientDelegate {
+ sessions: sessions.clone(),
+ cx: cx.clone(),
+ };
+ let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
+ let foreground_executor = cx.foreground_executor().clone();
+ move |fut| {
+ foreground_executor.spawn(fut).detach();
+ }
+ });
+
+ let io_task = cx.background_spawn(io_task);
+
+ let response = connection
+ .initialize(acp::InitializeRequest {
+ protocol_version: acp::VERSION,
+ client_capabilities: acp::ClientCapabilities {
+ fs: acp::FileSystemCapability {
+ read_text_file: true,
+ write_text_file: true,
+ },
+ },
+ })
+ .await?;
+
+ if response.protocol_version < MINIMUM_SUPPORTED_VERSION {
+ return Err(UnsupportedVersion.into());
+ }
+
+ Ok(Self {
+ auth_methods: response.auth_methods,
+ connection: connection.into(),
+ server_name,
+ sessions,
+ _child: child,
+ _io_task: io_task,
+ })
+ }
+}
+
+impl AgentConnection for AcpConnection {
+ fn new_thread(
+ self: Rc<Self>,
+ project: Entity<Project>,
+ cwd: &Path,
+ cx: &mut AsyncApp,
+ ) -> Task<Result<Entity<AcpThread>>> {
+ let conn = self.connection.clone();
+ let sessions = self.sessions.clone();
+ let cwd = cwd.to_path_buf();
+ cx.spawn(async move |cx| {
+ let response = conn
+ .new_session(acp::NewSessionRequest {
+ mcp_servers: vec![],
+ cwd,
+ })
+ .await?;
+
+ let Some(session_id) = response.session_id else {
+ anyhow::bail!(AuthRequired);
+ };
+
+ let thread = cx.new(|cx| {
+ AcpThread::new(
+ self.server_name,
+ self.clone(),
+ project,
+ session_id.clone(),
+ cx,
+ )
+ })?;
+
+ let session = AcpSession {
+ thread: thread.downgrade(),
+ };
+ sessions.borrow_mut().insert(session_id, session);
+
+ Ok(thread)
+ })
+ }
+
+ fn auth_methods(&self) -> &[acp::AuthMethod] {
+ &self.auth_methods
+ }
+
+ fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
+ let conn = self.connection.clone();
+ cx.foreground_executor().spawn(async move {
+ let result = conn
+ .authenticate(acp::AuthenticateRequest {
+ method_id: method_id.clone(),
+ })
+ .await?;
+
+ Ok(result)
+ })
+ }
+
+ fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
+ let conn = self.connection.clone();
+ cx.foreground_executor()
+ .spawn(async move { Ok(conn.prompt(params).await?) })
+ }
+
+ fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
+ let conn = self.connection.clone();
+ let params = acp::CancelledNotification {
+ session_id: session_id.clone(),
+ };
+ cx.foreground_executor()
+ .spawn(async move { conn.cancelled(params).await })
+ .detach();
+ }
+}
+
+struct ClientDelegate {
+ sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
+ cx: AsyncApp,
+}
+
+impl acp::Client for ClientDelegate {
+ async fn request_permission(
+ &self,
+ arguments: acp::RequestPermissionRequest,
+ ) -> Result<acp::RequestPermissionResponse, acp::Error> {
+ let cx = &mut self.cx.clone();
+ let rx = self
+ .sessions
+ .borrow()
+ .get(&arguments.session_id)
+ .context("Failed to get session")?
+ .thread
+ .update(cx, |thread, cx| {
+ thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx)
+ })?;
+
+ let result = rx.await;
+
+ let outcome = match result {
+ Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
+ Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
+ };
+
+ Ok(acp::RequestPermissionResponse { outcome })
+ }
+
+ async fn write_text_file(
+ &self,
+ arguments: acp::WriteTextFileRequest,
+ ) -> Result<(), acp::Error> {
+ let cx = &mut self.cx.clone();
+ let task = self
+ .sessions
+ .borrow()
+ .get(&arguments.session_id)
+ .context("Failed to get session")?
+ .thread
+ .update(cx, |thread, cx| {
+ thread.write_text_file(arguments.path, arguments.content, cx)
+ })?;
+
+ task.await?;
+
+ Ok(())
+ }
+
+ async fn read_text_file(
+ &self,
+ arguments: acp::ReadTextFileRequest,
+ ) -> Result<acp::ReadTextFileResponse, acp::Error> {
+ let cx = &mut self.cx.clone();
+ let task = self
+ .sessions
+ .borrow()
+ .get(&arguments.session_id)
+ .context("Failed to get session")?
+ .thread
+ .update(cx, |thread, cx| {
+ thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
+ })?;
+
+ let content = task.await?;
+
+ Ok(acp::ReadTextFileResponse { content })
+ }
+
+ async fn session_notification(
+ &self,
+ notification: acp::SessionNotification,
+ ) -> Result<(), acp::Error> {
+ let cx = &mut self.cx.clone();
+ let sessions = self.sessions.borrow();
+ let session = sessions
+ .get(¬ification.session_id)
+ .context("Failed to get session")?;
+
+ session.thread.update(cx, |thread, cx| {
+ thread.handle_session_update(notification.update, cx)
+ })??;
+
+ Ok(())
+ }
+}
@@ -1,14 +1,12 @@
+mod acp;
mod claude;
-mod codex;
mod gemini;
-mod mcp_server;
mod settings;
#[cfg(test)]
mod e2e_tests;
pub use claude::*;
-pub use codex::*;
pub use gemini::*;
pub use settings::*;
@@ -38,7 +36,6 @@ pub trait AgentServer: Send {
fn connect(
&self,
- // these will go away when old_acp is fully removed
root_dir: &Path,
project: &Entity<Project>,
cx: &mut App,
@@ -70,10 +70,6 @@ struct ClaudeAgentConnection {
}
impl AgentConnection for ClaudeAgentConnection {
- fn name(&self) -> &'static str {
- ClaudeCode.name()
- }
-
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
@@ -168,8 +164,9 @@ impl AgentConnection for ClaudeAgentConnection {
}
});
- let thread =
- cx.new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))?;
+ let thread = cx.new(|cx| {
+ AcpThread::new("Claude Code", self.clone(), project, session_id.clone(), cx)
+ })?;
thread_tx.send(thread.downgrade())?;
@@ -186,11 +183,15 @@ impl AgentConnection for ClaudeAgentConnection {
})
}
- fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
+ fn auth_methods(&self) -> &[acp::AuthMethod] {
+ &[]
+ }
+
+ fn authenticate(&self, _: acp::AuthMethodId, _cx: &mut App) -> Task<Result<()>> {
Task::ready(Err(anyhow!("Authentication not supported")))
}
- fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<Result<()>> {
+ fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
let sessions = self.sessions.borrow();
let Some(session) = sessions.get(¶ms.session_id) else {
return Task::ready(Err(anyhow!(
@@ -1,319 +0,0 @@
-use agent_client_protocol as acp;
-use anyhow::anyhow;
-use collections::HashMap;
-use context_server::listener::McpServerTool;
-use context_server::types::requests;
-use context_server::{ContextServer, ContextServerCommand, ContextServerId};
-use futures::channel::{mpsc, oneshot};
-use project::Project;
-use settings::SettingsStore;
-use smol::stream::StreamExt as _;
-use std::cell::RefCell;
-use std::rc::Rc;
-use std::{path::Path, sync::Arc};
-use util::ResultExt;
-
-use anyhow::{Context, Result};
-use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
-
-use crate::mcp_server::ZedMcpServer;
-use crate::{AgentServer, AgentServerCommand, AllAgentServersSettings, mcp_server};
-use acp_thread::{AcpThread, AgentConnection};
-
-#[derive(Clone)]
-pub struct Codex;
-
-impl AgentServer for Codex {
- fn name(&self) -> &'static str {
- "Codex"
- }
-
- fn empty_state_headline(&self) -> &'static str {
- "Welcome to Codex"
- }
-
- fn empty_state_message(&self) -> &'static str {
- "What can I help with?"
- }
-
- fn logo(&self) -> ui::IconName {
- ui::IconName::AiOpenAi
- }
-
- fn connect(
- &self,
- _root_dir: &Path,
- project: &Entity<Project>,
- cx: &mut App,
- ) -> Task<Result<Rc<dyn AgentConnection>>> {
- let project = project.clone();
- let working_directory = project.read(cx).active_project_directory(cx);
- cx.spawn(async move |cx| {
- let settings = cx.read_global(|settings: &SettingsStore, _| {
- settings.get::<AllAgentServersSettings>(None).codex.clone()
- })?;
-
- let Some(command) =
- AgentServerCommand::resolve("codex", &["mcp"], settings, &project, cx).await
- else {
- anyhow::bail!("Failed to find codex binary");
- };
-
- let client: Arc<ContextServer> = ContextServer::stdio(
- ContextServerId("codex-mcp-server".into()),
- ContextServerCommand {
- path: command.path,
- args: command.args,
- env: command.env,
- },
- working_directory,
- )
- .into();
- ContextServer::start(client.clone(), cx).await?;
-
- let (notification_tx, mut notification_rx) = mpsc::unbounded();
- client
- .client()
- .context("Failed to subscribe")?
- .on_notification(acp::SESSION_UPDATE_METHOD_NAME, {
- move |notification, _cx| {
- let notification_tx = notification_tx.clone();
- log::trace!(
- "ACP Notification: {}",
- serde_json::to_string_pretty(¬ification).unwrap()
- );
-
- if let Some(notification) =
- serde_json::from_value::<acp::SessionNotification>(notification)
- .log_err()
- {
- notification_tx.unbounded_send(notification).ok();
- }
- }
- });
-
- let sessions = Rc::new(RefCell::new(HashMap::default()));
-
- let notification_handler_task = cx.spawn({
- let sessions = sessions.clone();
- async move |cx| {
- while let Some(notification) = notification_rx.next().await {
- CodexConnection::handle_session_notification(
- notification,
- sessions.clone(),
- cx,
- )
- }
- }
- });
-
- let connection = CodexConnection {
- client,
- sessions,
- _notification_handler_task: notification_handler_task,
- };
- Ok(Rc::new(connection) as _)
- })
- }
-}
-
-struct CodexConnection {
- client: Arc<context_server::ContextServer>,
- sessions: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
- _notification_handler_task: Task<()>,
-}
-
-struct CodexSession {
- thread: WeakEntity<AcpThread>,
- cancel_tx: Option<oneshot::Sender<()>>,
- _mcp_server: ZedMcpServer,
-}
-
-impl AgentConnection for CodexConnection {
- fn name(&self) -> &'static str {
- "Codex"
- }
-
- fn new_thread(
- self: Rc<Self>,
- project: Entity<Project>,
- cwd: &Path,
- cx: &mut AsyncApp,
- ) -> Task<Result<Entity<AcpThread>>> {
- let client = self.client.client();
- let sessions = self.sessions.clone();
- let cwd = cwd.to_path_buf();
- cx.spawn(async move |cx| {
- let client = client.context("MCP server is not initialized yet")?;
- let (mut thread_tx, thread_rx) = watch::channel(WeakEntity::new_invalid());
-
- let mcp_server = ZedMcpServer::new(thread_rx, cx).await?;
-
- let response = client
- .request::<requests::CallTool>(context_server::types::CallToolParams {
- name: acp::NEW_SESSION_TOOL_NAME.into(),
- arguments: Some(serde_json::to_value(acp::NewSessionArguments {
- mcp_servers: [(
- mcp_server::SERVER_NAME.to_string(),
- mcp_server.server_config()?,
- )]
- .into(),
- client_tools: acp::ClientTools {
- request_permission: Some(acp::McpToolId {
- mcp_server: mcp_server::SERVER_NAME.into(),
- tool_name: mcp_server::RequestPermissionTool::NAME.into(),
- }),
- read_text_file: Some(acp::McpToolId {
- mcp_server: mcp_server::SERVER_NAME.into(),
- tool_name: mcp_server::ReadTextFileTool::NAME.into(),
- }),
- write_text_file: Some(acp::McpToolId {
- mcp_server: mcp_server::SERVER_NAME.into(),
- tool_name: mcp_server::WriteTextFileTool::NAME.into(),
- }),
- },
- cwd,
- })?),
- meta: None,
- })
- .await?;
-
- if response.is_error.unwrap_or_default() {
- return Err(anyhow!(response.text_contents()));
- }
-
- let result = serde_json::from_value::<acp::NewSessionOutput>(
- response.structured_content.context("Empty response")?,
- )?;
-
- let thread =
- cx.new(|cx| AcpThread::new(self.clone(), project, result.session_id.clone(), cx))?;
-
- thread_tx.send(thread.downgrade())?;
-
- let session = CodexSession {
- thread: thread.downgrade(),
- cancel_tx: None,
- _mcp_server: mcp_server,
- };
- sessions.borrow_mut().insert(result.session_id, session);
-
- Ok(thread)
- })
- }
-
- fn authenticate(&self, _cx: &mut App) -> Task<Result<()>> {
- Task::ready(Err(anyhow!("Authentication not supported")))
- }
-
- fn prompt(
- &self,
- params: agent_client_protocol::PromptArguments,
- cx: &mut App,
- ) -> Task<Result<()>> {
- let client = self.client.client();
- let sessions = self.sessions.clone();
-
- cx.foreground_executor().spawn(async move {
- let client = client.context("MCP server is not initialized yet")?;
-
- let (new_cancel_tx, cancel_rx) = oneshot::channel();
- {
- let mut sessions = sessions.borrow_mut();
- let session = sessions
- .get_mut(¶ms.session_id)
- .context("Session not found")?;
- session.cancel_tx.replace(new_cancel_tx);
- }
-
- let result = client
- .request_with::<requests::CallTool>(
- context_server::types::CallToolParams {
- name: acp::PROMPT_TOOL_NAME.into(),
- arguments: Some(serde_json::to_value(params)?),
- meta: None,
- },
- Some(cancel_rx),
- None,
- )
- .await;
-
- if let Err(err) = &result
- && err.is::<context_server::client::RequestCanceled>()
- {
- return Ok(());
- }
-
- let response = result?;
-
- if response.is_error.unwrap_or_default() {
- return Err(anyhow!(response.text_contents()));
- }
-
- Ok(())
- })
- }
-
- fn cancel(&self, session_id: &agent_client_protocol::SessionId, _cx: &mut App) {
- let mut sessions = self.sessions.borrow_mut();
-
- if let Some(cancel_tx) = sessions
- .get_mut(session_id)
- .and_then(|session| session.cancel_tx.take())
- {
- cancel_tx.send(()).ok();
- }
- }
-}
-
-impl CodexConnection {
- pub fn handle_session_notification(
- notification: acp::SessionNotification,
- threads: Rc<RefCell<HashMap<acp::SessionId, CodexSession>>>,
- cx: &mut AsyncApp,
- ) {
- let threads = threads.borrow();
- let Some(thread) = threads
- .get(¬ification.session_id)
- .and_then(|session| session.thread.upgrade())
- else {
- log::error!(
- "Thread not found for session ID: {}",
- notification.session_id
- );
- return;
- };
-
- thread
- .update(cx, |thread, cx| {
- thread.handle_session_update(notification.update, cx)
- })
- .log_err();
- }
-}
-
-impl Drop for CodexConnection {
- fn drop(&mut self) {
- self.client.stop().log_err();
- }
-}
-
-#[cfg(test)]
-pub(crate) mod tests {
- use super::*;
- use crate::AgentServerCommand;
- use std::path::Path;
-
- crate::common_e2e_tests!(Codex, allow_option_id = "approve");
-
- pub fn local_command() -> AgentServerCommand {
- let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
- .join("../../../codex/codex-rs/target/debug/codex");
-
- AgentServerCommand {
- path: cli_path,
- args: vec![],
- env: None,
- }
- }
-}
@@ -375,9 +375,6 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
gemini: Some(AgentServerSettings {
command: crate::gemini::tests::local_command(),
}),
- codex: Some(AgentServerSettings {
- command: crate::codex::tests::local_command(),
- }),
},
cx,
);
@@ -1,14 +1,10 @@
-use anyhow::anyhow;
-use std::cell::RefCell;
use std::path::Path;
use std::rc::Rc;
-use util::ResultExt as _;
-use crate::{AgentServer, AgentServerCommand, AgentServerVersion};
-use acp_thread::{AgentConnection, LoadError, OldAcpAgentConnection, OldAcpClientDelegate};
-use agentic_coding_protocol as acp_old;
-use anyhow::{Context as _, Result};
-use gpui::{AppContext as _, AsyncApp, Entity, Task, WeakEntity};
+use crate::{AgentServer, AgentServerCommand};
+use acp_thread::AgentConnection;
+use anyhow::Result;
+use gpui::{Entity, Task};
use project::Project;
use settings::SettingsStore;
use ui::App;
@@ -43,144 +39,23 @@ impl AgentServer for Gemini {
project: &Entity<Project>,
cx: &mut App,
) -> Task<Result<Rc<dyn AgentConnection>>> {
- let root_dir = root_dir.to_path_buf();
let project = project.clone();
- let this = self.clone();
- let name = self.name();
-
+ let root_dir = root_dir.to_path_buf();
+ let server_name = self.name();
cx.spawn(async move |cx| {
- let command = this.command(&project, cx).await?;
-
- let mut child = util::command::new_smol_command(&command.path)
- .args(command.args.iter())
- .current_dir(root_dir)
- .stdin(std::process::Stdio::piped())
- .stdout(std::process::Stdio::piped())
- .stderr(std::process::Stdio::inherit())
- .kill_on_drop(true)
- .spawn()?;
-
- let stdin = child.stdin.take().unwrap();
- let stdout = child.stdout.take().unwrap();
-
- let foreground_executor = cx.foreground_executor().clone();
-
- let thread_rc = Rc::new(RefCell::new(WeakEntity::new_invalid()));
-
- let (connection, io_fut) = acp_old::AgentConnection::connect_to_agent(
- OldAcpClientDelegate::new(thread_rc.clone(), cx.clone()),
- stdin,
- stdout,
- move |fut| foreground_executor.spawn(fut).detach(),
- );
-
- let io_task = cx.background_spawn(async move {
- io_fut.await.log_err();
- });
-
- let child_status = cx.background_spawn(async move {
- let result = match child.status().await {
- Err(e) => Err(anyhow!(e)),
- Ok(result) if result.success() => Ok(()),
- Ok(result) => {
- if let Some(AgentServerVersion::Unsupported {
- error_message,
- upgrade_message,
- upgrade_command,
- }) = this.version(&command).await.log_err()
- {
- Err(anyhow!(LoadError::Unsupported {
- error_message,
- upgrade_message,
- upgrade_command
- }))
- } else {
- Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
- }
- }
- };
- drop(io_task);
- result
- });
-
- let connection: Rc<dyn AgentConnection> = Rc::new(OldAcpAgentConnection {
- name,
- connection,
- child_status,
- current_thread: thread_rc,
- });
-
- Ok(connection)
- })
- }
-}
-
-impl Gemini {
- async fn command(
- &self,
- project: &Entity<Project>,
- cx: &mut AsyncApp,
- ) -> Result<AgentServerCommand> {
- let settings = cx.read_global(|settings: &SettingsStore, _| {
- settings.get::<AllAgentServersSettings>(None).gemini.clone()
- })?;
+ let settings = cx.read_global(|settings: &SettingsStore, _| {
+ settings.get::<AllAgentServersSettings>(None).gemini.clone()
+ })?;
- if let Some(command) =
- AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await
- {
- return Ok(command);
- };
+ let Some(command) =
+ AgentServerCommand::resolve("gemini", &[ACP_ARG], settings, &project, cx).await
+ else {
+ anyhow::bail!("Failed to find gemini binary");
+ };
- let (fs, node_runtime) = project.update(cx, |project, _| {
- (project.fs().clone(), project.node_runtime().cloned())
- })?;
- let node_runtime = node_runtime.context("gemini not found on path")?;
-
- let directory = ::paths::agent_servers_dir().join("gemini");
- fs.create_dir(&directory).await?;
- node_runtime
- .npm_install_packages(&directory, &[("@google/gemini-cli", "latest")])
- .await?;
- let path = directory.join("node_modules/.bin/gemini");
-
- Ok(AgentServerCommand {
- path,
- args: vec![ACP_ARG.into()],
- env: None,
+ crate::acp::connect(server_name, command, &root_dir, cx).await
})
}
-
- async fn version(&self, command: &AgentServerCommand) -> Result<AgentServerVersion> {
- let version_fut = util::command::new_smol_command(&command.path)
- .args(command.args.iter())
- .arg("--version")
- .kill_on_drop(true)
- .output();
-
- let help_fut = util::command::new_smol_command(&command.path)
- .args(command.args.iter())
- .arg("--help")
- .kill_on_drop(true)
- .output();
-
- let (version_output, help_output) = futures::future::join(version_fut, help_fut).await;
-
- let current_version = String::from_utf8(version_output?.stdout)?;
- let supported = String::from_utf8(help_output?.stdout)?.contains(ACP_ARG);
-
- if supported {
- Ok(AgentServerVersion::Supported)
- } else {
- Ok(AgentServerVersion::Unsupported {
- error_message: format!(
- "Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).",
- current_version
- ).into(),
- upgrade_message: "Upgrade Gemini to Latest".into(),
- upgrade_command: "npm install -g @google/gemini-cli@latest".into(),
- })
- }
- }
}
#[cfg(test)]
@@ -199,7 +74,7 @@ pub(crate) mod tests {
AgentServerCommand {
path: "node".into(),
- args: vec![cli_path, ACP_ARG.into()],
+ args: vec![cli_path],
env: None,
}
}
@@ -1,207 +0,0 @@
-use acp_thread::AcpThread;
-use agent_client_protocol as acp;
-use anyhow::Result;
-use context_server::listener::{McpServerTool, ToolResponse};
-use context_server::types::{
- Implementation, InitializeParams, InitializeResponse, ProtocolVersion, ServerCapabilities,
- ToolsCapabilities, requests,
-};
-use futures::channel::oneshot;
-use gpui::{App, AsyncApp, Task, WeakEntity};
-use indoc::indoc;
-
-pub struct ZedMcpServer {
- server: context_server::listener::McpServer,
-}
-
-pub const SERVER_NAME: &str = "zed";
-
-impl ZedMcpServer {
- pub async fn new(
- thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
- cx: &AsyncApp,
- ) -> Result<Self> {
- let mut mcp_server = context_server::listener::McpServer::new(cx).await?;
- mcp_server.handle_request::<requests::Initialize>(Self::handle_initialize);
-
- mcp_server.add_tool(RequestPermissionTool {
- thread_rx: thread_rx.clone(),
- });
- mcp_server.add_tool(ReadTextFileTool {
- thread_rx: thread_rx.clone(),
- });
- mcp_server.add_tool(WriteTextFileTool {
- thread_rx: thread_rx.clone(),
- });
-
- Ok(Self { server: mcp_server })
- }
-
- pub fn server_config(&self) -> Result<acp::McpServerConfig> {
- #[cfg(not(test))]
- let zed_path = anyhow::Context::context(
- std::env::current_exe(),
- "finding current executable path for use in mcp_server",
- )?;
-
- #[cfg(test)]
- let zed_path = crate::e2e_tests::get_zed_path();
-
- Ok(acp::McpServerConfig {
- command: zed_path,
- args: vec![
- "--nc".into(),
- self.server.socket_path().display().to_string(),
- ],
- env: None,
- })
- }
-
- fn handle_initialize(_: InitializeParams, cx: &App) -> Task<Result<InitializeResponse>> {
- cx.foreground_executor().spawn(async move {
- Ok(InitializeResponse {
- protocol_version: ProtocolVersion("2025-06-18".into()),
- capabilities: ServerCapabilities {
- experimental: None,
- logging: None,
- completions: None,
- prompts: None,
- resources: None,
- tools: Some(ToolsCapabilities {
- list_changed: Some(false),
- }),
- },
- server_info: Implementation {
- name: SERVER_NAME.into(),
- version: "0.1.0".into(),
- },
- meta: None,
- })
- })
- }
-}
-
-// Tools
-
-#[derive(Clone)]
-pub struct RequestPermissionTool {
- thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
-}
-
-impl McpServerTool for RequestPermissionTool {
- type Input = acp::RequestPermissionArguments;
- type Output = acp::RequestPermissionOutput;
-
- const NAME: &'static str = "Confirmation";
-
- fn description(&self) -> &'static str {
- indoc! {"
- Request permission for tool calls.
-
- This tool is meant to be called programmatically by the agent loop, not the LLM.
- "}
- }
-
- async fn run(
- &self,
- input: Self::Input,
- cx: &mut AsyncApp,
- ) -> Result<ToolResponse<Self::Output>> {
- let mut thread_rx = self.thread_rx.clone();
- let Some(thread) = thread_rx.recv().await?.upgrade() else {
- anyhow::bail!("Thread closed");
- };
-
- let result = thread
- .update(cx, |thread, cx| {
- thread.request_tool_call_permission(input.tool_call, input.options, cx)
- })?
- .await;
-
- let outcome = match result {
- Ok(option_id) => acp::RequestPermissionOutcome::Selected { option_id },
- Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Canceled,
- };
-
- Ok(ToolResponse {
- content: vec![],
- structured_content: acp::RequestPermissionOutput { outcome },
- })
- }
-}
-
-#[derive(Clone)]
-pub struct ReadTextFileTool {
- thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
-}
-
-impl McpServerTool for ReadTextFileTool {
- type Input = acp::ReadTextFileArguments;
- type Output = acp::ReadTextFileOutput;
-
- const NAME: &'static str = "Read";
-
- fn description(&self) -> &'static str {
- "Reads the content of the given file in the project including unsaved changes."
- }
-
- async fn run(
- &self,
- input: Self::Input,
- cx: &mut AsyncApp,
- ) -> Result<ToolResponse<Self::Output>> {
- let mut thread_rx = self.thread_rx.clone();
- let Some(thread) = thread_rx.recv().await?.upgrade() else {
- anyhow::bail!("Thread closed");
- };
-
- let content = thread
- .update(cx, |thread, cx| {
- thread.read_text_file(input.path, input.line, input.limit, false, cx)
- })?
- .await?;
-
- Ok(ToolResponse {
- content: vec![],
- structured_content: acp::ReadTextFileOutput { content },
- })
- }
-}
-
-#[derive(Clone)]
-pub struct WriteTextFileTool {
- thread_rx: watch::Receiver<WeakEntity<AcpThread>>,
-}
-
-impl McpServerTool for WriteTextFileTool {
- type Input = acp::WriteTextFileArguments;
- type Output = ();
-
- const NAME: &'static str = "Write";
-
- fn description(&self) -> &'static str {
- "Write to a file replacing its contents"
- }
-
- async fn run(
- &self,
- input: Self::Input,
- cx: &mut AsyncApp,
- ) -> Result<ToolResponse<Self::Output>> {
- let mut thread_rx = self.thread_rx.clone();
- let Some(thread) = thread_rx.recv().await?.upgrade() else {
- anyhow::bail!("Thread closed");
- };
-
- thread
- .update(cx, |thread, cx| {
- thread.write_text_file(input.path, input.content, cx)
- })?
- .await?;
-
- Ok(ToolResponse {
- content: vec![],
- structured_content: (),
- })
- }
-}
@@ -13,7 +13,6 @@ pub fn init(cx: &mut App) {
pub struct AllAgentServersSettings {
pub gemini: Option<AgentServerSettings>,
pub claude: Option<AgentServerSettings>,
- pub codex: Option<AgentServerSettings>,
}
#[derive(Deserialize, Serialize, Clone, JsonSchema, Debug)]
@@ -30,21 +29,13 @@ impl settings::Settings for AllAgentServersSettings {
fn load(sources: SettingsSources<Self::FileContent>, _: &mut App) -> Result<Self> {
let mut settings = AllAgentServersSettings::default();
- for AllAgentServersSettings {
- gemini,
- claude,
- codex,
- } in sources.defaults_and_customizations()
- {
+ for AllAgentServersSettings { gemini, claude } in sources.defaults_and_customizations() {
if gemini.is_some() {
settings.gemini = gemini.clone();
}
if claude.is_some() {
settings.claude = claude.clone();
}
- if codex.is_some() {
- settings.codex = codex.clone();
- }
}
Ok(settings)
@@ -246,7 +246,7 @@ impl AcpThreadView {
{
Err(e) => {
let mut cx = cx.clone();
- if e.downcast_ref::<acp_thread::Unauthenticated>().is_some() {
+ if e.is::<acp_thread::AuthRequired>() {
this.update(&mut cx, |this, cx| {
this.thread_state = ThreadState::Unauthenticated { connection };
cx.notify();
@@ -719,13 +719,18 @@ impl AcpThreadView {
Some(entry.diffs().map(|diff| diff.multibuffer.clone()))
}
- fn authenticate(&mut self, window: &mut Window, cx: &mut Context<Self>) {
+ fn authenticate(
+ &mut self,
+ method: acp::AuthMethodId,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
let ThreadState::Unauthenticated { ref connection } = self.thread_state else {
return;
};
self.last_error.take();
- let authenticate = connection.authenticate(cx);
+ let authenticate = connection.authenticate(method, cx);
self.auth_task = Some(cx.spawn_in(window, {
let project = self.project.clone();
let agent = self.agent.clone();
@@ -2424,22 +2429,26 @@ impl Render for AcpThreadView {
.on_action(cx.listener(Self::next_history_message))
.on_action(cx.listener(Self::open_agent_diff))
.child(match &self.thread_state {
- ThreadState::Unauthenticated { .. } => {
- v_flex()
- .p_2()
- .flex_1()
- .items_center()
- .justify_center()
- .child(self.render_pending_auth_state())
- .child(
- h_flex().mt_1p5().justify_center().child(
- Button::new("sign-in", format!("Sign in to {}", self.agent.name()))
- .on_click(cx.listener(|this, _, window, cx| {
- this.authenticate(window, cx)
- })),
- ),
- )
- }
+ ThreadState::Unauthenticated { connection } => v_flex()
+ .p_2()
+ .flex_1()
+ .items_center()
+ .justify_center()
+ .child(self.render_pending_auth_state())
+ .child(h_flex().mt_1p5().justify_center().children(
+ connection.auth_methods().into_iter().map(|method| {
+ Button::new(
+ SharedString::from(method.id.0.clone()),
+ method.label.clone(),
+ )
+ .on_click({
+ let method_id = method.id.clone();
+ cx.listener(move |this, _, window, cx| {
+ this.authenticate(method_id.clone(), window, cx)
+ })
+ })
+ }),
+ )),
ThreadState::Loading { .. } => v_flex().flex_1().child(self.render_empty_state(cx)),
ThreadState::LoadError(e) => v_flex()
.p_2()
@@ -2878,8 +2887,8 @@ mod tests {
}
impl AgentConnection for StubAgentConnection {
- fn name(&self) -> &'static str {
- "StubAgentConnection"
+ fn auth_methods(&self) -> &[acp::AuthMethod] {
+ &[]
}
fn new_thread(
@@ -2897,17 +2906,21 @@ mod tests {
.into(),
);
let thread = cx
- .new(|cx| AcpThread::new(self.clone(), project, session_id.clone(), cx))
+ .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
.unwrap();
self.sessions.lock().insert(session_id, thread.downgrade());
Task::ready(Ok(thread))
}
- fn authenticate(&self, _cx: &mut App) -> Task<gpui::Result<()>> {
+ fn authenticate(
+ &self,
+ _method_id: acp::AuthMethodId,
+ _cx: &mut App,
+ ) -> Task<gpui::Result<()>> {
unimplemented!()
}
- fn prompt(&self, params: acp::PromptArguments, cx: &mut App) -> Task<gpui::Result<()>> {
+ fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<gpui::Result<()>> {
let sessions = self.sessions.lock();
let thread = sessions.get(¶ms.session_id).unwrap();
let mut tasks = vec![];
@@ -2954,10 +2967,6 @@ mod tests {
struct SaboteurAgentConnection;
impl AgentConnection for SaboteurAgentConnection {
- fn name(&self) -> &'static str {
- "SaboteurAgentConnection"
- }
-
fn new_thread(
self: Rc<Self>,
project: Entity<Project>,
@@ -2965,15 +2974,31 @@ mod tests {
cx: &mut gpui::AsyncApp,
) -> Task<gpui::Result<Entity<AcpThread>>> {
Task::ready(Ok(cx
- .new(|cx| AcpThread::new(self, project, SessionId("test".into()), cx))
+ .new(|cx| {
+ AcpThread::new(
+ "SaboteurAgentConnection",
+ self,
+ project,
+ SessionId("test".into()),
+ cx,
+ )
+ })
.unwrap()))
}
- fn authenticate(&self, _cx: &mut App) -> Task<gpui::Result<()>> {
+ fn auth_methods(&self) -> &[acp::AuthMethod] {
+ &[]
+ }
+
+ fn authenticate(
+ &self,
+ _method_id: acp::AuthMethodId,
+ _cx: &mut App,
+ ) -> Task<gpui::Result<()>> {
unimplemented!()
}
- fn prompt(&self, _params: acp::PromptArguments, _cx: &mut App) -> Task<gpui::Result<()>> {
+ fn prompt(&self, _params: acp::PromptRequest, _cx: &mut App) -> Task<gpui::Result<()>> {
Task::ready(Err(anyhow::anyhow!("Error prompting")))
}
@@ -1987,20 +1987,6 @@ impl AgentPanel {
);
}),
)
- .item(
- ContextMenuEntry::new("New Codex Thread")
- .icon(IconName::AiOpenAi)
- .icon_color(Color::Muted)
- .handler(move |window, cx| {
- window.dispatch_action(
- NewExternalAgentThread {
- agent: Some(crate::ExternalAgent::Codex),
- }
- .boxed_clone(),
- cx,
- );
- }),
- )
});
menu
}))
@@ -2662,25 +2648,6 @@ impl AgentPanel {
)
},
),
- )
- .child(
- NewThreadButton::new(
- "new-codex-thread-btn",
- "New Codex Thread",
- IconName::AiOpenAi,
- )
- .on_click(
- |window, cx| {
- window.dispatch_action(
- Box::new(NewExternalAgentThread {
- agent: Some(
- crate::ExternalAgent::Codex,
- ),
- }),
- cx,
- )
- },
- ),
),
)
}),
@@ -150,7 +150,6 @@ enum ExternalAgent {
#[default]
Gemini,
ClaudeCode,
- Codex,
}
impl ExternalAgent {
@@ -158,7 +157,6 @@ impl ExternalAgent {
match self {
ExternalAgent::Gemini => Rc::new(agent_servers::Gemini),
ExternalAgent::ClaudeCode => Rc::new(agent_servers::ClaudeCode),
- ExternalAgent::Codex => Rc::new(agent_servers::Codex),
}
}
}
@@ -441,14 +441,12 @@ impl Client {
Ok(())
}
- #[allow(unused)]
- pub fn on_notification<F>(&self, method: &'static str, f: F)
- where
- F: 'static + Send + FnMut(Value, AsyncApp),
- {
- self.notification_handlers
- .lock()
- .insert(method, Box::new(f));
+ pub fn on_notification(
+ &self,
+ method: &'static str,
+ f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
+ ) {
+ self.notification_handlers.lock().insert(method, f);
}
}
@@ -95,8 +95,28 @@ impl ContextServer {
self.client.read().clone()
}
- pub async fn start(self: Arc<Self>, cx: &AsyncApp) -> Result<()> {
- let client = match &self.configuration {
+ pub async fn start(&self, cx: &AsyncApp) -> Result<()> {
+ self.initialize(self.new_client(cx)?).await
+ }
+
+ /// Starts the context server, making sure handlers are registered before initialization happens
+ pub async fn start_with_handlers(
+ &self,
+ notification_handlers: Vec<(
+ &'static str,
+ Box<dyn 'static + Send + FnMut(serde_json::Value, AsyncApp)>,
+ )>,
+ cx: &AsyncApp,
+ ) -> Result<()> {
+ let client = self.new_client(cx)?;
+ for (method, handler) in notification_handlers {
+ client.on_notification(method, handler);
+ }
+ self.initialize(client).await
+ }
+
+ fn new_client(&self, cx: &AsyncApp) -> Result<Client> {
+ Ok(match &self.configuration {
ContextServerTransport::Stdio(command, working_directory) => Client::stdio(
client::ContextServerId(self.id.0.clone()),
client::ModelContextServerBinary {
@@ -113,8 +133,7 @@ impl ContextServer {
transport.clone(),
cx.clone(),
)?,
- };
- self.initialize(client).await
+ })
}
async fn initialize(&self, client: Client) -> Result<()> {
@@ -83,14 +83,18 @@ impl McpServer {
}
pub fn add_tool<T: McpServerTool + Clone + 'static>(&mut self, tool: T) {
- let output_schema = schemars::schema_for!(T::Output);
- let unit_schema = schemars::schema_for!(());
+ let mut settings = schemars::generate::SchemaSettings::draft07();
+ settings.inline_subschemas = true;
+ let mut generator = settings.into_generator();
+
+ let output_schema = generator.root_schema_for::<T::Output>();
+ let unit_schema = generator.root_schema_for::<T::Output>();
let registered_tool = RegisteredTool {
tool: Tool {
name: T::NAME.into(),
description: Some(tool.description().into()),
- input_schema: schemars::schema_for!(T::Input).into(),
+ input_schema: generator.root_schema_for::<T::Input>().into(),
output_schema: if output_schema == unit_schema {
None
} else {
@@ -115,10 +115,11 @@ impl InitializedContextServerProtocol {
self.inner.notify(T::METHOD, params)
}
- pub fn on_notification<F>(&self, method: &'static str, f: F)
- where
- F: 'static + Send + FnMut(Value, AsyncApp),
- {
+ pub fn on_notification(
+ &self,
+ method: &'static str,
+ f: Box<dyn 'static + Send + FnMut(Value, AsyncApp)>,
+ ) {
self.inner.on_notification(method, f);
}
}