acp_thread.rs

   1mod connection;
   2pub use connection::*;
   3
   4use agent_client_protocol as acp;
   5use anyhow::{Context as _, Result};
   6use assistant_tool::ActionLog;
   7use buffer_diff::BufferDiff;
   8use editor::{Bias, MultiBuffer, PathKey};
   9use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  10use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
  11use itertools::Itertools;
  12use language::{
  13    Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
  14    text_diff,
  15};
  16use markdown::Markdown;
  17use project::{AgentLocation, Project};
  18use std::collections::HashMap;
  19use std::error::Error;
  20use std::fmt::Formatter;
  21use std::process::ExitStatus;
  22use std::rc::Rc;
  23use std::{
  24    fmt::Display,
  25    mem,
  26    path::{Path, PathBuf},
  27    sync::Arc,
  28};
  29use ui::App;
  30use util::ResultExt;
  31
  32#[derive(Debug)]
  33pub struct UserMessage {
  34    pub content: ContentBlock,
  35}
  36
  37impl UserMessage {
  38    pub fn from_acp(
  39        message: impl IntoIterator<Item = acp::ContentBlock>,
  40        language_registry: Arc<LanguageRegistry>,
  41        cx: &mut App,
  42    ) -> Self {
  43        let mut content = ContentBlock::Empty;
  44        for chunk in message {
  45            content.append(chunk, &language_registry, cx)
  46        }
  47        Self { content: content }
  48    }
  49
  50    fn to_markdown(&self, cx: &App) -> String {
  51        format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
  52    }
  53}
  54
  55#[derive(Debug)]
  56pub struct MentionPath<'a>(&'a Path);
  57
  58impl<'a> MentionPath<'a> {
  59    const PREFIX: &'static str = "@file:";
  60
  61    pub fn new(path: &'a Path) -> Self {
  62        MentionPath(path)
  63    }
  64
  65    pub fn try_parse(url: &'a str) -> Option<Self> {
  66        let path = url.strip_prefix(Self::PREFIX)?;
  67        Some(MentionPath(Path::new(path)))
  68    }
  69
  70    pub fn path(&self) -> &Path {
  71        self.0
  72    }
  73}
  74
  75impl Display for MentionPath<'_> {
  76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  77        write!(
  78            f,
  79            "[@{}]({}{})",
  80            self.0.file_name().unwrap_or_default().display(),
  81            Self::PREFIX,
  82            self.0.display()
  83        )
  84    }
  85}
  86
  87#[derive(Debug, PartialEq)]
  88pub struct AssistantMessage {
  89    pub chunks: Vec<AssistantMessageChunk>,
  90}
  91
  92impl AssistantMessage {
  93    pub fn to_markdown(&self, cx: &App) -> String {
  94        format!(
  95            "## Assistant\n\n{}\n\n",
  96            self.chunks
  97                .iter()
  98                .map(|chunk| chunk.to_markdown(cx))
  99                .join("\n\n")
 100        )
 101    }
 102}
 103
 104#[derive(Debug, PartialEq)]
 105pub enum AssistantMessageChunk {
 106    Message { block: ContentBlock },
 107    Thought { block: ContentBlock },
 108}
 109
 110impl AssistantMessageChunk {
 111    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
 112        Self::Message {
 113            block: ContentBlock::new(chunk.into(), language_registry, cx),
 114        }
 115    }
 116
 117    fn to_markdown(&self, cx: &App) -> String {
 118        match self {
 119            Self::Message { block } => block.to_markdown(cx).to_string(),
 120            Self::Thought { block } => {
 121                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Debug)]
 128pub enum AgentThreadEntry {
 129    UserMessage(UserMessage),
 130    AssistantMessage(AssistantMessage),
 131    ToolCall(ToolCall),
 132}
 133
 134impl AgentThreadEntry {
 135    fn to_markdown(&self, cx: &App) -> String {
 136        match self {
 137            Self::UserMessage(message) => message.to_markdown(cx),
 138            Self::AssistantMessage(message) => message.to_markdown(cx),
 139            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 140        }
 141    }
 142
 143    pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
 144        if let AgentThreadEntry::ToolCall(call) = self {
 145            itertools::Either::Left(call.diffs())
 146        } else {
 147            itertools::Either::Right(std::iter::empty())
 148        }
 149    }
 150
 151    pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
 152        if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
 153            Some(locations)
 154        } else {
 155            None
 156        }
 157    }
 158}
 159
 160#[derive(Debug)]
 161pub struct ToolCall {
 162    pub id: acp::ToolCallId,
 163    pub label: Entity<Markdown>,
 164    pub kind: acp::ToolKind,
 165    pub content: Vec<ToolCallContent>,
 166    pub status: ToolCallStatus,
 167    pub locations: Vec<acp::ToolCallLocation>,
 168    pub raw_input: Option<serde_json::Value>,
 169}
 170
 171impl ToolCall {
 172    fn from_acp(
 173        tool_call: acp::ToolCall,
 174        status: ToolCallStatus,
 175        language_registry: Arc<LanguageRegistry>,
 176        cx: &mut App,
 177    ) -> Self {
 178        Self {
 179            id: tool_call.id,
 180            label: cx.new(|cx| {
 181                Markdown::new(
 182                    tool_call.title.into(),
 183                    Some(language_registry.clone()),
 184                    None,
 185                    cx,
 186                )
 187            }),
 188            kind: tool_call.kind,
 189            content: tool_call
 190                .content
 191                .into_iter()
 192                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 193                .collect(),
 194            locations: tool_call.locations,
 195            status,
 196            raw_input: tool_call.raw_input,
 197        }
 198    }
 199
 200    fn update(
 201        &mut self,
 202        fields: acp::ToolCallUpdateFields,
 203        language_registry: Arc<LanguageRegistry>,
 204        cx: &mut App,
 205    ) {
 206        let acp::ToolCallUpdateFields {
 207            kind,
 208            status,
 209            title,
 210            content,
 211            locations,
 212            raw_input,
 213        } = fields;
 214
 215        if let Some(kind) = kind {
 216            self.kind = kind;
 217        }
 218
 219        if let Some(status) = status {
 220            self.status = ToolCallStatus::Allowed { status };
 221        }
 222
 223        if let Some(title) = title {
 224            self.label.update(cx, |label, cx| {
 225                label.replace(title, cx);
 226            });
 227        }
 228
 229        if let Some(content) = content {
 230            self.content = content
 231                .into_iter()
 232                .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
 233                .collect();
 234        }
 235
 236        if let Some(locations) = locations {
 237            self.locations = locations;
 238        }
 239
 240        if let Some(raw_input) = raw_input {
 241            self.raw_input = Some(raw_input);
 242        }
 243    }
 244
 245    pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
 246        self.content.iter().filter_map(|content| match content {
 247            ToolCallContent::ContentBlock { .. } => None,
 248            ToolCallContent::Diff { diff } => Some(diff),
 249        })
 250    }
 251
 252    fn to_markdown(&self, cx: &App) -> String {
 253        let mut markdown = format!(
 254            "**Tool Call: {}**\nStatus: {}\n\n",
 255            self.label.read(cx).source(),
 256            self.status
 257        );
 258        for content in &self.content {
 259            markdown.push_str(content.to_markdown(cx).as_str());
 260            markdown.push_str("\n\n");
 261        }
 262        markdown
 263    }
 264}
 265
 266#[derive(Debug)]
 267pub enum ToolCallStatus {
 268    WaitingForConfirmation {
 269        options: Vec<acp::PermissionOption>,
 270        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 271    },
 272    Allowed {
 273        status: acp::ToolCallStatus,
 274    },
 275    Rejected,
 276    Canceled,
 277}
 278
 279impl Display for ToolCallStatus {
 280    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 281        write!(
 282            f,
 283            "{}",
 284            match self {
 285                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 286                ToolCallStatus::Allowed { status } => match status {
 287                    acp::ToolCallStatus::Pending => "Pending",
 288                    acp::ToolCallStatus::InProgress => "In Progress",
 289                    acp::ToolCallStatus::Completed => "Completed",
 290                    acp::ToolCallStatus::Failed => "Failed",
 291                },
 292                ToolCallStatus::Rejected => "Rejected",
 293                ToolCallStatus::Canceled => "Canceled",
 294            }
 295        )
 296    }
 297}
 298
 299#[derive(Debug, PartialEq, Clone)]
 300pub enum ContentBlock {
 301    Empty,
 302    Markdown { markdown: Entity<Markdown> },
 303}
 304
 305impl ContentBlock {
 306    pub fn new(
 307        block: acp::ContentBlock,
 308        language_registry: &Arc<LanguageRegistry>,
 309        cx: &mut App,
 310    ) -> Self {
 311        let mut this = Self::Empty;
 312        this.append(block, language_registry, cx);
 313        this
 314    }
 315
 316    pub fn new_combined(
 317        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 318        language_registry: Arc<LanguageRegistry>,
 319        cx: &mut App,
 320    ) -> Self {
 321        let mut this = Self::Empty;
 322        for block in blocks {
 323            this.append(block, &language_registry, cx);
 324        }
 325        this
 326    }
 327
 328    pub fn append(
 329        &mut self,
 330        block: acp::ContentBlock,
 331        language_registry: &Arc<LanguageRegistry>,
 332        cx: &mut App,
 333    ) {
 334        let new_content = match block {
 335            acp::ContentBlock::Text(text_content) => text_content.text.clone(),
 336            acp::ContentBlock::ResourceLink(resource_link) => {
 337                if let Some(path) = resource_link.uri.strip_prefix("file://") {
 338                    format!("{}", MentionPath(path.as_ref()))
 339                } else {
 340                    resource_link.uri.clone()
 341                }
 342            }
 343            acp::ContentBlock::Image(_)
 344            | acp::ContentBlock::Audio(_)
 345            | acp::ContentBlock::Resource(_) => String::new(),
 346        };
 347
 348        match self {
 349            ContentBlock::Empty => {
 350                *self = ContentBlock::Markdown {
 351                    markdown: cx.new(|cx| {
 352                        Markdown::new(
 353                            new_content.into(),
 354                            Some(language_registry.clone()),
 355                            None,
 356                            cx,
 357                        )
 358                    }),
 359                };
 360            }
 361            ContentBlock::Markdown { markdown } => {
 362                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 363            }
 364        }
 365    }
 366
 367    fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 368        match self {
 369            ContentBlock::Empty => "",
 370            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 371        }
 372    }
 373
 374    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 375        match self {
 376            ContentBlock::Empty => None,
 377            ContentBlock::Markdown { markdown } => Some(markdown),
 378        }
 379    }
 380}
 381
 382#[derive(Debug)]
 383pub enum ToolCallContent {
 384    ContentBlock { content: ContentBlock },
 385    Diff { diff: Diff },
 386}
 387
 388impl ToolCallContent {
 389    pub fn from_acp(
 390        content: acp::ToolCallContent,
 391        language_registry: Arc<LanguageRegistry>,
 392        cx: &mut App,
 393    ) -> Self {
 394        match content {
 395            acp::ToolCallContent::Content { content } => Self::ContentBlock {
 396                content: ContentBlock::new(content, &language_registry, cx),
 397            },
 398            acp::ToolCallContent::Diff { diff } => Self::Diff {
 399                diff: Diff::from_acp(diff, language_registry, cx),
 400            },
 401        }
 402    }
 403
 404    pub fn to_markdown(&self, cx: &App) -> String {
 405        match self {
 406            Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
 407            Self::Diff { diff } => diff.to_markdown(cx),
 408        }
 409    }
 410}
 411
 412#[derive(Debug)]
 413pub struct Diff {
 414    pub multibuffer: Entity<MultiBuffer>,
 415    pub path: PathBuf,
 416    _task: Task<Result<()>>,
 417}
 418
 419impl Diff {
 420    pub fn from_acp(
 421        diff: acp::Diff,
 422        language_registry: Arc<LanguageRegistry>,
 423        cx: &mut App,
 424    ) -> Self {
 425        let acp::Diff {
 426            path,
 427            old_text,
 428            new_text,
 429        } = diff;
 430
 431        let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
 432
 433        let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
 434        let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
 435        let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
 436        let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
 437
 438        let task = cx.spawn({
 439            let multibuffer = multibuffer.clone();
 440            let path = path.clone();
 441            async move |cx| {
 442                let language = language_registry
 443                    .language_for_file_path(&path)
 444                    .await
 445                    .log_err();
 446
 447                new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?;
 448
 449                let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| {
 450                    buffer.set_language(language, cx);
 451                    buffer.snapshot()
 452                })?;
 453
 454                buffer_diff
 455                    .update(cx, |diff, cx| {
 456                        diff.set_base_text(
 457                            old_buffer_snapshot,
 458                            Some(language_registry),
 459                            new_buffer_snapshot,
 460                            cx,
 461                        )
 462                    })?
 463                    .await?;
 464
 465                multibuffer
 466                    .update(cx, |multibuffer, cx| {
 467                        let hunk_ranges = {
 468                            let buffer = new_buffer.read(cx);
 469                            let diff = buffer_diff.read(cx);
 470                            diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
 471                                .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
 472                                .collect::<Vec<_>>()
 473                        };
 474
 475                        multibuffer.set_excerpts_for_path(
 476                            PathKey::for_buffer(&new_buffer, cx),
 477                            new_buffer.clone(),
 478                            hunk_ranges,
 479                            editor::DEFAULT_MULTIBUFFER_CONTEXT,
 480                            cx,
 481                        );
 482                        multibuffer.add_diff(buffer_diff, cx);
 483                    })
 484                    .log_err();
 485
 486                anyhow::Ok(())
 487            }
 488        });
 489
 490        Self {
 491            multibuffer,
 492            path,
 493            _task: task,
 494        }
 495    }
 496
 497    fn to_markdown(&self, cx: &App) -> String {
 498        let buffer_text = self
 499            .multibuffer
 500            .read(cx)
 501            .all_buffers()
 502            .iter()
 503            .map(|buffer| buffer.read(cx).text())
 504            .join("\n");
 505        format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
 506    }
 507}
 508
 509#[derive(Debug, Default)]
 510pub struct Plan {
 511    pub entries: Vec<PlanEntry>,
 512}
 513
 514#[derive(Debug)]
 515pub struct PlanStats<'a> {
 516    pub in_progress_entry: Option<&'a PlanEntry>,
 517    pub pending: u32,
 518    pub completed: u32,
 519}
 520
 521impl Plan {
 522    pub fn is_empty(&self) -> bool {
 523        self.entries.is_empty()
 524    }
 525
 526    pub fn stats(&self) -> PlanStats<'_> {
 527        let mut stats = PlanStats {
 528            in_progress_entry: None,
 529            pending: 0,
 530            completed: 0,
 531        };
 532
 533        for entry in &self.entries {
 534            match &entry.status {
 535                acp::PlanEntryStatus::Pending => {
 536                    stats.pending += 1;
 537                }
 538                acp::PlanEntryStatus::InProgress => {
 539                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 540                }
 541                acp::PlanEntryStatus::Completed => {
 542                    stats.completed += 1;
 543                }
 544            }
 545        }
 546
 547        stats
 548    }
 549}
 550
 551#[derive(Debug)]
 552pub struct PlanEntry {
 553    pub content: Entity<Markdown>,
 554    pub priority: acp::PlanEntryPriority,
 555    pub status: acp::PlanEntryStatus,
 556}
 557
 558impl PlanEntry {
 559    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 560        Self {
 561            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 562            priority: entry.priority,
 563            status: entry.status,
 564        }
 565    }
 566}
 567
 568pub struct AcpThread {
 569    title: SharedString,
 570    entries: Vec<AgentThreadEntry>,
 571    plan: Plan,
 572    project: Entity<Project>,
 573    action_log: Entity<ActionLog>,
 574    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 575    send_task: Option<Task<()>>,
 576    connection: Rc<dyn AgentConnection>,
 577    session_id: acp::SessionId,
 578}
 579
 580pub enum AcpThreadEvent {
 581    NewEntry,
 582    EntryUpdated(usize),
 583    ToolAuthorizationRequired,
 584    Stopped,
 585    Error,
 586    ServerExited(ExitStatus),
 587}
 588
 589impl EventEmitter<AcpThreadEvent> for AcpThread {}
 590
 591#[derive(PartialEq, Eq)]
 592pub enum ThreadStatus {
 593    Idle,
 594    WaitingForToolConfirmation,
 595    Generating,
 596}
 597
 598#[derive(Debug, Clone)]
 599pub enum LoadError {
 600    Unsupported {
 601        error_message: SharedString,
 602        upgrade_message: SharedString,
 603        upgrade_command: String,
 604    },
 605    Exited(i32),
 606    Other(SharedString),
 607}
 608
 609impl Display for LoadError {
 610    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 611        match self {
 612            LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
 613            LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
 614            LoadError::Other(msg) => write!(f, "{}", msg),
 615        }
 616    }
 617}
 618
 619impl Error for LoadError {}
 620
 621impl AcpThread {
 622    pub fn new(
 623        title: impl Into<SharedString>,
 624        connection: Rc<dyn AgentConnection>,
 625        project: Entity<Project>,
 626        session_id: acp::SessionId,
 627        cx: &mut Context<Self>,
 628    ) -> Self {
 629        let action_log = cx.new(|_| ActionLog::new(project.clone()));
 630
 631        Self {
 632            action_log,
 633            shared_buffers: Default::default(),
 634            entries: Default::default(),
 635            plan: Default::default(),
 636            title: title.into(),
 637            project,
 638            send_task: None,
 639            connection,
 640            session_id,
 641        }
 642    }
 643
 644    pub fn action_log(&self) -> &Entity<ActionLog> {
 645        &self.action_log
 646    }
 647
 648    pub fn project(&self) -> &Entity<Project> {
 649        &self.project
 650    }
 651
 652    pub fn title(&self) -> SharedString {
 653        self.title.clone()
 654    }
 655
 656    pub fn entries(&self) -> &[AgentThreadEntry] {
 657        &self.entries
 658    }
 659
 660    pub fn session_id(&self) -> &acp::SessionId {
 661        &self.session_id
 662    }
 663
 664    pub fn status(&self) -> ThreadStatus {
 665        if self.send_task.is_some() {
 666            if self.waiting_for_tool_confirmation() {
 667                ThreadStatus::WaitingForToolConfirmation
 668            } else {
 669                ThreadStatus::Generating
 670            }
 671        } else {
 672            ThreadStatus::Idle
 673        }
 674    }
 675
 676    pub fn has_pending_edit_tool_calls(&self) -> bool {
 677        for entry in self.entries.iter().rev() {
 678            match entry {
 679                AgentThreadEntry::UserMessage(_) => return false,
 680                AgentThreadEntry::ToolCall(
 681                    call @ ToolCall {
 682                        status:
 683                            ToolCallStatus::Allowed {
 684                                status:
 685                                    acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
 686                            },
 687                        ..
 688                    },
 689                ) if call.diffs().next().is_some() => {
 690                    return true;
 691                }
 692                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 693            }
 694        }
 695
 696        false
 697    }
 698
 699    pub fn used_tools_since_last_user_message(&self) -> bool {
 700        for entry in self.entries.iter().rev() {
 701            match entry {
 702                AgentThreadEntry::UserMessage(..) => return false,
 703                AgentThreadEntry::AssistantMessage(..) => continue,
 704                AgentThreadEntry::ToolCall(..) => return true,
 705            }
 706        }
 707
 708        false
 709    }
 710
 711    pub fn handle_session_update(
 712        &mut self,
 713        update: acp::SessionUpdate,
 714        cx: &mut Context<Self>,
 715    ) -> Result<()> {
 716        match update {
 717            acp::SessionUpdate::UserMessageChunk { content } => {
 718                self.push_user_content_block(content, cx);
 719            }
 720            acp::SessionUpdate::AgentMessageChunk { content } => {
 721                self.push_assistant_content_block(content, false, cx);
 722            }
 723            acp::SessionUpdate::AgentThoughtChunk { content } => {
 724                self.push_assistant_content_block(content, true, cx);
 725            }
 726            acp::SessionUpdate::ToolCall(tool_call) => {
 727                self.upsert_tool_call(tool_call, cx);
 728            }
 729            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 730                self.update_tool_call(tool_call_update, cx)?;
 731            }
 732            acp::SessionUpdate::Plan(plan) => {
 733                self.update_plan(plan, cx);
 734            }
 735        }
 736        Ok(())
 737    }
 738
 739    pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
 740        let language_registry = self.project.read(cx).languages().clone();
 741        let entries_len = self.entries.len();
 742
 743        if let Some(last_entry) = self.entries.last_mut()
 744            && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
 745        {
 746            content.append(chunk, &language_registry, cx);
 747            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 748        } else {
 749            let content = ContentBlock::new(chunk, &language_registry, cx);
 750            self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
 751        }
 752    }
 753
 754    pub fn push_assistant_content_block(
 755        &mut self,
 756        chunk: acp::ContentBlock,
 757        is_thought: bool,
 758        cx: &mut Context<Self>,
 759    ) {
 760        let language_registry = self.project.read(cx).languages().clone();
 761        let entries_len = self.entries.len();
 762        if let Some(last_entry) = self.entries.last_mut()
 763            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
 764        {
 765            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 766            match (chunks.last_mut(), is_thought) {
 767                (Some(AssistantMessageChunk::Message { block }), false)
 768                | (Some(AssistantMessageChunk::Thought { block }), true) => {
 769                    block.append(chunk, &language_registry, cx)
 770                }
 771                _ => {
 772                    let block = ContentBlock::new(chunk, &language_registry, cx);
 773                    if is_thought {
 774                        chunks.push(AssistantMessageChunk::Thought { block })
 775                    } else {
 776                        chunks.push(AssistantMessageChunk::Message { block })
 777                    }
 778                }
 779            }
 780        } else {
 781            let block = ContentBlock::new(chunk, &language_registry, cx);
 782            let chunk = if is_thought {
 783                AssistantMessageChunk::Thought { block }
 784            } else {
 785                AssistantMessageChunk::Message { block }
 786            };
 787
 788            self.push_entry(
 789                AgentThreadEntry::AssistantMessage(AssistantMessage {
 790                    chunks: vec![chunk],
 791                }),
 792                cx,
 793            );
 794        }
 795    }
 796
 797    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
 798        self.entries.push(entry);
 799        cx.emit(AcpThreadEvent::NewEntry);
 800    }
 801
 802    pub fn update_tool_call(
 803        &mut self,
 804        update: acp::ToolCallUpdate,
 805        cx: &mut Context<Self>,
 806    ) -> Result<()> {
 807        let languages = self.project.read(cx).languages().clone();
 808
 809        let (ix, current_call) = self
 810            .tool_call_mut(&update.id)
 811            .context("Tool call not found")?;
 812        current_call.update(update.fields, languages, cx);
 813
 814        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 815
 816        Ok(())
 817    }
 818
 819    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
 820    pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
 821        let status = ToolCallStatus::Allowed {
 822            status: tool_call.status,
 823        };
 824        self.upsert_tool_call_inner(tool_call, status, cx)
 825    }
 826
 827    pub fn upsert_tool_call_inner(
 828        &mut self,
 829        tool_call: acp::ToolCall,
 830        status: ToolCallStatus,
 831        cx: &mut Context<Self>,
 832    ) {
 833        let language_registry = self.project.read(cx).languages().clone();
 834        let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
 835
 836        let location = call.locations.last().cloned();
 837
 838        if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
 839            *current_call = call;
 840
 841            cx.emit(AcpThreadEvent::EntryUpdated(ix));
 842        } else {
 843            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
 844        }
 845
 846        if let Some(location) = location {
 847            self.set_project_location(location, cx)
 848        }
 849    }
 850
 851    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
 852        // The tool call we are looking for is typically the last one, or very close to the end.
 853        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
 854        self.entries
 855            .iter_mut()
 856            .enumerate()
 857            .rev()
 858            .find_map(|(index, tool_call)| {
 859                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
 860                    && &tool_call.id == id
 861                {
 862                    Some((index, tool_call))
 863                } else {
 864                    None
 865                }
 866            })
 867    }
 868
 869    pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
 870        self.project.update(cx, |project, cx| {
 871            let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
 872                return;
 873            };
 874            let buffer = project.open_buffer(path, cx);
 875            cx.spawn(async move |project, cx| {
 876                let buffer = buffer.await?;
 877
 878                project.update(cx, |project, cx| {
 879                    let position = if let Some(line) = location.line {
 880                        let snapshot = buffer.read(cx).snapshot();
 881                        let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
 882                        snapshot.anchor_before(point)
 883                    } else {
 884                        Anchor::MIN
 885                    };
 886
 887                    project.set_agent_location(
 888                        Some(AgentLocation {
 889                            buffer: buffer.downgrade(),
 890                            position,
 891                        }),
 892                        cx,
 893                    );
 894                })
 895            })
 896            .detach_and_log_err(cx);
 897        });
 898    }
 899
 900    pub fn request_tool_call_permission(
 901        &mut self,
 902        tool_call: acp::ToolCall,
 903        options: Vec<acp::PermissionOption>,
 904        cx: &mut Context<Self>,
 905    ) -> oneshot::Receiver<acp::PermissionOptionId> {
 906        let (tx, rx) = oneshot::channel();
 907
 908        let status = ToolCallStatus::WaitingForConfirmation {
 909            options,
 910            respond_tx: tx,
 911        };
 912
 913        self.upsert_tool_call_inner(tool_call, status, cx);
 914        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
 915        rx
 916    }
 917
 918    pub fn authorize_tool_call(
 919        &mut self,
 920        id: acp::ToolCallId,
 921        option_id: acp::PermissionOptionId,
 922        option_kind: acp::PermissionOptionKind,
 923        cx: &mut Context<Self>,
 924    ) {
 925        let Some((ix, call)) = self.tool_call_mut(&id) else {
 926            return;
 927        };
 928
 929        let new_status = match option_kind {
 930            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
 931                ToolCallStatus::Rejected
 932            }
 933            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
 934                ToolCallStatus::Allowed {
 935                    status: acp::ToolCallStatus::InProgress,
 936                }
 937            }
 938        };
 939
 940        let curr_status = mem::replace(&mut call.status, new_status);
 941
 942        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
 943            respond_tx.send(option_id).log_err();
 944        } else if cfg!(debug_assertions) {
 945            panic!("tried to authorize an already authorized tool call");
 946        }
 947
 948        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 949    }
 950
 951    /// Returns true if the last turn is awaiting tool authorization
 952    pub fn waiting_for_tool_confirmation(&self) -> bool {
 953        for entry in self.entries.iter().rev() {
 954            match &entry {
 955                AgentThreadEntry::ToolCall(call) => match call.status {
 956                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
 957                    ToolCallStatus::Allowed { .. }
 958                    | ToolCallStatus::Rejected
 959                    | ToolCallStatus::Canceled => continue,
 960                },
 961                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
 962                    // Reached the beginning of the turn
 963                    return false;
 964                }
 965            }
 966        }
 967        false
 968    }
 969
 970    pub fn plan(&self) -> &Plan {
 971        &self.plan
 972    }
 973
 974    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
 975        let new_entries_len = request.entries.len();
 976        let mut new_entries = request.entries.into_iter();
 977
 978        // Reuse existing markdown to prevent flickering
 979        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
 980            let PlanEntry {
 981                content,
 982                priority,
 983                status,
 984            } = old;
 985            content.update(cx, |old, cx| {
 986                old.replace(new.content, cx);
 987            });
 988            *priority = new.priority;
 989            *status = new.status;
 990        }
 991        for new in new_entries {
 992            self.plan.entries.push(PlanEntry::from_acp(new, cx))
 993        }
 994        self.plan.entries.truncate(new_entries_len);
 995
 996        cx.notify();
 997    }
 998
 999    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1000        self.plan
