1mod server;
2mod thread_view;
3
4use agentic_coding_protocol::{self as acp, Role};
5use anyhow::{Context as _, Result};
6use chrono::{DateTime, Utc};
7use futures::channel::oneshot;
8use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
9use language::LanguageRegistry;
10use markdown::Markdown;
11use project::Project;
12use std::{mem, ops::Range, path::PathBuf, sync::Arc};
13use ui::{App, IconName};
14use util::{ResultExt, debug_panic};
15
16pub use server::AcpServer;
17pub use thread_view::AcpThreadView;
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
20pub struct ThreadId(SharedString);
21
22#[derive(Copy, Clone, Debug, PartialEq, Eq)]
23pub struct FileVersion(u64);
24
25#[derive(Debug)]
26pub struct AgentThreadSummary {
27 pub id: ThreadId,
28 pub title: String,
29 pub created_at: DateTime<Utc>,
30}
31
32#[derive(Clone, Debug, PartialEq, Eq)]
33pub struct FileContent {
34 pub path: PathBuf,
35 pub version: FileVersion,
36 pub content: SharedString,
37}
38
39#[derive(Clone, Debug, Eq, PartialEq)]
40pub struct Message {
41 pub role: acp::Role,
42 pub chunks: Vec<MessageChunk>,
43}
44
45impl Message {
46 fn into_acp(self, cx: &App) -> acp::Message {
47 acp::Message {
48 role: self.role,
49 chunks: self
50 .chunks
51 .into_iter()
52 .map(|chunk| chunk.into_acp(cx))
53 .collect(),
54 }
55 }
56}
57
58#[derive(Clone, Debug, Eq, PartialEq)]
59pub enum MessageChunk {
60 Text {
61 chunk: Entity<Markdown>,
62 },
63 File {
64 content: FileContent,
65 },
66 Directory {
67 path: PathBuf,
68 contents: Vec<FileContent>,
69 },
70 Symbol {
71 path: PathBuf,
72 range: Range<u64>,
73 version: FileVersion,
74 name: SharedString,
75 content: SharedString,
76 },
77 Fetch {
78 url: SharedString,
79 content: SharedString,
80 },
81}
82
83impl MessageChunk {
84 pub fn from_acp(
85 chunk: acp::MessageChunk,
86 language_registry: Arc<LanguageRegistry>,
87 cx: &mut App,
88 ) -> Self {
89 match chunk {
90 acp::MessageChunk::Text { chunk } => MessageChunk::Text {
91 chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
92 },
93 }
94 }
95
96 pub fn into_acp(self, cx: &App) -> acp::MessageChunk {
97 match self {
98 MessageChunk::Text { chunk } => acp::MessageChunk::Text {
99 chunk: chunk.read(cx).source().to_string(),
100 },
101 MessageChunk::File { .. } => todo!(),
102 MessageChunk::Directory { .. } => todo!(),
103 MessageChunk::Symbol { .. } => todo!(),
104 MessageChunk::Fetch { .. } => todo!(),
105 }
106 }
107
108 pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
109 MessageChunk::Text {
110 chunk: cx.new(|cx| {
111 Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
112 }),
113 }
114 }
115}
116
117#[derive(Debug)]
118pub enum AgentThreadEntryContent {
119 Message(Message),
120 ToolCall(ToolCall),
121}
122
123#[derive(Debug)]
124pub struct ToolCall {
125 id: ToolCallId,
126 label: Entity<Markdown>,
127 icon: IconName,
128 status: ToolCallStatus,
129}
130
131#[derive(Debug)]
132pub enum ToolCallStatus {
133 WaitingForConfirmation {
134 confirmation: acp::ToolCallConfirmation,
135 respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
136 },
137 // todo! Running?
138 Allowed {
139 // todo! should this be variants in crate::ToolCallStatus instead?
140 status: acp::ToolCallStatus,
141 content: Option<Entity<Markdown>>,
142 },
143 Rejected,
144}
145
146/// A `ThreadEntryId` that is known to be a ToolCall
147#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
148pub struct ToolCallId(ThreadEntryId);
149
150impl ToolCallId {
151 pub fn as_u64(&self) -> u64 {
152 self.0.0
153 }
154}
155
156#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
157pub struct ThreadEntryId(pub u64);
158
159impl ThreadEntryId {
160 pub fn post_inc(&mut self) -> Self {
161 let id = *self;
162 self.0 += 1;
163 id
164 }
165}
166
167#[derive(Debug)]
168pub struct ThreadEntry {
169 pub id: ThreadEntryId,
170 pub content: AgentThreadEntryContent,
171}
172
173pub struct AcpThread {
174 id: ThreadId,
175 next_entry_id: ThreadEntryId,
176 entries: Vec<ThreadEntry>,
177 server: Arc<AcpServer>,
178 title: SharedString,
179 project: Entity<Project>,
180}
181
182enum AcpThreadEvent {
183 NewEntry,
184 EntryUpdated(usize),
185}
186
187impl EventEmitter<AcpThreadEvent> for AcpThread {}
188
189impl AcpThread {
190 pub fn new(
191 server: Arc<AcpServer>,
192 thread_id: ThreadId,
193 entries: Vec<AgentThreadEntryContent>,
194 project: Entity<Project>,
195 _: &mut Context<Self>,
196 ) -> Self {
197 let mut next_entry_id = ThreadEntryId(0);
198 Self {
199 title: "A new agent2 thread".into(),
200 entries: entries
201 .into_iter()
202 .map(|entry| ThreadEntry {
203 id: next_entry_id.post_inc(),
204 content: entry,
205 })
206 .collect(),
207 server,
208 id: thread_id,
209 next_entry_id,
210 project,
211 }
212 }
213
214 pub fn title(&self) -> SharedString {
215 self.title.clone()
216 }
217
218 pub fn entries(&self) -> &[ThreadEntry] {
219 &self.entries
220 }
221
222 pub fn push_entry(
223 &mut self,
224 entry: AgentThreadEntryContent,
225 cx: &mut Context<Self>,
226 ) -> ThreadEntryId {
227 let id = self.next_entry_id.post_inc();
228 self.entries.push(ThreadEntry { id, content: entry });
229 cx.emit(AcpThreadEvent::NewEntry);
230 id
231 }
232
233 pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context<Self>) {
234 let entries_len = self.entries.len();
235 if let Some(last_entry) = self.entries.last_mut()
236 && let AgentThreadEntryContent::Message(Message {
237 ref mut chunks,
238 role: Role::Assistant,
239 }) = last_entry.content
240 {
241 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
242
243 if let (
244 Some(MessageChunk::Text { chunk: old_chunk }),
245 acp::MessageChunk::Text { chunk: new_chunk },
246 ) = (chunks.last_mut(), &chunk)
247 {
248 old_chunk.update(cx, |old_chunk, cx| {
249 old_chunk.append(&new_chunk, cx);
250 });
251 } else {
252 chunks.push(MessageChunk::from_acp(
253 chunk,
254 self.project.read(cx).languages().clone(),
255 cx,
256 ));
257 }
258
259 return;
260 }
261
262 let chunk = MessageChunk::from_acp(chunk, self.project.read(cx).languages().clone(), cx);
263
264 self.push_entry(
265 AgentThreadEntryContent::Message(Message {
266 role: Role::Assistant,
267 chunks: vec![chunk],
268 }),
269 cx,
270 );
271 }
272
273 pub fn request_tool_call(
274 &mut self,
275 label: String,
276 icon: acp::Icon,
277 confirmation: acp::ToolCallConfirmation,
278 cx: &mut Context<Self>,
279 ) -> ToolCallRequest {
280 let (tx, rx) = oneshot::channel();
281
282 let status = ToolCallStatus::WaitingForConfirmation {
283 confirmation,
284 respond_tx: tx,
285 };
286
287 let id = self.insert_tool_call(label, status, icon, cx);
288 ToolCallRequest { id, outcome: rx }
289 }
290
291 pub fn push_tool_call(
292 &mut self,
293 label: String,
294 icon: acp::Icon,
295 cx: &mut Context<Self>,
296 ) -> ToolCallId {
297 let status = ToolCallStatus::Allowed {
298 status: acp::ToolCallStatus::Running,
299 content: None,
300 };
301
302 self.insert_tool_call(label, status, icon, cx)
303 }
304
305 fn insert_tool_call(
306 &mut self,
307 label: String,
308 status: ToolCallStatus,
309 icon: acp::Icon,
310 cx: &mut Context<Self>,
311 ) -> ToolCallId {
312 let language_registry = self.project.read(cx).languages().clone();
313
314 let entry_id = self.push_entry(
315 AgentThreadEntryContent::ToolCall(ToolCall {
316 // todo! clean up id creation
317 id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
318 label: cx.new(|cx| {
319 Markdown::new(label.into(), Some(language_registry.clone()), None, cx)
320 }),
321 icon: acp_icon_to_ui_icon(icon),
322 status,
323 }),
324 cx,
325 );
326
327 ToolCallId(entry_id)
328 }
329
330 pub fn authorize_tool_call(
331 &mut self,
332 id: ToolCallId,
333 outcome: acp::ToolCallConfirmationOutcome,
334 cx: &mut Context<Self>,
335 ) {
336 let Some(entry) = self.entry_mut(id.0) else {
337 return;
338 };
339
340 let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
341 debug_panic!("expected ToolCall");
342 return;
343 };
344
345 let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
346 ToolCallStatus::Rejected
347 } else {
348 ToolCallStatus::Allowed {
349 status: acp::ToolCallStatus::Running,
350 content: None,
351 }
352 };
353
354 let curr_status = mem::replace(&mut call.status, new_status);
355
356 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
357 respond_tx.send(outcome).log_err();
358 } else {
359 debug_panic!("tried to authorize an already authorized tool call");
360 }
361
362 cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
363 }
364
365 pub fn update_tool_call(
366 &mut self,
367 id: ToolCallId,
368 new_status: acp::ToolCallStatus,
369 new_content: Option<acp::ToolCallContent>,
370 cx: &mut Context<Self>,
371 ) -> Result<()> {
372 let language_registry = self.project.read(cx).languages().clone();
373 let entry = self.entry_mut(id.0).context("Entry not found")?;
374
375 match &mut entry.content {
376 AgentThreadEntryContent::ToolCall(call) => match &mut call.status {
377 ToolCallStatus::Allowed { content, status } => {
378 *content = new_content.map(|new_content| {
379 let acp::ToolCallContent::Markdown { markdown } = new_content;
380
381 cx.new(|cx| {
382 Markdown::new(markdown.into(), Some(language_registry), None, cx)
383 })
384 });
385
386 *status = new_status;
387 }
388 ToolCallStatus::WaitingForConfirmation { .. } => {
389 anyhow::bail!("Tool call hasn't been authorized yet")
390 }
391 ToolCallStatus::Rejected => {
392 anyhow::bail!("Tool call was rejected and therefore can't be updated")
393 }
394 },
395 _ => anyhow::bail!("Entry is not a tool call"),
396 }
397
398 cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
399 Ok(())
400 }
401
402 fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
403 let entry = self.entries.get_mut(id.0 as usize);
404 debug_assert!(
405 entry.is_some(),
406 "We shouldn't give out ids to entries that don't exist"
407 );
408 entry
409 }
410
411 /// Returns true if the last turn is awaiting tool authorization
412 pub fn waiting_for_tool_confirmation(&self) -> bool {
413 for entry in self.entries.iter().rev() {
414 match &entry.content {
415 AgentThreadEntryContent::ToolCall(call) => match call.status {
416 ToolCallStatus::WaitingForConfirmation { .. } => return true,
417 ToolCallStatus::Allowed { .. } | ToolCallStatus::Rejected => continue,
418 },
419 AgentThreadEntryContent::Message(_) => {
420 // Reached the beginning of the turn
421 return false;
422 }
423 }
424 }
425 false
426 }
427
428 pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> {
429 let agent = self.server.clone();
430 let id = self.id.clone();
431 let chunk = MessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
432 let message = Message {
433 role: Role::User,
434 chunks: vec![chunk],
435 };
436 self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
437 let acp_message = message.into_acp(cx);
438 cx.spawn(async move |_, cx| {
439 agent.send_message(id, acp_message, cx).await?;
440 Ok(())
441 })
442 }
443}
444
445fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
446 match icon {
447 acp::Icon::FileSearch => IconName::FileSearch,
448 acp::Icon::Folder => IconName::Folder,
449 acp::Icon::Globe => IconName::Globe,
450 acp::Icon::Hammer => IconName::Hammer,
451 acp::Icon::LightBulb => IconName::LightBulb,
452 acp::Icon::Pencil => IconName::Pencil,
453 acp::Icon::Regex => IconName::Regex,
454 acp::Icon::Terminal => IconName::Terminal,
455 }
456}
457
458pub struct ToolCallRequest {
459 pub id: ToolCallId,
460 pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466 use futures::{FutureExt as _, channel::mpsc, select};
467 use gpui::{AsyncApp, TestAppContext};
468 use project::FakeFs;
469 use serde_json::json;
470 use settings::SettingsStore;
471 use smol::stream::StreamExt as _;
472 use std::{env, path::Path, process::Stdio, time::Duration};
473 use util::path;
474
475 fn init_test(cx: &mut TestAppContext) {
476 env_logger::try_init().ok();
477 cx.update(|cx| {
478 let settings_store = SettingsStore::test(cx);
479 cx.set_global(settings_store);
480 Project::init_settings(cx);
481 language::init(cx);
482 });
483 }
484
485 #[gpui::test]
486 async fn test_gemini_basic(cx: &mut TestAppContext) {
487 init_test(cx);
488
489 cx.executor().allow_parking();
490
491 let fs = FakeFs::new(cx.executor());
492 let project = Project::test(fs, [], cx).await;
493 let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
494 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
495 thread
496 .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
497 .await
498 .unwrap();
499
500 thread.read_with(cx, |thread, _| {
501 assert_eq!(thread.entries.len(), 2);
502 assert!(matches!(
503 thread.entries[0].content,
504 AgentThreadEntryContent::Message(Message {
505 role: Role::User,
506 ..
507 })
508 ));
509 assert!(matches!(
510 thread.entries[1].content,
511 AgentThreadEntryContent::Message(Message {
512 role: Role::Assistant,
513 ..
514 })
515 ));
516 });
517 }
518
519 #[gpui::test]
520 async fn test_gemini_tool_call(cx: &mut TestAppContext) {
521 init_test(cx);
522
523 cx.executor().allow_parking();
524
525 let fs = FakeFs::new(cx.executor());
526 fs.insert_tree(
527 path!("/private/tmp"),
528 json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
529 )
530 .await;
531 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
532 let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
533 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
534 thread
535 .update(cx, |thread, cx| {
536 thread.send(
537 "Read the '/private/tmp/foo' file and tell me what you see.",
538 cx,
539 )
540 })
541 .await
542 .unwrap();
543 thread.read_with(cx, |thread, _cx| {
544 assert!(matches!(
545 &thread.entries()[1].content,
546 AgentThreadEntryContent::ToolCall(ToolCall {
547 status: ToolCallStatus::Allowed { .. },
548 ..
549 })
550 ));
551
552 assert!(matches!(
553 thread.entries[2].content,
554 AgentThreadEntryContent::Message(Message {
555 role: Role::Assistant,
556 ..
557 })
558 ));
559 });
560 }
561
562 #[gpui::test]
563 async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
564 init_test(cx);
565
566 cx.executor().allow_parking();
567
568 let fs = FakeFs::new(cx.executor());
569 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
570 let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
571 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
572 let full_turn = thread.update(cx, |thread, cx| {
573 thread.send(r#"Run `echo "Hello, world!"`"#, cx)
574 });
575
576 run_until_tool_call(&thread, cx).await;
577
578 let tool_call_id = thread.read_with(cx, |thread, _cx| {
579 let AgentThreadEntryContent::ToolCall(ToolCall {
580 id,
581 status:
582 ToolCallStatus::WaitingForConfirmation {
583 confirmation: acp::ToolCallConfirmation::Execute { root_command, .. },
584 ..
585 },
586 ..
587 }) = &thread.entries()[1].content
588 else {
589 panic!();
590 };
591
592 assert_eq!(root_command, "echo");
593
594 *id
595 });
596
597 thread.update(cx, |thread, cx| {
598 thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
599
600 assert!(matches!(
601 &thread.entries()[1].content,
602 AgentThreadEntryContent::ToolCall(ToolCall {
603 status: ToolCallStatus::Allowed { .. },
604 ..
605 })
606 ));
607 });
608
609 full_turn.await.unwrap();
610
611 thread.read_with(cx, |thread, cx| {
612 let AgentThreadEntryContent::ToolCall(ToolCall {
613 status: ToolCallStatus::Allowed { content, .. },
614 ..
615 }) = &thread.entries()[1].content
616 else {
617 panic!();
618 };
619
620 content.as_ref().unwrap().read_with(cx, |md, _cx| {
621 assert!(
622 md.source().contains("Hello, world!"),
623 r#"Expected '{}' to contain "Hello, world!""#,
624 md.source()
625 );
626 });
627 });
628 }
629
630 async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
631 let (mut tx, mut rx) = mpsc::channel::<()>(1);
632
633 let subscription = cx.update(|cx| {
634 cx.subscribe(thread, move |thread, _, cx| {
635 if thread
636 .read(cx)
637 .entries
638 .iter()
639 .any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_)))
640 {
641 tx.try_send(()).unwrap();
642 }
643 })
644 });
645
646 select! {
647 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
648 panic!("Timeout waiting for tool call")
649 }
650 _ = rx.next().fuse() => {
651 drop(subscription);
652 }
653 }
654 }
655
656 pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> {
657 let cli_path =
658 Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
659 let mut command = util::command::new_smol_command("node");
660 command
661 .arg(cli_path)
662 .arg("--acp")
663 .current_dir("/private/tmp")
664 .stdin(Stdio::piped())
665 .stdout(Stdio::piped())
666 .stderr(Stdio::inherit())
667 .kill_on_drop(true);
668
669 if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
670 command.env("GEMINI_API_KEY", gemini_key);
671 }
672
673 let child = command.spawn().unwrap();
674
675 Ok(AcpServer::stdio(child, project, &mut cx))
676 }
677}