acp_thread.rs

   1mod connection;
   2pub use connection::*;
   3
   4use agent_client_protocol as acp;
   5use anyhow::{Context as _, Result};
   6use assistant_tool::ActionLog;
   7use buffer_diff::BufferDiff;
   8use editor::{Bias, MultiBuffer, PathKey};
   9use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  10use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
  11use itertools::Itertools;
  12use language::{
  13    Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
  14    text_diff,
  15};
  16use markdown::Markdown;
  17use project::{AgentLocation, Project};
  18use std::collections::HashMap;
  19use std::error::Error;
  20use std::fmt::Formatter;
  21use std::process::ExitStatus;
  22use std::rc::Rc;
  23use std::{
  24    fmt::Display,
  25    mem,
  26    path::{Path, PathBuf},
  27    sync::Arc,
  28};
  29use ui::App;
  30use util::ResultExt;
  31
  32#[derive(Debug)]
  33pub struct UserMessage {
  34    pub content: ContentBlock,
  35}
  36
  37impl UserMessage {
  38    pub fn from_acp(
  39        message: impl IntoIterator<Item = acp::ContentBlock>,
  40        language_registry: Arc<LanguageRegistry>,
  41        cx: &mut App,
  42    ) -> Self {
  43        let mut content = ContentBlock::Empty;
  44        for chunk in message {
  45            content.append(chunk, &language_registry, cx)
  46        }
  47        Self { content: content }
  48    }
  49
  50    fn to_markdown(&self, cx: &App) -> String {
  51        format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
  52    }
  53}
  54
  55#[derive(Debug)]
  56pub struct MentionPath<'a>(&'a Path);
  57
  58impl<'a> MentionPath<'a> {
  59    const PREFIX: &'static str = "@file:";
  60
  61    pub fn new(path: &'a Path) -> Self {
  62        MentionPath(path)
  63    }
  64
  65    pub fn try_parse(url: &'a str) -> Option<Self> {
  66        let path = url.strip_prefix(Self::PREFIX)?;
  67        Some(MentionPath(Path::new(path)))
  68    }
  69
  70    pub fn path(&self) -> &Path {
  71        self.0
  72    }
  73}
  74
  75impl Display for MentionPath<'_> {
  76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  77        write!(
  78            f,
  79            "[@{}]({}{})",
  80            self.0.file_name().unwrap_or_default().display(),
  81            Self::PREFIX,
  82            self.0.display()
  83        )
  84    }
  85}
  86
  87#[derive(Debug, PartialEq)]
  88pub struct AssistantMessage {
  89    pub chunks: Vec<AssistantMessageChunk>,
  90}
  91
  92impl AssistantMessage {
  93    pub fn to_markdown(&self, cx: &App) -> String {
  94        format!(
  95            "## Assistant\n\n{}\n\n",
  96            self.chunks
  97                .iter()
  98                .map(|chunk| chunk.to_markdown(cx))
  99                .join("\n\n")
 100        )
 101    }
 102}
 103
 104#[derive(Debug, PartialEq)]
 105pub enum AssistantMessageChunk {
 106    Message { block: ContentBlock },
 107    Thought { block: ContentBlock },
 108}
 109
 110impl AssistantMessageChunk {
 111    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
 112        Self::Message {
 113            block: ContentBlock::new(chunk.into(), language_registry, cx),
 114        }
 115    }
 116
 117    fn to_markdown(&self, cx: &App) -> String {
 118        match self {
 119            Self::Message { block } => block.to_markdown(cx).to_string(),
 120            Self::Thought { block } => {
 121                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Debug)]
 128pub enum AgentThreadEntry {
 129    UserMessage(UserMessage),
 130    AssistantMessage(AssistantMessage),
 131    ToolCall(ToolCall),
 132}
 133
 134impl AgentThreadEntry {
 135    fn to_markdown(&self, cx: &App) -> String {
 136        match self {
 137            Self::UserMessage(message) => message.to_markdown(cx),
 138            Self::AssistantMessage(message) => message.to_markdown(cx),
 139            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 140        }
 141    }
 142
 143    pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
 144        if let AgentThreadEntry::ToolCall(call) = self {
 145            itertools::Either::Left(call.diffs())
 146        } else {
 147            itertools::Either::Right(std::iter::empty())
 148        }
 149    }
 150
 151    pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
 152        if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
 153            Some(locations)
 154        } else {
 155            None
 156        }
 157    }
 158}
 159
 160#[derive(Debug)]
 161pub struct ToolCall {
 162    pub id: acp::ToolCallId,
 163    pub label: Entity<Markdown>,
 164    pub kind: acp::ToolKind,
 165    pub content: Vec<ToolCallContent>,
 166    pub status: ToolCallStatus,
 167    pub locations: Vec<acp::ToolCallLocation>,
 168    pub raw_input: Option<serde_json::Value>,
 169}
 170
 171impl ToolCall {
 172    fn from_acp(
 173        tool_call: acp::ToolCall,
 174        status: ToolCallStatus,
 175        language_registry: Arc<LanguageRegistry>,
 176        cx: &mut App,
 177    ) -> Self {
 178        Self {
 179            id: tool_call.id,
 180            label: cx.new(|cx| {
 181                Markdown::new(
 182                    tool_call.title.into(),
 183                    Some(language_registry.clone()),
 184                    None,
 185                    cx,
 186                )
 187            }),
 188            kind: tool_call.kind,
 189            content: tool_call
 190                .content
 191                .into_iter()
 192                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 193                .collect(),
 194            locations: tool_call.locations,
 195            status,
 196            raw_input: tool_call.raw_input,
 197        }
 198    }
 199
 200    fn update(
 201        &mut self,
 202        fields: acp::ToolCallUpdateFields,
 203        language_registry: Arc<LanguageRegistry>,
 204        cx: &mut App,
 205    ) {
 206        let acp::ToolCallUpdateFields {
 207            kind,
 208            status,
 209            title,
 210            content,
 211            locations,
 212            raw_input,
 213        } = fields;
 214
 215        if let Some(kind) = kind {
 216            self.kind = kind;
 217        }
 218
 219        if let Some(status) = status {
 220            self.status = ToolCallStatus::Allowed { status };
 221        }
 222
 223        if let Some(title) = title {
 224            self.label = cx.new(|cx| Markdown::new_text(title.into(), cx));
 225        }
 226
 227        if let Some(content) = content {
 228            self.content = content
 229                .into_iter()
 230                .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
 231                .collect();
 232        }
 233
 234        if let Some(locations) = locations {
 235            self.locations = locations;
 236        }
 237
 238        if let Some(raw_input) = raw_input {
 239            self.raw_input = Some(raw_input);
 240        }
 241    }
 242
 243    pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
 244        self.content.iter().filter_map(|content| match content {
 245            ToolCallContent::ContentBlock { .. } => None,
 246            ToolCallContent::Diff { diff } => Some(diff),
 247        })
 248    }
 249
 250    fn to_markdown(&self, cx: &App) -> String {
 251        let mut markdown = format!(
 252            "**Tool Call: {}**\nStatus: {}\n\n",
 253            self.label.read(cx).source(),
 254            self.status
 255        );
 256        for content in &self.content {
 257            markdown.push_str(content.to_markdown(cx).as_str());
 258            markdown.push_str("\n\n");
 259        }
 260        markdown
 261    }
 262}
 263
 264#[derive(Debug)]
 265pub enum ToolCallStatus {
 266    WaitingForConfirmation {
 267        options: Vec<acp::PermissionOption>,
 268        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 269    },
 270    Allowed {
 271        status: acp::ToolCallStatus,
 272    },
 273    Rejected,
 274    Canceled,
 275}
 276
 277impl Display for ToolCallStatus {
 278    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 279        write!(
 280            f,
 281            "{}",
 282            match self {
 283                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 284                ToolCallStatus::Allowed { status } => match status {
 285                    acp::ToolCallStatus::Pending => "Pending",
 286                    acp::ToolCallStatus::InProgress => "In Progress",
 287                    acp::ToolCallStatus::Completed => "Completed",
 288                    acp::ToolCallStatus::Failed => "Failed",
 289                },
 290                ToolCallStatus::Rejected => "Rejected",
 291                ToolCallStatus::Canceled => "Canceled",
 292            }
 293        )
 294    }
 295}
 296
 297#[derive(Debug, PartialEq, Clone)]
 298pub enum ContentBlock {
 299    Empty,
 300    Markdown { markdown: Entity<Markdown> },
 301}
 302
 303impl ContentBlock {
 304    pub fn new(
 305        block: acp::ContentBlock,
 306        language_registry: &Arc<LanguageRegistry>,
 307        cx: &mut App,
 308    ) -> Self {
 309        let mut this = Self::Empty;
 310        this.append(block, language_registry, cx);
 311        this
 312    }
 313
 314    pub fn new_combined(
 315        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 316        language_registry: Arc<LanguageRegistry>,
 317        cx: &mut App,
 318    ) -> Self {
 319        let mut this = Self::Empty;
 320        for block in blocks {
 321            this.append(block, &language_registry, cx);
 322        }
 323        this
 324    }
 325
 326    pub fn append(
 327        &mut self,
 328        block: acp::ContentBlock,
 329        language_registry: &Arc<LanguageRegistry>,
 330        cx: &mut App,
 331    ) {
 332        let new_content = match block {
 333            acp::ContentBlock::Text(text_content) => text_content.text.clone(),
 334            acp::ContentBlock::ResourceLink(resource_link) => {
 335                if let Some(path) = resource_link.uri.strip_prefix("file://") {
 336                    format!("{}", MentionPath(path.as_ref()))
 337                } else {
 338                    resource_link.uri.clone()
 339                }
 340            }
 341            acp::ContentBlock::Image(_)
 342            | acp::ContentBlock::Audio(_)
 343            | acp::ContentBlock::Resource(_) => String::new(),
 344        };
 345
 346        match self {
 347            ContentBlock::Empty => {
 348                *self = ContentBlock::Markdown {
 349                    markdown: cx.new(|cx| {
 350                        Markdown::new(
 351                            new_content.into(),
 352                            Some(language_registry.clone()),
 353                            None,
 354                            cx,
 355                        )
 356                    }),
 357                };
 358            }
 359            ContentBlock::Markdown { markdown } => {
 360                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 361            }
 362        }
 363    }
 364
 365    fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 366        match self {
 367            ContentBlock::Empty => "",
 368            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 369        }
 370    }
 371
 372    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 373        match self {
 374            ContentBlock::Empty => None,
 375            ContentBlock::Markdown { markdown } => Some(markdown),
 376        }
 377    }
 378}
 379
 380#[derive(Debug)]
 381pub enum ToolCallContent {
 382    ContentBlock { content: ContentBlock },
 383    Diff { diff: Diff },
 384}
 385
 386impl ToolCallContent {
 387    pub fn from_acp(
 388        content: acp::ToolCallContent,
 389        language_registry: Arc<LanguageRegistry>,
 390        cx: &mut App,
 391    ) -> Self {
 392        match content {
 393            acp::ToolCallContent::Content { content } => Self::ContentBlock {
 394                content: ContentBlock::new(content, &language_registry, cx),
 395            },
 396            acp::ToolCallContent::Diff { diff } => Self::Diff {
 397                diff: Diff::from_acp(diff, language_registry, cx),
 398            },
 399        }
 400    }
 401
 402    pub fn to_markdown(&self, cx: &App) -> String {
 403        match self {
 404            Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
 405            Self::Diff { diff } => diff.to_markdown(cx),
 406        }
 407    }
 408}
 409
 410#[derive(Debug)]
 411pub struct Diff {
 412    pub multibuffer: Entity<MultiBuffer>,
 413    pub path: PathBuf,
 414    pub new_buffer: Entity<Buffer>,
 415    pub old_buffer: Entity<Buffer>,
 416    _task: Task<Result<()>>,
 417}
 418
 419impl Diff {
 420    pub fn from_acp(
 421        diff: acp::Diff,
 422        language_registry: Arc<LanguageRegistry>,
 423        cx: &mut App,
 424    ) -> Self {
 425        let acp::Diff {
 426            path,
 427            old_text,
 428            new_text,
 429        } = diff;
 430
 431        let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
 432
 433        let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
 434        let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
 435        let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
 436        let old_buffer_snapshot = old_buffer.read(cx).snapshot();
 437        let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
 438        let diff_task = buffer_diff.update(cx, |diff, cx| {
 439            diff.set_base_text(
 440                old_buffer_snapshot,
 441                Some(language_registry.clone()),
 442                new_buffer_snapshot,
 443                cx,
 444            )
 445        });
 446
 447        let task = cx.spawn({
 448            let multibuffer = multibuffer.clone();
 449            let path = path.clone();
 450            let new_buffer = new_buffer.clone();
 451            async move |cx| {
 452                diff_task.await?;
 453
 454                multibuffer
 455                    .update(cx, |multibuffer, cx| {
 456                        let hunk_ranges = {
 457                            let buffer = new_buffer.read(cx);
 458                            let diff = buffer_diff.read(cx);
 459                            diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
 460                                .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
 461                                .collect::<Vec<_>>()
 462                        };
 463
 464                        multibuffer.set_excerpts_for_path(
 465                            PathKey::for_buffer(&new_buffer, cx),
 466                            new_buffer.clone(),
 467                            hunk_ranges,
 468                            editor::DEFAULT_MULTIBUFFER_CONTEXT,
 469                            cx,
 470                        );
 471                        multibuffer.add_diff(buffer_diff.clone(), cx);
 472                    })
 473                    .log_err();
 474
 475                if let Some(language) = language_registry
 476                    .language_for_file_path(&path)
 477                    .await
 478                    .log_err()
 479                {
 480                    new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
 481                }
 482
 483                anyhow::Ok(())
 484            }
 485        });
 486
 487        Self {
 488            multibuffer,
 489            path,
 490            new_buffer,
 491            old_buffer,
 492            _task: task,
 493        }
 494    }
 495
 496    fn to_markdown(&self, cx: &App) -> String {
 497        let buffer_text = self
 498            .multibuffer
 499            .read(cx)
 500            .all_buffers()
 501            .iter()
 502            .map(|buffer| buffer.read(cx).text())
 503            .join("\n");
 504        format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
 505    }
 506}
 507
 508#[derive(Debug, Default)]
 509pub struct Plan {
 510    pub entries: Vec<PlanEntry>,
 511}
 512
 513#[derive(Debug)]
 514pub struct PlanStats<'a> {
 515    pub in_progress_entry: Option<&'a PlanEntry>,
 516    pub pending: u32,
 517    pub completed: u32,
 518}
 519
 520impl Plan {
 521    pub fn is_empty(&self) -> bool {
 522        self.entries.is_empty()
 523    }
 524
 525    pub fn stats(&self) -> PlanStats<'_> {
 526        let mut stats = PlanStats {
 527            in_progress_entry: None,
 528            pending: 0,
 529            completed: 0,
 530        };
 531
 532        for entry in &self.entries {
 533            match &entry.status {
 534                acp::PlanEntryStatus::Pending => {
 535                    stats.pending += 1;
 536                }
 537                acp::PlanEntryStatus::InProgress => {
 538                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 539                }
 540                acp::PlanEntryStatus::Completed => {
 541                    stats.completed += 1;
 542                }
 543            }
 544        }
 545
 546        stats
 547    }
 548}
 549
 550#[derive(Debug)]
 551pub struct PlanEntry {
 552    pub content: Entity<Markdown>,
 553    pub priority: acp::PlanEntryPriority,
 554    pub status: acp::PlanEntryStatus,
 555}
 556
 557impl PlanEntry {
 558    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 559        Self {
 560            content: cx.new(|cx| Markdown::new_text(entry.content.into(), cx)),
 561            priority: entry.priority,
 562            status: entry.status,
 563        }
 564    }
 565}
 566
 567pub struct AcpThread {
 568    title: SharedString,
 569    entries: Vec<AgentThreadEntry>,
 570    plan: Plan,
 571    project: Entity<Project>,
 572    action_log: Entity<ActionLog>,
 573    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 574    send_task: Option<Task<()>>,
 575    connection: Rc<dyn AgentConnection>,
 576    session_id: acp::SessionId,
 577}
 578
 579pub enum AcpThreadEvent {
 580    NewEntry,
 581    EntryUpdated(usize),
 582    ToolAuthorizationRequired,
 583    Stopped,
 584    Error,
 585    ServerExited(ExitStatus),
 586}
 587
 588impl EventEmitter<AcpThreadEvent> for AcpThread {}
 589
 590#[derive(PartialEq, Eq)]
 591pub enum ThreadStatus {
 592    Idle,
 593    WaitingForToolConfirmation,
 594    Generating,
 595}
 596
 597#[derive(Debug, Clone)]
 598pub enum LoadError {
 599    Unsupported {
 600        error_message: SharedString,
 601        upgrade_message: SharedString,
 602        upgrade_command: String,
 603    },
 604    Exited(i32),
 605    Other(SharedString),
 606}
 607
 608impl Display for LoadError {
 609    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 610        match self {
 611            LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
 612            LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
 613            LoadError::Other(msg) => write!(f, "{}", msg),
 614        }
 615    }
 616}
 617
 618impl Error for LoadError {}
 619
 620impl AcpThread {
 621    pub fn new(
 622        title: impl Into<SharedString>,
 623        connection: Rc<dyn AgentConnection>,
 624        project: Entity<Project>,
 625        session_id: acp::SessionId,
 626        cx: &mut Context<Self>,
 627    ) -> Self {
 628        let action_log = cx.new(|_| ActionLog::new(project.clone()));
 629
 630        Self {
 631            action_log,
 632            shared_buffers: Default::default(),
 633            entries: Default::default(),
 634            plan: Default::default(),
 635            title: title.into(),
 636            project,
 637            send_task: None,
 638            connection,
 639            session_id,
 640        }
 641    }
 642
 643    pub fn action_log(&self) -> &Entity<ActionLog> {
 644        &self.action_log
 645    }
 646
 647    pub fn project(&self) -> &Entity<Project> {
 648        &self.project
 649    }
 650
 651    pub fn title(&self) -> SharedString {
 652        self.title.clone()
 653    }
 654
 655    pub fn entries(&self) -> &[AgentThreadEntry] {
 656        &self.entries
 657    }
 658
 659    pub fn status(&self) -> ThreadStatus {
 660        if self.send_task.is_some() {
 661            if self.waiting_for_tool_confirmation() {
 662                ThreadStatus::WaitingForToolConfirmation
 663            } else {
 664                ThreadStatus::Generating
 665            }
 666        } else {
 667            ThreadStatus::Idle
 668        }
 669    }
 670
 671    pub fn has_pending_edit_tool_calls(&self) -> bool {
 672        for entry in self.entries.iter().rev() {
 673            match entry {
 674                AgentThreadEntry::UserMessage(_) => return false,
 675                AgentThreadEntry::ToolCall(
 676                    call @ ToolCall {
 677                        status:
 678                            ToolCallStatus::Allowed {
 679                                status:
 680                                    acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
 681                            },
 682                        ..
 683                    },
 684                ) if call.diffs().next().is_some() => {
 685                    return true;
 686                }
 687                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 688            }
 689        }
 690
 691        false
 692    }
 693
 694    pub fn used_tools_since_last_user_message(&self) -> bool {
 695        for entry in self.entries.iter().rev() {
 696            match entry {
 697                AgentThreadEntry::UserMessage(..) => return false,
 698                AgentThreadEntry::AssistantMessage(..) => continue,
 699                AgentThreadEntry::ToolCall(..) => return true,
 700            }
 701        }
 702
 703        false
 704    }
 705
 706    pub fn handle_session_update(
 707        &mut self,
 708        update: acp::SessionUpdate,
 709        cx: &mut Context<Self>,
 710    ) -> Result<()> {
 711        match update {
 712            acp::SessionUpdate::UserMessageChunk { content } => {
 713                self.push_user_content_block(content, cx);
 714            }
 715            acp::SessionUpdate::AgentMessageChunk { content } => {
 716                self.push_assistant_content_block(content, false, cx);
 717            }
 718            acp::SessionUpdate::AgentThoughtChunk { content } => {
 719                self.push_assistant_content_block(content, true, cx);
 720            }
 721            acp::SessionUpdate::ToolCall(tool_call) => {
 722                self.upsert_tool_call(tool_call, cx);
 723            }
 724            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 725                self.update_tool_call(tool_call_update, cx)?;
 726            }
 727            acp::SessionUpdate::Plan(plan) => {
 728                self.update_plan(plan, cx);
 729            }
 730        }
 731        Ok(())
 732    }
 733
 734    pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
 735        let language_registry = self.project.read(cx).languages().clone();
 736        let entries_len = self.entries.len();
 737
 738        if let Some(last_entry) = self.entries.last_mut()
 739            && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
 740        {
 741            content.append(chunk, &language_registry, cx);
 742            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 743        } else {
 744            let content = ContentBlock::new(chunk, &language_registry, cx);
 745            self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
 746        }
 747    }
 748
 749    pub fn push_assistant_content_block(
 750        &mut self,
 751        chunk: acp::ContentBlock,
 752        is_thought: bool,
 753        cx: &mut Context<Self>,
 754    ) {
 755        let language_registry = self.project.read(cx).languages().clone();
 756        let entries_len = self.entries.len();
 757        if let Some(last_entry) = self.entries.last_mut()
 758            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
 759        {
 760            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 761            match (chunks.last_mut(), is_thought) {
 762                (Some(AssistantMessageChunk::Message { block }), false)
 763                | (Some(AssistantMessageChunk::Thought { block }), true) => {
 764                    block.append(chunk, &language_registry, cx)
 765                }
 766                _ => {
 767                    let block = ContentBlock::new(chunk, &language_registry, cx);
 768                    if is_thought {
 769                        chunks.push(AssistantMessageChunk::Thought { block })
 770                    } else {
 771                        chunks.push(AssistantMessageChunk::Message { block })
 772                    }
 773                }
 774            }
 775        } else {
 776            let block = ContentBlock::new(chunk, &language_registry, cx);
 777            let chunk = if is_thought {
 778                AssistantMessageChunk::Thought { block }
 779            } else {
 780                AssistantMessageChunk::Message { block }
 781            };
 782
 783            self.push_entry(
 784                AgentThreadEntry::AssistantMessage(AssistantMessage {
 785                    chunks: vec![chunk],
 786                }),
 787                cx,
 788            );
 789        }
 790    }
 791
 792    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
 793        self.entries.push(entry);
 794        cx.emit(AcpThreadEvent::NewEntry);
 795    }
 796
 797    pub fn update_tool_call(
 798        &mut self,
 799        update: acp::ToolCallUpdate,
 800        cx: &mut Context<Self>,
 801    ) -> Result<()> {
 802        let languages = self.project.read(cx).languages().clone();
 803
 804        let (ix, current_call) = self
 805            .tool_call_mut(&update.id)
 806            .context("Tool call not found")?;
 807        current_call.update(update.fields, languages, cx);
 808
 809        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 810
 811        Ok(())
 812    }
 813
 814    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
 815    pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
 816        let status = ToolCallStatus::Allowed {
 817            status: tool_call.status,
 818        };
 819        self.upsert_tool_call_inner(tool_call, status, cx)
 820    }
 821
 822    pub fn upsert_tool_call_inner(
 823        &mut self,
 824        tool_call: acp::ToolCall,
 825        status: ToolCallStatus,
 826        cx: &mut Context<Self>,
 827    ) {
 828        let language_registry = self.project.read(cx).languages().clone();
 829        let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
 830
 831        let location = call.locations.last().cloned();
 832
 833        if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
 834            *current_call = call;
 835
 836            cx.emit(AcpThreadEvent::EntryUpdated(ix));
 837        } else {
 838            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
 839        }
 840
 841        if let Some(location) = location {
 842            self.set_project_location(location, cx)
 843        }
 844    }
 845
 846    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
 847        // The tool call we are looking for is typically the last one, or very close to the end.
 848        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
 849        self.entries
 850            .iter_mut()
 851            .enumerate()
 852            .rev()
 853            .find_map(|(index, tool_call)| {
 854                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
 855                    && &tool_call.id == id
 856                {
 857                    Some((index, tool_call))
 858                } else {
 859                    None
 860                }
 861            })
 862    }
 863
 864    pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
 865        self.project.update(cx, |project, cx| {
 866            let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
 867                return;
 868            };
 869            let buffer = project.open_buffer(path, cx);
 870            cx.spawn(async move |project, cx| {
 871                let buffer = buffer.await?;
 872
 873                project.update(cx, |project, cx| {
 874                    let position = if let Some(line) = location.line {
 875                        let snapshot = buffer.read(cx).snapshot();
 876                        let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
 877                        snapshot.anchor_before(point)
 878                    } else {
 879                        Anchor::MIN
 880                    };
 881
 882                    project.set_agent_location(
 883                        Some(AgentLocation {
 884                            buffer: buffer.downgrade(),
 885                            position,
 886                        }),
 887                        cx,
 888                    );
 889                })
 890            })
 891            .detach_and_log_err(cx);
 892        });
 893    }
 894
 895    pub fn request_tool_call_permission(
 896        &mut self,
 897        tool_call: acp::ToolCall,
 898        options: Vec<acp::PermissionOption>,
 899        cx: &mut Context<Self>,
 900    ) -> oneshot::Receiver<acp::PermissionOptionId> {
 901        let (tx, rx) = oneshot::channel();
 902
 903        let status = ToolCallStatus::WaitingForConfirmation {
 904            options,
 905            respond_tx: tx,
 906        };
 907
 908        self.upsert_tool_call_inner(tool_call, status, cx);
 909        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
 910        rx
 911    }
 912
 913    pub fn authorize_tool_call(
 914        &mut self,
 915        id: acp::ToolCallId,
 916        option_id: acp::PermissionOptionId,
 917        option_kind: acp::PermissionOptionKind,
 918        cx: &mut Context<Self>,
 919    ) {
 920        let Some((ix, call)) = self.tool_call_mut(&id) else {
 921            return;
 922        };
 923
 924        let new_status = match option_kind {
 925            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
 926                ToolCallStatus::Rejected
 927            }
 928            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
 929                ToolCallStatus::Allowed {
 930                    status: acp::ToolCallStatus::InProgress,
 931                }
 932            }
 933        };
 934
 935        let curr_status = mem::replace(&mut call.status, new_status);
 936
 937        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
 938            respond_tx.send(option_id).log_err();
 939        } else if cfg!(debug_assertions) {
 940            panic!("tried to authorize an already authorized tool call");
 941        }
 942
 943        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 944    }
 945
 946    /// Returns true if the last turn is awaiting tool authorization
 947    pub fn waiting_for_tool_confirmation(&self) -> bool {
 948        for entry in self.entries.iter().rev() {
 949            match &entry {
 950                AgentThreadEntry::ToolCall(call) => match call.status {
 951                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
 952                    ToolCallStatus::Allowed { .. }
 953                    | ToolCallStatus::Rejected
 954                    | ToolCallStatus::Canceled => continue,
 955                },
 956                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
 957                    // Reached the beginning of the turn
 958                    return false;
 959                }
 960            }
 961        }
 962        false
 963    }
 964
 965    pub fn plan(&self) -> &Plan {
 966        &self.plan
 967    }
 968
 969    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
 970        self.plan = Plan {
 971            entries: request
 972                .entries
 973                .into_iter()
 974                .map(|entry| PlanEntry::from_acp(entry, cx))
 975                .collect(),
 976        };
 977
 978        cx.notify();
 979    }
 980
 981    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
 982        self.plan
 983            .entries
 984            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
 985        cx.notify();
 986    }
 987
 988    #[cfg(any(test, feature = "test-support"))]
 989    pub fn send_raw(
 990        &mut self,
 991        message: &str,
 992        cx: &mut Context<Self>,
 993    ) -> BoxFuture<'static, Result<()>> {
 994        self.send(
 995            vec![acp::ContentBlock::Text(acp::TextContent {
 996                text: message.to_string(),
 997                annotations: None,
 998            })],
 999            cx,
1000        )
1001    }
1002
1003    pub fn send(
1004        &mut self,
1005        message: Vec<acp::ContentBlock>,
1006        cx: &mut Context<Self>,
1007    ) -> BoxFuture<'static, Result<()>> {
1008        let block = ContentBlock::new_combined(
1009            message.clone(),
1010            self.project.read(cx).languages().clone(),
1011            cx,
1012        );
1013        self.push_entry(
1014            AgentThreadEntry::UserMessage(UserMessage { content: block }),
1015            cx,
1016        );
1017        self.clear_completed_plan_entries(cx);
1018
1019        let (tx, rx) = oneshot::channel();
1020        let cancel_task = self.cancel(cx);
1021
1022        self.send_task = Some(cx.spawn(async move |this, cx| {
1023            async {
1024                cancel_task.await;
1025
1026                let result = this
1027                    .update(cx, |this, cx| {
1028                        this.connection.prompt(
1029                            acp::PromptRequest {
1030                                prompt: message,
1031                                session_id: this.session_id.clone(),
1032                            },
1033                            cx,
1034                        )
1035                    })?
1036                    .await;
1037                tx.send(result).log_err();
1038                this.update(cx, |this, _cx| this.send_task.take())?;
1039                anyhow::Ok(())
1040            }
1041            .await
1042            .log_err();
1043        }));
1044
1045        cx.spawn(async move |this, cx| match rx.await {
1046            Ok(Err(e)) => {
1047                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1048                    .log_err();
1049                Err(e)?
1050            }
1051            _ => {
1052                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1053                    .log_err();
1054                Ok(())
1055            }
1056        })
1057        .boxed()
1058    }
1059
1060    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1061        let Some(send_task) = self.send_task.take() else {
1062            return Task::ready(());
1063        };
1064
1065        for entry in self.entries.iter_mut() {
1066            if let AgentThreadEntry::ToolCall(call) = entry {
1067                let cancel = matches!(
1068                    call.status,
1069                    ToolCallStatus::WaitingForConfirmation { .. }
1070                        | ToolCallStatus::Allowed {
1071                            status: acp::ToolCallStatus::InProgress
1072                        }
1073                );
1074
1075                if cancel {
1076                    call.status = ToolCallStatus::Canceled;
1077                }
1078            }
1079        }
1080
1081        self.connection.cancel(&self.session_id, cx);
1082
1083        // Wait for the send task to complete
1084        cx.foreground_executor().spawn(send_task)
1085    }
1086
1087    pub fn read_text_file(
1088        &self,
1089        path: PathBuf,
1090        line: Option<u32>,
1091        limit: Option<u32>,
1092        reuse_shared_snapshot: bool,
1093        cx: &mut Context<Self>,
1094    ) -> Task<Result<String>> {
1095        let project = self.project.clone();
1096        let action_log = self.action_log.clone();
1097        cx.spawn(async move |this, cx| {
1098            let load = project.update(cx, |project, cx| {
1099                let path = project
1100                    .project_path_for_absolute_path(&path, cx)
1101                    .context("invalid path")?;
1102                anyhow::Ok(project.open_buffer(path, cx))
1103            });
1104            let buffer = load??.await?;
1105
1106            let snapshot = if reuse_shared_snapshot {
1107                this.read_with(cx, |this, _| {
1108                    this.shared_buffers.get(&buffer.clone()).cloned()
1109                })
1110                .log_err()
1111                .flatten()
1112            } else {
1113                None
1114            };
1115
1116            let snapshot = if let Some(snapshot) = snapshot {
1117                snapshot
1118            } else {
1119                action_log.update(cx, |action_log, cx| {
1120                    action_log.buffer_read(buffer.clone(), cx);
1121                })?;
1122                project.update(cx, |project, cx| {
1123                    let position = buffer
1124                        .read(cx)
1125                        .snapshot()
1126                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1127                    project.set_agent_location(
1128                        Some(AgentLocation {
1129                            buffer: buffer.downgrade(),
1130                            position,
1131                        }),
1132                        cx,
1133                    );
1134                })?;
1135
1136                buffer.update(cx, |buffer, _| buffer.snapshot())?
1137            };
1138
1139            this.update(cx, |this, _| {
1140                let text = snapshot.text();
1141                this.shared_buffers.insert(buffer.clone(), snapshot);
1142                if line.is_none() && limit.is_none() {
1143                    return Ok(text);
1144                }
1145                let limit = limit.unwrap_or(u32::MAX) as usize;
1146                let Some(line) = line else {
1147                    return Ok(text.lines().take(limit).collect::<String>());
1148                };
1149
1150                let count = text.lines().count();
1151                if count < line as usize {
1152                    anyhow::bail!("There are only {} lines", count);
1153                }
1154                Ok(text
1155                    .lines()
1156                    .skip(line as usize + 1)
1157                    .take(limit)
1158                    .collect::<String>())
1159            })?
1160        })
1161    }
1162
1163    pub fn write_text_file(
1164        &self,
1165        path: PathBuf,
1166        content: String,
1167        cx: &mut Context<Self>,
1168    ) -> Task<Result<()>> {
1169        let project = self.project.clone();
1170        let action_log = self.action_log.clone();
1171        cx.spawn(async move |this, cx| {
1172            let load = project.update(cx, |project, cx| {
1173                let path = project
1174                    .project_path_for_absolute_path(&path, cx)
1175                    .context("invalid path")?;
1176                anyhow::Ok(project.open_buffer(path, cx))
1177            });
1178            let buffer = load??.await?;
1179            let snapshot = this.update(cx, |this, cx| {
1180                this.shared_buffers
1181                    .get(&buffer)
1182                    .cloned()
1183                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1184            })?;
1185            let edits = cx
1186                .background_executor()
1187                .spawn(async move {
1188                    let old_text = snapshot.text();
1189                    text_diff(old_text.as_str(), &content)
1190                        .into_iter()
1191                        .map(|(range, replacement)| {
1192                            (
1193                                snapshot.anchor_after(range.start)
1194                                    ..snapshot.anchor_before(range.end),
1195                                replacement,
1196                            )
1197                        })
1198                        .collect::<Vec<_>>()
1199                })
1200                .await;
1201            cx.update(|cx| {
1202                project.update(cx, |project, cx| {
1203                    project.set_agent_location(
1204                        Some(AgentLocation {
1205                            buffer: buffer.downgrade(),
1206                            position: edits
1207                                .last()
1208                                .map(|(range, _)| range.end)
1209                                .unwrap_or(Anchor::MIN),
1210                        }),
1211                        cx,
1212                    );
1213                });
1214
1215                action_log.update(cx, |action_log, cx| {
1216                    action_log.buffer_read(buffer.clone(), cx);
1217                });
1218                buffer.update(cx, |buffer, cx| {
1219                    buffer.edit(edits, None, cx);
1220                });
1221                action_log.update(cx, |action_log, cx| {
1222                    action_log.buffer_edited(buffer.clone(), cx);
1223                });
1224            })?;
1225            project
1226                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1227                .await
1228        })
1229    }
1230
1231    pub fn to_markdown(&self, cx: &App) -> String {
1232        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1233    }
1234
1235    pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1236        cx.emit(AcpThreadEvent::ServerExited(status));
1237    }
1238}
1239
1240#[cfg(test)]
1241mod tests {
1242    use super::*;
1243    use anyhow::anyhow;
1244    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1245    use gpui::{AsyncApp, TestAppContext, WeakEntity};
1246    use indoc::indoc;
1247    use project::FakeFs;
1248    use rand::Rng as _;
1249    use serde_json::json;
1250    use settings::SettingsStore;
1251    use smol::stream::StreamExt as _;
1252    use std::{cell::RefCell, rc::Rc, time::Duration};
1253
1254    use util::path;
1255
1256    fn init_test(cx: &mut TestAppContext) {
1257        env_logger::try_init().ok();
1258        cx.update(|cx| {
1259            let settings_store = SettingsStore::test(cx);
1260            cx.set_global(settings_store);
1261            Project::init_settings(cx);
1262            language::init(cx);
1263        });
1264    }
1265
1266    #[gpui::test]
1267    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1268        init_test(cx);
1269
1270        let fs = FakeFs::new(cx.executor());
1271        let project = Project::test(fs, [], cx).await;
1272        let connection = Rc::new(FakeAgentConnection::new());
1273        let thread = cx
1274            .spawn(async move |mut cx| {
1275                connection
1276                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1277                    .await
1278            })
1279            .await
1280            .unwrap();
1281
1282        // Test creating a new user message
1283        thread.update(cx, |thread, cx| {
1284            thread.push_user_content_block(
1285                acp::ContentBlock::Text(acp::TextContent {
1286                    annotations: None,
1287                    text: "Hello, ".to_string(),
1288                }),
1289                cx,
1290            );
1291        });
1292
1293        thread.update(cx, |thread, cx| {
1294            assert_eq!(thread.entries.len(), 1);
1295            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1296                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1297            } else {
1298                panic!("Expected UserMessage");
1299            }
1300        });
1301
1302        // Test appending to existing user message
1303        thread.update(cx, |thread, cx| {
1304            thread.push_user_content_block(
1305                acp::ContentBlock::Text(acp::TextContent {
1306                    annotations: None,
1307                    text: "world!".to_string(),
1308                }),
1309                cx,
1310            );
1311        });
1312
1313        thread.update(cx, |thread, cx| {
1314            assert_eq!(thread.entries.len(), 1);
1315            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1316                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1317            } else {
1318                panic!("Expected UserMessage");
1319            }
1320        });
1321
1322        // Test creating new user message after assistant message
1323        thread.update(cx, |thread, cx| {
1324            thread.push_assistant_content_block(
1325                acp::ContentBlock::Text(acp::TextContent {
1326                    annotations: None,
1327                    text: "Assistant response".to_string(),
1328                }),
1329                false,
1330                cx,
1331            );
1332        });
1333
1334        thread.update(cx, |thread, cx| {
1335            thread.push_user_content_block(
1336                acp::ContentBlock::Text(acp::TextContent {
1337                    annotations: None,
1338                    text: "New user message".to_string(),
1339                }),
1340                cx,
1341            );
1342        });
1343
1344        thread.update(cx, |thread, cx| {
1345            assert_eq!(thread.entries.len(), 3);
1346            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1347                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1348            } else {
1349                panic!("Expected UserMessage at index 2");
1350            }
1351        });
1352    }
1353
1354    #[gpui::test]
1355    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1356        init_test(cx);
1357
1358        let fs = FakeFs::new(cx.executor());
1359        let project = Project::test(fs, [], cx).await;
1360        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1361            |_, thread, mut cx| {
1362                async move {
1363                    thread.update(&mut cx, |thread, cx| {
1364                        thread
1365                            .handle_session_update(
1366                                acp::SessionUpdate::AgentThoughtChunk {
1367                                    content: "Thinking ".into(),
1368                                },
1369                                cx,
1370                            )
1371                            .unwrap();
1372                        thread
1373                            .handle_session_update(
1374                                acp::SessionUpdate::AgentThoughtChunk {
1375                                    content: "hard!".into(),
1376                                },
1377                                cx,
1378                            )
1379                            .unwrap();
1380                    })
1381                }
1382                .boxed_local()
1383            },
1384        ));
1385
1386        let thread = cx
1387            .spawn(async move |mut cx| {
1388                connection
1389                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1390                    .await
1391            })
1392            .await
1393            .unwrap();
1394
1395        thread
1396            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1397            .await
1398            .unwrap();
1399
1400        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1401        assert_eq!(
1402            output,
1403            indoc! {r#"
1404            ## User
1405
1406            Hello from Zed!
1407
1408            ## Assistant
1409
1410            <thinking>
1411            Thinking hard!
1412            </thinking>
1413
1414            "#}
1415        );
1416    }
1417
1418    #[gpui::test]
1419    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1420        init_test(cx);
1421
1422        let fs = FakeFs::new(cx.executor());
1423        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1424            .await;
1425        let project = Project::test(fs.clone(), [], cx).await;
1426        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1427        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1428        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1429            move |_, thread, mut cx| {
1430                let read_file_tx = read_file_tx.clone();
1431                async move {
1432                    let content = thread
1433                        .update(&mut cx, |thread, cx| {
1434                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1435                        })
1436                        .unwrap()
1437                        .await
1438                        .unwrap();
1439                    assert_eq!(content, "one\ntwo\nthree\n");
1440                    read_file_tx.take().unwrap().send(()).unwrap();
1441                    thread
1442                        .update(&mut cx, |thread, cx| {
1443                            thread.write_text_file(
1444                                path!("/tmp/foo").into(),
1445                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
1446                                cx,
1447                            )
1448                        })
1449                        .unwrap()
1450                        .await
1451                        .unwrap();
1452                    Ok(())
1453                }
1454                .boxed_local()
1455            },
1456        ));
1457
1458        let (worktree, pathbuf) = project
1459            .update(cx, |project, cx| {
1460                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1461            })
1462            .await
1463            .unwrap();
1464        let buffer = project
1465            .update(cx, |project, cx| {
1466                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1467            })
1468            .await
1469            .unwrap();
1470
1471        let thread = cx
1472            .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1473            .await
1474            .unwrap();
1475
1476        let request = thread.update(cx, |thread, cx| {
1477            thread.send_raw("Extend the count in /tmp/foo", cx)
1478        });
1479        read_file_rx.await.ok();
1480        buffer.update(cx, |buffer, cx| {
1481            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1482        });
1483        cx.run_until_parked();
1484        assert_eq!(
1485            buffer.read_with(cx, |buffer, _| buffer.text()),
1486            "zero\none\ntwo\nthree\nfour\nfive\n"
1487        );
1488        assert_eq!(
1489            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1490            "zero\none\ntwo\nthree\nfour\nfive\n"
1491        );
1492        request.await.unwrap();
1493    }
1494
1495    #[gpui::test]
1496    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1497        init_test(cx);
1498
1499        let fs = FakeFs::new(cx.executor());
1500        let project = Project::test(fs, [], cx).await;
1501        let id = acp::ToolCallId("test".into());
1502
1503        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1504            let id = id.clone();
1505            move |_, thread, mut cx| {
1506                let id = id.clone();
1507                async move {
1508                    thread
1509                        .update(&mut cx, |thread, cx| {
1510                            thread.handle_session_update(
1511                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1512                                    id: id.clone(),
1513                                    title: "Label".into(),
1514                                    kind: acp::ToolKind::Fetch,
1515                                    status: acp::ToolCallStatus::InProgress,
1516                                    content: vec![],
1517                                    locations: vec![],
1518                                    raw_input: None,
1519                                }),
1520                                cx,
1521                            )
1522                        })
1523                        .unwrap()
1524                        .unwrap();
1525                    Ok(())
1526                }
1527                .boxed_local()
1528            }
1529        }));
1530
1531        let thread = cx
1532            .spawn(async move |mut cx| {
1533                connection
1534                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1535                    .await
1536            })
1537            .await
1538            .unwrap();
1539
1540        let request = thread.update(cx, |thread, cx| {
1541            thread.send_raw("Fetch https://example.com", cx)
1542        });
1543
1544        run_until_first_tool_call(&thread, cx).await;
1545
1546        thread.read_with(cx, |thread, _| {
1547            assert!(matches!(
1548                thread.entries[1],
1549                AgentThreadEntry::ToolCall(ToolCall {
1550                    status: ToolCallStatus::Allowed {
1551                        status: acp::ToolCallStatus::InProgress,
1552                        ..
1553                    },
1554                    ..
1555                })
1556            ));
1557        });
1558
1559        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1560
1561        thread.read_with(cx, |thread, _| {
1562            assert!(matches!(
1563                &thread.entries[1],
1564                AgentThreadEntry::ToolCall(ToolCall {
1565                    status: ToolCallStatus::Canceled,
1566                    ..
1567                })
1568            ));
1569        });
1570
1571        thread
1572            .update(cx, |thread, cx| {
1573                thread.handle_session_update(
1574                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1575                        id,
1576                        fields: acp::ToolCallUpdateFields {
1577                            status: Some(acp::ToolCallStatus::Completed),
1578                            ..Default::default()
1579                        },
1580                    }),
1581                    cx,
1582                )
1583            })
1584            .unwrap();
1585
1586        request.await.unwrap();
1587
1588        thread.read_with(cx, |thread, _| {
1589            assert!(matches!(
1590                thread.entries[1],
1591                AgentThreadEntry::ToolCall(ToolCall {
1592                    status: ToolCallStatus::Allowed {
1593                        status: acp::ToolCallStatus::Completed,
1594                        ..
1595                    },
1596                    ..
1597                })
1598            ));
1599        });
1600    }
1601
1602    #[gpui::test]
1603    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1604        init_test(cx);
1605        let fs = FakeFs::new(cx.background_executor.clone());
1606        fs.insert_tree(path!("/test"), json!({})).await;
1607        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1608
1609        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1610            move |_, thread, mut cx| {
1611                async move {
1612                    thread
1613                        .update(&mut cx, |thread, cx| {
1614                            thread.handle_session_update(
1615                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1616                                    id: acp::ToolCallId("test".into()),
1617                                    title: "Label".into(),
1618                                    kind: acp::ToolKind::Edit,
1619                                    status: acp::ToolCallStatus::Completed,
1620                                    content: vec![acp::ToolCallContent::Diff {
1621                                        diff: acp::Diff {
1622                                            path: "/test/test.txt".into(),
1623                                            old_text: None,
1624                                            new_text: "foo".into(),
1625                                        },
1626                                    }],
1627                                    locations: vec![],
1628                                    raw_input: None,
1629                                }),
1630                                cx,
1631                            )
1632                        })
1633                        .unwrap()
1634                        .unwrap();
1635                    Ok(())
1636                }
1637                .boxed_local()
1638            }
1639        }));
1640
1641        let thread = connection
1642            .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1643            .await
1644            .unwrap();
1645        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1646            .await
1647            .unwrap();
1648
1649        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1650    }
1651
1652    async fn run_until_first_tool_call(
1653        thread: &Entity<AcpThread>,
1654        cx: &mut TestAppContext,
1655    ) -> usize {
1656        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1657
1658        let subscription = cx.update(|cx| {
1659            cx.subscribe(thread, move |thread, _, cx| {
1660                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1661                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1662                        return tx.try_send(ix).unwrap();
1663                    }
1664                }
1665            })
1666        });
1667
1668        select! {
1669            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1670                panic!("Timeout waiting for tool call")
1671            }
1672            ix = rx.next().fuse() => {
1673                drop(subscription);
1674                ix.unwrap()
1675            }
1676        }
1677    }
1678
1679    #[derive(Clone, Default)]
1680    struct FakeAgentConnection {
1681        auth_methods: Vec<acp::AuthMethod>,
1682        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1683        on_user_message: Option<
1684            Rc<
1685                dyn Fn(
1686                        acp::PromptRequest,
1687                        WeakEntity<AcpThread>,
1688                        AsyncApp,
1689                    ) -> LocalBoxFuture<'static, Result<()>>
1690                    + 'static,
1691            >,
1692        >,
1693    }
1694
1695    impl FakeAgentConnection {
1696        fn new() -> Self {
1697            Self {
1698                auth_methods: Vec::new(),
1699                on_user_message: None,
1700                sessions: Arc::default(),
1701            }
1702        }
1703
1704        #[expect(unused)]
1705        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1706            self.auth_methods = auth_methods;
1707            self
1708        }
1709
1710        fn on_user_message(
1711            mut self,
1712            handler: impl Fn(
1713                acp::PromptRequest,
1714                WeakEntity<AcpThread>,
1715                AsyncApp,
1716            ) -> LocalBoxFuture<'static, Result<()>>
1717            + 'static,
1718        ) -> Self {
1719            self.on_user_message.replace(Rc::new(handler));
1720            self
1721        }
1722    }
1723
1724    impl AgentConnection for FakeAgentConnection {
1725        fn auth_methods(&self) -> &[acp::AuthMethod] {
1726            &self.auth_methods
1727        }
1728
1729        fn new_thread(
1730            self: Rc<Self>,
1731            project: Entity<Project>,
1732            _cwd: &Path,
1733            cx: &mut gpui::AsyncApp,
1734        ) -> Task<gpui::Result<Entity<AcpThread>>> {
1735            let session_id = acp::SessionId(
1736                rand::thread_rng()
1737                    .sample_iter(&rand::distributions::Alphanumeric)
1738                    .take(7)
1739                    .map(char::from)
1740                    .collect::<String>()
1741                    .into(),
1742            );
1743            let thread = cx
1744                .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1745                .unwrap();
1746            self.sessions.lock().insert(session_id, thread.downgrade());
1747            Task::ready(Ok(thread))
1748        }
1749
1750        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1751            if self.auth_methods().iter().any(|m| m.id == method) {
1752                Task::ready(Ok(()))
1753            } else {
1754                Task::ready(Err(anyhow!("Invalid Auth Method")))
1755            }
1756        }
1757
1758        fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<gpui::Result<()>> {
1759            let sessions = self.sessions.lock();
1760            let thread = sessions.get(&params.session_id).unwrap();
1761            if let Some(handler) = &self.on_user_message {
1762                let handler = handler.clone();
1763                let thread = thread.clone();
1764                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1765            } else {
1766                Task::ready(Ok(()))
1767            }
1768        }
1769
1770        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1771            let sessions = self.sessions.lock();
1772            let thread = sessions.get(&session_id).unwrap().clone();
1773
1774            cx.spawn(async move |cx| {
1775                thread
1776                    .update(cx, |thread, cx| thread.cancel(cx))
1777                    .unwrap()
1778                    .await
1779            })
1780            .detach();
1781        }
1782    }
1783}