1mod server;
2mod thread_view;
3
4use agentic_coding_protocol::{self as acp, Role};
5use anyhow::Result;
6use chrono::{DateTime, Utc};
7use futures::channel::oneshot;
8use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
9use language::LanguageRegistry;
10use markdown::Markdown;
11use project::Project;
12use std::{mem, ops::Range, path::PathBuf, sync::Arc};
13use ui::App;
14use util::{ResultExt, debug_panic};
15
16pub use server::AcpServer;
17pub use thread_view::AcpThreadView;
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
20pub struct ThreadId(SharedString);
21
22#[derive(Copy, Clone, Debug, PartialEq, Eq)]
23pub struct FileVersion(u64);
24
25#[derive(Debug)]
26pub struct AgentThreadSummary {
27 pub id: ThreadId,
28 pub title: String,
29 pub created_at: DateTime<Utc>,
30}
31
32#[derive(Clone, Debug, PartialEq, Eq)]
33pub struct FileContent {
34 pub path: PathBuf,
35 pub version: FileVersion,
36 pub content: SharedString,
37}
38
39#[derive(Clone, Debug, Eq, PartialEq)]
40pub struct Message {
41 pub role: acp::Role,
42 pub chunks: Vec<MessageChunk>,
43}
44
45impl Message {
46 fn into_acp(self, cx: &App) -> acp::Message {
47 acp::Message {
48 role: self.role,
49 chunks: self
50 .chunks
51 .into_iter()
52 .map(|chunk| chunk.into_acp(cx))
53 .collect(),
54 }
55 }
56}
57
58#[derive(Clone, Debug, Eq, PartialEq)]
59pub enum MessageChunk {
60 Text {
61 chunk: Entity<Markdown>,
62 },
63 File {
64 content: FileContent,
65 },
66 Directory {
67 path: PathBuf,
68 contents: Vec<FileContent>,
69 },
70 Symbol {
71 path: PathBuf,
72 range: Range<u64>,
73 version: FileVersion,
74 name: SharedString,
75 content: SharedString,
76 },
77 Fetch {
78 url: SharedString,
79 content: SharedString,
80 },
81}
82
83impl MessageChunk {
84 pub fn from_acp(
85 chunk: acp::MessageChunk,
86 language_registry: Arc<LanguageRegistry>,
87 cx: &mut App,
88 ) -> Self {
89 match chunk {
90 acp::MessageChunk::Text { chunk } => MessageChunk::Text {
91 chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
92 },
93 }
94 }
95
96 pub fn into_acp(self, cx: &App) -> acp::MessageChunk {
97 match self {
98 MessageChunk::Text { chunk } => acp::MessageChunk::Text {
99 chunk: chunk.read(cx).source().to_string(),
100 },
101 MessageChunk::File { .. } => todo!(),
102 MessageChunk::Directory { .. } => todo!(),
103 MessageChunk::Symbol { .. } => todo!(),
104 MessageChunk::Fetch { .. } => todo!(),
105 }
106 }
107
108 pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
109 MessageChunk::Text {
110 chunk: cx.new(|cx| {
111 Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
112 }),
113 }
114 }
115}
116
117#[derive(Debug)]
118pub enum AgentThreadEntryContent {
119 Message(Message),
120 ToolCall(ToolCall),
121}
122
123#[derive(Debug)]
124pub struct ToolCall {
125 id: ToolCallId,
126 tool_name: Entity<Markdown>,
127 status: ToolCallStatus,
128}
129
130#[derive(Debug)]
131pub enum ToolCallStatus {
132 WaitingForConfirmation {
133 description: Entity<Markdown>,
134 respond_tx: oneshot::Sender<bool>,
135 },
136 // todo! Running?
137 Allowed,
138 Rejected,
139}
140
141/// A `ThreadEntryId` that is known to be a ToolCall
142#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
143pub struct ToolCallId(ThreadEntryId);
144
145impl ToolCallId {
146 pub fn as_u64(&self) -> u64 {
147 self.0.0
148 }
149}
150
151#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
152pub struct ThreadEntryId(pub u64);
153
154impl ThreadEntryId {
155 pub fn post_inc(&mut self) -> Self {
156 let id = *self;
157 self.0 += 1;
158 id
159 }
160}
161
162#[derive(Debug)]
163pub struct ThreadEntry {
164 pub id: ThreadEntryId,
165 pub content: AgentThreadEntryContent,
166}
167
168pub struct AcpThread {
169 id: ThreadId,
170 next_entry_id: ThreadEntryId,
171 entries: Vec<ThreadEntry>,
172 server: Arc<AcpServer>,
173 title: SharedString,
174 project: Entity<Project>,
175}
176
177enum AcpThreadEvent {
178 NewEntry,
179 EntryUpdated(usize),
180}
181
182impl EventEmitter<AcpThreadEvent> for AcpThread {}
183
184impl AcpThread {
185 pub fn new(
186 server: Arc<AcpServer>,
187 thread_id: ThreadId,
188 entries: Vec<AgentThreadEntryContent>,
189 project: Entity<Project>,
190 _: &mut Context<Self>,
191 ) -> Self {
192 let mut next_entry_id = ThreadEntryId(0);
193 Self {
194 title: "A new agent2 thread".into(),
195 entries: entries
196 .into_iter()
197 .map(|entry| ThreadEntry {
198 id: next_entry_id.post_inc(),
199 content: entry,
200 })
201 .collect(),
202 server,
203 id: thread_id,
204 next_entry_id,
205 project,
206 }
207 }
208
209 pub fn title(&self) -> SharedString {
210 self.title.clone()
211 }
212
213 pub fn entries(&self) -> &[ThreadEntry] {
214 &self.entries
215 }
216
217 pub fn push_entry(
218 &mut self,
219 entry: AgentThreadEntryContent,
220 cx: &mut Context<Self>,
221 ) -> ThreadEntryId {
222 let id = self.next_entry_id.post_inc();
223 self.entries.push(ThreadEntry { id, content: entry });
224 cx.emit(AcpThreadEvent::NewEntry);
225 id
226 }
227
228 pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context<Self>) {
229 let entries_len = self.entries.len();
230 if let Some(last_entry) = self.entries.last_mut()
231 && let AgentThreadEntryContent::Message(Message {
232 ref mut chunks,
233 role: Role::Assistant,
234 }) = last_entry.content
235 {
236 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
237
238 if let (
239 Some(MessageChunk::Text { chunk: old_chunk }),
240 acp::MessageChunk::Text { chunk: new_chunk },
241 ) = (chunks.last_mut(), &chunk)
242 {
243 old_chunk.update(cx, |old_chunk, cx| {
244 old_chunk.append(&new_chunk, cx);
245 });
246 } else {
247 chunks.push(MessageChunk::from_acp(
248 chunk,
249 self.project.read(cx).languages().clone(),
250 cx,
251 ));
252 }
253
254 return;
255 }
256
257 let chunk = MessageChunk::from_acp(chunk, self.project.read(cx).languages().clone(), cx);
258
259 self.push_entry(
260 AgentThreadEntryContent::Message(Message {
261 role: Role::Assistant,
262 chunks: vec![chunk],
263 }),
264 cx,
265 );
266 }
267
268 pub fn push_tool_call(
269 &mut self,
270 title: String,
271 description: String,
272 respond_tx: oneshot::Sender<bool>,
273 cx: &mut Context<Self>,
274 ) -> ToolCallId {
275 let language_registry = self.project.read(cx).languages().clone();
276
277 let entry_id = self.push_entry(
278 AgentThreadEntryContent::ToolCall(ToolCall {
279 // todo! clean up id creation
280 id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
281 tool_name: cx.new(|cx| {
282 Markdown::new(title.into(), Some(language_registry.clone()), None, cx)
283 }),
284 status: ToolCallStatus::WaitingForConfirmation {
285 description: cx.new(|cx| {
286 Markdown::new(
287 description.into(),
288 Some(language_registry.clone()),
289 None,
290 cx,
291 )
292 }),
293 respond_tx,
294 },
295 }),
296 cx,
297 );
298
299 ToolCallId(entry_id)
300 }
301
302 pub fn authorize_tool_call(&mut self, id: ToolCallId, allowed: bool, cx: &mut Context<Self>) {
303 let Some(entry) = self.entry_mut(id.0) else {
304 return;
305 };
306
307 let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
308 debug_panic!("expected ToolCall");
309 return;
310 };
311
312 let new_status = if allowed {
313 ToolCallStatus::Allowed
314 } else {
315 ToolCallStatus::Rejected
316 };
317
318 let curr_status = mem::replace(&mut call.status, new_status);
319
320 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
321 respond_tx.send(allowed).log_err();
322 } else {
323 debug_panic!("tried to authorize an already authorized tool call");
324 }
325
326 cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
327 }
328
329 fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
330 let entry = self.entries.get_mut(id.0 as usize);
331 debug_assert!(
332 entry.is_some(),
333 "We shouldn't give out ids to entries that don't exist"
334 );
335 entry
336 }
337
338 /// Returns true if the last turn is awaiting tool authorization
339 pub fn waiting_for_tool_confirmation(&self) -> bool {
340 for entry in self.entries.iter().rev() {
341 match &entry.content {
342 AgentThreadEntryContent::ToolCall(call) => match call.status {
343 ToolCallStatus::WaitingForConfirmation { .. } => return true,
344 ToolCallStatus::Allowed | ToolCallStatus::Rejected => continue,
345 },
346 AgentThreadEntryContent::Message(_) => {
347 // Reached the beginning of the turn
348 return false;
349 }
350 }
351 }
352 false
353 }
354
355 pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> {
356 let agent = self.server.clone();
357 let id = self.id.clone();
358 let chunk = MessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
359 let message = Message {
360 role: Role::User,
361 chunks: vec![chunk],
362 };
363 self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
364 let acp_message = message.into_acp(cx);
365 cx.spawn(async move |_, cx| {
366 agent.send_message(id, acp_message, cx).await?;
367 Ok(())
368 })
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375 use futures::{FutureExt as _, channel::mpsc, select};
376 use gpui::{AsyncApp, TestAppContext};
377 use project::FakeFs;
378 use serde_json::json;
379 use settings::SettingsStore;
380 use smol::stream::StreamExt;
381 use std::{env, path::Path, process::Stdio, time::Duration};
382 use util::path;
383
384 fn init_test(cx: &mut TestAppContext) {
385 env_logger::try_init().ok();
386 cx.update(|cx| {
387 let settings_store = SettingsStore::test(cx);
388 cx.set_global(settings_store);
389 Project::init_settings(cx);
390 language::init(cx);
391 });
392 }
393
394 #[gpui::test]
395 async fn test_gemini_basic(cx: &mut TestAppContext) {
396 init_test(cx);
397
398 cx.executor().allow_parking();
399
400 let fs = FakeFs::new(cx.executor());
401 let project = Project::test(fs, [], cx).await;
402 let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
403 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
404 thread
405 .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
406 .await
407 .unwrap();
408
409 thread.read_with(cx, |thread, _| {
410 assert_eq!(thread.entries.len(), 2);
411 assert!(matches!(
412 thread.entries[0].content,
413 AgentThreadEntryContent::Message(Message {
414 role: Role::User,
415 ..
416 })
417 ));
418 assert!(matches!(
419 thread.entries[1].content,
420 AgentThreadEntryContent::Message(Message {
421 role: Role::Assistant,
422 ..
423 })
424 ));
425 });
426 }
427
428 #[gpui::test]
429 async fn test_gemini_tool_call(cx: &mut TestAppContext) {
430 init_test(cx);
431
432 cx.executor().allow_parking();
433
434 let fs = FakeFs::new(cx.executor());
435 fs.insert_tree(
436 path!("/private/tmp"),
437 json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
438 )
439 .await;
440 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
441 let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
442 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
443 let full_turn = thread.update(cx, |thread, cx| {
444 thread.send(
445 "Read the '/private/tmp/foo' file and tell me what you see.",
446 cx,
447 )
448 });
449
450 run_until_tool_call(&thread, cx).await;
451
452 let tool_call_id = thread.read_with(cx, |thread, cx| {
453 let AgentThreadEntryContent::ToolCall(ToolCall {
454 id,
455 tool_name,
456 status: ToolCallStatus::WaitingForConfirmation { description, .. },
457 }) = &thread.entries().last().unwrap().content
458 else {
459 panic!();
460 };
461
462 tool_name.read_with(cx, |md, _cx| {
463 assert_eq!(md.source(), "read_file");
464 });
465
466 description.read_with(cx, |md, _cx| {
467 assert!(
468 md.source().contains("foo"),
469 "Expected description to contain 'foo', but got {}",
470 md.source()
471 );
472 });
473 *id
474 });
475
476 thread.update(cx, |thread, cx| {
477 thread.authorize_tool_call(tool_call_id, true, cx);
478 assert!(matches!(
479 thread.entries().last().unwrap().content,
480 AgentThreadEntryContent::ToolCall(ToolCall {
481 status: ToolCallStatus::Allowed,
482 ..
483 })
484 ));
485 });
486
487 full_turn.await.unwrap();
488
489 thread.read_with(cx, |thread, _| {
490 assert!(thread.entries.len() >= 3, "{:?}", &thread.entries);
491 assert!(matches!(
492 thread.entries[0].content,
493 AgentThreadEntryContent::Message(Message {
494 role: Role::User,
495 ..
496 })
497 ));
498 assert!(matches!(
499 thread.entries[1].content,
500 AgentThreadEntryContent::ToolCall(ToolCall {
501 status: ToolCallStatus::Allowed,
502 ..
503 })
504 ));
505 assert!(matches!(
506 thread.entries[2].content,
507 AgentThreadEntryContent::Message(Message {
508 role: Role::Assistant,
509 ..
510 })
511 ));
512 });
513 }
514
515 async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
516 let (mut tx, mut rx) = mpsc::channel(1);
517
518 let subscription = cx.update(|cx| {
519 cx.subscribe(thread, move |thread, _, cx| {
520 if thread
521 .read(cx)
522 .entries
523 .iter()
524 .any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_)))
525 {
526 tx.try_send(()).unwrap();
527 }
528 })
529 });
530
531 select! {
532 _ = cx.executor().timer(Duration::from_secs(5)).fuse() => {
533 panic!("Timeout waiting for tool call")
534 }
535 _ = rx.next().fuse() => {
536 drop(subscription);
537 }
538 }
539 }
540
541 pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> {
542 let cli_path =
543 Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
544 let mut command = util::command::new_smol_command("node");
545 command
546 .arg(cli_path)
547 .arg("--acp")
548 .args(["--model", "gemini-2.5-flash"])
549 .current_dir("/private/tmp")
550 .stdin(Stdio::piped())
551 .stdout(Stdio::piped())
552 .stderr(Stdio::inherit())
553 .kill_on_drop(true);
554
555 if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
556 command.env("GEMINI_API_KEY", gemini_key);
557 }
558
559 let child = command.spawn().unwrap();
560
561 Ok(AcpServer::stdio(child, project, &mut cx))
562 }
563}