acp_thread.rs

   1mod connection;
   2mod diff;
   3mod terminal;
   4
   5pub use connection::*;
   6pub use diff::*;
   7pub use terminal::*;
   8
   9use action_log::ActionLog;
  10use agent_client_protocol as acp;
  11use anyhow::{Context as _, Result};
  12use editor::Bias;
  13use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  14use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
  15use itertools::Itertools;
  16use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff};
  17use markdown::Markdown;
  18use project::{AgentLocation, Project};
  19use std::collections::HashMap;
  20use std::error::Error;
  21use std::fmt::Formatter;
  22use std::process::ExitStatus;
  23use std::rc::Rc;
  24use std::{
  25    fmt::Display,
  26    mem,
  27    path::{Path, PathBuf},
  28    sync::Arc,
  29};
  30use ui::App;
  31use util::ResultExt;
  32
  33#[derive(Debug)]
  34pub struct UserMessage {
  35    pub content: ContentBlock,
  36}
  37
  38impl UserMessage {
  39    pub fn from_acp(
  40        message: impl IntoIterator<Item = acp::ContentBlock>,
  41        language_registry: Arc<LanguageRegistry>,
  42        cx: &mut App,
  43    ) -> Self {
  44        let mut content = ContentBlock::Empty;
  45        for chunk in message {
  46            content.append(chunk, &language_registry, cx)
  47        }
  48        Self { content: content }
  49    }
  50
  51    fn to_markdown(&self, cx: &App) -> String {
  52        format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
  53    }
  54}
  55
  56#[derive(Debug)]
  57pub struct MentionPath<'a>(&'a Path);
  58
  59impl<'a> MentionPath<'a> {
  60    const PREFIX: &'static str = "@file:";
  61
  62    pub fn new(path: &'a Path) -> Self {
  63        MentionPath(path)
  64    }
  65
  66    pub fn try_parse(url: &'a str) -> Option<Self> {
  67        let path = url.strip_prefix(Self::PREFIX)?;
  68        Some(MentionPath(Path::new(path)))
  69    }
  70
  71    pub fn path(&self) -> &Path {
  72        self.0
  73    }
  74}
  75
  76impl Display for MentionPath<'_> {
  77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  78        write!(
  79            f,
  80            "[@{}]({}{})",
  81            self.0.file_name().unwrap_or_default().display(),
  82            Self::PREFIX,
  83            self.0.display()
  84        )
  85    }
  86}
  87
  88#[derive(Debug, PartialEq)]
  89pub struct AssistantMessage {
  90    pub chunks: Vec<AssistantMessageChunk>,
  91}
  92
  93impl AssistantMessage {
  94    pub fn to_markdown(&self, cx: &App) -> String {
  95        format!(
  96            "## Assistant\n\n{}\n\n",
  97            self.chunks
  98                .iter()
  99                .map(|chunk| chunk.to_markdown(cx))
 100                .join("\n\n")
 101        )
 102    }
 103}
 104
 105#[derive(Debug, PartialEq)]
 106pub enum AssistantMessageChunk {
 107    Message { block: ContentBlock },
 108    Thought { block: ContentBlock },
 109}
 110
 111impl AssistantMessageChunk {
 112    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
 113        Self::Message {
 114            block: ContentBlock::new(chunk.into(), language_registry, cx),
 115        }
 116    }
 117
 118    fn to_markdown(&self, cx: &App) -> String {
 119        match self {
 120            Self::Message { block } => block.to_markdown(cx).to_string(),
 121            Self::Thought { block } => {
 122                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 123            }
 124        }
 125    }
 126}
 127
 128#[derive(Debug)]
 129pub enum AgentThreadEntry {
 130    UserMessage(UserMessage),
 131    AssistantMessage(AssistantMessage),
 132    ToolCall(ToolCall),
 133}
 134
 135impl AgentThreadEntry {
 136    fn to_markdown(&self, cx: &App) -> String {
 137        match self {
 138            Self::UserMessage(message) => message.to_markdown(cx),
 139            Self::AssistantMessage(message) => message.to_markdown(cx),
 140            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 141        }
 142    }
 143
 144    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 145        if let AgentThreadEntry::ToolCall(call) = self {
 146            itertools::Either::Left(call.diffs())
 147        } else {
 148            itertools::Either::Right(std::iter::empty())
 149        }
 150    }
 151
 152    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 153        if let AgentThreadEntry::ToolCall(call) = self {
 154            itertools::Either::Left(call.terminals())
 155        } else {
 156            itertools::Either::Right(std::iter::empty())
 157        }
 158    }
 159
 160    pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
 161        if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
 162            Some(locations)
 163        } else {
 164            None
 165        }
 166    }
 167}
 168
 169#[derive(Debug)]
 170pub struct ToolCall {
 171    pub id: acp::ToolCallId,
 172    pub label: Entity<Markdown>,
 173    pub kind: acp::ToolKind,
 174    pub content: Vec<ToolCallContent>,
 175    pub status: ToolCallStatus,
 176    pub locations: Vec<acp::ToolCallLocation>,
 177    pub raw_input: Option<serde_json::Value>,
 178    pub raw_output: Option<serde_json::Value>,
 179}
 180
 181impl ToolCall {
 182    fn from_acp(
 183        tool_call: acp::ToolCall,
 184        status: ToolCallStatus,
 185        language_registry: Arc<LanguageRegistry>,
 186        cx: &mut App,
 187    ) -> Self {
 188        Self {
 189            id: tool_call.id,
 190            label: cx.new(|cx| {
 191                Markdown::new(
 192                    tool_call.title.into(),
 193                    Some(language_registry.clone()),
 194                    None,
 195                    cx,
 196                )
 197            }),
 198            kind: tool_call.kind,
 199            content: tool_call
 200                .content
 201                .into_iter()
 202                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 203                .collect(),
 204            locations: tool_call.locations,
 205            status,
 206            raw_input: tool_call.raw_input,
 207            raw_output: tool_call.raw_output,
 208        }
 209    }
 210
 211    fn update_fields(
 212        &mut self,
 213        fields: acp::ToolCallUpdateFields,
 214        language_registry: Arc<LanguageRegistry>,
 215        cx: &mut App,
 216    ) {
 217        let acp::ToolCallUpdateFields {
 218            kind,
 219            status,
 220            title,
 221            content,
 222            locations,
 223            raw_input,
 224            raw_output,
 225        } = fields;
 226
 227        if let Some(kind) = kind {
 228            self.kind = kind;
 229        }
 230
 231        if let Some(status) = status {
 232            self.status = ToolCallStatus::Allowed { status };
 233        }
 234
 235        if let Some(title) = title {
 236            self.label.update(cx, |label, cx| {
 237                label.replace(title, cx);
 238            });
 239        }
 240
 241        if let Some(content) = content {
 242            self.content = content
 243                .into_iter()
 244                .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
 245                .collect();
 246        }
 247
 248        if let Some(locations) = locations {
 249            self.locations = locations;
 250        }
 251
 252        if let Some(raw_input) = raw_input {
 253            self.raw_input = Some(raw_input);
 254        }
 255
 256        if let Some(raw_output) = raw_output {
 257            if self.content.is_empty() {
 258                if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
 259                {
 260                    self.content
 261                        .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
 262                            markdown,
 263                        }));
 264                }
 265            }
 266            self.raw_output = Some(raw_output);
 267        }
 268    }
 269
 270    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 271        self.content.iter().filter_map(|content| match content {
 272            ToolCallContent::Diff(diff) => Some(diff),
 273            ToolCallContent::ContentBlock(_) => None,
 274            ToolCallContent::Terminal(_) => None,
 275        })
 276    }
 277
 278    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 279        self.content.iter().filter_map(|content| match content {
 280            ToolCallContent::Terminal(terminal) => Some(terminal),
 281            ToolCallContent::ContentBlock(_) => None,
 282            ToolCallContent::Diff(_) => None,
 283        })
 284    }
 285
 286    fn to_markdown(&self, cx: &App) -> String {
 287        let mut markdown = format!(
 288            "**Tool Call: {}**\nStatus: {}\n\n",
 289            self.label.read(cx).source(),
 290            self.status
 291        );
 292        for content in &self.content {
 293            markdown.push_str(content.to_markdown(cx).as_str());
 294            markdown.push_str("\n\n");
 295        }
 296        markdown
 297    }
 298}
 299
 300#[derive(Debug)]
 301pub enum ToolCallStatus {
 302    WaitingForConfirmation {
 303        options: Vec<acp::PermissionOption>,
 304        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 305    },
 306    Allowed {
 307        status: acp::ToolCallStatus,
 308    },
 309    Rejected,
 310    Canceled,
 311}
 312
 313impl Display for ToolCallStatus {
 314    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 315        write!(
 316            f,
 317            "{}",
 318            match self {
 319                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 320                ToolCallStatus::Allowed { status } => match status {
 321                    acp::ToolCallStatus::Pending => "Pending",
 322                    acp::ToolCallStatus::InProgress => "In Progress",
 323                    acp::ToolCallStatus::Completed => "Completed",
 324                    acp::ToolCallStatus::Failed => "Failed",
 325                },
 326                ToolCallStatus::Rejected => "Rejected",
 327                ToolCallStatus::Canceled => "Canceled",
 328            }
 329        )
 330    }
 331}
 332
 333#[derive(Debug, PartialEq, Clone)]
 334pub enum ContentBlock {
 335    Empty,
 336    Markdown { markdown: Entity<Markdown> },
 337}
 338
 339impl ContentBlock {
 340    pub fn new(
 341        block: acp::ContentBlock,
 342        language_registry: &Arc<LanguageRegistry>,
 343        cx: &mut App,
 344    ) -> Self {
 345        let mut this = Self::Empty;
 346        this.append(block, language_registry, cx);
 347        this
 348    }
 349
 350    pub fn new_combined(
 351        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 352        language_registry: Arc<LanguageRegistry>,
 353        cx: &mut App,
 354    ) -> Self {
 355        let mut this = Self::Empty;
 356        for block in blocks {
 357            this.append(block, &language_registry, cx);
 358        }
 359        this
 360    }
 361
 362    pub fn append(
 363        &mut self,
 364        block: acp::ContentBlock,
 365        language_registry: &Arc<LanguageRegistry>,
 366        cx: &mut App,
 367    ) {
 368        let new_content = match block {
 369            acp::ContentBlock::Text(text_content) => text_content.text.clone(),
 370            acp::ContentBlock::ResourceLink(resource_link) => {
 371                if let Some(path) = resource_link.uri.strip_prefix("file://") {
 372                    format!("{}", MentionPath(path.as_ref()))
 373                } else {
 374                    resource_link.uri.clone()
 375                }
 376            }
 377            acp::ContentBlock::Image(_)
 378            | acp::ContentBlock::Audio(_)
 379            | acp::ContentBlock::Resource(_) => String::new(),
 380        };
 381
 382        match self {
 383            ContentBlock::Empty => {
 384                *self = ContentBlock::Markdown {
 385                    markdown: cx.new(|cx| {
 386                        Markdown::new(
 387                            new_content.into(),
 388                            Some(language_registry.clone()),
 389                            None,
 390                            cx,
 391                        )
 392                    }),
 393                };
 394            }
 395            ContentBlock::Markdown { markdown } => {
 396                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 397            }
 398        }
 399    }
 400
 401    fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 402        match self {
 403            ContentBlock::Empty => "",
 404            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 405        }
 406    }
 407
 408    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 409        match self {
 410            ContentBlock::Empty => None,
 411            ContentBlock::Markdown { markdown } => Some(markdown),
 412        }
 413    }
 414}
 415
 416#[derive(Debug)]
 417pub enum ToolCallContent {
 418    ContentBlock(ContentBlock),
 419    Diff(Entity<Diff>),
 420    Terminal(Entity<Terminal>),
 421}
 422
 423impl ToolCallContent {
 424    pub fn from_acp(
 425        content: acp::ToolCallContent,
 426        language_registry: Arc<LanguageRegistry>,
 427        cx: &mut App,
 428    ) -> Self {
 429        match content {
 430            acp::ToolCallContent::Content { content } => {
 431                Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
 432            }
 433            acp::ToolCallContent::Diff { diff } => {
 434                Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
 435            }
 436        }
 437    }
 438
 439    pub fn to_markdown(&self, cx: &App) -> String {
 440        match self {
 441            Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
 442            Self::Diff(diff) => diff.read(cx).to_markdown(cx),
 443            Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
 444        }
 445    }
 446}
 447
 448#[derive(Debug, PartialEq)]
 449pub enum ToolCallUpdate {
 450    UpdateFields(acp::ToolCallUpdate),
 451    UpdateDiff(ToolCallUpdateDiff),
 452    UpdateTerminal(ToolCallUpdateTerminal),
 453}
 454
 455impl ToolCallUpdate {
 456    fn id(&self) -> &acp::ToolCallId {
 457        match self {
 458            Self::UpdateFields(update) => &update.id,
 459            Self::UpdateDiff(diff) => &diff.id,
 460            Self::UpdateTerminal(terminal) => &terminal.id,
 461        }
 462    }
 463}
 464
 465impl From<acp::ToolCallUpdate> for ToolCallUpdate {
 466    fn from(update: acp::ToolCallUpdate) -> Self {
 467        Self::UpdateFields(update)
 468    }
 469}
 470
 471impl From<ToolCallUpdateDiff> for ToolCallUpdate {
 472    fn from(diff: ToolCallUpdateDiff) -> Self {
 473        Self::UpdateDiff(diff)
 474    }
 475}
 476
 477#[derive(Debug, PartialEq)]
 478pub struct ToolCallUpdateDiff {
 479    pub id: acp::ToolCallId,
 480    pub diff: Entity<Diff>,
 481}
 482
 483impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
 484    fn from(terminal: ToolCallUpdateTerminal) -> Self {
 485        Self::UpdateTerminal(terminal)
 486    }
 487}
 488
 489#[derive(Debug, PartialEq)]
 490pub struct ToolCallUpdateTerminal {
 491    pub id: acp::ToolCallId,
 492    pub terminal: Entity<Terminal>,
 493}
 494
 495#[derive(Debug, Default)]
 496pub struct Plan {
 497    pub entries: Vec<PlanEntry>,
 498}
 499
 500#[derive(Debug)]
 501pub struct PlanStats<'a> {
 502    pub in_progress_entry: Option<&'a PlanEntry>,
 503    pub pending: u32,
 504    pub completed: u32,
 505}
 506
 507impl Plan {
 508    pub fn is_empty(&self) -> bool {
 509        self.entries.is_empty()
 510    }
 511
 512    pub fn stats(&self) -> PlanStats<'_> {
 513        let mut stats = PlanStats {
 514            in_progress_entry: None,
 515            pending: 0,
 516            completed: 0,
 517        };
 518
 519        for entry in &self.entries {
 520            match &entry.status {
 521                acp::PlanEntryStatus::Pending => {
 522                    stats.pending += 1;
 523                }
 524                acp::PlanEntryStatus::InProgress => {
 525                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 526                }
 527                acp::PlanEntryStatus::Completed => {
 528                    stats.completed += 1;
 529                }
 530            }
 531        }
 532
 533        stats
 534    }
 535}
 536
 537#[derive(Debug)]
 538pub struct PlanEntry {
 539    pub content: Entity<Markdown>,
 540    pub priority: acp::PlanEntryPriority,
 541    pub status: acp::PlanEntryStatus,
 542}
 543
 544impl PlanEntry {
 545    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 546        Self {
 547            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 548            priority: entry.priority,
 549            status: entry.status,
 550        }
 551    }
 552}
 553
 554pub struct AcpThread {
 555    title: SharedString,
 556    entries: Vec<AgentThreadEntry>,
 557    plan: Plan,
 558    project: Entity<Project>,
 559    action_log: Entity<ActionLog>,
 560    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 561    send_task: Option<Task<()>>,
 562    connection: Rc<dyn AgentConnection>,
 563    session_id: acp::SessionId,
 564}
 565
 566pub enum AcpThreadEvent {
 567    NewEntry,
 568    EntryUpdated(usize),
 569    ToolAuthorizationRequired,
 570    Stopped,
 571    Error,
 572    ServerExited(ExitStatus),
 573}
 574
 575impl EventEmitter<AcpThreadEvent> for AcpThread {}
 576
 577#[derive(PartialEq, Eq)]
 578pub enum ThreadStatus {
 579    Idle,
 580    WaitingForToolConfirmation,
 581    Generating,
 582}
 583
 584#[derive(Debug, Clone)]
 585pub enum LoadError {
 586    Unsupported {
 587        error_message: SharedString,
 588        upgrade_message: SharedString,
 589        upgrade_command: String,
 590    },
 591    Exited(i32),
 592    Other(SharedString),
 593}
 594
 595impl Display for LoadError {
 596    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 597        match self {
 598            LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
 599            LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
 600            LoadError::Other(msg) => write!(f, "{}", msg),
 601        }
 602    }
 603}
 604
 605impl Error for LoadError {}
 606
 607impl AcpThread {
 608    pub fn new(
 609        title: impl Into<SharedString>,
 610        connection: Rc<dyn AgentConnection>,
 611        project: Entity<Project>,
 612        session_id: acp::SessionId,
 613        cx: &mut Context<Self>,
 614    ) -> Self {
 615        let action_log = cx.new(|_| ActionLog::new(project.clone()));
 616
 617        Self {
 618            action_log,
 619            shared_buffers: Default::default(),
 620            entries: Default::default(),
 621            plan: Default::default(),
 622            title: title.into(),
 623            project,
 624            send_task: None,
 625            connection,
 626            session_id,
 627        }
 628    }
 629
 630    pub fn action_log(&self) -> &Entity<ActionLog> {
 631        &self.action_log
 632    }
 633
 634    pub fn project(&self) -> &Entity<Project> {
 635        &self.project
 636    }
 637
 638    pub fn title(&self) -> SharedString {
 639        self.title.clone()
 640    }
 641
 642    pub fn entries(&self) -> &[AgentThreadEntry] {
 643        &self.entries
 644    }
 645
 646    pub fn session_id(&self) -> &acp::SessionId {
 647        &self.session_id
 648    }
 649
 650    pub fn status(&self) -> ThreadStatus {
 651        if self.send_task.is_some() {
 652            if self.waiting_for_tool_confirmation() {
 653                ThreadStatus::WaitingForToolConfirmation
 654            } else {
 655                ThreadStatus::Generating
 656            }
 657        } else {
 658            ThreadStatus::Idle
 659        }
 660    }
 661
 662    pub fn has_pending_edit_tool_calls(&self) -> bool {
 663        for entry in self.entries.iter().rev() {
 664            match entry {
 665                AgentThreadEntry::UserMessage(_) => return false,
 666                AgentThreadEntry::ToolCall(
 667                    call @ ToolCall {
 668                        status:
 669                            ToolCallStatus::Allowed {
 670                                status:
 671                                    acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
 672                            },
 673                        ..
 674                    },
 675                ) if call.diffs().next().is_some() => {
 676                    return true;
 677                }
 678                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 679            }
 680        }
 681
 682        false
 683    }
 684
 685    pub fn used_tools_since_last_user_message(&self) -> bool {
 686        for entry in self.entries.iter().rev() {
 687            match entry {
 688                AgentThreadEntry::UserMessage(..) => return false,
 689                AgentThreadEntry::AssistantMessage(..) => continue,
 690                AgentThreadEntry::ToolCall(..) => return true,
 691            }
 692        }
 693
 694        false
 695    }
 696
 697    pub fn handle_session_update(
 698        &mut self,
 699        update: acp::SessionUpdate,
 700        cx: &mut Context<Self>,
 701    ) -> Result<()> {
 702        match update {
 703            acp::SessionUpdate::UserMessageChunk { content } => {
 704                self.push_user_content_block(content, cx);
 705            }
 706            acp::SessionUpdate::AgentMessageChunk { content } => {
 707                self.push_assistant_content_block(content, false, cx);
 708            }
 709            acp::SessionUpdate::AgentThoughtChunk { content } => {
 710                self.push_assistant_content_block(content, true, cx);
 711            }
 712            acp::SessionUpdate::ToolCall(tool_call) => {
 713                self.upsert_tool_call(tool_call, cx);
 714            }
 715            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 716                self.update_tool_call(tool_call_update, cx)?;
 717            }
 718            acp::SessionUpdate::Plan(plan) => {
 719                self.update_plan(plan, cx);
 720            }
 721        }
 722        Ok(())
 723    }
 724
 725    pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
 726        let language_registry = self.project.read(cx).languages().clone();
 727        let entries_len = self.entries.len();
 728
 729        if let Some(last_entry) = self.entries.last_mut()
 730            && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
 731        {
 732            content.append(chunk, &language_registry, cx);
 733            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 734        } else {
 735            let content = ContentBlock::new(chunk, &language_registry, cx);
 736            self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
 737        }
 738    }
 739
 740    pub fn push_assistant_content_block(
 741        &mut self,
 742        chunk: acp::ContentBlock,
 743        is_thought: bool,
 744        cx: &mut Context<Self>,
 745    ) {
 746        let language_registry = self.project.read(cx).languages().clone();
 747        let entries_len = self.entries.len();
 748        if let Some(last_entry) = self.entries.last_mut()
 749            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
 750        {
 751            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 752            match (chunks.last_mut(), is_thought) {
 753                (Some(AssistantMessageChunk::Message { block }), false)
 754                | (Some(AssistantMessageChunk::Thought { block }), true) => {
 755                    block.append(chunk, &language_registry, cx)
 756                }
 757                _ => {
 758                    let block = ContentBlock::new(chunk, &language_registry, cx);
 759                    if is_thought {
 760                        chunks.push(AssistantMessageChunk::Thought { block })
 761                    } else {
 762                        chunks.push(AssistantMessageChunk::Message { block })
 763                    }
 764                }
 765            }
 766        } else {
 767            let block = ContentBlock::new(chunk, &language_registry, cx);
 768            let chunk = if is_thought {
 769                AssistantMessageChunk::Thought { block }
 770            } else {
 771                AssistantMessageChunk::Message { block }
 772            };
 773
 774            self.push_entry(
 775                AgentThreadEntry::AssistantMessage(AssistantMessage {
 776                    chunks: vec![chunk],
 777                }),
 778                cx,
 779            );
 780        }
 781    }
 782
 783    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
 784        self.entries.push(entry);
 785        cx.emit(AcpThreadEvent::NewEntry);
 786    }
 787
 788    pub fn update_tool_call(
 789        &mut self,
 790        update: impl Into<ToolCallUpdate>,
 791        cx: &mut Context<Self>,
 792    ) -> Result<()> {
 793        let update = update.into();
 794        let languages = self.project.read(cx).languages().clone();
 795
 796        let (ix, current_call) = self
 797            .tool_call_mut(update.id())
 798            .context("Tool call not found")?;
 799        match update {
 800            ToolCallUpdate::UpdateFields(update) => {
 801                current_call.update_fields(update.fields, languages, cx);
 802            }
 803            ToolCallUpdate::UpdateDiff(update) => {
 804                current_call.content.clear();
 805                current_call
 806                    .content
 807                    .push(ToolCallContent::Diff(update.diff));
 808            }
 809            ToolCallUpdate::UpdateTerminal(update) => {
 810                current_call.content.clear();
 811                current_call
 812                    .content
 813                    .push(ToolCallContent::Terminal(update.terminal));
 814            }
 815        }
 816
 817        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 818
 819        Ok(())
 820    }
 821
 822    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
 823    pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
 824        let status = ToolCallStatus::Allowed {
 825            status: tool_call.status,
 826        };
 827        self.upsert_tool_call_inner(tool_call, status, cx)
 828    }
 829
 830    pub fn upsert_tool_call_inner(
 831        &mut self,
 832        tool_call: acp::ToolCall,
 833        status: ToolCallStatus,
 834        cx: &mut Context<Self>,
 835    ) {
 836        let language_registry = self.project.read(cx).languages().clone();
 837        let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
 838
 839        let location = call.locations.last().cloned();
 840
 841        if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
 842            *current_call = call;
 843
 844            cx.emit(AcpThreadEvent::EntryUpdated(ix));
 845        } else {
 846            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
 847        }
 848
 849        if let Some(location) = location {
 850            self.set_project_location(location, cx)
 851        }
 852    }
 853
 854    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
 855        // The tool call we are looking for is typically the last one, or very close to the end.
 856        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
 857        self.entries
 858            .iter_mut()
 859            .enumerate()
 860            .rev()
 861            .find_map(|(index, tool_call)| {
 862                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
 863                    && &tool_call.id == id
 864                {
 865                    Some((index, tool_call))
 866                } else {
 867                    None
 868                }
 869            })
 870    }
 871
 872    pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
 873        self.project.update(cx, |project, cx| {
 874            let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
 875                return;
 876            };
 877            let buffer = project.open_buffer(path, cx);
 878            cx.spawn(async move |project, cx| {
 879                let buffer = buffer.await?;
 880
 881                project.update(cx, |project, cx| {
 882                    let position = if let Some(line) = location.line {
 883                        let snapshot = buffer.read(cx).snapshot();
 884                        let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
 885                        snapshot.anchor_before(point)
 886                    } else {
 887                        Anchor::MIN
 888                    };
 889
 890                    project.set_agent_location(
 891                        Some(AgentLocation {
 892                            buffer: buffer.downgrade(),
 893                            position,
 894                        }),
 895                        cx,
 896                    );
 897                })
 898            })
 899            .detach_and_log_err(cx);
 900        });
 901    }
 902
 903    pub fn request_tool_call_authorization(
 904        &mut self,
 905        tool_call: acp::ToolCall,
 906        options: Vec<acp::PermissionOption>,
 907        cx: &mut Context<Self>,
 908    ) -> oneshot::Receiver<acp::PermissionOptionId> {
 909        let (tx, rx) = oneshot::channel();
 910
 911        let status = ToolCallStatus::WaitingForConfirmation {
 912            options,
 913            respond_tx: tx,
 914        };
 915
 916        self.upsert_tool_call_inner(tool_call, status, cx);
 917        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
 918        rx
 919    }
 920
 921    pub fn authorize_tool_call(
 922        &mut self,
 923        id: acp::ToolCallId,
 924        option_id: acp::PermissionOptionId,
 925        option_kind: acp::PermissionOptionKind,
 926        cx: &mut Context<Self>,
 927    ) {
 928        let Some((ix, call)) = self.tool_call_mut(&id) else {
 929            return;
 930        };
 931
 932        let new_status = match option_kind {
 933            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
 934                ToolCallStatus::Rejected
 935            }
 936            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
 937                ToolCallStatus::Allowed {
 938                    status: acp::ToolCallStatus::InProgress,
 939                }
 940            }
 941        };
 942
 943        let curr_status = mem::replace(&mut call.status, new_status);
 944
 945        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
 946            respond_tx.send(option_id).log_err();
 947        } else if cfg!(debug_assertions) {
 948            panic!("tried to authorize an already authorized tool call");
 949        }
 950
 951        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 952    }
 953
 954    /// Returns true if the last turn is awaiting tool authorization
 955    pub fn waiting_for_tool_confirmation(&self) -> bool {
 956        for entry in self.entries.iter().rev() {
 957            match &entry {
 958                AgentThreadEntry::ToolCall(call) => match call.status {
 959                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
 960                    ToolCallStatus::Allowed { .. }
 961                    | ToolCallStatus::Rejected
 962                    | ToolCallStatus::Canceled => continue,
 963                },
 964                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
 965                    // Reached the beginning of the turn
 966                    return false;
 967                }
 968            }
 969        }
 970        false
 971    }
 972
 973    pub fn plan(&self) -> &Plan {
 974        &self.plan
 975    }
 976
 977    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
 978        let new_entries_len = request.entries.len();
 979        let mut new_entries = request.entries.into_iter();
 980
 981        // Reuse existing markdown to prevent flickering
 982        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
 983            let PlanEntry {
 984                content,
 985                priority,
 986                status,
 987            } = old;
 988            content.update(cx, |old, cx| {
 989                old.replace(new.content, cx);
 990            });
 991            *priority = new.priority;
 992            *status = new.status;
 993        }
 994        for new in new_entries {
 995            self.plan.entries.push(PlanEntry::from_acp(new, cx))
 996        }
 997        self.plan.entries.truncate(new_entries_len);
 998
 999        cx.notify();
