acp_thread.rs

   1mod connection;
   2pub use connection::*;
   3
   4pub use acp::ToolCallId;
   5use agentic_coding_protocol::{
   6    self as acp, AgentRequest, ProtocolVersion, ToolCallConfirmationOutcome, ToolCallLocation,
   7    UserMessageChunk,
   8};
   9use anyhow::{Context as _, Result};
  10use assistant_tool::ActionLog;
  11use buffer_diff::BufferDiff;
  12use editor::{Bias, MultiBuffer, PathKey};
  13use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  14use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  15use itertools::Itertools;
  16use language::{
  17    Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
  18    text_diff,
  19};
  20use markdown::Markdown;
  21use project::{AgentLocation, Project};
  22use std::collections::HashMap;
  23use std::error::Error;
  24use std::fmt::{Formatter, Write};
  25use std::{
  26    fmt::Display,
  27    mem,
  28    path::{Path, PathBuf},
  29    sync::Arc,
  30};
  31use ui::{App, IconName};
  32use util::ResultExt;
  33
  34#[derive(Clone, Debug, Eq, PartialEq)]
  35pub struct UserMessage {
  36    pub content: Entity<Markdown>,
  37}
  38
  39impl UserMessage {
  40    pub fn from_acp(
  41        message: &acp::SendUserMessageParams,
  42        language_registry: Arc<LanguageRegistry>,
  43        cx: &mut App,
  44    ) -> Self {
  45        let mut md_source = String::new();
  46
  47        for chunk in &message.chunks {
  48            match chunk {
  49                UserMessageChunk::Text { text } => md_source.push_str(&text),
  50                UserMessageChunk::Path { path } => {
  51                    write!(&mut md_source, "{}", MentionPath(&path)).unwrap()
  52                }
  53            }
  54        }
  55
  56        Self {
  57            content: cx
  58                .new(|cx| Markdown::new(md_source.into(), Some(language_registry), None, cx)),
  59        }
  60    }
  61
  62    fn to_markdown(&self, cx: &App) -> String {
  63        format!("## User\n\n{}\n\n", self.content.read(cx).source())
  64    }
  65}
  66
  67#[derive(Debug)]
  68pub struct MentionPath<'a>(&'a Path);
  69
  70impl<'a> MentionPath<'a> {
  71    const PREFIX: &'static str = "@file:";
  72
  73    pub fn new(path: &'a Path) -> Self {
  74        MentionPath(path)
  75    }
  76
  77    pub fn try_parse(url: &'a str) -> Option<Self> {
  78        let path = url.strip_prefix(Self::PREFIX)?;
  79        Some(MentionPath(Path::new(path)))
  80    }
  81
  82    pub fn path(&self) -> &Path {
  83        self.0
  84    }
  85}
  86
  87impl Display for MentionPath<'_> {
  88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  89        write!(
  90            f,
  91            "[@{}]({}{})",
  92            self.0.file_name().unwrap_or_default().display(),
  93            Self::PREFIX,
  94            self.0.display()
  95        )
  96    }
  97}
  98
  99#[derive(Clone, Debug, Eq, PartialEq)]
 100pub struct AssistantMessage {
 101    pub chunks: Vec<AssistantMessageChunk>,
 102}
 103
 104impl AssistantMessage {
 105    pub fn to_markdown(&self, cx: &App) -> String {
 106        format!(
 107            "## Assistant\n\n{}\n\n",
 108            self.chunks
 109                .iter()
 110                .map(|chunk| chunk.to_markdown(cx))
 111                .join("\n\n")
 112        )
 113    }
 114}
 115
 116#[derive(Clone, Debug, Eq, PartialEq)]
 117pub enum AssistantMessageChunk {
 118    Text { chunk: Entity<Markdown> },
 119    Thought { chunk: Entity<Markdown> },
 120}
 121
 122impl AssistantMessageChunk {
 123    pub fn from_acp(
 124        chunk: acp::AssistantMessageChunk,
 125        language_registry: Arc<LanguageRegistry>,
 126        cx: &mut App,
 127    ) -> Self {
 128        match chunk {
 129            acp::AssistantMessageChunk::Text { text } => Self::Text {
 130                chunk: cx.new(|cx| Markdown::new(text.into(), Some(language_registry), None, cx)),
 131            },
 132            acp::AssistantMessageChunk::Thought { thought } => Self::Thought {
 133                chunk: cx
 134                    .new(|cx| Markdown::new(thought.into(), Some(language_registry), None, cx)),
 135            },
 136        }
 137    }
 138
 139    pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
 140        Self::Text {
 141            chunk: cx.new(|cx| {
 142                Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
 143            }),
 144        }
 145    }
 146
 147    fn to_markdown(&self, cx: &App) -> String {
 148        match self {
 149            Self::Text { chunk } => chunk.read(cx).source().to_string(),
 150            Self::Thought { chunk } => {
 151                format!("<thinking>\n{}\n</thinking>", chunk.read(cx).source())
 152            }
 153        }
 154    }
 155}
 156
 157#[derive(Debug)]
 158pub enum AgentThreadEntry {
 159    UserMessage(UserMessage),
 160    AssistantMessage(AssistantMessage),
 161    ToolCall(ToolCall),
 162}
 163
 164impl AgentThreadEntry {
 165    fn to_markdown(&self, cx: &App) -> String {
 166        match self {
 167            Self::UserMessage(message) => message.to_markdown(cx),
 168            Self::AssistantMessage(message) => message.to_markdown(cx),
 169            Self::ToolCall(too_call) => too_call.to_markdown(cx),
 170        }
 171    }
 172
 173    pub fn diff(&self) -> Option<&Diff> {
 174        if let AgentThreadEntry::ToolCall(ToolCall {
 175            content: Some(ToolCallContent::Diff { diff }),
 176            ..
 177        }) = self
 178        {
 179            Some(&diff)
 180        } else {
 181            None
 182        }
 183    }
 184
 185    pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
 186        if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
 187            Some(locations)
 188        } else {
 189            None
 190        }
 191    }
 192}
 193
 194#[derive(Debug)]
 195pub struct ToolCall {
 196    pub id: acp::ToolCallId,
 197    pub label: Entity<Markdown>,
 198    pub icon: IconName,
 199    pub content: Option<ToolCallContent>,
 200    pub status: ToolCallStatus,
 201    pub locations: Vec<acp::ToolCallLocation>,
 202}
 203
 204impl ToolCall {
 205    fn to_markdown(&self, cx: &App) -> String {
 206        let mut markdown = format!(
 207            "**Tool Call: {}**\nStatus: {}\n\n",
 208            self.label.read(cx).source(),
 209            self.status
 210        );
 211        if let Some(content) = &self.content {
 212            markdown.push_str(content.to_markdown(cx).as_str());
 213            markdown.push_str("\n\n");
 214        }
 215        markdown
 216    }
 217}
 218
 219#[derive(Debug)]
 220pub enum ToolCallStatus {
 221    WaitingForConfirmation {
 222        confirmation: ToolCallConfirmation,
 223        respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
 224    },
 225    Allowed {
 226        status: acp::ToolCallStatus,
 227    },
 228    Rejected,
 229    Canceled,
 230}
 231
 232impl Display for ToolCallStatus {
 233    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 234        write!(
 235            f,
 236            "{}",
 237            match self {
 238                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 239                ToolCallStatus::Allowed { status } => match status {
 240                    acp::ToolCallStatus::Running => "Running",
 241                    acp::ToolCallStatus::Finished => "Finished",
 242                    acp::ToolCallStatus::Error => "Error",
 243                },
 244                ToolCallStatus::Rejected => "Rejected",
 245                ToolCallStatus::Canceled => "Canceled",
 246            }
 247        )
 248    }
 249}
 250
 251#[derive(Debug)]
 252pub enum ToolCallConfirmation {
 253    Edit {
 254        description: Option<Entity<Markdown>>,
 255    },
 256    Execute {
 257        command: String,
 258        root_command: String,
 259        description: Option<Entity<Markdown>>,
 260    },
 261    Mcp {
 262        server_name: String,
 263        tool_name: String,
 264        tool_display_name: String,
 265        description: Option<Entity<Markdown>>,
 266    },
 267    Fetch {
 268        urls: Vec<SharedString>,
 269        description: Option<Entity<Markdown>>,
 270    },
 271    Other {
 272        description: Entity<Markdown>,
 273    },
 274}
 275
 276impl ToolCallConfirmation {
 277    pub fn from_acp(
 278        confirmation: acp::ToolCallConfirmation,
 279        language_registry: Arc<LanguageRegistry>,
 280        cx: &mut App,
 281    ) -> Self {
 282        let to_md = |description: String, cx: &mut App| -> Entity<Markdown> {
 283            cx.new(|cx| {
 284                Markdown::new(
 285                    description.into(),
 286                    Some(language_registry.clone()),
 287                    None,
 288                    cx,
 289                )
 290            })
 291        };
 292
 293        match confirmation {
 294            acp::ToolCallConfirmation::Edit { description } => Self::Edit {
 295                description: description.map(|description| to_md(description, cx)),
 296            },
 297            acp::ToolCallConfirmation::Execute {
 298                command,
 299                root_command,
 300                description,
 301            } => Self::Execute {
 302                command,
 303                root_command,
 304                description: description.map(|description| to_md(description, cx)),
 305            },
 306            acp::ToolCallConfirmation::Mcp {
 307                server_name,
 308                tool_name,
 309                tool_display_name,
 310                description,
 311            } => Self::Mcp {
 312                server_name,
 313                tool_name,
 314                tool_display_name,
 315                description: description.map(|description| to_md(description, cx)),
 316            },
 317            acp::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch {
 318                urls: urls.iter().map(|url| url.into()).collect(),
 319                description: description.map(|description| to_md(description, cx)),
 320            },
 321            acp::ToolCallConfirmation::Other { description } => Self::Other {
 322                description: to_md(description, cx),
 323            },
 324        }
 325    }
 326}
 327
 328#[derive(Debug)]
 329pub enum ToolCallContent {
 330    Markdown { markdown: Entity<Markdown> },
 331    Diff { diff: Diff },
 332}
 333
 334impl ToolCallContent {
 335    pub fn from_acp(
 336        content: acp::ToolCallContent,
 337        language_registry: Arc<LanguageRegistry>,
 338        cx: &mut App,
 339    ) -> Self {
 340        match content {
 341            acp::ToolCallContent::Markdown { markdown } => Self::Markdown {
 342                markdown: cx.new(|cx| Markdown::new_text(markdown.into(), cx)),
 343            },
 344            acp::ToolCallContent::Diff { diff } => Self::Diff {
 345                diff: Diff::from_acp(diff, language_registry, cx),
 346            },
 347        }
 348    }
 349
 350    fn to_markdown(&self, cx: &App) -> String {
 351        match self {
 352            Self::Markdown { markdown } => markdown.read(cx).source().to_string(),
 353            Self::Diff { diff } => diff.to_markdown(cx),
 354        }
 355    }
 356}
 357
 358#[derive(Debug)]
 359pub struct Diff {
 360    pub multibuffer: Entity<MultiBuffer>,
 361    pub path: PathBuf,
 362    pub new_buffer: Entity<Buffer>,
 363    pub old_buffer: Entity<Buffer>,
 364    _task: Task<Result<()>>,
 365}
 366
 367impl Diff {
 368    pub fn from_acp(
 369        diff: acp::Diff,
 370        language_registry: Arc<LanguageRegistry>,
 371        cx: &mut App,
 372    ) -> Self {
 373        let acp::Diff {
 374            path,
 375            old_text,
 376            new_text,
 377        } = diff;
 378
 379        let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
 380
 381        let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
 382        let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
 383        let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
 384        let old_buffer_snapshot = old_buffer.read(cx).snapshot();
 385        let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
 386        let diff_task = buffer_diff.update(cx, |diff, cx| {
 387            diff.set_base_text(
 388                old_buffer_snapshot,
 389                Some(language_registry.clone()),
 390                new_buffer_snapshot,
 391                cx,
 392            )
 393        });
 394
 395        let task = cx.spawn({
 396            let multibuffer = multibuffer.clone();
 397            let path = path.clone();
 398            let new_buffer = new_buffer.clone();
 399            async move |cx| {
 400                diff_task.await?;
 401
 402                multibuffer
 403                    .update(cx, |multibuffer, cx| {
 404                        let hunk_ranges = {
 405                            let buffer = new_buffer.read(cx);
 406                            let diff = buffer_diff.read(cx);
 407                            diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
 408                                .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
 409                                .collect::<Vec<_>>()
 410                        };
 411
 412                        multibuffer.set_excerpts_for_path(
 413                            PathKey::for_buffer(&new_buffer, cx),
 414                            new_buffer.clone(),
 415                            hunk_ranges,
 416                            editor::DEFAULT_MULTIBUFFER_CONTEXT,
 417                            cx,
 418                        );
 419                        multibuffer.add_diff(buffer_diff.clone(), cx);
 420                    })
 421                    .log_err();
 422
 423                if let Some(language) = language_registry
 424                    .language_for_file_path(&path)
 425                    .await
 426                    .log_err()
 427                {
 428                    new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
 429                }
 430
 431                anyhow::Ok(())
 432            }
 433        });
 434
 435        Self {
 436            multibuffer,
 437            path,
 438            new_buffer,
 439            old_buffer,
 440            _task: task,
 441        }
 442    }
 443
 444    fn to_markdown(&self, cx: &App) -> String {
 445        let buffer_text = self
 446            .multibuffer
 447            .read(cx)
 448            .all_buffers()
 449            .iter()
 450            .map(|buffer| buffer.read(cx).text())
 451            .join("\n");
 452        format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
 453    }
 454}
 455
 456pub struct AcpThread {
 457    entries: Vec<AgentThreadEntry>,
 458    title: SharedString,
 459    project: Entity<Project>,
 460    action_log: Entity<ActionLog>,
 461    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 462    send_task: Option<Task<()>>,
 463    connection: Arc<dyn AgentConnection>,
 464    child_status: Option<Task<Result<()>>>,
 465}
 466
 467pub enum AcpThreadEvent {
 468    NewEntry,
 469    EntryUpdated(usize),
 470}
 471
 472impl EventEmitter<AcpThreadEvent> for AcpThread {}
 473
 474#[derive(PartialEq, Eq)]
 475pub enum ThreadStatus {
 476    Idle,
 477    WaitingForToolConfirmation,
 478    Generating,
 479}
 480
 481#[derive(Debug, Clone)]
 482pub enum LoadError {
 483    Unsupported {
 484        error_message: SharedString,
 485        upgrade_message: SharedString,
 486        upgrade_command: String,
 487    },
 488    Exited(i32),
 489    Other(SharedString),
 490}
 491
 492impl Display for LoadError {
 493    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 494        match self {
 495            LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
 496            LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
 497            LoadError::Other(msg) => write!(f, "{}", msg),
 498        }
 499    }
 500}
 501
 502impl Error for LoadError {}
 503
 504impl AcpThread {
 505    pub fn new(
 506        connection: impl AgentConnection + 'static,
 507        title: SharedString,
 508        child_status: Option<Task<Result<()>>>,
 509        project: Entity<Project>,
 510        cx: &mut Context<Self>,
 511    ) -> Self {
 512        let action_log = cx.new(|_| ActionLog::new(project.clone()));
 513
 514        Self {
 515            action_log,
 516            shared_buffers: Default::default(),
 517            entries: Default::default(),
 518            title,
 519            project,
 520            send_task: None,
 521            connection: Arc::new(connection),
 522            child_status,
 523        }
 524    }
 525
 526    /// Send a request to the agent and wait for a response.
 527    pub fn request<R: AgentRequest + 'static>(
 528        &self,
 529        params: R,
 530    ) -> impl use<R> + Future<Output = Result<R::Response>> {
 531        let params = params.into_any();
 532        let result = self.connection.request_any(params);
 533        async move {
 534            let result = result.await?;
 535            Ok(R::response_from_any(result)?)
 536        }
 537    }
 538
 539    pub fn action_log(&self) -> &Entity<ActionLog> {
 540        &self.action_log
 541    }
 542
 543    pub fn project(&self) -> &Entity<Project> {
 544        &self.project
 545    }
 546
 547    pub fn title(&self) -> SharedString {
 548        self.title.clone()
 549    }
 550
 551    pub fn entries(&self) -> &[AgentThreadEntry] {
 552        &self.entries
 553    }
 554
 555    pub fn status(&self) -> ThreadStatus {
 556        if self.send_task.is_some() {
 557            if self.waiting_for_tool_confirmation() {
 558                ThreadStatus::WaitingForToolConfirmation
 559            } else {
 560                ThreadStatus::Generating
 561            }
 562        } else {
 563            ThreadStatus::Idle
 564        }
 565    }
 566
 567    pub fn has_pending_edit_tool_calls(&self) -> bool {
 568        for entry in self.entries.iter().rev() {
 569            match entry {
 570                AgentThreadEntry::UserMessage(_) => return false,
 571                AgentThreadEntry::ToolCall(ToolCall {
 572                    status:
 573                        ToolCallStatus::Allowed {
 574                            status: acp::ToolCallStatus::Running,
 575                            ..
 576                        },
 577                    content: Some(ToolCallContent::Diff { .. }),
 578                    ..
 579                }) => return true,
 580                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 581            }
 582        }
 583
 584        false
 585    }
 586
 587    pub fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
 588        self.entries.push(entry);
 589        cx.emit(AcpThreadEvent::NewEntry);
 590    }
 591
 592    pub fn push_assistant_chunk(
 593        &mut self,
 594        chunk: acp::AssistantMessageChunk,
 595        cx: &mut Context<Self>,
 596    ) {
 597        let entries_len = self.entries.len();
 598        if let Some(last_entry) = self.entries.last_mut()
 599            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
 600        {
 601            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 602
 603            match (chunks.last_mut(), &chunk) {
 604                (
 605                    Some(AssistantMessageChunk::Text { chunk: old_chunk }),
 606                    acp::AssistantMessageChunk::Text { text: new_chunk },
 607                )
 608                | (
 609                    Some(AssistantMessageChunk::Thought { chunk: old_chunk }),
 610                    acp::AssistantMessageChunk::Thought { thought: new_chunk },
 611                ) => {
 612                    old_chunk.update(cx, |old_chunk, cx| {
 613                        old_chunk.append(&new_chunk, cx);
 614                    });
 615                }
 616                _ => {
 617                    chunks.push(AssistantMessageChunk::from_acp(
 618                        chunk,
 619                        self.project.read(cx).languages().clone(),
 620                        cx,
 621                    ));
 622                }
 623            }
 624        } else {
 625            let chunk = AssistantMessageChunk::from_acp(
 626                chunk,
 627                self.project.read(cx).languages().clone(),
 628                cx,
 629            );
 630
 631            self.push_entry(
 632                AgentThreadEntry::AssistantMessage(AssistantMessage {
 633                    chunks: vec![chunk],
 634                }),
 635                cx,
 636            );
 637        }
 638    }
 639
 640    pub fn request_new_tool_call(
 641        &mut self,
 642        tool_call: acp::RequestToolCallConfirmationParams,
 643        cx: &mut Context<Self>,
 644    ) -> ToolCallRequest {
 645        let (tx, rx) = oneshot::channel();
 646
 647        let status = ToolCallStatus::WaitingForConfirmation {
 648            confirmation: ToolCallConfirmation::from_acp(
 649                tool_call.confirmation,
 650                self.project.read(cx).languages().clone(),
 651                cx,
 652            ),
 653            respond_tx: tx,
 654        };
 655
 656        let id = self.insert_tool_call(tool_call.tool_call, status, cx);
 657        ToolCallRequest { id, outcome: rx }
 658    }
 659
 660    pub fn request_tool_call_confirmation(
 661        &mut self,
 662        tool_call_id: ToolCallId,
 663        confirmation: acp::ToolCallConfirmation,
 664        cx: &mut Context<Self>,
 665    ) -> Result<ToolCallRequest> {
 666        let project = self.project.read(cx).languages().clone();
 667        let Some((_, call)) = self.tool_call_mut(tool_call_id) else {
 668            anyhow::bail!("Tool call not found");
 669        };
 670
 671        let (tx, rx) = oneshot::channel();
 672
 673        call.status = ToolCallStatus::WaitingForConfirmation {
 674            confirmation: ToolCallConfirmation::from_acp(confirmation, project, cx),
 675            respond_tx: tx,
 676        };
 677
 678        Ok(ToolCallRequest {
 679            id: tool_call_id,
 680            outcome: rx,
 681        })
 682    }
 683
 684    pub fn push_tool_call(
 685        &mut self,
 686        request: acp::PushToolCallParams,
 687        cx: &mut Context<Self>,
 688    ) -> acp::ToolCallId {
 689        let status = ToolCallStatus::Allowed {
 690            status: acp::ToolCallStatus::Running,
 691        };
 692
 693        self.insert_tool_call(request, status, cx)
 694    }
 695
 696    fn insert_tool_call(
 697        &mut self,
 698        tool_call: acp::PushToolCallParams,
 699        status: ToolCallStatus,
 700        cx: &mut Context<Self>,
 701    ) -> acp::ToolCallId {
 702        let language_registry = self.project.read(cx).languages().clone();
 703        let id = acp::ToolCallId(self.entries.len() as u64);
 704        let call = ToolCall {
 705            id,
 706            label: cx.new(|cx| {
 707                Markdown::new(
 708                    tool_call.label.into(),
 709                    Some(language_registry.clone()),
 710                    None,
 711                    cx,
 712                )
 713            }),
 714            icon: acp_icon_to_ui_icon(tool_call.icon),
 715            content: tool_call
 716                .content
 717                .map(|content| ToolCallContent::from_acp(content, language_registry, cx)),
 718            locations: tool_call.locations,
 719            status,
 720        };
 721
 722        let location = call.locations.last().cloned();
 723        if let Some(location) = location {
 724            self.set_project_location(location, cx)
 725        }
 726
 727        self.push_entry(AgentThreadEntry::ToolCall(call), cx);
 728
 729        id
 730    }
 731
 732    pub fn authorize_tool_call(
 733        &mut self,
 734        id: acp::ToolCallId,
 735        outcome: acp::ToolCallConfirmationOutcome,
 736        cx: &mut Context<Self>,
 737    ) {
 738        let Some((ix, call)) = self.tool_call_mut(id) else {
 739            return;
 740        };
 741
 742        let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
 743            ToolCallStatus::Rejected
 744        } else {
 745            ToolCallStatus::Allowed {
 746                status: acp::ToolCallStatus::Running,
 747            }
 748        };
 749
 750        let curr_status = mem::replace(&mut call.status, new_status);
 751
 752        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
 753            respond_tx.send(outcome).log_err();
 754        } else if cfg!(debug_assertions) {
 755            panic!("tried to authorize an already authorized tool call");
 756        }
 757
 758        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 759    }
 760
 761    pub fn update_tool_call(
 762        &mut self,
 763        id: acp::ToolCallId,
 764        new_status: acp::ToolCallStatus,
 765        new_content: Option<acp::ToolCallContent>,
 766        cx: &mut Context<Self>,
 767    ) -> Result<()> {
 768        let language_registry = self.project.read(cx).languages().clone();
 769        let (ix, call) = self.tool_call_mut(id).context("Entry not found")?;
 770
 771        call.content = new_content
 772            .map(|new_content| ToolCallContent::from_acp(new_content, language_registry, cx));
 773
 774        match &mut call.status {
 775            ToolCallStatus::Allowed { status } => {
 776                *status = new_status;
 777            }
 778            ToolCallStatus::WaitingForConfirmation { .. } => {
 779                anyhow::bail!("Tool call hasn't been authorized yet")
 780            }
 781            ToolCallStatus::Rejected => {
 782                anyhow::bail!("Tool call was rejected and therefore can't be updated")
 783            }
 784            ToolCallStatus::Canceled => {
 785                call.status = ToolCallStatus::Allowed { status: new_status };
 786            }
 787        }
 788
 789        let location = call.locations.last().cloned();
 790        if let Some(location) = location {
 791            self.set_project_location(location, cx)
 792        }
 793
 794        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 795        Ok(())
 796    }
 797
 798    fn tool_call_mut(&mut self, id: acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
 799        let entry = self.entries.get_mut(id.0 as usize);
 800        debug_assert!(
 801            entry.is_some(),
 802            "We shouldn't give out ids to entries that don't exist"
 803        );
 804        match entry {
 805            Some(AgentThreadEntry::ToolCall(call)) if call.id == id => Some((id.0 as usize, call)),
 806            _ => {
 807                if cfg!(debug_assertions) {
 808                    panic!("entry is not a tool call");
 809                }
 810                None
 811            }
 812        }
 813    }
 814
 815    pub fn set_project_location(&self, location: ToolCallLocation, cx: &mut Context<Self>) {
 816        self.project.update(cx, |project, cx| {
 817            let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
 818                return;
 819            };
 820            let buffer = project.open_buffer(path, cx);
 821            cx.spawn(async move |project, cx| {
 822                let buffer = buffer.await?;
 823
 824                project.update(cx, |project, cx| {
 825                    let position = if let Some(line) = location.line {
 826                        let snapshot = buffer.read(cx).snapshot();
 827                        let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
 828                        snapshot.anchor_before(point)
 829                    } else {
 830                        Anchor::MIN
 831                    };
 832
 833                    project.set_agent_location(
 834                        Some(AgentLocation {
 835                            buffer: buffer.downgrade(),
 836                            position,
 837                        }),
 838                        cx,
 839                    );
 840                })
 841            })
 842            .detach_and_log_err(cx);
 843        });
 844    }
 845
 846    /// Returns true if the last turn is awaiting tool authorization
 847    pub fn waiting_for_tool_confirmation(&self) -> bool {
 848        for entry in self.entries.iter().rev() {
 849            match &entry {
 850                AgentThreadEntry::ToolCall(call) => match call.status {
 851                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
 852                    ToolCallStatus::Allowed { .. }
 853                    | ToolCallStatus::Rejected
 854                    | ToolCallStatus::Canceled => continue,
 855                },
 856                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
 857                    // Reached the beginning of the turn
 858                    return false;
 859                }
 860            }
 861        }
 862        false
 863    }
 864
 865    pub fn initialize(&self) -> impl use<> + Future<Output = Result<acp::InitializeResponse>> {
 866        self.request(acp::InitializeParams {
 867            protocol_version: ProtocolVersion::latest(),
 868        })
 869    }
 870
 871    pub fn authenticate(&self) -> impl use<> + Future<Output = Result<()>> {
 872        self.request(acp::AuthenticateParams)
 873    }
 874
 875    #[cfg(any(test, feature = "test-support"))]
 876    pub fn send_raw(
 877        &mut self,
 878        message: &str,
 879        cx: &mut Context<Self>,
 880    ) -> BoxFuture<'static, Result<(), acp::Error>> {
 881        self.send(
 882            acp::SendUserMessageParams {
 883                chunks: vec![acp::UserMessageChunk::Text {
 884                    text: message.to_string(),
 885                }],
 886            },
 887            cx,
 888        )
 889    }
 890
 891    pub fn send(
 892        &mut self,
 893        message: acp::SendUserMessageParams,
 894        cx: &mut Context<Self>,
 895    ) -> BoxFuture<'static, Result<(), acp::Error>> {
 896        self.push_entry(
 897            AgentThreadEntry::UserMessage(UserMessage::from_acp(
 898                &message,
 899                self.project.read(cx).languages().clone(),
 900                cx,
 901            )),
 902            cx,
 903        );
 904
 905        let (tx, rx) = oneshot::channel();
 906        let cancel = self.cancel(cx);
 907
 908        self.send_task = Some(cx.spawn(async move |this, cx| {
 909            async {
 910                cancel.await.log_err();
 911
 912                let result = this.update(cx, |this, _| this.request(message))?.await;
 913                tx.send(result).log_err();
 914                this.update(cx, |this, _cx| this.send_task.take())?;
 915                anyhow::Ok(())
 916            }
 917            .await
 918            .log_err();
 919        }));
 920
 921        async move {
 922            match rx.await {
 923                Ok(Err(e)) => Err(e)?,
 924                _ => Ok(()),
 925            }
 926        }
 927        .boxed()
 928    }
 929
 930    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<(), acp::Error>> {
 931        if self.send_task.take().is_some() {
 932            let request = self.request(acp::CancelSendMessageParams);
 933            cx.spawn(async move |this, cx| {
 934                request.await?;
 935                this.update(cx, |this, _cx| {
 936                    for entry in this.entries.iter_mut() {
 937                        if let AgentThreadEntry::ToolCall(call) = entry {
 938                            let cancel = matches!(
 939                                call.status,
 940                                ToolCallStatus::WaitingForConfirmation { .. }
 941                                    | ToolCallStatus::Allowed {
 942                                        status: acp::ToolCallStatus::Running
 943                                    }
 944                            );
 945
 946                            if cancel {
 947                                let curr_status =
 948                                    mem::replace(&mut call.status, ToolCallStatus::Canceled);
 949
 950                                if let ToolCallStatus::WaitingForConfirmation {
 951                                    respond_tx, ..
 952                                } = curr_status
 953                                {
 954                                    respond_tx
 955                                        .send(acp::ToolCallConfirmationOutcome::Cancel)
 956                                        .ok();
 957                                }
 958                            }
 959                        }
 960                    }
 961                })?;
 962                Ok(())
 963            })
 964        } else {
 965            Task::ready(Ok(()))
 966        }
 967    }
 968
 969    pub fn read_text_file(
 970        &self,
 971        request: acp::ReadTextFileParams,
 972        reuse_shared_snapshot: bool,
 973        cx: &mut Context<Self>,
 974    ) -> Task<Result<String>> {
 975        let project = self.project.clone();
 976        let action_log = self.action_log.clone();
 977        cx.spawn(async move |this, cx| {
 978            let load = project.update(cx, |project, cx| {
 979                let path = project
 980                    .project_path_for_absolute_path(&request.path, cx)
 981                    .context("invalid path")?;
 982                anyhow::Ok(project.open_buffer(path, cx))
 983            });
 984            let buffer = load??.await?;
 985
 986            let snapshot = if reuse_shared_snapshot {
 987                this.read_with(cx, |this, _| {
 988                    this.shared_buffers.get(&buffer.clone()).cloned()
 989                })
 990                .log_err()
 991                .flatten()
 992            } else {
 993                None
 994            };
 995
 996            let snapshot = if let Some(snapshot) = snapshot {
 997                snapshot
 998            } else {
 999                action_log.update(cx, |action_log, cx| {
1000                    action_log.buffer_read(buffer.clone(), cx);
1001                })?;
1002                project.update(cx, |project, cx| {
1003                    let position = buffer
1004                        .read(cx)
1005                        .snapshot()
1006                        .anchor_before(Point::new(request.line.unwrap_or_default(), 0));
1007                    project.set_agent_location(
1008                        Some(AgentLocation {
1009                            buffer: buffer.downgrade(),
1010                            position,
1011                        }),
1012                        cx,
1013                    );
1014                })?;
1015
1016                buffer.update(cx, |buffer, _| buffer.snapshot())?
1017            };
1018
1019            this.update(cx, |this, _| {
1020                let text = snapshot.text();
1021                this.shared_buffers.insert(buffer.clone(), snapshot);
1022                if request.line.is_none() && request.limit.is_none() {
1023                    return Ok(text);
1024                }
1025                let limit = request.limit.unwrap_or(u32::MAX) as usize;
1026                let Some(line) = request.line else {
1027                    return Ok(text.lines().take(limit).collect::<String>());
1028                };
1029
1030                let count = text.lines().count();
1031                if count < line as usize {
1032                    anyhow::bail!("There are only {} lines", count);
1033                }
1034                Ok(text
1035                    .lines()
1036                    .skip(line as usize + 1)
1037                    .take(limit)
1038                    .collect::<String>())
1039            })?
1040        })
1041    }
1042
1043    pub fn write_text_file(
1044        &self,
1045        path: PathBuf,
1046        content: String,
1047        cx: &mut Context<Self>,
1048    ) -> Task<Result<()>> {
1049        let project = self.project.clone();
1050        let action_log = self.action_log.clone();
1051        cx.spawn(async move |this, cx| {
1052            let load = project.update(cx, |project, cx| {
1053                let path = project
1054                    .project_path_for_absolute_path(&path, cx)
1055                    .context("invalid path")?;
1056                anyhow::Ok(project.open_buffer(path, cx))
1057            });
1058            let buffer = load??.await?;
1059            let snapshot = this.update(cx, |this, cx| {
1060                this.shared_buffers
1061                    .get(&buffer)
1062                    .cloned()
1063                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1064            })?;
1065            let edits = cx
1066                .background_executor()
1067                .spawn(async move {
1068                    let old_text = snapshot.text();
1069                    text_diff(old_text.as_str(), &content)
1070                        .into_iter()
1071                        .map(|(range, replacement)| {
1072                            (
1073                                snapshot.anchor_after(range.start)
1074                                    ..snapshot.anchor_before(range.end),
1075                                replacement,
1076                            )
1077                        })
1078                        .collect::<Vec<_>>()
1079                })
1080                .await;
1081            cx.update(|cx| {
1082                project.update(cx, |project, cx| {
1083                    project.set_agent_location(
1084                        Some(AgentLocation {
1085                            buffer: buffer.downgrade(),
1086                            position: edits
1087                                .last()
1088                                .map(|(range, _)| range.end)
1089                                .unwrap_or(Anchor::MIN),
1090                        }),
1091                        cx,
1092                    );
1093                });
1094
1095                action_log.update(cx, |action_log, cx| {
1096                    action_log.buffer_read(buffer.clone(), cx);
1097                });
1098                buffer.update(cx, |buffer, cx| {
1099                    buffer.edit(edits, None, cx);
1100                });
1101                action_log.update(cx, |action_log, cx| {
1102                    action_log.buffer_edited(buffer.clone(), cx);
1103                });
1104            })?;
1105            project
1106                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1107                .await
1108        })
1109    }
1110
1111    pub fn child_status(&mut self) -> Option<Task<Result<()>>> {
1112        self.child_status.take()
1113    }
1114
1115    pub fn to_markdown(&self, cx: &App) -> String {
1116        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1117    }
1118}
1119
1120#[derive(Clone)]
1121pub struct AcpClientDelegate {
1122    thread: WeakEntity<AcpThread>,
1123    cx: AsyncApp,
1124    // sent_buffer_versions: HashMap<Entity<Buffer>, HashMap<u64, BufferSnapshot>>,
1125}
1126
1127impl AcpClientDelegate {
1128    pub fn new(thread: WeakEntity<AcpThread>, cx: AsyncApp) -> Self {
1129        Self { thread, cx }
1130    }
1131
1132    pub async fn request_existing_tool_call_confirmation(
1133        &self,
1134        tool_call_id: ToolCallId,
1135        confirmation: acp::ToolCallConfirmation,
1136    ) -> Result<ToolCallConfirmationOutcome> {
1137        let cx = &mut self.cx.clone();
1138        let ToolCallRequest { outcome, .. } = cx
1139            .update(|cx| {
1140                self.thread.update(cx, |thread, cx| {
1141                    thread.request_tool_call_confirmation(tool_call_id, confirmation, cx)
1142                })
1143            })?
1144            .context("Failed to update thread")??;
1145
1146        Ok(outcome.await?)
1147    }
1148
1149    pub async fn read_text_file_reusing_snapshot(
1150        &self,
1151        request: acp::ReadTextFileParams,
1152    ) -> Result<acp::ReadTextFileResponse, acp::Error> {
1153        let content = self
1154            .cx
1155            .update(|cx| {
1156                self.thread
1157                    .update(cx, |thread, cx| thread.read_text_file(request, true, cx))
1158            })?
1159            .context("Failed to update thread")?
1160            .await?;
1161        Ok(acp::ReadTextFileResponse { content })
1162    }
1163}
1164
1165impl acp::Client for AcpClientDelegate {
1166    async fn stream_assistant_message_chunk(
1167        &self,
1168        params: acp::StreamAssistantMessageChunkParams,
1169    ) -> Result<(), acp::Error> {
1170        let cx = &mut self.cx.clone();
1171
1172        cx.update(|cx| {
1173            self.thread
1174                .update(cx, |thread, cx| {
1175                    thread.push_assistant_chunk(params.chunk, cx)
1176                })
1177                .ok();
1178        })?;
1179
1180        Ok(())
1181    }
1182
1183    async fn request_tool_call_confirmation(
1184        &self,
1185        request: acp::RequestToolCallConfirmationParams,
1186    ) -> Result<acp::RequestToolCallConfirmationResponse, acp::Error> {
1187        let cx = &mut self.cx.clone();
1188        let ToolCallRequest { id, outcome } = cx
1189            .update(|cx| {
1190                self.thread
1191                    .update(cx, |thread, cx| thread.request_new_tool_call(request, cx))
1192            })?
1193            .context("Failed to update thread")?;
1194
1195        Ok(acp::RequestToolCallConfirmationResponse {
1196            id,
1197            outcome: outcome.await.map_err(acp::Error::into_internal_error)?,
1198        })
1199    }
1200
1201    async fn push_tool_call(
1202        &self,
1203        request: acp::PushToolCallParams,
1204    ) -> Result<acp::PushToolCallResponse, acp::Error> {
1205        let cx = &mut self.cx.clone();
1206        let id = cx
1207            .update(|cx| {
1208                self.thread
1209                    .update(cx, |thread, cx| thread.push_tool_call(request, cx))
1210            })?
1211            .context("Failed to update thread")?;
1212
1213        Ok(acp::PushToolCallResponse { id })
1214    }
1215
1216    async fn update_tool_call(&self, request: acp::UpdateToolCallParams) -> Result<(), acp::Error> {
1217        let cx = &mut self.cx.clone();
1218
1219        cx.update(|cx| {
1220            self.thread.update(cx, |thread, cx| {
1221                thread.update_tool_call(request.tool_call_id, request.status, request.content, cx)
1222            })
1223        })?
1224        .context("Failed to update thread")??;
1225
1226        Ok(())
1227    }
1228
1229    async fn read_text_file(
1230        &self,
1231        request: acp::ReadTextFileParams,
1232    ) -> Result<acp::ReadTextFileResponse, acp::Error> {
1233        let content = self
1234            .cx
1235            .update(|cx| {
1236                self.thread
1237                    .update(cx, |thread, cx| thread.read_text_file(request, false, cx))
1238            })?
1239            .context("Failed to update thread")?
1240            .await?;
1241        Ok(acp::ReadTextFileResponse { content })
1242    }
1243
1244    async fn write_text_file(&self, request: acp::WriteTextFileParams) -> Result<(), acp::Error> {
1245        self.cx
1246            .update(|cx| {
1247                self.thread.update(cx, |thread, cx| {
1248                    thread.write_text_file(request.path, request.content, cx)
1249                })
1250            })?
1251            .context("Failed to update thread")?
1252            .await?;
1253
1254        Ok(())
1255    }
1256}
1257
1258fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
1259    match icon {
1260        acp::Icon::FileSearch => IconName::ToolSearch,
1261        acp::Icon::Folder => IconName::ToolFolder,
1262        acp::Icon::Globe => IconName::ToolWeb,
1263        acp::Icon::Hammer => IconName::ToolHammer,
1264        acp::Icon::LightBulb => IconName::ToolBulb,
1265        acp::Icon::Pencil => IconName::ToolPencil,
1266        acp::Icon::Regex => IconName::ToolRegex,
1267        acp::Icon::Terminal => IconName::ToolTerminal,
1268    }
1269}
1270
1271pub struct ToolCallRequest {
1272    pub id: acp::ToolCallId,
1273    pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
1274}
1275
1276#[cfg(test)]
1277mod tests {
1278    use super::*;
1279    use anyhow::anyhow;
1280    use async_pipe::{PipeReader, PipeWriter};
1281    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1282    use gpui::{AsyncApp, TestAppContext};
1283    use indoc::indoc;
1284    use project::FakeFs;
1285    use serde_json::json;
1286    use settings::SettingsStore;
1287    use smol::{future::BoxedLocal, stream::StreamExt as _};
1288    use std::{cell::RefCell, rc::Rc, time::Duration};
1289    use util::path;
1290
1291    fn init_test(cx: &mut TestAppContext) {
1292        env_logger::try_init().ok();
1293        cx.update(|cx| {
1294            let settings_store = SettingsStore::test(cx);
1295            cx.set_global(settings_store);
1296            Project::init_settings(cx);
1297            language::init(cx);
1298        });
1299    }
1300
1301    #[gpui::test]
1302    async fn test_thinking_concatenation(cx: &mut TestAppContext) {
1303        init_test(cx);
1304
1305        let fs = FakeFs::new(cx.executor());
1306        let project = Project::test(fs, [], cx).await;
1307        let (thread, fake_server) = fake_acp_thread(project, cx);
1308
1309        fake_server.update(cx, |fake_server, _| {
1310            fake_server.on_user_message(move |_, server, mut cx| async move {
1311                server
1312                    .update(&mut cx, |server, _| {
1313                        server.send_to_zed(acp::StreamAssistantMessageChunkParams {
1314                            chunk: acp::AssistantMessageChunk::Thought {
1315                                thought: "Thinking ".into(),
1316                            },
1317                        })
1318                    })?
1319                    .await
1320                    .unwrap();
1321                server
1322                    .update(&mut cx, |server, _| {
1323                        server.send_to_zed(acp::StreamAssistantMessageChunkParams {
1324                            chunk: acp::AssistantMessageChunk::Thought {
1325                                thought: "hard!".into(),
1326                            },
1327                        })
1328                    })?
1329                    .await
1330                    .unwrap();
1331
1332                Ok(())
1333            })
1334        });
1335
1336        thread
1337            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1338            .await
1339            .unwrap();
1340
1341        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1342        assert_eq!(
1343            output,
1344            indoc! {r#"
1345            ## User
1346
1347            Hello from Zed!
1348
1349            ## Assistant
1350
1351            <thinking>
1352            Thinking hard!
1353            </thinking>
1354
1355            "#}
1356        );
1357    }
1358
1359    #[gpui::test]
1360    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1361        init_test(cx);
1362
1363        let fs = FakeFs::new(cx.executor());
1364        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1365            .await;
1366        let project = Project::test(fs.clone(), [], cx).await;
1367        let (thread, fake_server) = fake_acp_thread(project.clone(), cx);
1368        let (worktree, pathbuf) = project
1369            .update(cx, |project, cx| {
1370                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1371            })
1372            .await
1373            .unwrap();
1374        let buffer = project
1375            .update(cx, |project, cx| {
1376                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1377            })
1378            .await
1379            .unwrap();
1380
1381        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1382        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1383
1384        fake_server.update(cx, |fake_server, _| {
1385            fake_server.on_user_message(move |_, server, mut cx| {
1386                let read_file_tx = read_file_tx.clone();
1387                async move {
1388                    let content = server
1389                        .update(&mut cx, |server, _| {
1390                            server.send_to_zed(acp::ReadTextFileParams {
1391                                path: path!("/tmp/foo").into(),
1392                                line: None,
1393                                limit: None,
1394                            })
1395                        })?
1396                        .await
1397                        .unwrap();
1398                    assert_eq!(content.content, "one\ntwo\nthree\n");
1399                    read_file_tx.take().unwrap().send(()).unwrap();
1400                    server
1401                        .update(&mut cx, |server, _| {
1402                            server.send_to_zed(acp::WriteTextFileParams {
1403                                path: path!("/tmp/foo").into(),
1404                                content: "one\ntwo\nthree\nfour\nfive\n".to_string(),
1405                            })
1406                        })?
1407                        .await
1408                        .unwrap();
1409                    Ok(())
1410                }
1411            })
1412        });
1413
1414        let request = thread.update(cx, |thread, cx| {
1415            thread.send_raw("Extend the count in /tmp/foo", cx)
1416        });
1417        read_file_rx.await.ok();
1418        buffer.update(cx, |buffer, cx| {
1419            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1420        });
1421        cx.run_until_parked();
1422        assert_eq!(
1423            buffer.read_with(cx, |buffer, _| buffer.text()),
1424            "zero\none\ntwo\nthree\nfour\nfive\n"
1425        );
1426        assert_eq!(
1427            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1428            "zero\none\ntwo\nthree\nfour\nfive\n"
1429        );
1430        request.await.unwrap();
1431    }
1432
1433    #[gpui::test]
1434    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1435        init_test(cx);
1436
1437        let fs = FakeFs::new(cx.executor());
1438        let project = Project::test(fs, [], cx).await;
1439        let (thread, fake_server) = fake_acp_thread(project, cx);
1440
1441        let (end_turn_tx, end_turn_rx) = oneshot::channel::<()>();
1442
1443        let tool_call_id = Rc::new(RefCell::new(None));
1444        let end_turn_rx = Rc::new(RefCell::new(Some(end_turn_rx)));
1445        fake_server.update(cx, |fake_server, _| {
1446            let tool_call_id = tool_call_id.clone();
1447            fake_server.on_user_message(move |_, server, mut cx| {
1448                let end_turn_rx = end_turn_rx.clone();
1449                let tool_call_id = tool_call_id.clone();
1450                async move {
1451                    let tool_call_result = server
1452                        .update(&mut cx, |server, _| {
1453                            server.send_to_zed(acp::PushToolCallParams {
1454                                label: "Fetch".to_string(),
1455                                icon: acp::Icon::Globe,
1456                                content: None,
1457                                locations: vec![],
1458                            })
1459                        })?
1460                        .await
1461                        .unwrap();
1462                    *tool_call_id.clone().borrow_mut() = Some(tool_call_result.id);
1463                    end_turn_rx.take().unwrap().await.ok();
1464
1465                    Ok(())
1466                }
1467            })
1468        });
1469
1470        let request = thread.update(cx, |thread, cx| {
1471            thread.send_raw("Fetch https://example.com", cx)
1472        });
1473
1474        run_until_first_tool_call(&thread, cx).await;
1475
1476        thread.read_with(cx, |thread, _| {
1477            assert!(matches!(
1478                thread.entries[1],
1479                AgentThreadEntry::ToolCall(ToolCall {
1480                    status: ToolCallStatus::Allowed {
1481                        status: acp::ToolCallStatus::Running,
1482                        ..
1483                    },
1484                    ..
1485                })
1486            ));
1487        });
1488
1489        cx.run_until_parked();
1490
1491        thread
1492            .update(cx, |thread, cx| thread.cancel(cx))
1493            .await
1494            .unwrap();
1495
1496        thread.read_with(cx, |thread, _| {
1497            assert!(matches!(
1498                &thread.entries[1],
1499                AgentThreadEntry::ToolCall(ToolCall {
1500                    status: ToolCallStatus::Canceled,
1501                    ..
1502                })
1503            ));
1504        });
1505
1506        fake_server
1507            .update(cx, |fake_server, _| {
1508                fake_server.send_to_zed(acp::UpdateToolCallParams {
1509                    tool_call_id: tool_call_id.borrow().unwrap(),
1510                    status: acp::ToolCallStatus::Finished,
1511                    content: None,
1512                })
1513            })
1514            .await
1515            .unwrap();
1516
1517        drop(end_turn_tx);
1518        request.await.unwrap();
1519
1520        thread.read_with(cx, |thread, _| {
1521            assert!(matches!(
1522                thread.entries[1],
1523                AgentThreadEntry::ToolCall(ToolCall {
1524                    status: ToolCallStatus::Allowed {
1525                        status: acp::ToolCallStatus::Finished,
1526                        ..
1527                    },
1528                    ..
1529                })
1530            ));
1531        });
1532    }
1533
1534    async fn run_until_first_tool_call(
1535        thread: &Entity<AcpThread>,
1536        cx: &mut TestAppContext,
1537    ) -> usize {
1538        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1539
1540        let subscription = cx.update(|cx| {
1541            cx.subscribe(thread, move |thread, _, cx| {
1542                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1543                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1544                        return tx.try_send(ix).unwrap();
1545                    }
1546                }
1547            })
1548        });
1549
1550        select! {
1551            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1552                panic!("Timeout waiting for tool call")
1553            }
1554            ix = rx.next().fuse() => {
1555                drop(subscription);
1556                ix.unwrap()
1557            }
1558        }
1559    }
1560
1561    pub fn fake_acp_thread(
1562        project: Entity<Project>,
1563        cx: &mut TestAppContext,
1564    ) -> (Entity<AcpThread>, Entity<FakeAcpServer>) {
1565        let (stdin_tx, stdin_rx) = async_pipe::pipe();
1566        let (stdout_tx, stdout_rx) = async_pipe::pipe();
1567
1568        let thread = cx.new(|cx| {
1569            let foreground_executor = cx.foreground_executor().clone();
1570            let (connection, io_fut) = acp::AgentConnection::connect_to_agent(
1571                AcpClientDelegate::new(cx.entity().downgrade(), cx.to_async()),
1572                stdin_tx,
1573                stdout_rx,
1574                move |fut| {
1575                    foreground_executor.spawn(fut).detach();
1576                },
1577            );
1578
1579            let io_task = cx.background_spawn({
1580                async move {
1581                    io_fut.await.log_err();
1582                    Ok(())
1583                }
1584            });
1585            AcpThread::new(connection, "Test".into(), Some(io_task), project, cx)
1586        });
1587        let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
1588        (thread, agent)
1589    }
1590
1591    pub struct FakeAcpServer {
1592        connection: acp::ClientConnection,
1593
1594        _io_task: Task<()>,
1595        on_user_message: Option<
1596            Rc<
1597                dyn Fn(
1598                    acp::SendUserMessageParams,
1599                    Entity<FakeAcpServer>,
1600                    AsyncApp,
1601                ) -> LocalBoxFuture<'static, Result<(), acp::Error>>,
1602            >,
1603        >,
1604    }
1605
1606    #[derive(Clone)]
1607    struct FakeAgent {
1608        server: Entity<FakeAcpServer>,
1609        cx: AsyncApp,
1610    }
1611
1612    impl acp::Agent for FakeAgent {
1613        async fn initialize(
1614            &self,
1615            params: acp::InitializeParams,
1616        ) -> Result<acp::InitializeResponse, acp::Error> {
1617            Ok(acp::InitializeResponse {
1618                protocol_version: params.protocol_version,
1619                is_authenticated: true,
1620            })
1621        }
1622
1623        async fn authenticate(&self) -> Result<(), acp::Error> {
1624            Ok(())
1625        }
1626
1627        async fn cancel_send_message(&self) -> Result<(), acp::Error> {
1628            Ok(())
1629        }
1630
1631        async fn send_user_message(
1632            &self,
1633            request: acp::SendUserMessageParams,
1634        ) -> Result<(), acp::Error> {
1635            let mut cx = self.cx.clone();
1636            let handler = self
1637                .server
1638                .update(&mut cx, |server, _| server.on_user_message.clone())
1639                .ok()
1640                .flatten();
1641            if let Some(handler) = handler {
1642                handler(request, self.server.clone(), self.cx.clone()).await
1643            } else {
1644                Err(anyhow::anyhow!("No handler for on_user_message").into())
1645            }
1646        }
1647    }
1648
1649    impl FakeAcpServer {
1650        fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
1651            let agent = FakeAgent {
1652                server: cx.entity(),
1653                cx: cx.to_async(),
1654            };
1655            let foreground_executor = cx.foreground_executor().clone();
1656
1657            let (connection, io_fut) = acp::ClientConnection::connect_to_client(
1658                agent.clone(),
1659                stdout,
1660                stdin,
1661                move |fut| {
1662                    foreground_executor.spawn(fut).detach();
1663                },
1664            );
1665            FakeAcpServer {
1666                connection: connection,
1667                on_user_message: None,
1668                _io_task: cx.background_spawn(async move {
1669                    io_fut.await.log_err();
1670                }),
1671            }
1672        }
1673
1674        fn on_user_message<F>(
1675            &mut self,
1676            handler: impl for<'a> Fn(acp::SendUserMessageParams, Entity<FakeAcpServer>, AsyncApp) -> F
1677            + 'static,
1678        ) where
1679            F: Future<Output = Result<(), acp::Error>> + 'static,
1680        {
1681            self.on_user_message
1682                .replace(Rc::new(move |request, server, cx| {
1683                    handler(request, server, cx).boxed_local()
1684                }));
1685        }
1686
1687        fn send_to_zed<T: acp::ClientRequest + 'static>(
1688            &self,
1689            message: T,
1690        ) -> BoxedLocal<Result<T::Response>> {
1691            self.connection
1692                .request(message)
1693                .map(|f| f.map_err(|err| anyhow!(err)))
1694                .boxed_local()
1695        }
1696    }
1697}