1mod server;
2mod thread_view;
3
4use agentic_coding_protocol::{self as acp, Role};
5use anyhow::{Context as _, Result};
6use chrono::{DateTime, Utc};
7use futures::channel::oneshot;
8use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
9use language::LanguageRegistry;
10use markdown::Markdown;
11use project::Project;
12use std::{mem, ops::Range, path::PathBuf, sync::Arc};
13use ui::App;
14use util::{ResultExt, debug_panic};
15
16pub use server::AcpServer;
17pub use thread_view::AcpThreadView;
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
20pub struct ThreadId(SharedString);
21
22#[derive(Copy, Clone, Debug, PartialEq, Eq)]
23pub struct FileVersion(u64);
24
25#[derive(Debug)]
26pub struct AgentThreadSummary {
27 pub id: ThreadId,
28 pub title: String,
29 pub created_at: DateTime<Utc>,
30}
31
32#[derive(Clone, Debug, PartialEq, Eq)]
33pub struct FileContent {
34 pub path: PathBuf,
35 pub version: FileVersion,
36 pub content: SharedString,
37}
38
39#[derive(Clone, Debug, Eq, PartialEq)]
40pub struct Message {
41 pub role: acp::Role,
42 pub chunks: Vec<MessageChunk>,
43}
44
45impl Message {
46 fn into_acp(self, cx: &App) -> acp::Message {
47 acp::Message {
48 role: self.role,
49 chunks: self
50 .chunks
51 .into_iter()
52 .map(|chunk| chunk.into_acp(cx))
53 .collect(),
54 }
55 }
56}
57
58#[derive(Clone, Debug, Eq, PartialEq)]
59pub enum MessageChunk {
60 Text {
61 chunk: Entity<Markdown>,
62 },
63 File {
64 content: FileContent,
65 },
66 Directory {
67 path: PathBuf,
68 contents: Vec<FileContent>,
69 },
70 Symbol {
71 path: PathBuf,
72 range: Range<u64>,
73 version: FileVersion,
74 name: SharedString,
75 content: SharedString,
76 },
77 Fetch {
78 url: SharedString,
79 content: SharedString,
80 },
81}
82
83impl MessageChunk {
84 pub fn from_acp(
85 chunk: acp::MessageChunk,
86 language_registry: Arc<LanguageRegistry>,
87 cx: &mut App,
88 ) -> Self {
89 match chunk {
90 acp::MessageChunk::Text { chunk } => MessageChunk::Text {
91 chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
92 },
93 }
94 }
95
96 pub fn into_acp(self, cx: &App) -> acp::MessageChunk {
97 match self {
98 MessageChunk::Text { chunk } => acp::MessageChunk::Text {
99 chunk: chunk.read(cx).source().to_string(),
100 },
101 MessageChunk::File { .. } => todo!(),
102 MessageChunk::Directory { .. } => todo!(),
103 MessageChunk::Symbol { .. } => todo!(),
104 MessageChunk::Fetch { .. } => todo!(),
105 }
106 }
107
108 pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
109 MessageChunk::Text {
110 chunk: cx.new(|cx| {
111 Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
112 }),
113 }
114 }
115}
116
117#[derive(Debug)]
118pub enum AgentThreadEntryContent {
119 Message(Message),
120 ToolCall(ToolCall),
121}
122
123#[derive(Debug)]
124pub struct ToolCall {
125 id: ToolCallId,
126 label: Entity<Markdown>,
127 status: ToolCallStatus,
128}
129
130#[derive(Debug)]
131pub enum ToolCallStatus {
132 WaitingForConfirmation {
133 confirmation: acp::ToolCallConfirmation,
134 respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
135 },
136 // todo! Running?
137 Allowed {
138 // todo! should this be variants in crate::ToolCallStatus instead?
139 status: acp::ToolCallStatus,
140 content: Option<Entity<Markdown>>,
141 },
142 Rejected,
143}
144
145/// A `ThreadEntryId` that is known to be a ToolCall
146#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
147pub struct ToolCallId(ThreadEntryId);
148
149impl ToolCallId {
150 pub fn as_u64(&self) -> u64 {
151 self.0.0
152 }
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
156pub struct ThreadEntryId(pub u64);
157
158impl ThreadEntryId {
159 pub fn post_inc(&mut self) -> Self {
160 let id = *self;
161 self.0 += 1;
162 id
163 }
164}
165
166#[derive(Debug)]
167pub struct ThreadEntry {
168 pub id: ThreadEntryId,
169 pub content: AgentThreadEntryContent,
170}
171
172pub struct AcpThread {
173 id: ThreadId,
174 next_entry_id: ThreadEntryId,
175 entries: Vec<ThreadEntry>,
176 server: Arc<AcpServer>,
177 title: SharedString,
178 project: Entity<Project>,
179}
180
181enum AcpThreadEvent {
182 NewEntry,
183 EntryUpdated(usize),
184}
185
186impl EventEmitter<AcpThreadEvent> for AcpThread {}
187
188impl AcpThread {
189 pub fn new(
190 server: Arc<AcpServer>,
191 thread_id: ThreadId,
192 entries: Vec<AgentThreadEntryContent>,
193 project: Entity<Project>,
194 _: &mut Context<Self>,
195 ) -> Self {
196 let mut next_entry_id = ThreadEntryId(0);
197 Self {
198 title: "A new agent2 thread".into(),
199 entries: entries
200 .into_iter()
201 .map(|entry| ThreadEntry {
202 id: next_entry_id.post_inc(),
203 content: entry,
204 })
205 .collect(),
206 server,
207 id: thread_id,
208 next_entry_id,
209 project,
210 }
211 }
212
213 pub fn title(&self) -> SharedString {
214 self.title.clone()
215 }
216
217 pub fn entries(&self) -> &[ThreadEntry] {
218 &self.entries
219 }
220
221 pub fn push_entry(
222 &mut self,
223 entry: AgentThreadEntryContent,
224 cx: &mut Context<Self>,
225 ) -> ThreadEntryId {
226 let id = self.next_entry_id.post_inc();
227 self.entries.push(ThreadEntry { id, content: entry });
228 cx.emit(AcpThreadEvent::NewEntry);
229 id
230 }
231
232 pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context<Self>) {
233 let entries_len = self.entries.len();
234 if let Some(last_entry) = self.entries.last_mut()
235 && let AgentThreadEntryContent::Message(Message {
236 ref mut chunks,
237 role: Role::Assistant,
238 }) = last_entry.content
239 {
240 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
241
242 if let (
243 Some(MessageChunk::Text { chunk: old_chunk }),
244 acp::MessageChunk::Text { chunk: new_chunk },
245 ) = (chunks.last_mut(), &chunk)
246 {
247 old_chunk.update(cx, |old_chunk, cx| {
248 old_chunk.append(&new_chunk, cx);
249 });
250 } else {
251 chunks.push(MessageChunk::from_acp(
252 chunk,
253 self.project.read(cx).languages().clone(),
254 cx,
255 ));
256 }
257
258 return;
259 }
260
261 let chunk = MessageChunk::from_acp(chunk, self.project.read(cx).languages().clone(), cx);
262
263 self.push_entry(
264 AgentThreadEntryContent::Message(Message {
265 role: Role::Assistant,
266 chunks: vec![chunk],
267 }),
268 cx,
269 );
270 }
271
272 pub fn request_tool_call(
273 &mut self,
274 label: String,
275 confirmation: acp::ToolCallConfirmation,
276 cx: &mut Context<Self>,
277 ) -> ToolCallRequest {
278 let (tx, rx) = oneshot::channel();
279
280 let status = ToolCallStatus::WaitingForConfirmation {
281 confirmation,
282 respond_tx: tx,
283 };
284
285 let id = self.insert_tool_call(label, status, cx);
286 ToolCallRequest { id, outcome: rx }
287 }
288
289 pub fn push_tool_call(&mut self, label: String, cx: &mut Context<Self>) -> ToolCallId {
290 let status = ToolCallStatus::Allowed {
291 status: acp::ToolCallStatus::Running,
292 content: None,
293 };
294
295 self.insert_tool_call(label, status, cx)
296 }
297
298 fn insert_tool_call(
299 &mut self,
300 label: String,
301 status: ToolCallStatus,
302 cx: &mut Context<Self>,
303 ) -> ToolCallId {
304 let language_registry = self.project.read(cx).languages().clone();
305
306 let entry_id = self.push_entry(
307 AgentThreadEntryContent::ToolCall(ToolCall {
308 // todo! clean up id creation
309 id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
310 label: cx.new(|cx| {
311 Markdown::new(label.into(), Some(language_registry.clone()), None, cx)
312 }),
313 status,
314 }),
315 cx,
316 );
317
318 ToolCallId(entry_id)
319 }
320
321 pub fn authorize_tool_call(
322 &mut self,
323 id: ToolCallId,
324 outcome: acp::ToolCallConfirmationOutcome,
325 cx: &mut Context<Self>,
326 ) {
327 let Some(entry) = self.entry_mut(id.0) else {
328 return;
329 };
330
331 let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
332 debug_panic!("expected ToolCall");
333 return;
334 };
335
336 let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
337 ToolCallStatus::Rejected
338 } else {
339 ToolCallStatus::Allowed {
340 status: acp::ToolCallStatus::Running,
341 content: None,
342 }
343 };
344
345 let curr_status = mem::replace(&mut call.status, new_status);
346
347 if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
348 respond_tx.send(outcome).log_err();
349 } else {
350 debug_panic!("tried to authorize an already authorized tool call");
351 }
352
353 cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
354 }
355
356 pub fn update_tool_call(
357 &mut self,
358 id: ToolCallId,
359 new_status: acp::ToolCallStatus,
360 new_content: Option<acp::ToolCallContent>,
361 cx: &mut Context<Self>,
362 ) -> Result<()> {
363 let language_registry = self.project.read(cx).languages().clone();
364 let entry = self.entry_mut(id.0).context("Entry not found")?;
365
366 match &mut entry.content {
367 AgentThreadEntryContent::ToolCall(call) => match &mut call.status {
368 ToolCallStatus::Allowed { content, status } => {
369 *content = new_content.map(|new_content| {
370 let acp::ToolCallContent::Markdown { markdown } = new_content;
371
372 cx.new(|cx| {
373 Markdown::new(markdown.into(), Some(language_registry), None, cx)
374 })
375 });
376
377 *status = new_status;
378 }
379 ToolCallStatus::WaitingForConfirmation { .. } => {
380 anyhow::bail!("Tool call hasn't been authorized yet")
381 }
382 ToolCallStatus::Rejected => {
383 anyhow::bail!("Tool call was rejected and therefore can't be updated")
384 }
385 },
386 _ => anyhow::bail!("Entry is not a tool call"),
387 }
388
389 cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
390 Ok(())
391 }
392
393 fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
394 let entry = self.entries.get_mut(id.0 as usize);
395 debug_assert!(
396 entry.is_some(),
397 "We shouldn't give out ids to entries that don't exist"
398 );
399 entry
400 }
401
402 /// Returns true if the last turn is awaiting tool authorization
403 pub fn waiting_for_tool_confirmation(&self) -> bool {
404 for entry in self.entries.iter().rev() {
405 match &entry.content {
406 AgentThreadEntryContent::ToolCall(call) => match call.status {
407 ToolCallStatus::WaitingForConfirmation { .. } => return true,
408 ToolCallStatus::Allowed { .. } | ToolCallStatus::Rejected => continue,
409 },
410 AgentThreadEntryContent::Message(_) => {
411 // Reached the beginning of the turn
412 return false;
413 }
414 }
415 }
416 false
417 }
418
419 pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> {
420 let agent = self.server.clone();
421 let id = self.id.clone();
422 let chunk = MessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
423 let message = Message {
424 role: Role::User,
425 chunks: vec![chunk],
426 };
427 self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
428 let acp_message = message.into_acp(cx);
429 cx.spawn(async move |_, cx| {
430 agent.send_message(id, acp_message, cx).await?;
431 Ok(())
432 })
433 }
434}
435
436pub struct ToolCallRequest {
437 pub id: ToolCallId,
438 pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444 use futures::{FutureExt as _, channel::mpsc, select};
445 use gpui::{AsyncApp, TestAppContext};
446 use project::FakeFs;
447 use serde_json::json;
448 use settings::SettingsStore;
449 use smol::stream::StreamExt as _;
450 use std::{env, path::Path, process::Stdio, time::Duration};
451 use util::path;
452
453 fn init_test(cx: &mut TestAppContext) {
454 env_logger::try_init().ok();
455 cx.update(|cx| {
456 let settings_store = SettingsStore::test(cx);
457 cx.set_global(settings_store);
458 Project::init_settings(cx);
459 language::init(cx);
460 });
461 }
462
463 #[gpui::test]
464 async fn test_gemini_basic(cx: &mut TestAppContext) {
465 init_test(cx);
466
467 cx.executor().allow_parking();
468
469 let fs = FakeFs::new(cx.executor());
470 let project = Project::test(fs, [], cx).await;
471 let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
472 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
473 thread
474 .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
475 .await
476 .unwrap();
477
478 thread.read_with(cx, |thread, _| {
479 assert_eq!(thread.entries.len(), 2);
480 assert!(matches!(
481 thread.entries[0].content,
482 AgentThreadEntryContent::Message(Message {
483 role: Role::User,
484 ..
485 })
486 ));
487 assert!(matches!(
488 thread.entries[1].content,
489 AgentThreadEntryContent::Message(Message {
490 role: Role::Assistant,
491 ..
492 })
493 ));
494 });
495 }
496
497 #[gpui::test]
498 async fn test_gemini_tool_call(cx: &mut TestAppContext) {
499 init_test(cx);
500
501 cx.executor().allow_parking();
502
503 let fs = FakeFs::new(cx.executor());
504 fs.insert_tree(
505 path!("/private/tmp"),
506 json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
507 )
508 .await;
509 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
510 let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
511 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
512 thread
513 .update(cx, |thread, cx| {
514 thread.send(
515 "Read the '/private/tmp/foo' file and tell me what you see.",
516 cx,
517 )
518 })
519 .await
520 .unwrap();
521 thread.read_with(cx, |thread, cx| {
522 let AgentThreadEntryContent::ToolCall(ToolCall {
523 label,
524 status: ToolCallStatus::Allowed { .. },
525 ..
526 }) = &thread.entries()[1].content
527 else {
528 panic!();
529 };
530
531 label.read_with(cx, |md, _cx| {
532 assert_eq!(md.source(), "ReadFile");
533 });
534
535 assert!(matches!(
536 thread.entries[2].content,
537 AgentThreadEntryContent::Message(Message {
538 role: Role::Assistant,
539 ..
540 })
541 ));
542 });
543 }
544
545 #[gpui::test]
546 async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
547 init_test(cx);
548
549 cx.executor().allow_parking();
550
551 let fs = FakeFs::new(cx.executor());
552 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
553 let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
554 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
555 let full_turn = thread.update(cx, |thread, cx| {
556 thread.send(r#"Run `echo "Hello, world!"`"#, cx)
557 });
558
559 run_until_tool_call(&thread, cx).await;
560
561 let tool_call_id = thread.read_with(cx, |thread, cx| {
562 let AgentThreadEntryContent::ToolCall(ToolCall {
563 id,
564 label,
565 status:
566 ToolCallStatus::WaitingForConfirmation {
567 confirmation: acp::ToolCallConfirmation::Execute { root_command, .. },
568 ..
569 },
570 }) = &thread.entries()[1].content
571 else {
572 panic!();
573 };
574
575 assert_eq!(root_command, "echo");
576
577 label.read_with(cx, |md, _cx| {
578 assert_eq!(md.source(), "Shell");
579 });
580
581 *id
582 });
583
584 thread.update(cx, |thread, cx| {
585 thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
586
587 assert!(matches!(
588 &thread.entries()[1].content,
589 AgentThreadEntryContent::ToolCall(ToolCall {
590 status: ToolCallStatus::Allowed { .. },
591 ..
592 })
593 ));
594 });
595
596 full_turn.await.unwrap();
597
598 thread.read_with(cx, |thread, cx| {
599 let AgentThreadEntryContent::ToolCall(ToolCall {
600 status: ToolCallStatus::Allowed { content, .. },
601 ..
602 }) = &thread.entries()[1].content
603 else {
604 panic!();
605 };
606
607 content.as_ref().unwrap().read_with(cx, |md, _cx| {
608 assert!(
609 md.source().contains("Hello, world!"),
610 r#"Expected '{}' to contain "Hello, world!""#,
611 md.source()
612 );
613 });
614 });
615 }
616
617 async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
618 let (mut tx, mut rx) = mpsc::channel::<()>(1);
619
620 let subscription = cx.update(|cx| {
621 cx.subscribe(thread, move |thread, _, cx| {
622 if thread
623 .read(cx)
624 .entries
625 .iter()
626 .any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_)))
627 {
628 tx.try_send(()).unwrap();
629 }
630 })
631 });
632
633 select! {
634 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
635 panic!("Timeout waiting for tool call")
636 }
637 _ = rx.next().fuse() => {
638 drop(subscription);
639 }
640 }
641 }
642
643 pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> {
644 let cli_path =
645 Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
646 let mut command = util::command::new_smol_command("node");
647 command
648 .arg(cli_path)
649 .arg("--acp")
650 .current_dir("/private/tmp")
651 .stdin(Stdio::piped())
652 .stdout(Stdio::piped())
653 .stderr(Stdio::inherit())
654 .kill_on_drop(true);
655
656 if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
657 command.env("GEMINI_API_KEY", gemini_key);
658 }
659
660 let child = command.spawn().unwrap();
661
662 Ok(AcpServer::stdio(child, project, &mut cx))
663 }
664}