acp_thread.rs

   1mod connection;
   2pub use connection::*;
   3
   4use agent_client_protocol as acp;
   5use anyhow::{Context as _, Result};
   6use assistant_tool::ActionLog;
   7use buffer_diff::BufferDiff;
   8use editor::{Bias, MultiBuffer, PathKey};
   9use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  10use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
  11use itertools::Itertools;
  12use language::{
  13    Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
  14    text_diff,
  15};
  16use markdown::Markdown;
  17use project::{AgentLocation, Project};
  18use std::collections::HashMap;
  19use std::error::Error;
  20use std::fmt::Formatter;
  21use std::rc::Rc;
  22use std::{
  23    fmt::Display,
  24    mem,
  25    path::{Path, PathBuf},
  26    sync::Arc,
  27};
  28use ui::App;
  29use util::ResultExt;
  30
  31#[derive(Debug)]
  32pub struct UserMessage {
  33    pub content: ContentBlock,
  34}
  35
  36impl UserMessage {
  37    pub fn from_acp(
  38        message: impl IntoIterator<Item = acp::ContentBlock>,
  39        language_registry: Arc<LanguageRegistry>,
  40        cx: &mut App,
  41    ) -> Self {
  42        let mut content = ContentBlock::Empty;
  43        for chunk in message {
  44            content.append(chunk, &language_registry, cx)
  45        }
  46        Self { content: content }
  47    }
  48
  49    fn to_markdown(&self, cx: &App) -> String {
  50        format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
  51    }
  52}
  53
  54#[derive(Debug)]
  55pub struct MentionPath<'a>(&'a Path);
  56
  57impl<'a> MentionPath<'a> {
  58    const PREFIX: &'static str = "@file:";
  59
  60    pub fn new(path: &'a Path) -> Self {
  61        MentionPath(path)
  62    }
  63
  64    pub fn try_parse(url: &'a str) -> Option<Self> {
  65        let path = url.strip_prefix(Self::PREFIX)?;
  66        Some(MentionPath(Path::new(path)))
  67    }
  68
  69    pub fn path(&self) -> &Path {
  70        self.0
  71    }
  72}
  73
  74impl Display for MentionPath<'_> {
  75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  76        write!(
  77            f,
  78            "[@{}]({}{})",
  79            self.0.file_name().unwrap_or_default().display(),
  80            Self::PREFIX,
  81            self.0.display()
  82        )
  83    }
  84}
  85
  86#[derive(Debug, PartialEq)]
  87pub struct AssistantMessage {
  88    pub chunks: Vec<AssistantMessageChunk>,
  89}
  90
  91impl AssistantMessage {
  92    pub fn to_markdown(&self, cx: &App) -> String {
  93        format!(
  94            "## Assistant\n\n{}\n\n",
  95            self.chunks
  96                .iter()
  97                .map(|chunk| chunk.to_markdown(cx))
  98                .join("\n\n")
  99        )
 100    }
 101}
 102
 103#[derive(Debug, PartialEq)]
 104pub enum AssistantMessageChunk {
 105    Message { block: ContentBlock },
 106    Thought { block: ContentBlock },
 107}
 108
 109impl AssistantMessageChunk {
 110    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
 111        Self::Message {
 112            block: ContentBlock::new(chunk.into(), language_registry, cx),
 113        }
 114    }
 115
 116    fn to_markdown(&self, cx: &App) -> String {
 117        match self {
 118            Self::Message { block } => block.to_markdown(cx).to_string(),
 119            Self::Thought { block } => {
 120                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 121            }
 122        }
 123    }
 124}
 125
 126#[derive(Debug)]
 127pub enum AgentThreadEntry {
 128    UserMessage(UserMessage),
 129    AssistantMessage(AssistantMessage),
 130    ToolCall(ToolCall),
 131}
 132
 133impl AgentThreadEntry {
 134    fn to_markdown(&self, cx: &App) -> String {
 135        match self {
 136            Self::UserMessage(message) => message.to_markdown(cx),
 137            Self::AssistantMessage(message) => message.to_markdown(cx),
 138            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 139        }
 140    }
 141
 142    pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
 143        if let AgentThreadEntry::ToolCall(call) = self {
 144            itertools::Either::Left(call.diffs())
 145        } else {
 146            itertools::Either::Right(std::iter::empty())
 147        }
 148    }
 149
 150    pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
 151        if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
 152            Some(locations)
 153        } else {
 154            None
 155        }
 156    }
 157}
 158
 159#[derive(Debug)]
 160pub struct ToolCall {
 161    pub id: acp::ToolCallId,
 162    pub label: Entity<Markdown>,
 163    pub kind: acp::ToolKind,
 164    pub content: Vec<ToolCallContent>,
 165    pub status: ToolCallStatus,
 166    pub locations: Vec<acp::ToolCallLocation>,
 167    pub raw_input: Option<serde_json::Value>,
 168}
 169
 170impl ToolCall {
 171    fn from_acp(
 172        tool_call: acp::ToolCall,
 173        status: ToolCallStatus,
 174        language_registry: Arc<LanguageRegistry>,
 175        cx: &mut App,
 176    ) -> Self {
 177        Self {
 178            id: tool_call.id,
 179            label: cx.new(|cx| {
 180                Markdown::new(
 181                    tool_call.label.into(),
 182                    Some(language_registry.clone()),
 183                    None,
 184                    cx,
 185                )
 186            }),
 187            kind: tool_call.kind,
 188            content: tool_call
 189                .content
 190                .into_iter()
 191                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 192                .collect(),
 193            locations: tool_call.locations,
 194            status,
 195            raw_input: tool_call.raw_input,
 196        }
 197    }
 198
 199    fn update(
 200        &mut self,
 201        fields: acp::ToolCallUpdateFields,
 202        language_registry: Arc<LanguageRegistry>,
 203        cx: &mut App,
 204    ) {
 205        let acp::ToolCallUpdateFields {
 206            kind,
 207            status,
 208            label,
 209            content,
 210            locations,
 211            raw_input,
 212        } = fields;
 213
 214        if let Some(kind) = kind {
 215            self.kind = kind;
 216        }
 217
 218        if let Some(status) = status {
 219            self.status = ToolCallStatus::Allowed { status };
 220        }
 221
 222        if let Some(label) = label {
 223            self.label = cx.new(|cx| Markdown::new_text(label.into(), cx));
 224        }
 225
 226        if let Some(content) = content {
 227            self.content = content
 228                .into_iter()
 229                .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
 230                .collect();
 231        }
 232
 233        if let Some(locations) = locations {
 234            self.locations = locations;
 235        }
 236
 237        if let Some(raw_input) = raw_input {
 238            self.raw_input = Some(raw_input);
 239        }
 240    }
 241
 242    pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
 243        self.content.iter().filter_map(|content| match content {
 244            ToolCallContent::ContentBlock { .. } => None,
 245            ToolCallContent::Diff { diff } => Some(diff),
 246        })
 247    }
 248
 249    fn to_markdown(&self, cx: &App) -> String {
 250        let mut markdown = format!(
 251            "**Tool Call: {}**\nStatus: {}\n\n",
 252            self.label.read(cx).source(),
 253            self.status
 254        );
 255        for content in &self.content {
 256            markdown.push_str(content.to_markdown(cx).as_str());
 257            markdown.push_str("\n\n");
 258        }
 259        markdown
 260    }
 261}
 262
 263#[derive(Debug)]
 264pub enum ToolCallStatus {
 265    WaitingForConfirmation {
 266        options: Vec<acp::PermissionOption>,
 267        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 268    },
 269    Allowed {
 270        status: acp::ToolCallStatus,
 271    },
 272    Rejected,
 273    Canceled,
 274}
 275
 276impl Display for ToolCallStatus {
 277    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 278        write!(
 279            f,
 280            "{}",
 281            match self {
 282                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 283                ToolCallStatus::Allowed { status } => match status {
 284                    acp::ToolCallStatus::Pending => "Pending",
 285                    acp::ToolCallStatus::InProgress => "In Progress",
 286                    acp::ToolCallStatus::Completed => "Completed",
 287                    acp::ToolCallStatus::Failed => "Failed",
 288                },
 289                ToolCallStatus::Rejected => "Rejected",
 290                ToolCallStatus::Canceled => "Canceled",
 291            }
 292        )
 293    }
 294}
 295
 296#[derive(Debug, PartialEq, Clone)]
 297pub enum ContentBlock {
 298    Empty,
 299    Markdown { markdown: Entity<Markdown> },
 300}
 301
 302impl ContentBlock {
 303    pub fn new(
 304        block: acp::ContentBlock,
 305        language_registry: &Arc<LanguageRegistry>,
 306        cx: &mut App,
 307    ) -> Self {
 308        let mut this = Self::Empty;
 309        this.append(block, language_registry, cx);
 310        this
 311    }
 312
 313    pub fn new_combined(
 314        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 315        language_registry: Arc<LanguageRegistry>,
 316        cx: &mut App,
 317    ) -> Self {
 318        let mut this = Self::Empty;
 319        for block in blocks {
 320            this.append(block, &language_registry, cx);
 321        }
 322        this
 323    }
 324
 325    pub fn append(
 326        &mut self,
 327        block: acp::ContentBlock,
 328        language_registry: &Arc<LanguageRegistry>,
 329        cx: &mut App,
 330    ) {
 331        let new_content = match block {
 332            acp::ContentBlock::Text(text_content) => text_content.text.clone(),
 333            acp::ContentBlock::ResourceLink(resource_link) => {
 334                if let Some(path) = resource_link.uri.strip_prefix("file://") {
 335                    format!("{}", MentionPath(path.as_ref()))
 336                } else {
 337                    resource_link.uri.clone()
 338                }
 339            }
 340            acp::ContentBlock::Image(_)
 341            | acp::ContentBlock::Audio(_)
 342            | acp::ContentBlock::Resource(_) => String::new(),
 343        };
 344
 345        match self {
 346            ContentBlock::Empty => {
 347                *self = ContentBlock::Markdown {
 348                    markdown: cx.new(|cx| {
 349                        Markdown::new(
 350                            new_content.into(),
 351                            Some(language_registry.clone()),
 352                            None,
 353                            cx,
 354                        )
 355                    }),
 356                };
 357            }
 358            ContentBlock::Markdown { markdown } => {
 359                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 360            }
 361        }
 362    }
 363
 364    fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 365        match self {
 366            ContentBlock::Empty => "",
 367            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 368        }
 369    }
 370
 371    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 372        match self {
 373            ContentBlock::Empty => None,
 374            ContentBlock::Markdown { markdown } => Some(markdown),
 375        }
 376    }
 377}
 378
 379#[derive(Debug)]
 380pub enum ToolCallContent {
 381    ContentBlock { content: ContentBlock },
 382    Diff { diff: Diff },
 383}
 384
 385impl ToolCallContent {
 386    pub fn from_acp(
 387        content: acp::ToolCallContent,
 388        language_registry: Arc<LanguageRegistry>,
 389        cx: &mut App,
 390    ) -> Self {
 391        match content {
 392            acp::ToolCallContent::Content { content } => Self::ContentBlock {
 393                content: ContentBlock::new(content, &language_registry, cx),
 394            },
 395            acp::ToolCallContent::Diff { diff } => Self::Diff {
 396                diff: Diff::from_acp(diff, language_registry, cx),
 397            },
 398        }
 399    }
 400
 401    pub fn to_markdown(&self, cx: &App) -> String {
 402        match self {
 403            Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
 404            Self::Diff { diff } => diff.to_markdown(cx),
 405        }
 406    }
 407}
 408
 409#[derive(Debug)]
 410pub struct Diff {
 411    pub multibuffer: Entity<MultiBuffer>,
 412    pub path: PathBuf,
 413    pub new_buffer: Entity<Buffer>,
 414    pub old_buffer: Entity<Buffer>,
 415    _task: Task<Result<()>>,
 416}
 417
 418impl Diff {
 419    pub fn from_acp(
 420        diff: acp::Diff,
 421        language_registry: Arc<LanguageRegistry>,
 422        cx: &mut App,
 423    ) -> Self {
 424        let acp::Diff {
 425            path,
 426            old_text,
 427            new_text,
 428        } = diff;
 429
 430        let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
 431
 432        let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
 433        let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
 434        let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
 435        let old_buffer_snapshot = old_buffer.read(cx).snapshot();
 436        let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
 437        let diff_task = buffer_diff.update(cx, |diff, cx| {
 438            diff.set_base_text(
 439                old_buffer_snapshot,
 440                Some(language_registry.clone()),
 441                new_buffer_snapshot,
 442                cx,
 443            )
 444        });
 445
 446        let task = cx.spawn({
 447            let multibuffer = multibuffer.clone();
 448            let path = path.clone();
 449            let new_buffer = new_buffer.clone();
 450            async move |cx| {
 451                diff_task.await?;
 452
 453                multibuffer
 454                    .update(cx, |multibuffer, cx| {
 455                        let hunk_ranges = {
 456                            let buffer = new_buffer.read(cx);
 457                            let diff = buffer_diff.read(cx);
 458                            diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
 459                                .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
 460                                .collect::<Vec<_>>()
 461                        };
 462
 463                        multibuffer.set_excerpts_for_path(
 464                            PathKey::for_buffer(&new_buffer, cx),
 465                            new_buffer.clone(),
 466                            hunk_ranges,
 467                            editor::DEFAULT_MULTIBUFFER_CONTEXT,
 468                            cx,
 469                        );
 470                        multibuffer.add_diff(buffer_diff.clone(), cx);
 471                    })
 472                    .log_err();
 473
 474                if let Some(language) = language_registry
 475                    .language_for_file_path(&path)
 476                    .await
 477                    .log_err()
 478                {
 479                    new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
 480                }
 481
 482                anyhow::Ok(())
 483            }
 484        });
 485
 486        Self {
 487            multibuffer,
 488            path,
 489            new_buffer,
 490            old_buffer,
 491            _task: task,
 492        }
 493    }
 494
 495    fn to_markdown(&self, cx: &App) -> String {
 496        let buffer_text = self
 497            .multibuffer
 498            .read(cx)
 499            .all_buffers()
 500            .iter()
 501            .map(|buffer| buffer.read(cx).text())
 502            .join("\n");
 503        format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
 504    }
 505}
 506
 507#[derive(Debug, Default)]
 508pub struct Plan {
 509    pub entries: Vec<PlanEntry>,
 510}
 511
 512#[derive(Debug)]
 513pub struct PlanStats<'a> {
 514    pub in_progress_entry: Option<&'a PlanEntry>,
 515    pub pending: u32,
 516    pub completed: u32,
 517}
 518
 519impl Plan {
 520    pub fn is_empty(&self) -> bool {
 521        self.entries.is_empty()
 522    }
 523
 524    pub fn stats(&self) -> PlanStats<'_> {
 525        let mut stats = PlanStats {
 526            in_progress_entry: None,
 527            pending: 0,
 528            completed: 0,
 529        };
 530
 531        for entry in &self.entries {
 532            match &entry.status {
 533                acp::PlanEntryStatus::Pending => {
 534                    stats.pending += 1;
 535                }
 536                acp::PlanEntryStatus::InProgress => {
 537                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 538                }
 539                acp::PlanEntryStatus::Completed => {
 540                    stats.completed += 1;
 541                }
 542            }
 543        }
 544
 545        stats
 546    }
 547}
 548
 549#[derive(Debug)]
 550pub struct PlanEntry {
 551    pub content: Entity<Markdown>,
 552    pub priority: acp::PlanEntryPriority,
 553    pub status: acp::PlanEntryStatus,
 554}
 555
 556impl PlanEntry {
 557    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 558        Self {
 559            content: cx.new(|cx| Markdown::new_text(entry.content.into(), cx)),
 560            priority: entry.priority,
 561            status: entry.status,
 562        }
 563    }
 564}
 565
 566pub struct AcpThread {
 567    title: SharedString,
 568    entries: Vec<AgentThreadEntry>,
 569    plan: Plan,
 570    project: Entity<Project>,
 571    action_log: Entity<ActionLog>,
 572    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 573    send_task: Option<Task<()>>,
 574    connection: Rc<dyn AgentConnection>,
 575    session_id: acp::SessionId,
 576}
 577
 578pub enum AcpThreadEvent {
 579    NewEntry,
 580    EntryUpdated(usize),
 581    ToolAuthorizationRequired,
 582    Stopped,
 583    Error,
 584}
 585
 586impl EventEmitter<AcpThreadEvent> for AcpThread {}
 587
 588#[derive(PartialEq, Eq)]
 589pub enum ThreadStatus {
 590    Idle,
 591    WaitingForToolConfirmation,
 592    Generating,
 593}
 594
 595#[derive(Debug, Clone)]
 596pub enum LoadError {
 597    Unsupported {
 598        error_message: SharedString,
 599        upgrade_message: SharedString,
 600        upgrade_command: String,
 601    },
 602    Exited(i32),
 603    Other(SharedString),
 604}
 605
 606impl Display for LoadError {
 607    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 608        match self {
 609            LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
 610            LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
 611            LoadError::Other(msg) => write!(f, "{}", msg),
 612        }
 613    }
 614}
 615
 616impl Error for LoadError {}
 617
 618impl AcpThread {
 619    pub fn new(
 620        title: impl Into<SharedString>,
 621        connection: Rc<dyn AgentConnection>,
 622        project: Entity<Project>,
 623        session_id: acp::SessionId,
 624        cx: &mut Context<Self>,
 625    ) -> Self {
 626        let action_log = cx.new(|_| ActionLog::new(project.clone()));
 627
 628        Self {
 629            action_log,
 630            shared_buffers: Default::default(),
 631            entries: Default::default(),
 632            plan: Default::default(),
 633            title: title.into(),
 634            project,
 635            send_task: None,
 636            connection,
 637            session_id,
 638        }
 639    }
 640
 641    pub fn action_log(&self) -> &Entity<ActionLog> {
 642        &self.action_log
 643    }
 644
 645    pub fn project(&self) -> &Entity<Project> {
 646        &self.project
 647    }
 648
 649    pub fn title(&self) -> SharedString {
 650        self.title.clone()
 651    }
 652
 653    pub fn entries(&self) -> &[AgentThreadEntry] {
 654        &self.entries
 655    }
 656
 657    pub fn status(&self) -> ThreadStatus {
 658        if self.send_task.is_some() {
 659            if self.waiting_for_tool_confirmation() {
 660                ThreadStatus::WaitingForToolConfirmation
 661            } else {
 662                ThreadStatus::Generating
 663            }
 664        } else {
 665            ThreadStatus::Idle
 666        }
 667    }
 668
 669    pub fn has_pending_edit_tool_calls(&self) -> bool {
 670        for entry in self.entries.iter().rev() {
 671            match entry {
 672                AgentThreadEntry::UserMessage(_) => return false,
 673                AgentThreadEntry::ToolCall(
 674                    call @ ToolCall {
 675                        status:
 676                            ToolCallStatus::Allowed {
 677                                status:
 678                                    acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
 679                            },
 680                        ..
 681                    },
 682                ) if call.diffs().next().is_some() => {
 683                    return true;
 684                }
 685                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 686            }
 687        }
 688
 689        false
 690    }
 691
 692    pub fn used_tools_since_last_user_message(&self) -> bool {
 693        for entry in self.entries.iter().rev() {
 694            match entry {
 695                AgentThreadEntry::UserMessage(..) => return false,
 696                AgentThreadEntry::AssistantMessage(..) => continue,
 697                AgentThreadEntry::ToolCall(..) => return true,
 698            }
 699        }
 700
 701        false
 702    }
 703
 704    pub fn handle_session_update(
 705        &mut self,
 706        update: acp::SessionUpdate,
 707        cx: &mut Context<Self>,
 708    ) -> Result<()> {
 709        match update {
 710            acp::SessionUpdate::UserMessageChunk { content } => {
 711                self.push_user_content_block(content, cx);
 712            }
 713            acp::SessionUpdate::AgentMessageChunk { content } => {
 714                self.push_assistant_content_block(content, false, cx);
 715            }
 716            acp::SessionUpdate::AgentThoughtChunk { content } => {
 717                self.push_assistant_content_block(content, true, cx);
 718            }
 719            acp::SessionUpdate::ToolCall(tool_call) => {
 720                self.upsert_tool_call(tool_call, cx);
 721            }
 722            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 723                self.update_tool_call(tool_call_update, cx)?;
 724            }
 725            acp::SessionUpdate::Plan(plan) => {
 726                self.update_plan(plan, cx);
 727            }
 728        }
 729        Ok(())
 730    }
 731
 732    pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
 733        let language_registry = self.project.read(cx).languages().clone();
 734        let entries_len = self.entries.len();
 735
 736        if let Some(last_entry) = self.entries.last_mut()
 737            && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
 738        {
 739            content.append(chunk, &language_registry, cx);
 740            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 741        } else {
 742            let content = ContentBlock::new(chunk, &language_registry, cx);
 743            self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
 744        }
 745    }
 746
 747    pub fn push_assistant_content_block(
 748        &mut self,
 749        chunk: acp::ContentBlock,
 750        is_thought: bool,
 751        cx: &mut Context<Self>,
 752    ) {
 753        let language_registry = self.project.read(cx).languages().clone();
 754        let entries_len = self.entries.len();
 755        if let Some(last_entry) = self.entries.last_mut()
 756            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
 757        {
 758            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 759            match (chunks.last_mut(), is_thought) {
 760                (Some(AssistantMessageChunk::Message { block }), false)
 761                | (Some(AssistantMessageChunk::Thought { block }), true) => {
 762                    block.append(chunk, &language_registry, cx)
 763                }
 764                _ => {
 765                    let block = ContentBlock::new(chunk, &language_registry, cx);
 766                    if is_thought {
 767                        chunks.push(AssistantMessageChunk::Thought { block })
 768                    } else {
 769                        chunks.push(AssistantMessageChunk::Message { block })
 770                    }
 771                }
 772            }
 773        } else {
 774            let block = ContentBlock::new(chunk, &language_registry, cx);
 775            let chunk = if is_thought {
 776                AssistantMessageChunk::Thought { block }
 777            } else {
 778                AssistantMessageChunk::Message { block }
 779            };
 780
 781            self.push_entry(
 782                AgentThreadEntry::AssistantMessage(AssistantMessage {
 783                    chunks: vec![chunk],
 784                }),
 785                cx,
 786            );
 787        }
 788    }
 789
 790    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
 791        self.entries.push(entry);
 792        cx.emit(AcpThreadEvent::NewEntry);
 793    }
 794
 795    pub fn update_tool_call(
 796        &mut self,
 797        update: acp::ToolCallUpdate,
 798        cx: &mut Context<Self>,
 799    ) -> Result<()> {
 800        let languages = self.project.read(cx).languages().clone();
 801
 802        let (ix, current_call) = self
 803            .tool_call_mut(&update.id)
 804            .context("Tool call not found")?;
 805        current_call.update(update.fields, languages, cx);
 806
 807        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 808
 809        Ok(())
 810    }
 811
 812    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
 813    pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
 814        let status = ToolCallStatus::Allowed {
 815            status: tool_call.status,
 816        };
 817        self.upsert_tool_call_inner(tool_call, status, cx)
 818    }
 819
 820    pub fn upsert_tool_call_inner(
 821        &mut self,
 822        tool_call: acp::ToolCall,
 823        status: ToolCallStatus,
 824        cx: &mut Context<Self>,
 825    ) {
 826        let language_registry = self.project.read(cx).languages().clone();
 827        let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
 828
 829        let location = call.locations.last().cloned();
 830
 831        if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
 832            *current_call = call;
 833
 834            cx.emit(AcpThreadEvent::EntryUpdated(ix));
 835        } else {
 836            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
 837        }
 838
 839        if let Some(location) = location {
 840            self.set_project_location(location, cx)
 841        }
 842    }
 843
 844    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
 845        // The tool call we are looking for is typically the last one, or very close to the end.
 846        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
 847        self.entries
 848            .iter_mut()
 849            .enumerate()
 850            .rev()
 851            .find_map(|(index, tool_call)| {
 852                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
 853                    && &tool_call.id == id
 854                {
 855                    Some((index, tool_call))
 856                } else {
 857                    None
 858                }
 859            })
 860    }
 861
 862    pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
 863        self.project.update(cx, |project, cx| {
 864            let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
 865                return;
 866            };
 867            let buffer = project.open_buffer(path, cx);
 868            cx.spawn(async move |project, cx| {
 869                let buffer = buffer.await?;
 870
 871                project.update(cx, |project, cx| {
 872                    let position = if let Some(line) = location.line {
 873                        let snapshot = buffer.read(cx).snapshot();
 874                        let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
 875                        snapshot.anchor_before(point)
 876                    } else {
 877                        Anchor::MIN
 878                    };
 879
 880                    project.set_agent_location(
 881                        Some(AgentLocation {
 882                            buffer: buffer.downgrade(),
 883                            position,
 884                        }),
 885                        cx,
 886                    );
 887                })
 888            })
 889            .detach_and_log_err(cx);
 890        });
 891    }
 892
 893    pub fn request_tool_call_permission(
 894        &mut self,
 895        tool_call: acp::ToolCall,
 896        options: Vec<acp::PermissionOption>,
 897        cx: &mut Context<Self>,
 898    ) -> oneshot::Receiver<acp::PermissionOptionId> {
 899        let (tx, rx) = oneshot::channel();
 900
 901        let status = ToolCallStatus::WaitingForConfirmation {
 902            options,
 903            respond_tx: tx,
 904        };
 905
 906        self.upsert_tool_call_inner(tool_call, status, cx);
 907        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
 908        rx
 909    }
 910
 911    pub fn authorize_tool_call(
 912        &mut self,
 913        id: acp::ToolCallId,
 914        option_id: acp::PermissionOptionId,
 915        option_kind: acp::PermissionOptionKind,
 916        cx: &mut Context<Self>,
 917    ) {
 918        let Some((ix, call)) = self.tool_call_mut(&id) else {
 919            return;
 920        };
 921
 922        let new_status = match option_kind {
 923            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
 924                ToolCallStatus::Rejected
 925            }
 926            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
 927                ToolCallStatus::Allowed {
 928                    status: acp::ToolCallStatus::InProgress,
 929                }
 930            }
 931        };
 932
 933        let curr_status = mem::replace(&mut call.status, new_status);
 934
 935        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
 936            respond_tx.send(option_id).log_err();
 937        } else if cfg!(debug_assertions) {
 938            panic!("tried to authorize an already authorized tool call");
 939        }
 940
 941        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 942    }
 943
 944    /// Returns true if the last turn is awaiting tool authorization
 945    pub fn waiting_for_tool_confirmation(&self) -> bool {
 946        for entry in self.entries.iter().rev() {
 947            match &entry {
 948                AgentThreadEntry::ToolCall(call) => match call.status {
 949                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
 950                    ToolCallStatus::Allowed { .. }
 951                    | ToolCallStatus::Rejected
 952                    | ToolCallStatus::Canceled => continue,
 953                },
 954                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
 955                    // Reached the beginning of the turn
 956                    return false;
 957                }
 958            }
 959        }
 960        false
 961    }
 962
 963    pub fn plan(&self) -> &Plan {
 964        &self.plan
 965    }
 966
 967    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
 968        self.plan = Plan {
 969            entries: request
 970                .entries
 971                .into_iter()
 972                .map(|entry| PlanEntry::from_acp(entry, cx))
 973                .collect(),
 974        };
 975
 976        cx.notify();
 977    }
 978
 979    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
 980        self.plan
 981            .entries
 982            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
 983        cx.notify();
 984    }
 985
 986    #[cfg(any(test, feature = "test-support"))]
 987    pub fn send_raw(
 988        &mut self,
 989        message: &str,
 990        cx: &mut Context<Self>,
 991    ) -> BoxFuture<'static, Result<()>> {
 992        self.send(
 993            vec![acp::ContentBlock::Text(acp::TextContent {
 994                text: message.to_string(),
 995                annotations: None,
 996            })],
 997            cx,
 998        )
 999    }
