acp_thread.rs

   1mod connection;
   2mod diff;
   3mod mention;
   4mod terminal;
   5
   6pub use connection::*;
   7pub use diff::*;
   8pub use mention::*;
   9pub use terminal::*;
  10
  11use action_log::ActionLog;
  12use agent_client_protocol::{self as acp};
  13use anyhow::{Context as _, Result};
  14use editor::Bias;
  15use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  16use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
  17use itertools::Itertools;
  18use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff};
  19use markdown::Markdown;
  20use project::{AgentLocation, Project};
  21use std::collections::HashMap;
  22use std::error::Error;
  23use std::fmt::Formatter;
  24use std::process::ExitStatus;
  25use std::rc::Rc;
  26use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
  27use ui::App;
  28use util::ResultExt;
  29
  30#[derive(Debug)]
  31pub struct UserMessage {
  32    pub content: ContentBlock,
  33}
  34
  35impl UserMessage {
  36    pub fn from_acp(
  37        message: impl IntoIterator<Item = acp::ContentBlock>,
  38        language_registry: Arc<LanguageRegistry>,
  39        cx: &mut App,
  40    ) -> Self {
  41        let mut content = ContentBlock::Empty;
  42        for chunk in message {
  43            content.append(chunk, &language_registry, cx)
  44        }
  45        Self { content: content }
  46    }
  47
  48    fn to_markdown(&self, cx: &App) -> String {
  49        format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
  50    }
  51}
  52
  53#[derive(Debug, PartialEq)]
  54pub struct AssistantMessage {
  55    pub chunks: Vec<AssistantMessageChunk>,
  56}
  57
  58impl AssistantMessage {
  59    pub fn to_markdown(&self, cx: &App) -> String {
  60        format!(
  61            "## Assistant\n\n{}\n\n",
  62            self.chunks
  63                .iter()
  64                .map(|chunk| chunk.to_markdown(cx))
  65                .join("\n\n")
  66        )
  67    }
  68}
  69
  70#[derive(Debug, PartialEq)]
  71pub enum AssistantMessageChunk {
  72    Message { block: ContentBlock },
  73    Thought { block: ContentBlock },
  74}
  75
  76impl AssistantMessageChunk {
  77    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
  78        Self::Message {
  79            block: ContentBlock::new(chunk.into(), language_registry, cx),
  80        }
  81    }
  82
  83    fn to_markdown(&self, cx: &App) -> String {
  84        match self {
  85            Self::Message { block } => block.to_markdown(cx).to_string(),
  86            Self::Thought { block } => {
  87                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
  88            }
  89        }
  90    }
  91}
  92
  93#[derive(Debug)]
  94pub enum AgentThreadEntry {
  95    UserMessage(UserMessage),
  96    AssistantMessage(AssistantMessage),
  97    ToolCall(ToolCall),
  98}
  99
 100impl AgentThreadEntry {
 101    fn to_markdown(&self, cx: &App) -> String {
 102        match self {
 103            Self::UserMessage(message) => message.to_markdown(cx),
 104            Self::AssistantMessage(message) => message.to_markdown(cx),
 105            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 106        }
 107    }
 108
 109    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 110        if let AgentThreadEntry::ToolCall(call) = self {
 111            itertools::Either::Left(call.diffs())
 112        } else {
 113            itertools::Either::Right(std::iter::empty())
 114        }
 115    }
 116
 117    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 118        if let AgentThreadEntry::ToolCall(call) = self {
 119            itertools::Either::Left(call.terminals())
 120        } else {
 121            itertools::Either::Right(std::iter::empty())
 122        }
 123    }
 124
 125    pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
 126        if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
 127            Some(locations)
 128        } else {
 129            None
 130        }
 131    }
 132}
 133
 134#[derive(Debug)]
 135pub struct ToolCall {
 136    pub id: acp::ToolCallId,
 137    pub label: Entity<Markdown>,
 138    pub kind: acp::ToolKind,
 139    pub content: Vec<ToolCallContent>,
 140    pub status: ToolCallStatus,
 141    pub locations: Vec<acp::ToolCallLocation>,
 142    pub raw_input: Option<serde_json::Value>,
 143    pub raw_output: Option<serde_json::Value>,
 144}
 145
 146impl ToolCall {
 147    fn from_acp(
 148        tool_call: acp::ToolCall,
 149        status: ToolCallStatus,
 150        language_registry: Arc<LanguageRegistry>,
 151        cx: &mut App,
 152    ) -> Self {
 153        Self {
 154            id: tool_call.id,
 155            label: cx.new(|cx| {
 156                Markdown::new(
 157                    tool_call.title.into(),
 158                    Some(language_registry.clone()),
 159                    None,
 160                    cx,
 161                )
 162            }),
 163            kind: tool_call.kind,
 164            content: tool_call
 165                .content
 166                .into_iter()
 167                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 168                .collect(),
 169            locations: tool_call.locations,
 170            status,
 171            raw_input: tool_call.raw_input,
 172            raw_output: tool_call.raw_output,
 173        }
 174    }
 175
 176    fn update_fields(
 177        &mut self,
 178        fields: acp::ToolCallUpdateFields,
 179        language_registry: Arc<LanguageRegistry>,
 180        cx: &mut App,
 181    ) {
 182        let acp::ToolCallUpdateFields {
 183            kind,
 184            status,
 185            title,
 186            content,
 187            locations,
 188            raw_input,
 189            raw_output,
 190        } = fields;
 191
 192        if let Some(kind) = kind {
 193            self.kind = kind;
 194        }
 195
 196        if let Some(status) = status {
 197            self.status = ToolCallStatus::Allowed { status };
 198        }
 199
 200        if let Some(title) = title {
 201            self.label.update(cx, |label, cx| {
 202                label.replace(title, cx);
 203            });
 204        }
 205
 206        if let Some(content) = content {
 207            self.content = content
 208                .into_iter()
 209                .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
 210                .collect();
 211        }
 212
 213        if let Some(locations) = locations {
 214            self.locations = locations;
 215        }
 216
 217        if let Some(raw_input) = raw_input {
 218            self.raw_input = Some(raw_input);
 219        }
 220
 221        if let Some(raw_output) = raw_output {
 222            if self.content.is_empty() {
 223                if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
 224                {
 225                    self.content
 226                        .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
 227                            markdown,
 228                        }));
 229                }
 230            }
 231            self.raw_output = Some(raw_output);
 232        }
 233    }
 234
 235    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 236        self.content.iter().filter_map(|content| match content {
 237            ToolCallContent::Diff(diff) => Some(diff),
 238            ToolCallContent::ContentBlock(_) => None,
 239            ToolCallContent::Terminal(_) => None,
 240        })
 241    }
 242
 243    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 244        self.content.iter().filter_map(|content| match content {
 245            ToolCallContent::Terminal(terminal) => Some(terminal),
 246            ToolCallContent::ContentBlock(_) => None,
 247            ToolCallContent::Diff(_) => None,
 248        })
 249    }
 250
 251    fn to_markdown(&self, cx: &App) -> String {
 252        let mut markdown = format!(
 253            "**Tool Call: {}**\nStatus: {}\n\n",
 254            self.label.read(cx).source(),
 255            self.status
 256        );
 257        for content in &self.content {
 258            markdown.push_str(content.to_markdown(cx).as_str());
 259            markdown.push_str("\n\n");
 260        }
 261        markdown
 262    }
 263}
 264
 265#[derive(Debug)]
 266pub enum ToolCallStatus {
 267    WaitingForConfirmation {
 268        options: Vec<acp::PermissionOption>,
 269        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 270    },
 271    Allowed {
 272        status: acp::ToolCallStatus,
 273    },
 274    Rejected,
 275    Canceled,
 276}
 277
 278impl Display for ToolCallStatus {
 279    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 280        write!(
 281            f,
 282            "{}",
 283            match self {
 284                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 285                ToolCallStatus::Allowed { status } => match status {
 286                    acp::ToolCallStatus::Pending => "Pending",
 287                    acp::ToolCallStatus::InProgress => "In Progress",
 288                    acp::ToolCallStatus::Completed => "Completed",
 289                    acp::ToolCallStatus::Failed => "Failed",
 290                },
 291                ToolCallStatus::Rejected => "Rejected",
 292                ToolCallStatus::Canceled => "Canceled",
 293            }
 294        )
 295    }
 296}
 297
 298#[derive(Debug, PartialEq, Clone)]
 299pub enum ContentBlock {
 300    Empty,
 301    Markdown { markdown: Entity<Markdown> },
 302}
 303
 304impl ContentBlock {
 305    pub fn new(
 306        block: acp::ContentBlock,
 307        language_registry: &Arc<LanguageRegistry>,
 308        cx: &mut App,
 309    ) -> Self {
 310        let mut this = Self::Empty;
 311        this.append(block, language_registry, cx);
 312        this
 313    }
 314
 315    pub fn new_combined(
 316        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 317        language_registry: Arc<LanguageRegistry>,
 318        cx: &mut App,
 319    ) -> Self {
 320        let mut this = Self::Empty;
 321        for block in blocks {
 322            this.append(block, &language_registry, cx);
 323        }
 324        this
 325    }
 326
 327    pub fn append(
 328        &mut self,
 329        block: acp::ContentBlock,
 330        language_registry: &Arc<LanguageRegistry>,
 331        cx: &mut App,
 332    ) {
 333        let new_content = match block {
 334            acp::ContentBlock::Text(text_content) => text_content.text.clone(),
 335            acp::ContentBlock::Resource(acp::EmbeddedResource {
 336                resource:
 337                    acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
 338                        uri,
 339                        ..
 340                    }),
 341                ..
 342            }) => {
 343                if let Some(uri) = MentionUri::parse(&uri).log_err() {
 344                    uri.to_link()
 345                } else {
 346                    uri.clone()
 347                }
 348            }
 349            acp::ContentBlock::Image(_)
 350            | acp::ContentBlock::Audio(_)
 351            | acp::ContentBlock::Resource(acp::EmbeddedResource { .. })
 352            | acp::ContentBlock::ResourceLink(_) => String::new(),
 353        };
 354
 355        match self {
 356            ContentBlock::Empty => {
 357                *self = ContentBlock::Markdown {
 358                    markdown: cx.new(|cx| {
 359                        Markdown::new(
 360                            new_content.into(),
 361                            Some(language_registry.clone()),
 362                            None,
 363                            cx,
 364                        )
 365                    }),
 366                };
 367            }
 368            ContentBlock::Markdown { markdown } => {
 369                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 370            }
 371        }
 372    }
 373
 374    fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 375        match self {
 376            ContentBlock::Empty => "",
 377            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 378        }
 379    }
 380
 381    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 382        match self {
 383            ContentBlock::Empty => None,
 384            ContentBlock::Markdown { markdown } => Some(markdown),
 385        }
 386    }
 387}
 388
 389#[derive(Debug)]
 390pub enum ToolCallContent {
 391    ContentBlock(ContentBlock),
 392    Diff(Entity<Diff>),
 393    Terminal(Entity<Terminal>),
 394}
 395
 396impl ToolCallContent {
 397    pub fn from_acp(
 398        content: acp::ToolCallContent,
 399        language_registry: Arc<LanguageRegistry>,
 400        cx: &mut App,
 401    ) -> Self {
 402        match content {
 403            acp::ToolCallContent::Content { content } => {
 404                Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
 405            }
 406            acp::ToolCallContent::Diff { diff } => {
 407                Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
 408            }
 409        }
 410    }
 411
 412    pub fn to_markdown(&self, cx: &App) -> String {
 413        match self {
 414            Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
 415            Self::Diff(diff) => diff.read(cx).to_markdown(cx),
 416            Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
 417        }
 418    }
 419}
 420
 421#[derive(Debug, PartialEq)]
 422pub enum ToolCallUpdate {
 423    UpdateFields(acp::ToolCallUpdate),
 424    UpdateDiff(ToolCallUpdateDiff),
 425    UpdateTerminal(ToolCallUpdateTerminal),
 426}
 427
 428impl ToolCallUpdate {
 429    fn id(&self) -> &acp::ToolCallId {
 430        match self {
 431            Self::UpdateFields(update) => &update.id,
 432            Self::UpdateDiff(diff) => &diff.id,
 433            Self::UpdateTerminal(terminal) => &terminal.id,
 434        }
 435    }
 436}
 437
 438impl From<acp::ToolCallUpdate> for ToolCallUpdate {
 439    fn from(update: acp::ToolCallUpdate) -> Self {
 440        Self::UpdateFields(update)
 441    }
 442}
 443
 444impl From<ToolCallUpdateDiff> for ToolCallUpdate {
 445    fn from(diff: ToolCallUpdateDiff) -> Self {
 446        Self::UpdateDiff(diff)
 447    }
 448}
 449
 450#[derive(Debug, PartialEq)]
 451pub struct ToolCallUpdateDiff {
 452    pub id: acp::ToolCallId,
 453    pub diff: Entity<Diff>,
 454}
 455
 456impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
 457    fn from(terminal: ToolCallUpdateTerminal) -> Self {
 458        Self::UpdateTerminal(terminal)
 459    }
 460}
 461
 462#[derive(Debug, PartialEq)]
 463pub struct ToolCallUpdateTerminal {
 464    pub id: acp::ToolCallId,
 465    pub terminal: Entity<Terminal>,
 466}
 467
 468#[derive(Debug, Default)]
 469pub struct Plan {
 470    pub entries: Vec<PlanEntry>,
 471}
 472
 473#[derive(Debug)]
 474pub struct PlanStats<'a> {
 475    pub in_progress_entry: Option<&'a PlanEntry>,
 476    pub pending: u32,
 477    pub completed: u32,
 478}
 479
 480impl Plan {
 481    pub fn is_empty(&self) -> bool {
 482        self.entries.is_empty()
 483    }
 484
 485    pub fn stats(&self) -> PlanStats<'_> {
 486        let mut stats = PlanStats {
 487            in_progress_entry: None,
 488            pending: 0,
 489            completed: 0,
 490        };
 491
 492        for entry in &self.entries {
 493            match &entry.status {
 494                acp::PlanEntryStatus::Pending => {
 495                    stats.pending += 1;
 496                }
 497                acp::PlanEntryStatus::InProgress => {
 498                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 499                }
 500                acp::PlanEntryStatus::Completed => {
 501                    stats.completed += 1;
 502                }
 503            }
 504        }
 505
 506        stats
 507    }
 508}
 509
 510#[derive(Debug)]
 511pub struct PlanEntry {
 512    pub content: Entity<Markdown>,
 513    pub priority: acp::PlanEntryPriority,
 514    pub status: acp::PlanEntryStatus,
 515}
 516
 517impl PlanEntry {
 518    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 519        Self {
 520            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 521            priority: entry.priority,
 522            status: entry.status,
 523        }
 524    }
 525}
 526
 527pub struct AcpThread {
 528    title: SharedString,
 529    entries: Vec<AgentThreadEntry>,
 530    plan: Plan,
 531    project: Entity<Project>,
 532    action_log: Entity<ActionLog>,
 533    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 534    send_task: Option<Task<()>>,
 535    connection: Rc<dyn AgentConnection>,
 536    session_id: acp::SessionId,
 537}
 538
 539pub enum AcpThreadEvent {
 540    NewEntry,
 541    EntryUpdated(usize),
 542    ToolAuthorizationRequired,
 543    Stopped,
 544    Error,
 545    ServerExited(ExitStatus),
 546}
 547
 548impl EventEmitter<AcpThreadEvent> for AcpThread {}
 549
 550#[derive(PartialEq, Eq)]
 551pub enum ThreadStatus {
 552    Idle,
 553    WaitingForToolConfirmation,
 554    Generating,
 555}
 556
 557#[derive(Debug, Clone)]
 558pub enum LoadError {
 559    Unsupported {
 560        error_message: SharedString,
 561        upgrade_message: SharedString,
 562        upgrade_command: String,
 563    },
 564    Exited(i32),
 565    Other(SharedString),
 566}
 567
 568impl Display for LoadError {
 569    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 570        match self {
 571            LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
 572            LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
 573            LoadError::Other(msg) => write!(f, "{}", msg),
 574        }
 575    }
 576}
 577
 578impl Error for LoadError {}
 579
 580impl AcpThread {
 581    pub fn new(
 582        title: impl Into<SharedString>,
 583        connection: Rc<dyn AgentConnection>,
 584        project: Entity<Project>,
 585        session_id: acp::SessionId,
 586        cx: &mut Context<Self>,
 587    ) -> Self {
 588        let action_log = cx.new(|_| ActionLog::new(project.clone()));
 589
 590        Self {
 591            action_log,
 592            shared_buffers: Default::default(),
 593            entries: Default::default(),
 594            plan: Default::default(),
 595            title: title.into(),
 596            project,
 597            send_task: None,
 598            connection,
 599            session_id,
 600        }
 601    }
 602
 603    pub fn action_log(&self) -> &Entity<ActionLog> {
 604        &self.action_log
 605    }
 606
 607    pub fn project(&self) -> &Entity<Project> {
 608        &self.project
 609    }
 610
 611    pub fn title(&self) -> SharedString {
 612        self.title.clone()
 613    }
 614
 615    pub fn entries(&self) -> &[AgentThreadEntry] {
 616        &self.entries
 617    }
 618
 619    pub fn session_id(&self) -> &acp::SessionId {
 620        &self.session_id
 621    }
 622
 623    pub fn status(&self) -> ThreadStatus {
 624        if self.send_task.is_some() {
 625            if self.waiting_for_tool_confirmation() {
 626                ThreadStatus::WaitingForToolConfirmation
 627            } else {
 628                ThreadStatus::Generating
 629            }
 630        } else {
 631            ThreadStatus::Idle
 632        }
 633    }
 634
 635    pub fn has_pending_edit_tool_calls(&self) -> bool {
 636        for entry in self.entries.iter().rev() {
 637            match entry {
 638                AgentThreadEntry::UserMessage(_) => return false,
 639                AgentThreadEntry::ToolCall(
 640                    call @ ToolCall {
 641                        status:
 642                            ToolCallStatus::Allowed {
 643                                status:
 644                                    acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
 645                            },
 646                        ..
 647                    },
 648                ) if call.diffs().next().is_some() => {
 649                    return true;
 650                }
 651                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 652            }
 653        }
 654
 655        false
 656    }
 657
 658    pub fn used_tools_since_last_user_message(&self) -> bool {
 659        for entry in self.entries.iter().rev() {
 660            match entry {
 661                AgentThreadEntry::UserMessage(..) => return false,
 662                AgentThreadEntry::AssistantMessage(..) => continue,
 663                AgentThreadEntry::ToolCall(..) => return true,
 664            }
 665        }
 666
 667        false
 668    }
 669
 670    pub fn handle_session_update(
 671        &mut self,
 672        update: acp::SessionUpdate,
 673        cx: &mut Context<Self>,
 674    ) -> Result<()> {
 675        match update {
 676            acp::SessionUpdate::UserMessageChunk { content } => {
 677                self.push_user_content_block(content, cx);
 678            }
 679            acp::SessionUpdate::AgentMessageChunk { content } => {
 680                self.push_assistant_content_block(content, false, cx);
 681            }
 682            acp::SessionUpdate::AgentThoughtChunk { content } => {
 683                self.push_assistant_content_block(content, true, cx);
 684            }
 685            acp::SessionUpdate::ToolCall(tool_call) => {
 686                self.upsert_tool_call(tool_call, cx);
 687            }
 688            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 689                self.update_tool_call(tool_call_update, cx)?;
 690            }
 691            acp::SessionUpdate::Plan(plan) => {
 692                self.update_plan(plan, cx);
 693            }
 694        }
 695        Ok(())
 696    }
 697
 698    pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
 699        let language_registry = self.project.read(cx).languages().clone();
 700        let entries_len = self.entries.len();
 701
 702        if let Some(last_entry) = self.entries.last_mut()
 703            && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
 704        {
 705            content.append(chunk, &language_registry, cx);
 706            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 707        } else {
 708            let content = ContentBlock::new(chunk, &language_registry, cx);
 709            self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
 710        }
 711    }
 712
 713    pub fn push_assistant_content_block(
 714        &mut self,
 715        chunk: acp::ContentBlock,
 716        is_thought: bool,
 717        cx: &mut Context<Self>,
 718    ) {
 719        let language_registry = self.project.read(cx).languages().clone();
 720        let entries_len = self.entries.len();
 721        if let Some(last_entry) = self.entries.last_mut()
 722            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
 723        {
 724            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 725            match (chunks.last_mut(), is_thought) {
 726                (Some(AssistantMessageChunk::Message { block }), false)
 727                | (Some(AssistantMessageChunk::Thought { block }), true) => {
 728                    block.append(chunk, &language_registry, cx)
 729                }
 730                _ => {
 731                    let block = ContentBlock::new(chunk, &language_registry, cx);
 732                    if is_thought {
 733                        chunks.push(AssistantMessageChunk::Thought { block })
 734                    } else {
 735                        chunks.push(AssistantMessageChunk::Message { block })
 736                    }
 737                }
 738            }
 739        } else {
 740            let block = ContentBlock::new(chunk, &language_registry, cx);
 741            let chunk = if is_thought {
 742                AssistantMessageChunk::Thought { block }
 743            } else {
 744                AssistantMessageChunk::Message { block }
 745            };
 746
 747            self.push_entry(
 748                AgentThreadEntry::AssistantMessage(AssistantMessage {
 749                    chunks: vec![chunk],
 750                }),
 751                cx,
 752            );
 753        }
 754    }
 755
 756    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
 757        self.entries.push(entry);
 758        cx.emit(AcpThreadEvent::NewEntry);
 759    }
 760
 761    pub fn update_tool_call(
 762        &mut self,
 763        update: impl Into<ToolCallUpdate>,
 764        cx: &mut Context<Self>,
 765    ) -> Result<()> {
 766        let update = update.into();
 767        let languages = self.project.read(cx).languages().clone();
 768
 769        let (ix, current_call) = self
 770            .tool_call_mut(update.id())
 771            .context("Tool call not found")?;
 772        match update {
 773            ToolCallUpdate::UpdateFields(update) => {
 774                current_call.update_fields(update.fields, languages, cx);
 775            }
 776            ToolCallUpdate::UpdateDiff(update) => {
 777                current_call.content.clear();
 778                current_call
 779                    .content
 780                    .push(ToolCallContent::Diff(update.diff));
 781            }
 782            ToolCallUpdate::UpdateTerminal(update) => {
 783                current_call.content.clear();
 784                current_call
 785                    .content
 786                    .push(ToolCallContent::Terminal(update.terminal));
 787            }
 788        }
 789
 790        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 791
 792        Ok(())
 793    }
 794
 795    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
 796    pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
 797        let status = ToolCallStatus::Allowed {
 798            status: tool_call.status,
 799        };
 800        self.upsert_tool_call_inner(tool_call, status, cx)
 801    }
 802
 803    pub fn upsert_tool_call_inner(
 804        &mut self,
 805        tool_call: acp::ToolCall,
 806        status: ToolCallStatus,
 807        cx: &mut Context<Self>,
 808    ) {
 809        let language_registry = self.project.read(cx).languages().clone();
 810        let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
 811
 812        let location = call.locations.last().cloned();
 813
 814        if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
 815            *current_call = call;
 816
 817            cx.emit(AcpThreadEvent::EntryUpdated(ix));
 818        } else {
 819            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
 820        }
 821
 822        if let Some(location) = location {
 823            self.set_project_location(location, cx)
 824        }
 825    }
 826
 827    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
 828        // The tool call we are looking for is typically the last one, or very close to the end.
 829        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
 830        self.entries
 831            .iter_mut()
 832            .enumerate()
 833            .rev()
 834            .find_map(|(index, tool_call)| {
 835                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
 836                    && &tool_call.id == id
 837                {
 838                    Some((index, tool_call))
 839                } else {
 840                    None
 841                }
 842            })
 843    }
 844
 845    pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
 846        self.project.update(cx, |project, cx| {
 847            let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
 848                return;
 849            };
 850            let buffer = project.open_buffer(path, cx);
 851            cx.spawn(async move |project, cx| {
 852                let buffer = buffer.await?;
 853
 854                project.update(cx, |project, cx| {
 855                    let position = if let Some(line) = location.line {
 856                        let snapshot = buffer.read(cx).snapshot();
 857                        let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
 858                        snapshot.anchor_before(point)
 859                    } else {
 860                        Anchor::MIN
 861                    };
 862
 863                    project.set_agent_location(
 864                        Some(AgentLocation {
 865                            buffer: buffer.downgrade(),
 866                            position,
 867                        }),
 868                        cx,
 869                    );
 870                })
 871            })
 872            .detach_and_log_err(cx);
 873        });
 874    }
 875
 876    pub fn request_tool_call_authorization(
 877        &mut self,
 878        tool_call: acp::ToolCall,
 879        options: Vec<acp::PermissionOption>,
 880        cx: &mut Context<Self>,
 881    ) -> oneshot::Receiver<acp::PermissionOptionId> {
 882        let (tx, rx) = oneshot::channel();
 883
 884        let status = ToolCallStatus::WaitingForConfirmation {
 885            options,
 886            respond_tx: tx,
 887        };
 888
 889        self.upsert_tool_call_inner(tool_call, status, cx);
 890        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
 891        rx
 892    }
 893
 894    pub fn authorize_tool_call(
 895        &mut self,
 896        id: acp::ToolCallId,
 897        option_id: acp::PermissionOptionId,
 898        option_kind: acp::PermissionOptionKind,
 899        cx: &mut Context<Self>,
 900    ) {
 901        let Some((ix, call)) = self.tool_call_mut(&id) else {
 902            return;
 903        };
 904
 905        let new_status = match option_kind {
 906            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
 907                ToolCallStatus::Rejected
 908            }
 909            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
 910                ToolCallStatus::Allowed {
 911                    status: acp::ToolCallStatus::InProgress,
 912                }
 913            }
 914        };
 915
 916        let curr_status = mem::replace(&mut call.status, new_status);
 917
 918        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
 919            respond_tx.send(option_id).log_err();
 920        } else if cfg!(debug_assertions) {
 921            panic!("tried to authorize an already authorized tool call");
 922        }
 923
 924        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 925    }
 926
 927    /// Returns true if the last turn is awaiting tool authorization
 928    pub fn waiting_for_tool_confirmation(&self) -> bool {
 929        for entry in self.entries.iter().rev() {
 930            match &entry {
 931                AgentThreadEntry::ToolCall(call) => match call.status {
 932                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
 933                    ToolCallStatus::Allowed { .. }
 934                    | ToolCallStatus::Rejected
 935                    | ToolCallStatus::Canceled => continue,
 936                },
 937                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
 938                    // Reached the beginning of the turn
 939                    return false;
 940                }
 941            }
 942        }
 943        false
 944    }
 945
 946    pub fn plan(&self) -> &Plan {
 947        &self.plan
 948    }
 949
 950    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
 951        let new_entries_len = request.entries.len();
 952        let mut new_entries = request.entries.into_iter();
 953
 954        // Reuse existing markdown to prevent flickering
 955        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
 956            let PlanEntry {
 957                content,
 958                priority,
 959                status,
 960            } = old;
 961            content.update(cx, |old, cx| {
 962                old.replace(new.content, cx);
 963            });
 964            *priority = new.priority;
 965            *status = new.status;
 966        }
 967        for new in new_entries {
 968            self.plan.entries.push(PlanEntry::from_acp(new, cx))
 969        }
 970        self.plan.entries.truncate(new_entries_len);
 971
 972        cx.notify();
 973    }
 974
 975    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
 976        self.plan
 977            .entries
 978            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
 979        cx.notify();
 980    }
 981
 982    #[cfg(any(test, feature = "test-support"))]
 983    pub fn send_raw(
 984        &mut self,
 985        message: &str,
 986        cx: &mut Context<Self>,
 987    ) -> BoxFuture<'static, Result<()>> {
 988        self.send(
 989            vec![acp::ContentBlock::Text(acp::TextContent {
 990                text: message.to_string(),
 991                annotations: None,
 992            })],
 993            cx,
 994        )
 995    }
 996
 997    pub fn send(
 998        &mut self,
 999        message: Vec<acp::ContentBlock>,
1000        cx: &mut Context<Self>,
1001    ) -> BoxFuture<'static, Result<()>> {
1002        let block = ContentBlock::new_combined(
1003            message.clone(),
1004            self.project.read(cx).languages().clone(),
1005            cx,
1006        );
1007        self.push_entry(
1008            AgentThreadEntry::UserMessage(UserMessage { content: block }),
1009            cx,
1010        );
1011        self.clear_completed_plan_entries(cx);
1012
1013        let (tx, rx) = oneshot::channel();
1014        let cancel_task = self.cancel(cx);
1015
1016        self.send_task = Some(cx.spawn(async move |this, cx| {
1017            async {
1018                cancel_task.await;
1019
1020                let result = this
1021                    .update(cx, |this, cx| {
1022                        this.connection.prompt(
1023                            acp::PromptRequest {
1024                                prompt: message,
1025                                session_id: this.session_id.clone(),
1026                            },
1027                            cx,
1028                        )
1029                    })?
1030                    .await;
1031
1032                tx.send(result).log_err();
1033
1034                anyhow::Ok(())
1035            }
1036            .await
1037            .log_err();
1038        }));
1039
1040        cx.spawn(async move |this, cx| match rx.await {
1041            Ok(Err(e)) => {
1042                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1043                    .log_err();
1044                Err(e)?
1045            }
1046            result => {
1047                let cancelled = matches!(
1048                    result,
1049                    Ok(Ok(acp::PromptResponse {
1050                        stop_reason: acp::StopReason::Cancelled
1051                    }))
1052                );
1053
1054                // We only take the task if the current prompt wasn't cancelled.
1055                //
1056                // This prompt may have been cancelled because another one was sent
1057                // while it was still generating. In these cases, dropping `send_task`
1058                // would cause the next generation to be cancelled.
1059                if !cancelled {
1060                    this.update(cx, |this, _cx| this.send_task.take()).ok();
1061                }
1062
1063                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1064                    .log_err();
1065                Ok(())
1066            }
1067        })
1068        .boxed()
1069    }
1070
1071    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1072        let Some(send_task) = self.send_task.take() else {
1073            return Task::ready(());
1074        };
1075
1076        for entry in self.entries.iter_mut() {
1077            if let AgentThreadEntry::ToolCall(call) = entry {
1078                let cancel = matches!(
1079                    call.status,
1080                    ToolCallStatus::WaitingForConfirmation { .. }
1081                        | ToolCallStatus::Allowed {
1082                            status: acp::ToolCallStatus::InProgress
1083                        }
1084                );
1085
1086                if cancel {
1087                    call.status = ToolCallStatus::Canceled;
1088                }
1089            }
1090        }
1091
1092        self.connection.cancel(&self.session_id, cx);
1093
1094        // Wait for the send task to complete
1095        cx.foreground_executor().spawn(send_task)
1096    }
1097
1098    pub fn read_text_file(
1099        &self,
1100        path: PathBuf,
1101        line: Option<u32>,
1102        limit: Option<u32>,
1103        reuse_shared_snapshot: bool,
1104        cx: &mut Context<Self>,
1105    ) -> Task<Result<String>> {
1106        let project = self.project.clone();
1107        let action_log = self.action_log.clone();
1108        cx.spawn(async move |this, cx| {
1109            let load = project.update(cx, |project, cx| {
1110                let path = project
1111                    .project_path_for_absolute_path(&path, cx)
1112                    .context("invalid path")?;
1113                anyhow::Ok(project.open_buffer(path, cx))
1114            });
1115            let buffer = load??.await?;
1116
1117            let snapshot = if reuse_shared_snapshot {
1118                this.read_with(cx, |this, _| {
1119                    this.shared_buffers.get(&buffer.clone()).cloned()
1120                })
1121                .log_err()
1122                .flatten()
1123            } else {
1124                None
1125            };
1126
1127            let snapshot = if let Some(snapshot) = snapshot {
1128                snapshot
1129            } else {
1130                action_log.update(cx, |action_log, cx| {
1131                    action_log.buffer_read(buffer.clone(), cx);
1132                })?;
1133                project.update(cx, |project, cx| {
1134                    let position = buffer
1135                        .read(cx)
1136                        .snapshot()
1137                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1138                    project.set_agent_location(
1139                        Some(AgentLocation {
1140                            buffer: buffer.downgrade(),
1141                            position,
1142                        }),
1143                        cx,
1144                    );
1145                })?;
1146
1147                buffer.update(cx, |buffer, _| buffer.snapshot())?
1148            };
1149
1150            this.update(cx, |this, _| {
1151                let text = snapshot.text();
1152                this.shared_buffers.insert(buffer.clone(), snapshot);
1153                if line.is_none() && limit.is_none() {
1154                    return Ok(text);
1155                }
1156                let limit = limit.unwrap_or(u32::MAX) as usize;
1157                let Some(line) = line else {
1158                    return Ok(text.lines().take(limit).collect::<String>());
1159                };
1160
1161                let count = text.lines().count();
1162                if count < line as usize {
1163                    anyhow::bail!("There are only {} lines", count);
1164                }
1165                Ok(text
1166                    .lines()
1167                    .skip(line as usize + 1)
1168                    .take(limit)
1169                    .collect::<String>())
1170            })?
1171        })
1172    }
1173
1174    pub fn write_text_file(
1175        &self,
1176        path: PathBuf,
1177        content: String,
1178        cx: &mut Context<Self>,
1179    ) -> Task<Result<()>> {
1180        let project = self.project.clone();
1181        let action_log = self.action_log.clone();
1182        cx.spawn(async move |this, cx| {
1183            let load = project.update(cx, |project, cx| {
1184                let path = project
1185                    .project_path_for_absolute_path(&path, cx)
1186                    .context("invalid path")?;
1187                anyhow::Ok(project.open_buffer(path, cx))
1188            });
1189            let buffer = load??.await?;
1190            let snapshot = this.update(cx, |this, cx| {
1191                this.shared_buffers
1192                    .get(&buffer)
1193                    .cloned()
1194                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1195            })?;
1196            let edits = cx
1197                .background_executor()
1198                .spawn(async move {
1199                    let old_text = snapshot.text();
1200                    text_diff(old_text.as_str(), &content)
1201                        .into_iter()
1202                        .map(|(range, replacement)| {
1203                            (
1204                                snapshot.anchor_after(range.start)
1205                                    ..snapshot.anchor_before(range.end),
1206                                replacement,
1207                            )
1208                        })
1209                        .collect::<Vec<_>>()
1210                })
1211                .await;
1212            cx.update(|cx| {
1213                project.update(cx, |project, cx| {
1214                    project.set_agent_location(
1215                        Some(AgentLocation {
1216                            buffer: buffer.downgrade(),
1217                            position: edits
1218                                .last()
1219                                .map(|(range, _)| range.end)
1220                                .unwrap_or(Anchor::MIN),
1221                        }),
1222                        cx,
1223                    );
1224                });
1225
1226                action_log.update(cx, |action_log, cx| {
1227                    action_log.buffer_read(buffer.clone(), cx);
1228                });
1229                buffer.update(cx, |buffer, cx| {
1230                    buffer.edit(edits, None, cx);
1231                });
1232                action_log.update(cx, |action_log, cx| {
1233                    action_log.buffer_edited(buffer.clone(), cx);
1234                });
1235            })?;
1236            project
1237                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1238                .await
1239        })
1240    }
1241
1242    pub fn to_markdown(&self, cx: &App) -> String {
1243        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1244    }
1245
1246    pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1247        cx.emit(AcpThreadEvent::ServerExited(status));
1248    }
1249}
1250
1251fn markdown_for_raw_output(
1252    raw_output: &serde_json::Value,
1253    language_registry: &Arc<LanguageRegistry>,
1254    cx: &mut App,
1255) -> Option<Entity<Markdown>> {
1256    match raw_output {
1257        serde_json::Value::Null => None,
1258        serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1259            Markdown::new(
1260                value.to_string().into(),
1261                Some(language_registry.clone()),
1262                None,
1263                cx,
1264            )
1265        })),
1266        serde_json::Value::Number(value) => Some(cx.new(|cx| {
1267            Markdown::new(
1268                value.to_string().into(),
1269                Some(language_registry.clone()),
1270                None,
1271                cx,
1272            )
1273        })),
1274        serde_json::Value::String(value) => Some(cx.new(|cx| {
1275            Markdown::new(
1276                value.clone().into(),
1277                Some(language_registry.clone()),
1278                None,
1279                cx,
1280            )
1281        })),
1282        value => Some(cx.new(|cx| {
1283            Markdown::new(
1284                format!("```json\n{}\n```", value).into(),
1285                Some(language_registry.clone()),
1286                None,
1287                cx,
1288            )
1289        })),
1290    }
1291}
1292
1293#[cfg(test)]
1294mod tests {
1295    use super::*;
1296    use anyhow::anyhow;
1297    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1298    use gpui::{AsyncApp, TestAppContext, WeakEntity};
1299    use indoc::indoc;
1300    use project::FakeFs;
1301    use rand::Rng as _;
1302    use serde_json::json;
1303    use settings::SettingsStore;
1304    use smol::stream::StreamExt as _;
1305    use std::{cell::RefCell, path::Path, rc::Rc, time::Duration};
1306
1307    use util::path;
1308
1309    fn init_test(cx: &mut TestAppContext) {
1310        env_logger::try_init().ok();
1311        cx.update(|cx| {
1312            let settings_store = SettingsStore::test(cx);
1313            cx.set_global(settings_store);
1314            Project::init_settings(cx);
1315            language::init(cx);
1316        });
1317    }
1318
1319    #[gpui::test]
1320    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1321        init_test(cx);
1322
1323        let fs = FakeFs::new(cx.executor());
1324        let project = Project::test(fs, [], cx).await;
1325        let connection = Rc::new(FakeAgentConnection::new());
1326        let thread = cx
1327            .spawn(async move |mut cx| {
1328                connection
1329                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1330                    .await
1331            })
1332            .await
1333            .unwrap();
1334
1335        // Test creating a new user message
1336        thread.update(cx, |thread, cx| {
1337            thread.push_user_content_block(
1338                acp::ContentBlock::Text(acp::TextContent {
1339                    annotations: None,
1340                    text: "Hello, ".to_string(),
1341                }),
1342                cx,
1343            );
1344        });
1345
1346        thread.update(cx, |thread, cx| {
1347            assert_eq!(thread.entries.len(), 1);
1348            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1349                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1350            } else {
1351                panic!("Expected UserMessage");
1352            }
1353        });
1354
1355        // Test appending to existing user message
1356        thread.update(cx, |thread, cx| {
1357            thread.push_user_content_block(
1358                acp::ContentBlock::Text(acp::TextContent {
1359                    annotations: None,
1360                    text: "world!".to_string(),
1361                }),
1362                cx,
1363            );
1364        });
1365
1366        thread.update(cx, |thread, cx| {
1367            assert_eq!(thread.entries.len(), 1);
1368            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1369                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1370            } else {
1371                panic!("Expected UserMessage");
1372            }
1373        });
1374
1375        // Test creating new user message after assistant message
1376        thread.update(cx, |thread, cx| {
1377            thread.push_assistant_content_block(
1378                acp::ContentBlock::Text(acp::TextContent {
1379                    annotations: None,
1380                    text: "Assistant response".to_string(),
1381                }),
1382                false,
1383                cx,
1384            );
1385        });
1386
1387        thread.update(cx, |thread, cx| {
1388            thread.push_user_content_block(
1389                acp::ContentBlock::Text(acp::TextContent {
1390                    annotations: None,
1391                    text: "New user message".to_string(),
1392                }),
1393                cx,
1394            );
1395        });
1396
1397        thread.update(cx, |thread, cx| {
1398            assert_eq!(thread.entries.len(), 3);
1399            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1400                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1401            } else {
1402                panic!("Expected UserMessage at index 2");
1403            }
1404        });
1405    }
1406
1407    #[gpui::test]
1408    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1409        init_test(cx);
1410
1411        let fs = FakeFs::new(cx.executor());
1412        let project = Project::test(fs, [], cx).await;
1413        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1414            |_, thread, mut cx| {
1415                async move {
1416                    thread.update(&mut cx, |thread, cx| {
1417                        thread
1418                            .handle_session_update(
1419                                acp::SessionUpdate::AgentThoughtChunk {
1420                                    content: "Thinking ".into(),
1421                                },
1422                                cx,
1423                            )
1424                            .unwrap();
1425                        thread
1426                            .handle_session_update(
1427                                acp::SessionUpdate::AgentThoughtChunk {
1428                                    content: "hard!".into(),
1429                                },
1430                                cx,
1431                            )
1432                            .unwrap();
1433                    })?;
1434                    Ok(acp::PromptResponse {
1435                        stop_reason: acp::StopReason::EndTurn,
1436                    })
1437                }
1438                .boxed_local()
1439            },
1440        ));
1441
1442        let thread = cx
1443            .spawn(async move |mut cx| {
1444                connection
1445                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1446                    .await
1447            })
1448            .await
1449            .unwrap();
1450
1451        thread
1452            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1453            .await
1454            .unwrap();
1455
1456        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1457        assert_eq!(
1458            output,
1459            indoc! {r#"
1460            ## User
1461
1462            Hello from Zed!
1463
1464            ## Assistant
1465
1466            <thinking>
1467            Thinking hard!
1468            </thinking>
1469
1470            "#}
1471        );
1472    }
1473
1474    #[gpui::test]
1475    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1476        init_test(cx);
1477
1478        let fs = FakeFs::new(cx.executor());
1479        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1480            .await;
1481        let project = Project::test(fs.clone(), [], cx).await;
1482        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1483        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1484        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1485            move |_, thread, mut cx| {
1486                let read_file_tx = read_file_tx.clone();
1487                async move {
1488                    let content = thread
1489                        .update(&mut cx, |thread, cx| {
1490                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1491                        })
1492                        .unwrap()
1493                        .await
1494                        .unwrap();
1495                    assert_eq!(content, "one\ntwo\nthree\n");
1496                    read_file_tx.take().unwrap().send(()).unwrap();
1497                    thread
1498                        .update(&mut cx, |thread, cx| {
1499                            thread.write_text_file(
1500                                path!("/tmp/foo").into(),
1501                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
1502                                cx,
1503                            )
1504                        })
1505                        .unwrap()
1506                        .await
1507                        .unwrap();
1508                    Ok(acp::PromptResponse {
1509                        stop_reason: acp::StopReason::EndTurn,
1510                    })
1511                }
1512                .boxed_local()
1513            },
1514        ));
1515
1516        let (worktree, pathbuf) = project
1517            .update(cx, |project, cx| {
1518                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1519            })
1520            .await
1521            .unwrap();
1522        let buffer = project
1523            .update(cx, |project, cx| {
1524                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1525            })
1526            .await
1527            .unwrap();
1528
1529        let thread = cx
1530            .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1531            .await
1532            .unwrap();
1533
1534        let request = thread.update(cx, |thread, cx| {
1535            thread.send_raw("Extend the count in /tmp/foo", cx)
1536        });
1537        read_file_rx.await.ok();
1538        buffer.update(cx, |buffer, cx| {
1539            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1540        });
1541        cx.run_until_parked();
1542        assert_eq!(
1543            buffer.read_with(cx, |buffer, _| buffer.text()),
1544            "zero\none\ntwo\nthree\nfour\nfive\n"
1545        );
1546        assert_eq!(
1547            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1548            "zero\none\ntwo\nthree\nfour\nfive\n"
1549        );
1550        request.await.unwrap();
1551    }
1552
1553    #[gpui::test]
1554    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1555        init_test(cx);
1556
1557        let fs = FakeFs::new(cx.executor());
1558        let project = Project::test(fs, [], cx).await;
1559        let id = acp::ToolCallId("test".into());
1560
1561        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1562            let id = id.clone();
1563            move |_, thread, mut cx| {
1564                let id = id.clone();
1565                async move {
1566                    thread
1567                        .update(&mut cx, |thread, cx| {
1568                            thread.handle_session_update(
1569                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1570                                    id: id.clone(),
1571                                    title: "Label".into(),
1572                                    kind: acp::ToolKind::Fetch,
1573                                    status: acp::ToolCallStatus::InProgress,
1574                                    content: vec![],
1575                                    locations: vec![],
1576                                    raw_input: None,
1577                                    raw_output: None,
1578                                }),
1579                                cx,
1580                            )
1581                        })
1582                        .unwrap()
1583                        .unwrap();
1584                    Ok(acp::PromptResponse {
1585                        stop_reason: acp::StopReason::EndTurn,
1586                    })
1587                }
1588                .boxed_local()
1589            }
1590        }));
1591
1592        let thread = cx
1593            .spawn(async move |mut cx| {
1594                connection
1595                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1596                    .await
1597            })
1598            .await
1599            .unwrap();
1600
1601        let request = thread.update(cx, |thread, cx| {
1602            thread.send_raw("Fetch https://example.com", cx)
1603        });
1604
1605        run_until_first_tool_call(&thread, cx).await;
1606
1607        thread.read_with(cx, |thread, _| {
1608            assert!(matches!(
1609                thread.entries[1],
1610                AgentThreadEntry::ToolCall(ToolCall {
1611                    status: ToolCallStatus::Allowed {
1612                        status: acp::ToolCallStatus::InProgress,
1613                        ..
1614                    },
1615                    ..
1616                })
1617            ));
1618        });
1619
1620        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1621
1622        thread.read_with(cx, |thread, _| {
1623            assert!(matches!(
1624                &thread.entries[1],
1625                AgentThreadEntry::ToolCall(ToolCall {
1626                    status: ToolCallStatus::Canceled,
1627                    ..
1628                })
1629            ));
1630        });
1631
1632        thread
1633            .update(cx, |thread, cx| {
1634                thread.handle_session_update(
1635                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1636                        id,
1637                        fields: acp::ToolCallUpdateFields {
1638                            status: Some(acp::ToolCallStatus::Completed),
1639                            ..Default::default()
1640                        },
1641                    }),
1642                    cx,
1643                )
1644            })
1645            .unwrap();
1646
1647        request.await.unwrap();
1648
1649        thread.read_with(cx, |thread, _| {
1650            assert!(matches!(
1651                thread.entries[1],
1652                AgentThreadEntry::ToolCall(ToolCall {
1653                    status: ToolCallStatus::Allowed {
1654                        status: acp::ToolCallStatus::Completed,
1655                        ..
1656                    },
1657                    ..
1658                })
1659            ));
1660        });
1661    }
1662
1663    #[gpui::test]
1664    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1665        init_test(cx);
1666        let fs = FakeFs::new(cx.background_executor.clone());
1667        fs.insert_tree(path!("/test"), json!({})).await;
1668        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1669
1670        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1671            move |_, thread, mut cx| {
1672                async move {
1673                    thread
1674                        .update(&mut cx, |thread, cx| {
1675                            thread.handle_session_update(
1676                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1677                                    id: acp::ToolCallId("test".into()),
1678                                    title: "Label".into(),
1679                                    kind: acp::ToolKind::Edit,
1680                                    status: acp::ToolCallStatus::Completed,
1681                                    content: vec![acp::ToolCallContent::Diff {
1682                                        diff: acp::Diff {
1683                                            path: "/test/test.txt".into(),
1684                                            old_text: None,
1685                                            new_text: "foo".into(),
1686                                        },
1687                                    }],
1688                                    locations: vec![],
1689                                    raw_input: None,
1690                                    raw_output: None,
1691                                }),
1692                                cx,
1693                            )
1694                        })
1695                        .unwrap()
1696                        .unwrap();
1697                    Ok(acp::PromptResponse {
1698                        stop_reason: acp::StopReason::EndTurn,
1699                    })
1700                }
1701                .boxed_local()
1702            }
1703        }));
1704
1705        let thread = connection
1706            .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1707            .await
1708            .unwrap();
1709        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1710            .await
1711            .unwrap();
1712
1713        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1714    }
1715
1716    async fn run_until_first_tool_call(
1717        thread: &Entity<AcpThread>,
1718        cx: &mut TestAppContext,
1719    ) -> usize {
1720        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1721
1722        let subscription = cx.update(|cx| {
1723            cx.subscribe(thread, move |thread, _, cx| {
1724                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1725                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1726                        return tx.try_send(ix).unwrap();
1727                    }
1728                }
1729            })
1730        });
1731
1732        select! {
1733            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1734                panic!("Timeout waiting for tool call")
1735            }
1736            ix = rx.next().fuse() => {
1737                drop(subscription);
1738                ix.unwrap()
1739            }
1740        }
1741    }
1742
1743    #[derive(Clone, Default)]
1744    struct FakeAgentConnection {
1745        auth_methods: Vec<acp::AuthMethod>,
1746        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1747        on_user_message: Option<
1748            Rc<
1749                dyn Fn(
1750                        acp::PromptRequest,
1751                        WeakEntity<AcpThread>,
1752                        AsyncApp,
1753                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1754                    + 'static,
1755            >,
1756        >,
1757    }
1758
1759    impl FakeAgentConnection {
1760        fn new() -> Self {
1761            Self {
1762                auth_methods: Vec::new(),
1763                on_user_message: None,
1764                sessions: Arc::default(),
1765            }
1766        }
1767
1768        #[expect(unused)]
1769        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1770            self.auth_methods = auth_methods;
1771            self
1772        }
1773
1774        fn on_user_message(
1775            mut self,
1776            handler: impl Fn(
1777                acp::PromptRequest,
1778                WeakEntity<AcpThread>,
1779                AsyncApp,
1780            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1781            + 'static,
1782        ) -> Self {
1783            self.on_user_message.replace(Rc::new(handler));
1784            self
1785        }
1786    }
1787
1788    impl AgentConnection for FakeAgentConnection {
1789        fn auth_methods(&self) -> &[acp::AuthMethod] {
1790            &self.auth_methods
1791        }
1792
1793        fn new_thread(
1794            self: Rc<Self>,
1795            project: Entity<Project>,
1796            _cwd: &Path,
1797            cx: &mut gpui::AsyncApp,
1798        ) -> Task<gpui::Result<Entity<AcpThread>>> {
1799            let session_id = acp::SessionId(
1800                rand::thread_rng()
1801                    .sample_iter(&rand::distributions::Alphanumeric)
1802                    .take(7)
1803                    .map(char::from)
1804                    .collect::<String>()
1805                    .into(),
1806            );
1807            let thread = cx
1808                .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1809                .unwrap();
1810            self.sessions.lock().insert(session_id, thread.downgrade());
1811            Task::ready(Ok(thread))
1812        }
1813
1814        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1815            if self.auth_methods().iter().any(|m| m.id == method) {
1816                Task::ready(Ok(()))
1817            } else {
1818                Task::ready(Err(anyhow!("Invalid Auth Method")))
1819            }
1820        }
1821
1822        fn prompt(
1823            &self,
1824            params: acp::PromptRequest,
1825            cx: &mut App,
1826        ) -> Task<gpui::Result<acp::PromptResponse>> {
1827            let sessions = self.sessions.lock();
1828            let thread = sessions.get(&params.session_id).unwrap();
1829            if let Some(handler) = &self.on_user_message {
1830                let handler = handler.clone();
1831                let thread = thread.clone();
1832                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1833            } else {
1834                Task::ready(Ok(acp::PromptResponse {
1835                    stop_reason: acp::StopReason::EndTurn,
1836                }))
1837            }
1838        }
1839
1840        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1841            let sessions = self.sessions.lock();
1842            let thread = sessions.get(&session_id).unwrap().clone();
1843
1844            cx.spawn(async move |cx| {
1845                thread
1846                    .update(cx, |thread, cx| thread.cancel(cx))
1847                    .unwrap()
1848                    .await
1849            })
1850            .detach();
1851        }
1852    }
1853}