acp.rs

   1mod server;
   2mod thread_view;
   3
   4use agentic_coding_protocol::{self as acp};
   5use anyhow::{Context as _, Result};
   6use buffer_diff::BufferDiff;
   7use chrono::{DateTime, Utc};
   8use editor::{MultiBuffer, PathKey};
   9use futures::channel::oneshot;
  10use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
  11use language::{Anchor, Buffer, Capability, LanguageRegistry, OffsetRangeExt as _};
  12use markdown::Markdown;
  13use project::Project;
  14use std::{mem, ops::Range, path::PathBuf, sync::Arc};
  15use ui::{App, IconName};
  16use util::{ResultExt, debug_panic};
  17
  18pub use server::AcpServer;
  19pub use thread_view::AcpThreadView;
  20
  21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  22pub struct ThreadId(SharedString);
  23
  24#[derive(Copy, Clone, Debug, PartialEq, Eq)]
  25pub struct FileVersion(u64);
  26
  27#[derive(Debug)]
  28pub struct AgentThreadSummary {
  29    pub id: ThreadId,
  30    pub title: String,
  31    pub created_at: DateTime<Utc>,
  32}
  33
  34#[derive(Clone, Debug, PartialEq, Eq)]
  35pub struct FileContent {
  36    pub path: PathBuf,
  37    pub version: FileVersion,
  38    pub content: SharedString,
  39}
  40
  41#[derive(Clone, Debug, Eq, PartialEq)]
  42pub struct UserMessage {
  43    pub chunks: Vec<UserMessageChunk>,
  44}
  45
  46impl UserMessage {
  47    fn into_acp(self, cx: &App) -> acp::UserMessage {
  48        acp::UserMessage {
  49            chunks: self
  50                .chunks
  51                .into_iter()
  52                .map(|chunk| chunk.into_acp(cx))
  53                .collect(),
  54        }
  55    }
  56}
  57
  58#[derive(Clone, Debug, Eq, PartialEq)]
  59pub enum UserMessageChunk {
  60    Text {
  61        chunk: Entity<Markdown>,
  62    },
  63    File {
  64        content: FileContent,
  65    },
  66    Directory {
  67        path: PathBuf,
  68        contents: Vec<FileContent>,
  69    },
  70    Symbol {
  71        path: PathBuf,
  72        range: Range<u64>,
  73        version: FileVersion,
  74        name: SharedString,
  75        content: SharedString,
  76    },
  77    Fetch {
  78        url: SharedString,
  79        content: SharedString,
  80    },
  81}
  82
  83impl UserMessageChunk {
  84    pub fn into_acp(self, cx: &App) -> acp::UserMessageChunk {
  85        match self {
  86            Self::Text { chunk } => acp::UserMessageChunk::Text {
  87                chunk: chunk.read(cx).source().to_string(),
  88            },
  89            Self::File { .. } => todo!(),
  90            Self::Directory { .. } => todo!(),
  91            Self::Symbol { .. } => todo!(),
  92            Self::Fetch { .. } => todo!(),
  93        }
  94    }
  95
  96    pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
  97        Self::Text {
  98            chunk: cx.new(|cx| {
  99                Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
 100            }),
 101        }
 102    }
 103}
 104
 105#[derive(Clone, Debug, Eq, PartialEq)]
 106pub struct AssistantMessage {
 107    pub chunks: Vec<AssistantMessageChunk>,
 108}
 109
 110#[derive(Clone, Debug, Eq, PartialEq)]
 111pub enum AssistantMessageChunk {
 112    Text { chunk: Entity<Markdown> },
 113    Thought { chunk: Entity<Markdown> },
 114}
 115
 116impl AssistantMessageChunk {
 117    pub fn from_acp(
 118        chunk: acp::AssistantMessageChunk,
 119        language_registry: Arc<LanguageRegistry>,
 120        cx: &mut App,
 121    ) -> Self {
 122        match chunk {
 123            acp::AssistantMessageChunk::Text { chunk } => Self::Text {
 124                chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
 125            },
 126            acp::AssistantMessageChunk::Thought { chunk } => Self::Thought {
 127                chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
 128            },
 129        }
 130    }
 131
 132    pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
 133        Self::Text {
 134            chunk: cx.new(|cx| {
 135                Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
 136            }),
 137        }
 138    }
 139}
 140
 141#[derive(Debug)]
 142pub enum AgentThreadEntryContent {
 143    UserMessage(UserMessage),
 144    AssistantMessage(AssistantMessage),
 145    ToolCall(ToolCall),
 146}
 147
 148#[derive(Debug)]
 149pub struct ToolCall {
 150    id: ToolCallId,
 151    label: Entity<Markdown>,
 152    icon: IconName,
 153    content: Option<ToolCallContent>,
 154    status: ToolCallStatus,
 155}
 156
 157#[derive(Debug)]
 158pub enum ToolCallStatus {
 159    WaitingForConfirmation {
 160        confirmation: ToolCallConfirmation,
 161        respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
 162    },
 163    Allowed {
 164        status: acp::ToolCallStatus,
 165    },
 166    Rejected,
 167    Canceled,
 168}
 169
 170#[derive(Debug)]
 171pub enum ToolCallConfirmation {
 172    Edit {
 173        description: Option<Entity<Markdown>>,
 174    },
 175    Execute {
 176        command: String,
 177        root_command: String,
 178        description: Option<Entity<Markdown>>,
 179    },
 180    Mcp {
 181        server_name: String,
 182        tool_name: String,
 183        tool_display_name: String,
 184        description: Option<Entity<Markdown>>,
 185    },
 186    Fetch {
 187        urls: Vec<String>,
 188        description: Option<Entity<Markdown>>,
 189    },
 190    Other {
 191        description: Entity<Markdown>,
 192    },
 193}
 194
 195impl ToolCallConfirmation {
 196    pub fn from_acp(
 197        confirmation: acp::ToolCallConfirmation,
 198        language_registry: Arc<LanguageRegistry>,
 199        cx: &mut App,
 200    ) -> Self {
 201        let to_md = |description: String, cx: &mut App| -> Entity<Markdown> {
 202            cx.new(|cx| {
 203                Markdown::new(
 204                    description.into(),
 205                    Some(language_registry.clone()),
 206                    None,
 207                    cx,
 208                )
 209            })
 210        };
 211
 212        match confirmation {
 213            acp::ToolCallConfirmation::Edit { description } => Self::Edit {
 214                description: description.map(|description| to_md(description, cx)),
 215            },
 216            acp::ToolCallConfirmation::Execute {
 217                command,
 218                root_command,
 219                description,
 220            } => Self::Execute {
 221                command,
 222                root_command,
 223                description: description.map(|description| to_md(description, cx)),
 224            },
 225            acp::ToolCallConfirmation::Mcp {
 226                server_name,
 227                tool_name,
 228                tool_display_name,
 229                description,
 230            } => Self::Mcp {
 231                server_name,
 232                tool_name,
 233                tool_display_name,
 234                description: description.map(|description| to_md(description, cx)),
 235            },
 236            acp::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch {
 237                urls,
 238                description: description.map(|description| to_md(description, cx)),
 239            },
 240            acp::ToolCallConfirmation::Other { description } => Self::Other {
 241                description: to_md(description, cx),
 242            },
 243        }
 244    }
 245}
 246
 247#[derive(Debug)]
 248pub enum ToolCallContent {
 249    Markdown { markdown: Entity<Markdown> },
 250    Diff { diff: Diff },
 251}
 252
 253impl ToolCallContent {
 254    pub fn from_acp(
 255        content: acp::ToolCallContent,
 256        language_registry: Arc<LanguageRegistry>,
 257        cx: &mut App,
 258    ) -> Self {
 259        match content {
 260            acp::ToolCallContent::Markdown { markdown } => Self::Markdown {
 261                markdown: cx.new(|cx| Markdown::new_text(markdown.into(), cx)),
 262            },
 263            acp::ToolCallContent::Diff { diff } => Self::Diff {
 264                diff: Diff::from_acp(diff, language_registry, cx),
 265            },
 266        }
 267    }
 268}
 269
 270#[derive(Debug)]
 271pub struct Diff {
 272    multibuffer: Entity<MultiBuffer>,
 273    path: PathBuf,
 274    _task: Task<Result<()>>,
 275}
 276
 277impl Diff {
 278    pub fn from_acp(
 279        diff: acp::Diff,
 280        language_registry: Arc<LanguageRegistry>,
 281        cx: &mut App,
 282    ) -> Self {
 283        let acp::Diff {
 284            path,
 285            old_text,
 286            new_text,
 287        } = diff;
 288
 289        let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
 290
 291        let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
 292        let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
 293        let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
 294        let old_buffer_snapshot = old_buffer.read(cx).snapshot();
 295        let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
 296        let diff_task = buffer_diff.update(cx, |diff, cx| {
 297            diff.set_base_text(
 298                old_buffer_snapshot,
 299                Some(language_registry.clone()),
 300                new_buffer_snapshot,
 301                cx,
 302            )
 303        });
 304
 305        let task = cx.spawn({
 306            let multibuffer = multibuffer.clone();
 307            let path = path.clone();
 308            async move |cx| {
 309                diff_task.await?;
 310
 311                multibuffer
 312                    .update(cx, |multibuffer, cx| {
 313                        let hunk_ranges = {
 314                            let buffer = new_buffer.read(cx);
 315                            let diff = buffer_diff.read(cx);
 316                            diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
 317                                .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
 318                                .collect::<Vec<_>>()
 319                        };
 320
 321                        multibuffer.set_excerpts_for_path(
 322                            PathKey::for_buffer(&new_buffer, cx),
 323                            new_buffer.clone(),
 324                            hunk_ranges,
 325                            editor::DEFAULT_MULTIBUFFER_CONTEXT,
 326                            cx,
 327                        );
 328                        multibuffer.add_diff(buffer_diff.clone(), cx);
 329                    })
 330                    .log_err();
 331
 332                if let Some(language) = language_registry
 333                    .language_for_file_path(&path)
 334                    .await
 335                    .log_err()
 336                {
 337                    new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
 338                }
 339
 340                anyhow::Ok(())
 341            }
 342        });
 343
 344        Self {
 345            multibuffer,
 346            path,
 347            _task: task,
 348        }
 349    }
 350}
 351
 352/// A `ThreadEntryId` that is known to be a ToolCall
 353#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
 354pub struct ToolCallId(ThreadEntryId);
 355
 356impl ToolCallId {
 357    pub fn as_u64(&self) -> u64 {
 358        self.0.0
 359    }
 360}
 361
 362#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
 363pub struct ThreadEntryId(pub u64);
 364
 365impl ThreadEntryId {
 366    pub fn post_inc(&mut self) -> Self {
 367        let id = *self;
 368        self.0 += 1;
 369        id
 370    }
 371}
 372
 373#[derive(Debug)]
 374pub struct ThreadEntry {
 375    pub id: ThreadEntryId,
 376    pub content: AgentThreadEntryContent,
 377}
 378
 379pub struct AcpThread {
 380    id: ThreadId,
 381    next_entry_id: ThreadEntryId,
 382    entries: Vec<ThreadEntry>,
 383    server: Arc<AcpServer>,
 384    title: SharedString,
 385    project: Entity<Project>,
 386    send_task: Option<Task<()>>,
 387}
 388
 389enum AcpThreadEvent {
 390    NewEntry,
 391    EntryUpdated(usize),
 392}
 393
 394#[derive(PartialEq, Eq)]
 395pub enum ThreadStatus {
 396    Idle,
 397    WaitingForToolConfirmation,
 398    Generating,
 399}
 400
 401impl EventEmitter<AcpThreadEvent> for AcpThread {}
 402
 403impl AcpThread {
 404    pub fn new(
 405        server: Arc<AcpServer>,
 406        thread_id: ThreadId,
 407        entries: Vec<AgentThreadEntryContent>,
 408        project: Entity<Project>,
 409        _: &mut Context<Self>,
 410    ) -> Self {
 411        let mut next_entry_id = ThreadEntryId(0);
 412        Self {
 413            title: "ACP Thread".into(),
 414            entries: entries
 415                .into_iter()
 416                .map(|entry| ThreadEntry {
 417                    id: next_entry_id.post_inc(),
 418                    content: entry,
 419                })
 420                .collect(),
 421            server,
 422            id: thread_id,
 423            next_entry_id,
 424            project,
 425            send_task: None,
 426        }
 427    }
 428
 429    pub fn title(&self) -> SharedString {
 430        self.title.clone()
 431    }
 432
 433    pub fn entries(&self) -> &[ThreadEntry] {
 434        &self.entries
 435    }
 436
 437    pub fn status(&self) -> ThreadStatus {
 438        if self.send_task.is_some() {
 439            if self.waiting_for_tool_confirmation() {
 440                ThreadStatus::WaitingForToolConfirmation
 441            } else {
 442                ThreadStatus::Generating
 443            }
 444        } else {
 445            ThreadStatus::Idle
 446        }
 447    }
 448
 449    pub fn push_entry(
 450        &mut self,
 451        entry: AgentThreadEntryContent,
 452        cx: &mut Context<Self>,
 453    ) -> ThreadEntryId {
 454        let id = self.next_entry_id.post_inc();
 455        self.entries.push(ThreadEntry { id, content: entry });
 456        cx.emit(AcpThreadEvent::NewEntry);
 457        id
 458    }
 459
 460    pub fn push_assistant_chunk(
 461        &mut self,
 462        chunk: acp::AssistantMessageChunk,
 463        cx: &mut Context<Self>,
 464    ) {
 465        let entries_len = self.entries.len();
 466        if let Some(last_entry) = self.entries.last_mut()
 467            && let AgentThreadEntryContent::AssistantMessage(AssistantMessage { ref mut chunks }) =
 468                last_entry.content
 469        {
 470            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 471
 472            match (chunks.last_mut(), &chunk) {
 473                (
 474                    Some(AssistantMessageChunk::Text { chunk: old_chunk }),
 475                    acp::AssistantMessageChunk::Text { chunk: new_chunk },
 476                )
 477                | (
 478                    Some(AssistantMessageChunk::Thought { chunk: old_chunk }),
 479                    acp::AssistantMessageChunk::Thought { chunk: new_chunk },
 480                ) => {
 481                    old_chunk.update(cx, |old_chunk, cx| {
 482                        old_chunk.append(&new_chunk, cx);
 483                    });
 484                }
 485                _ => {
 486                    chunks.push(AssistantMessageChunk::from_acp(
 487                        chunk,
 488                        self.project.read(cx).languages().clone(),
 489                        cx,
 490                    ));
 491                }
 492            }
 493        } else {
 494            let chunk = AssistantMessageChunk::from_acp(
 495                chunk,
 496                self.project.read(cx).languages().clone(),
 497                cx,
 498            );
 499
 500            self.push_entry(
 501                AgentThreadEntryContent::AssistantMessage(AssistantMessage {
 502                    chunks: vec![chunk],
 503                }),
 504                cx,
 505            );
 506        }
 507    }
 508
 509    pub fn request_tool_call(
 510        &mut self,
 511        label: String,
 512        icon: acp::Icon,
 513        content: Option<acp::ToolCallContent>,
 514        confirmation: acp::ToolCallConfirmation,
 515        cx: &mut Context<Self>,
 516    ) -> ToolCallRequest {
 517        let (tx, rx) = oneshot::channel();
 518
 519        let status = ToolCallStatus::WaitingForConfirmation {
 520            confirmation: ToolCallConfirmation::from_acp(
 521                confirmation,
 522                self.project.read(cx).languages().clone(),
 523                cx,
 524            ),
 525            respond_tx: tx,
 526        };
 527
 528        let id = self.insert_tool_call(label, status, icon, content, cx);
 529        ToolCallRequest { id, outcome: rx }
 530    }
 531
 532    pub fn push_tool_call(
 533        &mut self,
 534        label: String,
 535        icon: acp::Icon,
 536        content: Option<acp::ToolCallContent>,
 537        cx: &mut Context<Self>,
 538    ) -> ToolCallId {
 539        let status = ToolCallStatus::Allowed {
 540            status: acp::ToolCallStatus::Running,
 541        };
 542
 543        self.insert_tool_call(label, status, icon, content, cx)
 544    }
 545
 546    fn insert_tool_call(
 547        &mut self,
 548        label: String,
 549        status: ToolCallStatus,
 550        icon: acp::Icon,
 551        content: Option<acp::ToolCallContent>,
 552        cx: &mut Context<Self>,
 553    ) -> ToolCallId {
 554        let language_registry = self.project.read(cx).languages().clone();
 555
 556        let entry_id = self.push_entry(
 557            AgentThreadEntryContent::ToolCall(ToolCall {
 558                // todo! clean up id creation
 559                id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
 560                label: cx.new(|cx| {
 561                    Markdown::new(label.into(), Some(language_registry.clone()), None, cx)
 562                }),
 563                icon: acp_icon_to_ui_icon(icon),
 564                content: content
 565                    .map(|content| ToolCallContent::from_acp(content, language_registry, cx)),
 566                status,
 567            }),
 568            cx,
 569        );
 570
 571        ToolCallId(entry_id)
 572    }
 573
 574    pub fn authorize_tool_call(
 575        &mut self,
 576        id: ToolCallId,
 577        outcome: acp::ToolCallConfirmationOutcome,
 578        cx: &mut Context<Self>,
 579    ) {
 580        let Some(entry) = self.entry_mut(id.0) else {
 581            return;
 582        };
 583
 584        let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
 585            debug_panic!("expected ToolCall");
 586            return;
 587        };
 588
 589        let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
 590            ToolCallStatus::Rejected
 591        } else {
 592            ToolCallStatus::Allowed {
 593                status: acp::ToolCallStatus::Running,
 594            }
 595        };
 596
 597        let curr_status = mem::replace(&mut call.status, new_status);
 598
 599        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
 600            respond_tx.send(outcome).log_err();
 601        } else {
 602            debug_panic!("tried to authorize an already authorized tool call");
 603        }
 604
 605        cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
 606    }
 607
 608    pub fn update_tool_call(
 609        &mut self,
 610        id: ToolCallId,
 611        new_status: acp::ToolCallStatus,
 612        new_content: Option<acp::ToolCallContent>,
 613        cx: &mut Context<Self>,
 614    ) -> Result<()> {
 615        let language_registry = self.project.read(cx).languages().clone();
 616        let entry = self.entry_mut(id.0).context("Entry not found")?;
 617
 618        match &mut entry.content {
 619            AgentThreadEntryContent::ToolCall(call) => {
 620                call.content = new_content.map(|new_content| {
 621                    ToolCallContent::from_acp(new_content, language_registry, cx)
 622                });
 623
 624                match &mut call.status {
 625                    ToolCallStatus::Allowed { status } => {
 626                        *status = new_status;
 627                    }
 628                    ToolCallStatus::WaitingForConfirmation { .. } => {
 629                        anyhow::bail!("Tool call hasn't been authorized yet")
 630                    }
 631                    ToolCallStatus::Rejected => {
 632                        anyhow::bail!("Tool call was rejected and therefore can't be updated")
 633                    }
 634                    ToolCallStatus::Canceled => {
 635                        // todo! test this case with fake server
 636                        call.status = ToolCallStatus::Allowed { status: new_status };
 637                    }
 638                }
 639            }
 640            _ => anyhow::bail!("Entry is not a tool call"),
 641        }
 642
 643        cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
 644        Ok(())
 645    }
 646
 647    fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
 648        let entry = self.entries.get_mut(id.0 as usize);
 649        debug_assert!(
 650            entry.is_some(),
 651            "We shouldn't give out ids to entries that don't exist"
 652        );
 653        entry
 654    }
 655
 656    /// Returns true if the last turn is awaiting tool authorization
 657    pub fn waiting_for_tool_confirmation(&self) -> bool {
 658        // todo!("should we use a hashmap?")
 659        for entry in self.entries.iter().rev() {
 660            match &entry.content {
 661                AgentThreadEntryContent::ToolCall(call) => match call.status {
 662                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
 663                    ToolCallStatus::Allowed { .. }
 664                    | ToolCallStatus::Rejected
 665                    | ToolCallStatus::Canceled => continue,
 666                },
 667                AgentThreadEntryContent::UserMessage(_)
 668                | AgentThreadEntryContent::AssistantMessage(_) => {
 669                    // Reached the beginning of the turn
 670                    return false;
 671                }
 672            }
 673        }
 674        false
 675    }
 676
 677    pub fn send(
 678        &mut self,
 679        message: &str,
 680        cx: &mut Context<Self>,
 681    ) -> impl use<> + Future<Output = Result<()>> {
 682        let agent = self.server.clone();
 683        let id = self.id.clone();
 684        let chunk =
 685            UserMessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
 686        let message = UserMessage {
 687            chunks: vec![chunk],
 688        };
 689        self.push_entry(AgentThreadEntryContent::UserMessage(message.clone()), cx);
 690        let acp_message = message.into_acp(cx);
 691
 692        let (tx, rx) = oneshot::channel();
 693        let cancel = self.cancel(cx);
 694
 695        self.send_task = Some(cx.spawn(async move |this, cx| {
 696            cancel.await.log_err();
 697
 698            let result = agent.send_message(id, acp_message, cx).await;
 699            tx.send(result).log_err();
 700            this.update(cx, |this, _cx| this.send_task.take()).log_err();
 701        }));
 702
 703        async move {
 704            match rx.await {
 705                Ok(result) => result,
 706                Err(_) => Ok(()),
 707            }
 708        }
 709    }
 710
 711    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 712        let agent = self.server.clone();
 713        let id = self.id.clone();
 714
 715        if self.send_task.take().is_some() {
 716            cx.spawn(async move |this, cx| {
 717                agent.cancel_send_message(id, cx).await?;
 718
 719                this.update(cx, |this, _cx| {
 720                    for entry in this.entries.iter_mut() {
 721                        if let AgentThreadEntryContent::ToolCall(call) = &mut entry.content {
 722                            let cancel = matches!(
 723                                call.status,
 724                                ToolCallStatus::WaitingForConfirmation { .. }
 725                                    | ToolCallStatus::Allowed {
 726                                        status: acp::ToolCallStatus::Running
 727                                    }
 728                            );
 729
 730                            if cancel {
 731                                let curr_status =
 732                                    mem::replace(&mut call.status, ToolCallStatus::Canceled);
 733
 734                                if let ToolCallStatus::WaitingForConfirmation {
 735                                    respond_tx, ..
 736                                } = curr_status
 737                                {
 738                                    respond_tx
 739                                        .send(acp::ToolCallConfirmationOutcome::Cancel)
 740                                        .ok();
 741                                }
 742                            }
 743                        }
 744                    }
 745                })
 746            })
 747        } else {
 748            Task::ready(Ok(()))
 749        }
 750    }
 751}
 752
 753fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
 754    match icon {
 755        acp::Icon::FileSearch => IconName::FileSearch,
 756        acp::Icon::Folder => IconName::Folder,
 757        acp::Icon::Globe => IconName::Globe,
 758        acp::Icon::Hammer => IconName::Hammer,
 759        acp::Icon::LightBulb => IconName::LightBulb,
 760        acp::Icon::Pencil => IconName::Pencil,
 761        acp::Icon::Regex => IconName::Regex,
 762        acp::Icon::Terminal => IconName::Terminal,
 763    }
 764}
 765
 766pub struct ToolCallRequest {
 767    pub id: ToolCallId,
 768    pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
 769}
 770
 771#[cfg(test)]
 772mod tests {
 773    use super::*;
 774    use async_pipe::{PipeReader, PipeWriter};
 775    use async_trait::async_trait;
 776    use futures::{FutureExt as _, channel::mpsc, future::LocalBoxFuture, select};
 777    use gpui::{AsyncApp, TestAppContext};
 778    use project::FakeFs;
 779    use serde_json::json;
 780    use settings::SettingsStore;
 781    use smol::stream::StreamExt as _;
 782    use std::{env, path::Path, process::Stdio, rc::Rc, time::Duration};
 783    use util::path;
 784
 785    fn init_test(cx: &mut TestAppContext) {
 786        env_logger::try_init().ok();
 787        cx.update(|cx| {
 788            let settings_store = SettingsStore::test(cx);
 789            cx.set_global(settings_store);
 790            Project::init_settings(cx);
 791            language::init(cx);
 792        });
 793    }
 794
 795    #[gpui::test]
 796    async fn test_message_receipt(cx: &mut TestAppContext) {
 797        init_test(cx);
 798
 799        cx.executor().allow_parking();
 800
 801        let fs = FakeFs::new(cx.executor());
 802        let project = Project::test(fs, [], cx).await;
 803        let (server, fake_server) = fake_acp_server(project, cx);
 804
 805        server.initialize().await.unwrap();
 806
 807        fake_server.update(cx, |fake_server, _| {
 808            fake_server.on_user_message(move |params, server, mut cx| async move {
 809                server
 810                    .update(&mut cx, |server, cx| {
 811                        server.send_to_zed(
 812                            acp::StreamAssistantMessageChunkParams {
 813                                thread_id: params.thread_id.clone(),
 814                                chunk: acp::AssistantMessageChunk::Thought {
 815                                    chunk: "Thinking ".into(),
 816                                },
 817                            },
 818                            cx,
 819                        )
 820                    })?
 821                    .await
 822                    .unwrap();
 823                server
 824                    .update(&mut cx, |server, cx| {
 825                        server.send_to_zed(
 826                            acp::StreamAssistantMessageChunkParams {
 827                                thread_id: params.thread_id,
 828                                chunk: acp::AssistantMessageChunk::Thought {
 829                                    chunk: "hard!".into(),
 830                                },
 831                            },
 832                            cx,
 833                        )
 834                    })?
 835                    .await
 836                    .unwrap();
 837
 838                Ok(acp::SendUserMessageResponse)
 839            })
 840        })
 841    }
 842
 843    #[gpui::test]
 844    async fn test_gemini_basic(cx: &mut TestAppContext) {
 845        init_test(cx);
 846
 847        cx.executor().allow_parking();
 848
 849        let fs = FakeFs::new(cx.executor());
 850        let project = Project::test(fs, [], cx).await;
 851        let server = gemini_acp_server(project.clone(), cx).await;
 852        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
 853        thread
 854            .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
 855            .await
 856            .unwrap();
 857
 858        thread.read_with(cx, |thread, _| {
 859            assert_eq!(thread.entries.len(), 2);
 860            assert!(matches!(
 861                thread.entries[0].content,
 862                AgentThreadEntryContent::UserMessage(_)
 863            ));
 864            assert!(matches!(
 865                thread.entries[1].content,
 866                AgentThreadEntryContent::AssistantMessage(_)
 867            ));
 868        });
 869    }
 870
 871    #[gpui::test]
 872    async fn test_gemini_tool_call(cx: &mut TestAppContext) {
 873        init_test(cx);
 874
 875        cx.executor().allow_parking();
 876
 877        let fs = FakeFs::new(cx.executor());
 878        fs.insert_tree(
 879            path!("/private/tmp"),
 880            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
 881        )
 882        .await;
 883        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
 884        let server = gemini_acp_server(project.clone(), cx).await;
 885        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
 886        thread
 887            .update(cx, |thread, cx| {
 888                thread.send(
 889                    "Read the '/private/tmp/foo' file and tell me what you see.",
 890                    cx,
 891                )
 892            })
 893            .await
 894            .unwrap();
 895        thread.read_with(cx, |thread, _cx| {
 896            assert!(matches!(
 897                &thread.entries()[2].content,
 898                AgentThreadEntryContent::ToolCall(ToolCall {
 899                    status: ToolCallStatus::Allowed { .. },
 900                    ..
 901                })
 902            ));
 903
 904            assert!(matches!(
 905                thread.entries[3].content,
 906                AgentThreadEntryContent::AssistantMessage(_)
 907            ));
 908        });
 909    }
 910
 911    #[gpui::test]
 912    async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
 913        init_test(cx);
 914
 915        cx.executor().allow_parking();
 916
 917        let fs = FakeFs::new(cx.executor());
 918        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
 919        let server = gemini_acp_server(project.clone(), cx).await;
 920        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
 921        let full_turn = thread.update(cx, |thread, cx| {
 922            thread.send(r#"Run `echo "Hello, world!"`"#, cx)
 923        });
 924
 925        run_until_first_tool_call(&thread, cx).await;
 926
 927        let tool_call_id = thread.read_with(cx, |thread, _cx| {
 928            let AgentThreadEntryContent::ToolCall(ToolCall {
 929                id,
 930                status:
 931                    ToolCallStatus::WaitingForConfirmation {
 932                        confirmation: ToolCallConfirmation::Execute { root_command, .. },
 933                        ..
 934                    },
 935                ..
 936            }) = &thread.entries()[2].content
 937            else {
 938                panic!();
 939            };
 940
 941            assert_eq!(root_command, "echo");
 942
 943            *id
 944        });
 945
 946        thread.update(cx, |thread, cx| {
 947            thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
 948
 949            assert!(matches!(
 950                &thread.entries()[2].content,
 951                AgentThreadEntryContent::ToolCall(ToolCall {
 952                    status: ToolCallStatus::Allowed { .. },
 953                    ..
 954                })
 955            ));
 956        });
 957
 958        full_turn.await.unwrap();
 959
 960        thread.read_with(cx, |thread, cx| {
 961            let AgentThreadEntryContent::ToolCall(ToolCall {
 962                content: Some(ToolCallContent::Markdown { markdown }),
 963                status: ToolCallStatus::Allowed { .. },
 964                ..
 965            }) = &thread.entries()[2].content
 966            else {
 967                panic!();
 968            };
 969
 970            markdown.read_with(cx, |md, _cx| {
 971                assert!(
 972                    md.source().contains("Hello, world!"),
 973                    r#"Expected '{}' to contain "Hello, world!""#,
 974                    md.source()
 975                );
 976            });
 977        });
 978    }
 979
 980    #[gpui::test]
 981    async fn test_gemini_cancel(cx: &mut TestAppContext) {
 982        init_test(cx);
 983
 984        cx.executor().allow_parking();
 985
 986        let fs = FakeFs::new(cx.executor());
 987        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
 988        let server = gemini_acp_server(project.clone(), cx).await;
 989        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
 990        let full_turn = thread.update(cx, |thread, cx| {
 991            thread.send(r#"Run `echo "Hello, world!"`"#, cx)
 992        });
 993
 994        let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await;
 995
 996        thread.read_with(cx, |thread, _cx| {
 997            let AgentThreadEntryContent::ToolCall(ToolCall {
 998                id,
 999                status:
1000                    ToolCallStatus::WaitingForConfirmation {
1001                        confirmation: ToolCallConfirmation::Execute { root_command, .. },
1002                        ..
1003                    },
1004                ..
1005            }) = &thread.entries()[first_tool_call_ix].content
1006            else {
1007                panic!("{:?}", thread.entries()[1].content);
1008            };
1009
1010            assert_eq!(root_command, "echo");
1011
1012            *id
1013        });
1014
1015        thread
1016            .update(cx, |thread, cx| thread.cancel(cx))
1017            .await
1018            .unwrap();
1019        full_turn.await.unwrap();
1020        thread.read_with(cx, |thread, _| {
1021            let AgentThreadEntryContent::ToolCall(ToolCall {
1022                status: ToolCallStatus::Canceled,
1023                ..
1024            }) = &thread.entries()[first_tool_call_ix].content
1025            else {
1026                panic!();
1027            };
1028        });
1029
1030        thread
1031            .update(cx, |thread, cx| {
1032                thread.send(r#"Stop running and say goodbye to me."#, cx)
1033            })
1034            .await
1035            .unwrap();
1036        thread.read_with(cx, |thread, _| {
1037            assert!(matches!(
1038                &thread.entries().last().unwrap().content,
1039                AgentThreadEntryContent::AssistantMessage(..),
1040            ))
1041        });
1042    }
1043
1044    async fn run_until_first_tool_call(
1045        thread: &Entity<AcpThread>,
1046        cx: &mut TestAppContext,
1047    ) -> usize {
1048        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1049
1050        let subscription = cx.update(|cx| {
1051            cx.subscribe(thread, move |thread, _, cx| {
1052                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1053                    if matches!(entry.content, AgentThreadEntryContent::ToolCall(_)) {
1054                        return tx.try_send(ix).unwrap();
1055                    }
1056                }
1057            })
1058        });
1059
1060        select! {
1061            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1062                panic!("Timeout waiting for tool call")
1063            }
1064            ix = rx.next().fuse() => {
1065                drop(subscription);
1066                ix.unwrap()
1067            }
1068        }
1069    }
1070
1071    pub async fn gemini_acp_server(
1072        project: Entity<Project>,
1073        cx: &mut TestAppContext,
1074    ) -> Arc<AcpServer> {
1075        let cli_path =
1076            Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
1077        let mut command = util::command::new_smol_command("node");
1078        command
1079            .arg(cli_path)
1080            .arg("--acp")
1081            .current_dir("/private/tmp")
1082            .stdin(Stdio::piped())
1083            .stdout(Stdio::piped())
1084            .stderr(Stdio::inherit())
1085            .kill_on_drop(true);
1086
1087        if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
1088            command.env("GEMINI_API_KEY", gemini_key);
1089        }
1090
1091        let child = command.spawn().unwrap();
1092        let server = cx.update(|cx| AcpServer::stdio(child, project, cx));
1093        server.initialize().await.unwrap();
1094        server
1095    }
1096
1097    pub fn fake_acp_server(
1098        project: Entity<Project>,
1099        cx: &mut TestAppContext,
1100    ) -> (Arc<AcpServer>, Entity<FakeAcpServer>) {
1101        let (stdin_tx, stdin_rx) = async_pipe::pipe();
1102        let (stdout_tx, stdout_rx) = async_pipe::pipe();
1103        let server = cx.update(|cx| AcpServer::fake(stdin_tx, stdout_rx, project, cx));
1104        let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
1105        (server, agent)
1106    }
1107
1108    pub struct FakeAcpServer {
1109        connection: acp::ClientConnection,
1110        _handler_task: Task<()>,
1111        _io_task: Task<()>,
1112        on_user_message: Option<
1113            Rc<
1114                dyn Fn(
1115                    acp::SendUserMessageParams,
1116                    Entity<FakeAcpServer>,
1117                    AsyncApp,
1118                )
1119                    -> LocalBoxFuture<'static, Result<acp::SendUserMessageResponse>>,
1120            >,
1121        >,
1122    }
1123
1124    #[derive(Clone)]
1125    struct FakeAgent {
1126        server: Entity<FakeAcpServer>,
1127        cx: AsyncApp,
1128    }
1129
1130    #[async_trait(?Send)]
1131    impl acp::Agent for FakeAgent {
1132        async fn initialize(
1133            &self,
1134            _request: acp::InitializeParams,
1135        ) -> Result<acp::InitializeResponse> {
1136            Ok(acp::InitializeResponse {
1137                is_authenticated: true,
1138            })
1139        }
1140
1141        async fn authenticate(
1142            &self,
1143            _request: acp::AuthenticateParams,
1144        ) -> Result<acp::AuthenticateResponse> {
1145            Ok(acp::AuthenticateResponse)
1146        }
1147
1148        async fn create_thread(
1149            &self,
1150            _request: acp::CreateThreadParams,
1151        ) -> Result<acp::CreateThreadResponse> {
1152            Ok(acp::CreateThreadResponse {
1153                thread_id: acp::ThreadId("test-thread".into()),
1154            })
1155        }
1156
1157        async fn send_user_message(
1158            &self,
1159            request: acp::SendUserMessageParams,
1160        ) -> Result<acp::SendUserMessageResponse> {
1161            let mut cx = self.cx.clone();
1162            let handler = self
1163                .server
1164                .update(&mut cx, |server, _| server.on_user_message.clone())
1165                .ok()
1166                .flatten();
1167            if let Some(handler) = handler {
1168                handler(request, self.server.clone(), self.cx.clone()).await
1169            } else {
1170                anyhow::bail!("No handler for on_user_message")
1171            }
1172        }
1173    }
1174
1175    impl FakeAcpServer {
1176        fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
1177            let agent = FakeAgent {
1178                server: cx.entity(),
1179                cx: cx.to_async(),
1180            };
1181
1182            let (connection, handler_fut, io_fut) =
1183                acp::ClientConnection::connect_to_client(agent.clone(), stdout, stdin);
1184            FakeAcpServer {
1185                connection: connection,
1186                on_user_message: None,
1187                _handler_task: cx.foreground_executor().spawn(handler_fut),
1188                _io_task: cx.background_spawn(async move {
1189                    io_fut.await.log_err();
1190                }),
1191            }
1192        }
1193
1194        fn on_user_message<F>(
1195            &mut self,
1196            handler: impl for<'a> Fn(acp::SendUserMessageParams, Entity<FakeAcpServer>, AsyncApp) -> F
1197            + 'static,
1198        ) where
1199            F: Future<Output = Result<acp::SendUserMessageResponse>> + 'static,
1200        {
1201            self.on_user_message
1202                .replace(Rc::new(move |request, server, cx| {
1203                    handler(request, server, cx).boxed_local()
1204                }));
1205        }
1206
1207        fn send_to_zed<T: acp::ClientRequest>(
1208            &self,
1209            message: T,
1210            cx: &Context<Self>,
1211        ) -> Task<Result<T::Response, acp::Error>> {
1212            let future = self.connection.request(message);
1213            cx.foreground_executor().spawn(future)
1214        }
1215    }
1216}