acp.rs

   1mod server;
   2mod thread_view;
   3
   4use agentic_coding_protocol::{self as acp};
   5use anyhow::{Context as _, Result};
   6use buffer_diff::BufferDiff;
   7use chrono::{DateTime, Utc};
   8use editor::{MultiBuffer, PathKey};
   9use futures::channel::oneshot;
  10use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
  11use language::{Anchor, Buffer, Capability, LanguageRegistry, OffsetRangeExt as _};
  12use markdown::Markdown;
  13use parking_lot::Mutex;
  14use parking_lot::Mutex;
  15use project::Project;
  16use std::{mem, ops::Range, path::PathBuf, process::ExitStatus, sync::Arc};
  17use ui::{App, IconName};
  18use util::{ResultExt, debug_panic};
  19
  20pub use server::AcpServer;
  21pub use thread_view::AcpThreadView;
  22
  23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  24pub struct ThreadId(SharedString);
  25
  26#[derive(Copy, Clone, Debug, PartialEq, Eq)]
  27pub struct FileVersion(u64);
  28
  29#[derive(Debug)]
  30pub struct AgentThreadSummary {
  31    pub id: ThreadId,
  32    pub title: String,
  33    pub created_at: DateTime<Utc>,
  34}
  35
  36#[derive(Clone, Debug, PartialEq, Eq)]
  37pub struct FileContent {
  38    pub path: PathBuf,
  39    pub version: FileVersion,
  40    pub content: SharedString,
  41}
  42
  43#[derive(Clone, Debug, Eq, PartialEq)]
  44pub struct UserMessage {
  45    pub chunks: Vec<UserMessageChunk>,
  46}
  47
  48impl UserMessage {
  49    fn into_acp(self, cx: &App) -> acp::UserMessage {
  50        acp::UserMessage {
  51            chunks: self
  52                .chunks
  53                .into_iter()
  54                .map(|chunk| chunk.into_acp(cx))
  55                .collect(),
  56        }
  57    }
  58}
  59
  60#[derive(Clone, Debug, Eq, PartialEq)]
  61pub enum UserMessageChunk {
  62    Text {
  63        chunk: Entity<Markdown>,
  64    },
  65    File {
  66        content: FileContent,
  67    },
  68    Directory {
  69        path: PathBuf,
  70        contents: Vec<FileContent>,
  71    },
  72    Symbol {
  73        path: PathBuf,
  74        range: Range<u64>,
  75        version: FileVersion,
  76        name: SharedString,
  77        content: SharedString,
  78    },
  79    Fetch {
  80        url: SharedString,
  81        content: SharedString,
  82    },
  83}
  84
  85impl UserMessageChunk {
  86    pub fn into_acp(self, cx: &App) -> acp::UserMessageChunk {
  87        match self {
  88            Self::Text { chunk } => acp::UserMessageChunk::Text {
  89                chunk: chunk.read(cx).source().to_string(),
  90            },
  91            Self::File { .. } => todo!(),
  92            Self::Directory { .. } => todo!(),
  93            Self::Symbol { .. } => todo!(),
  94            Self::Fetch { .. } => todo!(),
  95        }
  96    }
  97
  98    pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
  99        Self::Text {
 100            chunk: cx.new(|cx| {
 101                Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
 102            }),
 103        }
 104    }
 105}
 106
 107#[derive(Clone, Debug, Eq, PartialEq)]
 108pub struct AssistantMessage {
 109    pub chunks: Vec<AssistantMessageChunk>,
 110}
 111
 112#[derive(Clone, Debug, Eq, PartialEq)]
 113pub enum AssistantMessageChunk {
 114    Text { chunk: Entity<Markdown> },
 115    Thought { chunk: Entity<Markdown> },
 116}
 117
 118impl AssistantMessageChunk {
 119    pub fn from_acp(
 120        chunk: acp::AssistantMessageChunk,
 121        language_registry: Arc<LanguageRegistry>,
 122        cx: &mut App,
 123    ) -> Self {
 124        match chunk {
 125            acp::AssistantMessageChunk::Text { chunk } => Self::Text {
 126                chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
 127            },
 128            acp::AssistantMessageChunk::Thought { chunk } => Self::Thought {
 129                chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
 130            },
 131        }
 132    }
 133
 134    pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
 135        Self::Text {
 136            chunk: cx.new(|cx| {
 137                Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
 138            }),
 139        }
 140    }
 141}
 142
 143#[derive(Debug)]
 144pub enum AgentThreadEntryContent {
 145    UserMessage(UserMessage),
 146    AssistantMessage(AssistantMessage),
 147    ToolCall(ToolCall),
 148}
 149
 150#[derive(Debug)]
 151pub struct ToolCall {
 152    id: ToolCallId,
 153    label: Entity<Markdown>,
 154    icon: IconName,
 155    content: Option<ToolCallContent>,
 156    status: ToolCallStatus,
 157}
 158
 159#[derive(Debug)]
 160pub enum ToolCallStatus {
 161    WaitingForConfirmation {
 162        confirmation: ToolCallConfirmation,
 163        respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
 164    },
 165    Allowed {
 166        status: acp::ToolCallStatus,
 167    },
 168    Rejected,
 169    Canceled,
 170}
 171
 172#[derive(Debug)]
 173pub enum ToolCallConfirmation {
 174    Edit {
 175        description: Option<Entity<Markdown>>,
 176    },
 177    Execute {
 178        command: String,
 179        root_command: String,
 180        description: Option<Entity<Markdown>>,
 181    },
 182    Mcp {
 183        server_name: String,
 184        tool_name: String,
 185        tool_display_name: String,
 186        description: Option<Entity<Markdown>>,
 187    },
 188    Fetch {
 189        urls: Vec<String>,
 190        description: Option<Entity<Markdown>>,
 191    },
 192    Other {
 193        description: Entity<Markdown>,
 194    },
 195}
 196
 197impl ToolCallConfirmation {
 198    pub fn from_acp(
 199        confirmation: acp::ToolCallConfirmation,
 200        language_registry: Arc<LanguageRegistry>,
 201        cx: &mut App,
 202    ) -> Self {
 203        let to_md = |description: String, cx: &mut App| -> Entity<Markdown> {
 204            cx.new(|cx| {
 205                Markdown::new(
 206                    description.into(),
 207                    Some(language_registry.clone()),
 208                    None,
 209                    cx,
 210                )
 211            })
 212        };
 213
 214        match confirmation {
 215            acp::ToolCallConfirmation::Edit { description } => Self::Edit {
 216                description: description.map(|description| to_md(description, cx)),
 217            },
 218            acp::ToolCallConfirmation::Execute {
 219                command,
 220                root_command,
 221                description,
 222            } => Self::Execute {
 223                command,
 224                root_command,
 225                description: description.map(|description| to_md(description, cx)),
 226            },
 227            acp::ToolCallConfirmation::Mcp {
 228                server_name,
 229                tool_name,
 230                tool_display_name,
 231                description,
 232            } => Self::Mcp {
 233                server_name,
 234                tool_name,
 235                tool_display_name,
 236                description: description.map(|description| to_md(description, cx)),
 237            },
 238            acp::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch {
 239                urls,
 240                description: description.map(|description| to_md(description, cx)),
 241            },
 242            acp::ToolCallConfirmation::Other { description } => Self::Other {
 243                description: to_md(description, cx),
 244            },
 245        }
 246    }
 247}
 248
 249#[derive(Debug)]
 250pub enum ToolCallContent {
 251    Markdown { markdown: Entity<Markdown> },
 252    Diff { diff: Diff },
 253}
 254
 255impl ToolCallContent {
 256    pub fn from_acp(
 257        content: acp::ToolCallContent,
 258        language_registry: Arc<LanguageRegistry>,
 259        cx: &mut App,
 260    ) -> Self {
 261        match content {
 262            acp::ToolCallContent::Markdown { markdown } => Self::Markdown {
 263                markdown: cx.new(|cx| Markdown::new_text(markdown.into(), cx)),
 264            },
 265            acp::ToolCallContent::Diff { diff } => Self::Diff {
 266                diff: Diff::from_acp(diff, language_registry, cx),
 267            },
 268        }
 269    }
 270}
 271
 272#[derive(Debug)]
 273pub struct Diff {
 274    multibuffer: Entity<MultiBuffer>,
 275    path: PathBuf,
 276    _task: Task<Result<()>>,
 277}
 278
 279impl Diff {
 280    pub fn from_acp(
 281        diff: acp::Diff,
 282        language_registry: Arc<LanguageRegistry>,
 283        cx: &mut App,
 284    ) -> Self {
 285        let acp::Diff {
 286            path,
 287            old_text,
 288            new_text,
 289        } = diff;
 290
 291        let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
 292
 293        let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
 294        let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
 295        let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
 296        let old_buffer_snapshot = old_buffer.read(cx).snapshot();
 297        let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
 298        let diff_task = buffer_diff.update(cx, |diff, cx| {
 299            diff.set_base_text(
 300                old_buffer_snapshot,
 301                Some(language_registry.clone()),
 302                new_buffer_snapshot,
 303                cx,
 304            )
 305        });
 306
 307        let task = cx.spawn({
 308            let multibuffer = multibuffer.clone();
 309            let path = path.clone();
 310            async move |cx| {
 311                diff_task.await?;
 312
 313                multibuffer
 314                    .update(cx, |multibuffer, cx| {
 315                        let hunk_ranges = {
 316                            let buffer = new_buffer.read(cx);
 317                            let diff = buffer_diff.read(cx);
 318                            diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
 319                                .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
 320                                .collect::<Vec<_>>()
 321                        };
 322
 323                        multibuffer.set_excerpts_for_path(
 324                            PathKey::for_buffer(&new_buffer, cx),
 325                            new_buffer.clone(),
 326                            hunk_ranges,
 327                            editor::DEFAULT_MULTIBUFFER_CONTEXT,
 328                            cx,
 329                        );
 330                        multibuffer.add_diff(buffer_diff.clone(), cx);
 331                    })
 332                    .log_err();
 333
 334                if let Some(language) = language_registry
 335                    .language_for_file_path(&path)
 336                    .await
 337                    .log_err()
 338                {
 339                    new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
 340                }
 341
 342                anyhow::Ok(())
 343            }
 344        });
 345
 346        Self {
 347            multibuffer,
 348            path,
 349            _task: task,
 350        }
 351    }
 352}
 353
 354/// A `ThreadEntryId` that is known to be a ToolCall
 355#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
 356pub struct ToolCallId(ThreadEntryId);
 357
 358impl ToolCallId {
 359    pub fn as_u64(&self) -> u64 {
 360        self.0.0
 361    }
 362}
 363
 364#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
 365pub struct ThreadEntryId(pub u64);
 366
 367impl ThreadEntryId {
 368    pub fn post_inc(&mut self) -> Self {
 369        let id = *self;
 370        self.0 += 1;
 371        id
 372    }
 373}
 374
 375#[derive(Debug)]
 376pub struct ThreadEntry {
 377    pub id: ThreadEntryId,
 378    pub content: AgentThreadEntryContent,
 379}
 380
 381pub struct AcpThread {
 382    next_entry_id: ThreadEntryId,
 383    entries: Vec<ThreadEntry>,
 384    server: Arc<AcpServer>,
 385    title: SharedString,
 386    project: Entity<Project>,
 387    send_task: Option<Task<()>>,
 388
 389    connection: Arc<acp::AgentConnection>,
 390    exit_status: Arc<Mutex<Option<ExitStatus>>>,
 391    _handler_task: Task<()>,
 392    _io_task: Task<()>,
 393}
 394
 395enum AcpThreadEvent {
 396    NewEntry,
 397    EntryUpdated(usize),
 398}
 399
 400#[derive(PartialEq, Eq)]
 401pub enum ThreadStatus {
 402    Idle,
 403    WaitingForToolConfirmation,
 404    Generating,
 405}
 406
 407impl EventEmitter<AcpThreadEvent> for AcpThread {}
 408
 409impl AcpThread {
 410    pub fn new(
 411        server: Arc<AcpServer>,
 412        entries: Vec<AgentThreadEntryContent>,
 413        project: Entity<Project>,
 414        _: &mut Context<Self>,
 415    ) -> Self {
 416        let mut next_entry_id = ThreadEntryId(0);
 417        Self {
 418            title: "ACP Thread".into(),
 419            entries: entries
 420                .into_iter()
 421                .map(|entry| ThreadEntry {
 422                    id: next_entry_id.post_inc(),
 423                    content: entry,
 424                })
 425                .collect(),
 426            server,
 427            next_entry_id,
 428            project,
 429            send_task: None,
 430        }
 431    }
 432
 433    pub fn title(&self) -> SharedString {
 434        self.title.clone()
 435    }
 436
 437    pub fn entries(&self) -> &[ThreadEntry] {
 438        &self.entries
 439    }
 440
 441    pub fn status(&self) -> ThreadStatus {
 442        if self.send_task.is_some() {
 443            if self.waiting_for_tool_confirmation() {
 444                ThreadStatus::WaitingForToolConfirmation
 445            } else {
 446                ThreadStatus::Generating
 447            }
 448        } else {
 449            ThreadStatus::Idle
 450        }
 451    }
 452
 453    pub fn push_entry(
 454        &mut self,
 455        entry: AgentThreadEntryContent,
 456        cx: &mut Context<Self>,
 457    ) -> ThreadEntryId {
 458        let id = self.next_entry_id.post_inc();
 459        self.entries.push(ThreadEntry { id, content: entry });
 460        cx.emit(AcpThreadEvent::NewEntry);
 461        id
 462    }
 463
 464    pub fn push_assistant_chunk(
 465        &mut self,
 466        chunk: acp::AssistantMessageChunk,
 467        cx: &mut Context<Self>,
 468    ) {
 469        let entries_len = self.entries.len();
 470        if let Some(last_entry) = self.entries.last_mut()
 471            && let AgentThreadEntryContent::AssistantMessage(AssistantMessage { ref mut chunks }) =
 472                last_entry.content
 473        {
 474            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 475
 476            match (chunks.last_mut(), &chunk) {
 477                (
 478                    Some(AssistantMessageChunk::Text { chunk: old_chunk }),
 479                    acp::AssistantMessageChunk::Text { chunk: new_chunk },
 480                )
 481                | (
 482                    Some(AssistantMessageChunk::Thought { chunk: old_chunk }),
 483                    acp::AssistantMessageChunk::Thought { chunk: new_chunk },
 484                ) => {
 485                    old_chunk.update(cx, |old_chunk, cx| {
 486                        old_chunk.append(&new_chunk, cx);
 487                    });
 488                }
 489                _ => {
 490                    chunks.push(AssistantMessageChunk::from_acp(
 491                        chunk,
 492                        self.project.read(cx).languages().clone(),
 493                        cx,
 494                    ));
 495                }
 496            }
 497        } else {
 498            let chunk = AssistantMessageChunk::from_acp(
 499                chunk,
 500                self.project.read(cx).languages().clone(),
 501                cx,
 502            );
 503
 504            self.push_entry(
 505                AgentThreadEntryContent::AssistantMessage(AssistantMessage {
 506                    chunks: vec![chunk],
 507                }),
 508                cx,
 509            );
 510        }
 511    }
 512
 513    pub fn request_tool_call(
 514        &mut self,
 515        label: String,
 516        icon: acp::Icon,
 517        content: Option<acp::ToolCallContent>,
 518        confirmation: acp::ToolCallConfirmation,
 519        cx: &mut Context<Self>,
 520    ) -> ToolCallRequest {
 521        let (tx, rx) = oneshot::channel();
 522
 523        let status = ToolCallStatus::WaitingForConfirmation {
 524            confirmation: ToolCallConfirmation::from_acp(
 525                confirmation,
 526                self.project.read(cx).languages().clone(),
 527                cx,
 528            ),
 529            respond_tx: tx,
 530        };
 531
 532        let id = self.insert_tool_call(label, status, icon, content, cx);
 533        ToolCallRequest { id, outcome: rx }
 534    }
 535
 536    pub fn push_tool_call(
 537        &mut self,
 538        label: String,
 539        icon: acp::Icon,
 540        content: Option<acp::ToolCallContent>,
 541        cx: &mut Context<Self>,
 542    ) -> ToolCallId {
 543        let status = ToolCallStatus::Allowed {
 544            status: acp::ToolCallStatus::Running,
 545        };
 546
 547        self.insert_tool_call(label, status, icon, content, cx)
 548    }
 549
 550    fn insert_tool_call(
 551        &mut self,
 552        label: String,
 553        status: ToolCallStatus,
 554        icon: acp::Icon,
 555        content: Option<acp::ToolCallContent>,
 556        cx: &mut Context<Self>,
 557    ) -> ToolCallId {
 558        let language_registry = self.project.read(cx).languages().clone();
 559
 560        let entry_id = self.push_entry(
 561            AgentThreadEntryContent::ToolCall(ToolCall {
 562                // todo! clean up id creation
 563                id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
 564                label: cx.new(|cx| {
 565                    Markdown::new(label.into(), Some(language_registry.clone()), None, cx)
 566                }),
 567                icon: acp_icon_to_ui_icon(icon),
 568                content: content
 569                    .map(|content| ToolCallContent::from_acp(content, language_registry, cx)),
 570                status,
 571            }),
 572            cx,
 573        );
 574
 575        ToolCallId(entry_id)
 576    }
 577
 578    pub fn authorize_tool_call(
 579        &mut self,
 580        id: ToolCallId,
 581        outcome: acp::ToolCallConfirmationOutcome,
 582        cx: &mut Context<Self>,
 583    ) {
 584        let Some(entry) = self.entry_mut(id.0) else {
 585            return;
 586        };
 587
 588        let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
 589            debug_panic!("expected ToolCall");
 590            return;
 591        };
 592
 593        let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
 594            ToolCallStatus::Rejected
 595        } else {
 596            ToolCallStatus::Allowed {
 597                status: acp::ToolCallStatus::Running,
 598            }
 599        };
 600
 601        let curr_status = mem::replace(&mut call.status, new_status);
 602
 603        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
 604            respond_tx.send(outcome).log_err();
 605        } else {
 606            debug_panic!("tried to authorize an already authorized tool call");
 607        }
 608
 609        cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
 610    }
 611
 612    pub fn update_tool_call(
 613        &mut self,
 614        id: ToolCallId,
 615        new_status: acp::ToolCallStatus,
 616        new_content: Option<acp::ToolCallContent>,
 617        cx: &mut Context<Self>,
 618    ) -> Result<()> {
 619        let language_registry = self.project.read(cx).languages().clone();
 620        let entry = self.entry_mut(id.0).context("Entry not found")?;
 621
 622        match &mut entry.content {
 623            AgentThreadEntryContent::ToolCall(call) => {
 624                call.content = new_content.map(|new_content| {
 625                    ToolCallContent::from_acp(new_content, language_registry, cx)
 626                });
 627
 628                match &mut call.status {
 629                    ToolCallStatus::Allowed { status } => {
 630                        *status = new_status;
 631                    }
 632                    ToolCallStatus::WaitingForConfirmation { .. } => {
 633                        anyhow::bail!("Tool call hasn't been authorized yet")
 634                    }
 635                    ToolCallStatus::Rejected => {
 636                        anyhow::bail!("Tool call was rejected and therefore can't be updated")
 637                    }
 638                    ToolCallStatus::Canceled => {
 639                        // todo! test this case with fake server
 640                        call.status = ToolCallStatus::Allowed { status: new_status };
 641                    }
 642                }
 643            }
 644            _ => anyhow::bail!("Entry is not a tool call"),
 645        }
 646
 647        cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
 648        Ok(())
 649    }
 650
 651    fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
 652        let entry = self.entries.get_mut(id.0 as usize);
 653        debug_assert!(
 654            entry.is_some(),
 655            "We shouldn't give out ids to entries that don't exist"
 656        );
 657        entry
 658    }
 659
 660    /// Returns true if the last turn is awaiting tool authorization
 661    pub fn waiting_for_tool_confirmation(&self) -> bool {
 662        // todo!("should we use a hashmap?")
 663        for entry in self.entries.iter().rev() {
 664            match &entry.content {
 665                AgentThreadEntryContent::ToolCall(call) => match call.status {
 666                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
 667                    ToolCallStatus::Allowed { .. }
 668                    | ToolCallStatus::Rejected
 669                    | ToolCallStatus::Canceled => continue,
 670                },
 671                AgentThreadEntryContent::UserMessage(_)
 672                | AgentThreadEntryContent::AssistantMessage(_) => {
 673                    // Reached the beginning of the turn
 674                    return false;
 675                }
 676            }
 677        }
 678        false
 679    }
 680
 681    pub fn send(
 682        &mut self,
 683        message: &str,
 684        cx: &mut Context<Self>,
 685    ) -> impl use<> + Future<Output = Result<()>> {
 686        let agent = self.server.clone();
 687        let chunk =
 688            UserMessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
 689        let message = UserMessage {
 690            chunks: vec![chunk],
 691        };
 692        self.push_entry(AgentThreadEntryContent::UserMessage(message.clone()), cx);
 693        let acp_message = message.into_acp(cx);
 694
 695        let (tx, rx) = oneshot::channel();
 696        let cancel = self.cancel(cx);
 697
 698        self.send_task = Some(cx.spawn(async move |this, cx| {
 699            cancel.await.log_err();
 700
 701            let result = agent.send_message(acp_message, cx).await;
 702            tx.send(result).log_err();
 703            this.update(cx, |this, _cx| this.send_task.take()).log_err();
 704        }));
 705
 706        async move {
 707            match rx.await {
 708                Ok(result) => result,
 709                Err(_) => Ok(()),
 710            }
 711        }
 712    }
 713
 714    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 715        let agent = self.server.clone();
 716
 717        if self.send_task.take().is_some() {
 718            cx.spawn(async move |this, cx| {
 719                agent.cancel_send_message(cx).await?;
 720
 721                this.update(cx, |this, _cx| {
 722                    for entry in this.entries.iter_mut() {
 723                        if let AgentThreadEntryContent::ToolCall(call) = &mut entry.content {
 724                            let cancel = matches!(
 725                                call.status,
 726                                ToolCallStatus::WaitingForConfirmation { .. }
 727                                    | ToolCallStatus::Allowed {
 728                                        status: acp::ToolCallStatus::Running
 729                                    }
 730                            );
 731
 732                            if cancel {
 733                                let curr_status =
 734                                    mem::replace(&mut call.status, ToolCallStatus::Canceled);
 735
 736                                if let ToolCallStatus::WaitingForConfirmation {
 737                                    respond_tx, ..
 738                                } = curr_status
 739                                {
 740                                    respond_tx
 741                                        .send(acp::ToolCallConfirmationOutcome::Cancel)
 742                                        .ok();
 743                                }
 744                            }
 745                        }
 746                    }
 747                })
 748            })
 749        } else {
 750            Task::ready(Ok(()))
 751        }
 752    }
 753
 754    #[cfg(test)]
 755    pub fn to_string(&self, cx: &App) -> String {
 756        let mut result = String::new();
 757        for entry in &self.entries {
 758            match &entry.content {
 759                AgentThreadEntryContent::UserMessage(user_message) => {
 760                    result.push_str("# User\n");
 761                    for chunk in &user_message.chunks {
 762                        match chunk {
 763                            UserMessageChunk::Text { chunk } => {
 764                                result.push_str(chunk.read(cx).source());
 765                                result.push('\n');
 766                            }
 767                            _ => unimplemented!(),
 768                        }
 769                    }
 770                }
 771                AgentThreadEntryContent::AssistantMessage(assistant_message) => {
 772                    result.push_str("# Assistant\n");
 773                    for chunk in &assistant_message.chunks {
 774                        match chunk {
 775                            AssistantMessageChunk::Text { chunk } => {
 776                                result.push_str(chunk.read(cx).source());
 777                                result.push('\n')
 778                            }
 779                            AssistantMessageChunk::Thought { chunk } => {
 780                                result.push_str("<thinking>\n");
 781                                result.push_str(chunk.read(cx).source());
 782                                result.push_str("\n</thinking>\n");
 783                            }
 784                        }
 785                    }
 786                }
 787                AgentThreadEntryContent::ToolCall(_tool_call) => unimplemented!(),
 788            }
 789        }
 790        result
 791    }
 792}
 793
 794fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
 795    match icon {
 796        acp::Icon::FileSearch => IconName::FileSearch,
 797        acp::Icon::Folder => IconName::Folder,
 798        acp::Icon::Globe => IconName::Globe,
 799        acp::Icon::Hammer => IconName::Hammer,
 800        acp::Icon::LightBulb => IconName::LightBulb,
 801        acp::Icon::Pencil => IconName::Pencil,
 802        acp::Icon::Regex => IconName::Regex,
 803        acp::Icon::Terminal => IconName::Terminal,
 804    }
 805}
 806
 807pub struct ToolCallRequest {
 808    pub id: ToolCallId,
 809    pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
 810}
 811
 812#[cfg(test)]
 813mod tests {
 814    use super::*;
 815    use async_pipe::{PipeReader, PipeWriter};
 816    use async_trait::async_trait;
 817    use futures::{FutureExt as _, channel::mpsc, future::LocalBoxFuture, select};
 818    use gpui::{AsyncApp, TestAppContext};
 819    use indoc::indoc;
 820    use project::FakeFs;
 821    use serde_json::json;
 822    use settings::SettingsStore;
 823    use smol::{future::BoxedLocal, stream::StreamExt as _};
 824    use std::{env, path::Path, process::Stdio, rc::Rc, time::Duration};
 825    use util::path;
 826
 827    fn init_test(cx: &mut TestAppContext) {
 828        env_logger::try_init().ok();
 829        cx.update(|cx| {
 830            let settings_store = SettingsStore::test(cx);
 831            cx.set_global(settings_store);
 832            Project::init_settings(cx);
 833            language::init(cx);
 834        });
 835    }
 836
 837    #[gpui::test]
 838    async fn test_thinking_concatenation(cx: &mut TestAppContext) {
 839        init_test(cx);
 840
 841        cx.executor().allow_parking();
 842
 843        let fs = FakeFs::new(cx.executor());
 844        let project = Project::test(fs, [], cx).await;
 845        let (server, fake_server) = fake_acp_server(project, cx);
 846
 847        server.initialize().await.unwrap();
 848
 849        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
 850
 851        fake_server.update(cx, |fake_server, _| {
 852            fake_server.on_user_message(move |params, server, mut cx| async move {
 853                server
 854                    .update(&mut cx, |server, _| {
 855                        server.send_to_zed(acp::StreamAssistantMessageChunkParams {
 856                            chunk: acp::AssistantMessageChunk::Thought {
 857                                chunk: "Thinking ".into(),
 858                            },
 859                        })
 860                    })?
 861                    .await
 862                    .unwrap();
 863                server
 864                    .update(&mut cx, |server, _| {
 865                        server.send_to_zed(acp::StreamAssistantMessageChunkParams {
 866                            chunk: acp::AssistantMessageChunk::Thought {
 867                                chunk: "hard!".into(),
 868                            },
 869                        })
 870                    })?
 871                    .await
 872                    .unwrap();
 873
 874                Ok(acp::SendUserMessageResponse)
 875            })
 876        });
 877
 878        thread
 879            .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
 880            .await
 881            .unwrap();
 882
 883        let output = thread.read_with(cx, |thread, cx| thread.to_string(cx));
 884        assert_eq!(
 885            output,
 886            indoc! {r#"
 887            # User
 888            Hello from Zed!
 889            # Assistant
 890            <thinking>
 891            Thinking hard!
 892            </thinking>
 893            "#}
 894        );
 895    }
 896
 897    #[gpui::test]
 898    async fn test_gemini_basic(cx: &mut TestAppContext) {
 899        init_test(cx);
 900
 901        cx.executor().allow_parking();
 902
 903        let fs = FakeFs::new(cx.executor());
 904        let project = Project::test(fs, [], cx).await;
 905        let server = gemini_acp_server(project.clone(), cx).await;
 906        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
 907        thread
 908            .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
 909            .await
 910            .unwrap();
 911
 912        thread.read_with(cx, |thread, _| {
 913            assert_eq!(thread.entries.len(), 2);
 914            assert!(matches!(
 915                thread.entries[0].content,
 916                AgentThreadEntryContent::UserMessage(_)
 917            ));
 918            assert!(matches!(
 919                thread.entries[1].content,
 920                AgentThreadEntryContent::AssistantMessage(_)
 921            ));
 922        });
 923    }
 924
 925    #[gpui::test]
 926    async fn test_gemini_tool_call(cx: &mut TestAppContext) {
 927        init_test(cx);
 928
 929        cx.executor().allow_parking();
 930
 931        let fs = FakeFs::new(cx.executor());
 932        fs.insert_tree(
 933            path!("/private/tmp"),
 934            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
 935        )
 936        .await;
 937        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
 938        let server = gemini_acp_server(project.clone(), cx).await;
 939        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
 940        thread
 941            .update(cx, |thread, cx| {
 942                thread.send(
 943                    "Read the '/private/tmp/foo' file and tell me what you see.",
 944                    cx,
 945                )
 946            })
 947            .await
 948            .unwrap();
 949        thread.read_with(cx, |thread, _cx| {
 950            assert!(matches!(
 951                &thread.entries()[2].content,
 952                AgentThreadEntryContent::ToolCall(ToolCall {
 953                    status: ToolCallStatus::Allowed { .. },
 954                    ..
 955                })
 956            ));
 957
 958            assert!(matches!(
 959                thread.entries[3].content,
 960                AgentThreadEntryContent::AssistantMessage(_)
 961            ));
 962        });
 963    }
 964
 965    #[gpui::test]
 966    async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
 967        init_test(cx);
 968
 969        cx.executor().allow_parking();
 970
 971        let fs = FakeFs::new(cx.executor());
 972        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
 973        let server = gemini_acp_server(project.clone(), cx).await;
 974        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
 975        let full_turn = thread.update(cx, |thread, cx| {
 976            thread.send(r#"Run `echo "Hello, world!"`"#, cx)
 977        });
 978
 979        run_until_first_tool_call(&thread, cx).await;
 980
 981        let tool_call_id = thread.read_with(cx, |thread, _cx| {
 982            let AgentThreadEntryContent::ToolCall(ToolCall {
 983                id,
 984                status:
 985                    ToolCallStatus::WaitingForConfirmation {
 986                        confirmation: ToolCallConfirmation::Execute { root_command, .. },
 987                        ..
 988                    },
 989                ..
 990            }) = &thread.entries()[2].content
 991            else {
 992                panic!();
 993            };
 994
 995            assert_eq!(root_command, "echo");
 996
 997            *id
 998        });
 999
1000        thread.update(cx, |thread, cx| {
1001            thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
1002
1003            assert!(matches!(
1004                &thread.entries()[2].content,
1005                AgentThreadEntryContent::ToolCall(ToolCall {
1006                    status: ToolCallStatus::Allowed { .. },
1007                    ..
1008                })
1009            ));
1010        });
1011
1012        full_turn.await.unwrap();
1013
1014        thread.read_with(cx, |thread, cx| {
1015            let AgentThreadEntryContent::ToolCall(ToolCall {
1016                content: Some(ToolCallContent::Markdown { markdown }),
1017                status: ToolCallStatus::Allowed { .. },
1018                ..
1019            }) = &thread.entries()[2].content
1020            else {
1021                panic!();
1022            };
1023
1024            markdown.read_with(cx, |md, _cx| {
1025                assert!(
1026                    md.source().contains("Hello, world!"),
1027                    r#"Expected '{}' to contain "Hello, world!""#,
1028                    md.source()
1029                );
1030            });
1031        });
1032    }
1033
1034    #[gpui::test]
1035    async fn test_gemini_cancel(cx: &mut TestAppContext) {
1036        init_test(cx);
1037
1038        cx.executor().allow_parking();
1039
1040        let fs = FakeFs::new(cx.executor());
1041        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
1042        let server = gemini_acp_server(project.clone(), cx).await;
1043        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
1044        let full_turn = thread.update(cx, |thread, cx| {
1045            thread.send(r#"Run `echo "Hello, world!"`"#, cx)
1046        });
1047
1048        let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await;
1049
1050        thread.read_with(cx, |thread, _cx| {
1051            let AgentThreadEntryContent::ToolCall(ToolCall {
1052                id,
1053                status:
1054                    ToolCallStatus::WaitingForConfirmation {
1055                        confirmation: ToolCallConfirmation::Execute { root_command, .. },
1056                        ..
1057                    },
1058                ..
1059            }) = &thread.entries()[first_tool_call_ix].content
1060            else {
1061                panic!("{:?}", thread.entries()[1].content);
1062            };
1063
1064            assert_eq!(root_command, "echo");
1065
1066            *id
1067        });
1068
1069        thread
1070            .update(cx, |thread, cx| thread.cancel(cx))
1071            .await
1072            .unwrap();
1073        full_turn.await.unwrap();
1074        thread.read_with(cx, |thread, _| {
1075            let AgentThreadEntryContent::ToolCall(ToolCall {
1076                status: ToolCallStatus::Canceled,
1077                ..
1078            }) = &thread.entries()[first_tool_call_ix].content
1079            else {
1080                panic!();
1081            };
1082        });
1083
1084        thread
1085            .update(cx, |thread, cx| {
1086                thread.send(r#"Stop running and say goodbye to me."#, cx)
1087            })
1088            .await
1089            .unwrap();
1090        thread.read_with(cx, |thread, _| {
1091            assert!(matches!(
1092                &thread.entries().last().unwrap().content,
1093                AgentThreadEntryContent::AssistantMessage(..),
1094            ))
1095        });
1096    }
1097
1098    async fn run_until_first_tool_call(
1099        thread: &Entity<AcpThread>,
1100        cx: &mut TestAppContext,
1101    ) -> usize {
1102        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1103
1104        let subscription = cx.update(|cx| {
1105            cx.subscribe(thread, move |thread, _, cx| {
1106                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1107                    if matches!(entry.content, AgentThreadEntryContent::ToolCall(_)) {
1108                        return tx.try_send(ix).unwrap();
1109                    }
1110                }
1111            })
1112        });
1113
1114        select! {
1115            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1116                panic!("Timeout waiting for tool call")
1117            }
1118            ix = rx.next().fuse() => {
1119                drop(subscription);
1120                ix.unwrap()
1121            }
1122        }
1123    }
1124
1125    pub async fn gemini_acp_server(
1126        project: Entity<Project>,
1127        cx: &mut TestAppContext,
1128    ) -> Arc<AcpServer> {
1129        let cli_path =
1130            Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
1131        let mut command = util::command::new_smol_command("node");
1132        command
1133            .arg(cli_path)
1134            .arg("--acp")
1135            .current_dir("/private/tmp")
1136            .stdin(Stdio::piped())
1137            .stdout(Stdio::piped())
1138            .stderr(Stdio::inherit())
1139            .kill_on_drop(true);
1140
1141        if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
1142            command.env("GEMINI_API_KEY", gemini_key);
1143        }
1144
1145        let child = command.spawn().unwrap();
1146        let server = cx.update(|cx| AcpServer::stdio(child, project, cx));
1147        server.initialize().await.unwrap();
1148        server
1149    }
1150
1151    pub fn fake_acp_server(
1152        project: Entity<Project>,
1153        cx: &mut TestAppContext,
1154    ) -> (Entity<Thread>, Arc<AcpServer>, Entity<FakeAcpServer>) {
1155        let (stdin_tx, stdin_rx) = async_pipe::pipe();
1156        let (stdout_tx, stdout_rx) = async_pipe::pipe();
1157        let server = cx.update(|cx| AcpServer::fake(stdin_tx, stdout_rx, project, cx));
1158        let thread = server.thread.upgrade().unwrap();
1159        let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
1160        (server, agent)
1161    }
1162
1163    pub struct FakeAcpServer {
1164        connection: acp::ClientConnection,
1165        _handler_task: Task<()>,
1166        _io_task: Task<()>,
1167        on_user_message: Option<
1168            Rc<
1169                dyn Fn(
1170                    acp::SendUserMessageParams,
1171                    Entity<FakeAcpServer>,
1172                    AsyncApp,
1173                )
1174                    -> LocalBoxFuture<'static, Result<acp::SendUserMessageResponse>>,
1175            >,
1176        >,
1177    }
1178
1179    #[derive(Clone)]
1180    struct FakeAgent {
1181        server: Entity<FakeAcpServer>,
1182        cx: AsyncApp,
1183    }
1184
1185    #[async_trait(?Send)]
1186    impl acp::Agent for FakeAgent {
1187        async fn initialize(
1188            &self,
1189            _request: acp::InitializeParams,
1190        ) -> Result<acp::InitializeResponse> {
1191            Ok(acp::InitializeResponse {
1192                is_authenticated: true,
1193            })
1194        }
1195
1196        async fn authenticate(
1197            &self,
1198            _request: acp::AuthenticateParams,
1199        ) -> Result<acp::AuthenticateResponse> {
1200            Ok(acp::AuthenticateResponse)
1201        }
1202
1203        async fn send_user_message(
1204            &self,
1205            request: acp::SendUserMessageParams,
1206        ) -> Result<acp::SendUserMessageResponse> {
1207            let mut cx = self.cx.clone();
1208            let handler = self
1209                .server
1210                .update(&mut cx, |server, _| server.on_user_message.clone())
1211                .ok()
1212                .flatten();
1213            if let Some(handler) = handler {
1214                handler(request, self.server.clone(), self.cx.clone()).await
1215            } else {
1216                anyhow::bail!("No handler for on_user_message")
1217            }
1218        }
1219    }
1220
1221    impl FakeAcpServer {
1222        fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
1223            let agent = FakeAgent {
1224                server: cx.entity(),
1225                cx: cx.to_async(),
1226            };
1227
1228            let (connection, handler_fut, io_fut) =
1229                acp::ClientConnection::connect_to_client(agent.clone(), stdout, stdin);
1230            FakeAcpServer {
1231                connection: connection,
1232                on_user_message: None,
1233                _handler_task: cx.foreground_executor().spawn(handler_fut),
1234                _io_task: cx.background_spawn(async move {
1235                    io_fut.await.log_err();
1236                }),
1237            }
1238        }
1239
1240        fn on_user_message<F>(
1241            &mut self,
1242            handler: impl for<'a> Fn(acp::SendUserMessageParams, Entity<FakeAcpServer>, AsyncApp) -> F
1243            + 'static,
1244        ) where
1245            F: Future<Output = Result<acp::SendUserMessageResponse>> + 'static,
1246        {
1247            self.on_user_message
1248                .replace(Rc::new(move |request, server, cx| {
1249                    handler(request, server, cx).boxed_local()
1250                }));
1251        }
1252
1253        fn send_to_zed<T: acp::ClientRequest>(
1254            &self,
1255            message: T,
1256        ) -> BoxedLocal<Result<T::Response, acp::Error>> {
1257            self.connection.request(message).boxed_local()
1258        }
1259    }
1260}