acp_thread.rs

   1mod connection;
   2pub use connection::*;
   3
   4use agent_client_protocol as acp;
   5use anyhow::{Context as _, Result};
   6use assistant_tool::ActionLog;
   7use buffer_diff::BufferDiff;
   8use editor::{Bias, MultiBuffer, PathKey};
   9use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  10use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
  11use itertools::Itertools;
  12use language::{
  13    Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
  14    text_diff,
  15};
  16use markdown::Markdown;
  17use project::{AgentLocation, Project};
  18use std::collections::HashMap;
  19use std::error::Error;
  20use std::fmt::Formatter;
  21use std::process::ExitStatus;
  22use std::rc::Rc;
  23use std::{
  24    fmt::Display,
  25    mem,
  26    path::{Path, PathBuf},
  27    sync::Arc,
  28};
  29use ui::App;
  30use util::ResultExt;
  31
  32#[derive(Debug)]
  33pub struct UserMessage {
  34    pub content: ContentBlock,
  35}
  36
  37impl UserMessage {
  38    pub fn from_acp(
  39        message: impl IntoIterator<Item = acp::ContentBlock>,
  40        language_registry: Arc<LanguageRegistry>,
  41        cx: &mut App,
  42    ) -> Self {
  43        let mut content = ContentBlock::Empty;
  44        for chunk in message {
  45            content.append(chunk, &language_registry, cx)
  46        }
  47        Self { content: content }
  48    }
  49
  50    fn to_markdown(&self, cx: &App) -> String {
  51        format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
  52    }
  53}
  54
  55#[derive(Debug)]
  56pub struct MentionPath<'a>(&'a Path);
  57
  58impl<'a> MentionPath<'a> {
  59    const PREFIX: &'static str = "@file:";
  60
  61    pub fn new(path: &'a Path) -> Self {
  62        MentionPath(path)
  63    }
  64
  65    pub fn try_parse(url: &'a str) -> Option<Self> {
  66        let path = url.strip_prefix(Self::PREFIX)?;
  67        Some(MentionPath(Path::new(path)))
  68    }
  69
  70    pub fn path(&self) -> &Path {
  71        self.0
  72    }
  73}
  74
  75impl Display for MentionPath<'_> {
  76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  77        write!(
  78            f,
  79            "[@{}]({}{})",
  80            self.0.file_name().unwrap_or_default().display(),
  81            Self::PREFIX,
  82            self.0.display()
  83        )
  84    }
  85}
  86
  87#[derive(Debug, PartialEq)]
  88pub struct AssistantMessage {
  89    pub chunks: Vec<AssistantMessageChunk>,
  90}
  91
  92impl AssistantMessage {
  93    pub fn to_markdown(&self, cx: &App) -> String {
  94        format!(
  95            "## Assistant\n\n{}\n\n",
  96            self.chunks
  97                .iter()
  98                .map(|chunk| chunk.to_markdown(cx))
  99                .join("\n\n")
 100        )
 101    }
 102}
 103
 104#[derive(Debug, PartialEq)]
 105pub enum AssistantMessageChunk {
 106    Message { block: ContentBlock },
 107    Thought { block: ContentBlock },
 108}
 109
 110impl AssistantMessageChunk {
 111    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
 112        Self::Message {
 113            block: ContentBlock::new(chunk.into(), language_registry, cx),
 114        }
 115    }
 116
 117    fn to_markdown(&self, cx: &App) -> String {
 118        match self {
 119            Self::Message { block } => block.to_markdown(cx).to_string(),
 120            Self::Thought { block } => {
 121                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Debug)]
 128pub enum AgentThreadEntry {
 129    UserMessage(UserMessage),
 130    AssistantMessage(AssistantMessage),
 131    ToolCall(ToolCall),
 132}
 133
 134impl AgentThreadEntry {
 135    fn to_markdown(&self, cx: &App) -> String {
 136        match self {
 137            Self::UserMessage(message) => message.to_markdown(cx),
 138            Self::AssistantMessage(message) => message.to_markdown(cx),
 139            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 140        }
 141    }
 142
 143    pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
 144        if let AgentThreadEntry::ToolCall(call) = self {
 145            itertools::Either::Left(call.diffs())
 146        } else {
 147            itertools::Either::Right(std::iter::empty())
 148        }
 149    }
 150
 151    pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
 152        if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
 153            Some(locations)
 154        } else {
 155            None
 156        }
 157    }
 158}
 159
 160#[derive(Debug)]
 161pub struct ToolCall {
 162    pub id: acp::ToolCallId,
 163    pub label: Entity<Markdown>,
 164    pub kind: acp::ToolKind,
 165    pub content: Vec<ToolCallContent>,
 166    pub status: ToolCallStatus,
 167    pub locations: Vec<acp::ToolCallLocation>,
 168    pub raw_input: Option<serde_json::Value>,
 169}
 170
 171impl ToolCall {
 172    fn from_acp(
 173        tool_call: acp::ToolCall,
 174        status: ToolCallStatus,
 175        language_registry: Arc<LanguageRegistry>,
 176        cx: &mut App,
 177    ) -> Self {
 178        Self {
 179            id: tool_call.id,
 180            label: cx.new(|cx| {
 181                Markdown::new(
 182                    tool_call.title.into(),
 183                    Some(language_registry.clone()),
 184                    None,
 185                    cx,
 186                )
 187            }),
 188            kind: tool_call.kind,
 189            content: tool_call
 190                .content
 191                .into_iter()
 192                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 193                .collect(),
 194            locations: tool_call.locations,
 195            status,
 196            raw_input: tool_call.raw_input,
 197        }
 198    }
 199
 200    fn update(
 201        &mut self,
 202        fields: acp::ToolCallUpdateFields,
 203        language_registry: Arc<LanguageRegistry>,
 204        cx: &mut App,
 205    ) {
 206        let acp::ToolCallUpdateFields {
 207            kind,
 208            status,
 209            title,
 210            content,
 211            locations,
 212            raw_input,
 213        } = fields;
 214
 215        if let Some(kind) = kind {
 216            self.kind = kind;
 217        }
 218
 219        if let Some(status) = status {
 220            self.status = ToolCallStatus::Allowed { status };
 221        }
 222
 223        if let Some(title) = title {
 224            self.label = cx.new(|cx| Markdown::new_text(title.into(), cx));
 225        }
 226
 227        if let Some(content) = content {
 228            self.content = content
 229                .into_iter()
 230                .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
 231                .collect();
 232        }
 233
 234        if let Some(locations) = locations {
 235            self.locations = locations;
 236        }
 237
 238        if let Some(raw_input) = raw_input {
 239            self.raw_input = Some(raw_input);
 240        }
 241    }
 242
 243    pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
 244        self.content.iter().filter_map(|content| match content {
 245            ToolCallContent::ContentBlock { .. } => None,
 246            ToolCallContent::Diff { diff } => Some(diff),
 247        })
 248    }
 249
 250    fn to_markdown(&self, cx: &App) -> String {
 251        let mut markdown = format!(
 252            "**Tool Call: {}**\nStatus: {}\n\n",
 253            self.label.read(cx).source(),
 254            self.status
 255        );
 256        for content in &self.content {
 257            markdown.push_str(content.to_markdown(cx).as_str());
 258            markdown.push_str("\n\n");
 259        }
 260        markdown
 261    }
 262}
 263
 264#[derive(Debug)]
 265pub enum ToolCallStatus {
 266    WaitingForConfirmation {
 267        options: Vec<acp::PermissionOption>,
 268        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 269    },
 270    Allowed {
 271        status: acp::ToolCallStatus,
 272    },
 273    Rejected,
 274    Canceled,
 275}
 276
 277impl Display for ToolCallStatus {
 278    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 279        write!(
 280            f,
 281            "{}",
 282            match self {
 283                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 284                ToolCallStatus::Allowed { status } => match status {
 285                    acp::ToolCallStatus::Pending => "Pending",
 286                    acp::ToolCallStatus::InProgress => "In Progress",
 287                    acp::ToolCallStatus::Completed => "Completed",
 288                    acp::ToolCallStatus::Failed => "Failed",
 289                },
 290                ToolCallStatus::Rejected => "Rejected",
 291                ToolCallStatus::Canceled => "Canceled",
 292            }
 293        )
 294    }
 295}
 296
 297#[derive(Debug, PartialEq, Clone)]
 298pub enum ContentBlock {
 299    Empty,
 300    Markdown { markdown: Entity<Markdown> },
 301}
 302
 303impl ContentBlock {
 304    pub fn new(
 305        block: acp::ContentBlock,
 306        language_registry: &Arc<LanguageRegistry>,
 307        cx: &mut App,
 308    ) -> Self {
 309        let mut this = Self::Empty;
 310        this.append(block, language_registry, cx);
 311        this
 312    }
 313
 314    pub fn new_combined(
 315        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 316        language_registry: Arc<LanguageRegistry>,
 317        cx: &mut App,
 318    ) -> Self {
 319        let mut this = Self::Empty;
 320        for block in blocks {
 321            this.append(block, &language_registry, cx);
 322        }
 323        this
 324    }
 325
 326    pub fn append(
 327        &mut self,
 328        block: acp::ContentBlock,
 329        language_registry: &Arc<LanguageRegistry>,
 330        cx: &mut App,
 331    ) {
 332        let new_content = match block {
 333            acp::ContentBlock::Text(text_content) => text_content.text.clone(),
 334            acp::ContentBlock::ResourceLink(resource_link) => {
 335                if let Some(path) = resource_link.uri.strip_prefix("file://") {
 336                    format!("{}", MentionPath(path.as_ref()))
 337                } else {
 338                    resource_link.uri.clone()
 339                }
 340            }
 341            acp::ContentBlock::Image(_)
 342            | acp::ContentBlock::Audio(_)
 343            | acp::ContentBlock::Resource(_) => String::new(),
 344        };
 345
 346        match self {
 347            ContentBlock::Empty => {
 348                *self = ContentBlock::Markdown {
 349                    markdown: cx.new(|cx| {
 350                        Markdown::new(
 351                            new_content.into(),
 352                            Some(language_registry.clone()),
 353                            None,
 354                            cx,
 355                        )
 356                    }),
 357                };
 358            }
 359            ContentBlock::Markdown { markdown } => {
 360                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 361            }
 362        }
 363    }
 364
 365    fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 366        match self {
 367            ContentBlock::Empty => "",
 368            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 369        }
 370    }
 371
 372    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 373        match self {
 374            ContentBlock::Empty => None,
 375            ContentBlock::Markdown { markdown } => Some(markdown),
 376        }
 377    }
 378}
 379
 380#[derive(Debug)]
 381pub enum ToolCallContent {
 382    ContentBlock { content: ContentBlock },
 383    Diff { diff: Diff },
 384}
 385
 386impl ToolCallContent {
 387    pub fn from_acp(
 388        content: acp::ToolCallContent,
 389        language_registry: Arc<LanguageRegistry>,
 390        cx: &mut App,
 391    ) -> Self {
 392        match content {
 393            acp::ToolCallContent::Content { content } => Self::ContentBlock {
 394                content: ContentBlock::new(content, &language_registry, cx),
 395            },
 396            acp::ToolCallContent::Diff { diff } => Self::Diff {
 397                diff: Diff::from_acp(diff, language_registry, cx),
 398            },
 399        }
 400    }
 401
 402    pub fn to_markdown(&self, cx: &App) -> String {
 403        match self {
 404            Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
 405            Self::Diff { diff } => diff.to_markdown(cx),
 406        }
 407    }
 408}
 409
 410#[derive(Debug)]
 411pub struct Diff {
 412    pub multibuffer: Entity<MultiBuffer>,
 413    pub path: PathBuf,
 414    _task: Task<Result<()>>,
 415}
 416
 417impl Diff {
 418    pub fn from_acp(
 419        diff: acp::Diff,
 420        language_registry: Arc<LanguageRegistry>,
 421        cx: &mut App,
 422    ) -> Self {
 423        let acp::Diff {
 424            path,
 425            old_text,
 426            new_text,
 427        } = diff;
 428
 429        let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
 430
 431        let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
 432        let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
 433        let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
 434        let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
 435
 436        let task = cx.spawn({
 437            let multibuffer = multibuffer.clone();
 438            let path = path.clone();
 439            async move |cx| {
 440                let language = language_registry
 441                    .language_for_file_path(&path)
 442                    .await
 443                    .log_err();
 444
 445                new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?;
 446
 447                let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| {
 448                    buffer.set_language(language, cx);
 449                    buffer.snapshot()
 450                })?;
 451
 452                buffer_diff
 453                    .update(cx, |diff, cx| {
 454                        diff.set_base_text(
 455                            old_buffer_snapshot,
 456                            Some(language_registry),
 457                            new_buffer_snapshot,
 458                            cx,
 459                        )
 460                    })?
 461                    .await?;
 462
 463                multibuffer
 464                    .update(cx, |multibuffer, cx| {
 465                        let hunk_ranges = {
 466                            let buffer = new_buffer.read(cx);
 467                            let diff = buffer_diff.read(cx);
 468                            diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
 469                                .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
 470                                .collect::<Vec<_>>()
 471                        };
 472
 473                        multibuffer.set_excerpts_for_path(
 474                            PathKey::for_buffer(&new_buffer, cx),
 475                            new_buffer.clone(),
 476                            hunk_ranges,
 477                            editor::DEFAULT_MULTIBUFFER_CONTEXT,
 478                            cx,
 479                        );
 480                        multibuffer.add_diff(buffer_diff, cx);
 481                    })
 482                    .log_err();
 483
 484                anyhow::Ok(())
 485            }
 486        });
 487
 488        Self {
 489            multibuffer,
 490            path,
 491            _task: task,
 492        }
 493    }
 494
 495    fn to_markdown(&self, cx: &App) -> String {
 496        let buffer_text = self
 497            .multibuffer
 498            .read(cx)
 499            .all_buffers()
 500            .iter()
 501            .map(|buffer| buffer.read(cx).text())
 502            .join("\n");
 503        format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
 504    }
 505}
 506
 507#[derive(Debug, Default)]
 508pub struct Plan {
 509    pub entries: Vec<PlanEntry>,
 510}
 511
 512#[derive(Debug)]
 513pub struct PlanStats<'a> {
 514    pub in_progress_entry: Option<&'a PlanEntry>,
 515    pub pending: u32,
 516    pub completed: u32,
 517}
 518
 519impl Plan {
 520    pub fn is_empty(&self) -> bool {
 521        self.entries.is_empty()
 522    }
 523
 524    pub fn stats(&self) -> PlanStats<'_> {
 525        let mut stats = PlanStats {
 526            in_progress_entry: None,
 527            pending: 0,
 528            completed: 0,
 529        };
 530
 531        for entry in &self.entries {
 532            match &entry.status {
 533                acp::PlanEntryStatus::Pending => {
 534                    stats.pending += 1;
 535                }
 536                acp::PlanEntryStatus::InProgress => {
 537                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 538                }
 539                acp::PlanEntryStatus::Completed => {
 540                    stats.completed += 1;
 541                }
 542            }
 543        }
 544
 545        stats
 546    }
 547}
 548
 549#[derive(Debug)]
 550pub struct PlanEntry {
 551    pub content: Entity<Markdown>,
 552    pub priority: acp::PlanEntryPriority,
 553    pub status: acp::PlanEntryStatus,
 554}
 555
 556impl PlanEntry {
 557    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 558        Self {
 559            content: cx.new(|cx| Markdown::new_text(entry.content.into(), cx)),
 560            priority: entry.priority,
 561            status: entry.status,
 562        }
 563    }
 564}
 565
 566pub struct AcpThread {
 567    title: SharedString,
 568    entries: Vec<AgentThreadEntry>,
 569    plan: Plan,
 570    project: Entity<Project>,
 571    action_log: Entity<ActionLog>,
 572    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 573    send_task: Option<Task<()>>,
 574    connection: Rc<dyn AgentConnection>,
 575    session_id: acp::SessionId,
 576}
 577
 578pub enum AcpThreadEvent {
 579    NewEntry,
 580    EntryUpdated(usize),
 581    ToolAuthorizationRequired,
 582    Stopped,
 583    Error,
 584    ServerExited(ExitStatus),
 585}
 586
 587impl EventEmitter<AcpThreadEvent> for AcpThread {}
 588
 589#[derive(PartialEq, Eq)]
 590pub enum ThreadStatus {
 591    Idle,
 592    WaitingForToolConfirmation,
 593    Generating,
 594}
 595
 596#[derive(Debug, Clone)]
 597pub enum LoadError {
 598    Unsupported {
 599        error_message: SharedString,
 600        upgrade_message: SharedString,
 601        upgrade_command: String,
 602    },
 603    Exited(i32),
 604    Other(SharedString),
 605}
 606
 607impl Display for LoadError {
 608    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 609        match self {
 610            LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
 611            LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
 612            LoadError::Other(msg) => write!(f, "{}", msg),
 613        }
 614    }
 615}
 616
 617impl Error for LoadError {}
 618
 619impl AcpThread {
 620    pub fn new(
 621        title: impl Into<SharedString>,
 622        connection: Rc<dyn AgentConnection>,
 623        project: Entity<Project>,
 624        session_id: acp::SessionId,
 625        cx: &mut Context<Self>,
 626    ) -> Self {
 627        let action_log = cx.new(|_| ActionLog::new(project.clone()));
 628
 629        Self {
 630            action_log,
 631            shared_buffers: Default::default(),
 632            entries: Default::default(),
 633            plan: Default::default(),
 634            title: title.into(),
 635            project,
 636            send_task: None,
 637            connection,
 638            session_id,
 639        }
 640    }
 641
 642    pub fn action_log(&self) -> &Entity<ActionLog> {
 643        &self.action_log
 644    }
 645
 646    pub fn project(&self) -> &Entity<Project> {
 647        &self.project
 648    }
 649
 650    pub fn title(&self) -> SharedString {
 651        self.title.clone()
 652    }
 653
 654    pub fn entries(&self) -> &[AgentThreadEntry] {
 655        &self.entries
 656    }
 657
 658    pub fn session_id(&self) -> &acp::SessionId {
 659        &self.session_id
 660    }
 661
 662    pub fn status(&self) -> ThreadStatus {
 663        if self.send_task.is_some() {
 664            if self.waiting_for_tool_confirmation() {
 665                ThreadStatus::WaitingForToolConfirmation
 666            } else {
 667                ThreadStatus::Generating
 668            }
 669        } else {
 670            ThreadStatus::Idle
 671        }
 672    }
 673
 674    pub fn has_pending_edit_tool_calls(&self) -> bool {
 675        for entry in self.entries.iter().rev() {
 676            match entry {
 677                AgentThreadEntry::UserMessage(_) => return false,
 678                AgentThreadEntry::ToolCall(
 679                    call @ ToolCall {
 680                        status:
 681                            ToolCallStatus::Allowed {
 682                                status:
 683                                    acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
 684                            },
 685                        ..
 686                    },
 687                ) if call.diffs().next().is_some() => {
 688                    return true;
 689                }
 690                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 691            }
 692        }
 693
 694        false
 695    }
 696
 697    pub fn used_tools_since_last_user_message(&self) -> bool {
 698        for entry in self.entries.iter().rev() {
 699            match entry {
 700                AgentThreadEntry::UserMessage(..) => return false,
 701                AgentThreadEntry::AssistantMessage(..) => continue,
 702                AgentThreadEntry::ToolCall(..) => return true,
 703            }
 704        }
 705
 706        false
 707    }
 708
 709    pub fn handle_session_update(
 710        &mut self,
 711        update: acp::SessionUpdate,
 712        cx: &mut Context<Self>,
 713    ) -> Result<()> {
 714        match update {
 715            acp::SessionUpdate::UserMessageChunk { content } => {
 716                self.push_user_content_block(content, cx);
 717            }
 718            acp::SessionUpdate::AgentMessageChunk { content } => {
 719                self.push_assistant_content_block(content, false, cx);
 720            }
 721            acp::SessionUpdate::AgentThoughtChunk { content } => {
 722                self.push_assistant_content_block(content, true, cx);
 723            }
 724            acp::SessionUpdate::ToolCall(tool_call) => {
 725                self.upsert_tool_call(tool_call, cx);
 726            }
 727            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 728                self.update_tool_call(tool_call_update, cx)?;
 729            }
 730            acp::SessionUpdate::Plan(plan) => {
 731                self.update_plan(plan, cx);
 732            }
 733        }
 734        Ok(())
 735    }
 736
 737    pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
 738        let language_registry = self.project.read(cx).languages().clone();
 739        let entries_len = self.entries.len();
 740
 741        if let Some(last_entry) = self.entries.last_mut()
 742            && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
 743        {
 744            content.append(chunk, &language_registry, cx);
 745            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 746        } else {
 747            let content = ContentBlock::new(chunk, &language_registry, cx);
 748            self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
 749        }
 750    }
 751
 752    pub fn push_assistant_content_block(
 753        &mut self,
 754        chunk: acp::ContentBlock,
 755        is_thought: bool,
 756        cx: &mut Context<Self>,
 757    ) {
 758        let language_registry = self.project.read(cx).languages().clone();
 759        let entries_len = self.entries.len();
 760        if let Some(last_entry) = self.entries.last_mut()
 761            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
 762        {
 763            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 764            match (chunks.last_mut(), is_thought) {
 765                (Some(AssistantMessageChunk::Message { block }), false)
 766                | (Some(AssistantMessageChunk::Thought { block }), true) => {
 767                    block.append(chunk, &language_registry, cx)
 768                }
 769                _ => {
 770                    let block = ContentBlock::new(chunk, &language_registry, cx);
 771                    if is_thought {
 772                        chunks.push(AssistantMessageChunk::Thought { block })
 773                    } else {
 774                        chunks.push(AssistantMessageChunk::Message { block })
 775                    }
 776                }
 777            }
 778        } else {
 779            let block = ContentBlock::new(chunk, &language_registry, cx);
 780            let chunk = if is_thought {
 781                AssistantMessageChunk::Thought { block }
 782            } else {
 783                AssistantMessageChunk::Message { block }
 784            };
 785
 786            self.push_entry(
 787                AgentThreadEntry::AssistantMessage(AssistantMessage {
 788                    chunks: vec![chunk],
 789                }),
 790                cx,
 791            );
 792        }
 793    }
 794
 795    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
 796        self.entries.push(entry);
 797        cx.emit(AcpThreadEvent::NewEntry);
 798    }
 799
 800    pub fn update_tool_call(
 801        &mut self,
 802        update: acp::ToolCallUpdate,
 803        cx: &mut Context<Self>,
 804    ) -> Result<()> {
 805        let languages = self.project.read(cx).languages().clone();
 806
 807        let (ix, current_call) = self
 808            .tool_call_mut(&update.id)
 809            .context("Tool call not found")?;
 810        current_call.update(update.fields, languages, cx);
 811
 812        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 813
 814        Ok(())
 815    }
 816
 817    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
 818    pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
 819        let status = ToolCallStatus::Allowed {
 820            status: tool_call.status,
 821        };
 822        self.upsert_tool_call_inner(tool_call, status, cx)
 823    }
 824
 825    pub fn upsert_tool_call_inner(
 826        &mut self,
 827        tool_call: acp::ToolCall,
 828        status: ToolCallStatus,
 829        cx: &mut Context<Self>,
 830    ) {
 831        let language_registry = self.project.read(cx).languages().clone();
 832        let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
 833
 834        let location = call.locations.last().cloned();
 835
 836        if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
 837            *current_call = call;
 838
 839            cx.emit(AcpThreadEvent::EntryUpdated(ix));
 840        } else {
 841            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
 842        }
 843
 844        if let Some(location) = location {
 845            self.set_project_location(location, cx)
 846        }
 847    }
 848
 849    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
 850        // The tool call we are looking for is typically the last one, or very close to the end.
 851        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
 852        self.entries
 853            .iter_mut()
 854            .enumerate()
 855            .rev()
 856            .find_map(|(index, tool_call)| {
 857                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
 858                    && &tool_call.id == id
 859                {
 860                    Some((index, tool_call))
 861                } else {
 862                    None
 863                }
 864            })
 865    }
 866
 867    pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
 868        self.project.update(cx, |project, cx| {
 869            let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
 870                return;
 871            };
 872            let buffer = project.open_buffer(path, cx);
 873            cx.spawn(async move |project, cx| {
 874                let buffer = buffer.await?;
 875
 876                project.update(cx, |project, cx| {
 877                    let position = if let Some(line) = location.line {
 878                        let snapshot = buffer.read(cx).snapshot();
 879                        let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
 880                        snapshot.anchor_before(point)
 881                    } else {
 882                        Anchor::MIN
 883                    };
 884
 885                    project.set_agent_location(
 886                        Some(AgentLocation {
 887                            buffer: buffer.downgrade(),
 888                            position,
 889                        }),
 890                        cx,
 891                    );
 892                })
 893            })
 894            .detach_and_log_err(cx);
 895        });
 896    }
 897
 898    pub fn request_tool_call_permission(
 899        &mut self,
 900        tool_call: acp::ToolCall,
 901        options: Vec<acp::PermissionOption>,
 902        cx: &mut Context<Self>,
 903    ) -> oneshot::Receiver<acp::PermissionOptionId> {
 904        let (tx, rx) = oneshot::channel();
 905
 906        let status = ToolCallStatus::WaitingForConfirmation {
 907            options,
 908            respond_tx: tx,
 909        };
 910
 911        self.upsert_tool_call_inner(tool_call, status, cx);
 912        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
 913        rx
 914    }
 915
 916    pub fn authorize_tool_call(
 917        &mut self,
 918        id: acp::ToolCallId,
 919        option_id: acp::PermissionOptionId,
 920        option_kind: acp::PermissionOptionKind,
 921        cx: &mut Context<Self>,
 922    ) {
 923        let Some((ix, call)) = self.tool_call_mut(&id) else {
 924            return;
 925        };
 926
 927        let new_status = match option_kind {
 928            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
 929                ToolCallStatus::Rejected
 930            }
 931            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
 932                ToolCallStatus::Allowed {
 933                    status: acp::ToolCallStatus::InProgress,
 934                }
 935            }
 936        };
 937
 938        let curr_status = mem::replace(&mut call.status, new_status);
 939
 940        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
 941            respond_tx.send(option_id).log_err();
 942        } else if cfg!(debug_assertions) {
 943            panic!("tried to authorize an already authorized tool call");
 944        }
 945
 946        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 947    }
 948
 949    /// Returns true if the last turn is awaiting tool authorization
 950    pub fn waiting_for_tool_confirmation(&self) -> bool {
 951        for entry in self.entries.iter().rev() {
 952            match &entry {
 953                AgentThreadEntry::ToolCall(call) => match call.status {
 954                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
 955                    ToolCallStatus::Allowed { .. }
 956                    | ToolCallStatus::Rejected
 957                    | ToolCallStatus::Canceled => continue,
 958                },
 959                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
 960                    // Reached the beginning of the turn
 961                    return false;
 962                }
 963            }
 964        }
 965        false
 966    }
 967
 968    pub fn plan(&self) -> &Plan {
 969        &self.plan
 970    }
 971
 972    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
 973        self.plan = Plan {
 974            entries: request
 975                .entries
 976                .into_iter()
 977                .map(|entry| PlanEntry::from_acp(entry, cx))
 978                .collect(),
 979        };
 980
 981        cx.notify();
 982    }
 983
 984    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
 985        self.plan
 986            .entries
 987            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
 988        cx.notify();
 989    }
 990
 991    #[cfg(any(test, feature = "test-support"))]
 992    pub fn send_raw(
 993        &mut self,
 994        message: &str,
 995        cx: &mut Context<Self>,
 996    ) -> BoxFuture<'static, Result<()>> {
 997        self.send(
 998            vec![acp::ContentBlock::Text(acp::TextContent {
 999                text: message.to_string(),
1000                annotations: None,
1001            })],
1002            cx,
1003        )
1004    }
1005
1006    pub fn send(
1007        &mut self,
1008        message: Vec<acp::ContentBlock>,
1009        cx: &mut Context<Self>,
1010    ) -> BoxFuture<'static, Result<()>> {
1011        let block = ContentBlock::new_combined(
1012            message.clone(),
1013            self.project.read(cx).languages().clone(),
1014            cx,
1015        );
1016        self.push_entry(
1017            AgentThreadEntry::UserMessage(UserMessage { content: block }),
1018            cx,
1019        );
1020        self.clear_completed_plan_entries(cx);
1021
1022        let (tx, rx) = oneshot::channel();
1023        let cancel_task = self.cancel(cx);
1024
1025        self.send_task = Some(cx.spawn(async move |this, cx| {
1026            async {
1027                cancel_task.await;
1028
1029                let result = this
1030                    .update(cx, |this, cx| {
1031                        this.connection.prompt(
1032                            acp::PromptRequest {
1033                                prompt: message,
1034                                session_id: this.session_id.clone(),
1035                            },
1036                            cx,
1037                        )
1038                    })?
1039                    .await;
1040                tx.send(result).log_err();
1041                this.update(cx, |this, _cx| this.send_task.take())?;
1042                anyhow::Ok(())
1043            }
1044            .await
1045            .log_err();
1046        }));
1047
1048        cx.spawn(async move |this, cx| match rx.await {
1049            Ok(Err(e)) => {
1050                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1051                    .log_err();
1052                Err(e)?
1053            }
1054            _ => {
1055                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1056                    .log_err();
1057                Ok(())
1058            }
1059        })
1060        .boxed()
1061    }
1062
1063    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1064        let Some(send_task) = self.send_task.take() else {
1065            return Task::ready(());
1066        };
1067
1068        for entry in self.entries.iter_mut() {
1069            if let AgentThreadEntry::ToolCall(call) = entry {
1070                let cancel = matches!(
1071                    call.status,
1072                    ToolCallStatus::WaitingForConfirmation { .. }
1073                        | ToolCallStatus::Allowed {
1074                            status: acp::ToolCallStatus::InProgress
1075                        }
1076                );
1077
1078                if cancel {
1079                    call.status = ToolCallStatus::Canceled;
1080                }
1081            }
1082        }
1083
1084        self.connection.cancel(&self.session_id, cx);
1085
1086        // Wait for the send task to complete
1087        cx.foreground_executor().spawn(send_task)
1088    }
1089
1090    pub fn read_text_file(
1091        &self,
1092        path: PathBuf,
1093        line: Option<u32>,
1094        limit: Option<u32>,
1095        reuse_shared_snapshot: bool,
1096        cx: &mut Context<Self>,
1097    ) -> Task<Result<String>> {
1098        let project = self.project.clone();
1099        let action_log = self.action_log.clone();
1100        cx.spawn(async move |this, cx| {
1101            let load = project.update(cx, |project, cx| {
1102                let path = project
1103                    .project_path_for_absolute_path(&path, cx)
1104                    .context("invalid path")?;
1105                anyhow::Ok(project.open_buffer(path, cx))
1106            });
1107            let buffer = load??.await?;
1108
1109            let snapshot = if reuse_shared_snapshot {
1110                this.read_with(cx, |this, _| {
1111                    this.shared_buffers.get(&buffer.clone()).cloned()
1112                })
1113                .log_err()
1114                .flatten()
1115            } else {
1116                None
1117            };
1118
1119            let snapshot = if let Some(snapshot) = snapshot {
1120                snapshot
1121            } else {
1122                action_log.update(cx, |action_log, cx| {
1123                    action_log.buffer_read(buffer.clone(), cx);
1124                })?;
1125                project.update(cx, |project, cx| {
1126                    let position = buffer
1127                        .read(cx)
1128                        .snapshot()
1129                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1130                    project.set_agent_location(
1131                        Some(AgentLocation {
1132                            buffer: buffer.downgrade(),
1133                            position,
1134                        }),
1135                        cx,
1136                    );
1137                })?;
1138
1139                buffer.update(cx, |buffer, _| buffer.snapshot())?
1140            };
1141
1142            this.update(cx, |this, _| {
1143                let text = snapshot.text();
1144                this.shared_buffers.insert(buffer.clone(), snapshot);
1145                if line.is_none() && limit.is_none() {
1146                    return Ok(text);
1147                }
1148                let limit = limit.unwrap_or(u32::MAX) as usize;
1149                let Some(line) = line else {
1150                    return Ok(text.lines().take(limit).collect::<String>());
1151                };
1152
1153                let count = text.lines().count();
1154                if count < line as usize {
1155                    anyhow::bail!("There are only {} lines", count);
1156                }
1157                Ok(text
1158                    .lines()
1159                    .skip(line as usize + 1)
1160                    .take(limit)
1161                    .collect::<String>())
1162            })?
1163        })
1164    }
1165
1166    pub fn write_text_file(
1167        &self,
1168        path: PathBuf,
1169        content: String,
1170        cx: &mut Context<Self>,
1171    ) -> Task<Result<()>> {
1172        let project = self.project.clone();
1173        let action_log = self.action_log.clone();
1174        cx.spawn(async move |this, cx| {
1175            let load = project.update(cx, |project, cx| {
1176                let path = project
1177                    .project_path_for_absolute_path(&path, cx)
1178                    .context("invalid path")?;
1179                anyhow::Ok(project.open_buffer(path, cx))
1180            });
1181            let buffer = load??.await?;
1182            let snapshot = this.update(cx, |this, cx| {
1183                this.shared_buffers
1184                    .get(&buffer)
1185                    .cloned()
1186                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1187            })?;
1188            let edits = cx
1189                .background_executor()
1190                .spawn(async move {
1191                    let old_text = snapshot.text();
1192                    text_diff(old_text.as_str(), &content)
1193                        .into_iter()
1194                        .map(|(range, replacement)| {
1195                            (
1196                                snapshot.anchor_after(range.start)
1197                                    ..snapshot.anchor_before(range.end),
1198                                replacement,
1199                            )
1200                        })
1201                        .collect::<Vec<_>>()
1202                })
1203                .await;
1204            cx.update(|cx| {
1205                project.update(cx, |project, cx| {
1206                    project.set_agent_location(
1207                        Some(AgentLocation {
1208                            buffer: buffer.downgrade(),
1209                            position: edits
1210                                .last()
1211                                .map(|(range, _)| range.end)
1212                                .unwrap_or(Anchor::MIN),
1213                        }),
1214                        cx,
1215                    );
1216                });
1217
1218                action_log.update(cx, |action_log, cx| {
1219                    action_log.buffer_read(buffer.clone(), cx);
1220                });
1221                buffer.update(cx, |buffer, cx| {
1222                    buffer.edit(edits, None, cx);
1223                });
1224                action_log.update(cx, |action_log, cx| {
1225                    action_log.buffer_edited(buffer.clone(), cx);
1226                });
1227            })?;
1228            project
1229                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1230                .await
1231        })
1232    }
1233
1234    pub fn to_markdown(&self, cx: &App) -> String {
1235        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1236    }
1237
1238    pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1239        cx.emit(AcpThreadEvent::ServerExited(status));
1240    }
1241}
1242
1243#[cfg(test)]
1244mod tests {
1245    use super::*;
1246    use anyhow::anyhow;
1247    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1248    use gpui::{AsyncApp, TestAppContext, WeakEntity};
1249    use indoc::indoc;
1250    use project::FakeFs;
1251    use rand::Rng as _;
1252    use serde_json::json;
1253    use settings::SettingsStore;
1254    use smol::stream::StreamExt as _;
1255    use std::{cell::RefCell, rc::Rc, time::Duration};
1256
1257    use util::path;
1258
1259    fn init_test(cx: &mut TestAppContext) {
1260        env_logger::try_init().ok();
1261        cx.update(|cx| {
1262            let settings_store = SettingsStore::test(cx);
1263            cx.set_global(settings_store);
1264            Project::init_settings(cx);
1265            language::init(cx);
1266        });
1267    }
1268
1269    #[gpui::test]
1270    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1271        init_test(cx);
1272
1273        let fs = FakeFs::new(cx.executor());
1274        let project = Project::test(fs, [], cx).await;
1275        let connection = Rc::new(FakeAgentConnection::new());
1276        let thread = cx
1277            .spawn(async move |mut cx| {
1278                connection
1279                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1280                    .await
1281            })
1282            .await
1283            .unwrap();
1284
1285        // Test creating a new user message
1286        thread.update(cx, |thread, cx| {
1287            thread.push_user_content_block(
1288                acp::ContentBlock::Text(acp::TextContent {
1289                    annotations: None,
1290                    text: "Hello, ".to_string(),
1291                }),
1292                cx,
1293            );
1294        });
1295
1296        thread.update(cx, |thread, cx| {
1297            assert_eq!(thread.entries.len(), 1);
1298            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1299                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1300            } else {
1301                panic!("Expected UserMessage");
1302            }
1303        });
1304
1305        // Test appending to existing user message
1306        thread.update(cx, |thread, cx| {
1307            thread.push_user_content_block(
1308                acp::ContentBlock::Text(acp::TextContent {
1309                    annotations: None,
1310                    text: "world!".to_string(),
1311                }),
1312                cx,
1313            );
1314        });
1315
1316        thread.update(cx, |thread, cx| {
1317            assert_eq!(thread.entries.len(), 1);
1318            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1319                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1320            } else {
1321                panic!("Expected UserMessage");
1322            }
1323        });
1324
1325        // Test creating new user message after assistant message
1326        thread.update(cx, |thread, cx| {
1327            thread.push_assistant_content_block(
1328                acp::ContentBlock::Text(acp::TextContent {
1329                    annotations: None,
1330                    text: "Assistant response".to_string(),
1331                }),
1332                false,
1333                cx,
1334            );
1335        });
1336
1337        thread.update(cx, |thread, cx| {
1338            thread.push_user_content_block(
1339                acp::ContentBlock::Text(acp::TextContent {
1340                    annotations: None,
1341                    text: "New user message".to_string(),
1342                }),
1343                cx,
1344            );
1345        });
1346
1347        thread.update(cx, |thread, cx| {
1348            assert_eq!(thread.entries.len(), 3);
1349            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1350                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1351            } else {
1352                panic!("Expected UserMessage at index 2");
1353            }
1354        });
1355    }
1356
1357    #[gpui::test]
1358    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1359        init_test(cx);
1360
1361        let fs = FakeFs::new(cx.executor());
1362        let project = Project::test(fs, [], cx).await;
1363        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1364            |_, thread, mut cx| {
1365                async move {
1366                    thread.update(&mut cx, |thread, cx| {
1367                        thread
1368                            .handle_session_update(
1369                                acp::SessionUpdate::AgentThoughtChunk {
1370                                    content: "Thinking ".into(),
1371                                },
1372                                cx,
1373                            )
1374                            .unwrap();
1375                        thread
1376                            .handle_session_update(
1377                                acp::SessionUpdate::AgentThoughtChunk {
1378                                    content: "hard!".into(),
1379                                },
1380                                cx,
1381                            )
1382                            .unwrap();
1383                    })?;
1384                    Ok(acp::PromptResponse {
1385                        stop_reason: acp::StopReason::EndTurn,
1386                    })
1387                }
1388                .boxed_local()
1389            },
1390        ));
1391
1392        let thread = cx
1393            .spawn(async move |mut cx| {
1394                connection
1395                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1396                    .await
1397            })
1398            .await
1399            .unwrap();
1400
1401        thread
1402            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1403            .await
1404            .unwrap();
1405
1406        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1407        assert_eq!(
1408            output,
1409            indoc! {r#"
1410            ## User
1411
1412            Hello from Zed!
1413
1414            ## Assistant
1415
1416            <thinking>
1417            Thinking hard!
1418            </thinking>
1419
1420            "#}
1421        );
1422    }
1423
1424    #[gpui::test]
1425    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1426        init_test(cx);
1427
1428        let fs = FakeFs::new(cx.executor());
1429        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1430            .await;
1431        let project = Project::test(fs.clone(), [], cx).await;
1432        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1433        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1434        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1435            move |_, thread, mut cx| {
1436                let read_file_tx = read_file_tx.clone();
1437                async move {
1438                    let content = thread
1439                        .update(&mut cx, |thread, cx| {
1440                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1441                        })
1442                        .unwrap()
1443                        .await
1444                        .unwrap();
1445                    assert_eq!(content, "one\ntwo\nthree\n");
1446                    read_file_tx.take().unwrap().send(()).unwrap();
1447                    thread
1448                        .update(&mut cx, |thread, cx| {
1449                            thread.write_text_file(
1450                                path!("/tmp/foo").into(),
1451                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
1452                                cx,
1453                            )
1454                        })
1455                        .unwrap()
1456                        .await
1457                        .unwrap();
1458                    Ok(acp::PromptResponse {
1459                        stop_reason: acp::StopReason::EndTurn,
1460                    })
1461                }
1462                .boxed_local()
1463            },
1464        ));
1465
1466        let (worktree, pathbuf) = project
1467            .update(cx, |project, cx| {
1468                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1469            })
1470            .await
1471            .unwrap();
1472        let buffer = project
1473            .update(cx, |project, cx| {
1474                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1475            })
1476            .await
1477            .unwrap();
1478
1479        let thread = cx
1480            .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1481            .await
1482            .unwrap();
1483
1484        let request = thread.update(cx, |thread, cx| {
1485            thread.send_raw("Extend the count in /tmp/foo", cx)
1486        });
1487        read_file_rx.await.ok();
1488        buffer.update(cx, |buffer, cx| {
1489            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1490        });
1491        cx.run_until_parked();
1492        assert_eq!(
1493            buffer.read_with(cx, |buffer, _| buffer.text()),
1494            "zero\none\ntwo\nthree\nfour\nfive\n"
1495        );
1496        assert_eq!(
1497            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1498            "zero\none\ntwo\nthree\nfour\nfive\n"
1499        );
1500        request.await.unwrap();
1501    }
1502
1503    #[gpui::test]
1504    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1505        init_test(cx);
1506
1507        let fs = FakeFs::new(cx.executor());
1508        let project = Project::test(fs, [], cx).await;
1509        let id = acp::ToolCallId("test".into());
1510
1511        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1512            let id = id.clone();
1513            move |_, thread, mut cx| {
1514                let id = id.clone();
1515                async move {
1516                    thread
1517                        .update(&mut cx, |thread, cx| {
1518                            thread.handle_session_update(
1519                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1520                                    id: id.clone(),
1521                                    title: "Label".into(),
1522                                    kind: acp::ToolKind::Fetch,
1523                                    status: acp::ToolCallStatus::InProgress,
1524                                    content: vec![],
1525                                    locations: vec![],
1526                                    raw_input: None,
1527                                }),
1528                                cx,
1529                            )
1530                        })
1531                        .unwrap()
1532                        .unwrap();
1533                    Ok(acp::PromptResponse {
1534                        stop_reason: acp::StopReason::EndTurn,
1535                    })
1536                }
1537                .boxed_local()
1538            }
1539        }));
1540
1541        let thread = cx
1542            .spawn(async move |mut cx| {
1543                connection
1544                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1545                    .await
1546            })
1547            .await
1548            .unwrap();
1549
1550        let request = thread.update(cx, |thread, cx| {
1551            thread.send_raw("Fetch https://example.com", cx)
1552        });
1553
1554        run_until_first_tool_call(&thread, cx).await;
1555
1556        thread.read_with(cx, |thread, _| {
1557            assert!(matches!(
1558                thread.entries[1],
1559                AgentThreadEntry::ToolCall(ToolCall {
1560                    status: ToolCallStatus::Allowed {
1561                        status: acp::ToolCallStatus::InProgress,
1562                        ..
1563                    },
1564                    ..
1565                })
1566            ));
1567        });
1568
1569        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1570
1571        thread.read_with(cx, |thread, _| {
1572            assert!(matches!(
1573                &thread.entries[1],
1574                AgentThreadEntry::ToolCall(ToolCall {
1575                    status: ToolCallStatus::Canceled,
1576                    ..
1577                })
1578            ));
1579        });
1580
1581        thread
1582            .update(cx, |thread, cx| {
1583                thread.handle_session_update(
1584                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1585                        id,
1586                        fields: acp::ToolCallUpdateFields {
1587                            status: Some(acp::ToolCallStatus::Completed),
1588                            ..Default::default()
1589                        },
1590                    }),
1591                    cx,
1592                )
1593            })
1594            .unwrap();
1595
1596        request.await.unwrap();
1597
1598        thread.read_with(cx, |thread, _| {
1599            assert!(matches!(
1600                thread.entries[1],
1601                AgentThreadEntry::ToolCall(ToolCall {
1602                    status: ToolCallStatus::Allowed {
1603                        status: acp::ToolCallStatus::Completed,
1604                        ..
1605                    },
1606                    ..
1607                })
1608            ));
1609        });
1610    }
1611
1612    #[gpui::test]
1613    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1614        init_test(cx);
1615        let fs = FakeFs::new(cx.background_executor.clone());
1616        fs.insert_tree(path!("/test"), json!({})).await;
1617        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1618
1619        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1620            move |_, thread, mut cx| {
1621                async move {
1622                    thread
1623                        .update(&mut cx, |thread, cx| {
1624                            thread.handle_session_update(
1625                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1626                                    id: acp::ToolCallId("test".into()),
1627                                    title: "Label".into(),
1628                                    kind: acp::ToolKind::Edit,
1629                                    status: acp::ToolCallStatus::Completed,
1630                                    content: vec![acp::ToolCallContent::Diff {
1631                                        diff: acp::Diff {
1632                                            path: "/test/test.txt".into(),
1633                                            old_text: None,
1634                                            new_text: "foo".into(),
1635                                        },
1636                                    }],
1637                                    locations: vec![],
1638                                    raw_input: None,
1639                                }),
1640                                cx,
1641                            )
1642                        })
1643                        .unwrap()
1644                        .unwrap();
1645                    Ok(acp::PromptResponse {
1646                        stop_reason: acp::StopReason::EndTurn,
1647                    })
1648                }
1649                .boxed_local()
1650            }
1651        }));
1652
1653        let thread = connection
1654            .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1655            .await
1656            .unwrap();
1657        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1658            .await
1659            .unwrap();
1660
1661        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1662    }
1663
1664    async fn run_until_first_tool_call(
1665        thread: &Entity<AcpThread>,
1666        cx: &mut TestAppContext,
1667    ) -> usize {
1668        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1669
1670        let subscription = cx.update(|cx| {
1671            cx.subscribe(thread, move |thread, _, cx| {
1672                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1673                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1674                        return tx.try_send(ix).unwrap();
1675                    }
1676                }
1677            })
1678        });
1679
1680        select! {
1681            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1682                panic!("Timeout waiting for tool call")
1683            }
1684            ix = rx.next().fuse() => {
1685                drop(subscription);
1686                ix.unwrap()
1687            }
1688        }
1689    }
1690
1691    #[derive(Clone, Default)]
1692    struct FakeAgentConnection {
1693        auth_methods: Vec<acp::AuthMethod>,
1694        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1695        on_user_message: Option<
1696            Rc<
1697                dyn Fn(
1698                        acp::PromptRequest,
1699                        WeakEntity<AcpThread>,
1700                        AsyncApp,
1701                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1702                    + 'static,
1703            >,
1704        >,
1705    }
1706
1707    impl FakeAgentConnection {
1708        fn new() -> Self {
1709            Self {
1710                auth_methods: Vec::new(),
1711                on_user_message: None,
1712                sessions: Arc::default(),
1713            }
1714        }
1715
1716        #[expect(unused)]
1717        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1718            self.auth_methods = auth_methods;
1719            self
1720        }
1721
1722        fn on_user_message(
1723            mut self,
1724            handler: impl Fn(
1725                acp::PromptRequest,
1726                WeakEntity<AcpThread>,
1727                AsyncApp,
1728            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1729            + 'static,
1730        ) -> Self {
1731            self.on_user_message.replace(Rc::new(handler));
1732            self
1733        }
1734    }
1735
1736    impl AgentConnection for FakeAgentConnection {
1737        fn auth_methods(&self) -> &[acp::AuthMethod] {
1738            &self.auth_methods
1739        }
1740
1741        fn new_thread(
1742            self: Rc<Self>,
1743            project: Entity<Project>,
1744            _cwd: &Path,
1745            cx: &mut gpui::AsyncApp,
1746        ) -> Task<gpui::Result<Entity<AcpThread>>> {
1747            let session_id = acp::SessionId(
1748                rand::thread_rng()
1749                    .sample_iter(&rand::distributions::Alphanumeric)
1750                    .take(7)
1751                    .map(char::from)
1752                    .collect::<String>()
1753                    .into(),
1754            );
1755            let thread = cx
1756                .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1757                .unwrap();
1758            self.sessions.lock().insert(session_id, thread.downgrade());
1759            Task::ready(Ok(thread))
1760        }
1761
1762        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1763            if self.auth_methods().iter().any(|m| m.id == method) {
1764                Task::ready(Ok(()))
1765            } else {
1766                Task::ready(Err(anyhow!("Invalid Auth Method")))
1767            }
1768        }
1769
1770        fn prompt(
1771            &self,
1772            params: acp::PromptRequest,
1773            cx: &mut App,
1774        ) -> Task<gpui::Result<acp::PromptResponse>> {
1775            let sessions = self.sessions.lock();
1776            let thread = sessions.get(&params.session_id).unwrap();
1777            if let Some(handler) = &self.on_user_message {
1778                let handler = handler.clone();
1779                let thread = thread.clone();
1780                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1781            } else {
1782                Task::ready(Ok(acp::PromptResponse {
1783                    stop_reason: acp::StopReason::EndTurn,
1784                }))
1785            }
1786        }
1787
1788        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1789            let sessions = self.sessions.lock();
1790            let thread = sessions.get(&session_id).unwrap().clone();
1791
1792            cx.spawn(async move |cx| {
1793                thread
1794                    .update(cx, |thread, cx| thread.cancel(cx))
1795                    .unwrap()
1796                    .await
1797            })
1798            .detach();
1799        }
1800    }
1801}