1mod server;
2mod thread_view;
3
4use agentic_coding_protocol::{self as acp, Role};
5use anyhow::Result;
6use chrono::{DateTime, Utc};
7use futures::channel::oneshot;
8use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
9use language::LanguageRegistry;
10use markdown::Markdown;
11use project::Project;
12use std::{mem, ops::Range, path::PathBuf, sync::Arc};
13use ui::App;
14use util::{ResultExt, debug_panic};
15
16pub use server::AcpServer;
17pub use thread_view::AcpThreadView;
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
20pub struct ThreadId(SharedString);
21
22#[derive(Copy, Clone, Debug, PartialEq, Eq)]
23pub struct FileVersion(u64);
24
25#[derive(Debug)]
26pub struct AgentThreadSummary {
27 pub id: ThreadId,
28 pub title: String,
29 pub created_at: DateTime<Utc>,
30}
31
32#[derive(Clone, Debug, PartialEq, Eq)]
33pub struct FileContent {
34 pub path: PathBuf,
35 pub version: FileVersion,
36 pub content: SharedString,
37}
38
39#[derive(Clone, Debug, Eq, PartialEq)]
40pub struct Message {
41 pub role: acp::Role,
42 pub chunks: Vec<MessageChunk>,
43}
44
45impl Message {
46 fn into_acp(self, cx: &App) -> acp::Message {
47 acp::Message {
48 role: self.role,
49 chunks: self
50 .chunks
51 .into_iter()
52 .map(|chunk| chunk.into_acp(cx))
53 .collect(),
54 }
55 }
56}
57
58#[derive(Clone, Debug, Eq, PartialEq)]
59pub enum MessageChunk {
60 Text {
61 chunk: Entity<Markdown>,
62 },
63 File {
64 content: FileContent,
65 },
66 Directory {
67 path: PathBuf,
68 contents: Vec<FileContent>,
69 },
70 Symbol {
71 path: PathBuf,
72 range: Range<u64>,
73 version: FileVersion,
74 name: SharedString,
75 content: SharedString,
76 },
77 Fetch {
78 url: SharedString,
79 content: SharedString,
80 },
81}
82
83impl MessageChunk {
84 pub fn from_acp(
85 chunk: acp::MessageChunk,
86 language_registry: Arc<LanguageRegistry>,
87 cx: &mut App,
88 ) -> Self {
89 match chunk {
90 acp::MessageChunk::Text { chunk } => MessageChunk::Text {
91 chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
92 },
93 }
94 }
95
96 pub fn into_acp(self, cx: &App) -> acp::MessageChunk {
97 match self {
98 MessageChunk::Text { chunk } => acp::MessageChunk::Text {
99 chunk: chunk.read(cx).source().to_string(),
100 },
101 MessageChunk::File { .. } => todo!(),
102 MessageChunk::Directory { .. } => todo!(),
103 MessageChunk::Symbol { .. } => todo!(),
104 MessageChunk::Fetch { .. } => todo!(),
105 }
106 }
107
108 pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
109 MessageChunk::Text {
110 chunk: cx.new(|cx| {
111 Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
112 }),
113 }
114 }
115}
116
117#[derive(Debug)]
118pub enum AgentThreadEntryContent {
119 Message(Message),
120 ToolCall(ToolCall),
121}
122
123#[derive(Debug)]
124pub enum ToolCall {
125 WaitingForConfirmation {
126 id: ToolCallId,
127 tool_name: Entity<Markdown>,
128 description: Entity<Markdown>,
129 respond_tx: oneshot::Sender<bool>,
130 },
131 // todo! Running?
132 Allowed,
133 Rejected,
134}
135
136/// A `ThreadEntryId` that is known to be a ToolCall
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
138pub struct ToolCallId(ThreadEntryId);
139
140impl ToolCallId {
141 pub fn as_u64(&self) -> u64 {
142 self.0.0
143 }
144}
145
146#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
147pub struct ThreadEntryId(pub u64);
148
149impl ThreadEntryId {
150 pub fn post_inc(&mut self) -> Self {
151 let id = *self;
152 self.0 += 1;
153 id
154 }
155}
156
157#[derive(Debug)]
158pub struct ThreadEntry {
159 pub id: ThreadEntryId,
160 pub content: AgentThreadEntryContent,
161}
162
163pub struct AcpThread {
164 id: ThreadId,
165 next_entry_id: ThreadEntryId,
166 entries: Vec<ThreadEntry>,
167 server: Arc<AcpServer>,
168 title: SharedString,
169 project: Entity<Project>,
170}
171
172enum AcpThreadEvent {
173 NewEntry,
174 EntryUpdated(usize),
175}
176
177impl EventEmitter<AcpThreadEvent> for AcpThread {}
178
179impl AcpThread {
180 pub fn new(
181 server: Arc<AcpServer>,
182 thread_id: ThreadId,
183 entries: Vec<AgentThreadEntryContent>,
184 project: Entity<Project>,
185 _: &mut Context<Self>,
186 ) -> Self {
187 let mut next_entry_id = ThreadEntryId(0);
188 Self {
189 title: "A new agent2 thread".into(),
190 entries: entries
191 .into_iter()
192 .map(|entry| ThreadEntry {
193 id: next_entry_id.post_inc(),
194 content: entry,
195 })
196 .collect(),
197 server,
198 id: thread_id,
199 next_entry_id,
200 project,
201 }
202 }
203
204 pub fn title(&self) -> SharedString {
205 self.title.clone()
206 }
207
208 pub fn entries(&self) -> &[ThreadEntry] {
209 &self.entries
210 }
211
212 pub fn push_entry(
213 &mut self,
214 entry: AgentThreadEntryContent,
215 cx: &mut Context<Self>,
216 ) -> ThreadEntryId {
217 let id = self.next_entry_id.post_inc();
218 self.entries.push(ThreadEntry { id, content: entry });
219 cx.emit(AcpThreadEvent::NewEntry);
220 id
221 }
222
223 pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context<Self>) {
224 let entries_len = self.entries.len();
225 if let Some(last_entry) = self.entries.last_mut()
226 && let AgentThreadEntryContent::Message(Message {
227 ref mut chunks,
228 role: Role::Assistant,
229 }) = last_entry.content
230 {
231 cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
232
233 if let (
234 Some(MessageChunk::Text { chunk: old_chunk }),
235 acp::MessageChunk::Text { chunk: new_chunk },
236 ) = (chunks.last_mut(), &chunk)
237 {
238 old_chunk.update(cx, |old_chunk, cx| {
239 old_chunk.append(&new_chunk, cx);
240 });
241 } else {
242 chunks.push(MessageChunk::from_acp(
243 chunk,
244 self.project.read(cx).languages().clone(),
245 cx,
246 ));
247 }
248
249 return;
250 }
251
252 let chunk = MessageChunk::from_acp(chunk, self.project.read(cx).languages().clone(), cx);
253
254 self.push_entry(
255 AgentThreadEntryContent::Message(Message {
256 role: Role::Assistant,
257 chunks: vec![chunk],
258 }),
259 cx,
260 );
261 }
262
263 pub fn push_tool_call(
264 &mut self,
265 title: String,
266 description: String,
267 respond_tx: oneshot::Sender<bool>,
268 cx: &mut Context<Self>,
269 ) -> ToolCallId {
270 let language_registry = self.project.read(cx).languages().clone();
271
272 let entry_id = self.push_entry(
273 AgentThreadEntryContent::ToolCall(ToolCall::WaitingForConfirmation {
274 // todo! clean up id creation
275 id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
276 tool_name: cx.new(|cx| {
277 Markdown::new(title.into(), Some(language_registry.clone()), None, cx)
278 }),
279 description: cx.new(|cx| {
280 Markdown::new(
281 description.into(),
282 Some(language_registry.clone()),
283 None,
284 cx,
285 )
286 }),
287 respond_tx,
288 }),
289 cx,
290 );
291
292 ToolCallId(entry_id)
293 }
294
295 pub fn authorize_tool_call(&mut self, id: ToolCallId, allowed: bool, cx: &mut Context<Self>) {
296 let Some(entry) = self.entry_mut(id.0) else {
297 return;
298 };
299
300 let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
301 debug_panic!("expected ToolCall");
302 return;
303 };
304
305 let new_state = if allowed {
306 ToolCall::Allowed
307 } else {
308 ToolCall::Rejected
309 };
310
311 let call = mem::replace(call, new_state);
312
313 if let ToolCall::WaitingForConfirmation { respond_tx, .. } = call {
314 respond_tx.send(allowed).log_err();
315 } else {
316 debug_panic!("tried to authorize an already authorized tool call");
317 }
318
319 cx.emit(AcpThreadEvent::EntryUpdated(id.0.0 as usize));
320 }
321
322 fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
323 let entry = self.entries.get_mut(id.0 as usize);
324 debug_assert!(
325 entry.is_some(),
326 "We shouldn't give out ids to entries that don't exist"
327 );
328 entry
329 }
330
331 pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> {
332 let agent = self.server.clone();
333 let id = self.id.clone();
334 let chunk = MessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
335 let message = Message {
336 role: Role::User,
337 chunks: vec![chunk],
338 };
339 self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
340 let acp_message = message.into_acp(cx);
341 cx.spawn(async move |_, cx| {
342 agent.send_message(id, acp_message, cx).await?;
343 Ok(())
344 })
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use futures::{FutureExt as _, channel::mpsc, select};
352 use gpui::{AsyncApp, TestAppContext};
353 use project::FakeFs;
354 use serde_json::json;
355 use settings::SettingsStore;
356 use smol::stream::StreamExt;
357 use std::{env, path::Path, process::Stdio, time::Duration};
358 use util::path;
359
360 fn init_test(cx: &mut TestAppContext) {
361 env_logger::try_init().ok();
362 cx.update(|cx| {
363 let settings_store = SettingsStore::test(cx);
364 cx.set_global(settings_store);
365 Project::init_settings(cx);
366 language::init(cx);
367 });
368 }
369
370 #[gpui::test]
371 async fn test_gemini_basic(cx: &mut TestAppContext) {
372 init_test(cx);
373
374 cx.executor().allow_parking();
375
376 let fs = FakeFs::new(cx.executor());
377 let project = Project::test(fs, [], cx).await;
378 let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
379 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
380 thread
381 .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
382 .await
383 .unwrap();
384
385 thread.read_with(cx, |thread, _| {
386 assert_eq!(thread.entries.len(), 2);
387 assert!(matches!(
388 thread.entries[0].content,
389 AgentThreadEntryContent::Message(Message {
390 role: Role::User,
391 ..
392 })
393 ));
394 assert!(matches!(
395 thread.entries[1].content,
396 AgentThreadEntryContent::Message(Message {
397 role: Role::Assistant,
398 ..
399 })
400 ));
401 });
402 }
403
404 #[gpui::test]
405 async fn test_gemini_tool_call(cx: &mut TestAppContext) {
406 init_test(cx);
407
408 cx.executor().allow_parking();
409
410 let fs = FakeFs::new(cx.executor());
411 fs.insert_tree(
412 path!("/private/tmp"),
413 json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
414 )
415 .await;
416 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
417 let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
418 let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
419 let full_turn = thread.update(cx, |thread, cx| {
420 thread.send(
421 "Read the '/private/tmp/foo' file and tell me what you see.",
422 cx,
423 )
424 });
425
426 run_until_tool_call(&thread, cx).await;
427
428 let tool_call_id = thread.read_with(cx, |thread, cx| {
429 let AgentThreadEntryContent::ToolCall(ToolCall::WaitingForConfirmation {
430 id,
431 tool_name,
432 description,
433 ..
434 }) = &thread.entries().last().unwrap().content
435 else {
436 panic!();
437 };
438
439 tool_name.read_with(cx, |md, _cx| {
440 assert_eq!(md.source(), "read_file");
441 });
442
443 description.read_with(cx, |md, _cx| {
444 assert!(
445 md.source().contains("foo"),
446 "Expected description to contain 'foo', but got {}",
447 md.source()
448 );
449 });
450 *id
451 });
452
453 thread.update(cx, |thread, cx| {
454 thread.authorize_tool_call(tool_call_id, true, cx);
455 assert!(matches!(
456 thread.entries().last().unwrap().content,
457 AgentThreadEntryContent::ToolCall(ToolCall::Allowed)
458 ));
459 });
460
461 full_turn.await.unwrap();
462
463 thread.read_with(cx, |thread, _| {
464 assert!(thread.entries.len() >= 3, "{:?}", &thread.entries);
465 assert!(matches!(
466 thread.entries[0].content,
467 AgentThreadEntryContent::Message(Message {
468 role: Role::User,
469 ..
470 })
471 ));
472 assert!(matches!(
473 thread.entries[1].content,
474 AgentThreadEntryContent::ToolCall(ToolCall::Allowed)
475 ));
476 assert!(matches!(
477 thread.entries[2].content,
478 AgentThreadEntryContent::Message(Message {
479 role: Role::Assistant,
480 ..
481 })
482 ));
483 });
484 }
485
486 async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
487 let (mut tx, mut rx) = mpsc::channel(1);
488
489 let subscription = cx.update(|cx| {
490 cx.subscribe(thread, move |thread, _, cx| {
491 if thread
492 .read(cx)
493 .entries
494 .iter()
495 .any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_)))
496 {
497 tx.try_send(()).unwrap();
498 }
499 })
500 });
501
502 select! {
503 _ = cx.executor().timer(Duration::from_secs(5)).fuse() => {
504 panic!("Timeout waiting for tool call")
505 }
506 _ = rx.next().fuse() => {
507 drop(subscription);
508 }
509 }
510 }
511
512 pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> {
513 let cli_path =
514 Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
515 let mut command = util::command::new_smol_command("node");
516 command
517 .arg(cli_path)
518 .arg("--acp")
519 .args(["--model", "gemini-2.5-flash"])
520 .current_dir("/private/tmp")
521 .stdin(Stdio::piped())
522 .stdout(Stdio::piped())
523 .stderr(Stdio::inherit())
524 .kill_on_drop(true);
525
526 if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
527 command.env("GEMINI_API_KEY", gemini_key);
528 }
529
530 let child = command.spawn().unwrap();
531
532 Ok(AcpServer::stdio(child, project, &mut cx))
533 }
534}