acp.rs

   1pub use acp::ToolCallId;
   2use agent_servers::AgentServer;
   3use agentic_coding_protocol::{self as acp, UserMessageChunk};
   4use anyhow::{Context as _, Result, anyhow};
   5use assistant_tool::ActionLog;
   6use buffer_diff::BufferDiff;
   7use editor::{MultiBuffer, PathKey};
   8use futures::{FutureExt, channel::oneshot, future::BoxFuture};
   9use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  10use itertools::Itertools;
  11use language::{
  12    Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
  13    text_diff,
  14};
  15use markdown::Markdown;
  16use project::{AgentLocation, Project};
  17use std::collections::HashMap;
  18use std::error::Error;
  19use std::fmt::{Formatter, Write};
  20use std::{
  21    fmt::Display,
  22    mem,
  23    path::{Path, PathBuf},
  24    sync::Arc,
  25};
  26use ui::{App, IconName};
  27use util::ResultExt;
  28
  29#[derive(Clone, Debug, Eq, PartialEq)]
  30pub struct UserMessage {
  31    pub content: Entity<Markdown>,
  32}
  33
  34impl UserMessage {
  35    pub fn from_acp(
  36        message: &acp::SendUserMessageParams,
  37        language_registry: Arc<LanguageRegistry>,
  38        cx: &mut App,
  39    ) -> Self {
  40        let mut md_source = String::new();
  41
  42        for chunk in &message.chunks {
  43            match chunk {
  44                UserMessageChunk::Text { text } => md_source.push_str(&text),
  45                UserMessageChunk::Path { path } => {
  46                    write!(&mut md_source, "{}", MentionPath(&path)).unwrap()
  47                }
  48            }
  49        }
  50
  51        Self {
  52            content: cx
  53                .new(|cx| Markdown::new(md_source.into(), Some(language_registry), None, cx)),
  54        }
  55    }
  56
  57    fn to_markdown(&self, cx: &App) -> String {
  58        format!("## User\n\n{}\n\n", self.content.read(cx).source())
  59    }
  60}
  61
  62#[derive(Debug)]
  63pub struct MentionPath<'a>(&'a Path);
  64
  65impl<'a> MentionPath<'a> {
  66    const PREFIX: &'static str = "@file:";
  67
  68    pub fn new(path: &'a Path) -> Self {
  69        MentionPath(path)
  70    }
  71
  72    pub fn try_parse(url: &'a str) -> Option<Self> {
  73        let path = url.strip_prefix(Self::PREFIX)?;
  74        Some(MentionPath(Path::new(path)))
  75    }
  76
  77    pub fn path(&self) -> &Path {
  78        self.0
  79    }
  80}
  81
  82impl Display for MentionPath<'_> {
  83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  84        write!(
  85            f,
  86            "[@{}]({}{})",
  87            self.0.file_name().unwrap_or_default().display(),
  88            Self::PREFIX,
  89            self.0.display()
  90        )
  91    }
  92}
  93
  94#[derive(Clone, Debug, Eq, PartialEq)]
  95pub struct AssistantMessage {
  96    pub chunks: Vec<AssistantMessageChunk>,
  97}
  98
  99impl AssistantMessage {
 100    fn to_markdown(&self, cx: &App) -> String {
 101        format!(
 102            "## Assistant\n\n{}\n\n",
 103            self.chunks
 104                .iter()
 105                .map(|chunk| chunk.to_markdown(cx))
 106                .join("\n\n")
 107        )
 108    }
 109}
 110
 111#[derive(Clone, Debug, Eq, PartialEq)]
 112pub enum AssistantMessageChunk {
 113    Text { chunk: Entity<Markdown> },
 114    Thought { chunk: Entity<Markdown> },
 115}
 116
 117impl AssistantMessageChunk {
 118    pub fn from_acp(
 119        chunk: acp::AssistantMessageChunk,
 120        language_registry: Arc<LanguageRegistry>,
 121        cx: &mut App,
 122    ) -> Self {
 123        match chunk {
 124            acp::AssistantMessageChunk::Text { text } => Self::Text {
 125                chunk: cx.new(|cx| Markdown::new(text.into(), Some(language_registry), None, cx)),
 126            },
 127            acp::AssistantMessageChunk::Thought { thought } => Self::Thought {
 128                chunk: cx
 129                    .new(|cx| Markdown::new(thought.into(), Some(language_registry), None, cx)),
 130            },
 131        }
 132    }
 133
 134    pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
 135        Self::Text {
 136            chunk: cx.new(|cx| {
 137                Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
 138            }),
 139        }
 140    }
 141
 142    fn to_markdown(&self, cx: &App) -> String {
 143        match self {
 144            Self::Text { chunk } => chunk.read(cx).source().to_string(),
 145            Self::Thought { chunk } => {
 146                format!("<thinking>\n{}\n</thinking>", chunk.read(cx).source())
 147            }
 148        }
 149    }
 150}
 151
 152#[derive(Debug)]
 153pub enum AgentThreadEntry {
 154    UserMessage(UserMessage),
 155    AssistantMessage(AssistantMessage),
 156    ToolCall(ToolCall),
 157}
 158
 159impl AgentThreadEntry {
 160    fn to_markdown(&self, cx: &App) -> String {
 161        match self {
 162            Self::UserMessage(message) => message.to_markdown(cx),
 163            Self::AssistantMessage(message) => message.to_markdown(cx),
 164            Self::ToolCall(too_call) => too_call.to_markdown(cx),
 165        }
 166    }
 167
 168    pub fn diff(&self) -> Option<&Diff> {
 169        if let AgentThreadEntry::ToolCall(ToolCall {
 170            content: Some(ToolCallContent::Diff { diff }),
 171            ..
 172        }) = self
 173        {
 174            Some(&diff)
 175        } else {
 176            None
 177        }
 178    }
 179
 180    pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
 181        if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
 182            Some(locations)
 183        } else {
 184            None
 185        }
 186    }
 187}
 188
 189#[derive(Debug)]
 190pub struct ToolCall {
 191    pub id: acp::ToolCallId,
 192    pub label: Entity<Markdown>,
 193    pub icon: IconName,
 194    pub content: Option<ToolCallContent>,
 195    pub status: ToolCallStatus,
 196    pub locations: Vec<acp::ToolCallLocation>,
 197}
 198
 199impl ToolCall {
 200    fn to_markdown(&self, cx: &App) -> String {
 201        let mut markdown = format!(
 202            "**Tool Call: {}**\nStatus: {}\n\n",
 203            self.label.read(cx).source(),
 204            self.status
 205        );
 206        if let Some(content) = &self.content {
 207            markdown.push_str(content.to_markdown(cx).as_str());
 208            markdown.push_str("\n\n");
 209        }
 210        markdown
 211    }
 212}
 213
 214#[derive(Debug)]
 215pub enum ToolCallStatus {
 216    WaitingForConfirmation {
 217        confirmation: ToolCallConfirmation,
 218        respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
 219    },
 220    Allowed {
 221        status: acp::ToolCallStatus,
 222    },
 223    Rejected,
 224    Canceled,
 225}
 226
 227impl Display for ToolCallStatus {
 228    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 229        write!(
 230            f,
 231            "{}",
 232            match self {
 233                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 234                ToolCallStatus::Allowed { status } => match status {
 235                    acp::ToolCallStatus::Running => "Running",
 236                    acp::ToolCallStatus::Finished => "Finished",
 237                    acp::ToolCallStatus::Error => "Error",
 238                },
 239                ToolCallStatus::Rejected => "Rejected",
 240                ToolCallStatus::Canceled => "Canceled",
 241            }
 242        )
 243    }
 244}
 245
 246#[derive(Debug)]
 247pub enum ToolCallConfirmation {
 248    Edit {
 249        description: Option<Entity<Markdown>>,
 250    },
 251    Execute {
 252        command: String,
 253        root_command: String,
 254        description: Option<Entity<Markdown>>,
 255    },
 256    Mcp {
 257        server_name: String,
 258        tool_name: String,
 259        tool_display_name: String,
 260        description: Option<Entity<Markdown>>,
 261    },
 262    Fetch {
 263        urls: Vec<SharedString>,
 264        description: Option<Entity<Markdown>>,
 265    },
 266    Other {
 267        description: Entity<Markdown>,
 268    },
 269}
 270
 271impl ToolCallConfirmation {
 272    pub fn from_acp(
 273        confirmation: acp::ToolCallConfirmation,
 274        language_registry: Arc<LanguageRegistry>,
 275        cx: &mut App,
 276    ) -> Self {
 277        let to_md = |description: String, cx: &mut App| -> Entity<Markdown> {
 278            cx.new(|cx| {
 279                Markdown::new(
 280                    description.into(),
 281                    Some(language_registry.clone()),
 282                    None,
 283                    cx,
 284                )
 285            })
 286        };
 287
 288        match confirmation {
 289            acp::ToolCallConfirmation::Edit { description } => Self::Edit {
 290                description: description.map(|description| to_md(description, cx)),
 291            },
 292            acp::ToolCallConfirmation::Execute {
 293                command,
 294                root_command,
 295                description,
 296            } => Self::Execute {
 297                command,
 298                root_command,
 299                description: description.map(|description| to_md(description, cx)),
 300            },
 301            acp::ToolCallConfirmation::Mcp {
 302                server_name,
 303                tool_name,
 304                tool_display_name,
 305                description,
 306            } => Self::Mcp {
 307                server_name,
 308                tool_name,
 309                tool_display_name,
 310                description: description.map(|description| to_md(description, cx)),
 311            },
 312            acp::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch {
 313                urls: urls.iter().map(|url| url.into()).collect(),
 314                description: description.map(|description| to_md(description, cx)),
 315            },
 316            acp::ToolCallConfirmation::Other { description } => Self::Other {
 317                description: to_md(description, cx),
 318            },
 319        }
 320    }
 321}
 322
 323#[derive(Debug)]
 324pub enum ToolCallContent {
 325    Markdown { markdown: Entity<Markdown> },
 326    Diff { diff: Diff },
 327}
 328
 329impl ToolCallContent {
 330    pub fn from_acp(
 331        content: acp::ToolCallContent,
 332        language_registry: Arc<LanguageRegistry>,
 333        cx: &mut App,
 334    ) -> Self {
 335        match content {
 336            acp::ToolCallContent::Markdown { markdown } => Self::Markdown {
 337                markdown: cx.new(|cx| Markdown::new_text(markdown.into(), cx)),
 338            },
 339            acp::ToolCallContent::Diff { diff } => Self::Diff {
 340                diff: Diff::from_acp(diff, language_registry, cx),
 341            },
 342        }
 343    }
 344
 345    fn to_markdown(&self, cx: &App) -> String {
 346        match self {
 347            Self::Markdown { markdown } => markdown.read(cx).source().to_string(),
 348            Self::Diff { diff } => diff.to_markdown(cx),
 349        }
 350    }
 351}
 352
 353#[derive(Debug)]
 354pub struct Diff {
 355    pub multibuffer: Entity<MultiBuffer>,
 356    pub path: PathBuf,
 357    pub new_buffer: Entity<Buffer>,
 358    pub old_buffer: Entity<Buffer>,
 359    _task: Task<Result<()>>,
 360}
 361
 362impl Diff {
 363    pub fn from_acp(
 364        diff: acp::Diff,
 365        language_registry: Arc<LanguageRegistry>,
 366        cx: &mut App,
 367    ) -> Self {
 368        let acp::Diff {
 369            path,
 370            old_text,
 371            new_text,
 372        } = diff;
 373
 374        let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
 375
 376        let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
 377        let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
 378        let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
 379        let old_buffer_snapshot = old_buffer.read(cx).snapshot();
 380        let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
 381        let diff_task = buffer_diff.update(cx, |diff, cx| {
 382            diff.set_base_text(
 383                old_buffer_snapshot,
 384                Some(language_registry.clone()),
 385                new_buffer_snapshot,
 386                cx,
 387            )
 388        });
 389
 390        let task = cx.spawn({
 391            let multibuffer = multibuffer.clone();
 392            let path = path.clone();
 393            let new_buffer = new_buffer.clone();
 394            async move |cx| {
 395                diff_task.await?;
 396
 397                multibuffer
 398                    .update(cx, |multibuffer, cx| {
 399                        let hunk_ranges = {
 400                            let buffer = new_buffer.read(cx);
 401                            let diff = buffer_diff.read(cx);
 402                            diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
 403                                .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
 404                                .collect::<Vec<_>>()
 405                        };
 406
 407                        multibuffer.set_excerpts_for_path(
 408                            PathKey::for_buffer(&new_buffer, cx),
 409                            new_buffer.clone(),
 410                            hunk_ranges,
 411                            editor::DEFAULT_MULTIBUFFER_CONTEXT,
 412                            cx,
 413                        );
 414                        multibuffer.add_diff(buffer_diff.clone(), cx);
 415                    })
 416                    .log_err();
 417
 418                if let Some(language) = language_registry
 419                    .language_for_file_path(&path)
 420                    .await
 421                    .log_err()
 422                {
 423                    new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
 424                }
 425
 426                anyhow::Ok(())
 427            }
 428        });
 429
 430        Self {
 431            multibuffer,
 432            path,
 433            new_buffer,
 434            old_buffer,
 435            _task: task,
 436        }
 437    }
 438
 439    fn to_markdown(&self, cx: &App) -> String {
 440        let buffer_text = self
 441            .multibuffer
 442            .read(cx)
 443            .all_buffers()
 444            .iter()
 445            .map(|buffer| buffer.read(cx).text())
 446            .join("\n");
 447        format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
 448    }
 449}
 450
 451pub struct AcpThread {
 452    entries: Vec<AgentThreadEntry>,
 453    title: SharedString,
 454    project: Entity<Project>,
 455    action_log: Entity<ActionLog>,
 456    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 457    send_task: Option<Task<()>>,
 458    connection: Arc<acp::AgentConnection>,
 459    child_status: Option<Task<Result<()>>>,
 460    _io_task: Task<()>,
 461}
 462
 463pub enum AcpThreadEvent {
 464    NewEntry,
 465    EntryUpdated(usize),
 466}
 467
 468impl EventEmitter<AcpThreadEvent> for AcpThread {}
 469
 470#[derive(PartialEq, Eq)]
 471pub enum ThreadStatus {
 472    Idle,
 473    WaitingForToolConfirmation,
 474    Generating,
 475}
 476
 477#[derive(Debug, Clone)]
 478pub enum LoadError {
 479    Unsupported { current_version: SharedString },
 480    Exited(i32),
 481    Other(SharedString),
 482}
 483
 484impl Display for LoadError {
 485    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 486        match self {
 487            LoadError::Unsupported { current_version } => {
 488                write!(
 489                    f,
 490                    "Your installed version of Gemini {} doesn't support the Agentic Coding Protocol (ACP).",
 491                    current_version
 492                )
 493            }
 494            LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
 495            LoadError::Other(msg) => write!(f, "{}", msg),
 496        }
 497    }
 498}
 499
 500impl Error for LoadError {}
 501
 502impl AcpThread {
 503    pub async fn spawn(
 504        server: impl AgentServer + 'static,
 505        root_dir: &Path,
 506        project: Entity<Project>,
 507        cx: &mut AsyncApp,
 508    ) -> Result<Entity<Self>> {
 509        let command = match server.command(&project, cx).await {
 510            Ok(command) => command,
 511            Err(e) => return Err(anyhow!(LoadError::Other(format!("{e}").into()))),
 512        };
 513
 514        let mut child = util::command::new_smol_command(&command.path)
 515            .args(command.args.iter())
 516            .current_dir(root_dir)
 517            .stdin(std::process::Stdio::piped())
 518            .stdout(std::process::Stdio::piped())
 519            .stderr(std::process::Stdio::inherit())
 520            .kill_on_drop(true)
 521            .spawn()?;
 522
 523        let stdin = child.stdin.take().unwrap();
 524        let stdout = child.stdout.take().unwrap();
 525
 526        cx.new(|cx| {
 527            let foreground_executor = cx.foreground_executor().clone();
 528
 529            let (connection, io_fut) = acp::AgentConnection::connect_to_agent(
 530                AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
 531                stdin,
 532                stdout,
 533                move |fut| foreground_executor.spawn(fut).detach(),
 534            );
 535
 536            let io_task = cx.background_spawn(async move {
 537                io_fut.await.log_err();
 538            });
 539
 540            let child_status = cx.background_spawn(async move {
 541                match child.status().await {
 542                    Err(e) => Err(anyhow!(e)),
 543                    Ok(result) if result.success() => Ok(()),
 544                    Ok(result) => {
 545                        if let Some(version) = server.version(&command).await.log_err()
 546                            && !version.supported
 547                        {
 548                            Err(anyhow!(LoadError::Unsupported {
 549                                current_version: version.current_version
 550                            }))
 551                        } else {
 552                            Err(anyhow!(LoadError::Exited(result.code().unwrap_or(-127))))
 553                        }
 554                    }
 555                }
 556            });
 557
 558            let action_log = cx.new(|_| ActionLog::new(project.clone()));
 559
 560            Self {
 561                action_log,
 562                shared_buffers: Default::default(),
 563                entries: Default::default(),
 564                title: "ACP Thread".into(),
 565                project,
 566                send_task: None,
 567                connection: Arc::new(connection),
 568                child_status: Some(child_status),
 569                _io_task: io_task,
 570            }
 571        })
 572    }
 573
 574    pub fn action_log(&self) -> &Entity<ActionLog> {
 575        &self.action_log
 576    }
 577
 578    pub fn project(&self) -> &Entity<Project> {
 579        &self.project
 580    }
 581
 582    #[cfg(test)]
 583    pub fn fake(
 584        stdin: async_pipe::PipeWriter,
 585        stdout: async_pipe::PipeReader,
 586        project: Entity<Project>,
 587        cx: &mut Context<Self>,
 588    ) -> Self {
 589        let foreground_executor = cx.foreground_executor().clone();
 590
 591        let (connection, io_fut) = acp::AgentConnection::connect_to_agent(
 592            AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
 593            stdin,
 594            stdout,
 595            move |fut| {
 596                foreground_executor.spawn(fut).detach();
 597            },
 598        );
 599
 600        let io_task = cx.background_spawn({
 601            async move {
 602                io_fut.await.log_err();
 603            }
 604        });
 605
 606        let action_log = cx.new(|_| ActionLog::new(project.clone()));
 607
 608        Self {
 609            action_log,
 610            shared_buffers: Default::default(),
 611            entries: Default::default(),
 612            title: "ACP Thread".into(),
 613            project,
 614            send_task: None,
 615            connection: Arc::new(connection),
 616            child_status: None,
 617            _io_task: io_task,
 618        }
 619    }
 620
 621    pub fn title(&self) -> SharedString {
 622        self.title.clone()
 623    }
 624
 625    pub fn entries(&self) -> &[AgentThreadEntry] {
 626        &self.entries
 627    }
 628
 629    pub fn status(&self) -> ThreadStatus {
 630        if self.send_task.is_some() {
 631            if self.waiting_for_tool_confirmation() {
 632                ThreadStatus::WaitingForToolConfirmation
 633            } else {
 634                ThreadStatus::Generating
 635            }
 636        } else {
 637            ThreadStatus::Idle
 638        }
 639    }
 640
 641    pub fn has_pending_edit_tool_calls(&self) -> bool {
 642        for entry in self.entries.iter().rev() {
 643            match entry {
 644                AgentThreadEntry::UserMessage(_) => return false,
 645                AgentThreadEntry::ToolCall(ToolCall {
 646                    status:
 647                        ToolCallStatus::Allowed {
 648                            status: acp::ToolCallStatus::Running,
 649                            ..
 650                        },
 651                    content: Some(ToolCallContent::Diff { .. }),
 652                    ..
 653                }) => return true,
 654                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 655            }
 656        }
 657
 658        false
 659    }
 660
 661    pub fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
 662        self.entries.push(entry);
 663        cx.emit(AcpThreadEvent::NewEntry);
 664    }
 665
 666    pub fn push_assistant_chunk(
 667        &mut self,
 668        chunk: acp::AssistantMessageChunk,
 669        cx: &mut Context<Self>,
 670    ) {
 671        let entries_len = self.entries.len();
 672        if let Some(last_entry) = self.entries.last_mut()
 673            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
 674        {
 675            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 676
 677            match (chunks.last_mut(), &chunk) {
 678                (
 679                    Some(AssistantMessageChunk::Text { chunk: old_chunk }),
 680                    acp::AssistantMessageChunk::Text { text: new_chunk },
 681                )
 682                | (
 683                    Some(AssistantMessageChunk::Thought { chunk: old_chunk }),
 684                    acp::AssistantMessageChunk::Thought { thought: new_chunk },
 685                ) => {
 686                    old_chunk.update(cx, |old_chunk, cx| {
 687                        old_chunk.append(&new_chunk, cx);
 688                    });
 689                }
 690                _ => {
 691                    chunks.push(AssistantMessageChunk::from_acp(
 692                        chunk,
 693                        self.project.read(cx).languages().clone(),
 694                        cx,
 695                    ));
 696                }
 697            }
 698        } else {
 699            let chunk = AssistantMessageChunk::from_acp(
 700                chunk,
 701                self.project.read(cx).languages().clone(),
 702                cx,
 703            );
 704
 705            self.push_entry(
 706                AgentThreadEntry::AssistantMessage(AssistantMessage {
 707                    chunks: vec![chunk],
 708                }),
 709                cx,
 710            );
 711        }
 712    }
 713
 714    pub fn request_tool_call(
 715        &mut self,
 716        tool_call: acp::RequestToolCallConfirmationParams,
 717        cx: &mut Context<Self>,
 718    ) -> ToolCallRequest {
 719        let (tx, rx) = oneshot::channel();
 720
 721        let status = ToolCallStatus::WaitingForConfirmation {
 722            confirmation: ToolCallConfirmation::from_acp(
 723                tool_call.confirmation,
 724                self.project.read(cx).languages().clone(),
 725                cx,
 726            ),
 727            respond_tx: tx,
 728        };
 729
 730        let id = self.insert_tool_call(tool_call.tool_call, status, cx);
 731        ToolCallRequest { id, outcome: rx }
 732    }
 733
 734    pub fn push_tool_call(
 735        &mut self,
 736        request: acp::PushToolCallParams,
 737        cx: &mut Context<Self>,
 738    ) -> acp::ToolCallId {
 739        let status = ToolCallStatus::Allowed {
 740            status: acp::ToolCallStatus::Running,
 741        };
 742
 743        self.insert_tool_call(request, status, cx)
 744    }
 745
 746    fn insert_tool_call(
 747        &mut self,
 748        tool_call: acp::PushToolCallParams,
 749        status: ToolCallStatus,
 750        cx: &mut Context<Self>,
 751    ) -> acp::ToolCallId {
 752        let language_registry = self.project.read(cx).languages().clone();
 753        let id = acp::ToolCallId(self.entries.len() as u64);
 754        let call = ToolCall {
 755            id,
 756            label: cx.new(|cx| {
 757                Markdown::new(
 758                    tool_call.label.into(),
 759                    Some(language_registry.clone()),
 760                    None,
 761                    cx,
 762                )
 763            }),
 764            icon: acp_icon_to_ui_icon(tool_call.icon),
 765            content: tool_call
 766                .content
 767                .map(|content| ToolCallContent::from_acp(content, language_registry, cx)),
 768            locations: tool_call.locations,
 769            status,
 770        };
 771
 772        self.push_entry(AgentThreadEntry::ToolCall(call), cx);
 773
 774        id
 775    }
 776
 777    pub fn authorize_tool_call(
 778        &mut self,
 779        id: acp::ToolCallId,
 780        outcome: acp::ToolCallConfirmationOutcome,
 781        cx: &mut Context<Self>,
 782    ) {
 783        let Some((ix, call)) = self.tool_call_mut(id) else {
 784            return;
 785        };
 786
 787        let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
 788            ToolCallStatus::Rejected
 789        } else {
 790            ToolCallStatus::Allowed {
 791                status: acp::ToolCallStatus::Running,
 792            }
 793        };
 794
 795        let curr_status = mem::replace(&mut call.status, new_status);
 796
 797        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
 798            respond_tx.send(outcome).log_err();
 799        } else if cfg!(debug_assertions) {
 800            panic!("tried to authorize an already authorized tool call");
 801        }
 802
 803        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 804    }
 805
 806    pub fn update_tool_call(
 807        &mut self,
 808        id: acp::ToolCallId,
 809        new_status: acp::ToolCallStatus,
 810        new_content: Option<acp::ToolCallContent>,
 811        cx: &mut Context<Self>,
 812    ) -> Result<()> {
 813        let language_registry = self.project.read(cx).languages().clone();
 814        let (ix, call) = self.tool_call_mut(id).context("Entry not found")?;
 815
 816        call.content = new_content
 817            .map(|new_content| ToolCallContent::from_acp(new_content, language_registry, cx));
 818
 819        match &mut call.status {
 820            ToolCallStatus::Allowed { status } => {
 821                *status = new_status;
 822            }
 823            ToolCallStatus::WaitingForConfirmation { .. } => {
 824                anyhow::bail!("Tool call hasn't been authorized yet")
 825            }
 826            ToolCallStatus::Rejected => {
 827                anyhow::bail!("Tool call was rejected and therefore can't be updated")
 828            }
 829            ToolCallStatus::Canceled => {
 830                call.status = ToolCallStatus::Allowed { status: new_status };
 831            }
 832        }
 833
 834        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 835        Ok(())
 836    }
 837
 838    fn tool_call_mut(&mut self, id: acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
 839        let entry = self.entries.get_mut(id.0 as usize);
 840        debug_assert!(
 841            entry.is_some(),
 842            "We shouldn't give out ids to entries that don't exist"
 843        );
 844        match entry {
 845            Some(AgentThreadEntry::ToolCall(call)) if call.id == id => Some((id.0 as usize, call)),
 846            _ => {
 847                if cfg!(debug_assertions) {
 848                    panic!("entry is not a tool call");
 849                }
 850                None
 851            }
 852        }
 853    }
 854
 855    /// Returns true if the last turn is awaiting tool authorization
 856    pub fn waiting_for_tool_confirmation(&self) -> bool {
 857        for entry in self.entries.iter().rev() {
 858            match &entry {
 859                AgentThreadEntry::ToolCall(call) => match call.status {
 860                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
 861                    ToolCallStatus::Allowed { .. }
 862                    | ToolCallStatus::Rejected
 863                    | ToolCallStatus::Canceled => continue,
 864                },
 865                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
 866                    // Reached the beginning of the turn
 867                    return false;
 868                }
 869            }
 870        }
 871        false
 872    }
 873
 874    pub fn initialize(
 875        &self,
 876    ) -> impl use<> + Future<Output = Result<acp::InitializeResponse, acp::Error>> {
 877        let connection = self.connection.clone();
 878        async move { connection.initialize().await }
 879    }
 880
 881    pub fn authenticate(&self) -> impl use<> + Future<Output = Result<(), acp::Error>> {
 882        let connection = self.connection.clone();
 883        async move { connection.request(acp::AuthenticateParams).await }
 884    }
 885
 886    #[cfg(test)]
 887    pub fn send_raw(
 888        &mut self,
 889        message: &str,
 890        cx: &mut Context<Self>,
 891    ) -> BoxFuture<'static, Result<(), acp::Error>> {
 892        self.send(
 893            acp::SendUserMessageParams {
 894                chunks: vec![acp::UserMessageChunk::Text {
 895                    text: message.to_string(),
 896                }],
 897            },
 898            cx,
 899        )
 900    }
 901
 902    pub fn send(
 903        &mut self,
 904        message: acp::SendUserMessageParams,
 905        cx: &mut Context<Self>,
 906    ) -> BoxFuture<'static, Result<(), acp::Error>> {
 907        let agent = self.connection.clone();
 908        self.push_entry(
 909            AgentThreadEntry::UserMessage(UserMessage::from_acp(
 910                &message,
 911                self.project.read(cx).languages().clone(),
 912                cx,
 913            )),
 914            cx,
 915        );
 916
 917        let (tx, rx) = oneshot::channel();
 918        let cancel = self.cancel(cx);
 919
 920        self.send_task = Some(cx.spawn(async move |this, cx| {
 921            cancel.await.log_err();
 922
 923            let result = agent.request(message).await;
 924            tx.send(result).log_err();
 925            this.update(cx, |this, _cx| this.send_task.take()).log_err();
 926        }));
 927
 928        async move {
 929            match rx.await {
 930                Ok(Err(e)) => Err(e)?,
 931                _ => Ok(()),
 932            }
 933        }
 934        .boxed()
 935    }
 936
 937    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<(), acp::Error>> {
 938        let agent = self.connection.clone();
 939
 940        if self.send_task.take().is_some() {
 941            cx.spawn(async move |this, cx| {
 942                agent.request(acp::CancelSendMessageParams).await?;
 943
 944                this.update(cx, |this, _cx| {
 945                    for entry in this.entries.iter_mut() {
 946                        if let AgentThreadEntry::ToolCall(call) = entry {
 947                            let cancel = matches!(
 948                                call.status,
 949                                ToolCallStatus::WaitingForConfirmation { .. }
 950                                    | ToolCallStatus::Allowed {
 951                                        status: acp::ToolCallStatus::Running
 952                                    }
 953                            );
 954
 955                            if cancel {
 956                                let curr_status =
 957                                    mem::replace(&mut call.status, ToolCallStatus::Canceled);
 958
 959                                if let ToolCallStatus::WaitingForConfirmation {
 960                                    respond_tx, ..
 961                                } = curr_status
 962                                {
 963                                    respond_tx
 964                                        .send(acp::ToolCallConfirmationOutcome::Cancel)
 965                                        .ok();
 966                                }
 967                            }
 968                        }
 969                    }
 970                })?;
 971                Ok(())
 972            })
 973        } else {
 974            Task::ready(Ok(()))
 975        }
 976    }
 977
 978    pub fn read_text_file(
 979        &self,
 980        request: acp::ReadTextFileParams,
 981        cx: &mut Context<Self>,
 982    ) -> Task<Result<String>> {
 983        let project = self.project.clone();
 984        let action_log = self.action_log.clone();
 985        cx.spawn(async move |this, cx| {
 986            let load = project.update(cx, |project, cx| {
 987                let path = project
 988                    .project_path_for_absolute_path(&request.path, cx)
 989                    .context("invalid path")?;
 990                anyhow::Ok(project.open_buffer(path, cx))
 991            });
 992            let buffer = load??.await?;
 993
 994            action_log.update(cx, |action_log, cx| {
 995                action_log.buffer_read(buffer.clone(), cx);
 996            })?;
 997            project.update(cx, |project, cx| {
 998                let position = buffer
 999                    .read(cx)
1000                    .snapshot()
1001                    .anchor_before(Point::new(request.line.unwrap_or_default(), 0));
1002                project.set_agent_location(
1003                    Some(AgentLocation {
1004                        buffer: buffer.downgrade(),
1005                        position,
1006                    }),
1007                    cx,
1008                );
1009            })?;
1010            let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
1011            this.update(cx, |this, _| {
1012                let text = snapshot.text();
1013                this.shared_buffers.insert(buffer.clone(), snapshot);
1014                text
1015            })
1016        })
1017    }
1018
1019    pub fn write_text_file(
1020        &self,
1021        path: PathBuf,
1022        content: String,
1023        cx: &mut Context<Self>,
1024    ) -> Task<Result<()>> {
1025        let project = self.project.clone();
1026        let action_log = self.action_log.clone();
1027        cx.spawn(async move |this, cx| {
1028            let load = project.update(cx, |project, cx| {
1029                let path = project
1030                    .project_path_for_absolute_path(&path, cx)
1031                    .context("invalid path")?;
1032                anyhow::Ok(project.open_buffer(path, cx))
1033            });
1034            let buffer = load??.await?;
1035            let snapshot = this.update(cx, |this, cx| {
1036                this.shared_buffers
1037                    .get(&buffer)
1038                    .cloned()
1039                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1040            })?;
1041            let edits = cx
1042                .background_executor()
1043                .spawn(async move {
1044                    let old_text = snapshot.text();
1045                    text_diff(old_text.as_str(), &content)
1046                        .into_iter()
1047                        .map(|(range, replacement)| {
1048                            (
1049                                snapshot.anchor_after(range.start)
1050                                    ..snapshot.anchor_before(range.end),
1051                                replacement,
1052                            )
1053                        })
1054                        .collect::<Vec<_>>()
1055                })
1056                .await;
1057            cx.update(|cx| {
1058                project.update(cx, |project, cx| {
1059                    project.set_agent_location(
1060                        Some(AgentLocation {
1061                            buffer: buffer.downgrade(),
1062                            position: edits
1063                                .last()
1064                                .map(|(range, _)| range.end)
1065                                .unwrap_or(Anchor::MIN),
1066                        }),
1067                        cx,
1068                    );
1069                });
1070
1071                action_log.update(cx, |action_log, cx| {
1072                    action_log.buffer_read(buffer.clone(), cx);
1073                });
1074                buffer.update(cx, |buffer, cx| {
1075                    buffer.edit(edits, None, cx);
1076                });
1077                action_log.update(cx, |action_log, cx| {
1078                    action_log.buffer_edited(buffer.clone(), cx);
1079                });
1080            })?;
1081            project
1082                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1083                .await
1084        })
1085    }
1086
1087    pub fn child_status(&mut self) -> Option<Task<Result<()>>> {
1088        self.child_status.take()
1089    }
1090
1091    pub fn to_markdown(&self, cx: &App) -> String {
1092        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1093    }
1094}
1095
1096struct AcpClientDelegate {
1097    thread: WeakEntity<AcpThread>,
1098    cx: AsyncApp,
1099    // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
1100}
1101
1102impl AcpClientDelegate {
1103    fn new(thread: WeakEntity<AcpThread>, cx: AsyncApp) -> Self {
1104        Self { thread, cx }
1105    }
1106}
1107
1108impl acp::Client for AcpClientDelegate {
1109    async fn stream_assistant_message_chunk(
1110        &self,
1111        params: acp::StreamAssistantMessageChunkParams,
1112    ) -> Result<(), acp::Error> {
1113        let cx = &mut self.cx.clone();
1114
1115        cx.update(|cx| {
1116            self.thread
1117                .update(cx, |thread, cx| {
1118                    thread.push_assistant_chunk(params.chunk, cx)
1119                })
1120                .ok();
1121        })?;
1122
1123        Ok(())
1124    }
1125
1126    async fn request_tool_call_confirmation(
1127        &self,
1128        request: acp::RequestToolCallConfirmationParams,
1129    ) -> Result<acp::RequestToolCallConfirmationResponse, acp::Error> {
1130        let cx = &mut self.cx.clone();
1131        let ToolCallRequest { id, outcome } = cx
1132            .update(|cx| {
1133                self.thread
1134                    .update(cx, |thread, cx| thread.request_tool_call(request, cx))
1135            })?
1136            .context("Failed to update thread")?;
1137
1138        Ok(acp::RequestToolCallConfirmationResponse {
1139            id,
1140            outcome: outcome.await.map_err(acp::Error::into_internal_error)?,
1141        })
1142    }
1143
1144    async fn push_tool_call(
1145        &self,
1146        request: acp::PushToolCallParams,
1147    ) -> Result<acp::PushToolCallResponse, acp::Error> {
1148        let cx = &mut self.cx.clone();
1149        let id = cx
1150            .update(|cx| {
1151                self.thread
1152                    .update(cx, |thread, cx| thread.push_tool_call(request, cx))
1153            })?
1154            .context("Failed to update thread")?;
1155
1156        Ok(acp::PushToolCallResponse { id })
1157    }
1158
1159    async fn update_tool_call(&self, request: acp::UpdateToolCallParams) -> Result<(), acp::Error> {
1160        let cx = &mut self.cx.clone();
1161
1162        cx.update(|cx| {
1163            self.thread.update(cx, |thread, cx| {
1164                thread.update_tool_call(request.tool_call_id, request.status, request.content, cx)
1165            })
1166        })?
1167        .context("Failed to update thread")??;
1168
1169        Ok(())
1170    }
1171
1172    async fn read_text_file(
1173        &self,
1174        request: acp::ReadTextFileParams,
1175    ) -> Result<acp::ReadTextFileResponse, acp::Error> {
1176        let content = self
1177            .cx
1178            .update(|cx| {
1179                self.thread
1180                    .update(cx, |thread, cx| thread.read_text_file(request, cx))
1181            })?
1182            .context("Failed to update thread")?
1183            .await?;
1184        Ok(acp::ReadTextFileResponse { content })
1185    }
1186
1187    async fn write_text_file(&self, request: acp::WriteTextFileParams) -> Result<(), acp::Error> {
1188        self.cx
1189            .update(|cx| {
1190                self.thread.update(cx, |thread, cx| {
1191                    thread.write_text_file(request.path, request.content, cx)
1192                })
1193            })?
1194            .context("Failed to update thread")?
1195            .await?;
1196
1197        Ok(())
1198    }
1199}
1200
1201fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
1202    match icon {
1203        acp::Icon::FileSearch => IconName::ToolSearch,
1204        acp::Icon::Folder => IconName::ToolFolder,
1205        acp::Icon::Globe => IconName::ToolWeb,
1206        acp::Icon::Hammer => IconName::ToolHammer,
1207        acp::Icon::LightBulb => IconName::ToolBulb,
1208        acp::Icon::Pencil => IconName::ToolPencil,
1209        acp::Icon::Regex => IconName::ToolRegex,
1210        acp::Icon::Terminal => IconName::ToolTerminal,
1211    }
1212}
1213
1214pub struct ToolCallRequest {
1215    pub id: acp::ToolCallId,
1216    pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
1217}
1218
1219#[cfg(test)]
1220mod tests {
1221    use super::*;
1222    use agent_servers::{AgentServerCommand, AgentServerVersion};
1223    use async_pipe::{PipeReader, PipeWriter};
1224    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1225    use gpui::{AsyncApp, TestAppContext};
1226    use indoc::indoc;
1227    use project::FakeFs;
1228    use serde_json::json;
1229    use settings::SettingsStore;
1230    use smol::{future::BoxedLocal, stream::StreamExt as _};
1231    use std::{cell::RefCell, env, path::Path, rc::Rc, time::Duration};
1232    use util::path;
1233
1234    fn init_test(cx: &mut TestAppContext) {
1235        env_logger::try_init().ok();
1236        cx.update(|cx| {
1237            let settings_store = SettingsStore::test(cx);
1238            cx.set_global(settings_store);
1239            Project::init_settings(cx);
1240            language::init(cx);
1241        });
1242    }
1243
1244    #[gpui::test]
1245    async fn test_thinking_concatenation(cx: &mut TestAppContext) {
1246        init_test(cx);
1247
1248        let fs = FakeFs::new(cx.executor());
1249        let project = Project::test(fs, [], cx).await;
1250        let (thread, fake_server) = fake_acp_thread(project, cx);
1251
1252        fake_server.update(cx, |fake_server, _| {
1253            fake_server.on_user_message(move |_, server, mut cx| async move {
1254                server
1255                    .update(&mut cx, |server, _| {
1256                        server.send_to_zed(acp::StreamAssistantMessageChunkParams {
1257                            chunk: acp::AssistantMessageChunk::Thought {
1258                                thought: "Thinking ".into(),
1259                            },
1260                        })
1261                    })?
1262                    .await
1263                    .unwrap();
1264                server
1265                    .update(&mut cx, |server, _| {
1266                        server.send_to_zed(acp::StreamAssistantMessageChunkParams {
1267                            chunk: acp::AssistantMessageChunk::Thought {
1268                                thought: "hard!".into(),
1269                            },
1270                        })
1271                    })?
1272                    .await
1273                    .unwrap();
1274
1275                Ok(())
1276            })
1277        });
1278
1279        thread
1280            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1281            .await
1282            .unwrap();
1283
1284        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1285        assert_eq!(
1286            output,
1287            indoc! {r#"
1288            ## User
1289
1290            Hello from Zed!
1291
1292            ## Assistant
1293
1294            <thinking>
1295            Thinking hard!
1296            </thinking>
1297
1298            "#}
1299        );
1300    }
1301
1302    #[gpui::test]
1303    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1304        init_test(cx);
1305
1306        let fs = FakeFs::new(cx.executor());
1307        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1308            .await;
1309        let project = Project::test(fs.clone(), [], cx).await;
1310        let (thread, fake_server) = fake_acp_thread(project.clone(), cx);
1311        let (worktree, pathbuf) = project
1312            .update(cx, |project, cx| {
1313                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1314            })
1315            .await
1316            .unwrap();
1317        let buffer = project
1318            .update(cx, |project, cx| {
1319                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1320            })
1321            .await
1322            .unwrap();
1323
1324        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1325        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1326
1327        fake_server.update(cx, |fake_server, _| {
1328            fake_server.on_user_message(move |_, server, mut cx| {
1329                let read_file_tx = read_file_tx.clone();
1330                async move {
1331                    let content = server
1332                        .update(&mut cx, |server, _| {
1333                            server.send_to_zed(acp::ReadTextFileParams {
1334                                path: path!("/tmp/foo").into(),
1335                                line: None,
1336                                limit: None,
1337                            })
1338                        })?
1339                        .await
1340                        .unwrap();
1341                    assert_eq!(content.content, "one\ntwo\nthree\n");
1342                    read_file_tx.take().unwrap().send(()).unwrap();
1343                    server
1344                        .update(&mut cx, |server, _| {
1345                            server.send_to_zed(acp::WriteTextFileParams {
1346                                path: path!("/tmp/foo").into(),
1347                                content: "one\ntwo\nthree\nfour\nfive\n".to_string(),
1348                            })
1349                        })?
1350                        .await
1351                        .unwrap();
1352                    Ok(())
1353                }
1354            })
1355        });
1356
1357        let request = thread.update(cx, |thread, cx| {
1358            thread.send_raw("Extend the count in /tmp/foo", cx)
1359        });
1360        read_file_rx.await.ok();
1361        buffer.update(cx, |buffer, cx| {
1362            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1363        });
1364        cx.run_until_parked();
1365        assert_eq!(
1366            buffer.read_with(cx, |buffer, _| buffer.text()),
1367            "zero\none\ntwo\nthree\nfour\nfive\n"
1368        );
1369        assert_eq!(
1370            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1371            "zero\none\ntwo\nthree\nfour\nfive\n"
1372        );
1373        request.await.unwrap();
1374    }
1375
1376    #[gpui::test]
1377    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1378        init_test(cx);
1379
1380        let fs = FakeFs::new(cx.executor());
1381        let project = Project::test(fs, [], cx).await;
1382        let (thread, fake_server) = fake_acp_thread(project, cx);
1383
1384        let (end_turn_tx, end_turn_rx) = oneshot::channel::<()>();
1385
1386        let tool_call_id = Rc::new(RefCell::new(None));
1387        let end_turn_rx = Rc::new(RefCell::new(Some(end_turn_rx)));
1388        fake_server.update(cx, |fake_server, _| {
1389            let tool_call_id = tool_call_id.clone();
1390            fake_server.on_user_message(move |_, server, mut cx| {
1391                let end_turn_rx = end_turn_rx.clone();
1392                let tool_call_id = tool_call_id.clone();
1393                async move {
1394                    let tool_call_result = server
1395                        .update(&mut cx, |server, _| {
1396                            server.send_to_zed(acp::PushToolCallParams {
1397                                label: "Fetch".to_string(),
1398                                icon: acp::Icon::Globe,
1399                                content: None,
1400                                locations: vec![],
1401                            })
1402                        })?
1403                        .await
1404                        .unwrap();
1405                    *tool_call_id.clone().borrow_mut() = Some(tool_call_result.id);
1406                    end_turn_rx.take().unwrap().await.ok();
1407
1408                    Ok(())
1409                }
1410            })
1411        });
1412
1413        let request = thread.update(cx, |thread, cx| {
1414            thread.send_raw("Fetch https://example.com", cx)
1415        });
1416
1417        run_until_first_tool_call(&thread, cx).await;
1418
1419        thread.read_with(cx, |thread, _| {
1420            assert!(matches!(
1421                thread.entries[1],
1422                AgentThreadEntry::ToolCall(ToolCall {
1423                    status: ToolCallStatus::Allowed {
1424                        status: acp::ToolCallStatus::Running,
1425                        ..
1426                    },
1427                    ..
1428                })
1429            ));
1430        });
1431
1432        cx.run_until_parked();
1433
1434        thread
1435            .update(cx, |thread, cx| thread.cancel(cx))
1436            .await
1437            .unwrap();
1438
1439        thread.read_with(cx, |thread, _| {
1440            assert!(matches!(
1441                &thread.entries[1],
1442                AgentThreadEntry::ToolCall(ToolCall {
1443                    status: ToolCallStatus::Canceled,
1444                    ..
1445                })
1446            ));
1447        });
1448
1449        fake_server
1450            .update(cx, |fake_server, _| {
1451                fake_server.send_to_zed(acp::UpdateToolCallParams {
1452                    tool_call_id: tool_call_id.borrow().unwrap(),
1453                    status: acp::ToolCallStatus::Finished,
1454                    content: None,
1455                })
1456            })
1457            .await
1458            .unwrap();
1459
1460        drop(end_turn_tx);
1461        request.await.unwrap();
1462
1463        thread.read_with(cx, |thread, _| {
1464            assert!(matches!(
1465                thread.entries[1],
1466                AgentThreadEntry::ToolCall(ToolCall {
1467                    status: ToolCallStatus::Allowed {
1468                        status: acp::ToolCallStatus::Finished,
1469                        ..
1470                    },
1471                    ..
1472                })
1473            ));
1474        });
1475    }
1476
1477    #[gpui::test]
1478    #[cfg_attr(not(feature = "gemini"), ignore)]
1479    async fn test_gemini_basic(cx: &mut TestAppContext) {
1480        init_test(cx);
1481
1482        cx.executor().allow_parking();
1483
1484        let fs = FakeFs::new(cx.executor());
1485        let project = Project::test(fs, [], cx).await;
1486        let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
1487        thread
1488            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1489            .await
1490            .unwrap();
1491
1492        thread.read_with(cx, |thread, _| {
1493            assert_eq!(thread.entries.len(), 2);
1494            assert!(matches!(
1495                thread.entries[0],
1496                AgentThreadEntry::UserMessage(_)
1497            ));
1498            assert!(matches!(
1499                thread.entries[1],
1500                AgentThreadEntry::AssistantMessage(_)
1501            ));
1502        });
1503    }
1504
1505    #[gpui::test]
1506    #[cfg_attr(not(feature = "gemini"), ignore)]
1507    async fn test_gemini_path_mentions(cx: &mut TestAppContext) {
1508        init_test(cx);
1509
1510        cx.executor().allow_parking();
1511        let tempdir = tempfile::tempdir().unwrap();
1512        std::fs::write(
1513            tempdir.path().join("foo.rs"),
1514            indoc! {"
1515                fn main() {
1516                    println!(\"Hello, world!\");
1517                }
1518            "},
1519        )
1520        .expect("failed to write file");
1521        let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
1522        let thread = gemini_acp_thread(project.clone(), tempdir.path(), cx).await;
1523        thread
1524            .update(cx, |thread, cx| {
1525                thread.send(
1526                    acp::SendUserMessageParams {
1527                        chunks: vec![
1528                            acp::UserMessageChunk::Text {
1529                                text: "Read the file ".into(),
1530                            },
1531                            acp::UserMessageChunk::Path {
1532                                path: Path::new("foo.rs").into(),
1533                            },
1534                            acp::UserMessageChunk::Text {
1535                                text: " and tell me what the content of the println! is".into(),
1536                            },
1537                        ],
1538                    },
1539                    cx,
1540                )
1541            })
1542            .await
1543            .unwrap();
1544
1545        thread.read_with(cx, |thread, cx| {
1546            assert_eq!(thread.entries.len(), 3);
1547            assert!(matches!(
1548                thread.entries[0],
1549                AgentThreadEntry::UserMessage(_)
1550            ));
1551            assert!(matches!(thread.entries[1], AgentThreadEntry::ToolCall(_)));
1552            let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries[2] else {
1553                panic!("Expected AssistantMessage")
1554            };
1555            assert!(
1556                assistant_message.to_markdown(cx).contains("Hello, world!"),
1557                "unexpected assistant message: {:?}",
1558                assistant_message.to_markdown(cx)
1559            );
1560        });
1561    }
1562
1563    #[gpui::test]
1564    #[cfg_attr(not(feature = "gemini"), ignore)]
1565    async fn test_gemini_tool_call(cx: &mut TestAppContext) {
1566        init_test(cx);
1567
1568        cx.executor().allow_parking();
1569
1570        let fs = FakeFs::new(cx.executor());
1571        fs.insert_tree(
1572            path!("/private/tmp"),
1573            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
1574        )
1575        .await;
1576        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
1577        let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
1578        thread
1579            .update(cx, |thread, cx| {
1580                thread.send_raw(
1581                    "Read the '/private/tmp/foo' file and tell me what you see.",
1582                    cx,
1583                )
1584            })
1585            .await
1586            .unwrap();
1587        thread.read_with(cx, |thread, _cx| {
1588            assert!(matches!(
1589                &thread.entries()[2],
1590                AgentThreadEntry::ToolCall(ToolCall {
1591                    status: ToolCallStatus::Allowed { .. },
1592                    ..
1593                })
1594            ));
1595
1596            assert!(matches!(
1597                thread.entries[3],
1598                AgentThreadEntry::AssistantMessage(_)
1599            ));
1600        });
1601    }
1602
1603    #[gpui::test]
1604    #[cfg_attr(not(feature = "gemini"), ignore)]
1605    async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
1606        init_test(cx);
1607
1608        cx.executor().allow_parking();
1609
1610        let fs = FakeFs::new(cx.executor());
1611        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
1612        let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
1613        let full_turn = thread.update(cx, |thread, cx| {
1614            thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
1615        });
1616
1617        run_until_first_tool_call(&thread, cx).await;
1618
1619        let tool_call_id = thread.read_with(cx, |thread, _cx| {
1620            let AgentThreadEntry::ToolCall(ToolCall {
1621                id,
1622                status:
1623                    ToolCallStatus::WaitingForConfirmation {
1624                        confirmation: ToolCallConfirmation::Execute { root_command, .. },
1625                        ..
1626                    },
1627                ..
1628            }) = &thread.entries()[2]
1629            else {
1630                panic!();
1631            };
1632
1633            assert_eq!(root_command, "echo");
1634
1635            *id
1636        });
1637
1638        thread.update(cx, |thread, cx| {
1639            thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
1640
1641            assert!(matches!(
1642                &thread.entries()[2],
1643                AgentThreadEntry::ToolCall(ToolCall {
1644                    status: ToolCallStatus::Allowed { .. },
1645                    ..
1646                })
1647            ));
1648        });
1649
1650        full_turn.await.unwrap();
1651
1652        thread.read_with(cx, |thread, cx| {
1653            let AgentThreadEntry::ToolCall(ToolCall {
1654                content: Some(ToolCallContent::Markdown { markdown }),
1655                status: ToolCallStatus::Allowed { .. },
1656                ..
1657            }) = &thread.entries()[2]
1658            else {
1659                panic!();
1660            };
1661
1662            markdown.read_with(cx, |md, _cx| {
1663                assert!(
1664                    md.source().contains("Hello, world!"),
1665                    r#"Expected '{}' to contain "Hello, world!""#,
1666                    md.source()
1667                );
1668            });
1669        });
1670    }
1671
1672    #[gpui::test]
1673    #[cfg_attr(not(feature = "gemini"), ignore)]
1674    async fn test_gemini_cancel(cx: &mut TestAppContext) {
1675        init_test(cx);
1676
1677        cx.executor().allow_parking();
1678
1679        let fs = FakeFs::new(cx.executor());
1680        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
1681        let thread = gemini_acp_thread(project.clone(), "/private/tmp", cx).await;
1682        let full_turn = thread.update(cx, |thread, cx| {
1683            thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
1684        });
1685
1686        let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await;
1687
1688        thread.read_with(cx, |thread, _cx| {
1689            let AgentThreadEntry::ToolCall(ToolCall {
1690                id,
1691                status:
1692                    ToolCallStatus::WaitingForConfirmation {
1693                        confirmation: ToolCallConfirmation::Execute { root_command, .. },
1694                        ..
1695                    },
1696                ..
1697            }) = &thread.entries()[first_tool_call_ix]
1698            else {
1699                panic!("{:?}", thread.entries()[1]);
1700            };
1701
1702            assert_eq!(root_command, "echo");
1703
1704            *id
1705        });
1706
1707        thread
1708            .update(cx, |thread, cx| thread.cancel(cx))
1709            .await
1710            .unwrap();
1711        full_turn.await.unwrap();
1712        thread.read_with(cx, |thread, _| {
1713            let AgentThreadEntry::ToolCall(ToolCall {
1714                status: ToolCallStatus::Canceled,
1715                ..
1716            }) = &thread.entries()[first_tool_call_ix]
1717            else {
1718                panic!();
1719            };
1720        });
1721
1722        thread
1723            .update(cx, |thread, cx| {
1724                thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
1725            })
1726            .await
1727            .unwrap();
1728        thread.read_with(cx, |thread, _| {
1729            assert!(matches!(
1730                &thread.entries().last().unwrap(),
1731                AgentThreadEntry::AssistantMessage(..),
1732            ))
1733        });
1734    }
1735
1736    async fn run_until_first_tool_call(
1737        thread: &Entity<AcpThread>,
1738        cx: &mut TestAppContext,
1739    ) -> usize {
1740        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1741
1742        let subscription = cx.update(|cx| {
1743            cx.subscribe(thread, move |thread, _, cx| {
1744                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1745                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1746                        return tx.try_send(ix).unwrap();
1747                    }
1748                }
1749            })
1750        });
1751
1752        select! {
1753            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1754                panic!("Timeout waiting for tool call")
1755            }
1756            ix = rx.next().fuse() => {
1757                drop(subscription);
1758                ix.unwrap()
1759            }
1760        }
1761    }
1762
1763    pub async fn gemini_acp_thread(
1764        project: Entity<Project>,
1765        current_dir: impl AsRef<Path>,
1766        cx: &mut TestAppContext,
1767    ) -> Entity<AcpThread> {
1768        struct DevGemini;
1769
1770        impl agent_servers::AgentServer for DevGemini {
1771            async fn command(
1772                &self,
1773                _project: &Entity<Project>,
1774                _cx: &mut AsyncApp,
1775            ) -> Result<agent_servers::AgentServerCommand> {
1776                let cli_path = Path::new(env!("CARGO_MANIFEST_DIR"))
1777                    .join("../../../gemini-cli/packages/cli")
1778                    .to_string_lossy()
1779                    .to_string();
1780
1781                Ok(AgentServerCommand {
1782                    path: "node".into(),
1783                    args: vec![cli_path, "--acp".into()],
1784                    env: None,
1785                })
1786            }
1787
1788            async fn version(
1789                &self,
1790                _command: &agent_servers::AgentServerCommand,
1791            ) -> Result<AgentServerVersion> {
1792                Ok(AgentServerVersion {
1793                    current_version: "0.1.0".into(),
1794                    supported: true,
1795                })
1796            }
1797        }
1798
1799        let thread = AcpThread::spawn(DevGemini, current_dir.as_ref(), project, &mut cx.to_async())
1800            .await
1801            .unwrap();
1802
1803        thread
1804            .update(cx, |thread, _| thread.initialize())
1805            .await
1806            .unwrap();
1807        thread
1808    }
1809
1810    pub fn fake_acp_thread(
1811        project: Entity<Project>,
1812        cx: &mut TestAppContext,
1813    ) -> (Entity<AcpThread>, Entity<FakeAcpServer>) {
1814        let (stdin_tx, stdin_rx) = async_pipe::pipe();
1815        let (stdout_tx, stdout_rx) = async_pipe::pipe();
1816        let thread = cx.update(|cx| cx.new(|cx| AcpThread::fake(stdin_tx, stdout_rx, project, cx)));
1817        let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
1818        (thread, agent)
1819    }
1820
1821    pub struct FakeAcpServer {
1822        connection: acp::ClientConnection,
1823        _io_task: Task<()>,
1824        on_user_message: Option<
1825            Rc<
1826                dyn Fn(
1827                    acp::SendUserMessageParams,
1828                    Entity<FakeAcpServer>,
1829                    AsyncApp,
1830                ) -> LocalBoxFuture<'static, Result<(), acp::Error>>,
1831            >,
1832        >,
1833    }
1834
1835    #[derive(Clone)]
1836    struct FakeAgent {
1837        server: Entity<FakeAcpServer>,
1838        cx: AsyncApp,
1839    }
1840
1841    impl acp::Agent for FakeAgent {
1842        async fn initialize(&self) -> Result<acp::InitializeResponse, acp::Error> {
1843            Ok(acp::InitializeResponse {
1844                is_authenticated: true,
1845            })
1846        }
1847
1848        async fn authenticate(&self) -> Result<(), acp::Error> {
1849            Ok(())
1850        }
1851
1852        async fn cancel_send_message(&self) -> Result<(), acp::Error> {
1853            Ok(())
1854        }
1855
1856        async fn send_user_message(
1857            &self,
1858            request: acp::SendUserMessageParams,
1859        ) -> Result<(), acp::Error> {
1860            let mut cx = self.cx.clone();
1861            let handler = self
1862                .server
1863                .update(&mut cx, |server, _| server.on_user_message.clone())
1864                .ok()
1865                .flatten();
1866            if let Some(handler) = handler {
1867                handler(request, self.server.clone(), self.cx.clone()).await
1868            } else {
1869                Err(anyhow::anyhow!("No handler for on_user_message").into())
1870            }
1871        }
1872    }
1873
1874    impl FakeAcpServer {
1875        fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
1876            let agent = FakeAgent {
1877                server: cx.entity(),
1878                cx: cx.to_async(),
1879            };
1880            let foreground_executor = cx.foreground_executor().clone();
1881
1882            let (connection, io_fut) = acp::ClientConnection::connect_to_client(
1883                agent.clone(),
1884                stdout,
1885                stdin,
1886                move |fut| {
1887                    foreground_executor.spawn(fut).detach();
1888                },
1889            );
1890            FakeAcpServer {
1891                connection: connection,
1892                on_user_message: None,
1893                _io_task: cx.background_spawn(async move {
1894                    io_fut.await.log_err();
1895                }),
1896            }
1897        }
1898
1899        fn on_user_message<F>(
1900            &mut self,
1901            handler: impl for<'a> Fn(acp::SendUserMessageParams, Entity<FakeAcpServer>, AsyncApp) -> F
1902            + 'static,
1903        ) where
1904            F: Future<Output = Result<(), acp::Error>> + 'static,
1905        {
1906            self.on_user_message
1907                .replace(Rc::new(move |request, server, cx| {
1908                    handler(request, server, cx).boxed_local()
1909                }));
1910        }
1911
1912        fn send_to_zed<T: acp::ClientRequest + 'static>(
1913            &self,
1914            message: T,
1915        ) -> BoxedLocal<Result<T::Response>> {
1916            self.connection
1917                .request(message)
1918                .map(|f| f.map_err(|err| anyhow!(err)))
1919                .boxed_local()
1920        }
1921    }
1922}