1mod server;
2mod thread_view;
3
4use agentic_coding_protocol::{self as acp, Role};
5use anyhow::{Context as _, Result};
6use buffer_diff::BufferDiff;
7use chrono::{DateTime, Utc};
8use editor::MultiBuffer;
9use futures::channel::oneshot;
10use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
11use language::{Buffer, LanguageRegistry};
12use markdown::Markdown;
13use project::Project;
14use std::{mem, ops::Range, path::PathBuf, sync::Arc};
15use ui::{App, IconName};
16use util::{ResultExt, debug_panic};
17
18pub use server::AcpServer;
19pub use thread_view::AcpThreadView;
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub struct ThreadId(SharedString);
23
24#[derive(Copy, Clone, Debug, PartialEq, Eq)]
25pub struct FileVersion(u64);
26
27#[derive(Debug)]
28pub struct AgentThreadSummary {
29 pub id: ThreadId,
30 pub title: String,
31 pub created_at: DateTime<Utc>,
32}
33
34#[derive(Clone, Debug, PartialEq, Eq)]
35pub struct FileContent {
36 pub path: PathBuf,
37 pub version: FileVersion,
38 pub content: SharedString,
39}
40
41#[derive(Clone, Debug, Eq, PartialEq)]
42pub struct Message {
43 pub role: acp::Role,
44 pub chunks: Vec<MessageChunk>,
45}
46
47impl Message {
48 fn into_acp(self, cx: &App) -> acp::Message {
49 acp::Message {
50 role: self.role,
51 chunks: self
52 .chunks
53 .into_iter()
54 .map(|chunk| chunk.into_acp(cx))
55 .collect(),
56 }
57 }
58}
59
60#[derive(Clone, Debug, Eq, PartialEq)]
61pub enum MessageChunk {
62 Text {
63 chunk: Entity<Markdown>,
64 },
65 File {
66 content: FileContent,
67 },
68 Directory {
69 path: PathBuf,
70 contents: Vec<FileContent>,
71 },
72 Symbol {
73 path: PathBuf,
74 range: Range<u64>,
75 version: FileVersion,
76 name: SharedString,
77 content: SharedString,
78 },
79 Fetch {
80 url: SharedString,
81 content: SharedString,
82 },
83}
84
85impl MessageChunk {
86 pub fn from_acp(
87 chunk: acp::MessageChunk,
88 language_registry: Arc<LanguageRegistry>,
89 cx: &mut App,
90 ) -> Self {
91 match chunk {
92 acp::MessageChunk::Text { chunk } => MessageChunk::Text {
93 chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
94 },
95 }
96 }
97
98 pub fn into_acp(self, cx: &App) -> acp::MessageChunk {
99 match self {
100 MessageChunk::Text { chunk } => acp::MessageChunk::Text {
101 chunk: chunk.read(cx).source().to_string(),
102 },
103 MessageChunk::File { .. } => todo!(),
104 MessageChunk::Directory { .. } => todo!(),
105 MessageChunk::Symbol { .. } => todo!(),
106 MessageChunk::Fetch { .. } => todo!(),
107 }
108 }
109
110 pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
111 MessageChunk::Text {
112 chunk: cx.new(|cx| {
113 Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
114 }),
115 }
116 }
117}
118
119#[derive(Debug)]
120pub enum AgentThreadEntryContent {
121 Message(Message),
122 ToolCall(ToolCall),
123}
124
125#[derive(Debug)]
126pub struct ToolCall {
127 id: ToolCallId,
128 label: Entity<Markdown>,
129 icon: IconName,
130 status: ToolCallStatus,
131}
132
133#[derive(Debug)]
134pub enum ToolCallStatus {
135 WaitingForConfirmation {
136 confirmation: acp::ToolCallConfirmation,
137 respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
138 },
139 // todo! Running?
140 Allowed {
141 // todo! should this be variants in crate::ToolCallStatus instead?
142 status: acp::ToolCallStatus,
143 content: Option<ToolCallContent>,
144 },
145 Rejected,
146}
147
148#[derive(Debug)]
149pub enum ToolCallContent {
150 Markdown {
151 markdown: Entity<Markdown>,
152 },
153 Diff {
154 path: PathBuf,
155 diff: Entity<BufferDiff>,
156 buffer: Entity<MultiBuffer>,
157 _task: Task<Result<()>>,
158 },
159}
160
161/// A `ThreadEntryId` that is known to be a ToolCall
162#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
163pub struct ToolCallId(ThreadEntryId);
164
165impl ToolCallId {
166 pub fn as_u64(&self) -> u64 {
167 self.0.0
168 }
169}
170
171#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
172pub struct ThreadEntryId(pub u64);
173
174impl ThreadEntryId {
175 pub fn post_inc(&mut self) -> Self {
176 let id = *self;
177 self.0 += 1;
178 id
179 }
180}
181
182#[derive(Debug)]
183pub struct ThreadEntry {
184 pub id: ThreadEntryId,
185 pub content: AgentThreadEntryContent,
186}
187
188pub struct AcpThread {
189 id: ThreadId,
190 next_entry_id: ThreadEntryId,
191 entries: Vec<ThreadEntry>,
192 server: Arc<AcpServer>,
193 title: SharedString,
194 project: Entity<Project>,
195}
196
197enum AcpThreadEvent {
198 NewEntry,
199 EntryUpdated(usize),
200}
201
202impl EventEmitter<AcpThreadEvent> for AcpThread {}
203
204impl AcpThread {
205 pub fn new(
206 server: Arc<AcpServer>,
207 thread_id: ThreadId,
208 entries: Vec<AgentThreadEntryContent>,
209 project: Entity<Project>,
210 _: &mut Context<Self>,
211 ) -> Self {
212 let mut next_entry_id = ThreadEntryId(0);
213 Self {
214 title: "A new agent2 thread".into(),
215 entries: entries
216 .into_iter()
217 .map(|entry| ThreadEntry {
218 id: next_entry_id.post_inc(),
219 content: entry,
220 })
221 .collect(),
222 server,
223 id: thread_id,
224 next_entry_id,
225 project,
226 }
227 }
228
229 pub fn title(&self) -> SharedString {
230 self.title.clone()
231 }
232
233 pub fn entries(&self) -> &[ThreadEntry] {
234 &self.entries
235 }
236
237 pub fn push_entry(
238 &mut self,
239 entry: AgentThreadEntryContent,
240 cx: &mut Context<Self>,
241 ) -> ThreadEntryId {
242 let id = self.next_entry_id.post_inc();
243 self.entries.push(ThreadEntry { id, content: entry });
244 cx.emit(AcpThreadEvent::NewEntry);
245 id
246 }
247
248 pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context<Self>) {
249 let entries_len = self.entries.len();
250 if let Some(last_entry) = self.entries.last_mut()
251 && let AgentThreadEntryContent::Message(Message {
252 ref mut chunks,
253 role: Role::Assistant,
254 }) = last_entry.content
255 {
256 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
257
258 if let (
259 Some(MessageChunk::Text { chunk: old_chunk }),
260 acp::MessageChunk::Text { chunk: new_chunk },
261 ) = (chunks.last_mut(), &chunk)
262 {
263 old_chunk.update(cx, |old_chunk, cx| {
264 old_chunk.append(&new_chunk, cx);
265 });
266 } else {
267 chunks.push(MessageChunk::from_acp(
268 chunk,
269 self.project.read(cx).languages().clone(),
270 cx,
271 ));
272 }
273
274 return;
275 }
276
277 let chunk = MessageChunk::from_acp(chunk, self.project.read(cx).languages().clone(), cx);
278
279 self.push_entry(
280 AgentThreadEntryContent::Message(Message {
281 role: Role::Assistant,
282 chunks: vec![chunk],
283 }),
284 cx,
285 );
286 }
287
288 pub fn request_tool_call(
289 &mut self,
290 label: String,
291 icon: acp::Icon,
292 confirmation: acp::ToolCallConfirmation,
293 cx: &mut Context<Self>,
294 ) -> ToolCallRequest {
295 let (tx, rx) = oneshot::channel();
296
297 let status = ToolCallStatus::WaitingForConfirmation {
298 confirmation,
299 respond_tx: tx,
300 };
301
302 let id = self.insert_tool_call(label, status, icon, cx);
303 ToolCallRequest { id, outcome: rx }
304 }
305
306 pub fn push_tool_call(
307 &mut self,
308 label: String,
309 icon: acp::Icon,
310 cx: &mut Context<Self>,
311 ) -> ToolCallId {
312 let status = ToolCallStatus::Allowed {
313 status: acp::ToolCallStatus::Running,
314 content: None,
315 };
316
317 self.insert_tool_call(label, status, icon, cx)
318 }
319
320 fn insert_tool_call(
321 &mut self,
322 label: String,
323 status: ToolCallStatus,
324 icon: acp::Icon,
325 cx: &mut Context<Self>,
326 ) -> ToolCallId {
327 let language_registry = self.project.read(cx).languages().clone();
328
329 let entry_id = self.push_entry(
330 AgentThreadEntryContent::ToolCall(ToolCall {
331 // todo! clean up id creation
332 id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
333 label: cx.new(|cx| {
334 Markdown::new(label.into(), Some(language_registry.clone()), None, cx)
335 }),
336 icon: acp_icon_to_ui_icon(icon),
337 status,
338 }),
339 cx,
340 );
341
342 ToolCallId(entry_id)
343 }
344
345 pub fn authorize_tool_call(
346 &mut self,
347 id: ToolCallId,
348 outcome: acp::ToolCallConfirmationOutcome,
349 cx: &mut Context<Self>,
350 ) {
351 let Some(entry) = self.entry_mut(id.0) else {
352 return;
353 };
354
355 let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
356 debug_panic!("expected ToolCall");
357 return;
358 };
359
360 let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
361 ToolCallStatus::Rejected
362 } else {
363 ToolCallStatus::Allowed {
364 status: acp::ToolCallStatus::Running,
365 content: None,
366 }
367 };
368
369 let curr_status = mem::replace(&mut call.status, new_status);
370
371 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
372 respond_tx.send(outcome).log_err();
373 } else {
374 debug_panic!("tried to authorize an already authorized tool call");
375 }
376
377 cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
378 }
379
380 pub fn update_tool_call(
381 &mut self,
382 id: ToolCallId,
383 new_status: acp::ToolCallStatus,
384 new_content: Option<acp::ToolCallContent>,
385 cx: &mut Context<Self>,
386 ) -> Result<()> {
387 let language_registry = self.project.read(cx).languages().clone();
388 let entry = self.entry_mut(id.0).context("Entry not found")?;
389
390 match &mut entry.content {
391 AgentThreadEntryContent::ToolCall(call) => match &mut call.status {
392 ToolCallStatus::Allowed { content, status } => {
393 *content = new_content.map(|new_content| match new_content {
394 acp::ToolCallContent::Markdown { markdown } => ToolCallContent::Markdown {
395 markdown: cx.new(|cx| {
396 Markdown::new(
397 markdown.into(),
398 Some(language_registry.clone()),
399 None,
400 cx,
401 )
402 }),
403 },
404 acp::ToolCallContent::Diff {
405 path,
406 old_text,
407 new_text,
408 } => {
409 let buffer = cx.new(|cx| Buffer::local(new_text, cx));
410 let text_snapshot = buffer.read(cx).text_snapshot();
411 let buffer_diff = cx.new(|cx| BufferDiff::new(&text_snapshot, cx));
412
413 let multibuffer = cx.new(|cx| {
414 let mut multibuffer = MultiBuffer::singleton(buffer.clone(), cx);
415 multibuffer.add_diff(buffer_diff.clone(), cx);
416 multibuffer
417 });
418
419 ToolCallContent::Diff {
420 path: path.clone(),
421 diff: buffer_diff.clone(),
422 buffer: multibuffer,
423 _task: cx.spawn(async move |_this, cx| {
424 let diff_snapshot = BufferDiff::update_diff(
425 buffer_diff.clone(),
426 text_snapshot.clone(),
427 old_text.map(|o| o.into()),
428 true,
429 true,
430 None,
431 Some(language_registry.clone()),
432 cx,
433 )
434 .await?;
435
436 buffer_diff.update(cx, |diff, cx| {
437 diff.set_snapshot(diff_snapshot, &text_snapshot, cx)
438 })?;
439
440 if let Some(language) = language_registry
441 .language_for_file_path(&path)
442 .await
443 .log_err()
444 {
445 buffer.update(cx, |buffer, cx| {
446 buffer.set_language(Some(language), cx)
447 })?;
448 }
449
450 anyhow::Ok(())
451 }),
452 }
453 }
454 });
455 *status = new_status;
456 }
457 ToolCallStatus::WaitingForConfirmation { .. } => {
458 anyhow::bail!("Tool call hasn't been authorized yet")
459 }
460 ToolCallStatus::Rejected => {
461 anyhow::bail!("Tool call was rejected and therefore can't be updated")
462 }
463 },
464 _ => anyhow::bail!("Entry is not a tool call"),
465 }
466
467 cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
468 Ok(())
469 }
470
471 fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
472 let entry = self.entries.get_mut(id.0 as usize);
473 debug_assert!(
474 entry.is_some(),
475 "We shouldn't give out ids to entries that don't exist"
476 );
477 entry
478 }
479
480 /// Returns true if the last turn is awaiting tool authorization
481 pub fn waiting_for_tool_confirmation(&self) -> bool {
482 for entry in self.entries.iter().rev() {
483 match &entry.content {
484 AgentThreadEntryContent::ToolCall(call) => match call.status {
485 ToolCallStatus::WaitingForConfirmation { .. } => return true,
486 ToolCallStatus::Allowed { .. } | ToolCallStatus::Rejected => continue,
487 },
488 AgentThreadEntryContent::Message(_) => {
489 // Reached the beginning of the turn
490 return false;
491 }
492 }
493 }
494 false
495 }
496
497 pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> {
498 let agent = self.server.clone();
499 let id = self.id.clone();
500 let chunk = MessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
501 let message = Message {
502 role: Role::User,
503 chunks: vec![chunk],
504 };
505 self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
506 let acp_message = message.into_acp(cx);
507 cx.spawn(async move |_, cx| {
508 agent.send_message(id, acp_message, cx).await?;
509 Ok(())
510 })
511 }
512}
513
514fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
515 match icon {
516 acp::Icon::FileSearch => IconName::FileSearch,
517 acp::Icon::Folder => IconName::Folder,
518 acp::Icon::Globe => IconName::Globe,
519 acp::Icon::Hammer => IconName::Hammer,
520 acp::Icon::LightBulb => IconName::LightBulb,
521 acp::Icon::Pencil => IconName::Pencil,
522 acp::Icon::Regex => IconName::Regex,
523 acp::Icon::Terminal => IconName::Terminal,
524 }
525}
526
527pub struct ToolCallRequest {
528 pub id: ToolCallId,
529 pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
530}
531
532#[cfg(test)]
533mod tests {
534 use super::*;
535 use futures::{FutureExt as _, channel::mpsc, select};
536 use gpui::{AsyncApp, TestAppContext};
537 use project::FakeFs;
538 use serde_json::json;
539 use settings::SettingsStore;
540 use smol::stream::StreamExt as _;
541 use std::{env, path::Path, process::Stdio, time::Duration};
542 use util::path;
543
544 fn init_test(cx: &mut TestAppContext) {
545 env_logger::try_init().ok();
546 cx.update(|cx| {
547 let settings_store = SettingsStore::test(cx);
548 cx.set_global(settings_store);
549 Project::init_settings(cx);
550 language::init(cx);
551 });
552 }
553
554 #[gpui::test]
555 async fn test_gemini_basic(cx: &mut TestAppContext) {
556 init_test(cx);
557
558 cx.executor().allow_parking();
559
560 let fs = FakeFs::new(cx.executor());
561 let project = Project::test(fs, [], cx).await;
562 let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
563 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
564 thread
565 .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
566 .await
567 .unwrap();
568
569 thread.read_with(cx, |thread, _| {
570 assert_eq!(thread.entries.len(), 2);
571 assert!(matches!(
572 thread.entries[0].content,
573 AgentThreadEntryContent::Message(Message {
574 role: Role::User,
575 ..
576 })
577 ));
578 assert!(matches!(
579 thread.entries[1].content,
580 AgentThreadEntryContent::Message(Message {
581 role: Role::Assistant,
582 ..
583 })
584 ));
585 });
586 }
587
588 #[gpui::test]
589 async fn test_gemini_tool_call(cx: &mut TestAppContext) {
590 init_test(cx);
591
592 cx.executor().allow_parking();
593
594 let fs = FakeFs::new(cx.executor());
595 fs.insert_tree(
596 path!("/private/tmp"),
597 json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
598 )
599 .await;
600 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
601 let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
602 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
603 thread
604 .update(cx, |thread, cx| {
605 thread.send(
606 "Read the '/private/tmp/foo' file and tell me what you see.",
607 cx,
608 )
609 })
610 .await
611 .unwrap();
612 thread.read_with(cx, |thread, _cx| {
613 assert!(matches!(
614 &thread.entries()[1].content,
615 AgentThreadEntryContent::ToolCall(ToolCall {
616 status: ToolCallStatus::Allowed { .. },
617 ..
618 })
619 ));
620
621 assert!(matches!(
622 thread.entries[2].content,
623 AgentThreadEntryContent::Message(Message {
624 role: Role::Assistant,
625 ..
626 })
627 ));
628 });
629 }
630
631 #[gpui::test]
632 async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
633 init_test(cx);
634
635 cx.executor().allow_parking();
636
637 let fs = FakeFs::new(cx.executor());
638 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
639 let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
640 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
641 let full_turn = thread.update(cx, |thread, cx| {
642 thread.send(r#"Run `echo "Hello, world!"`"#, cx)
643 });
644
645 run_until_tool_call(&thread, cx).await;
646
647 let tool_call_id = thread.read_with(cx, |thread, _cx| {
648 let AgentThreadEntryContent::ToolCall(ToolCall {
649 id,
650 status:
651 ToolCallStatus::WaitingForConfirmation {
652 confirmation: acp::ToolCallConfirmation::Execute { root_command, .. },
653 ..
654 },
655 ..
656 }) = &thread.entries()[1].content
657 else {
658 panic!();
659 };
660
661 assert_eq!(root_command, "echo");
662
663 *id
664 });
665
666 thread.update(cx, |thread, cx| {
667 thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
668
669 assert!(matches!(
670 &thread.entries()[1].content,
671 AgentThreadEntryContent::ToolCall(ToolCall {
672 status: ToolCallStatus::Allowed { .. },
673 ..
674 })
675 ));
676 });
677
678 full_turn.await.unwrap();
679
680 thread.read_with(cx, |thread, cx| {
681 let AgentThreadEntryContent::ToolCall(ToolCall {
682 status:
683 ToolCallStatus::Allowed {
684 content: Some(ToolCallContent::Markdown { markdown }),
685 ..
686 },
687 ..
688 }) = &thread.entries()[1].content
689 else {
690 panic!();
691 };
692
693 markdown.read_with(cx, |md, _cx| {
694 assert!(
695 md.source().contains("Hello, world!"),
696 r#"Expected '{}' to contain "Hello, world!""#,
697 md.source()
698 );
699 });
700 });
701 }
702
703 async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
704 let (mut tx, mut rx) = mpsc::channel::<()>(1);
705
706 let subscription = cx.update(|cx| {
707 cx.subscribe(thread, move |thread, _, cx| {
708 if thread
709 .read(cx)
710 .entries
711 .iter()
712 .any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_)))
713 {
714 tx.try_send(()).unwrap();
715 }
716 })
717 });
718
719 select! {
720 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
721 panic!("Timeout waiting for tool call")
722 }
723 _ = rx.next().fuse() => {
724 drop(subscription);
725 }
726 }
727 }
728
729 pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> {
730 let cli_path =
731 Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
732 let mut command = util::command::new_smol_command("node");
733 command
734 .arg(cli_path)
735 .arg("--acp")
736 .current_dir("/private/tmp")
737 .stdin(Stdio::piped())
738 .stdout(Stdio::piped())
739 .stderr(Stdio::inherit())
740 .kill_on_drop(true);
741
742 if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
743 command.env("GEMINI_API_KEY", gemini_key);
744 }
745
746 let child = command.spawn().unwrap();
747
748 Ok(AcpServer::stdio(child, project, &mut cx))
749 }
750}