acp.rs

  1mod server;
  2mod thread_view;
  3
  4use agentic_coding_protocol::{self as acp, Role};
  5use anyhow::Result;
  6use chrono::{DateTime, Utc};
  7use futures::channel::oneshot;
  8use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
  9use language::LanguageRegistry;
 10use markdown::Markdown;
 11use project::Project;
 12use std::{mem, ops::Range, path::PathBuf, sync::Arc};
 13use ui::App;
 14use util::{ResultExt, debug_panic};
 15
 16pub use server::AcpServer;
 17pub use thread_view::AcpThreadView;
 18
 19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 20pub struct ThreadId(SharedString);
 21
 22#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 23pub struct FileVersion(u64);
 24
 25#[derive(Debug)]
 26pub struct AgentThreadSummary {
 27    pub id: ThreadId,
 28    pub title: String,
 29    pub created_at: DateTime<Utc>,
 30}
 31
 32#[derive(Clone, Debug, PartialEq, Eq)]
 33pub struct FileContent {
 34    pub path: PathBuf,
 35    pub version: FileVersion,
 36    pub content: SharedString,
 37}
 38
 39#[derive(Clone, Debug, Eq, PartialEq)]
 40pub struct Message {
 41    pub role: acp::Role,
 42    pub chunks: Vec<MessageChunk>,
 43}
 44
 45impl Message {
 46    fn into_acp(self, cx: &App) -> acp::Message {
 47        acp::Message {
 48            role: self.role,
 49            chunks: self
 50                .chunks
 51                .into_iter()
 52                .map(|chunk| chunk.into_acp(cx))
 53                .collect(),
 54        }
 55    }
 56}
 57
 58#[derive(Clone, Debug, Eq, PartialEq)]
 59pub enum MessageChunk {
 60    Text {
 61        chunk: Entity<Markdown>,
 62    },
 63    File {
 64        content: FileContent,
 65    },
 66    Directory {
 67        path: PathBuf,
 68        contents: Vec<FileContent>,
 69    },
 70    Symbol {
 71        path: PathBuf,
 72        range: Range<u64>,
 73        version: FileVersion,
 74        name: SharedString,
 75        content: SharedString,
 76    },
 77    Fetch {
 78        url: SharedString,
 79        content: SharedString,
 80    },
 81}
 82
 83impl MessageChunk {
 84    pub fn from_acp(
 85        chunk: acp::MessageChunk,
 86        language_registry: Arc<LanguageRegistry>,
 87        cx: &mut App,
 88    ) -> Self {
 89        match chunk {
 90            acp::MessageChunk::Text { chunk } => MessageChunk::Text {
 91                chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
 92            },
 93        }
 94    }
 95
 96    pub fn into_acp(self, cx: &App) -> acp::MessageChunk {
 97        match self {
 98            MessageChunk::Text { chunk } => acp::MessageChunk::Text {
 99                chunk: chunk.read(cx).source().to_string(),
100            },
101            MessageChunk::File { .. } => todo!(),
102            MessageChunk::Directory { .. } => todo!(),
103            MessageChunk::Symbol { .. } => todo!(),
104            MessageChunk::Fetch { .. } => todo!(),
105        }
106    }
107
108    pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
109        MessageChunk::Text {
110            chunk: cx.new(|cx| {
111                Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
112            }),
113        }
114    }
115}
116
117#[derive(Debug)]
118pub enum AgentThreadEntryContent {
119    Message(Message),
120    ToolCall(ToolCall),
121}
122
123#[derive(Debug)]
124pub enum ToolCall {
125    WaitingForConfirmation {
126        id: ToolCallId,
127        tool_name: Entity<Markdown>,
128        description: Entity<Markdown>,
129        respond_tx: oneshot::Sender<bool>,
130    },
131    // todo! Running?
132    Allowed,
133    Rejected,
134}
135
136/// A `ThreadEntryId` that is known to be a ToolCall
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
138pub struct ToolCallId(ThreadEntryId);
139
140impl ToolCallId {
141    pub fn as_u64(&self) -> u64 {
142        self.0.0
143    }
144}
145
146#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
147pub struct ThreadEntryId(pub u64);
148
149impl ThreadEntryId {
150    pub fn post_inc(&mut self) -> Self {
151        let id = *self;
152        self.0 += 1;
153        id
154    }
155}
156
157#[derive(Debug)]
158pub struct ThreadEntry {
159    pub id: ThreadEntryId,
160    pub content: AgentThreadEntryContent,
161}
162
163pub struct AcpThread {
164    id: ThreadId,
165    next_entry_id: ThreadEntryId,
166    entries: Vec<ThreadEntry>,
167    server: Arc<AcpServer>,
168    title: SharedString,
169    project: Entity<Project>,
170}
171
172enum AcpThreadEvent {
173    NewEntry,
174    EntryUpdated(usize),
175}
176
177impl EventEmitter<AcpThreadEvent> for AcpThread {}
178
179impl AcpThread {
180    pub fn new(
181        server: Arc<AcpServer>,
182        thread_id: ThreadId,
183        entries: Vec<AgentThreadEntryContent>,
184        project: Entity<Project>,
185        _: &mut Context<Self>,
186    ) -> Self {
187        let mut next_entry_id = ThreadEntryId(0);
188        Self {
189            title: "A new agent2 thread".into(),
190            entries: entries
191                .into_iter()
192                .map(|entry| ThreadEntry {
193                    id: next_entry_id.post_inc(),
194                    content: entry,
195                })
196                .collect(),
197            server,
198            id: thread_id,
199            next_entry_id,
200            project,
201        }
202    }
203
204    pub fn title(&self) -> SharedString {
205        self.title.clone()
206    }
207
208    pub fn entries(&self) -> &[ThreadEntry] {
209        &self.entries
210    }
211
212    pub fn push_entry(
213        &mut self,
214        entry: AgentThreadEntryContent,
215        cx: &mut Context<Self>,
216    ) -> ThreadEntryId {
217        let id = self.next_entry_id.post_inc();
218        self.entries.push(ThreadEntry { id, content: entry });
219        cx.emit(AcpThreadEvent::NewEntry);
220        id
221    }
222
223    pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context<Self>) {
224        let entries_len = self.entries.len();
225        if let Some(last_entry) = self.entries.last_mut()
226            && let AgentThreadEntryContent::Message(Message {
227                ref mut chunks,
228                role: Role::Assistant,
229            }) = last_entry.content
230        {
231            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
232
233            if let (
234                Some(MessageChunk::Text { chunk: old_chunk }),
235                acp::MessageChunk::Text { chunk: new_chunk },
236            ) = (chunks.last_mut(), &chunk)
237            {
238                old_chunk.update(cx, |old_chunk, cx| {
239                    old_chunk.append(&new_chunk, cx);
240                });
241            } else {
242                chunks.push(MessageChunk::from_acp(
243                    chunk,
244                    self.project.read(cx).languages().clone(),
245                    cx,
246                ));
247            }
248
249            return;
250        }
251
252        let chunk = MessageChunk::from_acp(chunk, self.project.read(cx).languages().clone(), cx);
253
254        self.push_entry(
255            AgentThreadEntryContent::Message(Message {
256                role: Role::Assistant,
257                chunks: vec![chunk],
258            }),
259            cx,
260        );
261    }
262
263    pub fn push_tool_call(
264        &mut self,
265        title: String,
266        description: String,
267        respond_tx: oneshot::Sender<bool>,
268        cx: &mut Context<Self>,
269    ) -> ToolCallId {
270        let language_registry = self.project.read(cx).languages().clone();
271
272        let entry_id = self.push_entry(
273            AgentThreadEntryContent::ToolCall(ToolCall::WaitingForConfirmation {
274                // todo! clean up id creation
275                id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
276                tool_name: cx.new(|cx| {
277                    Markdown::new(title.into(), Some(language_registry.clone()), None, cx)
278                }),
279                description: cx.new(|cx| {
280                    Markdown::new(
281                        description.into(),
282                        Some(language_registry.clone()),
283                        None,
284                        cx,
285                    )
286                }),
287                respond_tx,
288            }),
289            cx,
290        );
291
292        ToolCallId(entry_id)
293    }
294
295    pub fn authorize_tool_call(&mut self, id: ToolCallId, allowed: bool, cx: &mut Context<Self>) {
296        let Some(entry) = self.entry_mut(id.0) else {
297            return;
298        };
299
300        let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
301            debug_panic!("expected ToolCall");
302            return;
303        };
304
305        let new_state = if allowed {
306            ToolCall::Allowed
307        } else {
308            ToolCall::Rejected
309        };
310
311        let call = mem::replace(call, new_state);
312
313        if let ToolCall::WaitingForConfirmation { respond_tx, .. } = call {
314            respond_tx.send(allowed).log_err();
315        } else {
316            debug_panic!("tried to authorize an already authorized tool call");
317        }
318
319        cx.emit(AcpThreadEvent::EntryUpdated(id.0.0 as usize));
320    }
321
322    fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
323        let entry = self.entries.get_mut(id.0 as usize);
324        debug_assert!(
325            entry.is_some(),
326            "We shouldn't give out ids to entries that don't exist"
327        );
328        entry
329    }
330
331    pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> {
332        let agent = self.server.clone();
333        let id = self.id.clone();
334        let chunk = MessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
335        let message = Message {
336            role: Role::User,
337            chunks: vec![chunk],
338        };
339        self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
340        let acp_message = message.into_acp(cx);
341        cx.spawn(async move |_, cx| {
342            agent.send_message(id, acp_message, cx).await?;
343            Ok(())
344        })
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use futures::{FutureExt as _, channel::mpsc, select};
352    use gpui::{AsyncApp, TestAppContext};
353    use project::FakeFs;
354    use serde_json::json;
355    use settings::SettingsStore;
356    use smol::stream::StreamExt;
357    use std::{env, path::Path, process::Stdio, time::Duration};
358    use util::path;
359
360    fn init_test(cx: &mut TestAppContext) {
361        env_logger::try_init().ok();
362        cx.update(|cx| {
363            let settings_store = SettingsStore::test(cx);
364            cx.set_global(settings_store);
365            Project::init_settings(cx);
366            language::init(cx);
367        });
368    }
369
370    #[gpui::test]
371    async fn test_gemini_basic(cx: &mut TestAppContext) {
372        init_test(cx);
373
374        cx.executor().allow_parking();
375
376        let fs = FakeFs::new(cx.executor());
377        let project = Project::test(fs, [], cx).await;
378        let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
379        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
380        thread
381            .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
382            .await
383            .unwrap();
384
385        thread.read_with(cx, |thread, _| {
386            assert_eq!(thread.entries.len(), 2);
387            assert!(matches!(
388                thread.entries[0].content,
389                AgentThreadEntryContent::Message(Message {
390                    role: Role::User,
391                    ..
392                })
393            ));
394            assert!(matches!(
395                thread.entries[1].content,
396                AgentThreadEntryContent::Message(Message {
397                    role: Role::Assistant,
398                    ..
399                })
400            ));
401        });
402    }
403
404    #[gpui::test]
405    async fn test_gemini_tool_call(cx: &mut TestAppContext) {
406        init_test(cx);
407
408        cx.executor().allow_parking();
409
410        let fs = FakeFs::new(cx.executor());
411        fs.insert_tree(
412            path!("/private/tmp"),
413            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
414        )
415        .await;
416        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
417        let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
418        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
419        let full_turn = thread.update(cx, |thread, cx| {
420            thread.send(
421                "Read the '/private/tmp/foo' file and tell me what you see.",
422                cx,
423            )
424        });
425
426        run_until_tool_call(&thread, cx).await;
427
428        let tool_call_id = thread.read_with(cx, |thread, cx| {
429            let AgentThreadEntryContent::ToolCall(ToolCall::WaitingForConfirmation {
430                id,
431                tool_name,
432                description,
433                ..
434            }) = &thread.entries().last().unwrap().content
435            else {
436                panic!();
437            };
438
439            tool_name.read_with(cx, |md, _cx| {
440                assert_eq!(md.source(), "read_file");
441            });
442
443            description.read_with(cx, |md, _cx| {
444                assert!(
445                    md.source().contains("foo"),
446                    "Expected description to contain 'foo', but got {}",
447                    md.source()
448                );
449            });
450            *id
451        });
452
453        thread.update(cx, |thread, cx| {
454            thread.authorize_tool_call(tool_call_id, true, cx);
455            assert!(matches!(
456                thread.entries().last().unwrap().content,
457                AgentThreadEntryContent::ToolCall(ToolCall::Allowed)
458            ));
459        });
460
461        full_turn.await.unwrap();
462
463        thread.read_with(cx, |thread, _| {
464            assert!(thread.entries.len() >= 3, "{:?}", &thread.entries);
465            assert!(matches!(
466                thread.entries[0].content,
467                AgentThreadEntryContent::Message(Message {
468                    role: Role::User,
469                    ..
470                })
471            ));
472            assert!(matches!(
473                thread.entries[1].content,
474                AgentThreadEntryContent::ToolCall(ToolCall::Allowed)
475            ));
476            assert!(matches!(
477                thread.entries[2].content,
478                AgentThreadEntryContent::Message(Message {
479                    role: Role::Assistant,
480                    ..
481                })
482            ));
483        });
484    }
485
486    async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
487        let (mut tx, mut rx) = mpsc::channel(1);
488
489        let subscription = cx.update(|cx| {
490            cx.subscribe(thread, move |thread, _, cx| {
491                if thread
492                    .read(cx)
493                    .entries
494                    .iter()
495                    .any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_)))
496                {
497                    tx.try_send(()).unwrap();
498                }
499            })
500        });
501
502        select! {
503            _ = cx.executor().timer(Duration::from_secs(5)).fuse() => {
504                panic!("Timeout waiting for tool call")
505            }
506            _ = rx.next().fuse() => {
507                drop(subscription);
508            }
509        }
510    }
511
512    pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> {
513        let cli_path =
514            Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
515        let mut command = util::command::new_smol_command("node");
516        command
517            .arg(cli_path)
518            .arg("--acp")
519            .args(["--model", "gemini-2.5-flash"])
520            .current_dir("/private/tmp")
521            .stdin(Stdio::piped())
522            .stdout(Stdio::piped())
523            .stderr(Stdio::inherit())
524            .kill_on_drop(true);
525
526        if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
527            command.env("GEMINI_API_KEY", gemini_key);
528        }
529
530        let child = command.spawn().unwrap();
531
532        Ok(AcpServer::stdio(child, project, &mut cx))
533    }
534}