acp_thread.rs

   1mod connection;
   2mod diff;
   3mod mention;
   4mod terminal;
   5
   6use agent_settings::AgentSettings;
   7use collections::HashSet;
   8pub use connection::*;
   9pub use diff::*;
  10use futures::future::Shared;
  11use language::language_settings::FormatOnSave;
  12pub use mention::*;
  13use project::lsp_store::{FormatTrigger, LspFormatTarget};
  14use serde::{Deserialize, Serialize};
  15use settings::Settings as _;
  16pub use terminal::*;
  17
  18use action_log::ActionLog;
  19use agent_client_protocol::{self as acp};
  20use anyhow::{Context as _, Result, anyhow};
  21use editor::Bias;
  22use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  23use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  24use itertools::Itertools;
  25use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
  26use markdown::Markdown;
  27use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
  28use std::collections::HashMap;
  29use std::error::Error;
  30use std::fmt::{Formatter, Write};
  31use std::ops::Range;
  32use std::process::ExitStatus;
  33use std::rc::Rc;
  34use std::time::{Duration, Instant};
  35use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
  36use ui::App;
  37use util::{ResultExt, get_system_shell};
  38use uuid::Uuid;
  39
  40#[derive(Debug)]
  41pub struct UserMessage {
  42    pub id: Option<UserMessageId>,
  43    pub content: ContentBlock,
  44    pub chunks: Vec<acp::ContentBlock>,
  45    pub checkpoint: Option<Checkpoint>,
  46}
  47
  48#[derive(Debug)]
  49pub struct Checkpoint {
  50    git_checkpoint: GitStoreCheckpoint,
  51    pub show: bool,
  52}
  53
  54impl UserMessage {
  55    fn to_markdown(&self, cx: &App) -> String {
  56        let mut markdown = String::new();
  57        if self
  58            .checkpoint
  59            .as_ref()
  60            .is_some_and(|checkpoint| checkpoint.show)
  61        {
  62            writeln!(markdown, "## User (checkpoint)").unwrap();
  63        } else {
  64            writeln!(markdown, "## User").unwrap();
  65        }
  66        writeln!(markdown).unwrap();
  67        writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
  68        writeln!(markdown).unwrap();
  69        markdown
  70    }
  71}
  72
  73#[derive(Debug, PartialEq)]
  74pub struct AssistantMessage {
  75    pub chunks: Vec<AssistantMessageChunk>,
  76}
  77
  78impl AssistantMessage {
  79    pub fn to_markdown(&self, cx: &App) -> String {
  80        format!(
  81            "## Assistant\n\n{}\n\n",
  82            self.chunks
  83                .iter()
  84                .map(|chunk| chunk.to_markdown(cx))
  85                .join("\n\n")
  86        )
  87    }
  88}
  89
  90#[derive(Debug, PartialEq)]
  91pub enum AssistantMessageChunk {
  92    Message { block: ContentBlock },
  93    Thought { block: ContentBlock },
  94}
  95
  96impl AssistantMessageChunk {
  97    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
  98        Self::Message {
  99            block: ContentBlock::new(chunk.into(), language_registry, cx),
 100        }
 101    }
 102
 103    fn to_markdown(&self, cx: &App) -> String {
 104        match self {
 105            Self::Message { block } => block.to_markdown(cx).to_string(),
 106            Self::Thought { block } => {
 107                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 108            }
 109        }
 110    }
 111}
 112
 113#[derive(Debug)]
 114pub enum AgentThreadEntry {
 115    UserMessage(UserMessage),
 116    AssistantMessage(AssistantMessage),
 117    ToolCall(ToolCall),
 118}
 119
 120impl AgentThreadEntry {
 121    pub fn to_markdown(&self, cx: &App) -> String {
 122        match self {
 123            Self::UserMessage(message) => message.to_markdown(cx),
 124            Self::AssistantMessage(message) => message.to_markdown(cx),
 125            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 126        }
 127    }
 128
 129    pub fn user_message(&self) -> Option<&UserMessage> {
 130        if let AgentThreadEntry::UserMessage(message) = self {
 131            Some(message)
 132        } else {
 133            None
 134        }
 135    }
 136
 137    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 138        if let AgentThreadEntry::ToolCall(call) = self {
 139            itertools::Either::Left(call.diffs())
 140        } else {
 141            itertools::Either::Right(std::iter::empty())
 142        }
 143    }
 144
 145    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 146        if let AgentThreadEntry::ToolCall(call) = self {
 147            itertools::Either::Left(call.terminals())
 148        } else {
 149            itertools::Either::Right(std::iter::empty())
 150        }
 151    }
 152
 153    pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
 154        if let AgentThreadEntry::ToolCall(ToolCall {
 155            locations,
 156            resolved_locations,
 157            ..
 158        }) = self
 159        {
 160            Some((
 161                locations.get(ix)?.clone(),
 162                resolved_locations.get(ix)?.clone()?,
 163            ))
 164        } else {
 165            None
 166        }
 167    }
 168}
 169
 170#[derive(Debug)]
 171pub struct ToolCall {
 172    pub id: acp::ToolCallId,
 173    pub label: Entity<Markdown>,
 174    pub kind: acp::ToolKind,
 175    pub content: Vec<ToolCallContent>,
 176    pub status: ToolCallStatus,
 177    pub locations: Vec<acp::ToolCallLocation>,
 178    pub resolved_locations: Vec<Option<AgentLocation>>,
 179    pub raw_input: Option<serde_json::Value>,
 180    pub raw_output: Option<serde_json::Value>,
 181}
 182
 183impl ToolCall {
 184    fn from_acp(
 185        tool_call: acp::ToolCall,
 186        status: ToolCallStatus,
 187        language_registry: Arc<LanguageRegistry>,
 188        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 189        cx: &mut App,
 190    ) -> Result<Self> {
 191        let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
 192            first_line.to_owned() + ""
 193        } else {
 194            tool_call.title
 195        };
 196        let mut content = Vec::with_capacity(tool_call.content.len());
 197        for item in tool_call.content {
 198            content.push(ToolCallContent::from_acp(
 199                item,
 200                language_registry.clone(),
 201                terminals,
 202                cx,
 203            )?);
 204        }
 205
 206        let result = Self {
 207            id: tool_call.id,
 208            label: cx
 209                .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
 210            kind: tool_call.kind,
 211            content,
 212            locations: tool_call.locations,
 213            resolved_locations: Vec::default(),
 214            status,
 215            raw_input: tool_call.raw_input,
 216            raw_output: tool_call.raw_output,
 217        };
 218        Ok(result)
 219    }
 220
 221    fn update_fields(
 222        &mut self,
 223        fields: acp::ToolCallUpdateFields,
 224        language_registry: Arc<LanguageRegistry>,
 225        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 226        cx: &mut App,
 227    ) -> Result<()> {
 228        let acp::ToolCallUpdateFields {
 229            kind,
 230            status,
 231            title,
 232            content,
 233            locations,
 234            raw_input,
 235            raw_output,
 236        } = fields;
 237
 238        if let Some(kind) = kind {
 239            self.kind = kind;
 240        }
 241
 242        if let Some(status) = status {
 243            self.status = status.into();
 244        }
 245
 246        if let Some(title) = title {
 247            self.label.update(cx, |label, cx| {
 248                if let Some((first_line, _)) = title.split_once("\n") {
 249                    label.replace(first_line.to_owned() + "", cx)
 250                } else {
 251                    label.replace(title, cx);
 252                }
 253            });
 254        }
 255
 256        if let Some(content) = content {
 257            let new_content_len = content.len();
 258            let mut content = content.into_iter();
 259
 260            // Reuse existing content if we can
 261            for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
 262                old.update_from_acp(new, language_registry.clone(), terminals, cx)?;
 263            }
 264            for new in content {
 265                self.content.push(ToolCallContent::from_acp(
 266                    new,
 267                    language_registry.clone(),
 268                    terminals,
 269                    cx,
 270                )?)
 271            }
 272            self.content.truncate(new_content_len);
 273        }
 274
 275        if let Some(locations) = locations {
 276            self.locations = locations;
 277        }
 278
 279        if let Some(raw_input) = raw_input {
 280            self.raw_input = Some(raw_input);
 281        }
 282
 283        if let Some(raw_output) = raw_output {
 284            if self.content.is_empty()
 285                && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
 286            {
 287                self.content
 288                    .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
 289                        markdown,
 290                    }));
 291            }
 292            self.raw_output = Some(raw_output);
 293        }
 294        Ok(())
 295    }
 296
 297    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 298        self.content.iter().filter_map(|content| match content {
 299            ToolCallContent::Diff(diff) => Some(diff),
 300            ToolCallContent::ContentBlock(_) => None,
 301            ToolCallContent::Terminal(_) => None,
 302        })
 303    }
 304
 305    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 306        self.content.iter().filter_map(|content| match content {
 307            ToolCallContent::Terminal(terminal) => Some(terminal),
 308            ToolCallContent::ContentBlock(_) => None,
 309            ToolCallContent::Diff(_) => None,
 310        })
 311    }
 312
 313    fn to_markdown(&self, cx: &App) -> String {
 314        let mut markdown = format!(
 315            "**Tool Call: {}**\nStatus: {}\n\n",
 316            self.label.read(cx).source(),
 317            self.status
 318        );
 319        for content in &self.content {
 320            markdown.push_str(content.to_markdown(cx).as_str());
 321            markdown.push_str("\n\n");
 322        }
 323        markdown
 324    }
 325
 326    async fn resolve_location(
 327        location: acp::ToolCallLocation,
 328        project: WeakEntity<Project>,
 329        cx: &mut AsyncApp,
 330    ) -> Option<AgentLocation> {
 331        let buffer = project
 332            .update(cx, |project, cx| {
 333                project
 334                    .project_path_for_absolute_path(&location.path, cx)
 335                    .map(|path| project.open_buffer(path, cx))
 336            })
 337            .ok()??;
 338        let buffer = buffer.await.log_err()?;
 339        let position = buffer
 340            .update(cx, |buffer, _| {
 341                if let Some(row) = location.line {
 342                    let snapshot = buffer.snapshot();
 343                    let column = snapshot.indent_size_for_line(row).len;
 344                    let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
 345                    snapshot.anchor_before(point)
 346                } else {
 347                    Anchor::MIN
 348                }
 349            })
 350            .ok()?;
 351
 352        Some(AgentLocation {
 353            buffer: buffer.downgrade(),
 354            position,
 355        })
 356    }
 357
 358    fn resolve_locations(
 359        &self,
 360        project: Entity<Project>,
 361        cx: &mut App,
 362    ) -> Task<Vec<Option<AgentLocation>>> {
 363        let locations = self.locations.clone();
 364        project.update(cx, |_, cx| {
 365            cx.spawn(async move |project, cx| {
 366                let mut new_locations = Vec::new();
 367                for location in locations {
 368                    new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
 369                }
 370                new_locations
 371            })
 372        })
 373    }
 374}
 375
 376#[derive(Debug)]
 377pub enum ToolCallStatus {
 378    /// The tool call hasn't started running yet, but we start showing it to
 379    /// the user.
 380    Pending,
 381    /// The tool call is waiting for confirmation from the user.
 382    WaitingForConfirmation {
 383        options: Vec<acp::PermissionOption>,
 384        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 385    },
 386    /// The tool call is currently running.
 387    InProgress,
 388    /// The tool call completed successfully.
 389    Completed,
 390    /// The tool call failed.
 391    Failed,
 392    /// The user rejected the tool call.
 393    Rejected,
 394    /// The user canceled generation so the tool call was canceled.
 395    Canceled,
 396}
 397
 398impl From<acp::ToolCallStatus> for ToolCallStatus {
 399    fn from(status: acp::ToolCallStatus) -> Self {
 400        match status {
 401            acp::ToolCallStatus::Pending => Self::Pending,
 402            acp::ToolCallStatus::InProgress => Self::InProgress,
 403            acp::ToolCallStatus::Completed => Self::Completed,
 404            acp::ToolCallStatus::Failed => Self::Failed,
 405        }
 406    }
 407}
 408
 409impl Display for ToolCallStatus {
 410    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 411        write!(
 412            f,
 413            "{}",
 414            match self {
 415                ToolCallStatus::Pending => "Pending",
 416                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 417                ToolCallStatus::InProgress => "In Progress",
 418                ToolCallStatus::Completed => "Completed",
 419                ToolCallStatus::Failed => "Failed",
 420                ToolCallStatus::Rejected => "Rejected",
 421                ToolCallStatus::Canceled => "Canceled",
 422            }
 423        )
 424    }
 425}
 426
 427#[derive(Debug, PartialEq, Clone)]
 428pub enum ContentBlock {
 429    Empty,
 430    Markdown { markdown: Entity<Markdown> },
 431    ResourceLink { resource_link: acp::ResourceLink },
 432}
 433
 434impl ContentBlock {
 435    pub fn new(
 436        block: acp::ContentBlock,
 437        language_registry: &Arc<LanguageRegistry>,
 438        cx: &mut App,
 439    ) -> Self {
 440        let mut this = Self::Empty;
 441        this.append(block, language_registry, cx);
 442        this
 443    }
 444
 445    pub fn new_combined(
 446        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 447        language_registry: Arc<LanguageRegistry>,
 448        cx: &mut App,
 449    ) -> Self {
 450        let mut this = Self::Empty;
 451        for block in blocks {
 452            this.append(block, &language_registry, cx);
 453        }
 454        this
 455    }
 456
 457    pub fn append(
 458        &mut self,
 459        block: acp::ContentBlock,
 460        language_registry: &Arc<LanguageRegistry>,
 461        cx: &mut App,
 462    ) {
 463        if matches!(self, ContentBlock::Empty)
 464            && let acp::ContentBlock::ResourceLink(resource_link) = block
 465        {
 466            *self = ContentBlock::ResourceLink { resource_link };
 467            return;
 468        }
 469
 470        let new_content = self.block_string_contents(block);
 471
 472        match self {
 473            ContentBlock::Empty => {
 474                *self = Self::create_markdown_block(new_content, language_registry, cx);
 475            }
 476            ContentBlock::Markdown { markdown } => {
 477                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 478            }
 479            ContentBlock::ResourceLink { resource_link } => {
 480                let existing_content = Self::resource_link_md(&resource_link.uri);
 481                let combined = format!("{}\n{}", existing_content, new_content);
 482
 483                *self = Self::create_markdown_block(combined, language_registry, cx);
 484            }
 485        }
 486    }
 487
 488    fn create_markdown_block(
 489        content: String,
 490        language_registry: &Arc<LanguageRegistry>,
 491        cx: &mut App,
 492    ) -> ContentBlock {
 493        ContentBlock::Markdown {
 494            markdown: cx
 495                .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
 496        }
 497    }
 498
 499    fn block_string_contents(&self, block: acp::ContentBlock) -> String {
 500        match block {
 501            acp::ContentBlock::Text(text_content) => text_content.text,
 502            acp::ContentBlock::ResourceLink(resource_link) => {
 503                Self::resource_link_md(&resource_link.uri)
 504            }
 505            acp::ContentBlock::Resource(acp::EmbeddedResource {
 506                resource:
 507                    acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
 508                        uri,
 509                        ..
 510                    }),
 511                ..
 512            }) => Self::resource_link_md(&uri),
 513            acp::ContentBlock::Image(image) => Self::image_md(&image),
 514            acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
 515        }
 516    }
 517
 518    fn resource_link_md(uri: &str) -> String {
 519        if let Some(uri) = MentionUri::parse(uri).log_err() {
 520            uri.as_link().to_string()
 521        } else {
 522            uri.to_string()
 523        }
 524    }
 525
 526    fn image_md(_image: &acp::ImageContent) -> String {
 527        "`Image`".into()
 528    }
 529
 530    pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 531        match self {
 532            ContentBlock::Empty => "",
 533            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 534            ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
 535        }
 536    }
 537
 538    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 539        match self {
 540            ContentBlock::Empty => None,
 541            ContentBlock::Markdown { markdown } => Some(markdown),
 542            ContentBlock::ResourceLink { .. } => None,
 543        }
 544    }
 545
 546    pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
 547        match self {
 548            ContentBlock::ResourceLink { resource_link } => Some(resource_link),
 549            _ => None,
 550        }
 551    }
 552}
 553
 554#[derive(Debug)]
 555pub enum ToolCallContent {
 556    ContentBlock(ContentBlock),
 557    Diff(Entity<Diff>),
 558    Terminal(Entity<Terminal>),
 559}
 560
 561impl ToolCallContent {
 562    pub fn from_acp(
 563        content: acp::ToolCallContent,
 564        language_registry: Arc<LanguageRegistry>,
 565        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 566        cx: &mut App,
 567    ) -> Result<Self> {
 568        match content {
 569            acp::ToolCallContent::Content { content } => Ok(Self::ContentBlock(ContentBlock::new(
 570                content,
 571                &language_registry,
 572                cx,
 573            ))),
 574            acp::ToolCallContent::Diff { diff } => Ok(Self::Diff(cx.new(|cx| {
 575                Diff::finalized(
 576                    diff.path,
 577                    diff.old_text,
 578                    diff.new_text,
 579                    language_registry,
 580                    cx,
 581                )
 582            }))),
 583            acp::ToolCallContent::Terminal { terminal_id } => terminals
 584                .get(&terminal_id)
 585                .cloned()
 586                .map(Self::Terminal)
 587                .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
 588        }
 589    }
 590
 591    pub fn update_from_acp(
 592        &mut self,
 593        new: acp::ToolCallContent,
 594        language_registry: Arc<LanguageRegistry>,
 595        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 596        cx: &mut App,
 597    ) -> Result<()> {
 598        let needs_update = match (&self, &new) {
 599            (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
 600                old_diff.read(cx).needs_update(
 601                    new_diff.old_text.as_deref().unwrap_or(""),
 602                    &new_diff.new_text,
 603                    cx,
 604                )
 605            }
 606            _ => true,
 607        };
 608
 609        if needs_update {
 610            *self = Self::from_acp(new, language_registry, terminals, cx)?;
 611        }
 612        Ok(())
 613    }
 614
 615    pub fn to_markdown(&self, cx: &App) -> String {
 616        match self {
 617            Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
 618            Self::Diff(diff) => diff.read(cx).to_markdown(cx),
 619            Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
 620        }
 621    }
 622}
 623
 624#[derive(Debug, PartialEq)]
 625pub enum ToolCallUpdate {
 626    UpdateFields(acp::ToolCallUpdate),
 627    UpdateDiff(ToolCallUpdateDiff),
 628    UpdateTerminal(ToolCallUpdateTerminal),
 629}
 630
 631impl ToolCallUpdate {
 632    fn id(&self) -> &acp::ToolCallId {
 633        match self {
 634            Self::UpdateFields(update) => &update.id,
 635            Self::UpdateDiff(diff) => &diff.id,
 636            Self::UpdateTerminal(terminal) => &terminal.id,
 637        }
 638    }
 639}
 640
 641impl From<acp::ToolCallUpdate> for ToolCallUpdate {
 642    fn from(update: acp::ToolCallUpdate) -> Self {
 643        Self::UpdateFields(update)
 644    }
 645}
 646
 647impl From<ToolCallUpdateDiff> for ToolCallUpdate {
 648    fn from(diff: ToolCallUpdateDiff) -> Self {
 649        Self::UpdateDiff(diff)
 650    }
 651}
 652
 653#[derive(Debug, PartialEq)]
 654pub struct ToolCallUpdateDiff {
 655    pub id: acp::ToolCallId,
 656    pub diff: Entity<Diff>,
 657}
 658
 659impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
 660    fn from(terminal: ToolCallUpdateTerminal) -> Self {
 661        Self::UpdateTerminal(terminal)
 662    }
 663}
 664
 665#[derive(Debug, PartialEq)]
 666pub struct ToolCallUpdateTerminal {
 667    pub id: acp::ToolCallId,
 668    pub terminal: Entity<Terminal>,
 669}
 670
 671#[derive(Debug, Default)]
 672pub struct Plan {
 673    pub entries: Vec<PlanEntry>,
 674}
 675
 676#[derive(Debug)]
 677pub struct PlanStats<'a> {
 678    pub in_progress_entry: Option<&'a PlanEntry>,
 679    pub pending: u32,
 680    pub completed: u32,
 681}
 682
 683impl Plan {
 684    pub fn is_empty(&self) -> bool {
 685        self.entries.is_empty()
 686    }
 687
 688    pub fn stats(&self) -> PlanStats<'_> {
 689        let mut stats = PlanStats {
 690            in_progress_entry: None,
 691            pending: 0,
 692            completed: 0,
 693        };
 694
 695        for entry in &self.entries {
 696            match &entry.status {
 697                acp::PlanEntryStatus::Pending => {
 698                    stats.pending += 1;
 699                }
 700                acp::PlanEntryStatus::InProgress => {
 701                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 702                }
 703                acp::PlanEntryStatus::Completed => {
 704                    stats.completed += 1;
 705                }
 706            }
 707        }
 708
 709        stats
 710    }
 711}
 712
 713#[derive(Debug)]
 714pub struct PlanEntry {
 715    pub content: Entity<Markdown>,
 716    pub priority: acp::PlanEntryPriority,
 717    pub status: acp::PlanEntryStatus,
 718}
 719
 720impl PlanEntry {
 721    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 722        Self {
 723            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 724            priority: entry.priority,
 725            status: entry.status,
 726        }
 727    }
 728}
 729
 730#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 731pub struct TokenUsage {
 732    pub max_tokens: u64,
 733    pub used_tokens: u64,
 734}
 735
 736impl TokenUsage {
 737    pub fn ratio(&self) -> TokenUsageRatio {
 738        #[cfg(debug_assertions)]
 739        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 740            .unwrap_or("0.8".to_string())
 741            .parse()
 742            .unwrap();
 743        #[cfg(not(debug_assertions))]
 744        let warning_threshold: f32 = 0.8;
 745
 746        // When the maximum is unknown because there is no selected model,
 747        // avoid showing the token limit warning.
 748        if self.max_tokens == 0 {
 749            TokenUsageRatio::Normal
 750        } else if self.used_tokens >= self.max_tokens {
 751            TokenUsageRatio::Exceeded
 752        } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
 753            TokenUsageRatio::Warning
 754        } else {
 755            TokenUsageRatio::Normal
 756        }
 757    }
 758}
 759
 760#[derive(Debug, Clone, PartialEq, Eq)]
 761pub enum TokenUsageRatio {
 762    Normal,
 763    Warning,
 764    Exceeded,
 765}
 766
 767#[derive(Debug, Clone)]
 768pub struct RetryStatus {
 769    pub last_error: SharedString,
 770    pub attempt: usize,
 771    pub max_attempts: usize,
 772    pub started_at: Instant,
 773    pub duration: Duration,
 774}
 775
 776pub struct AcpThread {
 777    title: SharedString,
 778    entries: Vec<AgentThreadEntry>,
 779    plan: Plan,
 780    project: Entity<Project>,
 781    action_log: Entity<ActionLog>,
 782    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 783    send_task: Option<Task<()>>,
 784    connection: Rc<dyn AgentConnection>,
 785    session_id: acp::SessionId,
 786    token_usage: Option<TokenUsage>,
 787    prompt_capabilities: acp::PromptCapabilities,
 788    _observe_prompt_capabilities: Task<anyhow::Result<()>>,
 789    determine_shell: Shared<Task<String>>,
 790    terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
 791}
 792
 793#[derive(Debug)]
 794pub enum AcpThreadEvent {
 795    NewEntry,
 796    TitleUpdated,
 797    TokenUsageUpdated,
 798    EntryUpdated(usize),
 799    EntriesRemoved(Range<usize>),
 800    ToolAuthorizationRequired,
 801    Retry(RetryStatus),
 802    Stopped,
 803    Error,
 804    LoadError(LoadError),
 805    PromptCapabilitiesUpdated,
 806    Refusal,
 807    AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
 808    ModeUpdated(acp::SessionModeId),
 809}
 810
 811impl EventEmitter<AcpThreadEvent> for AcpThread {}
 812
 813#[derive(PartialEq, Eq, Debug)]
 814pub enum ThreadStatus {
 815    Idle,
 816    Generating,
 817}
 818
 819#[derive(Debug, Clone)]
 820pub enum LoadError {
 821    Unsupported {
 822        command: SharedString,
 823        current_version: SharedString,
 824        minimum_version: SharedString,
 825    },
 826    FailedToInstall(SharedString),
 827    Exited {
 828        status: ExitStatus,
 829    },
 830    Other(SharedString),
 831}
 832
 833impl Display for LoadError {
 834    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 835        match self {
 836            LoadError::Unsupported {
 837                command: path,
 838                current_version,
 839                minimum_version,
 840            } => {
 841                write!(
 842                    f,
 843                    "version {current_version} from {path} is not supported (need at least {minimum_version})"
 844                )
 845            }
 846            LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
 847            LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
 848            LoadError::Other(msg) => write!(f, "{msg}"),
 849        }
 850    }
 851}
 852
 853impl Error for LoadError {}
 854
 855impl AcpThread {
 856    pub fn new(
 857        title: impl Into<SharedString>,
 858        connection: Rc<dyn AgentConnection>,
 859        project: Entity<Project>,
 860        action_log: Entity<ActionLog>,
 861        session_id: acp::SessionId,
 862        mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
 863        cx: &mut Context<Self>,
 864    ) -> Self {
 865        let prompt_capabilities = *prompt_capabilities_rx.borrow();
 866        let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
 867            loop {
 868                let caps = prompt_capabilities_rx.recv().await?;
 869                this.update(cx, |this, cx| {
 870                    this.prompt_capabilities = caps;
 871                    cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
 872                })?;
 873            }
 874        });
 875
 876        let determine_shell = cx
 877            .background_spawn(async move {
 878                if cfg!(windows) {
 879                    return get_system_shell();
 880                }
 881
 882                if which::which("bash").is_ok() {
 883                    "bash".into()
 884                } else {
 885                    get_system_shell()
 886                }
 887            })
 888            .shared();
 889
 890        Self {
 891            action_log,
 892            shared_buffers: Default::default(),
 893            entries: Default::default(),
 894            plan: Default::default(),
 895            title: title.into(),
 896            project,
 897            send_task: None,
 898            connection,
 899            session_id,
 900            token_usage: None,
 901            prompt_capabilities,
 902            _observe_prompt_capabilities: task,
 903            terminals: HashMap::default(),
 904            determine_shell,
 905        }
 906    }
 907
 908    pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
 909        self.prompt_capabilities
 910    }
 911
 912    pub fn connection(&self) -> &Rc<dyn AgentConnection> {
 913        &self.connection
 914    }
 915
 916    pub fn action_log(&self) -> &Entity<ActionLog> {
 917        &self.action_log
 918    }
 919
 920    pub fn project(&self) -> &Entity<Project> {
 921        &self.project
 922    }
 923
 924    pub fn title(&self) -> SharedString {
 925        self.title.clone()
 926    }
 927
 928    pub fn entries(&self) -> &[AgentThreadEntry] {
 929        &self.entries
 930    }
 931
 932    pub fn session_id(&self) -> &acp::SessionId {
 933        &self.session_id
 934    }
 935
 936    pub fn status(&self) -> ThreadStatus {
 937        if self.send_task.is_some() {
 938            ThreadStatus::Generating
 939        } else {
 940            ThreadStatus::Idle
 941        }
 942    }
 943
 944    pub fn token_usage(&self) -> Option<&TokenUsage> {
 945        self.token_usage.as_ref()
 946    }
 947
 948    pub fn has_pending_edit_tool_calls(&self) -> bool {
 949        for entry in self.entries.iter().rev() {
 950            match entry {
 951                AgentThreadEntry::UserMessage(_) => return false,
 952                AgentThreadEntry::ToolCall(
 953                    call @ ToolCall {
 954                        status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
 955                        ..
 956                    },
 957                ) if call.diffs().next().is_some() => {
 958                    return true;
 959                }
 960                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 961            }
 962        }
 963
 964        false
 965    }
 966
 967    pub fn used_tools_since_last_user_message(&self) -> bool {
 968        for entry in self.entries.iter().rev() {
 969            match entry {
 970                AgentThreadEntry::UserMessage(..) => return false,
 971                AgentThreadEntry::AssistantMessage(..) => continue,
 972                AgentThreadEntry::ToolCall(..) => return true,
 973            }
 974        }
 975
 976        false
 977    }
 978
 979    pub fn handle_session_update(
 980        &mut self,
 981        update: acp::SessionUpdate,
 982        cx: &mut Context<Self>,
 983    ) -> Result<(), acp::Error> {
 984        match update {
 985            acp::SessionUpdate::UserMessageChunk { content } => {
 986                self.push_user_content_block(None, content, cx);
 987            }
 988            acp::SessionUpdate::AgentMessageChunk { content } => {
 989                self.push_assistant_content_block(content, false, cx);
 990            }
 991            acp::SessionUpdate::AgentThoughtChunk { content } => {
 992                self.push_assistant_content_block(content, true, cx);
 993            }
 994            acp::SessionUpdate::ToolCall(tool_call) => {
 995                self.upsert_tool_call(tool_call, cx)?;
 996            }
 997            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 998                self.update_tool_call(tool_call_update, cx)?;
 999            }
1000            acp::SessionUpdate::Plan(plan) => {
1001                self.update_plan(plan, cx);
1002            }
1003            acp::SessionUpdate::AvailableCommandsUpdate { available_commands } => {
1004                cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands))
1005            }
1006            acp::SessionUpdate::CurrentModeUpdate { current_mode_id } => {
1007                cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id))
1008            }
1009        }
1010        Ok(())
1011    }
1012
1013    pub fn push_user_content_block(
1014        &mut self,
1015        message_id: Option<UserMessageId>,
1016        chunk: acp::ContentBlock,
1017        cx: &mut Context<Self>,
1018    ) {
1019        let language_registry = self.project.read(cx).languages().clone();
1020        let entries_len = self.entries.len();
1021
1022        if let Some(last_entry) = self.entries.last_mut()
1023            && let AgentThreadEntry::UserMessage(UserMessage {
1024                id,
1025                content,
1026                chunks,
1027                ..
1028            }) = last_entry
1029        {
1030            *id = message_id.or(id.take());
1031            content.append(chunk.clone(), &language_registry, cx);
1032            chunks.push(chunk);
1033            let idx = entries_len - 1;
1034            cx.emit(AcpThreadEvent::EntryUpdated(idx));
1035        } else {
1036            let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
1037            self.push_entry(
1038                AgentThreadEntry::UserMessage(UserMessage {
1039                    id: message_id,
1040                    content,
1041                    chunks: vec![chunk],
1042                    checkpoint: None,
1043                }),
1044                cx,
1045            );
1046        }
1047    }
1048
1049    pub fn push_assistant_content_block(
1050        &mut self,
1051        chunk: acp::ContentBlock,
1052        is_thought: bool,
1053        cx: &mut Context<Self>,
1054    ) {
1055        let language_registry = self.project.read(cx).languages().clone();
1056        let entries_len = self.entries.len();
1057        if let Some(last_entry) = self.entries.last_mut()
1058            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1059        {
1060            let idx = entries_len - 1;
1061            cx.emit(AcpThreadEvent::EntryUpdated(idx));
1062            match (chunks.last_mut(), is_thought) {
1063                (Some(AssistantMessageChunk::Message { block }), false)
1064                | (Some(AssistantMessageChunk::Thought { block }), true) => {
1065                    block.append(chunk, &language_registry, cx)
1066                }
1067                _ => {
1068                    let block = ContentBlock::new(chunk, &language_registry, cx);
1069                    if is_thought {
1070                        chunks.push(AssistantMessageChunk::Thought { block })
1071                    } else {
1072                        chunks.push(AssistantMessageChunk::Message { block })
1073                    }
1074                }
1075            }
1076        } else {
1077            let block = ContentBlock::new(chunk, &language_registry, cx);
1078            let chunk = if is_thought {
1079                AssistantMessageChunk::Thought { block }
1080            } else {
1081                AssistantMessageChunk::Message { block }
1082            };
1083
1084            self.push_entry(
1085                AgentThreadEntry::AssistantMessage(AssistantMessage {
1086                    chunks: vec![chunk],
1087                }),
1088                cx,
1089            );
1090        }
1091    }
1092
1093    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1094        self.entries.push(entry);
1095        cx.emit(AcpThreadEvent::NewEntry);
1096    }
1097
1098    pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1099        self.connection.set_title(&self.session_id, cx).is_some()
1100    }
1101
1102    pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1103        if title != self.title {
1104            self.title = title.clone();
1105            cx.emit(AcpThreadEvent::TitleUpdated);
1106            if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1107                return set_title.run(title, cx);
1108            }
1109        }
1110        Task::ready(Ok(()))
1111    }
1112
1113    pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1114        self.token_usage = usage;
1115        cx.emit(AcpThreadEvent::TokenUsageUpdated);
1116    }
1117
1118    pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1119        cx.emit(AcpThreadEvent::Retry(status));
1120    }
1121
1122    pub fn update_tool_call(
1123        &mut self,
1124        update: impl Into<ToolCallUpdate>,
1125        cx: &mut Context<Self>,
1126    ) -> Result<()> {
1127        let update = update.into();
1128        let languages = self.project.read(cx).languages().clone();
1129
1130        let ix = self
1131            .index_for_tool_call(update.id())
1132            .context("Tool call not found")?;
1133        let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1134            unreachable!()
1135        };
1136
1137        match update {
1138            ToolCallUpdate::UpdateFields(update) => {
1139                let location_updated = update.fields.locations.is_some();
1140                call.update_fields(update.fields, languages, &self.terminals, cx)?;
1141                if location_updated {
1142                    self.resolve_locations(update.id, cx);
1143                }
1144            }
1145            ToolCallUpdate::UpdateDiff(update) => {
1146                call.content.clear();
1147                call.content.push(ToolCallContent::Diff(update.diff));
1148            }
1149            ToolCallUpdate::UpdateTerminal(update) => {
1150                call.content.clear();
1151                call.content
1152                    .push(ToolCallContent::Terminal(update.terminal));
1153            }
1154        }
1155
1156        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1157
1158        Ok(())
1159    }
1160
1161    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1162    pub fn upsert_tool_call(
1163        &mut self,
1164        tool_call: acp::ToolCall,
1165        cx: &mut Context<Self>,
1166    ) -> Result<(), acp::Error> {
1167        let status = tool_call.status.into();
1168        self.upsert_tool_call_inner(tool_call.into(), status, cx)
1169    }
1170
1171    /// Fails if id does not match an existing entry.
1172    pub fn upsert_tool_call_inner(
1173        &mut self,
1174        update: acp::ToolCallUpdate,
1175        status: ToolCallStatus,
1176        cx: &mut Context<Self>,
1177    ) -> Result<(), acp::Error> {
1178        let language_registry = self.project.read(cx).languages().clone();
1179        let id = update.id.clone();
1180
1181        if let Some(ix) = self.index_for_tool_call(&id) {
1182            let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1183                unreachable!()
1184            };
1185
1186            call.update_fields(update.fields, language_registry, &self.terminals, cx)?;
1187            call.status = status;
1188
1189            cx.emit(AcpThreadEvent::EntryUpdated(ix));
1190        } else {
1191            let call = ToolCall::from_acp(
1192                update.try_into()?,
1193                status,
1194                language_registry,
1195                &self.terminals,
1196                cx,
1197            )?;
1198            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1199        };
1200
1201        self.resolve_locations(id, cx);
1202        Ok(())
1203    }
1204
1205    fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1206        self.entries
1207            .iter()
1208            .enumerate()
1209            .rev()
1210            .find_map(|(index, entry)| {
1211                if let AgentThreadEntry::ToolCall(tool_call) = entry
1212                    && &tool_call.id == id
1213                {
1214                    Some(index)
1215                } else {
1216                    None
1217                }
1218            })
1219    }
1220
1221    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1222        // The tool call we are looking for is typically the last one, or very close to the end.
1223        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1224        self.entries
1225            .iter_mut()
1226            .enumerate()
1227            .rev()
1228            .find_map(|(index, tool_call)| {
1229                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1230                    && &tool_call.id == id
1231                {
1232                    Some((index, tool_call))
1233                } else {
1234                    None
1235                }
1236            })
1237    }
1238
1239    pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1240        self.entries
1241            .iter()
1242            .enumerate()
1243            .rev()
1244            .find_map(|(index, tool_call)| {
1245                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1246                    && &tool_call.id == id
1247                {
1248                    Some((index, tool_call))
1249                } else {
1250                    None
1251                }
1252            })
1253    }
1254
1255    pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1256        let project = self.project.clone();
1257        let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1258            return;
1259        };
1260        let task = tool_call.resolve_locations(project, cx);
1261        cx.spawn(async move |this, cx| {
1262            let resolved_locations = task.await;
1263            this.update(cx, |this, cx| {
1264                let project = this.project.clone();
1265                let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1266                    return;
1267                };
1268                if let Some(Some(location)) = resolved_locations.last() {
1269                    project.update(cx, |project, cx| {
1270                        if let Some(agent_location) = project.agent_location() {
1271                            let should_ignore = agent_location.buffer == location.buffer
1272                                && location
1273                                    .buffer
1274                                    .update(cx, |buffer, _| {
1275                                        let snapshot = buffer.snapshot();
1276                                        let old_position =
1277                                            agent_location.position.to_point(&snapshot);
1278                                        let new_position = location.position.to_point(&snapshot);
1279                                        // ignore this so that when we get updates from the edit tool
1280                                        // the position doesn't reset to the startof line
1281                                        old_position.row == new_position.row
1282                                            && old_position.column > new_position.column
1283                                    })
1284                                    .ok()
1285                                    .unwrap_or_default();
1286                            if !should_ignore {
1287                                project.set_agent_location(Some(location.clone()), cx);
1288                            }
1289                        }
1290                    });
1291                }
1292                if tool_call.resolved_locations != resolved_locations {
1293                    tool_call.resolved_locations = resolved_locations;
1294                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1295                }
1296            })
1297        })
1298        .detach();
1299    }
1300
1301    pub fn request_tool_call_authorization(
1302        &mut self,
1303        tool_call: acp::ToolCallUpdate,
1304        options: Vec<acp::PermissionOption>,
1305        respect_always_allow_setting: bool,
1306        cx: &mut Context<Self>,
1307    ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1308        let (tx, rx) = oneshot::channel();
1309
1310        if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1311            // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1312            // some tools would (incorrectly) continue to auto-accept.
1313            if let Some(allow_once_option) = options.iter().find_map(|option| {
1314                if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1315                    Some(option.id.clone())
1316                } else {
1317                    None
1318                }
1319            }) {
1320                self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1321                return Ok(async {
1322                    acp::RequestPermissionOutcome::Selected {
1323                        option_id: allow_once_option,
1324                    }
1325                }
1326                .boxed());
1327            }
1328        }
1329
1330        let status = ToolCallStatus::WaitingForConfirmation {
1331            options,
1332            respond_tx: tx,
1333        };
1334
1335        self.upsert_tool_call_inner(tool_call, status, cx)?;
1336        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1337
1338        let fut = async {
1339            match rx.await {
1340                Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
1341                Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1342            }
1343        }
1344        .boxed();
1345
1346        Ok(fut)
1347    }
1348
1349    pub fn authorize_tool_call(
1350        &mut self,
1351        id: acp::ToolCallId,
1352        option_id: acp::PermissionOptionId,
1353        option_kind: acp::PermissionOptionKind,
1354        cx: &mut Context<Self>,
1355    ) {
1356        let Some((ix, call)) = self.tool_call_mut(&id) else {
1357            return;
1358        };
1359
1360        let new_status = match option_kind {
1361            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1362                ToolCallStatus::Rejected
1363            }
1364            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1365                ToolCallStatus::InProgress
1366            }
1367        };
1368
1369        let curr_status = mem::replace(&mut call.status, new_status);
1370
1371        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1372            respond_tx.send(option_id).log_err();
1373        } else if cfg!(debug_assertions) {
1374            panic!("tried to authorize an already authorized tool call");
1375        }
1376
1377        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1378    }
1379
1380    pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1381        let mut first_tool_call = None;
1382
1383        for entry in self.entries.iter().rev() {
1384            match &entry {
1385                AgentThreadEntry::ToolCall(call) => {
1386                    if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1387                        first_tool_call = Some(call);
1388                    } else {
1389                        continue;
1390                    }
1391                }
1392                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1393                    // Reached the beginning of the turn.
1394                    // If we had pending permission requests in the previous turn, they have been cancelled.
1395                    break;
1396                }
1397            }
1398        }
1399
1400        first_tool_call
1401    }
1402
1403    pub fn plan(&self) -> &Plan {
1404        &self.plan
1405    }
1406
1407    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1408        let new_entries_len = request.entries.len();
1409        let mut new_entries = request.entries.into_iter();
1410
1411        // Reuse existing markdown to prevent flickering
1412        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1413            let PlanEntry {
1414                content,
1415                priority,
1416                status,
1417            } = old;
1418            content.update(cx, |old, cx| {
1419                old.replace(new.content, cx);
1420            });
1421            *priority = new.priority;
1422            *status = new.status;
1423        }
1424        for new in new_entries {
1425            self.plan.entries.push(PlanEntry::from_acp(new, cx))
1426        }
1427        self.plan.entries.truncate(new_entries_len);
1428
1429        cx.notify();
1430    }
1431
1432    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1433        self.plan
1434            .entries
1435            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1436        cx.notify();
1437    }
1438
1439    #[cfg(any(test, feature = "test-support"))]
1440    pub fn send_raw(
1441        &mut self,
1442        message: &str,
1443        cx: &mut Context<Self>,
1444    ) -> BoxFuture<'static, Result<()>> {
1445        self.send(
1446            vec![acp::ContentBlock::Text(acp::TextContent {
1447                text: message.to_string(),
1448                annotations: None,
1449            })],
1450            cx,
1451        )
1452    }
1453
1454    pub fn send(
1455        &mut self,
1456        message: Vec<acp::ContentBlock>,
1457        cx: &mut Context<Self>,
1458    ) -> BoxFuture<'static, Result<()>> {
1459        let block = ContentBlock::new_combined(
1460            message.clone(),
1461            self.project.read(cx).languages().clone(),
1462            cx,
1463        );
1464        let request = acp::PromptRequest {
1465            prompt: message.clone(),
1466            session_id: self.session_id.clone(),
1467        };
1468        let git_store = self.project.read(cx).git_store().clone();
1469
1470        let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1471            Some(UserMessageId::new())
1472        } else {
1473            None
1474        };
1475
1476        self.run_turn(cx, async move |this, cx| {
1477            this.update(cx, |this, cx| {
1478                this.push_entry(
1479                    AgentThreadEntry::UserMessage(UserMessage {
1480                        id: message_id.clone(),
1481                        content: block,
1482                        chunks: message,
1483                        checkpoint: None,
1484                    }),
1485                    cx,
1486                );
1487            })
1488            .ok();
1489
1490            let old_checkpoint = git_store
1491                .update(cx, |git, cx| git.checkpoint(cx))?
1492                .await
1493                .context("failed to get old checkpoint")
1494                .log_err();
1495            this.update(cx, |this, cx| {
1496                if let Some((_ix, message)) = this.last_user_message() {
1497                    message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1498                        git_checkpoint,
1499                        show: false,
1500                    });
1501                }
1502                this.connection.prompt(message_id, request, cx)
1503            })?
1504            .await
1505        })
1506    }
1507
1508    pub fn can_resume(&self, cx: &App) -> bool {
1509        self.connection.resume(&self.session_id, cx).is_some()
1510    }
1511
1512    pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1513        self.run_turn(cx, async move |this, cx| {
1514            this.update(cx, |this, cx| {
1515                this.connection
1516                    .resume(&this.session_id, cx)
1517                    .map(|resume| resume.run(cx))
1518            })?
1519            .context("resuming a session is not supported")?
1520            .await
1521        })
1522    }
1523
1524    fn run_turn(
1525        &mut self,
1526        cx: &mut Context<Self>,
1527        f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1528    ) -> BoxFuture<'static, Result<()>> {
1529        self.clear_completed_plan_entries(cx);
1530
1531        let (tx, rx) = oneshot::channel();
1532        let cancel_task = self.cancel(cx);
1533
1534        self.send_task = Some(cx.spawn(async move |this, cx| {
1535            cancel_task.await;
1536            tx.send(f(this, cx).await).ok();
1537        }));
1538
1539        cx.spawn(async move |this, cx| {
1540            let response = rx.await;
1541
1542            this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1543                .await?;
1544
1545            this.update(cx, |this, cx| {
1546                this.project
1547                    .update(cx, |project, cx| project.set_agent_location(None, cx));
1548                match response {
1549                    Ok(Err(e)) => {
1550                        this.send_task.take();
1551                        cx.emit(AcpThreadEvent::Error);
1552                        Err(e)
1553                    }
1554                    result => {
1555                        let canceled = matches!(
1556                            result,
1557                            Ok(Ok(acp::PromptResponse {
1558                                stop_reason: acp::StopReason::Cancelled
1559                            }))
1560                        );
1561
1562                        // We only take the task if the current prompt wasn't canceled.
1563                        //
1564                        // This prompt may have been canceled because another one was sent
1565                        // while it was still generating. In these cases, dropping `send_task`
1566                        // would cause the next generation to be canceled.
1567                        if !canceled {
1568                            this.send_task.take();
1569                        }
1570
1571                        // Handle refusal - distinguish between user prompt and tool call refusals
1572                        if let Ok(Ok(acp::PromptResponse {
1573                            stop_reason: acp::StopReason::Refusal,
1574                        })) = result
1575                        {
1576                            if let Some((user_msg_ix, _)) = this.last_user_message() {
1577                                // Check if there's a completed tool call with results after the last user message
1578                                // This indicates the refusal is in response to tool output, not the user's prompt
1579                                let has_completed_tool_call_after_user_msg =
1580                                    this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1581                                        if let AgentThreadEntry::ToolCall(tool_call) = entry {
1582                                            // Check if the tool call has completed and has output
1583                                            matches!(tool_call.status, ToolCallStatus::Completed)
1584                                                && tool_call.raw_output.is_some()
1585                                        } else {
1586                                            false
1587                                        }
1588                                    });
1589
1590                                if has_completed_tool_call_after_user_msg {
1591                                    // Refusal is due to tool output - don't truncate, just notify
1592                                    // The model refused based on what the tool returned
1593                                    cx.emit(AcpThreadEvent::Refusal);
1594                                } else {
1595                                    // User prompt was refused - truncate back to before the user message
1596                                    let range = user_msg_ix..this.entries.len();
1597                                    if range.start < range.end {
1598                                        this.entries.truncate(user_msg_ix);
1599                                        cx.emit(AcpThreadEvent::EntriesRemoved(range));
1600                                    }
1601                                    cx.emit(AcpThreadEvent::Refusal);
1602                                }
1603                            } else {
1604                                // No user message found, treat as general refusal
1605                                cx.emit(AcpThreadEvent::Refusal);
1606                            }
1607                        }
1608
1609                        cx.emit(AcpThreadEvent::Stopped);
1610                        Ok(())
1611                    }
1612                }
1613            })?
1614        })
1615        .boxed()
1616    }
1617
1618    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1619        let Some(send_task) = self.send_task.take() else {
1620            return Task::ready(());
1621        };
1622
1623        for entry in self.entries.iter_mut() {
1624            if let AgentThreadEntry::ToolCall(call) = entry {
1625                let cancel = matches!(
1626                    call.status,
1627                    ToolCallStatus::Pending
1628                        | ToolCallStatus::WaitingForConfirmation { .. }
1629                        | ToolCallStatus::InProgress
1630                );
1631
1632                if cancel {
1633                    call.status = ToolCallStatus::Canceled;
1634                }
1635            }
1636        }
1637
1638        self.connection.cancel(&self.session_id, cx);
1639
1640        // Wait for the send task to complete
1641        cx.foreground_executor().spawn(send_task)
1642    }
1643
1644    /// Restores the git working tree to the state at the given checkpoint (if one exists)
1645    pub fn restore_checkpoint(
1646        &mut self,
1647        id: UserMessageId,
1648        cx: &mut Context<Self>,
1649    ) -> Task<Result<()>> {
1650        let Some((_, message)) = self.user_message_mut(&id) else {
1651            return Task::ready(Err(anyhow!("message not found")));
1652        };
1653
1654        let checkpoint = message
1655            .checkpoint
1656            .as_ref()
1657            .map(|c| c.git_checkpoint.clone());
1658        let rewind = self.rewind(id.clone(), cx);
1659        let git_store = self.project.read(cx).git_store().clone();
1660
1661        cx.spawn(async move |_, cx| {
1662            rewind.await?;
1663            if let Some(checkpoint) = checkpoint {
1664                git_store
1665                    .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1666                    .await?;
1667            }
1668
1669            Ok(())
1670        })
1671    }
1672
1673    /// Rewinds this thread to before the entry at `index`, removing it and all
1674    /// subsequent entries while rejecting any action_log changes made from that point.
1675    /// Unlike `restore_checkpoint`, this method does not restore from git.
1676    pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1677        let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1678            return Task::ready(Err(anyhow!("not supported")));
1679        };
1680
1681        cx.spawn(async move |this, cx| {
1682            cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1683            this.update(cx, |this, cx| {
1684                if let Some((ix, _)) = this.user_message_mut(&id) {
1685                    let range = ix..this.entries.len();
1686                    this.entries.truncate(ix);
1687                    cx.emit(AcpThreadEvent::EntriesRemoved(range));
1688                }
1689                this.action_log()
1690                    .update(cx, |action_log, cx| action_log.reject_all_edits(cx))
1691            })?
1692            .await;
1693            Ok(())
1694        })
1695    }
1696
1697    fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1698        let git_store = self.project.read(cx).git_store().clone();
1699
1700        let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1701            if let Some(checkpoint) = message.checkpoint.as_ref() {
1702                checkpoint.git_checkpoint.clone()
1703            } else {
1704                return Task::ready(Ok(()));
1705            }
1706        } else {
1707            return Task::ready(Ok(()));
1708        };
1709
1710        let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1711        cx.spawn(async move |this, cx| {
1712            let new_checkpoint = new_checkpoint
1713                .await
1714                .context("failed to get new checkpoint")
1715                .log_err();
1716            if let Some(new_checkpoint) = new_checkpoint {
1717                let equal = git_store
1718                    .update(cx, |git, cx| {
1719                        git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1720                    })?
1721                    .await
1722                    .unwrap_or(true);
1723                this.update(cx, |this, cx| {
1724                    let (ix, message) = this.last_user_message().context("no user message")?;
1725                    let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1726                    checkpoint.show = !equal;
1727                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1728                    anyhow::Ok(())
1729                })??;
1730            }
1731
1732            Ok(())
1733        })
1734    }
1735
1736    fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1737        self.entries
1738            .iter_mut()
1739            .enumerate()
1740            .rev()
1741            .find_map(|(ix, entry)| {
1742                if let AgentThreadEntry::UserMessage(message) = entry {
1743                    Some((ix, message))
1744                } else {
1745                    None
1746                }
1747            })
1748    }
1749
1750    fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1751        self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1752            if let AgentThreadEntry::UserMessage(message) = entry {
1753                if message.id.as_ref() == Some(id) {
1754                    Some((ix, message))
1755                } else {
1756                    None
1757                }
1758            } else {
1759                None
1760            }
1761        })
1762    }
1763
1764    pub fn read_text_file(
1765        &self,
1766        path: PathBuf,
1767        line: Option<u32>,
1768        limit: Option<u32>,
1769        reuse_shared_snapshot: bool,
1770        cx: &mut Context<Self>,
1771    ) -> Task<Result<String>> {
1772        // Args are 1-based, move to 0-based
1773        let line = line.unwrap_or_default().saturating_sub(1);
1774        let limit = limit.unwrap_or(u32::MAX);
1775        let project = self.project.clone();
1776        let action_log = self.action_log.clone();
1777        cx.spawn(async move |this, cx| {
1778            let load = project.update(cx, |project, cx| {
1779                let path = project
1780                    .project_path_for_absolute_path(&path, cx)
1781                    .context("invalid path")?;
1782                anyhow::Ok(project.open_buffer(path, cx))
1783            });
1784            let buffer = load??.await?;
1785
1786            let snapshot = if reuse_shared_snapshot {
1787                this.read_with(cx, |this, _| {
1788                    this.shared_buffers.get(&buffer.clone()).cloned()
1789                })
1790                .log_err()
1791                .flatten()
1792            } else {
1793                None
1794            };
1795
1796            let snapshot = if let Some(snapshot) = snapshot {
1797                snapshot
1798            } else {
1799                action_log.update(cx, |action_log, cx| {
1800                    action_log.buffer_read(buffer.clone(), cx);
1801                })?;
1802
1803                let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
1804                this.update(cx, |this, _| {
1805                    this.shared_buffers.insert(buffer.clone(), snapshot.clone());
1806                })?;
1807                snapshot
1808            };
1809
1810            let max_point = snapshot.max_point();
1811            if line >= max_point.row {
1812                anyhow::bail!(
1813                    "Attempting to read beyond the end of the file, line {}:{}",
1814                    max_point.row + 1,
1815                    max_point.column
1816                );
1817            }
1818
1819            let start = snapshot.anchor_before(Point::new(line, 0));
1820            let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
1821
1822            project.update(cx, |project, cx| {
1823                project.set_agent_location(
1824                    Some(AgentLocation {
1825                        buffer: buffer.downgrade(),
1826                        position: start,
1827                    }),
1828                    cx,
1829                );
1830            })?;
1831
1832            Ok(snapshot.text_for_range(start..end).collect::<String>())
1833        })
1834    }
1835
1836    pub fn write_text_file(
1837        &self,
1838        path: PathBuf,
1839        content: String,
1840        cx: &mut Context<Self>,
1841    ) -> Task<Result<()>> {
1842        let project = self.project.clone();
1843        let action_log = self.action_log.clone();
1844        cx.spawn(async move |this, cx| {
1845            let load = project.update(cx, |project, cx| {
1846                let path = project
1847                    .project_path_for_absolute_path(&path, cx)
1848                    .context("invalid path")?;
1849                anyhow::Ok(project.open_buffer(path, cx))
1850            });
1851            let buffer = load??.await?;
1852            let snapshot = this.update(cx, |this, cx| {
1853                this.shared_buffers
1854                    .get(&buffer)
1855                    .cloned()
1856                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1857            })?;
1858            let edits = cx
1859                .background_executor()
1860                .spawn(async move {
1861                    let old_text = snapshot.text();
1862                    text_diff(old_text.as_str(), &content)
1863                        .into_iter()
1864                        .map(|(range, replacement)| {
1865                            (
1866                                snapshot.anchor_after(range.start)
1867                                    ..snapshot.anchor_before(range.end),
1868                                replacement,
1869                            )
1870                        })
1871                        .collect::<Vec<_>>()
1872                })
1873                .await;
1874
1875            project.update(cx, |project, cx| {
1876                project.set_agent_location(
1877                    Some(AgentLocation {
1878                        buffer: buffer.downgrade(),
1879                        position: edits
1880                            .last()
1881                            .map(|(range, _)| range.end)
1882                            .unwrap_or(Anchor::MIN),
1883                    }),
1884                    cx,
1885                );
1886            })?;
1887
1888            let format_on_save = cx.update(|cx| {
1889                action_log.update(cx, |action_log, cx| {
1890                    action_log.buffer_read(buffer.clone(), cx);
1891                });
1892
1893                let format_on_save = buffer.update(cx, |buffer, cx| {
1894                    buffer.edit(edits, None, cx);
1895
1896                    let settings = language::language_settings::language_settings(
1897                        buffer.language().map(|l| l.name()),
1898                        buffer.file(),
1899                        cx,
1900                    );
1901
1902                    settings.format_on_save != FormatOnSave::Off
1903                });
1904                action_log.update(cx, |action_log, cx| {
1905                    action_log.buffer_edited(buffer.clone(), cx);
1906                });
1907                format_on_save
1908            })?;
1909
1910            if format_on_save {
1911                let format_task = project.update(cx, |project, cx| {
1912                    project.format(
1913                        HashSet::from_iter([buffer.clone()]),
1914                        LspFormatTarget::Buffers,
1915                        false,
1916                        FormatTrigger::Save,
1917                        cx,
1918                    )
1919                })?;
1920                format_task.await.log_err();
1921
1922                action_log.update(cx, |action_log, cx| {
1923                    action_log.buffer_edited(buffer.clone(), cx);
1924                })?;
1925            }
1926
1927            project
1928                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1929                .await
1930        })
1931    }
1932
1933    pub fn create_terminal(
1934        &self,
1935        mut command: String,
1936        args: Vec<String>,
1937        extra_env: Vec<acp::EnvVariable>,
1938        cwd: Option<PathBuf>,
1939        output_byte_limit: Option<u64>,
1940        cx: &mut Context<Self>,
1941    ) -> Task<Result<Entity<Terminal>>> {
1942        for arg in args {
1943            command.push(' ');
1944            command.push_str(&arg);
1945        }
1946
1947        let shell_command = if cfg!(windows) {
1948            format!("$null | & {{{}}}", command.replace("\"", "'"))
1949        } else if let Some(cwd) = cwd.as_ref().and_then(|cwd| cwd.as_os_str().to_str()) {
1950            // Make sure once we're *inside* the shell, we cd into `cwd`
1951            format!("(cd {cwd}; {}) </dev/null", command)
1952        } else {
1953            format!("({}) </dev/null", command)
1954        };
1955        let args = vec!["-c".into(), shell_command];
1956
1957        let env = match &cwd {
1958            Some(dir) => self.project.update(cx, |project, cx| {
1959                project.directory_environment(dir.as_path().into(), cx)
1960            }),
1961            None => Task::ready(None).shared(),
1962        };
1963
1964        let env = cx.spawn(async move |_, _| {
1965            let mut env = env.await.unwrap_or_default();
1966            if cfg!(unix) {
1967                env.insert("PAGER".into(), "cat".into());
1968            }
1969            for var in extra_env {
1970                env.insert(var.name, var.value);
1971            }
1972            env
1973        });
1974
1975        let project = self.project.clone();
1976        let language_registry = project.read(cx).languages().clone();
1977        let determine_shell = self.determine_shell.clone();
1978
1979        let terminal_id = acp::TerminalId(Uuid::new_v4().to_string().into());
1980        let terminal_task = cx.spawn({
1981            let terminal_id = terminal_id.clone();
1982            async move |_this, cx| {
1983                let program = determine_shell.await;
1984                let env = env.await;
1985                let terminal = project
1986                    .update(cx, |project, cx| {
1987                        project.create_terminal_task(
1988                            task::SpawnInTerminal {
1989                                command: Some(program),
1990                                args,
1991                                cwd: cwd.clone(),
1992                                env,
1993                                ..Default::default()
1994                            },
1995                            cx,
1996                        )
1997                    })?
1998                    .await?;
1999
2000                cx.new(|cx| {
2001                    Terminal::new(
2002                        terminal_id,
2003                        command,
2004                        cwd,
2005                        output_byte_limit.map(|l| l as usize),
2006                        terminal,
2007                        language_registry,
2008                        cx,
2009                    )
2010                })
2011            }
2012        });
2013
2014        cx.spawn(async move |this, cx| {
2015            let terminal = terminal_task.await?;
2016            this.update(cx, |this, _cx| {
2017                this.terminals.insert(terminal_id, terminal.clone());
2018                terminal
2019            })
2020        })
2021    }
2022
2023    pub fn kill_terminal(
2024        &mut self,
2025        terminal_id: acp::TerminalId,
2026        cx: &mut Context<Self>,
2027    ) -> Result<()> {
2028        self.terminals
2029            .get(&terminal_id)
2030            .context("Terminal not found")?
2031            .update(cx, |terminal, cx| {
2032                terminal.kill(cx);
2033            });
2034
2035        Ok(())
2036    }
2037
2038    pub fn release_terminal(
2039        &mut self,
2040        terminal_id: acp::TerminalId,
2041        cx: &mut Context<Self>,
2042    ) -> Result<()> {
2043        self.terminals
2044            .remove(&terminal_id)
2045            .context("Terminal not found")?
2046            .update(cx, |terminal, cx| {
2047                terminal.kill(cx);
2048            });
2049
2050        Ok(())
2051    }
2052
2053    pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2054        self.terminals
2055            .get(&terminal_id)
2056            .context("Terminal not found")
2057            .cloned()
2058    }
2059
2060    pub fn to_markdown(&self, cx: &App) -> String {
2061        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2062    }
2063
2064    pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2065        cx.emit(AcpThreadEvent::LoadError(error));
2066    }
2067}
2068
2069fn markdown_for_raw_output(
2070    raw_output: &serde_json::Value,
2071    language_registry: &Arc<LanguageRegistry>,
2072    cx: &mut App,
2073) -> Option<Entity<Markdown>> {
2074    match raw_output {
2075        serde_json::Value::Null => None,
2076        serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2077            Markdown::new(
2078                value.to_string().into(),
2079                Some(language_registry.clone()),
2080                None,
2081                cx,
2082            )
2083        })),
2084        serde_json::Value::Number(value) => Some(cx.new(|cx| {
2085            Markdown::new(
2086                value.to_string().into(),
2087                Some(language_registry.clone()),
2088                None,
2089                cx,
2090            )
2091        })),
2092        serde_json::Value::String(value) => Some(cx.new(|cx| {
2093            Markdown::new(
2094                value.clone().into(),
2095                Some(language_registry.clone()),
2096                None,
2097                cx,
2098            )
2099        })),
2100        value => Some(cx.new(|cx| {
2101            Markdown::new(
2102                format!("```json\n{}\n```", value).into(),
2103                Some(language_registry.clone()),
2104                None,
2105                cx,
2106            )
2107        })),
2108    }
2109}
2110
2111#[cfg(test)]
2112mod tests {
2113    use super::*;
2114    use anyhow::anyhow;
2115    use futures::{channel::mpsc, future::LocalBoxFuture, select};
2116    use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2117    use indoc::indoc;
2118    use project::{FakeFs, Fs};
2119    use rand::{distr, prelude::*};
2120    use serde_json::json;
2121    use settings::SettingsStore;
2122    use smol::stream::StreamExt as _;
2123    use std::{
2124        any::Any,
2125        cell::RefCell,
2126        path::Path,
2127        rc::Rc,
2128        sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2129        time::Duration,
2130    };
2131    use util::path;
2132
2133    fn init_test(cx: &mut TestAppContext) {
2134        env_logger::try_init().ok();
2135        cx.update(|cx| {
2136            let settings_store = SettingsStore::test(cx);
2137            cx.set_global(settings_store);
2138            Project::init_settings(cx);
2139            language::init(cx);
2140        });
2141    }
2142
2143    #[gpui::test]
2144    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2145        init_test(cx);
2146
2147        let fs = FakeFs::new(cx.executor());
2148        let project = Project::test(fs, [], cx).await;
2149        let connection = Rc::new(FakeAgentConnection::new());
2150        let thread = cx
2151            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2152            .await
2153            .unwrap();
2154
2155        // Test creating a new user message
2156        thread.update(cx, |thread, cx| {
2157            thread.push_user_content_block(
2158                None,
2159                acp::ContentBlock::Text(acp::TextContent {
2160                    annotations: None,
2161                    text: "Hello, ".to_string(),
2162                }),
2163                cx,
2164            );
2165        });
2166
2167        thread.update(cx, |thread, cx| {
2168            assert_eq!(thread.entries.len(), 1);
2169            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2170                assert_eq!(user_msg.id, None);
2171                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2172            } else {
2173                panic!("Expected UserMessage");
2174            }
2175        });
2176
2177        // Test appending to existing user message
2178        let message_1_id = UserMessageId::new();
2179        thread.update(cx, |thread, cx| {
2180            thread.push_user_content_block(
2181                Some(message_1_id.clone()),
2182                acp::ContentBlock::Text(acp::TextContent {
2183                    annotations: None,
2184                    text: "world!".to_string(),
2185                }),
2186                cx,
2187            );
2188        });
2189
2190        thread.update(cx, |thread, cx| {
2191            assert_eq!(thread.entries.len(), 1);
2192            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2193                assert_eq!(user_msg.id, Some(message_1_id));
2194                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2195            } else {
2196                panic!("Expected UserMessage");
2197            }
2198        });
2199
2200        // Test creating new user message after assistant message
2201        thread.update(cx, |thread, cx| {
2202            thread.push_assistant_content_block(
2203                acp::ContentBlock::Text(acp::TextContent {
2204                    annotations: None,
2205                    text: "Assistant response".to_string(),
2206                }),
2207                false,
2208                cx,
2209            );
2210        });
2211
2212        let message_2_id = UserMessageId::new();
2213        thread.update(cx, |thread, cx| {
2214            thread.push_user_content_block(
2215                Some(message_2_id.clone()),
2216                acp::ContentBlock::Text(acp::TextContent {
2217                    annotations: None,
2218                    text: "New user message".to_string(),
2219                }),
2220                cx,
2221            );
2222        });
2223
2224        thread.update(cx, |thread, cx| {
2225            assert_eq!(thread.entries.len(), 3);
2226            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2227                assert_eq!(user_msg.id, Some(message_2_id));
2228                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2229            } else {
2230                panic!("Expected UserMessage at index 2");
2231            }
2232        });
2233    }
2234
2235    #[gpui::test]
2236    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2237        init_test(cx);
2238
2239        let fs = FakeFs::new(cx.executor());
2240        let project = Project::test(fs, [], cx).await;
2241        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2242            |_, thread, mut cx| {
2243                async move {
2244                    thread.update(&mut cx, |thread, cx| {
2245                        thread
2246                            .handle_session_update(
2247                                acp::SessionUpdate::AgentThoughtChunk {
2248                                    content: "Thinking ".into(),
2249                                },
2250                                cx,
2251                            )
2252                            .unwrap();
2253                        thread
2254                            .handle_session_update(
2255                                acp::SessionUpdate::AgentThoughtChunk {
2256                                    content: "hard!".into(),
2257                                },
2258                                cx,
2259                            )
2260                            .unwrap();
2261                    })?;
2262                    Ok(acp::PromptResponse {
2263                        stop_reason: acp::StopReason::EndTurn,
2264                    })
2265                }
2266                .boxed_local()
2267            },
2268        ));
2269
2270        let thread = cx
2271            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2272            .await
2273            .unwrap();
2274
2275        thread
2276            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2277            .await
2278            .unwrap();
2279
2280        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2281        assert_eq!(
2282            output,
2283            indoc! {r#"
2284            ## User
2285
2286            Hello from Zed!
2287
2288            ## Assistant
2289
2290            <thinking>
2291            Thinking hard!
2292            </thinking>
2293
2294            "#}
2295        );
2296    }
2297
2298    #[gpui::test]
2299    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2300        init_test(cx);
2301
2302        let fs = FakeFs::new(cx.executor());
2303        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2304            .await;
2305        let project = Project::test(fs.clone(), [], cx).await;
2306        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2307        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2308        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2309            move |_, thread, mut cx| {
2310                let read_file_tx = read_file_tx.clone();
2311                async move {
2312                    let content = thread
2313                        .update(&mut cx, |thread, cx| {
2314                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2315                        })
2316                        .unwrap()
2317                        .await
2318                        .unwrap();
2319                    assert_eq!(content, "one\ntwo\nthree\n");
2320                    read_file_tx.take().unwrap().send(()).unwrap();
2321                    thread
2322                        .update(&mut cx, |thread, cx| {
2323                            thread.write_text_file(
2324                                path!("/tmp/foo").into(),
2325                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
2326                                cx,
2327                            )
2328                        })
2329                        .unwrap()
2330                        .await
2331                        .unwrap();
2332                    Ok(acp::PromptResponse {
2333                        stop_reason: acp::StopReason::EndTurn,
2334                    })
2335                }
2336                .boxed_local()
2337            },
2338        ));
2339
2340        let (worktree, pathbuf) = project
2341            .update(cx, |project, cx| {
2342                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2343            })
2344            .await
2345            .unwrap();
2346        let buffer = project
2347            .update(cx, |project, cx| {
2348                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2349            })
2350            .await
2351            .unwrap();
2352
2353        let thread = cx
2354            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2355            .await
2356            .unwrap();
2357
2358        let request = thread.update(cx, |thread, cx| {
2359            thread.send_raw("Extend the count in /tmp/foo", cx)
2360        });
2361        read_file_rx.await.ok();
2362        buffer.update(cx, |buffer, cx| {
2363            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2364        });
2365        cx.run_until_parked();
2366        assert_eq!(
2367            buffer.read_with(cx, |buffer, _| buffer.text()),
2368            "zero\none\ntwo\nthree\nfour\nfive\n"
2369        );
2370        assert_eq!(
2371            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2372            "zero\none\ntwo\nthree\nfour\nfive\n"
2373        );
2374        request.await.unwrap();
2375    }
2376
2377    #[gpui::test]
2378    async fn test_reading_from_line(cx: &mut TestAppContext) {
2379        init_test(cx);
2380
2381        let fs = FakeFs::new(cx.executor());
2382        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
2383            .await;
2384        let project = Project::test(fs.clone(), [], cx).await;
2385        project
2386            .update(cx, |project, cx| {
2387                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2388            })
2389            .await
2390            .unwrap();
2391
2392        let connection = Rc::new(FakeAgentConnection::new());
2393
2394        let thread = cx
2395            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2396            .await
2397            .unwrap();
2398
2399        // Whole file
2400        let content = thread
2401            .update(cx, |thread, cx| {
2402                thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2403            })
2404            .await
2405            .unwrap();
2406
2407        assert_eq!(content, "one\ntwo\nthree\nfour\n");
2408
2409        // Only start line
2410        let content = thread
2411            .update(cx, |thread, cx| {
2412                thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
2413            })
2414            .await
2415            .unwrap();
2416
2417        assert_eq!(content, "three\nfour\n");
2418
2419        // Only limit
2420        let content = thread
2421            .update(cx, |thread, cx| {
2422                thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2423            })
2424            .await
2425            .unwrap();
2426
2427        assert_eq!(content, "one\ntwo\n");
2428
2429        // Range
2430        let content = thread
2431            .update(cx, |thread, cx| {
2432                thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
2433            })
2434            .await
2435            .unwrap();
2436
2437        assert_eq!(content, "two\nthree\n");
2438
2439        // Invalid
2440        let err = thread
2441            .update(cx, |thread, cx| {
2442                thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
2443            })
2444            .await
2445            .unwrap_err();
2446
2447        assert_eq!(
2448            err.to_string(),
2449            "Attempting to read beyond the end of the file, line 5:0"
2450        );
2451    }
2452
2453    #[gpui::test]
2454    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2455        init_test(cx);
2456
2457        let fs = FakeFs::new(cx.executor());
2458        let project = Project::test(fs, [], cx).await;
2459        let id = acp::ToolCallId("test".into());
2460
2461        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2462            let id = id.clone();
2463            move |_, thread, mut cx| {
2464                let id = id.clone();
2465                async move {
2466                    thread
2467                        .update(&mut cx, |thread, cx| {
2468                            thread.handle_session_update(
2469                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2470                                    id: id.clone(),
2471                                    title: "Label".into(),
2472                                    kind: acp::ToolKind::Fetch,
2473                                    status: acp::ToolCallStatus::InProgress,
2474                                    content: vec![],
2475                                    locations: vec![],
2476                                    raw_input: None,
2477                                    raw_output: None,
2478                                }),
2479                                cx,
2480                            )
2481                        })
2482                        .unwrap()
2483                        .unwrap();
2484                    Ok(acp::PromptResponse {
2485                        stop_reason: acp::StopReason::EndTurn,
2486                    })
2487                }
2488                .boxed_local()
2489            }
2490        }));
2491
2492        let thread = cx
2493            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2494            .await
2495            .unwrap();
2496
2497        let request = thread.update(cx, |thread, cx| {
2498            thread.send_raw("Fetch https://example.com", cx)
2499        });
2500
2501        run_until_first_tool_call(&thread, cx).await;
2502
2503        thread.read_with(cx, |thread, _| {
2504            assert!(matches!(
2505                thread.entries[1],
2506                AgentThreadEntry::ToolCall(ToolCall {
2507                    status: ToolCallStatus::InProgress,
2508                    ..
2509                })
2510            ));
2511        });
2512
2513        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2514
2515        thread.read_with(cx, |thread, _| {
2516            assert!(matches!(
2517                &thread.entries[1],
2518                AgentThreadEntry::ToolCall(ToolCall {
2519                    status: ToolCallStatus::Canceled,
2520                    ..
2521                })
2522            ));
2523        });
2524
2525        thread
2526            .update(cx, |thread, cx| {
2527                thread.handle_session_update(
2528                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2529                        id,
2530                        fields: acp::ToolCallUpdateFields {
2531                            status: Some(acp::ToolCallStatus::Completed),
2532                            ..Default::default()
2533                        },
2534                    }),
2535                    cx,
2536                )
2537            })
2538            .unwrap();
2539
2540        request.await.unwrap();
2541
2542        thread.read_with(cx, |thread, _| {
2543            assert!(matches!(
2544                thread.entries[1],
2545                AgentThreadEntry::ToolCall(ToolCall {
2546                    status: ToolCallStatus::Completed,
2547                    ..
2548                })
2549            ));
2550        });
2551    }
2552
2553    #[gpui::test]
2554    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2555        init_test(cx);
2556        let fs = FakeFs::new(cx.background_executor.clone());
2557        fs.insert_tree(path!("/test"), json!({})).await;
2558        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2559
2560        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2561            move |_, thread, mut cx| {
2562                async move {
2563                    thread
2564                        .update(&mut cx, |thread, cx| {
2565                            thread.handle_session_update(
2566                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2567                                    id: acp::ToolCallId("test".into()),
2568                                    title: "Label".into(),
2569                                    kind: acp::ToolKind::Edit,
2570                                    status: acp::ToolCallStatus::Completed,
2571                                    content: vec![acp::ToolCallContent::Diff {
2572                                        diff: acp::Diff {
2573                                            path: "/test/test.txt".into(),
2574                                            old_text: None,
2575                                            new_text: "foo".into(),
2576                                        },
2577                                    }],
2578                                    locations: vec![],
2579                                    raw_input: None,
2580                                    raw_output: None,
2581                                }),
2582                                cx,
2583                            )
2584                        })
2585                        .unwrap()
2586                        .unwrap();
2587                    Ok(acp::PromptResponse {
2588                        stop_reason: acp::StopReason::EndTurn,
2589                    })
2590                }
2591                .boxed_local()
2592            }
2593        }));
2594
2595        let thread = cx
2596            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2597            .await
2598            .unwrap();
2599
2600        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2601            .await
2602            .unwrap();
2603
2604        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2605    }
2606
2607    #[gpui::test(iterations = 10)]
2608    async fn test_checkpoints(cx: &mut TestAppContext) {
2609        init_test(cx);
2610        let fs = FakeFs::new(cx.background_executor.clone());
2611        fs.insert_tree(
2612            path!("/test"),
2613            json!({
2614                ".git": {}
2615            }),
2616        )
2617        .await;
2618        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2619
2620        let simulate_changes = Arc::new(AtomicBool::new(true));
2621        let next_filename = Arc::new(AtomicUsize::new(0));
2622        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2623            let simulate_changes = simulate_changes.clone();
2624            let next_filename = next_filename.clone();
2625            let fs = fs.clone();
2626            move |request, thread, mut cx| {
2627                let fs = fs.clone();
2628                let simulate_changes = simulate_changes.clone();
2629                let next_filename = next_filename.clone();
2630                async move {
2631                    if simulate_changes.load(SeqCst) {
2632                        let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2633                        fs.write(Path::new(&filename), b"").await?;
2634                    }
2635
2636                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2637                        panic!("expected text content block");
2638                    };
2639                    thread.update(&mut cx, |thread, cx| {
2640                        thread
2641                            .handle_session_update(
2642                                acp::SessionUpdate::AgentMessageChunk {
2643                                    content: content.text.to_uppercase().into(),
2644                                },
2645                                cx,
2646                            )
2647                            .unwrap();
2648                    })?;
2649                    Ok(acp::PromptResponse {
2650                        stop_reason: acp::StopReason::EndTurn,
2651                    })
2652                }
2653                .boxed_local()
2654            }
2655        }));
2656        let thread = cx
2657            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2658            .await
2659            .unwrap();
2660
2661        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2662            .await
2663            .unwrap();
2664        thread.read_with(cx, |thread, cx| {
2665            assert_eq!(
2666                thread.to_markdown(cx),
2667                indoc! {"
2668                    ## User (checkpoint)
2669
2670                    Lorem
2671
2672                    ## Assistant
2673
2674                    LOREM
2675
2676                "}
2677            );
2678        });
2679        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2680
2681        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2682            .await
2683            .unwrap();
2684        thread.read_with(cx, |thread, cx| {
2685            assert_eq!(
2686                thread.to_markdown(cx),
2687                indoc! {"
2688                    ## User (checkpoint)
2689
2690                    Lorem
2691
2692                    ## Assistant
2693
2694                    LOREM
2695
2696                    ## User (checkpoint)
2697
2698                    ipsum
2699
2700                    ## Assistant
2701
2702                    IPSUM
2703
2704                "}
2705            );
2706        });
2707        assert_eq!(
2708            fs.files(),
2709            vec![
2710                Path::new(path!("/test/file-0")),
2711                Path::new(path!("/test/file-1"))
2712            ]
2713        );
2714
2715        // Checkpoint isn't stored when there are no changes.
2716        simulate_changes.store(false, SeqCst);
2717        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2718            .await
2719            .unwrap();
2720        thread.read_with(cx, |thread, cx| {
2721            assert_eq!(
2722                thread.to_markdown(cx),
2723                indoc! {"
2724                    ## User (checkpoint)
2725
2726                    Lorem
2727
2728                    ## Assistant
2729
2730                    LOREM
2731
2732                    ## User (checkpoint)
2733
2734                    ipsum
2735
2736                    ## Assistant
2737
2738                    IPSUM
2739
2740                    ## User
2741
2742                    dolor
2743
2744                    ## Assistant
2745
2746                    DOLOR
2747
2748                "}
2749            );
2750        });
2751        assert_eq!(
2752            fs.files(),
2753            vec![
2754                Path::new(path!("/test/file-0")),
2755                Path::new(path!("/test/file-1"))
2756            ]
2757        );
2758
2759        // Rewinding the conversation truncates the history and restores the checkpoint.
2760        thread
2761            .update(cx, |thread, cx| {
2762                let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2763                    panic!("unexpected entries {:?}", thread.entries)
2764                };
2765                thread.restore_checkpoint(message.id.clone().unwrap(), cx)
2766            })
2767            .await
2768            .unwrap();
2769        thread.read_with(cx, |thread, cx| {
2770            assert_eq!(
2771                thread.to_markdown(cx),
2772                indoc! {"
2773                    ## User (checkpoint)
2774
2775                    Lorem
2776
2777                    ## Assistant
2778
2779                    LOREM
2780
2781                "}
2782            );
2783        });
2784        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2785    }
2786
2787    #[gpui::test]
2788    async fn test_tool_result_refusal(cx: &mut TestAppContext) {
2789        use std::sync::atomic::AtomicUsize;
2790        init_test(cx);
2791
2792        let fs = FakeFs::new(cx.executor());
2793        let project = Project::test(fs, None, cx).await;
2794
2795        // Create a connection that simulates refusal after tool result
2796        let prompt_count = Arc::new(AtomicUsize::new(0));
2797        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2798            let prompt_count = prompt_count.clone();
2799            move |_request, thread, mut cx| {
2800                let count = prompt_count.fetch_add(1, SeqCst);
2801                async move {
2802                    if count == 0 {
2803                        // First prompt: Generate a tool call with result
2804                        thread.update(&mut cx, |thread, cx| {
2805                            thread
2806                                .handle_session_update(
2807                                    acp::SessionUpdate::ToolCall(acp::ToolCall {
2808                                        id: acp::ToolCallId("tool1".into()),
2809                                        title: "Test Tool".into(),
2810                                        kind: acp::ToolKind::Fetch,
2811                                        status: acp::ToolCallStatus::Completed,
2812                                        content: vec![],
2813                                        locations: vec![],
2814                                        raw_input: Some(serde_json::json!({"query": "test"})),
2815                                        raw_output: Some(
2816                                            serde_json::json!({"result": "inappropriate content"}),
2817                                        ),
2818                                    }),
2819                                    cx,
2820                                )
2821                                .unwrap();
2822                        })?;
2823
2824                        // Now return refusal because of the tool result
2825                        Ok(acp::PromptResponse {
2826                            stop_reason: acp::StopReason::Refusal,
2827                        })
2828                    } else {
2829                        Ok(acp::PromptResponse {
2830                            stop_reason: acp::StopReason::EndTurn,
2831                        })
2832                    }
2833                }
2834                .boxed_local()
2835            }
2836        }));
2837
2838        let thread = cx
2839            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2840            .await
2841            .unwrap();
2842
2843        // Track if we see a Refusal event
2844        let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
2845        let saw_refusal_event_captured = saw_refusal_event.clone();
2846        thread.update(cx, |_thread, cx| {
2847            cx.subscribe(
2848                &thread,
2849                move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
2850                    if matches!(event, AcpThreadEvent::Refusal) {
2851                        *saw_refusal_event_captured.lock().unwrap() = true;
2852                    }
2853                },
2854            )
2855            .detach();
2856        });
2857
2858        // Send a user message - this will trigger tool call and then refusal
2859        let send_task = thread.update(cx, |thread, cx| {
2860            thread.send(
2861                vec![acp::ContentBlock::Text(acp::TextContent {
2862                    text: "Hello".into(),
2863                    annotations: None,
2864                })],
2865                cx,
2866            )
2867        });
2868        cx.background_executor.spawn(send_task).detach();
2869        cx.run_until_parked();
2870
2871        // Verify that:
2872        // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
2873        // 2. The user message was NOT truncated
2874        assert!(
2875            *saw_refusal_event.lock().unwrap(),
2876            "Refusal event should be emitted for tool result refusals"
2877        );
2878
2879        thread.read_with(cx, |thread, _| {
2880            let entries = thread.entries();
2881            assert!(entries.len() >= 2, "Should have user message and tool call");
2882
2883            // Verify user message is still there
2884            assert!(
2885                matches!(entries[0], AgentThreadEntry::UserMessage(_)),
2886                "User message should not be truncated"
2887            );
2888
2889            // Verify tool call is there with result
2890            if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
2891                assert!(
2892                    tool_call.raw_output.is_some(),
2893                    "Tool call should have output"
2894                );
2895            } else {
2896                panic!("Expected tool call at index 1");
2897            }
2898        });
2899    }
2900
2901    #[gpui::test]
2902    async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
2903        init_test(cx);
2904
2905        let fs = FakeFs::new(cx.executor());
2906        let project = Project::test(fs, None, cx).await;
2907
2908        let refuse_next = Arc::new(AtomicBool::new(false));
2909        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2910            let refuse_next = refuse_next.clone();
2911            move |_request, _thread, _cx| {
2912                if refuse_next.load(SeqCst) {
2913                    async move {
2914                        Ok(acp::PromptResponse {
2915                            stop_reason: acp::StopReason::Refusal,
2916                        })
2917                    }
2918                    .boxed_local()
2919                } else {
2920                    async move {
2921                        Ok(acp::PromptResponse {
2922                            stop_reason: acp::StopReason::EndTurn,
2923                        })
2924                    }
2925                    .boxed_local()
2926                }
2927            }
2928        }));
2929
2930        let thread = cx
2931            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2932            .await
2933            .unwrap();
2934
2935        // Track if we see a Refusal event
2936        let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
2937        let saw_refusal_event_captured = saw_refusal_event.clone();
2938        thread.update(cx, |_thread, cx| {
2939            cx.subscribe(
2940                &thread,
2941                move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
2942                    if matches!(event, AcpThreadEvent::Refusal) {
2943                        *saw_refusal_event_captured.lock().unwrap() = true;
2944                    }
2945                },
2946            )
2947            .detach();
2948        });
2949
2950        // Send a message that will be refused
2951        refuse_next.store(true, SeqCst);
2952        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2953            .await
2954            .unwrap();
2955
2956        // Verify that a Refusal event WAS emitted for user prompt refusal
2957        assert!(
2958            *saw_refusal_event.lock().unwrap(),
2959            "Refusal event should be emitted for user prompt refusals"
2960        );
2961
2962        // Verify the message was truncated (user prompt refusal)
2963        thread.read_with(cx, |thread, cx| {
2964            assert_eq!(thread.to_markdown(cx), "");
2965        });
2966    }
2967
2968    #[gpui::test]
2969    async fn test_refusal(cx: &mut TestAppContext) {
2970        init_test(cx);
2971        let fs = FakeFs::new(cx.background_executor.clone());
2972        fs.insert_tree(path!("/"), json!({})).await;
2973        let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2974
2975        let refuse_next = Arc::new(AtomicBool::new(false));
2976        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2977            let refuse_next = refuse_next.clone();
2978            move |request, thread, mut cx| {
2979                let refuse_next = refuse_next.clone();
2980                async move {
2981                    if refuse_next.load(SeqCst) {
2982                        return Ok(acp::PromptResponse {
2983                            stop_reason: acp::StopReason::Refusal,
2984                        });
2985                    }
2986
2987                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2988                        panic!("expected text content block");
2989                    };
2990                    thread.update(&mut cx, |thread, cx| {
2991                        thread
2992                            .handle_session_update(
2993                                acp::SessionUpdate::AgentMessageChunk {
2994                                    content: content.text.to_uppercase().into(),
2995                                },
2996                                cx,
2997                            )
2998                            .unwrap();
2999                    })?;
3000                    Ok(acp::PromptResponse {
3001                        stop_reason: acp::StopReason::EndTurn,
3002                    })
3003                }
3004                .boxed_local()
3005            }
3006        }));
3007        let thread = cx
3008            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3009            .await
3010            .unwrap();
3011
3012        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3013            .await
3014            .unwrap();
3015        thread.read_with(cx, |thread, cx| {
3016            assert_eq!(
3017                thread.to_markdown(cx),
3018                indoc! {"
3019                    ## User
3020
3021                    hello
3022
3023                    ## Assistant
3024
3025                    HELLO
3026
3027                "}
3028            );
3029        });
3030
3031        // Simulate refusing the second message. The message should be truncated
3032        // when a user prompt is refused.
3033        refuse_next.store(true, SeqCst);
3034        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3035            .await
3036            .unwrap();
3037        thread.read_with(cx, |thread, cx| {
3038            assert_eq!(
3039                thread.to_markdown(cx),
3040                indoc! {"
3041                    ## User
3042
3043                    hello
3044
3045                    ## Assistant
3046
3047                    HELLO
3048
3049                "}
3050            );
3051        });
3052    }
3053
3054    async fn run_until_first_tool_call(
3055        thread: &Entity<AcpThread>,
3056        cx: &mut TestAppContext,
3057    ) -> usize {
3058        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3059
3060        let subscription = cx.update(|cx| {
3061            cx.subscribe(thread, move |thread, _, cx| {
3062                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3063                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3064                        return tx.try_send(ix).unwrap();
3065                    }
3066                }
3067            })
3068        });
3069
3070        select! {
3071            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
3072                panic!("Timeout waiting for tool call")
3073            }
3074            ix = rx.next().fuse() => {
3075                drop(subscription);
3076                ix.unwrap()
3077            }
3078        }
3079    }
3080
3081    #[derive(Clone, Default)]
3082    struct FakeAgentConnection {
3083        auth_methods: Vec<acp::AuthMethod>,
3084        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3085        on_user_message: Option<
3086            Rc<
3087                dyn Fn(
3088                        acp::PromptRequest,
3089                        WeakEntity<AcpThread>,
3090                        AsyncApp,
3091                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3092                    + 'static,
3093            >,
3094        >,
3095    }
3096
3097    impl FakeAgentConnection {
3098        fn new() -> Self {
3099            Self {
3100                auth_methods: Vec::new(),
3101                on_user_message: None,
3102                sessions: Arc::default(),
3103            }
3104        }
3105
3106        #[expect(unused)]
3107        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3108            self.auth_methods = auth_methods;
3109            self
3110        }
3111
3112        fn on_user_message(
3113            mut self,
3114            handler: impl Fn(
3115                acp::PromptRequest,
3116                WeakEntity<AcpThread>,
3117                AsyncApp,
3118            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3119            + 'static,
3120        ) -> Self {
3121            self.on_user_message.replace(Rc::new(handler));
3122            self
3123        }
3124    }
3125
3126    impl AgentConnection for FakeAgentConnection {
3127        fn auth_methods(&self) -> &[acp::AuthMethod] {
3128            &self.auth_methods
3129        }
3130
3131        fn new_thread(
3132            self: Rc<Self>,
3133            project: Entity<Project>,
3134            _cwd: &Path,
3135            cx: &mut App,
3136        ) -> Task<gpui::Result<Entity<AcpThread>>> {
3137            let session_id = acp::SessionId(
3138                rand::rng()
3139                    .sample_iter(&distr::Alphanumeric)
3140                    .take(7)
3141                    .map(char::from)
3142                    .collect::<String>()
3143                    .into(),
3144            );
3145            let action_log = cx.new(|_| ActionLog::new(project.clone()));
3146            let thread = cx.new(|cx| {
3147                AcpThread::new(
3148                    "Test",
3149                    self.clone(),
3150                    project,
3151                    action_log,
3152                    session_id.clone(),
3153                    watch::Receiver::constant(acp::PromptCapabilities {
3154                        image: true,
3155                        audio: true,
3156                        embedded_context: true,
3157                    }),
3158                    cx,
3159                )
3160            });
3161            self.sessions.lock().insert(session_id, thread.downgrade());
3162            Task::ready(Ok(thread))
3163        }
3164
3165        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3166            if self.auth_methods().iter().any(|m| m.id == method) {
3167                Task::ready(Ok(()))
3168            } else {
3169                Task::ready(Err(anyhow!("Invalid Auth Method")))
3170            }
3171        }
3172
3173        fn prompt(
3174            &self,
3175            _id: Option<UserMessageId>,
3176            params: acp::PromptRequest,
3177            cx: &mut App,
3178        ) -> Task<gpui::Result<acp::PromptResponse>> {
3179            let sessions = self.sessions.lock();
3180            let thread = sessions.get(&params.session_id).unwrap();
3181            if let Some(handler) = &self.on_user_message {
3182                let handler = handler.clone();
3183                let thread = thread.clone();
3184                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3185            } else {
3186                Task::ready(Ok(acp::PromptResponse {
3187                    stop_reason: acp::StopReason::EndTurn,
3188                }))
3189            }
3190        }
3191
3192        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3193            let sessions = self.sessions.lock();
3194            let thread = sessions.get(session_id).unwrap().clone();
3195
3196            cx.spawn(async move |cx| {
3197                thread
3198                    .update(cx, |thread, cx| thread.cancel(cx))
3199                    .unwrap()
3200                    .await
3201            })
3202            .detach();
3203        }
3204
3205        fn truncate(
3206            &self,
3207            session_id: &acp::SessionId,
3208            _cx: &App,
3209        ) -> Option<Rc<dyn AgentSessionTruncate>> {
3210            Some(Rc::new(FakeAgentSessionEditor {
3211                _session_id: session_id.clone(),
3212            }))
3213        }
3214
3215        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3216            self
3217        }
3218    }
3219
3220    struct FakeAgentSessionEditor {
3221        _session_id: acp::SessionId,
3222    }
3223
3224    impl AgentSessionTruncate for FakeAgentSessionEditor {
3225        fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3226            Task::ready(Ok(()))
3227        }
3228    }
3229}