1mod server;
2mod thread_view;
3
4use agentic_coding_protocol::{self as acp, Role};
5use anyhow::{Context as _, Result};
6use buffer_diff::BufferDiff;
7use chrono::{DateTime, Utc};
8use editor::MultiBuffer;
9use futures::channel::oneshot;
10use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
11use language::{Buffer, LanguageRegistry};
12use markdown::Markdown;
13use project::Project;
14use std::{mem, ops::Range, path::PathBuf, sync::Arc};
15use ui::{App, IconName};
16use util::{ResultExt, debug_panic};
17
18pub use server::AcpServer;
19pub use thread_view::AcpThreadView;
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub struct ThreadId(SharedString);
23
24#[derive(Copy, Clone, Debug, PartialEq, Eq)]
25pub struct FileVersion(u64);
26
27#[derive(Debug)]
28pub struct AgentThreadSummary {
29 pub id: ThreadId,
30 pub title: String,
31 pub created_at: DateTime<Utc>,
32}
33
34#[derive(Clone, Debug, PartialEq, Eq)]
35pub struct FileContent {
36 pub path: PathBuf,
37 pub version: FileVersion,
38 pub content: SharedString,
39}
40
41#[derive(Clone, Debug, Eq, PartialEq)]
42pub struct Message {
43 pub role: acp::Role,
44 pub chunks: Vec<MessageChunk>,
45}
46
47impl Message {
48 fn into_acp(self, cx: &App) -> acp::Message {
49 acp::Message {
50 role: self.role,
51 chunks: self
52 .chunks
53 .into_iter()
54 .map(|chunk| chunk.into_acp(cx))
55 .collect(),
56 }
57 }
58}
59
60#[derive(Clone, Debug, Eq, PartialEq)]
61pub enum MessageChunk {
62 Text {
63 chunk: Entity<Markdown>,
64 },
65 File {
66 content: FileContent,
67 },
68 Directory {
69 path: PathBuf,
70 contents: Vec<FileContent>,
71 },
72 Symbol {
73 path: PathBuf,
74 range: Range<u64>,
75 version: FileVersion,
76 name: SharedString,
77 content: SharedString,
78 },
79 Fetch {
80 url: SharedString,
81 content: SharedString,
82 },
83}
84
85impl MessageChunk {
86 pub fn from_acp(
87 chunk: acp::MessageChunk,
88 language_registry: Arc<LanguageRegistry>,
89 cx: &mut App,
90 ) -> Self {
91 match chunk {
92 acp::MessageChunk::Text { chunk } => MessageChunk::Text {
93 chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
94 },
95 }
96 }
97
98 pub fn into_acp(self, cx: &App) -> acp::MessageChunk {
99 match self {
100 MessageChunk::Text { chunk } => acp::MessageChunk::Text {
101 chunk: chunk.read(cx).source().to_string(),
102 },
103 MessageChunk::File { .. } => todo!(),
104 MessageChunk::Directory { .. } => todo!(),
105 MessageChunk::Symbol { .. } => todo!(),
106 MessageChunk::Fetch { .. } => todo!(),
107 }
108 }
109
110 pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
111 MessageChunk::Text {
112 chunk: cx.new(|cx| {
113 Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
114 }),
115 }
116 }
117}
118
119#[derive(Debug)]
120pub enum AgentThreadEntryContent {
121 Message(Message),
122 ToolCall(ToolCall),
123}
124
125#[derive(Debug)]
126pub struct ToolCall {
127 id: ToolCallId,
128 label: Entity<Markdown>,
129 icon: IconName,
130 status: ToolCallStatus,
131}
132
133#[derive(Debug)]
134pub enum ToolCallStatus {
135 WaitingForConfirmation {
136 confirmation: acp::ToolCallConfirmation,
137 respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
138 },
139 Allowed {
140 status: acp::ToolCallStatus,
141 content: Option<ToolCallContent>,
142 },
143 Rejected,
144}
145
146#[derive(Debug)]
147pub enum ToolCallContent {
148 Markdown {
149 markdown: Entity<Markdown>,
150 },
151 Diff {
152 path: PathBuf,
153 diff: Entity<BufferDiff>,
154 buffer: Entity<MultiBuffer>,
155 _task: Task<Result<()>>,
156 },
157}
158
159/// A `ThreadEntryId` that is known to be a ToolCall
160#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
161pub struct ToolCallId(ThreadEntryId);
162
163impl ToolCallId {
164 pub fn as_u64(&self) -> u64 {
165 self.0.0
166 }
167}
168
169#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
170pub struct ThreadEntryId(pub u64);
171
172impl ThreadEntryId {
173 pub fn post_inc(&mut self) -> Self {
174 let id = *self;
175 self.0 += 1;
176 id
177 }
178}
179
180#[derive(Debug)]
181pub struct ThreadEntry {
182 pub id: ThreadEntryId,
183 pub content: AgentThreadEntryContent,
184}
185
186pub struct AcpThread {
187 id: ThreadId,
188 next_entry_id: ThreadEntryId,
189 entries: Vec<ThreadEntry>,
190 server: Arc<AcpServer>,
191 title: SharedString,
192 project: Entity<Project>,
193}
194
195enum AcpThreadEvent {
196 NewEntry,
197 EntryUpdated(usize),
198}
199
200impl EventEmitter<AcpThreadEvent> for AcpThread {}
201
202impl AcpThread {
203 pub fn new(
204 server: Arc<AcpServer>,
205 thread_id: ThreadId,
206 entries: Vec<AgentThreadEntryContent>,
207 project: Entity<Project>,
208 _: &mut Context<Self>,
209 ) -> Self {
210 let mut next_entry_id = ThreadEntryId(0);
211 Self {
212 title: "A new agent2 thread".into(),
213 entries: entries
214 .into_iter()
215 .map(|entry| ThreadEntry {
216 id: next_entry_id.post_inc(),
217 content: entry,
218 })
219 .collect(),
220 server,
221 id: thread_id,
222 next_entry_id,
223 project,
224 }
225 }
226
227 pub fn title(&self) -> SharedString {
228 self.title.clone()
229 }
230
231 pub fn entries(&self) -> &[ThreadEntry] {
232 &self.entries
233 }
234
235 pub fn push_entry(
236 &mut self,
237 entry: AgentThreadEntryContent,
238 cx: &mut Context<Self>,
239 ) -> ThreadEntryId {
240 let id = self.next_entry_id.post_inc();
241 self.entries.push(ThreadEntry { id, content: entry });
242 cx.emit(AcpThreadEvent::NewEntry);
243 id
244 }
245
246 pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context<Self>) {
247 let entries_len = self.entries.len();
248 if let Some(last_entry) = self.entries.last_mut()
249 && let AgentThreadEntryContent::Message(Message {
250 ref mut chunks,
251 role: Role::Assistant,
252 }) = last_entry.content
253 {
254 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
255
256 if let (
257 Some(MessageChunk::Text { chunk: old_chunk }),
258 acp::MessageChunk::Text { chunk: new_chunk },
259 ) = (chunks.last_mut(), &chunk)
260 {
261 old_chunk.update(cx, |old_chunk, cx| {
262 old_chunk.append(&new_chunk, cx);
263 });
264 } else {
265 chunks.push(MessageChunk::from_acp(
266 chunk,
267 self.project.read(cx).languages().clone(),
268 cx,
269 ));
270 }
271
272 return;
273 }
274
275 let chunk = MessageChunk::from_acp(chunk, self.project.read(cx).languages().clone(), cx);
276
277 self.push_entry(
278 AgentThreadEntryContent::Message(Message {
279 role: Role::Assistant,
280 chunks: vec![chunk],
281 }),
282 cx,
283 );
284 }
285
286 pub fn request_tool_call(
287 &mut self,
288 label: String,
289 icon: acp::Icon,
290 confirmation: acp::ToolCallConfirmation,
291 cx: &mut Context<Self>,
292 ) -> ToolCallRequest {
293 let (tx, rx) = oneshot::channel();
294
295 let status = ToolCallStatus::WaitingForConfirmation {
296 confirmation,
297 respond_tx: tx,
298 };
299
300 let id = self.insert_tool_call(label, status, icon, cx);
301 ToolCallRequest { id, outcome: rx }
302 }
303
304 pub fn push_tool_call(
305 &mut self,
306 label: String,
307 icon: acp::Icon,
308 cx: &mut Context<Self>,
309 ) -> ToolCallId {
310 let status = ToolCallStatus::Allowed {
311 status: acp::ToolCallStatus::Running,
312 content: None,
313 };
314
315 self.insert_tool_call(label, status, icon, cx)
316 }
317
318 fn insert_tool_call(
319 &mut self,
320 label: String,
321 status: ToolCallStatus,
322 icon: acp::Icon,
323 cx: &mut Context<Self>,
324 ) -> ToolCallId {
325 let language_registry = self.project.read(cx).languages().clone();
326
327 let entry_id = self.push_entry(
328 AgentThreadEntryContent::ToolCall(ToolCall {
329 // todo! clean up id creation
330 id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
331 label: cx.new(|cx| {
332 Markdown::new(label.into(), Some(language_registry.clone()), None, cx)
333 }),
334 icon: acp_icon_to_ui_icon(icon),
335 status,
336 }),
337 cx,
338 );
339
340 ToolCallId(entry_id)
341 }
342
343 pub fn authorize_tool_call(
344 &mut self,
345 id: ToolCallId,
346 outcome: acp::ToolCallConfirmationOutcome,
347 cx: &mut Context<Self>,
348 ) {
349 let Some(entry) = self.entry_mut(id.0) else {
350 return;
351 };
352
353 let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
354 debug_panic!("expected ToolCall");
355 return;
356 };
357
358 let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
359 ToolCallStatus::Rejected
360 } else {
361 ToolCallStatus::Allowed {
362 status: acp::ToolCallStatus::Running,
363 content: None,
364 }
365 };
366
367 let curr_status = mem::replace(&mut call.status, new_status);
368
369 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
370 respond_tx.send(outcome).log_err();
371 } else {
372 debug_panic!("tried to authorize an already authorized tool call");
373 }
374
375 cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
376 }
377
378 pub fn update_tool_call(
379 &mut self,
380 id: ToolCallId,
381 new_status: acp::ToolCallStatus,
382 new_content: Option<acp::ToolCallContent>,
383 cx: &mut Context<Self>,
384 ) -> Result<()> {
385 let language_registry = self.project.read(cx).languages().clone();
386 let entry = self.entry_mut(id.0).context("Entry not found")?;
387
388 match &mut entry.content {
389 AgentThreadEntryContent::ToolCall(call) => match &mut call.status {
390 ToolCallStatus::Allowed { content, status } => {
391 *content = new_content.map(|new_content| match new_content {
392 acp::ToolCallContent::Markdown { markdown } => ToolCallContent::Markdown {
393 markdown: cx.new(|cx| {
394 Markdown::new(
395 markdown.into(),
396 Some(language_registry.clone()),
397 None,
398 cx,
399 )
400 }),
401 },
402 acp::ToolCallContent::Diff {
403 path,
404 old_text,
405 new_text,
406 } => {
407 let buffer = cx.new(|cx| Buffer::local(new_text, cx));
408 let text_snapshot = buffer.read(cx).text_snapshot();
409 let buffer_diff = cx.new(|cx| BufferDiff::new(&text_snapshot, cx));
410
411 let multibuffer = cx.new(|cx| {
412 let mut multibuffer = MultiBuffer::singleton(buffer.clone(), cx);
413 multibuffer.add_diff(buffer_diff.clone(), cx);
414 multibuffer
415 });
416
417 ToolCallContent::Diff {
418 path: path.clone(),
419 diff: buffer_diff.clone(),
420 buffer: multibuffer,
421 _task: cx.spawn(async move |_this, cx| {
422 let diff_snapshot = BufferDiff::update_diff(
423 buffer_diff.clone(),
424 text_snapshot.clone(),
425 old_text.map(|o| o.into()),
426 true,
427 true,
428 None,
429 Some(language_registry.clone()),
430 cx,
431 )
432 .await?;
433
434 buffer_diff.update(cx, |diff, cx| {
435 diff.set_snapshot(diff_snapshot, &text_snapshot, cx)
436 })?;
437
438 if let Some(language) = language_registry
439 .language_for_file_path(&path)
440 .await
441 .log_err()
442 {
443 buffer.update(cx, |buffer, cx| {
444 buffer.set_language(Some(language), cx)
445 })?;
446 }
447
448 anyhow::Ok(())
449 }),
450 }
451 }
452 });
453 *status = new_status;
454 }
455 ToolCallStatus::WaitingForConfirmation { .. } => {
456 anyhow::bail!("Tool call hasn't been authorized yet")
457 }
458 ToolCallStatus::Rejected => {
459 anyhow::bail!("Tool call was rejected and therefore can't be updated")
460 }
461 },
462 _ => anyhow::bail!("Entry is not a tool call"),
463 }
464
465 cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
466 Ok(())
467 }
468
469 fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
470 let entry = self.entries.get_mut(id.0 as usize);
471 debug_assert!(
472 entry.is_some(),
473 "We shouldn't give out ids to entries that don't exist"
474 );
475 entry
476 }
477
478 /// Returns true if the last turn is awaiting tool authorization
479 pub fn waiting_for_tool_confirmation(&self) -> bool {
480 for entry in self.entries.iter().rev() {
481 match &entry.content {
482 AgentThreadEntryContent::ToolCall(call) => match call.status {
483 ToolCallStatus::WaitingForConfirmation { .. } => return true,
484 ToolCallStatus::Allowed { .. } | ToolCallStatus::Rejected => continue,
485 },
486 AgentThreadEntryContent::Message(_) => {
487 // Reached the beginning of the turn
488 return false;
489 }
490 }
491 }
492 false
493 }
494
495 pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> {
496 let agent = self.server.clone();
497 let id = self.id.clone();
498 let chunk = MessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
499 let message = Message {
500 role: Role::User,
501 chunks: vec![chunk],
502 };
503 self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
504 let acp_message = message.into_acp(cx);
505 cx.spawn(async move |_, cx| {
506 agent.send_message(id, acp_message, cx).await?;
507 Ok(())
508 })
509 }
510}
511
512fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
513 match icon {
514 acp::Icon::FileSearch => IconName::FileSearch,
515 acp::Icon::Folder => IconName::Folder,
516 acp::Icon::Globe => IconName::Globe,
517 acp::Icon::Hammer => IconName::Hammer,
518 acp::Icon::LightBulb => IconName::LightBulb,
519 acp::Icon::Pencil => IconName::Pencil,
520 acp::Icon::Regex => IconName::Regex,
521 acp::Icon::Terminal => IconName::Terminal,
522 }
523}
524
525pub struct ToolCallRequest {
526 pub id: ToolCallId,
527 pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533 use futures::{FutureExt as _, channel::mpsc, select};
534 use gpui::{AsyncApp, TestAppContext};
535 use project::FakeFs;
536 use serde_json::json;
537 use settings::SettingsStore;
538 use smol::stream::StreamExt as _;
539 use std::{env, path::Path, process::Stdio, time::Duration};
540 use util::path;
541
542 fn init_test(cx: &mut TestAppContext) {
543 env_logger::try_init().ok();
544 cx.update(|cx| {
545 let settings_store = SettingsStore::test(cx);
546 cx.set_global(settings_store);
547 Project::init_settings(cx);
548 language::init(cx);
549 });
550 }
551
552 #[gpui::test]
553 async fn test_gemini_basic(cx: &mut TestAppContext) {
554 init_test(cx);
555
556 cx.executor().allow_parking();
557
558 let fs = FakeFs::new(cx.executor());
559 let project = Project::test(fs, [], cx).await;
560 let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
561 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
562 thread
563 .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
564 .await
565 .unwrap();
566
567 thread.read_with(cx, |thread, _| {
568 assert_eq!(thread.entries.len(), 2);
569 assert!(matches!(
570 thread.entries[0].content,
571 AgentThreadEntryContent::Message(Message {
572 role: Role::User,
573 ..
574 })
575 ));
576 assert!(matches!(
577 thread.entries[1].content,
578 AgentThreadEntryContent::Message(Message {
579 role: Role::Assistant,
580 ..
581 })
582 ));
583 });
584 }
585
586 #[gpui::test]
587 async fn test_gemini_tool_call(cx: &mut TestAppContext) {
588 init_test(cx);
589
590 cx.executor().allow_parking();
591
592 let fs = FakeFs::new(cx.executor());
593 fs.insert_tree(
594 path!("/private/tmp"),
595 json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
596 )
597 .await;
598 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
599 let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
600 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
601 thread
602 .update(cx, |thread, cx| {
603 thread.send(
604 "Read the '/private/tmp/foo' file and tell me what you see.",
605 cx,
606 )
607 })
608 .await
609 .unwrap();
610 thread.read_with(cx, |thread, _cx| {
611 assert!(matches!(
612 &thread.entries()[1].content,
613 AgentThreadEntryContent::ToolCall(ToolCall {
614 status: ToolCallStatus::Allowed { .. },
615 ..
616 })
617 ));
618
619 assert!(matches!(
620 thread.entries[2].content,
621 AgentThreadEntryContent::Message(Message {
622 role: Role::Assistant,
623 ..
624 })
625 ));
626 });
627 }
628
629 #[gpui::test]
630 async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
631 init_test(cx);
632
633 cx.executor().allow_parking();
634
635 let fs = FakeFs::new(cx.executor());
636 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
637 let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
638 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
639 let full_turn = thread.update(cx, |thread, cx| {
640 thread.send(r#"Run `echo "Hello, world!"`"#, cx)
641 });
642
643 run_until_tool_call(&thread, cx).await;
644
645 let tool_call_id = thread.read_with(cx, |thread, _cx| {
646 let AgentThreadEntryContent::ToolCall(ToolCall {
647 id,
648 status:
649 ToolCallStatus::WaitingForConfirmation {
650 confirmation: acp::ToolCallConfirmation::Execute { root_command, .. },
651 ..
652 },
653 ..
654 }) = &thread.entries()[1].content
655 else {
656 panic!();
657 };
658
659 assert_eq!(root_command, "echo");
660
661 *id
662 });
663
664 thread.update(cx, |thread, cx| {
665 thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
666
667 assert!(matches!(
668 &thread.entries()[1].content,
669 AgentThreadEntryContent::ToolCall(ToolCall {
670 status: ToolCallStatus::Allowed { .. },
671 ..
672 })
673 ));
674 });
675
676 full_turn.await.unwrap();
677
678 thread.read_with(cx, |thread, cx| {
679 let AgentThreadEntryContent::ToolCall(ToolCall {
680 status:
681 ToolCallStatus::Allowed {
682 content: Some(ToolCallContent::Markdown { markdown }),
683 ..
684 },
685 ..
686 }) = &thread.entries()[1].content
687 else {
688 panic!();
689 };
690
691 markdown.read_with(cx, |md, _cx| {
692 assert!(
693 md.source().contains("Hello, world!"),
694 r#"Expected '{}' to contain "Hello, world!""#,
695 md.source()
696 );
697 });
698 });
699 }
700
701 async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
702 let (mut tx, mut rx) = mpsc::channel::<()>(1);
703
704 let subscription = cx.update(|cx| {
705 cx.subscribe(thread, move |thread, _, cx| {
706 if thread
707 .read(cx)
708 .entries
709 .iter()
710 .any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_)))
711 {
712 tx.try_send(()).unwrap();
713 }
714 })
715 });
716
717 select! {
718 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
719 panic!("Timeout waiting for tool call")
720 }
721 _ = rx.next().fuse() => {
722 drop(subscription);
723 }
724 }
725 }
726
727 pub fn gemini_acp_server(project: Entity<Project>, cx: AsyncApp) -> Result<Arc<AcpServer>> {
728 let cli_path =
729 Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
730 let mut command = util::command::new_smol_command("node");
731 command
732 .arg(cli_path)
733 .arg("--acp")
734 .current_dir("/private/tmp")
735 .stdin(Stdio::piped())
736 .stdout(Stdio::piped())
737 .stderr(Stdio::inherit())
738 .kill_on_drop(true);
739
740 if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
741 command.env("GEMINI_API_KEY", gemini_key);
742 }
743
744 let child = command.spawn().unwrap();
745
746 cx.update(|cx| AcpServer::stdio(child, project, cx))
747 }
748}