acp.rs

  1mod server;
  2mod thread_view;
  3
  4use agentic_coding_protocol::{self as acp, Role};
  5use anyhow::{Context as _, Result};
  6use chrono::{DateTime, Utc};
  7use futures::channel::oneshot;
  8use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
  9use language::LanguageRegistry;
 10use markdown::Markdown;
 11use project::Project;
 12use std::{mem, ops::Range, path::PathBuf, sync::Arc};
 13use ui::App;
 14use util::{ResultExt, debug_panic};
 15
 16pub use server::AcpServer;
 17pub use thread_view::AcpThreadView;
 18
 19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
 20pub struct ThreadId(SharedString);
 21
 22#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 23pub struct FileVersion(u64);
 24
 25#[derive(Debug)]
 26pub struct AgentThreadSummary {
 27    pub id: ThreadId,
 28    pub title: String,
 29    pub created_at: DateTime<Utc>,
 30}
 31
 32#[derive(Clone, Debug, PartialEq, Eq)]
 33pub struct FileContent {
 34    pub path: PathBuf,
 35    pub version: FileVersion,
 36    pub content: SharedString,
 37}
 38
 39#[derive(Clone, Debug, Eq, PartialEq)]
 40pub struct Message {
 41    pub role: acp::Role,
 42    pub chunks: Vec<MessageChunk>,
 43}
 44
 45impl Message {
 46    fn into_acp(self, cx: &App) -> acp::Message {
 47        acp::Message {
 48            role: self.role,
 49            chunks: self
 50                .chunks
 51                .into_iter()
 52                .map(|chunk| chunk.into_acp(cx))
 53                .collect(),
 54        }
 55    }
 56}
 57
 58#[derive(Clone, Debug, Eq, PartialEq)]
 59pub enum MessageChunk {
 60    Text {
 61        chunk: Entity<Markdown>,
 62    },
 63    File {
 64        content: FileContent,
 65    },
 66    Directory {
 67        path: PathBuf,
 68        contents: Vec<FileContent>,
 69    },
 70    Symbol {
 71        path: PathBuf,
 72        range: Range<u64>,
 73        version: FileVersion,
 74        name: SharedString,
 75        content: SharedString,
 76    },
 77    Fetch {
 78        url: SharedString,
 79        content: SharedString,
 80    },
 81}
 82
 83impl MessageChunk {
 84    pub fn from_acp(
 85        chunk: acp::MessageChunk,
 86        language_registry: Arc<LanguageRegistry>,
 87        cx: &mut App,
 88    ) -> Self {
 89        match chunk {
 90            acp::MessageChunk::Text { chunk } => MessageChunk::Text {
 91                chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
 92            },
 93        }
 94    }
 95
 96    pub fn into_acp(self, cx: &App) -> acp::MessageChunk {
 97        match self {
 98            MessageChunk::Text { chunk } => acp::MessageChunk::Text {
 99                chunk: chunk.read(cx).source().to_string(),
100            },
101            MessageChunk::File { .. } => todo!(),
102            MessageChunk::Directory { .. } => todo!(),
103            MessageChunk::Symbol { .. } => todo!(),
104            MessageChunk::Fetch { .. } => todo!(),
105        }
106    }
107
108    pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
109        MessageChunk::Text {
110            chunk: cx.new(|cx| {
111                Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
112            }),
113        }
114    }
115}
116
117#[derive(Debug)]
118pub enum AgentThreadEntryContent {
119    Message(Message),
120    ToolCall(ToolCall),
121}
122
123#[derive(Debug)]
124pub struct ToolCall {
125    id: ToolCallId,
126    tool_name: Entity<Markdown>,
127    status: ToolCallStatus,
128}
129
130#[derive(Debug)]
131pub enum ToolCallStatus {
132    WaitingForConfirmation {
133        confirmation: acp::ToolCallConfirmation,
134        respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
135    },
136    // todo! Running?
137    Allowed {
138        // todo! should this be variants in crate::ToolCallStatus instead?
139        status: acp::ToolCallStatus,
140        content: Option<Entity<Markdown>>,
141    },
142    Rejected,
143}
144
145/// A `ThreadEntryId` that is known to be a ToolCall
146#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
147pub struct ToolCallId(ThreadEntryId);
148
149impl ToolCallId {
150    pub fn as_u64(&self) -> u64 {
151        self.0.0
152    }
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
156pub struct ThreadEntryId(pub u64);
157
158impl ThreadEntryId {
159    pub fn post_inc(&mut self) -> Self {
160        let id = *self;
161        self.0 += 1;
162        id
163    }
164}
165
166#[derive(Debug)]
167pub struct ThreadEntry {
168    pub id: ThreadEntryId,
169    pub content: AgentThreadEntryContent,
170}
171
172pub struct AcpThread {
173    id: ThreadId,
174    next_entry_id: ThreadEntryId,
175    entries: Vec<ThreadEntry>,
176    server: Arc<AcpServer>,
177    title: SharedString,
178    project: Entity<Project>,
179}
180
181enum AcpThreadEvent {
182    NewEntry,
183    EntryUpdated(usize),
184}
185
186impl EventEmitter<AcpThreadEvent> for AcpThread {}
187
188impl AcpThread {
189    pub fn new(
190        server: Arc<AcpServer>,
191        thread_id: ThreadId,
192        entries: Vec<AgentThreadEntryContent>,
193        project: Entity<Project>,
194        _: &mut Context<Self>,
195    ) -> Self {
196        let mut next_entry_id = ThreadEntryId(0);
197        Self {
198            title: "A new agent2 thread".into(),
199            entries: entries
200                .into_iter()
201                .map(|entry| ThreadEntry {
202                    id: next_entry_id.post_inc(),
203                    content: entry,
204                })
205                .collect(),
206            server,
207            id: thread_id,
208            next_entry_id,
209            project,
210        }
211    }
212
213    pub fn title(&self) -> SharedString {
214        self.title.clone()
215    }
216
217    pub fn entries(&self) -> &[ThreadEntry] {
218        &self.entries
219    }
220
221    pub fn push_entry(
222        &mut self,
223        entry: AgentThreadEntryContent,
224        cx: &mut Context<Self>,
225    ) -> ThreadEntryId {
226        let id = self.next_entry_id.post_inc();
227        self.entries.push(ThreadEntry { id, content: entry });
228        cx.emit(AcpThreadEvent::NewEntry);
229        id
230    }
231
232    pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context<Self>) {
233        let entries_len = self.entries.len();
234        if let Some(last_entry) = self.entries.last_mut()
235            && let AgentThreadEntryContent::Message(Message {
236                ref mut chunks,
237                role: Role::Assistant,
238            }) = last_entry.content
239        {
240            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
241
242            if let (
243                Some(MessageChunk::Text { chunk: old_chunk }),
244                acp::MessageChunk::Text { chunk: new_chunk },
245            ) = (chunks.last_mut(), &chunk)
246            {
247                old_chunk.update(cx, |old_chunk, cx| {
248                    old_chunk.append(&new_chunk, cx);
249                });
250            } else {
251                chunks.push(MessageChunk::from_acp(
252                    chunk,
253                    self.project.read(cx).languages().clone(),
254                    cx,
255                ));
256            }
257
258            return;
259        }
260
261        let chunk = MessageChunk::from_acp(chunk, self.project.read(cx).languages().clone(), cx);
262
263        self.push_entry(
264            AgentThreadEntryContent::Message(Message {
265                role: Role::Assistant,
266                chunks: vec![chunk],
267            }),
268            cx,
269        );
270    }
271
272    pub fn request_tool_call(
273        &mut self,
274        title: String,
275        confirmation: acp::ToolCallConfirmation,
276        cx: &mut Context<Self>,
277    ) -> ToolCallRequest {
278        let (tx, rx) = oneshot::channel();
279
280        let status = ToolCallStatus::WaitingForConfirmation {
281            confirmation,
282            respond_tx: tx,
283        };
284
285        let id = self.insert_tool_call(title, status, cx);
286        ToolCallRequest { id, outcome: rx }
287    }
288
289    pub fn push_tool_call(&mut self, title: String, cx: &mut Context<Self>) -> ToolCallId {
290        let status = ToolCallStatus::Allowed {
291            status: acp::ToolCallStatus::Running,
292            content: None,
293        };
294
295        self.insert_tool_call(title, status, cx)
296    }
297
298    fn insert_tool_call(
299        &mut self,
300        title: String,
301        status: ToolCallStatus,
302        cx: &mut Context<Self>,
303    ) -> ToolCallId {
304        let language_registry = self.project.read(cx).languages().clone();
305
306        let entry_id = self.push_entry(
307            AgentThreadEntryContent::ToolCall(ToolCall {
308                // todo! clean up id creation
309                id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
310                tool_name: cx.new(|cx| {
311                    Markdown::new(title.into(), Some(language_registry.clone()), None, cx)
312                }),
313                status,
314            }),
315            cx,
316        );
317
318        ToolCallId(entry_id)
319    }
320
321    pub fn authorize_tool_call(
322        &mut self,
323        id: ToolCallId,
324        outcome: acp::ToolCallConfirmationOutcome,
325        cx: &mut Context<Self>,
326    ) {
327        let Some(entry) = self.entry_mut(id.0) else {
328            return;
329        };
330
331        let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
332            debug_panic!("expected ToolCall");
333            return;
334        };
335
336        let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
337            ToolCallStatus::Rejected
338        } else {
339            ToolCallStatus::Allowed {
340                status: acp::ToolCallStatus::Running,
341                content: None,
342            }
343        };
344
345        let curr_status = mem::replace(&mut call.status, new_status);
346
347        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
348            respond_tx.send(outcome).log_err();
349        } else {
350            debug_panic!("tried to authorize an already authorized tool call");
351        }
352
353        cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
354    }
355
356    pub fn update_tool_call(
357        &mut self,
358        id: ToolCallId,
359        new_status: acp::ToolCallStatus,
360        new_content: Option<acp::ToolCallContent>,
361        cx: &mut Context<Self>,
362    ) -> Result<()> {
363        let language_registry = self.project.read(cx).languages().clone();
364        let entry = self.entry_mut(id.0).context("Entry not found")?;
365
366        match &mut entry.content {
367            AgentThreadEntryContent::ToolCall(call) => match &mut call.status {
368                ToolCallStatus::Allowed { content, status } => {
369                    *content = new_content.map(|new_content| {
370                        let acp::ToolCallContent::Markdown { markdown } = new_content;
371
372                        cx.new(|cx| {
373                            Markdown::new(markdown.into(), Some(language_registry), None, cx)
374                        })
375                    });
376
377                    *status = new_status;
378                }
379                ToolCallStatus::WaitingForConfirmation { .. } => {
380                    anyhow::bail!("Tool call hasn't been authorized yet")
381                }
382                ToolCallStatus::Rejected => {
383                    anyhow::bail!("Tool call was rejected and therefore can't be updated")
384                }
385            },
386            _ => anyhow::bail!("Entry is not a tool call"),
387        }
388
389        cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
390        Ok(())
391    }
392
393    fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
394        let entry = self.entries.get_mut(id.0 as usize);
395        debug_assert!(
396            entry.is_some(),
397            "We shouldn't give out ids to entries that don't exist"
398        );
399        entry
400    }
401
402    /// Returns true if the last turn is awaiting tool authorization
403    pub fn waiting_for_tool_confirmation(&self) -> bool {
404        for entry in self.entries.iter().rev() {
405            match &entry.content {
406                AgentThreadEntryContent::ToolCall(call) => match call.status {
407                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
408                    ToolCallStatus::Allowed { .. } | ToolCallStatus::Rejected => continue,
409                },
410                AgentThreadEntryContent::Message(_) => {
411                    // Reached the beginning of the turn
412                    return false;
413                }
414            }
415        }
416        false
417    }
418
419    pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> {
420        let agent = self.server.clone();
421        let id = self.id.clone();
422        let chunk = MessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
423        let message = Message {
424            role: Role::User,
425            chunks: vec![chunk],
426        };
427        self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
428        let acp_message = message.into_acp(cx);
429        cx.spawn(async move |_, cx| {
430            agent.send_message(id, acp_message, cx).await?;
431            Ok(())
432        })
433    }
434}
435
436pub struct ToolCallRequest {
437    pub id: ToolCallId,
438    pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444    use futures::{FutureExt as _, channel::mpsc, select};
445    use gpui::{AsyncApp, TestAppContext};
446    use project::FakeFs;
447    use serde_json::json;
448    use settings::SettingsStore;
449    use smol::stream::StreamExt;
450    use std::{env, path::Path, process::Stdio, time::Duration};
451    use util::path;
452
453    fn init_test(cx: &mut TestAppContext) {
454        env_logger::try_init().ok();
455        cx.update(|cx| {
456            let settings_store = SettingsStore::test(cx);
457            cx.set_global(settings_store);
458            Project::init_settings(cx);
459            language::init(cx);
460        });
461    }
462
463    #[gpui::test]
464    async fn test_gemini_basic(cx: &mut TestAppContext) {
465        init_test(cx);
466
467        cx.executor().allow_parking();
468
469        let fs = FakeFs::new(cx.executor());
470        let project = Project::test(fs, [], cx).await;
471        let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
472        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
473        thread
474            .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
475            .await
476            .unwrap();
477
478        thread.read_with(cx, |thread, _| {
479            assert_eq!(thread.entries.len(), 2);
480            assert!(matches!(
481                thread.entries[0].content,
482                AgentThreadEntryContent::Message(Message {
483                    role: Role::User,
484                    ..
485                })
486            ));
487            assert!(matches!(
488                thread.entries[1].content,
489                AgentThreadEntryContent::Message(Message {
490                    role: Role::Assistant,
491                    ..
492                })
493            ));
494        });
495    }
496
497    #[gpui::test]
498    async fn test_gemini_tool_call(cx: &mut TestAppContext) {
499        init_test(cx);
500
501        cx.executor().allow_parking();
502
503        let fs = FakeFs::new(cx.executor());
504        fs.insert_tree(
505            path!("/private/tmp"),
506            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
507        )
508        .await;
509        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
510        let server = gemini_acp_server(project.clone(), cx.to_async()).unwrap();
511        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
512        let full_turn = thread.update(cx, |thread, cx| {
513            thread.send(
514                "Read the '/private/tmp/foo' file and tell me what you see.",
515                cx,
516            )
517        });
518
519        run_until_tool_call(&thread, cx).await;
520
521        let tool_call_id = thread.read_with(cx, |thread, cx| {
522            let AgentThreadEntryContent::ToolCall(ToolCall {
523                id,
524                tool_name,
525                status: ToolCallStatus::Allowed { .. },
526            }) = &thread.entries().last().unwrap().content
527            else {
528                panic!();
529            };
530
531            tool_name.read_with(cx, |md, _cx| {
532                assert_eq!(md.source(), "read_file");
533            });
534
535            // todo!
536            // description.read_with(cx, |md, _cx| {
537            //     assert!(
538            //         md.source().contains("foo"),
539            //         "Expected description to contain 'foo', but got {}",
540            //         md.source()
541            //     );
542            // });
543            *id
544        });
545
546        thread.update(cx, |thread, cx| {
547            thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
548            assert!(matches!(
549                thread.entries().last().unwrap().content,
550                AgentThreadEntryContent::ToolCall(ToolCall {
551                    status: ToolCallStatus::Allowed { .. },
552                    ..
553                })
554            ));
555        });
556
557        full_turn.await.unwrap();
558
559        thread.read_with(cx, |thread, _| {
560            assert!(thread.entries.len() >= 3, "{:?}", &thread.entries);
561            assert!(matches!(
562                thread.entries[0].content,
563                AgentThreadEntryContent::Message(Message {
564                    role: Role::User,
565                    ..
566                })
567            ));
568            assert!(matches!(
569                thread.entries[1].content,
570                AgentThreadEntryContent::ToolCall(ToolCall {
571                    status: ToolCallStatus::Allowed { .. },
572                    ..
573                })
574            ));
575            assert!(matches!(
576                thread.entries[2].content,
577                AgentThreadEntryContent::Message(Message {
578                    role: Role::Assistant,
579                    ..
580                })
581            ));
582        });
583    }
584
585    async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
586        let (mut tx, mut rx) = mpsc::channel(1);
587
588        let subscription = cx.update(|cx| {
589            cx.subscribe(thread, move |thread, _, cx| {
590                if thread
591                    .read(cx)
592                    .entries
593                    .iter()
594                    .any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_)))
595                {
596                    tx.try_send(()).unwrap();
597                }
598            })
599        });
600
601        select! {
602            _ = cx.executor().timer(Duration::from_secs(5)).fuse() => {
603                panic!("Timeout waiting for tool call")
604            }
605            _ = rx.next().fuse() => {
606                drop(subscription);
607            }
608        }
609    }
610
611    pub fn gemini_acp_server(project: Entity<Project>, mut cx: AsyncApp) -> Result<Arc<AcpServer>> {
612        let cli_path =
613            Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
614        let mut command = util::command::new_smol_command("node");
615        command
616            .arg(cli_path)
617            .arg("--acp")
618            .args(["--model", "gemini-2.5-flash"])
619            .current_dir("/private/tmp")
620            .stdin(Stdio::piped())
621            .stdout(Stdio::piped())
622            .stderr(Stdio::inherit())
623            .kill_on_drop(true);
624
625        if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
626            command.env("GEMINI_API_KEY", gemini_key);
627        }
628
629        let child = command.spawn().unwrap();
630
631        Ok(AcpServer::stdio(child, project, &mut cx))
632    }
633}