acp.rs

  1mod server;
  2mod thread_view;
  3
  4use agentic_coding_protocol::{self as acp, Role};
  5use anyhow::Result;
  6use chrono::{DateTime, Utc};
  7use futures::channel::oneshot;
  8use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
  9use language::LanguageRegistry;
 10use markdown::Markdown;
 11use project::Project;
 12use std::{mem, ops::Range, path::PathBuf, sync::Arc};
 13use ui::App;
 14use util::{ResultExt, debug_panic};
 15
 16pub use server::AcpServer;
 17pub use thread_view::AcpThreadView;
 18
 19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 20pub struct ThreadId(SharedString);
 21
 22#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 23pub struct FileVersion(u64);
 24
 25#[derive(Debug)]
 26pub struct AgentThreadSummary {
 27    pub id: ThreadId,
 28    pub title: String,
 29    pub created_at: DateTime<Utc>,
 30}
 31
 32#[derive(Clone, Debug, PartialEq, Eq)]
 33pub struct FileContent {
 34    pub path: PathBuf,
 35    pub version: FileVersion,
 36    pub content: SharedString,
 37}
 38
 39#[derive(Clone, Debug, Eq, PartialEq)]
 40pub struct Message {
 41    pub role: acp::Role,
 42    pub chunks: Vec<MessageChunk>,
 43}
 44
 45impl Message {
 46    fn into_acp(self, cx: &App) -> acp::Message {
 47        acp::Message {
 48            role: self.role,
 49            chunks: self
 50                .chunks
 51                .into_iter()
 52                .map(|chunk| chunk.into_acp(cx))
 53                .collect(),
 54        }
 55    }
 56}
 57
 58#[derive(Clone, Debug, Eq, PartialEq)]
 59pub enum MessageChunk {
 60    Text {
 61        chunk: Entity<Markdown>,
 62    },
 63    File {
 64        content: FileContent,
 65    },
 66    Directory {
 67        path: PathBuf,
 68        contents: Vec<FileContent>,
 69    },
 70    Symbol {
 71        path: PathBuf,
 72        range: Range<u64>,
 73        version: FileVersion,
 74        name: SharedString,
 75        content: SharedString,
 76    },
 77    Fetch {
 78        url: SharedString,
 79        content: SharedString,
 80    },
 81}
 82
 83impl MessageChunk {
 84    pub fn from_acp(
 85        chunk: acp::MessageChunk,
 86        language_registry: Arc<LanguageRegistry>,
 87        cx: &mut App,
 88    ) -> Self {
 89        match chunk {
 90            acp::MessageChunk::Text { chunk } => MessageChunk::Text {
 91                chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
 92            },
 93        }
 94    }
 95
 96    pub fn into_acp(self, cx: &App) -> acp::MessageChunk {
 97        match self {
 98            MessageChunk::Text { chunk } => acp::MessageChunk::Text {
 99                chunk: chunk.read(cx).source().to_string(),
100            },
101            MessageChunk::File { .. } => todo!(),
102            MessageChunk::Directory { .. } => todo!(),
103            MessageChunk::Symbol { .. } => todo!(),
104            MessageChunk::Fetch { .. } => todo!(),
105        }
106    }
107
108    pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
109        MessageChunk::Text {
110            chunk: cx.new(|cx| {
111                Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
112            }),
113        }
114    }
115}
116
117#[derive(Debug)]
118pub enum AgentThreadEntryContent {
119    Message(Message),
120    ToolCall(ToolCall),
121}
122
123#[derive(Debug)]
124pub enum ToolCall {
125    WaitingForConfirmation {
126        id: ToolCallId,
127        tool_name: Entity<Markdown>,
128        description: Entity<Markdown>,
129        respond_tx: oneshot::Sender<bool>,
130    },
131    // todo! Running?
132    Allowed,
133    Rejected,
134}
135
136/// A `ThreadEntryId` that is known to be a ToolCall
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
138pub struct ToolCallId(ThreadEntryId);
139
140#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
141pub struct ThreadEntryId(pub u64);
142
143impl ThreadEntryId {
144    pub fn post_inc(&mut self) -> Self {
145        let id = *self;
146        self.0 += 1;
147        id
148    }
149}
150
151#[derive(Debug)]
152pub struct ThreadEntry {
153    pub id: ThreadEntryId,
154    pub content: AgentThreadEntryContent,
155}
156
157pub struct AcpThread {
158    id: ThreadId,
159    next_entry_id: ThreadEntryId,
160    entries: Vec<ThreadEntry>,
161    server: Arc<AcpServer>,
162    title: SharedString,
163    project: Entity<Project>,
164}
165
166enum AcpThreadEvent {
167    NewEntry,
168    EntryUpdated(usize),
169}
170
171impl EventEmitter<AcpThreadEvent> for AcpThread {}
172
173impl AcpThread {
174    pub fn new(
175        server: Arc<AcpServer>,
176        thread_id: ThreadId,
177        entries: Vec<AgentThreadEntryContent>,
178        project: Entity<Project>,
179        _: &mut Context<Self>,
180    ) -> Self {
181        let mut next_entry_id = ThreadEntryId(0);
182        Self {
183            title: "A new agent2 thread".into(),
184            entries: entries
185                .into_iter()
186                .map(|entry| ThreadEntry {
187                    id: next_entry_id.post_inc(),
188                    content: entry,
189                })
190                .collect(),
191            server,
192            id: thread_id,
193            next_entry_id,
194            project,
195        }
196    }
197
198    pub fn title(&self) -> SharedString {
199        self.title.clone()
200    }
201
202    pub fn entries(&self) -> &[ThreadEntry] {
203        &self.entries
204    }
205
206    pub fn push_entry(
207        &mut self,
208        entry: AgentThreadEntryContent,
209        cx: &mut Context<Self>,
210    ) -> ThreadEntryId {
211        let id = self.next_entry_id.post_inc();
212        self.entries.push(ThreadEntry { id, content: entry });
213        cx.emit(AcpThreadEvent::NewEntry);
214        id
215    }
216
217    pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context<Self>) {
218        let entries_len = self.entries.len();
219        if let Some(last_entry) = self.entries.last_mut()
220            && let AgentThreadEntryContent::Message(Message {
221                ref mut chunks,
222                role: Role::Assistant,
223            }) = last_entry.content
224        {
225            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
226
227            if let (
228                Some(MessageChunk::Text { chunk: old_chunk }),
229                acp::MessageChunk::Text { chunk: new_chunk },
230            ) = (chunks.last_mut(), &chunk)
231            {
232                old_chunk.update(cx, |old_chunk, cx| {
233                    old_chunk.append(&new_chunk, cx);
234                });
235            } else {
236                chunks.push(MessageChunk::from_acp(
237                    chunk,
238                    self.project.read(cx).languages().clone(),
239                    cx,
240                ));
241            }
242
243            return;
244        }
245
246        let chunk = MessageChunk::from_acp(chunk, self.project.read(cx).languages().clone(), cx);
247
248        self.push_entry(
249            AgentThreadEntryContent::Message(Message {
250                role: Role::Assistant,
251                chunks: vec![chunk],
252            }),
253            cx,
254        );
255    }
256
257    pub fn push_tool_call(
258        &mut self,
259        title: String,
260        description: String,
261        respond_tx: oneshot::Sender<bool>,
262        cx: &mut Context<Self>,
263    ) -> ToolCallId {
264        let language_registry = self.project.read(cx).languages().clone();
265
266        let entry_id = self.push_entry(
267            AgentThreadEntryContent::ToolCall(ToolCall::WaitingForConfirmation {
268                // todo! clean up id creation
269                id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
270                tool_name: cx.new(|cx| {
271                    Markdown::new(title.into(), Some(language_registry.clone()), None, cx)
272                }),
273                description: cx.new(|cx| {
274                    Markdown::new(
275                        description.into(),
276                        Some(language_registry.clone()),
277                        None,
278                        cx,
279                    )
280                }),
281                respond_tx,
282            }),
283            cx,
284        );
285
286        ToolCallId(entry_id)
287    }
288
289    pub fn authorize_tool_call(&mut self, id: ToolCallId, allowed: bool, cx: &mut Context<Self>) {
290        let Some(entry) = self.entry_mut(id.0) else {
291            return;
292        };
293
294        let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
295            debug_panic!("expected ToolCall");
296            return;
297        };
298
299        let new_state = if allowed {
300            ToolCall::Allowed
301        } else {
302            ToolCall::Rejected
303        };
304
305        let call = mem::replace(call, new_state);
306
307        if let ToolCall::WaitingForConfirmation { respond_tx, .. } = call {
308            respond_tx.send(allowed).log_err();
309        } else {
310            debug_panic!("tried to authorize an already authorized tool call");
311        }
312
313        cx.emit(AcpThreadEvent::EntryUpdated(id.0.0 as usize));
314    }
315
316    fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
317        let entry = self.entries.get_mut(id.0 as usize);
318        debug_assert!(
319            entry.is_some(),
320            "We shouldn't give out ids to entries that don't exist"
321        );
322        entry
323    }
324
325    pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> {
326        let agent = self.server.clone();
327        let id = self.id.clone();
328        let chunk = MessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
329        let message = Message {
330            role: Role::User,
331            chunks: vec![chunk],
332        };
333        self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
334        let acp_message = message.into_acp(cx);
335        cx.spawn(async move |_, cx| {
336            agent.send_message(id, acp_message, cx).await?;
337            Ok(())
338        })
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use futures::{FutureExt as _, channel::mpsc, select};
346    use gpui::{AsyncApp, TestAppContext};
347    use project::FakeFs;
348    use serde_json::json;
349    use settings::SettingsStore;
350    use smol::stream::StreamExt;
351    use std::{env, path::Path, process::Stdio, time::Duration};
352    use util::path;
353
354    fn init_test(cx: &mut TestAppContext) {
355        env_logger::try_init().ok();
356        cx.update(|cx| {
357            let settings_store = SettingsStore::test(cx);
358            cx.set_global(settings_store);
359            Project::init_settings(cx);
360            language::init(cx);
361        });
362    }
363
364    #[gpui::test]
365    async fn test_gemini_basic(cx: &mut TestAppContext) {
366        init_test(cx);
367
368        cx.executor().allow_parking();
369
370        let fs = FakeFs::new(cx.executor());
371        let project = Project::test(fs, [], cx).await;
372        let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
373        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
374        thread
375            .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
376            .await
377            .unwrap();
378
379        thread.read_with(cx, |thread, _| {
380            assert_eq!(thread.entries.len(), 2);
381            assert!(matches!(
382                thread.entries[0].content,
383                AgentThreadEntryContent::Message(Message {
384                    role: Role::User,
385                    ..
386                })
387            ));
388            assert!(matches!(
389                thread.entries[1].content,
390                AgentThreadEntryContent::Message(Message {
391                    role: Role::Assistant,
392                    ..
393                })
394            ));
395        });
396    }
397
398    #[gpui::test]
399    async fn test_gemini_tool_call(cx: &mut TestAppContext) {
400        init_test(cx);
401
402        cx.executor().allow_parking();
403
404        let fs = FakeFs::new(cx.executor());
405        fs.insert_tree(
406            path!("/private/tmp"),
407            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
408        )
409        .await;
410        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
411        let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
412        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
413        let full_turn = thread.update(cx, |thread, cx| {
414            thread.send(
415                "Read the '/private/tmp/foo' file and tell me what you see.",
416                cx,
417            )
418        });
419
420        run_until_tool_call(&thread, cx).await;
421
422        let tool_call_id = thread.read_with(cx, |thread, cx| {
423            let AgentThreadEntryContent::ToolCall(ToolCall::WaitingForConfirmation {
424                id,
425                tool_name,
426                description,
427                ..
428            }) = &thread.entries().last().unwrap().content
429            else {
430                panic!();
431            };
432
433            tool_name.read_with(cx, |md, _cx| {
434                assert_eq!(md.source(), "read_file");
435            });
436
437            description.read_with(cx, |md, _cx| {
438                assert!(
439                    md.source().contains("foo"),
440                    "Expected description to contain 'foo', but got {}",
441                    md.source()
442                );
443            });
444            *id
445        });
446
447        thread.update(cx, |thread, cx| {
448            thread.authorize_tool_call(tool_call_id, true, cx);
449            assert!(matches!(
450                thread.entries().last().unwrap().content,
451                AgentThreadEntryContent::ToolCall(ToolCall::Allowed)
452            ));
453        });
454
455        full_turn.await.unwrap();
456
457        thread.read_with(cx, |thread, _| {
458            assert!(thread.entries.len() >= 3, "{:?}", &thread.entries);
459            assert!(matches!(
460                thread.entries[0].content,
461                AgentThreadEntryContent::Message(Message {
462                    role: Role::User,
463                    ..
464                })
465            ));
466            assert!(matches!(
467                thread.entries[1].content,
468                AgentThreadEntryContent::ToolCall(ToolCall::Allowed)
469            ));
470            assert!(matches!(
471                thread.entries[2].content,
472                AgentThreadEntryContent::Message(Message {
473                    role: Role::Assistant,
474                    ..
475                })
476            ));
477        });
478    }
479
480    async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
481        let (mut tx, mut rx) = mpsc::channel(1);
482
483        let subscription = cx.update(|cx| {
484            cx.subscribe(thread, move |thread, _, cx| {
485                if thread
486                    .read(cx)
487                    .entries
488                    .iter()
489                    .any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_)))
490                {
491                    tx.try_send(()).unwrap();
492                }
493            })
494        });
495
496        select! {
497            _ = cx.executor().timer(Duration::from_secs(5)).fuse() => {
498                panic!("Timeout waiting for tool call")
499            }
500            _ = rx.next().fuse() => {
501                drop(subscription);
502            }
503        }
504    }
505
506    pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> {
507        let cli_path =
508            Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
509        let mut command = util::command::new_smol_command("node");
510        command
511            .arg(cli_path)
512            .arg("--acp")
513            .args(["--model", "gemini-2.5-flash"])
514            .current_dir("/private/tmp")
515            .stdin(Stdio::piped())
516            .stdout(Stdio::piped())
517            .stderr(Stdio::inherit())
518            .kill_on_drop(true);
519
520        if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
521            command.env("GEMINI_API_KEY", gemini_key);
522        }
523
524        let child = command.spawn().unwrap();
525
526        Ok(AcpServer::stdio(child, project, &mut cx))
527    }
528}