1001            .entries
1002            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1003        cx.notify();
1004    }
1005
1006    #[cfg(any(test, feature = "test-support"))]
1007    pub fn send_raw(
1008        &mut self,
1009        message: &str,
1010        cx: &mut Context<Self>,
1011    ) -> BoxFuture<'static, Result<()>> {
1012        self.send(
1013            vec![acp::ContentBlock::Text(acp::TextContent {
1014                text: message.to_string(),
1015                annotations: None,
1016            })],
1017            cx,
1018        )
1019    }
1020
1021    pub fn send(
1022        &mut self,
1023        message: Vec<acp::ContentBlock>,
1024        cx: &mut Context<Self>,
1025    ) -> BoxFuture<'static, Result<()>> {
1026        let block = ContentBlock::new_combined(
1027            message.clone(),
1028            self.project.read(cx).languages().clone(),
1029            cx,
1030        );
1031        self.push_entry(
1032            AgentThreadEntry::UserMessage(UserMessage { content: block }),
1033            cx,
1034        );
1035        self.clear_completed_plan_entries(cx);
1036
1037        let (tx, rx) = oneshot::channel();
1038        let cancel_task = self.cancel(cx);
1039
1040        self.send_task = Some(cx.spawn(async move |this, cx| {
1041            async {
1042                cancel_task.await;
1043
1044                let result = this
1045                    .update(cx, |this, cx| {
1046                        this.connection.prompt(
1047                            acp::PromptRequest {
1048                                prompt: message,
1049                                session_id: this.session_id.clone(),
1050                            },
1051                            cx,
1052                        )
1053                    })?
1054                    .await;
1055                tx.send(result).log_err();
1056                this.update(cx, |this, _cx| this.send_task.take())?;
1057                anyhow::Ok(())
1058            }
1059            .await
1060            .log_err();
1061        }));
1062
1063        cx.spawn(async move |this, cx| match rx.await {
1064            Ok(Err(e)) => {
1065                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1066                    .log_err();
1067                Err(e)?
1068            }
1069            _ => {
1070                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1071                    .log_err();
1072                Ok(())
1073            }
1074        })
1075        .boxed()
1076    }
1077
1078    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1079        let Some(send_task) = self.send_task.take() else {
1080            return Task::ready(());
1081        };
1082
1083        for entry in self.entries.iter_mut() {
1084            if let AgentThreadEntry::ToolCall(call) = entry {
1085                let cancel = matches!(
1086                    call.status,
1087                    ToolCallStatus::WaitingForConfirmation { .. }
1088                        | ToolCallStatus::Allowed {
1089                            status: acp::ToolCallStatus::InProgress
1090                        }
1091                );
1092
1093                if cancel {
1094                    call.status = ToolCallStatus::Canceled;
1095                }
1096            }
1097        }
1098
1099        self.connection.cancel(&self.session_id, cx);
1100
1101        // Wait for the send task to complete
1102        cx.foreground_executor().spawn(send_task)
1103    }
1104
1105    pub fn read_text_file(
1106        &self,
1107        path: PathBuf,
1108        line: Option<u32>,
1109        limit: Option<u32>,
1110        reuse_shared_snapshot: bool,
1111        cx: &mut Context<Self>,
1112    ) -> Task<Result<String>> {
1113        let project = self.project.clone();
1114        let action_log = self.action_log.clone();
1115        cx.spawn(async move |this, cx| {
1116            let load = project.update(cx, |project, cx| {
1117                let path = project
1118                    .project_path_for_absolute_path(&path, cx)
1119                    .context("invalid path")?;
1120                anyhow::Ok(project.open_buffer(path, cx))
1121            });
1122            let buffer = load??.await?;
1123
1124            let snapshot = if reuse_shared_snapshot {
1125                this.read_with(cx, |this, _| {
1126                    this.shared_buffers.get(&buffer.clone()).cloned()
1127                })
1128                .log_err()
1129                .flatten()
1130            } else {
1131                None
1132            };
1133
1134            let snapshot = if let Some(snapshot) = snapshot {
1135                snapshot
1136            } else {
1137                action_log.update(cx, |action_log, cx| {
1138                    action_log.buffer_read(buffer.clone(), cx);
1139                })?;
1140                project.update(cx, |project, cx| {
1141                    let position = buffer
1142                        .read(cx)
1143                        .snapshot()
1144                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1145                    project.set_agent_location(
1146                        Some(AgentLocation {
1147                            buffer: buffer.downgrade(),
1148                            position,
1149                        }),
1150                        cx,
1151                    );
1152                })?;
1153
1154                buffer.update(cx, |buffer, _| buffer.snapshot())?
1155            };
1156
1157            this.update(cx, |this, _| {
1158                let text = snapshot.text();
1159                this.shared_buffers.insert(buffer.clone(), snapshot);
1160                if line.is_none() && limit.is_none() {
1161                    return Ok(text);
1162                }
1163                let limit = limit.unwrap_or(u32::MAX) as usize;
1164                let Some(line) = line else {
1165                    return Ok(text.lines().take(limit).collect::<String>());
1166                };
1167
1168                let count = text.lines().count();
1169                if count < line as usize {
1170                    anyhow::bail!("There are only {} lines", count);
1171                }
1172                Ok(text
1173                    .lines()
1174                    .skip(line as usize + 1)
1175                    .take(limit)
1176                    .collect::<String>())
1177            })?
1178        })
1179    }
1180
1181    pub fn write_text_file(
1182        &self,
1183        path: PathBuf,
1184        content: String,
1185        cx: &mut Context<Self>,
1186    ) -> Task<Result<()>> {
1187        let project = self.project.clone();
1188        let action_log = self.action_log.clone();
1189        cx.spawn(async move |this, cx| {
1190            let load = project.update(cx, |project, cx| {
1191                let path = project
1192                    .project_path_for_absolute_path(&path, cx)
1193                    .context("invalid path")?;
1194                anyhow::Ok(project.open_buffer(path, cx))
1195            });
1196            let buffer = load??.await?;
1197            let snapshot = this.update(cx, |this, cx| {
1198                this.shared_buffers
1199                    .get(&buffer)
1200                    .cloned()
1201                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1202            })?;
1203            let edits = cx
1204                .background_executor()
1205                .spawn(async move {
1206                    let old_text = snapshot.text();
1207                    text_diff(old_text.as_str(), &content)
1208                        .into_iter()
1209                        .map(|(range, replacement)| {
1210                            (
1211                                snapshot.anchor_after(range.start)
1212                                    ..snapshot.anchor_before(range.end),
1213                                replacement,
1214                            )
1215                        })
1216                        .collect::<Vec<_>>()
1217                })
1218                .await;
1219            cx.update(|cx| {
1220                project.update(cx, |project, cx| {
1221                    project.set_agent_location(
1222                        Some(AgentLocation {
1223                            buffer: buffer.downgrade(),
1224                            position: edits
1225                                .last()
1226                                .map(|(range, _)| range.end)
1227                                .unwrap_or(Anchor::MIN),
1228                        }),
1229                        cx,
1230                    );
1231                });
1232
1233                action_log.update(cx, |action_log, cx| {
1234                    action_log.buffer_read(buffer.clone(), cx);
1235                });
1236                buffer.update(cx, |buffer, cx| {
1237                    buffer.edit(edits, None, cx);
1238                });
1239                action_log.update(cx, |action_log, cx| {
1240                    action_log.buffer_edited(buffer.clone(), cx);
1241                });
1242            })?;
1243            project
1244                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1245                .await
1246        })
1247    }
1248
1249    pub fn to_markdown(&self, cx: &App) -> String {
1250        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1251    }
1252
1253    pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1254        cx.emit(AcpThreadEvent::ServerExited(status));
1255    }
1256}
1257
1258#[cfg(test)]
1259mod tests {
1260    use super::*;
1261    use anyhow::anyhow;
1262    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1263    use gpui::{AsyncApp, TestAppContext, WeakEntity};
1264    use indoc::indoc;
1265    use project::FakeFs;
1266    use rand::Rng as _;
1267    use serde_json::json;
1268    use settings::SettingsStore;
1269    use smol::stream::StreamExt as _;
1270    use std::{cell::RefCell, rc::Rc, time::Duration};
1271
1272    use util::path;
1273
1274    fn init_test(cx: &mut TestAppContext) {
1275        env_logger::try_init().ok();
1276        cx.update(|cx| {
1277            let settings_store = SettingsStore::test(cx);
1278            cx.set_global(settings_store);
1279            Project::init_settings(cx);
1280            language::init(cx);
1281        });
1282    }
1283
1284    #[gpui::test]
1285    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1286        init_test(cx);
1287
1288        let fs = FakeFs::new(cx.executor());
1289        let project = Project::test(fs, [], cx).await;
1290        let connection = Rc::new(FakeAgentConnection::new());
1291        let thread = cx
1292            .spawn(async move |mut cx| {
1293                connection
1294                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1295                    .await
1296            })
1297            .await
1298            .unwrap();
1299
1300        // Test creating a new user message
1301        thread.update(cx, |thread, cx| {
1302            thread.push_user_content_block(
1303                acp::ContentBlock::Text(acp::TextContent {
1304                    annotations: None,
1305                    text: "Hello, ".to_string(),
1306                }),
1307                cx,
1308            );
1309        });
1310
1311        thread.update(cx, |thread, cx| {
1312            assert_eq!(thread.entries.len(), 1);
1313            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1314                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1315            } else {
1316                panic!("Expected UserMessage");
1317            }
1318        });
1319
1320        // Test appending to existing user message
1321        thread.update(cx, |thread, cx| {
1322            thread.push_user_content_block(
1323                acp::ContentBlock::Text(acp::TextContent {
1324                    annotations: None,
1325                    text: "world!".to_string(),
1326                }),
1327                cx,
1328            );
1329        });
1330
1331        thread.update(cx, |thread, cx| {
1332            assert_eq!(thread.entries.len(), 1);
1333            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1334                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1335            } else {
1336                panic!("Expected UserMessage");
1337            }
1338        });
1339
1340        // Test creating new user message after assistant message
1341        thread.update(cx, |thread, cx| {
1342            thread.push_assistant_content_block(
1343                acp::ContentBlock::Text(acp::TextContent {
1344                    annotations: None,
1345                    text: "Assistant response".to_string(),
1346                }),
1347                false,
1348                cx,
1349            );
1350        });
1351
1352        thread.update(cx, |thread, cx| {
1353            thread.push_user_content_block(
1354                acp::ContentBlock::Text(acp::TextContent {
1355                    annotations: None,
1356                    text: "New user message".to_string(),
1357                }),
1358                cx,
1359            );
1360        });
1361
1362        thread.update(cx, |thread, cx| {
1363            assert_eq!(thread.entries.len(), 3);
1364            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1365                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1366            } else {
1367                panic!("Expected UserMessage at index 2");
1368            }
1369        });
1370    }
1371
1372    #[gpui::test]
1373    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1374        init_test(cx);
1375
1376        let fs = FakeFs::new(cx.executor());
1377        let project = Project::test(fs, [], cx).await;
1378        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1379            |_, thread, mut cx| {
1380                async move {
1381                    thread.update(&mut cx, |thread, cx| {
1382                        thread
1383                            .handle_session_update(
1384                                acp::SessionUpdate::AgentThoughtChunk {
1385                                    content: "Thinking ".into(),
1386                                },
1387                                cx,
1388                            )
1389                            .unwrap();
1390                        thread
1391                            .handle_session_update(
1392                                acp::SessionUpdate::AgentThoughtChunk {
1393                                    content: "hard!".into(),
1394                                },
1395                                cx,
1396                            )
1397                            .unwrap();
1398                    })?;
1399                    Ok(acp::PromptResponse {
1400                        stop_reason: acp::StopReason::EndTurn,
1401                    })
1402                }
1403                .boxed_local()
1404            },
1405        ));
1406
1407        let thread = cx
1408            .spawn(async move |mut cx| {
1409                connection
1410                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1411                    .await
1412            })
1413            .await
1414            .unwrap();
1415
1416        thread
1417            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1418            .await
1419            .unwrap();
1420
1421        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1422        assert_eq!(
1423            output,
1424            indoc! {r#"
1425            ## User
1426
1427            Hello from Zed!
1428
1429            ## Assistant
1430
1431            <thinking>
1432            Thinking hard!
1433            </thinking>
1434
1435            "#}
1436        );
1437    }
1438
1439    #[gpui::test]
1440    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1441        init_test(cx);
1442
1443        let fs = FakeFs::new(cx.executor());
1444        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1445            .await;
1446        let project = Project::test(fs.clone(), [], cx).await;
1447        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1448        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1449        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1450            move |_, thread, mut cx| {
1451                let read_file_tx = read_file_tx.clone();
1452                async move {
1453                    let content = thread
1454                        .update(&mut cx, |thread, cx| {
1455                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1456                        })
1457                        .unwrap()
1458                        .await
1459                        .unwrap();
1460                    assert_eq!(content, "one\ntwo\nthree\n");
1461                    read_file_tx.take().unwrap().send(()).unwrap();
1462                    thread
1463                        .update(&mut cx, |thread, cx| {
1464                            thread.write_text_file(
1465                                path!("/tmp/foo").into(),
1466                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
1467                                cx,
1468                            )
1469                        })
1470                        .unwrap()
1471                        .await
1472                        .unwrap();
1473                    Ok(acp::PromptResponse {
1474                        stop_reason: acp::StopReason::EndTurn,
1475                    })
1476                }
1477                .boxed_local()
1478            },
1479        ));
1480
1481        let (worktree, pathbuf) = project
1482            .update(cx, |project, cx| {
1483                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1484            })
1485            .await
1486            .unwrap();
1487        let buffer = project
1488            .update(cx, |project, cx| {
1489                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1490            })
1491            .await
1492            .unwrap();
1493
1494        let thread = cx
1495            .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1496            .await
1497            .unwrap();
1498
1499        let request = thread.update(cx, |thread, cx| {
1500            thread.send_raw("Extend the count in /tmp/foo", cx)
1501        });
1502        read_file_rx.await.ok();
1503        buffer.update(cx, |buffer, cx| {
1504            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1505        });
1506        cx.run_until_parked();
1507        assert_eq!(
1508            buffer.read_with(cx, |buffer, _| buffer.text()),
1509            "zero\none\ntwo\nthree\nfour\nfive\n"
1510        );
1511        assert_eq!(
1512            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1513            "zero\none\ntwo\nthree\nfour\nfive\n"
1514        );
1515        request.await.unwrap();
1516    }
1517
1518    #[gpui::test]
1519    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1520        init_test(cx);
1521
1522        let fs = FakeFs::new(cx.executor());
1523        let project = Project::test(fs, [], cx).await;
1524        let id = acp::ToolCallId("test".into());
1525
1526        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1527            let id = id.clone();
1528            move |_, thread, mut cx| {
1529                let id = id.clone();
1530                async move {
1531                    thread
1532                        .update(&mut cx, |thread, cx| {
1533                            thread.handle_session_update(
1534                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1535                                    id: id.clone(),
1536                                    title: "Label".into(),
1537                                    kind: acp::ToolKind::Fetch,
1538                                    status: acp::ToolCallStatus::InProgress,
1539                                    content: vec![],
1540                                    locations: vec![],
1541                                    raw_input: None,
1542                                }),
1543                                cx,
1544                            )
1545                        })
1546                        .unwrap()
1547                        .unwrap();
1548                    Ok(acp::PromptResponse {
1549                        stop_reason: acp::StopReason::EndTurn,
1550                    })
1551                }
1552                .boxed_local()
1553            }
1554        }));
1555
1556        let thread = cx
1557            .spawn(async move |mut cx| {
1558                connection
1559                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1560                    .await
1561            })
1562            .await
1563            .unwrap();
1564
1565        let request = thread.update(cx, |thread, cx| {
1566            thread.send_raw("Fetch https://example.com", cx)
1567        });
1568
1569        run_until_first_tool_call(&thread, cx).await;
1570
1571        thread.read_with(cx, |thread, _| {
1572            assert!(matches!(
1573                thread.entries[1],
1574                AgentThreadEntry::ToolCall(ToolCall {
1575                    status: ToolCallStatus::Allowed {
1576                        status: acp::ToolCallStatus::InProgress,
1577                        ..
1578                    },
1579                    ..
1580                })
1581            ));
1582        });
1583
1584        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1585
1586        thread.read_with(cx, |thread, _| {
1587            assert!(matches!(
1588                &thread.entries[1],
1589                AgentThreadEntry::ToolCall(ToolCall {
1590                    status: ToolCallStatus::Canceled,
1591                    ..
1592                })
1593            ));
1594        });
1595
1596        thread
1597            .update(cx, |thread, cx| {
1598                thread.handle_session_update(
1599                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1600                        id,
1601                        fields: acp::ToolCallUpdateFields {
1602                            status: Some(acp::ToolCallStatus::Completed),
1603                            ..Default::default()
1604                        },
1605                    }),
1606                    cx,
1607                )
1608            })
1609            .unwrap();
1610
1611        request.await.unwrap();
1612
1613        thread.read_with(cx, |thread, _| {
1614            assert!(matches!(
1615                thread.entries[1],
1616                AgentThreadEntry::ToolCall(ToolCall {
1617                    status: ToolCallStatus::Allowed {
1618                        status: acp::ToolCallStatus::Completed,
1619                        ..
1620                    },
1621                    ..
1622                })
1623            ));
1624        });
1625    }
1626
1627    #[gpui::test]
1628    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1629        init_test(cx);
1630        let fs = FakeFs::new(cx.background_executor.clone());
1631        fs.insert_tree(path!("/test"), json!({})).await;
1632        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1633
1634        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1635            move |_, thread, mut cx| {
1636                async move {
1637                    thread
1638                        .update(&mut cx, |thread, cx| {
1639                            thread.handle_session_update(
1640                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1641                                    id: acp::ToolCallId("test".into()),
1642                                    title: "Label".into(),
1643                                    kind: acp::ToolKind::Edit,
1644                                    status: acp::ToolCallStatus::Completed,
1645                                    content: vec![acp::ToolCallContent::Diff {
1646                                        diff: acp::Diff {
1647                                            path: "/test/test.txt".into(),
1648                                            old_text: None,
1649                                            new_text: "foo".into(),
1650                                        },
1651                                    }],
1652                                    locations: vec![],
1653                                    raw_input: None,
1654                                }),
1655                                cx,
1656                            )
1657                        })
1658                        .unwrap()
1659                        .unwrap();
1660                    Ok(acp::PromptResponse {
1661                        stop_reason: acp::StopReason::EndTurn,
1662                    })
1663                }
1664                .boxed_local()
1665            }
1666        }));
1667
1668        let thread = connection
1669            .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1670            .await
1671            .unwrap();
1672        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1673            .await
1674            .unwrap();
1675
1676        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1677    }
1678
1679    async fn run_until_first_tool_call(
1680        thread: &Entity<AcpThread>,
1681        cx: &mut TestAppContext,
1682    ) -> usize {
1683        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1684
1685        let subscription = cx.update(|cx| {
1686            cx.subscribe(thread, move |thread, _, cx| {
1687                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1688                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1689                        return tx.try_send(ix).unwrap();
1690                    }
1691                }
1692            })
1693        });
1694
1695        select! {
1696            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1697                panic!("Timeout waiting for tool call")
1698            }
1699            ix = rx.next().fuse() => {
1700                drop(subscription);
1701                ix.unwrap()
1702            }
1703        }
1704    }
1705
1706    #[derive(Clone, Default)]
1707    struct FakeAgentConnection {
1708        auth_methods: Vec<acp::AuthMethod>,
1709        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1710        on_user_message: Option<
1711            Rc<
1712                dyn Fn(
1713                        acp::PromptRequest,
1714                        WeakEntity<AcpThread>,
1715                        AsyncApp,
1716                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1717                    + 'static,
1718            >,
1719        >,
1720    }
1721
1722    impl FakeAgentConnection {
1723        fn new() -> Self {
1724            Self {
1725                auth_methods: Vec::new(),
1726                on_user_message: None,
1727                sessions: Arc::default(),
1728            }
1729        }
1730
1731        #[expect(unused)]
1732        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1733            self.auth_methods = auth_methods;
1734            self
1735        }
1736
1737        fn on_user_message(
1738            mut self,
1739            handler: impl Fn(
1740                acp::PromptRequest,
1741                WeakEntity<AcpThread>,
1742                AsyncApp,
1743            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1744            + 'static,
1745        ) -> Self {
1746            self.on_user_message.replace(Rc::new(handler));
1747            self
1748        }
1749    }
1750
1751    impl AgentConnection for FakeAgentConnection {
1752        fn auth_methods(&self) -> &[acp::AuthMethod] {
1753            &self.auth_methods
1754        }
1755
1756        fn new_thread(
1757            self: Rc<Self>,
1758            project: Entity<Project>,
1759            _cwd: &Path,
1760            cx: &mut gpui::AsyncApp,
1761        ) -> Task<gpui::Result<Entity<AcpThread>>> {
1762            let session_id = acp::SessionId(
1763                rand::thread_rng()
1764                    .sample_iter(&rand::distributions::Alphanumeric)
1765                    .take(7)
1766                    .map(char::from)
1767                    .collect::<String>()
1768                    .into(),
1769            );
1770            let thread = cx
1771                .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1772                .unwrap();
1773            self.sessions.lock().insert(session_id, thread.downgrade());
1774            Task::ready(Ok(thread))
1775        }
1776
1777        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1778            if self.auth_methods().iter().any(|m| m.id == method) {
1779                Task::ready(Ok(()))
1780            } else {
1781                Task::ready(Err(anyhow!("Invalid Auth Method")))
1782            }
1783        }
1784
1785        fn prompt(
1786            &self,
1787            params: acp::PromptRequest,
1788            cx: &mut App,
1789        ) -> Task<gpui::Result<acp::PromptResponse>> {
1790            let sessions = self.sessions.lock();
1791            let thread = sessions.get(&params.session_id).unwrap();
1792            if let Some(handler) = &self.on_user_message {
1793                let handler = handler.clone();
1794                let thread = thread.clone();
1795                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1796            } else {
1797                Task::ready(Ok(acp::PromptResponse {
1798                    stop_reason: acp::StopReason::EndTurn,
1799                }))
1800            }
1801        }
1802
1803        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1804            let sessions = self.sessions.lock();
1805            let thread = sessions.get(&session_id).unwrap().clone();
1806
1807            cx.spawn(async move |cx| {
1808                thread
1809                    .update(cx, |thread, cx| thread.cancel(cx))
1810                    .unwrap()
1811                    .await
1812            })
1813            .detach();
1814        }
1815    }
1816}