1000    }
1001
1002    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1003        self.plan
1004            .entries
1005            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1006        cx.notify();
1007    }
1008
1009    #[cfg(any(test, feature = "test-support"))]
1010    pub fn send_raw(
1011        &mut self,
1012        message: &str,
1013        cx: &mut Context<Self>,
1014    ) -> BoxFuture<'static, Result<()>> {
1015        self.send(
1016            vec![acp::ContentBlock::Text(acp::TextContent {
1017                text: message.to_string(),
1018                annotations: None,
1019            })],
1020            cx,
1021        )
1022    }
1023
1024    pub fn send(
1025        &mut self,
1026        message: Vec<acp::ContentBlock>,
1027        cx: &mut Context<Self>,
1028    ) -> BoxFuture<'static, Result<()>> {
1029        let block = ContentBlock::new_combined(
1030            message.clone(),
1031            self.project.read(cx).languages().clone(),
1032            cx,
1033        );
1034        self.push_entry(
1035            AgentThreadEntry::UserMessage(UserMessage { content: block }),
1036            cx,
1037        );
1038        self.clear_completed_plan_entries(cx);
1039
1040        let (tx, rx) = oneshot::channel();
1041        let cancel_task = self.cancel(cx);
1042
1043        self.send_task = Some(cx.spawn(async move |this, cx| {
1044            async {
1045                cancel_task.await;
1046
1047                let result = this
1048                    .update(cx, |this, cx| {
1049                        this.connection.prompt(
1050                            acp::PromptRequest {
1051                                prompt: message,
1052                                session_id: this.session_id.clone(),
1053                            },
1054                            cx,
1055                        )
1056                    })?
1057                    .await;
1058
1059                tx.send(result).log_err();
1060
1061                anyhow::Ok(())
1062            }
1063            .await
1064            .log_err();
1065        }));
1066
1067        cx.spawn(async move |this, cx| match rx.await {
1068            Ok(Err(e)) => {
1069                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1070                    .log_err();
1071                Err(e)?
1072            }
1073            result => {
1074                let cancelled = matches!(
1075                    result,
1076                    Ok(Ok(acp::PromptResponse {
1077                        stop_reason: acp::StopReason::Cancelled
1078                    }))
1079                );
1080
1081                // We only take the task if the current prompt wasn't cancelled.
1082                //
1083                // This prompt may have been cancelled because another one was sent
1084                // while it was still generating. In these cases, dropping `send_task`
1085                // would cause the next generation to be cancelled.
1086                if !cancelled {
1087                    this.update(cx, |this, _cx| this.send_task.take()).ok();
1088                }
1089
1090                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1091                    .log_err();
1092                Ok(())
1093            }
1094        })
1095        .boxed()
1096    }
1097
1098    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1099        let Some(send_task) = self.send_task.take() else {
1100            return Task::ready(());
1101        };
1102
1103        for entry in self.entries.iter_mut() {
1104            if let AgentThreadEntry::ToolCall(call) = entry {
1105                let cancel = matches!(
1106                    call.status,
1107                    ToolCallStatus::WaitingForConfirmation { .. }
1108                        | ToolCallStatus::Allowed {
1109                            status: acp::ToolCallStatus::InProgress
1110                        }
1111                );
1112
1113                if cancel {
1114                    call.status = ToolCallStatus::Canceled;
1115                }
1116            }
1117        }
1118
1119        self.connection.cancel(&self.session_id, cx);
1120
1121        // Wait for the send task to complete
1122        cx.foreground_executor().spawn(send_task)
1123    }
1124
1125    pub fn read_text_file(
1126        &self,
1127        path: PathBuf,
1128        line: Option<u32>,
1129        limit: Option<u32>,
1130        reuse_shared_snapshot: bool,
1131        cx: &mut Context<Self>,
1132    ) -> Task<Result<String>> {
1133        let project = self.project.clone();
1134        let action_log = self.action_log.clone();
1135        cx.spawn(async move |this, cx| {
1136            let load = project.update(cx, |project, cx| {
1137                let path = project
1138                    .project_path_for_absolute_path(&path, cx)
1139                    .context("invalid path")?;
1140                anyhow::Ok(project.open_buffer(path, cx))
1141            });
1142            let buffer = load??.await?;
1143
1144            let snapshot = if reuse_shared_snapshot {
1145                this.read_with(cx, |this, _| {
1146                    this.shared_buffers.get(&buffer.clone()).cloned()
1147                })
1148                .log_err()
1149                .flatten()
1150            } else {
1151                None
1152            };
1153
1154            let snapshot = if let Some(snapshot) = snapshot {
1155                snapshot
1156            } else {
1157                action_log.update(cx, |action_log, cx| {
1158                    action_log.buffer_read(buffer.clone(), cx);
1159                })?;
1160                project.update(cx, |project, cx| {
1161                    let position = buffer
1162                        .read(cx)
1163                        .snapshot()
1164                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1165                    project.set_agent_location(
1166                        Some(AgentLocation {
1167                            buffer: buffer.downgrade(),
1168                            position,
1169                        }),
1170                        cx,
1171                    );
1172                })?;
1173
1174                buffer.update(cx, |buffer, _| buffer.snapshot())?
1175            };
1176
1177            this.update(cx, |this, _| {
1178                let text = snapshot.text();
1179                this.shared_buffers.insert(buffer.clone(), snapshot);
1180                if line.is_none() && limit.is_none() {
1181                    return Ok(text);
1182                }
1183                let limit = limit.unwrap_or(u32::MAX) as usize;
1184                let Some(line) = line else {
1185                    return Ok(text.lines().take(limit).collect::<String>());
1186                };
1187
1188                let count = text.lines().count();
1189                if count < line as usize {
1190                    anyhow::bail!("There are only {} lines", count);
1191                }
1192                Ok(text
1193                    .lines()
1194                    .skip(line as usize + 1)
1195                    .take(limit)
1196                    .collect::<String>())
1197            })?
1198        })
1199    }
1200
1201    pub fn write_text_file(
1202        &self,
1203        path: PathBuf,
1204        content: String,
1205        cx: &mut Context<Self>,
1206    ) -> Task<Result<()>> {
1207        let project = self.project.clone();
1208        let action_log = self.action_log.clone();
1209        cx.spawn(async move |this, cx| {
1210            let load = project.update(cx, |project, cx| {
1211                let path = project
1212                    .project_path_for_absolute_path(&path, cx)
1213                    .context("invalid path")?;
1214                anyhow::Ok(project.open_buffer(path, cx))
1215            });
1216            let buffer = load??.await?;
1217            let snapshot = this.update(cx, |this, cx| {
1218                this.shared_buffers
1219                    .get(&buffer)
1220                    .cloned()
1221                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1222            })?;
1223            let edits = cx
1224                .background_executor()
1225                .spawn(async move {
1226                    let old_text = snapshot.text();
1227                    text_diff(old_text.as_str(), &content)
1228                        .into_iter()
1229                        .map(|(range, replacement)| {
1230                            (
1231                                snapshot.anchor_after(range.start)
1232                                    ..snapshot.anchor_before(range.end),
1233                                replacement,
1234                            )
1235                        })
1236                        .collect::<Vec<_>>()
1237                })
1238                .await;
1239            cx.update(|cx| {
1240                project.update(cx, |project, cx| {
1241                    project.set_agent_location(
1242                        Some(AgentLocation {
1243                            buffer: buffer.downgrade(),
1244                            position: edits
1245                                .last()
1246                                .map(|(range, _)| range.end)
1247                                .unwrap_or(Anchor::MIN),
1248                        }),
1249                        cx,
1250                    );
1251                });
1252
1253                action_log.update(cx, |action_log, cx| {
1254                    action_log.buffer_read(buffer.clone(), cx);
1255                });
1256                buffer.update(cx, |buffer, cx| {
1257                    buffer.edit(edits, None, cx);
1258                });
1259                action_log.update(cx, |action_log, cx| {
1260                    action_log.buffer_edited(buffer.clone(), cx);
1261                });
1262            })?;
1263            project
1264                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1265                .await
1266        })
1267    }
1268
1269    pub fn to_markdown(&self, cx: &App) -> String {
1270        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1271    }
1272
1273    pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1274        cx.emit(AcpThreadEvent::ServerExited(status));
1275    }
1276}
1277
1278fn markdown_for_raw_output(
1279    raw_output: &serde_json::Value,
1280    language_registry: &Arc<LanguageRegistry>,
1281    cx: &mut App,
1282) -> Option<Entity<Markdown>> {
1283    match raw_output {
1284        serde_json::Value::Null => None,
1285        serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1286            Markdown::new(
1287                value.to_string().into(),
1288                Some(language_registry.clone()),
1289                None,
1290                cx,
1291            )
1292        })),
1293        serde_json::Value::Number(value) => Some(cx.new(|cx| {
1294            Markdown::new(
1295                value.to_string().into(),
1296                Some(language_registry.clone()),
1297                None,
1298                cx,
1299            )
1300        })),
1301        serde_json::Value::String(value) => Some(cx.new(|cx| {
1302            Markdown::new(
1303                value.clone().into(),
1304                Some(language_registry.clone()),
1305                None,
1306                cx,
1307            )
1308        })),
1309        value => Some(cx.new(|cx| {
1310            Markdown::new(
1311                format!("```json\n{}\n```", value).into(),
1312                Some(language_registry.clone()),
1313                None,
1314                cx,
1315            )
1316        })),
1317    }
1318}
1319
1320#[cfg(test)]
1321mod tests {
1322    use super::*;
1323    use anyhow::anyhow;
1324    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1325    use gpui::{AsyncApp, TestAppContext, WeakEntity};
1326    use indoc::indoc;
1327    use project::FakeFs;
1328    use rand::Rng as _;
1329    use serde_json::json;
1330    use settings::SettingsStore;
1331    use smol::stream::StreamExt as _;
1332    use std::{cell::RefCell, rc::Rc, time::Duration};
1333
1334    use util::path;
1335
1336    fn init_test(cx: &mut TestAppContext) {
1337        env_logger::try_init().ok();
1338        cx.update(|cx| {
1339            let settings_store = SettingsStore::test(cx);
1340            cx.set_global(settings_store);
1341            Project::init_settings(cx);
1342            language::init(cx);
1343        });
1344    }
1345
1346    #[gpui::test]
1347    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1348        init_test(cx);
1349
1350        let fs = FakeFs::new(cx.executor());
1351        let project = Project::test(fs, [], cx).await;
1352        let connection = Rc::new(FakeAgentConnection::new());
1353        let thread = cx
1354            .spawn(async move |mut cx| {
1355                connection
1356                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1357                    .await
1358            })
1359            .await
1360            .unwrap();
1361
1362        // Test creating a new user message
1363        thread.update(cx, |thread, cx| {
1364            thread.push_user_content_block(
1365                acp::ContentBlock::Text(acp::TextContent {
1366                    annotations: None,
1367                    text: "Hello, ".to_string(),
1368                }),
1369                cx,
1370            );
1371        });
1372
1373        thread.update(cx, |thread, cx| {
1374            assert_eq!(thread.entries.len(), 1);
1375            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1376                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1377            } else {
1378                panic!("Expected UserMessage");
1379            }
1380        });
1381
1382        // Test appending to existing user message
1383        thread.update(cx, |thread, cx| {
1384            thread.push_user_content_block(
1385                acp::ContentBlock::Text(acp::TextContent {
1386                    annotations: None,
1387                    text: "world!".to_string(),
1388                }),
1389                cx,
1390            );
1391        });
1392
1393        thread.update(cx, |thread, cx| {
1394            assert_eq!(thread.entries.len(), 1);
1395            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1396                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1397            } else {
1398                panic!("Expected UserMessage");
1399            }
1400        });
1401
1402        // Test creating new user message after assistant message
1403        thread.update(cx, |thread, cx| {
1404            thread.push_assistant_content_block(
1405                acp::ContentBlock::Text(acp::TextContent {
1406                    annotations: None,
1407                    text: "Assistant response".to_string(),
1408                }),
1409                false,
1410                cx,
1411            );
1412        });
1413
1414        thread.update(cx, |thread, cx| {
1415            thread.push_user_content_block(
1416                acp::ContentBlock::Text(acp::TextContent {
1417                    annotations: None,
1418                    text: "New user message".to_string(),
1419                }),
1420                cx,
1421            );
1422        });
1423
1424        thread.update(cx, |thread, cx| {
1425            assert_eq!(thread.entries.len(), 3);
1426            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1427                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1428            } else {
1429                panic!("Expected UserMessage at index 2");
1430            }
1431        });
1432    }
1433
1434    #[gpui::test]
1435    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1436        init_test(cx);
1437
1438        let fs = FakeFs::new(cx.executor());
1439        let project = Project::test(fs, [], cx).await;
1440        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1441            |_, thread, mut cx| {
1442                async move {
1443                    thread.update(&mut cx, |thread, cx| {
1444                        thread
1445                            .handle_session_update(
1446                                acp::SessionUpdate::AgentThoughtChunk {
1447                                    content: "Thinking ".into(),
1448                                },
1449                                cx,
1450                            )
1451                            .unwrap();
1452                        thread
1453                            .handle_session_update(
1454                                acp::SessionUpdate::AgentThoughtChunk {
1455                                    content: "hard!".into(),
1456                                },
1457                                cx,
1458                            )
1459                            .unwrap();
1460                    })?;
1461                    Ok(acp::PromptResponse {
1462                        stop_reason: acp::StopReason::EndTurn,
1463                    })
1464                }
1465                .boxed_local()
1466            },
1467        ));
1468
1469        let thread = cx
1470            .spawn(async move |mut cx| {
1471                connection
1472                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1473                    .await
1474            })
1475            .await
1476            .unwrap();
1477
1478        thread
1479            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1480            .await
1481            .unwrap();
1482
1483        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1484        assert_eq!(
1485            output,
1486            indoc! {r#"
1487            ## User
1488
1489            Hello from Zed!
1490
1491            ## Assistant
1492
1493            <thinking>
1494            Thinking hard!
1495            </thinking>
1496
1497            "#}
1498        );
1499    }
1500
1501    #[gpui::test]
1502    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1503        init_test(cx);
1504
1505        let fs = FakeFs::new(cx.executor());
1506        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1507            .await;
1508        let project = Project::test(fs.clone(), [], cx).await;
1509        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1510        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1511        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1512            move |_, thread, mut cx| {
1513                let read_file_tx = read_file_tx.clone();
1514                async move {
1515                    let content = thread
1516                        .update(&mut cx, |thread, cx| {
1517                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1518                        })
1519                        .unwrap()
1520                        .await
1521                        .unwrap();
1522                    assert_eq!(content, "one\ntwo\nthree\n");
1523                    read_file_tx.take().unwrap().send(()).unwrap();
1524                    thread
1525                        .update(&mut cx, |thread, cx| {
1526                            thread.write_text_file(
1527                                path!("/tmp/foo").into(),
1528                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
1529                                cx,
1530                            )
1531                        })
1532                        .unwrap()
1533                        .await
1534                        .unwrap();
1535                    Ok(acp::PromptResponse {
1536                        stop_reason: acp::StopReason::EndTurn,
1537                    })
1538                }
1539                .boxed_local()
1540            },
1541        ));
1542
1543        let (worktree, pathbuf) = project
1544            .update(cx, |project, cx| {
1545                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1546            })
1547            .await
1548            .unwrap();
1549        let buffer = project
1550            .update(cx, |project, cx| {
1551                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1552            })
1553            .await
1554            .unwrap();
1555
1556        let thread = cx
1557            .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1558            .await
1559            .unwrap();
1560
1561        let request = thread.update(cx, |thread, cx| {
1562            thread.send_raw("Extend the count in /tmp/foo", cx)
1563        });
1564        read_file_rx.await.ok();
1565        buffer.update(cx, |buffer, cx| {
1566            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1567        });
1568        cx.run_until_parked();
1569        assert_eq!(
1570            buffer.read_with(cx, |buffer, _| buffer.text()),
1571            "zero\none\ntwo\nthree\nfour\nfive\n"
1572        );
1573        assert_eq!(
1574            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1575            "zero\none\ntwo\nthree\nfour\nfive\n"
1576        );
1577        request.await.unwrap();
1578    }
1579
1580    #[gpui::test]
1581    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1582        init_test(cx);
1583
1584        let fs = FakeFs::new(cx.executor());
1585        let project = Project::test(fs, [], cx).await;
1586        let id = acp::ToolCallId("test".into());
1587
1588        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1589            let id = id.clone();
1590            move |_, thread, mut cx| {
1591                let id = id.clone();
1592                async move {
1593                    thread
1594                        .update(&mut cx, |thread, cx| {
1595                            thread.handle_session_update(
1596                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1597                                    id: id.clone(),
1598                                    title: "Label".into(),
1599                                    kind: acp::ToolKind::Fetch,
1600                                    status: acp::ToolCallStatus::InProgress,
1601                                    content: vec![],
1602                                    locations: vec![],
1603                                    raw_input: None,
1604                                    raw_output: None,
1605                                }),
1606                                cx,
1607                            )
1608                        })
1609                        .unwrap()
1610                        .unwrap();
1611                    Ok(acp::PromptResponse {
1612                        stop_reason: acp::StopReason::EndTurn,
1613                    })
1614                }
1615                .boxed_local()
1616            }
1617        }));
1618
1619        let thread = cx
1620            .spawn(async move |mut cx| {
1621                connection
1622                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1623                    .await
1624            })
1625            .await
1626            .unwrap();
1627
1628        let request = thread.update(cx, |thread, cx| {
1629            thread.send_raw("Fetch https://example.com", cx)
1630        });
1631
1632        run_until_first_tool_call(&thread, cx).await;
1633
1634        thread.read_with(cx, |thread, _| {
1635            assert!(matches!(
1636                thread.entries[1],
1637                AgentThreadEntry::ToolCall(ToolCall {
1638                    status: ToolCallStatus::Allowed {
1639                        status: acp::ToolCallStatus::InProgress,
1640                        ..
1641                    },
1642                    ..
1643                })
1644            ));
1645        });
1646
1647        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1648
1649        thread.read_with(cx, |thread, _| {
1650            assert!(matches!(
1651                &thread.entries[1],
1652                AgentThreadEntry::ToolCall(ToolCall {
1653                    status: ToolCallStatus::Canceled,
1654                    ..
1655                })
1656            ));
1657        });
1658
1659        thread
1660            .update(cx, |thread, cx| {
1661                thread.handle_session_update(
1662                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1663                        id,
1664                        fields: acp::ToolCallUpdateFields {
1665                            status: Some(acp::ToolCallStatus::Completed),
1666                            ..Default::default()
1667                        },
1668                    }),
1669                    cx,
1670                )
1671            })
1672            .unwrap();
1673
1674        request.await.unwrap();
1675
1676        thread.read_with(cx, |thread, _| {
1677            assert!(matches!(
1678                thread.entries[1],
1679                AgentThreadEntry::ToolCall(ToolCall {
1680                    status: ToolCallStatus::Allowed {
1681                        status: acp::ToolCallStatus::Completed,
1682                        ..
1683                    },
1684                    ..
1685                })
1686            ));
1687        });
1688    }
1689
1690    #[gpui::test]
1691    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1692        init_test(cx);
1693        let fs = FakeFs::new(cx.background_executor.clone());
1694        fs.insert_tree(path!("/test"), json!({})).await;
1695        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1696
1697        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1698            move |_, thread, mut cx| {
1699                async move {
1700                    thread
1701                        .update(&mut cx, |thread, cx| {
1702                            thread.handle_session_update(
1703                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1704                                    id: acp::ToolCallId("test".into()),
1705                                    title: "Label".into(),
1706                                    kind: acp::ToolKind::Edit,
1707                                    status: acp::ToolCallStatus::Completed,
1708                                    content: vec![acp::ToolCallContent::Diff {
1709                                        diff: acp::Diff {
1710                                            path: "/test/test.txt".into(),
1711                                            old_text: None,
1712                                            new_text: "foo".into(),
1713                                        },
1714                                    }],
1715                                    locations: vec![],
1716                                    raw_input: None,
1717                                    raw_output: None,
1718                                }),
1719                                cx,
1720                            )
1721                        })
1722                        .unwrap()
1723                        .unwrap();
1724                    Ok(acp::PromptResponse {
1725                        stop_reason: acp::StopReason::EndTurn,
1726                    })
1727                }
1728                .boxed_local()
1729            }
1730        }));
1731
1732        let thread = connection
1733            .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1734            .await
1735            .unwrap();
1736        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1737            .await
1738            .unwrap();
1739
1740        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1741    }
1742
1743    async fn run_until_first_tool_call(
1744        thread: &Entity<AcpThread>,
1745        cx: &mut TestAppContext,
1746    ) -> usize {
1747        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1748
1749        let subscription = cx.update(|cx| {
1750            cx.subscribe(thread, move |thread, _, cx| {
1751                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1752                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1753                        return tx.try_send(ix).unwrap();
1754                    }
1755                }
1756            })
1757        });
1758
1759        select! {
1760            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1761                panic!("Timeout waiting for tool call")
1762            }
1763            ix = rx.next().fuse() => {
1764                drop(subscription);
1765                ix.unwrap()
1766            }
1767        }
1768    }
1769
1770    #[derive(Clone, Default)]
1771    struct FakeAgentConnection {
1772        auth_methods: Vec<acp::AuthMethod>,
1773        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1774        on_user_message: Option<
1775            Rc<
1776                dyn Fn(
1777                        acp::PromptRequest,
1778                        WeakEntity<AcpThread>,
1779                        AsyncApp,
1780                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1781                    + 'static,
1782            >,
1783        >,
1784    }
1785
1786    impl FakeAgentConnection {
1787        fn new() -> Self {
1788            Self {
1789                auth_methods: Vec::new(),
1790                on_user_message: None,
1791                sessions: Arc::default(),
1792            }
1793        }
1794
1795        #[expect(unused)]
1796        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1797            self.auth_methods = auth_methods;
1798            self
1799        }
1800
1801        fn on_user_message(
1802            mut self,
1803            handler: impl Fn(
1804                acp::PromptRequest,
1805                WeakEntity<AcpThread>,
1806                AsyncApp,
1807            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1808            + 'static,
1809        ) -> Self {
1810            self.on_user_message.replace(Rc::new(handler));
1811            self
1812        }
1813    }
1814
1815    impl AgentConnection for FakeAgentConnection {
1816        fn auth_methods(&self) -> &[acp::AuthMethod] {
1817            &self.auth_methods
1818        }
1819
1820        fn new_thread(
1821            self: Rc<Self>,
1822            project: Entity<Project>,
1823            _cwd: &Path,
1824            cx: &mut gpui::AsyncApp,
1825        ) -> Task<gpui::Result<Entity<AcpThread>>> {
1826            let session_id = acp::SessionId(
1827                rand::thread_rng()
1828                    .sample_iter(&rand::distributions::Alphanumeric)
1829                    .take(7)
1830                    .map(char::from)
1831                    .collect::<String>()
1832                    .into(),
1833            );
1834            let thread = cx
1835                .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1836                .unwrap();
1837            self.sessions.lock().insert(session_id, thread.downgrade());
1838            Task::ready(Ok(thread))
1839        }
1840
1841        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1842            if self.auth_methods().iter().any(|m| m.id == method) {
1843                Task::ready(Ok(()))
1844            } else {
1845                Task::ready(Err(anyhow!("Invalid Auth Method")))
1846            }
1847        }
1848
1849        fn prompt(
1850            &self,
1851            params: acp::PromptRequest,
1852            cx: &mut App,
1853        ) -> Task<gpui::Result<acp::PromptResponse>> {
1854            let sessions = self.sessions.lock();
1855            let thread = sessions.get(&params.session_id).unwrap();
1856            if let Some(handler) = &self.on_user_message {
1857                let handler = handler.clone();
1858                let thread = thread.clone();
1859                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1860            } else {
1861                Task::ready(Ok(acp::PromptResponse {
1862                    stop_reason: acp::StopReason::EndTurn,
1863                }))
1864            }
1865        }
1866
1867        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1868            let sessions = self.sessions.lock();
1869            let thread = sessions.get(&session_id).unwrap().clone();
1870
1871            cx.spawn(async move |cx| {
1872                thread
1873                    .update(cx, |thread, cx| thread.cancel(cx))
1874                    .unwrap()
1875                    .await
1876            })
1877            .detach();
1878        }
1879    }
1880}