1000
1001    pub fn send(
1002        &mut self,
1003        message: Vec<acp::ContentBlock>,
1004        cx: &mut Context<Self>,
1005    ) -> BoxFuture<'static, Result<()>> {
1006        let block = ContentBlock::new_combined(
1007            message.clone(),
1008            self.project.read(cx).languages().clone(),
1009            cx,
1010        );
1011        self.push_entry(
1012            AgentThreadEntry::UserMessage(UserMessage { content: block }),
1013            cx,
1014        );
1015        self.clear_completed_plan_entries(cx);
1016
1017        let (tx, rx) = oneshot::channel();
1018        let cancel_task = self.cancel(cx);
1019
1020        self.send_task = Some(cx.spawn(async move |this, cx| {
1021            async {
1022                cancel_task.await;
1023
1024                let result = this
1025                    .update(cx, |this, cx| {
1026                        this.connection.prompt(
1027                            acp::PromptRequest {
1028                                prompt: message,
1029                                session_id: this.session_id.clone(),
1030                            },
1031                            cx,
1032                        )
1033                    })?
1034                    .await;
1035                tx.send(result).log_err();
1036                this.update(cx, |this, _cx| this.send_task.take())?;
1037                anyhow::Ok(())
1038            }
1039            .await
1040            .log_err();
1041        }));
1042
1043        cx.spawn(async move |this, cx| match rx.await {
1044            Ok(Err(e)) => {
1045                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1046                    .log_err();
1047                Err(e)?
1048            }
1049            _ => {
1050                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1051                    .log_err();
1052                Ok(())
1053            }
1054        })
1055        .boxed()
1056    }
1057
1058    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1059        let Some(send_task) = self.send_task.take() else {
1060            return Task::ready(());
1061        };
1062
1063        for entry in self.entries.iter_mut() {
1064            if let AgentThreadEntry::ToolCall(call) = entry {
1065                let cancel = matches!(
1066                    call.status,
1067                    ToolCallStatus::WaitingForConfirmation { .. }
1068                        | ToolCallStatus::Allowed {
1069                            status: acp::ToolCallStatus::InProgress
1070                        }
1071                );
1072
1073                if cancel {
1074                    call.status = ToolCallStatus::Canceled;
1075                }
1076            }
1077        }
1078
1079        self.connection.cancel(&self.session_id, cx);
1080
1081        // Wait for the send task to complete
1082        cx.foreground_executor().spawn(send_task)
1083    }
1084
1085    pub fn read_text_file(
1086        &self,
1087        path: PathBuf,
1088        line: Option<u32>,
1089        limit: Option<u32>,
1090        reuse_shared_snapshot: bool,
1091        cx: &mut Context<Self>,
1092    ) -> Task<Result<String>> {
1093        let project = self.project.clone();
1094        let action_log = self.action_log.clone();
1095        cx.spawn(async move |this, cx| {
1096            let load = project.update(cx, |project, cx| {
1097                let path = project
1098                    .project_path_for_absolute_path(&path, cx)
1099                    .context("invalid path")?;
1100                anyhow::Ok(project.open_buffer(path, cx))
1101            });
1102            let buffer = load??.await?;
1103
1104            let snapshot = if reuse_shared_snapshot {
1105                this.read_with(cx, |this, _| {
1106                    this.shared_buffers.get(&buffer.clone()).cloned()
1107                })
1108                .log_err()
1109                .flatten()
1110            } else {
1111                None
1112            };
1113
1114            let snapshot = if let Some(snapshot) = snapshot {
1115                snapshot
1116            } else {
1117                action_log.update(cx, |action_log, cx| {
1118                    action_log.buffer_read(buffer.clone(), cx);
1119                })?;
1120                project.update(cx, |project, cx| {
1121                    let position = buffer
1122                        .read(cx)
1123                        .snapshot()
1124                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1125                    project.set_agent_location(
1126                        Some(AgentLocation {
1127                            buffer: buffer.downgrade(),
1128                            position,
1129                        }),
1130                        cx,
1131                    );
1132                })?;
1133
1134                buffer.update(cx, |buffer, _| buffer.snapshot())?
1135            };
1136
1137            this.update(cx, |this, _| {
1138                let text = snapshot.text();
1139                this.shared_buffers.insert(buffer.clone(), snapshot);
1140                if line.is_none() && limit.is_none() {
1141                    return Ok(text);
1142                }
1143                let limit = limit.unwrap_or(u32::MAX) as usize;
1144                let Some(line) = line else {
1145                    return Ok(text.lines().take(limit).collect::<String>());
1146                };
1147
1148                let count = text.lines().count();
1149                if count < line as usize {
1150                    anyhow::bail!("There are only {} lines", count);
1151                }
1152                Ok(text
1153                    .lines()
1154                    .skip(line as usize + 1)
1155                    .take(limit)
1156                    .collect::<String>())
1157            })?
1158        })
1159    }
1160
1161    pub fn write_text_file(
1162        &self,
1163        path: PathBuf,
1164        content: String,
1165        cx: &mut Context<Self>,
1166    ) -> Task<Result<()>> {
1167        let project = self.project.clone();
1168        let action_log = self.action_log.clone();
1169        cx.spawn(async move |this, cx| {
1170            let load = project.update(cx, |project, cx| {
1171                let path = project
1172                    .project_path_for_absolute_path(&path, cx)
1173                    .context("invalid path")?;
1174                anyhow::Ok(project.open_buffer(path, cx))
1175            });
1176            let buffer = load??.await?;
1177            let snapshot = this.update(cx, |this, cx| {
1178                this.shared_buffers
1179                    .get(&buffer)
1180                    .cloned()
1181                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1182            })?;
1183            let edits = cx
1184                .background_executor()
1185                .spawn(async move {
1186                    let old_text = snapshot.text();
1187                    text_diff(old_text.as_str(), &content)
1188                        .into_iter()
1189                        .map(|(range, replacement)| {
1190                            (
1191                                snapshot.anchor_after(range.start)
1192                                    ..snapshot.anchor_before(range.end),
1193                                replacement,
1194                            )
1195                        })
1196                        .collect::<Vec<_>>()
1197                })
1198                .await;
1199            cx.update(|cx| {
1200                project.update(cx, |project, cx| {
1201                    project.set_agent_location(
1202                        Some(AgentLocation {
1203                            buffer: buffer.downgrade(),
1204                            position: edits
1205                                .last()
1206                                .map(|(range, _)| range.end)
1207                                .unwrap_or(Anchor::MIN),
1208                        }),
1209                        cx,
1210                    );
1211                });
1212
1213                action_log.update(cx, |action_log, cx| {
1214                    action_log.buffer_read(buffer.clone(), cx);
1215                });
1216                buffer.update(cx, |buffer, cx| {
1217                    buffer.edit(edits, None, cx);
1218                });
1219                action_log.update(cx, |action_log, cx| {
1220                    action_log.buffer_edited(buffer.clone(), cx);
1221                });
1222            })?;
1223            project
1224                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1225                .await
1226        })
1227    }
1228
1229    pub fn to_markdown(&self, cx: &App) -> String {
1230        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1231    }
1232}
1233
1234#[cfg(test)]
1235mod tests {
1236    use super::*;
1237    use anyhow::anyhow;
1238    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1239    use gpui::{AsyncApp, TestAppContext, WeakEntity};
1240    use indoc::indoc;
1241    use project::FakeFs;
1242    use rand::Rng as _;
1243    use serde_json::json;
1244    use settings::SettingsStore;
1245    use smol::stream::StreamExt as _;
1246    use std::{cell::RefCell, rc::Rc, time::Duration};
1247
1248    use util::path;
1249
1250    fn init_test(cx: &mut TestAppContext) {
1251        env_logger::try_init().ok();
1252        cx.update(|cx| {
1253            let settings_store = SettingsStore::test(cx);
1254            cx.set_global(settings_store);
1255            Project::init_settings(cx);
1256            language::init(cx);
1257        });
1258    }
1259
1260    #[gpui::test]
1261    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1262        init_test(cx);
1263
1264        let fs = FakeFs::new(cx.executor());
1265        let project = Project::test(fs, [], cx).await;
1266        let connection = Rc::new(FakeAgentConnection::new());
1267        let thread = cx
1268            .spawn(async move |mut cx| {
1269                connection
1270                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1271                    .await
1272            })
1273            .await
1274            .unwrap();
1275
1276        // Test creating a new user message
1277        thread.update(cx, |thread, cx| {
1278            thread.push_user_content_block(
1279                acp::ContentBlock::Text(acp::TextContent {
1280                    annotations: None,
1281                    text: "Hello, ".to_string(),
1282                }),
1283                cx,
1284            );
1285        });
1286
1287        thread.update(cx, |thread, cx| {
1288            assert_eq!(thread.entries.len(), 1);
1289            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1290                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1291            } else {
1292                panic!("Expected UserMessage");
1293            }
1294        });
1295
1296        // Test appending to existing user message
1297        thread.update(cx, |thread, cx| {
1298            thread.push_user_content_block(
1299                acp::ContentBlock::Text(acp::TextContent {
1300                    annotations: None,
1301                    text: "world!".to_string(),
1302                }),
1303                cx,
1304            );
1305        });
1306
1307        thread.update(cx, |thread, cx| {
1308            assert_eq!(thread.entries.len(), 1);
1309            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1310                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1311            } else {
1312                panic!("Expected UserMessage");
1313            }
1314        });
1315
1316        // Test creating new user message after assistant message
1317        thread.update(cx, |thread, cx| {
1318            thread.push_assistant_content_block(
1319                acp::ContentBlock::Text(acp::TextContent {
1320                    annotations: None,
1321                    text: "Assistant response".to_string(),
1322                }),
1323                false,
1324                cx,
1325            );
1326        });
1327
1328        thread.update(cx, |thread, cx| {
1329            thread.push_user_content_block(
1330                acp::ContentBlock::Text(acp::TextContent {
1331                    annotations: None,
1332                    text: "New user message".to_string(),
1333                }),
1334                cx,
1335            );
1336        });
1337
1338        thread.update(cx, |thread, cx| {
1339            assert_eq!(thread.entries.len(), 3);
1340            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1341                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1342            } else {
1343                panic!("Expected UserMessage at index 2");
1344            }
1345        });
1346    }
1347
1348    #[gpui::test]
1349    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1350        init_test(cx);
1351
1352        let fs = FakeFs::new(cx.executor());
1353        let project = Project::test(fs, [], cx).await;
1354        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1355            |_, thread, mut cx| {
1356                async move {
1357                    thread.update(&mut cx, |thread, cx| {
1358                        thread
1359                            .handle_session_update(
1360                                acp::SessionUpdate::AgentThoughtChunk {
1361                                    content: "Thinking ".into(),
1362                                },
1363                                cx,
1364                            )
1365                            .unwrap();
1366                        thread
1367                            .handle_session_update(
1368                                acp::SessionUpdate::AgentThoughtChunk {
1369                                    content: "hard!".into(),
1370                                },
1371                                cx,
1372                            )
1373                            .unwrap();
1374                    })
1375                }
1376                .boxed_local()
1377            },
1378        ));
1379
1380        let thread = cx
1381            .spawn(async move |mut cx| {
1382                connection
1383                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1384                    .await
1385            })
1386            .await
1387            .unwrap();
1388
1389        thread
1390            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1391            .await
1392            .unwrap();
1393
1394        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1395        assert_eq!(
1396            output,
1397            indoc! {r#"
1398            ## User
1399
1400            Hello from Zed!
1401
1402            ## Assistant
1403
1404            <thinking>
1405            Thinking hard!
1406            </thinking>
1407
1408            "#}
1409        );
1410    }
1411
1412    #[gpui::test]
1413    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1414        init_test(cx);
1415
1416        let fs = FakeFs::new(cx.executor());
1417        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1418            .await;
1419        let project = Project::test(fs.clone(), [], cx).await;
1420        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1421        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1422        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1423            move |_, thread, mut cx| {
1424                let read_file_tx = read_file_tx.clone();
1425                async move {
1426                    let content = thread
1427                        .update(&mut cx, |thread, cx| {
1428                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1429                        })
1430                        .unwrap()
1431                        .await
1432                        .unwrap();
1433                    assert_eq!(content, "one\ntwo\nthree\n");
1434                    read_file_tx.take().unwrap().send(()).unwrap();
1435                    thread
1436                        .update(&mut cx, |thread, cx| {
1437                            thread.write_text_file(
1438                                path!("/tmp/foo").into(),
1439                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
1440                                cx,
1441                            )
1442                        })
1443                        .unwrap()
1444                        .await
1445                        .unwrap();
1446                    Ok(())
1447                }
1448                .boxed_local()
1449            },
1450        ));
1451
1452        let (worktree, pathbuf) = project
1453            .update(cx, |project, cx| {
1454                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1455            })
1456            .await
1457            .unwrap();
1458        let buffer = project
1459            .update(cx, |project, cx| {
1460                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1461            })
1462            .await
1463            .unwrap();
1464
1465        let thread = cx
1466            .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1467            .await
1468            .unwrap();
1469
1470        let request = thread.update(cx, |thread, cx| {
1471            thread.send_raw("Extend the count in /tmp/foo", cx)
1472        });
1473        read_file_rx.await.ok();
1474        buffer.update(cx, |buffer, cx| {
1475            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1476        });
1477        cx.run_until_parked();
1478        assert_eq!(
1479            buffer.read_with(cx, |buffer, _| buffer.text()),
1480            "zero\none\ntwo\nthree\nfour\nfive\n"
1481        );
1482        assert_eq!(
1483            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1484            "zero\none\ntwo\nthree\nfour\nfive\n"
1485        );
1486        request.await.unwrap();
1487    }
1488
1489    #[gpui::test]
1490    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1491        init_test(cx);
1492
1493        let fs = FakeFs::new(cx.executor());
1494        let project = Project::test(fs, [], cx).await;
1495        let id = acp::ToolCallId("test".into());
1496
1497        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1498            let id = id.clone();
1499            move |_, thread, mut cx| {
1500                let id = id.clone();
1501                async move {
1502                    thread
1503                        .update(&mut cx, |thread, cx| {
1504                            thread.handle_session_update(
1505                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1506                                    id: id.clone(),
1507                                    label: "Label".into(),
1508                                    kind: acp::ToolKind::Fetch,
1509                                    status: acp::ToolCallStatus::InProgress,
1510                                    content: vec![],
1511                                    locations: vec![],
1512                                    raw_input: None,
1513                                }),
1514                                cx,
1515                            )
1516                        })
1517                        .unwrap()
1518                        .unwrap();
1519                    Ok(())
1520                }
1521                .boxed_local()
1522            }
1523        }));
1524
1525        let thread = cx
1526            .spawn(async move |mut cx| {
1527                connection
1528                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1529                    .await
1530            })
1531            .await
1532            .unwrap();
1533
1534        let request = thread.update(cx, |thread, cx| {
1535            thread.send_raw("Fetch https://example.com", cx)
1536        });
1537
1538        run_until_first_tool_call(&thread, cx).await;
1539
1540        thread.read_with(cx, |thread, _| {
1541            assert!(matches!(
1542                thread.entries[1],
1543                AgentThreadEntry::ToolCall(ToolCall {
1544                    status: ToolCallStatus::Allowed {
1545                        status: acp::ToolCallStatus::InProgress,
1546                        ..
1547                    },
1548                    ..
1549                })
1550            ));
1551        });
1552
1553        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1554
1555        thread.read_with(cx, |thread, _| {
1556            assert!(matches!(
1557                &thread.entries[1],
1558                AgentThreadEntry::ToolCall(ToolCall {
1559                    status: ToolCallStatus::Canceled,
1560                    ..
1561                })
1562            ));
1563        });
1564
1565        thread
1566            .update(cx, |thread, cx| {
1567                thread.handle_session_update(
1568                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1569                        id,
1570                        fields: acp::ToolCallUpdateFields {
1571                            status: Some(acp::ToolCallStatus::Completed),
1572                            ..Default::default()
1573                        },
1574                    }),
1575                    cx,
1576                )
1577            })
1578            .unwrap();
1579
1580        request.await.unwrap();
1581
1582        thread.read_with(cx, |thread, _| {
1583            assert!(matches!(
1584                thread.entries[1],
1585                AgentThreadEntry::ToolCall(ToolCall {
1586                    status: ToolCallStatus::Allowed {
1587                        status: acp::ToolCallStatus::Completed,
1588                        ..
1589                    },
1590                    ..
1591                })
1592            ));
1593        });
1594    }
1595
1596    #[gpui::test]
1597    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1598        init_test(cx);
1599        let fs = FakeFs::new(cx.background_executor.clone());
1600        fs.insert_tree(path!("/test"), json!({})).await;
1601        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1602
1603        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1604            move |_, thread, mut cx| {
1605                async move {
1606                    thread
1607                        .update(&mut cx, |thread, cx| {
1608                            thread.handle_session_update(
1609                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1610                                    id: acp::ToolCallId("test".into()),
1611                                    label: "Label".into(),
1612                                    kind: acp::ToolKind::Edit,
1613                                    status: acp::ToolCallStatus::Completed,
1614                                    content: vec![acp::ToolCallContent::Diff {
1615                                        diff: acp::Diff {
1616                                            path: "/test/test.txt".into(),
1617                                            old_text: None,
1618                                            new_text: "foo".into(),
1619                                        },
1620                                    }],
1621                                    locations: vec![],
1622                                    raw_input: None,
1623                                }),
1624                                cx,
1625                            )
1626                        })
1627                        .unwrap()
1628                        .unwrap();
1629                    Ok(())
1630                }
1631                .boxed_local()
1632            }
1633        }));
1634
1635        let thread = connection
1636            .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1637            .await
1638            .unwrap();
1639        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1640            .await
1641            .unwrap();
1642
1643        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1644    }
1645
1646    async fn run_until_first_tool_call(
1647        thread: &Entity<AcpThread>,
1648        cx: &mut TestAppContext,
1649    ) -> usize {
1650        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1651
1652        let subscription = cx.update(|cx| {
1653            cx.subscribe(thread, move |thread, _, cx| {
1654                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1655                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1656                        return tx.try_send(ix).unwrap();
1657                    }
1658                }
1659            })
1660        });
1661
1662        select! {
1663            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1664                panic!("Timeout waiting for tool call")
1665            }
1666            ix = rx.next().fuse() => {
1667                drop(subscription);
1668                ix.unwrap()
1669            }
1670        }
1671    }
1672
1673    #[derive(Clone, Default)]
1674    struct FakeAgentConnection {
1675        auth_methods: Vec<acp::AuthMethod>,
1676        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1677        on_user_message: Option<
1678            Rc<
1679                dyn Fn(
1680                        acp::PromptRequest,
1681                        WeakEntity<AcpThread>,
1682                        AsyncApp,
1683                    ) -> LocalBoxFuture<'static, Result<()>>
1684                    + 'static,
1685            >,
1686        >,
1687    }
1688
1689    impl FakeAgentConnection {
1690        fn new() -> Self {
1691            Self {
1692                auth_methods: Vec::new(),
1693                on_user_message: None,
1694                sessions: Arc::default(),
1695            }
1696        }
1697
1698        #[expect(unused)]
1699        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1700            self.auth_methods = auth_methods;
1701            self
1702        }
1703
1704        fn on_user_message(
1705            mut self,
1706            handler: impl Fn(
1707                acp::PromptRequest,
1708                WeakEntity<AcpThread>,
1709                AsyncApp,
1710            ) -> LocalBoxFuture<'static, Result<()>>
1711            + 'static,
1712        ) -> Self {
1713            self.on_user_message.replace(Rc::new(handler));
1714            self
1715        }
1716    }
1717
1718    impl AgentConnection for FakeAgentConnection {
1719        fn auth_methods(&self) -> &[acp::AuthMethod] {
1720            &self.auth_methods
1721        }
1722
1723        fn new_thread(
1724            self: Rc<Self>,
1725            project: Entity<Project>,
1726            _cwd: &Path,
1727            cx: &mut gpui::AsyncApp,
1728        ) -> Task<gpui::Result<Entity<AcpThread>>> {
1729            let session_id = acp::SessionId(
1730                rand::thread_rng()
1731                    .sample_iter(&rand::distributions::Alphanumeric)
1732                    .take(7)
1733                    .map(char::from)
1734                    .collect::<String>()
1735                    .into(),
1736            );
1737            let thread = cx
1738                .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1739                .unwrap();
1740            self.sessions.lock().insert(session_id, thread.downgrade());
1741            Task::ready(Ok(thread))
1742        }
1743
1744        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1745            if self.auth_methods().iter().any(|m| m.id == method) {
1746                Task::ready(Ok(()))
1747            } else {
1748                Task::ready(Err(anyhow!("Invalid Auth Method")))
1749            }
1750        }
1751
1752        fn prompt(&self, params: acp::PromptRequest, cx: &mut App) -> Task<gpui::Result<()>> {
1753            let sessions = self.sessions.lock();
1754            let thread = sessions.get(&params.session_id).unwrap();
1755            if let Some(handler) = &self.on_user_message {
1756                let handler = handler.clone();
1757                let thread = thread.clone();
1758                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1759            } else {
1760                Task::ready(Ok(()))
1761            }
1762        }
1763
1764        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1765            let sessions = self.sessions.lock();
1766            let thread = sessions.get(&session_id).unwrap().clone();
1767
1768            cx.spawn(async move |cx| {
1769                thread
1770                    .update(cx, |thread, cx| thread.cancel(cx))
1771                    .unwrap()
1772                    .await
1773            })
1774            .detach();
1775        }
1776    }
1777}