1use agent_client_protocol::{self as acp, Agent as _};
2use collections::HashMap;
3use futures::channel::oneshot;
4use project::Project;
5use std::cell::RefCell;
6use std::path::Path;
7use std::rc::Rc;
8
9use anyhow::{Context as _, Result};
10use gpui::{App, AppContext as _, AsyncApp, Entity, Task, WeakEntity};
11
12use crate::AgentServerCommand;
13use acp_thread::{AcpThread, AgentConnection, AuthRequired};
14
15pub struct AcpConnection {
16 server_name: &'static str,
17 connection: Rc<acp::ClientSideConnection>,
18 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
19 auth_methods: Vec<acp::AuthMethod>,
20 _io_task: Task<Result<()>>,
21}
22
23pub struct AcpSession {
24 thread: WeakEntity<AcpThread>,
25}
26
27impl AcpConnection {
28 pub async fn stdio(
29 server_name: &'static str,
30 command: AgentServerCommand,
31 root_dir: &Path,
32 cx: &mut AsyncApp,
33 ) -> Result<Self> {
34 let mut child = util::command::new_smol_command(&command.path)
35 .args(command.args.iter().map(|arg| arg.as_str()))
36 .envs(command.env.iter().flatten())
37 .current_dir(root_dir)
38 .stdin(std::process::Stdio::piped())
39 .stdout(std::process::Stdio::piped())
40 .stderr(std::process::Stdio::inherit())
41 .kill_on_drop(true)
42 .spawn()?;
43
44 let stdout = child.stdout.take().expect("Failed to take stdout");
45 let stdin = child.stdin.take().expect("Failed to take stdin");
46
47 let sessions = Rc::new(RefCell::new(HashMap::default()));
48
49 let client = ClientDelegate {
50 sessions: sessions.clone(),
51 cx: cx.clone(),
52 };
53 let (connection, io_task) = acp::ClientSideConnection::new(client, stdin, stdout, {
54 let foreground_executor = cx.foreground_executor().clone();
55 move |fut| {
56 foreground_executor.spawn(fut).detach();
57 }
58 });
59
60 let io_task = cx.background_spawn(io_task);
61
62 let response = connection
63 .initialize(acp::InitializeRequest {
64 protocol_version: acp::VERSION,
65 client_capabilities: acp::ClientCapabilities {
66 fs: acp::FileSystemCapability {
67 read_text_file: true,
68 write_text_file: true,
69 },
70 },
71 })
72 .await?;
73
74 // todo! check version
75
76 Ok(Self {
77 auth_methods: response.auth_methods,
78 connection: connection.into(),
79 server_name,
80 sessions,
81 _io_task: io_task,
82 })
83 }
84}
85
86impl AgentConnection for AcpConnection {
87 fn new_thread(
88 self: Rc<Self>,
89 project: Entity<Project>,
90 cwd: &Path,
91 cx: &mut AsyncApp,
92 ) -> Task<Result<Entity<AcpThread>>> {
93 let conn = self.connection.clone();
94 let sessions = self.sessions.clone();
95 let cwd = cwd.to_path_buf();
96 cx.spawn(async move |cx| {
97 let response = conn
98 .new_session(acp::NewSessionRequest {
99 // todo! Zed MCP server?
100 mcp_servers: vec![],
101 cwd,
102 })
103 .await?;
104
105 let Some(session_id) = response.session_id else {
106 anyhow::bail!(AuthRequired);
107 };
108
109 let thread = cx.new(|cx| {
110 AcpThread::new(
111 self.server_name,
112 self.clone(),
113 project,
114 session_id.clone(),
115 cx,
116 )
117 })?;
118
119 let session = AcpSession {
120 thread: thread.downgrade(),
121 };
122 sessions.borrow_mut().insert(session_id, session);
123
124 Ok(thread)
125 })
126 }
127
128 fn auth_methods(&self) -> &[acp::AuthMethod] {
129 &self.auth_methods
130 }
131
132 fn authenticate(&self, method_id: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>> {
133 let conn = self.connection.clone();
134 cx.foreground_executor().spawn(async move {
135 let result = conn
136 .authenticate(acp::AuthenticateRequest {
137 method_id: method_id.clone(),
138 })
139 .await?;
140
141 Ok(result)
142 })
143 }
144
145 fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<Result<()>> {
146 let conn = self.connection.clone();
147 cx.foreground_executor()
148 .spawn(async move { Ok(conn.prompt(params).await?) })
149 }
150
151 fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
152 let conn = self.connection.clone();
153 let params = acp::CancelledNotification {
154 session_id: session_id.clone(),
155 };
156 cx.foreground_executor()
157 .spawn(async move { conn.cancelled(params).await })
158 .detach();
159 }
160}
161
162struct ClientDelegate {
163 sessions: Rc<RefCell<HashMap<acp::SessionId, AcpSession>>>,
164 cx: AsyncApp,
165}
166
167impl acp::Client for ClientDelegate {
168 async fn request_permission(
169 &self,
170 arguments: acp::RequestPermissionRequest,
171 ) -> Result<acp::RequestPermissionResponse, acp::Error> {
172 let cx = &mut self.cx.clone();
173 let result = self
174 .sessions
175 .borrow()
176 .get(&arguments.session_id)
177 .context("Failed to get session")?
178 .thread
179 .update(cx, |thread, cx| {
180 thread.request_tool_call_permission(arguments.tool_call, arguments.options, cx)
181 })?
182 .await;
183
184 let outcome = match result {
185 Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
186 Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
187 };
188
189 Ok(acp::RequestPermissionResponse { outcome })
190 }
191
192 async fn write_text_file(
193 &self,
194 arguments: acp::WriteTextFileRequest,
195 ) -> Result<(), acp::Error> {
196 let cx = &mut self.cx.clone();
197 self.sessions
198 .borrow()
199 .get(&arguments.session_id)
200 .context("Failed to get session")?
201 .thread
202 .update(cx, |thread, cx| {
203 thread.write_text_file(arguments.path, arguments.content, cx)
204 })?
205 .await?;
206
207 Ok(())
208 }
209
210 async fn read_text_file(
211 &self,
212 arguments: acp::ReadTextFileRequest,
213 ) -> Result<acp::ReadTextFileResponse, acp::Error> {
214 let cx = &mut self.cx.clone();
215 let content = self
216 .sessions
217 .borrow()
218 .get(&arguments.session_id)
219 .context("Failed to get session")?
220 .thread
221 .update(cx, |thread, cx| {
222 thread.read_text_file(arguments.path, arguments.line, arguments.limit, false, cx)
223 })?
224 .await?;
225
226 Ok(acp::ReadTextFileResponse { content })
227 }
228
229 async fn session_notification(
230 &self,
231 notification: acp::SessionNotification,
232 ) -> Result<(), acp::Error> {
233 let cx = &mut self.cx.clone();
234 let sessions = self.sessions.borrow();
235 let session = sessions
236 .get(¬ification.session_id)
237 .context("Failed to get session")?;
238
239 session.thread.update(cx, |thread, cx| {
240 thread.handle_session_update(notification.update, cx)
241 })??;
242
243 Ok(())
244 }
245}