acp_thread.rs

   1mod connection;
   2mod diff;
   3mod mention;
   4mod terminal;
   5
   6use agent_settings::AgentSettings;
   7use collections::HashSet;
   8pub use connection::*;
   9pub use diff::*;
  10use futures::future::Shared;
  11use language::language_settings::FormatOnSave;
  12pub use mention::*;
  13use project::lsp_store::{FormatTrigger, LspFormatTarget};
  14use serde::{Deserialize, Serialize};
  15use settings::Settings as _;
  16pub use terminal::*;
  17
  18use action_log::ActionLog;
  19use agent_client_protocol::{self as acp};
  20use anyhow::{Context as _, Result, anyhow};
  21use editor::Bias;
  22use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  23use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  24use itertools::Itertools;
  25use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
  26use markdown::Markdown;
  27use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
  28use std::collections::HashMap;
  29use std::error::Error;
  30use std::fmt::{Formatter, Write};
  31use std::ops::Range;
  32use std::process::ExitStatus;
  33use std::rc::Rc;
  34use std::time::{Duration, Instant};
  35use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
  36use ui::App;
  37use util::{ResultExt, get_system_shell};
  38use uuid::Uuid;
  39
  40#[derive(Debug)]
  41pub struct UserMessage {
  42    pub id: Option<UserMessageId>,
  43    pub content: ContentBlock,
  44    pub chunks: Vec<acp::ContentBlock>,
  45    pub checkpoint: Option<Checkpoint>,
  46}
  47
  48#[derive(Debug)]
  49pub struct Checkpoint {
  50    git_checkpoint: GitStoreCheckpoint,
  51    pub show: bool,
  52}
  53
  54impl UserMessage {
  55    fn to_markdown(&self, cx: &App) -> String {
  56        let mut markdown = String::new();
  57        if self
  58            .checkpoint
  59            .as_ref()
  60            .is_some_and(|checkpoint| checkpoint.show)
  61        {
  62            writeln!(markdown, "## User (checkpoint)").unwrap();
  63        } else {
  64            writeln!(markdown, "## User").unwrap();
  65        }
  66        writeln!(markdown).unwrap();
  67        writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
  68        writeln!(markdown).unwrap();
  69        markdown
  70    }
  71}
  72
  73#[derive(Debug, PartialEq)]
  74pub struct AssistantMessage {
  75    pub chunks: Vec<AssistantMessageChunk>,
  76}
  77
  78impl AssistantMessage {
  79    pub fn to_markdown(&self, cx: &App) -> String {
  80        format!(
  81            "## Assistant\n\n{}\n\n",
  82            self.chunks
  83                .iter()
  84                .map(|chunk| chunk.to_markdown(cx))
  85                .join("\n\n")
  86        )
  87    }
  88}
  89
  90#[derive(Debug, PartialEq)]
  91pub enum AssistantMessageChunk {
  92    Message { block: ContentBlock },
  93    Thought { block: ContentBlock },
  94}
  95
  96impl AssistantMessageChunk {
  97    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
  98        Self::Message {
  99            block: ContentBlock::new(chunk.into(), language_registry, cx),
 100        }
 101    }
 102
 103    fn to_markdown(&self, cx: &App) -> String {
 104        match self {
 105            Self::Message { block } => block.to_markdown(cx).to_string(),
 106            Self::Thought { block } => {
 107                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 108            }
 109        }
 110    }
 111}
 112
 113#[derive(Debug)]
 114pub enum AgentThreadEntry {
 115    UserMessage(UserMessage),
 116    AssistantMessage(AssistantMessage),
 117    ToolCall(ToolCall),
 118}
 119
 120impl AgentThreadEntry {
 121    pub fn to_markdown(&self, cx: &App) -> String {
 122        match self {
 123            Self::UserMessage(message) => message.to_markdown(cx),
 124            Self::AssistantMessage(message) => message.to_markdown(cx),
 125            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 126        }
 127    }
 128
 129    pub fn user_message(&self) -> Option<&UserMessage> {
 130        if let AgentThreadEntry::UserMessage(message) = self {
 131            Some(message)
 132        } else {
 133            None
 134        }
 135    }
 136
 137    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 138        if let AgentThreadEntry::ToolCall(call) = self {
 139            itertools::Either::Left(call.diffs())
 140        } else {
 141            itertools::Either::Right(std::iter::empty())
 142        }
 143    }
 144
 145    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 146        if let AgentThreadEntry::ToolCall(call) = self {
 147            itertools::Either::Left(call.terminals())
 148        } else {
 149            itertools::Either::Right(std::iter::empty())
 150        }
 151    }
 152
 153    pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
 154        if let AgentThreadEntry::ToolCall(ToolCall {
 155            locations,
 156            resolved_locations,
 157            ..
 158        }) = self
 159        {
 160            Some((
 161                locations.get(ix)?.clone(),
 162                resolved_locations.get(ix)?.clone()?,
 163            ))
 164        } else {
 165            None
 166        }
 167    }
 168}
 169
 170#[derive(Debug)]
 171pub struct ToolCall {
 172    pub id: acp::ToolCallId,
 173    pub label: Entity<Markdown>,
 174    pub kind: acp::ToolKind,
 175    pub content: Vec<ToolCallContent>,
 176    pub status: ToolCallStatus,
 177    pub locations: Vec<acp::ToolCallLocation>,
 178    pub resolved_locations: Vec<Option<AgentLocation>>,
 179    pub raw_input: Option<serde_json::Value>,
 180    pub raw_output: Option<serde_json::Value>,
 181}
 182
 183impl ToolCall {
 184    fn from_acp(
 185        tool_call: acp::ToolCall,
 186        status: ToolCallStatus,
 187        language_registry: Arc<LanguageRegistry>,
 188        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 189        cx: &mut App,
 190    ) -> Result<Self> {
 191        let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
 192            first_line.to_owned() + ""
 193        } else {
 194            tool_call.title
 195        };
 196        let mut content = Vec::with_capacity(tool_call.content.len());
 197        for item in tool_call.content {
 198            content.push(ToolCallContent::from_acp(
 199                item,
 200                language_registry.clone(),
 201                terminals,
 202                cx,
 203            )?);
 204        }
 205
 206        let result = Self {
 207            id: tool_call.id,
 208            label: cx
 209                .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
 210            kind: tool_call.kind,
 211            content,
 212            locations: tool_call.locations,
 213            resolved_locations: Vec::default(),
 214            status,
 215            raw_input: tool_call.raw_input,
 216            raw_output: tool_call.raw_output,
 217        };
 218        Ok(result)
 219    }
 220
 221    fn update_fields(
 222        &mut self,
 223        fields: acp::ToolCallUpdateFields,
 224        language_registry: Arc<LanguageRegistry>,
 225        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 226        cx: &mut App,
 227    ) -> Result<()> {
 228        let acp::ToolCallUpdateFields {
 229            kind,
 230            status,
 231            title,
 232            content,
 233            locations,
 234            raw_input,
 235            raw_output,
 236        } = fields;
 237
 238        if let Some(kind) = kind {
 239            self.kind = kind;
 240        }
 241
 242        if let Some(status) = status {
 243            self.status = status.into();
 244        }
 245
 246        if let Some(title) = title {
 247            self.label.update(cx, |label, cx| {
 248                if let Some((first_line, _)) = title.split_once("\n") {
 249                    label.replace(first_line.to_owned() + "", cx)
 250                } else {
 251                    label.replace(title, cx);
 252                }
 253            });
 254        }
 255
 256        if let Some(content) = content {
 257            let new_content_len = content.len();
 258            let mut content = content.into_iter();
 259
 260            // Reuse existing content if we can
 261            for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
 262                old.update_from_acp(new, language_registry.clone(), terminals, cx)?;
 263            }
 264            for new in content {
 265                self.content.push(ToolCallContent::from_acp(
 266                    new,
 267                    language_registry.clone(),
 268                    terminals,
 269                    cx,
 270                )?)
 271            }
 272            self.content.truncate(new_content_len);
 273        }
 274
 275        if let Some(locations) = locations {
 276            self.locations = locations;
 277        }
 278
 279        if let Some(raw_input) = raw_input {
 280            self.raw_input = Some(raw_input);
 281        }
 282
 283        if let Some(raw_output) = raw_output {
 284            if self.content.is_empty()
 285                && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
 286            {
 287                self.content
 288                    .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
 289                        markdown,
 290                    }));
 291            }
 292            self.raw_output = Some(raw_output);
 293        }
 294        Ok(())
 295    }
 296
 297    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 298        self.content.iter().filter_map(|content| match content {
 299            ToolCallContent::Diff(diff) => Some(diff),
 300            ToolCallContent::ContentBlock(_) => None,
 301            ToolCallContent::Terminal(_) => None,
 302        })
 303    }
 304
 305    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 306        self.content.iter().filter_map(|content| match content {
 307            ToolCallContent::Terminal(terminal) => Some(terminal),
 308            ToolCallContent::ContentBlock(_) => None,
 309            ToolCallContent::Diff(_) => None,
 310        })
 311    }
 312
 313    fn to_markdown(&self, cx: &App) -> String {
 314        let mut markdown = format!(
 315            "**Tool Call: {}**\nStatus: {}\n\n",
 316            self.label.read(cx).source(),
 317            self.status
 318        );
 319        for content in &self.content {
 320            markdown.push_str(content.to_markdown(cx).as_str());
 321            markdown.push_str("\n\n");
 322        }
 323        markdown
 324    }
 325
 326    async fn resolve_location(
 327        location: acp::ToolCallLocation,
 328        project: WeakEntity<Project>,
 329        cx: &mut AsyncApp,
 330    ) -> Option<AgentLocation> {
 331        let buffer = project
 332            .update(cx, |project, cx| {
 333                project
 334                    .project_path_for_absolute_path(&location.path, cx)
 335                    .map(|path| project.open_buffer(path, cx))
 336            })
 337            .ok()??;
 338        let buffer = buffer.await.log_err()?;
 339        let position = buffer
 340            .update(cx, |buffer, _| {
 341                if let Some(row) = location.line {
 342                    let snapshot = buffer.snapshot();
 343                    let column = snapshot.indent_size_for_line(row).len;
 344                    let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
 345                    snapshot.anchor_before(point)
 346                } else {
 347                    Anchor::MIN
 348                }
 349            })
 350            .ok()?;
 351
 352        Some(AgentLocation {
 353            buffer: buffer.downgrade(),
 354            position,
 355        })
 356    }
 357
 358    fn resolve_locations(
 359        &self,
 360        project: Entity<Project>,
 361        cx: &mut App,
 362    ) -> Task<Vec<Option<AgentLocation>>> {
 363        let locations = self.locations.clone();
 364        project.update(cx, |_, cx| {
 365            cx.spawn(async move |project, cx| {
 366                let mut new_locations = Vec::new();
 367                for location in locations {
 368                    new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
 369                }
 370                new_locations
 371            })
 372        })
 373    }
 374}
 375
 376#[derive(Debug)]
 377pub enum ToolCallStatus {
 378    /// The tool call hasn't started running yet, but we start showing it to
 379    /// the user.
 380    Pending,
 381    /// The tool call is waiting for confirmation from the user.
 382    WaitingForConfirmation {
 383        options: Vec<acp::PermissionOption>,
 384        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 385    },
 386    /// The tool call is currently running.
 387    InProgress,
 388    /// The tool call completed successfully.
 389    Completed,
 390    /// The tool call failed.
 391    Failed,
 392    /// The user rejected the tool call.
 393    Rejected,
 394    /// The user canceled generation so the tool call was canceled.
 395    Canceled,
 396}
 397
 398impl From<acp::ToolCallStatus> for ToolCallStatus {
 399    fn from(status: acp::ToolCallStatus) -> Self {
 400        match status {
 401            acp::ToolCallStatus::Pending => Self::Pending,
 402            acp::ToolCallStatus::InProgress => Self::InProgress,
 403            acp::ToolCallStatus::Completed => Self::Completed,
 404            acp::ToolCallStatus::Failed => Self::Failed,
 405        }
 406    }
 407}
 408
 409impl Display for ToolCallStatus {
 410    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 411        write!(
 412            f,
 413            "{}",
 414            match self {
 415                ToolCallStatus::Pending => "Pending",
 416                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 417                ToolCallStatus::InProgress => "In Progress",
 418                ToolCallStatus::Completed => "Completed",
 419                ToolCallStatus::Failed => "Failed",
 420                ToolCallStatus::Rejected => "Rejected",
 421                ToolCallStatus::Canceled => "Canceled",
 422            }
 423        )
 424    }
 425}
 426
 427#[derive(Debug, PartialEq, Clone)]
 428pub enum ContentBlock {
 429    Empty,
 430    Markdown { markdown: Entity<Markdown> },
 431    ResourceLink { resource_link: acp::ResourceLink },
 432}
 433
 434impl ContentBlock {
 435    pub fn new(
 436        block: acp::ContentBlock,
 437        language_registry: &Arc<LanguageRegistry>,
 438        cx: &mut App,
 439    ) -> Self {
 440        let mut this = Self::Empty;
 441        this.append(block, language_registry, cx);
 442        this
 443    }
 444
 445    pub fn new_combined(
 446        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 447        language_registry: Arc<LanguageRegistry>,
 448        cx: &mut App,
 449    ) -> Self {
 450        let mut this = Self::Empty;
 451        for block in blocks {
 452            this.append(block, &language_registry, cx);
 453        }
 454        this
 455    }
 456
 457    pub fn append(
 458        &mut self,
 459        block: acp::ContentBlock,
 460        language_registry: &Arc<LanguageRegistry>,
 461        cx: &mut App,
 462    ) {
 463        if matches!(self, ContentBlock::Empty)
 464            && let acp::ContentBlock::ResourceLink(resource_link) = block
 465        {
 466            *self = ContentBlock::ResourceLink { resource_link };
 467            return;
 468        }
 469
 470        let new_content = self.block_string_contents(block);
 471
 472        match self {
 473            ContentBlock::Empty => {
 474                *self = Self::create_markdown_block(new_content, language_registry, cx);
 475            }
 476            ContentBlock::Markdown { markdown } => {
 477                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 478            }
 479            ContentBlock::ResourceLink { resource_link } => {
 480                let existing_content = Self::resource_link_md(&resource_link.uri);
 481                let combined = format!("{}\n{}", existing_content, new_content);
 482
 483                *self = Self::create_markdown_block(combined, language_registry, cx);
 484            }
 485        }
 486    }
 487
 488    fn create_markdown_block(
 489        content: String,
 490        language_registry: &Arc<LanguageRegistry>,
 491        cx: &mut App,
 492    ) -> ContentBlock {
 493        ContentBlock::Markdown {
 494            markdown: cx
 495                .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
 496        }
 497    }
 498
 499    fn block_string_contents(&self, block: acp::ContentBlock) -> String {
 500        match block {
 501            acp::ContentBlock::Text(text_content) => text_content.text,
 502            acp::ContentBlock::ResourceLink(resource_link) => {
 503                Self::resource_link_md(&resource_link.uri)
 504            }
 505            acp::ContentBlock::Resource(acp::EmbeddedResource {
 506                resource:
 507                    acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
 508                        uri,
 509                        ..
 510                    }),
 511                ..
 512            }) => Self::resource_link_md(&uri),
 513            acp::ContentBlock::Image(image) => Self::image_md(&image),
 514            acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
 515        }
 516    }
 517
 518    fn resource_link_md(uri: &str) -> String {
 519        if let Some(uri) = MentionUri::parse(uri).log_err() {
 520            uri.as_link().to_string()
 521        } else {
 522            uri.to_string()
 523        }
 524    }
 525
 526    fn image_md(_image: &acp::ImageContent) -> String {
 527        "`Image`".into()
 528    }
 529
 530    pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 531        match self {
 532            ContentBlock::Empty => "",
 533            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 534            ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
 535        }
 536    }
 537
 538    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 539        match self {
 540            ContentBlock::Empty => None,
 541            ContentBlock::Markdown { markdown } => Some(markdown),
 542            ContentBlock::ResourceLink { .. } => None,
 543        }
 544    }
 545
 546    pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
 547        match self {
 548            ContentBlock::ResourceLink { resource_link } => Some(resource_link),
 549            _ => None,
 550        }
 551    }
 552}
 553
 554#[derive(Debug)]
 555pub enum ToolCallContent {
 556    ContentBlock(ContentBlock),
 557    Diff(Entity<Diff>),
 558    Terminal(Entity<Terminal>),
 559}
 560
 561impl ToolCallContent {
 562    pub fn from_acp(
 563        content: acp::ToolCallContent,
 564        language_registry: Arc<LanguageRegistry>,
 565        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 566        cx: &mut App,
 567    ) -> Result<Self> {
 568        match content {
 569            acp::ToolCallContent::Content { content } => Ok(Self::ContentBlock(ContentBlock::new(
 570                content,
 571                &language_registry,
 572                cx,
 573            ))),
 574            acp::ToolCallContent::Diff { diff } => Ok(Self::Diff(cx.new(|cx| {
 575                Diff::finalized(
 576                    diff.path,
 577                    diff.old_text,
 578                    diff.new_text,
 579                    language_registry,
 580                    cx,
 581                )
 582            }))),
 583            acp::ToolCallContent::Terminal { terminal_id } => terminals
 584                .get(&terminal_id)
 585                .cloned()
 586                .map(Self::Terminal)
 587                .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
 588        }
 589    }
 590
 591    pub fn update_from_acp(
 592        &mut self,
 593        new: acp::ToolCallContent,
 594        language_registry: Arc<LanguageRegistry>,
 595        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 596        cx: &mut App,
 597    ) -> Result<()> {
 598        let needs_update = match (&self, &new) {
 599            (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
 600                old_diff.read(cx).needs_update(
 601                    new_diff.old_text.as_deref().unwrap_or(""),
 602                    &new_diff.new_text,
 603                    cx,
 604                )
 605            }
 606            _ => true,
 607        };
 608
 609        if needs_update {
 610            *self = Self::from_acp(new, language_registry, terminals, cx)?;
 611        }
 612        Ok(())
 613    }
 614
 615    pub fn to_markdown(&self, cx: &App) -> String {
 616        match self {
 617            Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
 618            Self::Diff(diff) => diff.read(cx).to_markdown(cx),
 619            Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
 620        }
 621    }
 622}
 623
 624#[derive(Debug, PartialEq)]
 625pub enum ToolCallUpdate {
 626    UpdateFields(acp::ToolCallUpdate),
 627    UpdateDiff(ToolCallUpdateDiff),
 628    UpdateTerminal(ToolCallUpdateTerminal),
 629}
 630
 631impl ToolCallUpdate {
 632    fn id(&self) -> &acp::ToolCallId {
 633        match self {
 634            Self::UpdateFields(update) => &update.id,
 635            Self::UpdateDiff(diff) => &diff.id,
 636            Self::UpdateTerminal(terminal) => &terminal.id,
 637        }
 638    }
 639}
 640
 641impl From<acp::ToolCallUpdate> for ToolCallUpdate {
 642    fn from(update: acp::ToolCallUpdate) -> Self {
 643        Self::UpdateFields(update)
 644    }
 645}
 646
 647impl From<ToolCallUpdateDiff> for ToolCallUpdate {
 648    fn from(diff: ToolCallUpdateDiff) -> Self {
 649        Self::UpdateDiff(diff)
 650    }
 651}
 652
 653#[derive(Debug, PartialEq)]
 654pub struct ToolCallUpdateDiff {
 655    pub id: acp::ToolCallId,
 656    pub diff: Entity<Diff>,
 657}
 658
 659impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
 660    fn from(terminal: ToolCallUpdateTerminal) -> Self {
 661        Self::UpdateTerminal(terminal)
 662    }
 663}
 664
 665#[derive(Debug, PartialEq)]
 666pub struct ToolCallUpdateTerminal {
 667    pub id: acp::ToolCallId,
 668    pub terminal: Entity<Terminal>,
 669}
 670
 671#[derive(Debug, Default)]
 672pub struct Plan {
 673    pub entries: Vec<PlanEntry>,
 674}
 675
 676#[derive(Debug)]
 677pub struct PlanStats<'a> {
 678    pub in_progress_entry: Option<&'a PlanEntry>,
 679    pub pending: u32,
 680    pub completed: u32,
 681}
 682
 683impl Plan {
 684    pub fn is_empty(&self) -> bool {
 685        self.entries.is_empty()
 686    }
 687
 688    pub fn stats(&self) -> PlanStats<'_> {
 689        let mut stats = PlanStats {
 690            in_progress_entry: None,
 691            pending: 0,
 692            completed: 0,
 693        };
 694
 695        for entry in &self.entries {
 696            match &entry.status {
 697                acp::PlanEntryStatus::Pending => {
 698                    stats.pending += 1;
 699                }
 700                acp::PlanEntryStatus::InProgress => {
 701                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 702                }
 703                acp::PlanEntryStatus::Completed => {
 704                    stats.completed += 1;
 705                }
 706            }
 707        }
 708
 709        stats
 710    }
 711}
 712
 713#[derive(Debug)]
 714pub struct PlanEntry {
 715    pub content: Entity<Markdown>,
 716    pub priority: acp::PlanEntryPriority,
 717    pub status: acp::PlanEntryStatus,
 718}
 719
 720impl PlanEntry {
 721    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 722        Self {
 723            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 724            priority: entry.priority,
 725            status: entry.status,
 726        }
 727    }
 728}
 729
 730#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 731pub struct TokenUsage {
 732    pub max_tokens: u64,
 733    pub used_tokens: u64,
 734}
 735
 736impl TokenUsage {
 737    pub fn ratio(&self) -> TokenUsageRatio {
 738        #[cfg(debug_assertions)]
 739        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 740            .unwrap_or("0.8".to_string())
 741            .parse()
 742            .unwrap();
 743        #[cfg(not(debug_assertions))]
 744        let warning_threshold: f32 = 0.8;
 745
 746        // When the maximum is unknown because there is no selected model,
 747        // avoid showing the token limit warning.
 748        if self.max_tokens == 0 {
 749            TokenUsageRatio::Normal
 750        } else if self.used_tokens >= self.max_tokens {
 751            TokenUsageRatio::Exceeded
 752        } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
 753            TokenUsageRatio::Warning
 754        } else {
 755            TokenUsageRatio::Normal
 756        }
 757    }
 758}
 759
 760#[derive(Debug, Clone, PartialEq, Eq)]
 761pub enum TokenUsageRatio {
 762    Normal,
 763    Warning,
 764    Exceeded,
 765}
 766
 767#[derive(Debug, Clone)]
 768pub struct RetryStatus {
 769    pub last_error: SharedString,
 770    pub attempt: usize,
 771    pub max_attempts: usize,
 772    pub started_at: Instant,
 773    pub duration: Duration,
 774}
 775
 776pub struct AcpThread {
 777    title: SharedString,
 778    entries: Vec<AgentThreadEntry>,
 779    plan: Plan,
 780    project: Entity<Project>,
 781    action_log: Entity<ActionLog>,
 782    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 783    send_task: Option<Task<()>>,
 784    connection: Rc<dyn AgentConnection>,
 785    session_id: acp::SessionId,
 786    token_usage: Option<TokenUsage>,
 787    prompt_capabilities: acp::PromptCapabilities,
 788    available_commands: Vec<acp::AvailableCommand>,
 789    _observe_prompt_capabilities: Task<anyhow::Result<()>>,
 790    determine_shell: Shared<Task<String>>,
 791    terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
 792}
 793
 794#[derive(Debug)]
 795pub enum AcpThreadEvent {
 796    NewEntry,
 797    TitleUpdated,
 798    TokenUsageUpdated,
 799    EntryUpdated(usize),
 800    EntriesRemoved(Range<usize>),
 801    ToolAuthorizationRequired,
 802    Retry(RetryStatus),
 803    Stopped,
 804    Error,
 805    LoadError(LoadError),
 806    PromptCapabilitiesUpdated,
 807    Refusal,
 808}
 809
 810impl EventEmitter<AcpThreadEvent> for AcpThread {}
 811
 812#[derive(PartialEq, Eq, Debug)]
 813pub enum ThreadStatus {
 814    Idle,
 815    WaitingForToolConfirmation,
 816    Generating,
 817}
 818
 819#[derive(Debug, Clone)]
 820pub enum LoadError {
 821    Unsupported {
 822        command: SharedString,
 823        current_version: SharedString,
 824        minimum_version: SharedString,
 825    },
 826    FailedToInstall(SharedString),
 827    Exited {
 828        status: ExitStatus,
 829    },
 830    Other(SharedString),
 831}
 832
 833impl Display for LoadError {
 834    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 835        match self {
 836            LoadError::Unsupported {
 837                command: path,
 838                current_version,
 839                minimum_version,
 840            } => {
 841                write!(
 842                    f,
 843                    "version {current_version} from {path} is not supported (need at least {minimum_version})"
 844                )
 845            }
 846            LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
 847            LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
 848            LoadError::Other(msg) => write!(f, "{msg}"),
 849        }
 850    }
 851}
 852
 853impl Error for LoadError {}
 854
 855impl AcpThread {
 856    pub fn new(
 857        title: impl Into<SharedString>,
 858        connection: Rc<dyn AgentConnection>,
 859        project: Entity<Project>,
 860        action_log: Entity<ActionLog>,
 861        session_id: acp::SessionId,
 862        mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
 863        available_commands: Vec<acp::AvailableCommand>,
 864        cx: &mut Context<Self>,
 865    ) -> Self {
 866        let prompt_capabilities = *prompt_capabilities_rx.borrow();
 867        let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
 868            loop {
 869                let caps = prompt_capabilities_rx.recv().await?;
 870                this.update(cx, |this, cx| {
 871                    this.prompt_capabilities = caps;
 872                    cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
 873                })?;
 874            }
 875        });
 876
 877        let determine_shell = cx
 878            .background_spawn(async move {
 879                if cfg!(windows) {
 880                    return get_system_shell();
 881                }
 882
 883                if which::which("bash").is_ok() {
 884                    "bash".into()
 885                } else {
 886                    get_system_shell()
 887                }
 888            })
 889            .shared();
 890
 891        Self {
 892            action_log,
 893            shared_buffers: Default::default(),
 894            entries: Default::default(),
 895            plan: Default::default(),
 896            title: title.into(),
 897            project,
 898            send_task: None,
 899            connection,
 900            session_id,
 901            token_usage: None,
 902            prompt_capabilities,
 903            available_commands,
 904            _observe_prompt_capabilities: task,
 905            terminals: HashMap::default(),
 906            determine_shell,
 907        }
 908    }
 909
 910    pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
 911        self.prompt_capabilities
 912    }
 913
 914    pub fn available_commands(&self) -> Vec<acp::AvailableCommand> {
 915        self.available_commands.clone()
 916    }
 917
 918    pub fn connection(&self) -> &Rc<dyn AgentConnection> {
 919        &self.connection
 920    }
 921
 922    pub fn action_log(&self) -> &Entity<ActionLog> {
 923        &self.action_log
 924    }
 925
 926    pub fn project(&self) -> &Entity<Project> {
 927        &self.project
 928    }
 929
 930    pub fn title(&self) -> SharedString {
 931        self.title.clone()
 932    }
 933
 934    pub fn entries(&self) -> &[AgentThreadEntry] {
 935        &self.entries
 936    }
 937
 938    pub fn session_id(&self) -> &acp::SessionId {
 939        &self.session_id
 940    }
 941
 942    pub fn status(&self) -> ThreadStatus {
 943        if self.send_task.is_some() {
 944            if self.waiting_for_tool_confirmation() {
 945                ThreadStatus::WaitingForToolConfirmation
 946            } else {
 947                ThreadStatus::Generating
 948            }
 949        } else {
 950            ThreadStatus::Idle
 951        }
 952    }
 953
 954    pub fn token_usage(&self) -> Option<&TokenUsage> {
 955        self.token_usage.as_ref()
 956    }
 957
 958    pub fn has_pending_edit_tool_calls(&self) -> bool {
 959        for entry in self.entries.iter().rev() {
 960            match entry {
 961                AgentThreadEntry::UserMessage(_) => return false,
 962                AgentThreadEntry::ToolCall(
 963                    call @ ToolCall {
 964                        status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
 965                        ..
 966                    },
 967                ) if call.diffs().next().is_some() => {
 968                    return true;
 969                }
 970                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 971            }
 972        }
 973
 974        false
 975    }
 976
 977    pub fn used_tools_since_last_user_message(&self) -> bool {
 978        for entry in self.entries.iter().rev() {
 979            match entry {
 980                AgentThreadEntry::UserMessage(..) => return false,
 981                AgentThreadEntry::AssistantMessage(..) => continue,
 982                AgentThreadEntry::ToolCall(..) => return true,
 983            }
 984        }
 985
 986        false
 987    }
 988
 989    pub fn handle_session_update(
 990        &mut self,
 991        update: acp::SessionUpdate,
 992        cx: &mut Context<Self>,
 993    ) -> Result<(), acp::Error> {
 994        match update {
 995            acp::SessionUpdate::UserMessageChunk { content } => {
 996                self.push_user_content_block(None, content, cx);
 997            }
 998            acp::SessionUpdate::AgentMessageChunk { content } => {
 999                self.push_assistant_content_block(content, false, cx);
1000            }
1001            acp::SessionUpdate::AgentThoughtChunk { content } => {
1002                self.push_assistant_content_block(content, true, cx);
1003            }
1004            acp::SessionUpdate::ToolCall(tool_call) => {
1005                self.upsert_tool_call(tool_call, cx)?;
1006            }
1007            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1008                self.update_tool_call(tool_call_update, cx)?;
1009            }
1010            acp::SessionUpdate::Plan(plan) => {
1011                self.update_plan(plan, cx);
1012            }
1013        }
1014        Ok(())
1015    }
1016
1017    pub fn push_user_content_block(
1018        &mut self,
1019        message_id: Option<UserMessageId>,
1020        chunk: acp::ContentBlock,
1021        cx: &mut Context<Self>,
1022    ) {
1023        let language_registry = self.project.read(cx).languages().clone();
1024        let entries_len = self.entries.len();
1025
1026        if let Some(last_entry) = self.entries.last_mut()
1027            && let AgentThreadEntry::UserMessage(UserMessage {
1028                id,
1029                content,
1030                chunks,
1031                ..
1032            }) = last_entry
1033        {
1034            *id = message_id.or(id.take());
1035            content.append(chunk.clone(), &language_registry, cx);
1036            chunks.push(chunk);
1037            let idx = entries_len - 1;
1038            cx.emit(AcpThreadEvent::EntryUpdated(idx));
1039        } else {
1040            let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
1041            self.push_entry(
1042                AgentThreadEntry::UserMessage(UserMessage {
1043                    id: message_id,
1044                    content,
1045                    chunks: vec![chunk],
1046                    checkpoint: None,
1047                }),
1048                cx,
1049            );
1050        }
1051    }
1052
1053    pub fn push_assistant_content_block(
1054        &mut self,
1055        chunk: acp::ContentBlock,
1056        is_thought: bool,
1057        cx: &mut Context<Self>,
1058    ) {
1059        let language_registry = self.project.read(cx).languages().clone();
1060        let entries_len = self.entries.len();
1061        if let Some(last_entry) = self.entries.last_mut()
1062            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1063        {
1064            let idx = entries_len - 1;
1065            cx.emit(AcpThreadEvent::EntryUpdated(idx));
1066            match (chunks.last_mut(), is_thought) {
1067                (Some(AssistantMessageChunk::Message { block }), false)
1068                | (Some(AssistantMessageChunk::Thought { block }), true) => {
1069                    block.append(chunk, &language_registry, cx)
1070                }
1071                _ => {
1072                    let block = ContentBlock::new(chunk, &language_registry, cx);
1073                    if is_thought {
1074                        chunks.push(AssistantMessageChunk::Thought { block })
1075                    } else {
1076                        chunks.push(AssistantMessageChunk::Message { block })
1077                    }
1078                }
1079            }
1080        } else {
1081            let block = ContentBlock::new(chunk, &language_registry, cx);
1082            let chunk = if is_thought {
1083                AssistantMessageChunk::Thought { block }
1084            } else {
1085                AssistantMessageChunk::Message { block }
1086            };
1087
1088            self.push_entry(
1089                AgentThreadEntry::AssistantMessage(AssistantMessage {
1090                    chunks: vec![chunk],
1091                }),
1092                cx,
1093            );
1094        }
1095    }
1096
1097    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1098        self.entries.push(entry);
1099        cx.emit(AcpThreadEvent::NewEntry);
1100    }
1101
1102    pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1103        self.connection.set_title(&self.session_id, cx).is_some()
1104    }
1105
1106    pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1107        if title != self.title {
1108            self.title = title.clone();
1109            cx.emit(AcpThreadEvent::TitleUpdated);
1110            if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1111                return set_title.run(title, cx);
1112            }
1113        }
1114        Task::ready(Ok(()))
1115    }
1116
1117    pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1118        self.token_usage = usage;
1119        cx.emit(AcpThreadEvent::TokenUsageUpdated);
1120    }
1121
1122    pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1123        cx.emit(AcpThreadEvent::Retry(status));
1124    }
1125
1126    pub fn update_tool_call(
1127        &mut self,
1128        update: impl Into<ToolCallUpdate>,
1129        cx: &mut Context<Self>,
1130    ) -> Result<()> {
1131        let update = update.into();
1132        let languages = self.project.read(cx).languages().clone();
1133
1134        let ix = self
1135            .index_for_tool_call(update.id())
1136            .context("Tool call not found")?;
1137        let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1138            unreachable!()
1139        };
1140
1141        match update {
1142            ToolCallUpdate::UpdateFields(update) => {
1143                let location_updated = update.fields.locations.is_some();
1144                call.update_fields(update.fields, languages, &self.terminals, cx)?;
1145                if location_updated {
1146                    self.resolve_locations(update.id, cx);
1147                }
1148            }
1149            ToolCallUpdate::UpdateDiff(update) => {
1150                call.content.clear();
1151                call.content.push(ToolCallContent::Diff(update.diff));
1152            }
1153            ToolCallUpdate::UpdateTerminal(update) => {
1154                call.content.clear();
1155                call.content
1156                    .push(ToolCallContent::Terminal(update.terminal));
1157            }
1158        }
1159
1160        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1161
1162        Ok(())
1163    }
1164
1165    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1166    pub fn upsert_tool_call(
1167        &mut self,
1168        tool_call: acp::ToolCall,
1169        cx: &mut Context<Self>,
1170    ) -> Result<(), acp::Error> {
1171        let status = tool_call.status.into();
1172        self.upsert_tool_call_inner(tool_call.into(), status, cx)
1173    }
1174
1175    /// Fails if id does not match an existing entry.
1176    pub fn upsert_tool_call_inner(
1177        &mut self,
1178        update: acp::ToolCallUpdate,
1179        status: ToolCallStatus,
1180        cx: &mut Context<Self>,
1181    ) -> Result<(), acp::Error> {
1182        let language_registry = self.project.read(cx).languages().clone();
1183        let id = update.id.clone();
1184
1185        if let Some(ix) = self.index_for_tool_call(&id) {
1186            let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1187                unreachable!()
1188            };
1189
1190            call.update_fields(update.fields, language_registry, &self.terminals, cx)?;
1191            call.status = status;
1192
1193            cx.emit(AcpThreadEvent::EntryUpdated(ix));
1194        } else {
1195            let call = ToolCall::from_acp(
1196                update.try_into()?,
1197                status,
1198                language_registry,
1199                &self.terminals,
1200                cx,
1201            )?;
1202            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1203        };
1204
1205        self.resolve_locations(id, cx);
1206        Ok(())
1207    }
1208
1209    fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1210        self.entries
1211            .iter()
1212            .enumerate()
1213            .rev()
1214            .find_map(|(index, entry)| {
1215                if let AgentThreadEntry::ToolCall(tool_call) = entry
1216                    && &tool_call.id == id
1217                {
1218                    Some(index)
1219                } else {
1220                    None
1221                }
1222            })
1223    }
1224
1225    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1226        // The tool call we are looking for is typically the last one, or very close to the end.
1227        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1228        self.entries
1229            .iter_mut()
1230            .enumerate()
1231            .rev()
1232            .find_map(|(index, tool_call)| {
1233                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1234                    && &tool_call.id == id
1235                {
1236                    Some((index, tool_call))
1237                } else {
1238                    None
1239                }
1240            })
1241    }
1242
1243    pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1244        self.entries
1245            .iter()
1246            .enumerate()
1247            .rev()
1248            .find_map(|(index, tool_call)| {
1249                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1250                    && &tool_call.id == id
1251                {
1252                    Some((index, tool_call))
1253                } else {
1254                    None
1255                }
1256            })
1257    }
1258
1259    pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1260        let project = self.project.clone();
1261        let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1262            return;
1263        };
1264        let task = tool_call.resolve_locations(project, cx);
1265        cx.spawn(async move |this, cx| {
1266            let resolved_locations = task.await;
1267            this.update(cx, |this, cx| {
1268                let project = this.project.clone();
1269                let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1270                    return;
1271                };
1272                if let Some(Some(location)) = resolved_locations.last() {
1273                    project.update(cx, |project, cx| {
1274                        if let Some(agent_location) = project.agent_location() {
1275                            let should_ignore = agent_location.buffer == location.buffer
1276                                && location
1277                                    .buffer
1278                                    .update(cx, |buffer, _| {
1279                                        let snapshot = buffer.snapshot();
1280                                        let old_position =
1281                                            agent_location.position.to_point(&snapshot);
1282                                        let new_position = location.position.to_point(&snapshot);
1283                                        // ignore this so that when we get updates from the edit tool
1284                                        // the position doesn't reset to the startof line
1285                                        old_position.row == new_position.row
1286                                            && old_position.column > new_position.column
1287                                    })
1288                                    .ok()
1289                                    .unwrap_or_default();
1290                            if !should_ignore {
1291                                project.set_agent_location(Some(location.clone()), cx);
1292                            }
1293                        }
1294                    });
1295                }
1296                if tool_call.resolved_locations != resolved_locations {
1297                    tool_call.resolved_locations = resolved_locations;
1298                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1299                }
1300            })
1301        })
1302        .detach();
1303    }
1304
1305    pub fn request_tool_call_authorization(
1306        &mut self,
1307        tool_call: acp::ToolCallUpdate,
1308        options: Vec<acp::PermissionOption>,
1309        cx: &mut Context<Self>,
1310    ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1311        let (tx, rx) = oneshot::channel();
1312
1313        if AgentSettings::get_global(cx).always_allow_tool_actions {
1314            // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1315            // some tools would (incorrectly) continue to auto-accept.
1316            if let Some(allow_once_option) = options.iter().find_map(|option| {
1317                if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1318                    Some(option.id.clone())
1319                } else {
1320                    None
1321                }
1322            }) {
1323                self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1324                return Ok(async {
1325                    acp::RequestPermissionOutcome::Selected {
1326                        option_id: allow_once_option,
1327                    }
1328                }
1329                .boxed());
1330            }
1331        }
1332
1333        let status = ToolCallStatus::WaitingForConfirmation {
1334            options,
1335            respond_tx: tx,
1336        };
1337
1338        self.upsert_tool_call_inner(tool_call, status, cx)?;
1339        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1340
1341        let fut = async {
1342            match rx.await {
1343                Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
1344                Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1345            }
1346        }
1347        .boxed();
1348
1349        Ok(fut)
1350    }
1351
1352    pub fn authorize_tool_call(
1353        &mut self,
1354        id: acp::ToolCallId,
1355        option_id: acp::PermissionOptionId,
1356        option_kind: acp::PermissionOptionKind,
1357        cx: &mut Context<Self>,
1358    ) {
1359        let Some((ix, call)) = self.tool_call_mut(&id) else {
1360            return;
1361        };
1362
1363        let new_status = match option_kind {
1364            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1365                ToolCallStatus::Rejected
1366            }
1367            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1368                ToolCallStatus::InProgress
1369            }
1370        };
1371
1372        let curr_status = mem::replace(&mut call.status, new_status);
1373
1374        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1375            respond_tx.send(option_id).log_err();
1376        } else if cfg!(debug_assertions) {
1377            panic!("tried to authorize an already authorized tool call");
1378        }
1379
1380        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1381    }
1382
1383    /// Returns true if the last turn is awaiting tool authorization
1384    pub fn waiting_for_tool_confirmation(&self) -> bool {
1385        for entry in self.entries.iter().rev() {
1386            match &entry {
1387                AgentThreadEntry::ToolCall(call) => match call.status {
1388                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
1389                    ToolCallStatus::Pending
1390                    | ToolCallStatus::InProgress
1391                    | ToolCallStatus::Completed
1392                    | ToolCallStatus::Failed
1393                    | ToolCallStatus::Rejected
1394                    | ToolCallStatus::Canceled => continue,
1395                },
1396                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1397                    // Reached the beginning of the turn
1398                    return false;
1399                }
1400            }
1401        }
1402        false
1403    }
1404
1405    pub fn plan(&self) -> &Plan {
1406        &self.plan
1407    }
1408
1409    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1410        let new_entries_len = request.entries.len();
1411        let mut new_entries = request.entries.into_iter();
1412
1413        // Reuse existing markdown to prevent flickering
1414        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1415            let PlanEntry {
1416                content,
1417                priority,
1418                status,
1419            } = old;
1420            content.update(cx, |old, cx| {
1421                old.replace(new.content, cx);
1422            });
1423            *priority = new.priority;
1424            *status = new.status;
1425        }
1426        for new in new_entries {
1427            self.plan.entries.push(PlanEntry::from_acp(new, cx))
1428        }
1429        self.plan.entries.truncate(new_entries_len);
1430
1431        cx.notify();
1432    }
1433
1434    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1435        self.plan
1436            .entries
1437            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1438        cx.notify();
1439    }
1440
1441    #[cfg(any(test, feature = "test-support"))]
1442    pub fn send_raw(
1443        &mut self,
1444        message: &str,
1445        cx: &mut Context<Self>,
1446    ) -> BoxFuture<'static, Result<()>> {
1447        self.send(
1448            vec![acp::ContentBlock::Text(acp::TextContent {
1449                text: message.to_string(),
1450                annotations: None,
1451            })],
1452            cx,
1453        )
1454    }
1455
1456    pub fn send(
1457        &mut self,
1458        message: Vec<acp::ContentBlock>,
1459        cx: &mut Context<Self>,
1460    ) -> BoxFuture<'static, Result<()>> {
1461        let block = ContentBlock::new_combined(
1462            message.clone(),
1463            self.project.read(cx).languages().clone(),
1464            cx,
1465        );
1466        let request = acp::PromptRequest {
1467            prompt: message.clone(),
1468            session_id: self.session_id.clone(),
1469        };
1470        let git_store = self.project.read(cx).git_store().clone();
1471
1472        let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1473            Some(UserMessageId::new())
1474        } else {
1475            None
1476        };
1477
1478        self.run_turn(cx, async move |this, cx| {
1479            this.update(cx, |this, cx| {
1480                this.push_entry(
1481                    AgentThreadEntry::UserMessage(UserMessage {
1482                        id: message_id.clone(),
1483                        content: block,
1484                        chunks: message,
1485                        checkpoint: None,
1486                    }),
1487                    cx,
1488                );
1489            })
1490            .ok();
1491
1492            let old_checkpoint = git_store
1493                .update(cx, |git, cx| git.checkpoint(cx))?
1494                .await
1495                .context("failed to get old checkpoint")
1496                .log_err();
1497            this.update(cx, |this, cx| {
1498                if let Some((_ix, message)) = this.last_user_message() {
1499                    message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1500                        git_checkpoint,
1501                        show: false,
1502                    });
1503                }
1504                this.connection.prompt(message_id, request, cx)
1505            })?
1506            .await
1507        })
1508    }
1509
1510    pub fn can_resume(&self, cx: &App) -> bool {
1511        self.connection.resume(&self.session_id, cx).is_some()
1512    }
1513
1514    pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1515        self.run_turn(cx, async move |this, cx| {
1516            this.update(cx, |this, cx| {
1517                this.connection
1518                    .resume(&this.session_id, cx)
1519                    .map(|resume| resume.run(cx))
1520            })?
1521            .context("resuming a session is not supported")?
1522            .await
1523        })
1524    }
1525
1526    fn run_turn(
1527        &mut self,
1528        cx: &mut Context<Self>,
1529        f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1530    ) -> BoxFuture<'static, Result<()>> {
1531        self.clear_completed_plan_entries(cx);
1532
1533        let (tx, rx) = oneshot::channel();
1534        let cancel_task = self.cancel(cx);
1535
1536        self.send_task = Some(cx.spawn(async move |this, cx| {
1537            cancel_task.await;
1538            tx.send(f(this, cx).await).ok();
1539        }));
1540
1541        cx.spawn(async move |this, cx| {
1542            let response = rx.await;
1543
1544            this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1545                .await?;
1546
1547            this.update(cx, |this, cx| {
1548                this.project
1549                    .update(cx, |project, cx| project.set_agent_location(None, cx));
1550                match response {
1551                    Ok(Err(e)) => {
1552                        this.send_task.take();
1553                        cx.emit(AcpThreadEvent::Error);
1554                        Err(e)
1555                    }
1556                    result => {
1557                        let canceled = matches!(
1558                            result,
1559                            Ok(Ok(acp::PromptResponse {
1560                                stop_reason: acp::StopReason::Cancelled
1561                            }))
1562                        );
1563
1564                        // We only take the task if the current prompt wasn't canceled.
1565                        //
1566                        // This prompt may have been canceled because another one was sent
1567                        // while it was still generating. In these cases, dropping `send_task`
1568                        // would cause the next generation to be canceled.
1569                        if !canceled {
1570                            this.send_task.take();
1571                        }
1572
1573                        // Handle refusal - distinguish between user prompt and tool call refusals
1574                        if let Ok(Ok(acp::PromptResponse {
1575                            stop_reason: acp::StopReason::Refusal,
1576                        })) = result
1577                        {
1578                            if let Some((user_msg_ix, _)) = this.last_user_message() {
1579                                // Check if there's a completed tool call with results after the last user message
1580                                // This indicates the refusal is in response to tool output, not the user's prompt
1581                                let has_completed_tool_call_after_user_msg =
1582                                    this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1583                                        if let AgentThreadEntry::ToolCall(tool_call) = entry {
1584                                            // Check if the tool call has completed and has output
1585                                            matches!(tool_call.status, ToolCallStatus::Completed)
1586                                                && tool_call.raw_output.is_some()
1587                                        } else {
1588                                            false
1589                                        }
1590                                    });
1591
1592                                if has_completed_tool_call_after_user_msg {
1593                                    // Refusal is due to tool output - don't truncate, just notify
1594                                    // The model refused based on what the tool returned
1595                                    cx.emit(AcpThreadEvent::Refusal);
1596                                } else {
1597                                    // User prompt was refused - truncate back to before the user message
1598                                    let range = user_msg_ix..this.entries.len();
1599                                    if range.start < range.end {
1600                                        this.entries.truncate(user_msg_ix);
1601                                        cx.emit(AcpThreadEvent::EntriesRemoved(range));
1602                                    }
1603                                    cx.emit(AcpThreadEvent::Refusal);
1604                                }
1605                            } else {
1606                                // No user message found, treat as general refusal
1607                                cx.emit(AcpThreadEvent::Refusal);
1608                            }
1609                        }
1610
1611                        cx.emit(AcpThreadEvent::Stopped);
1612                        Ok(())
1613                    }
1614                }
1615            })?
1616        })
1617        .boxed()
1618    }
1619
1620    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1621        let Some(send_task) = self.send_task.take() else {
1622            return Task::ready(());
1623        };
1624
1625        for entry in self.entries.iter_mut() {
1626            if let AgentThreadEntry::ToolCall(call) = entry {
1627                let cancel = matches!(
1628                    call.status,
1629                    ToolCallStatus::Pending
1630                        | ToolCallStatus::WaitingForConfirmation { .. }
1631                        | ToolCallStatus::InProgress
1632                );
1633
1634                if cancel {
1635                    call.status = ToolCallStatus::Canceled;
1636                }
1637            }
1638        }
1639
1640        self.connection.cancel(&self.session_id, cx);
1641
1642        // Wait for the send task to complete
1643        cx.foreground_executor().spawn(send_task)
1644    }
1645
1646    /// Rewinds this thread to before the entry at `index`, removing it and all
1647    /// subsequent entries while reverting any changes made from that point.
1648    pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1649        let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1650            return Task::ready(Err(anyhow!("not supported")));
1651        };
1652        let Some(message) = self.user_message(&id) else {
1653            return Task::ready(Err(anyhow!("message not found")));
1654        };
1655
1656        let checkpoint = message
1657            .checkpoint
1658            .as_ref()
1659            .map(|c| c.git_checkpoint.clone());
1660
1661        let git_store = self.project.read(cx).git_store().clone();
1662        cx.spawn(async move |this, cx| {
1663            if let Some(checkpoint) = checkpoint {
1664                git_store
1665                    .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1666                    .await?;
1667            }
1668
1669            cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1670            this.update(cx, |this, cx| {
1671                if let Some((ix, _)) = this.user_message_mut(&id) {
1672                    let range = ix..this.entries.len();
1673                    this.entries.truncate(ix);
1674                    cx.emit(AcpThreadEvent::EntriesRemoved(range));
1675                }
1676            })
1677        })
1678    }
1679
1680    fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1681        let git_store = self.project.read(cx).git_store().clone();
1682
1683        let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1684            if let Some(checkpoint) = message.checkpoint.as_ref() {
1685                checkpoint.git_checkpoint.clone()
1686            } else {
1687                return Task::ready(Ok(()));
1688            }
1689        } else {
1690            return Task::ready(Ok(()));
1691        };
1692
1693        let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1694        cx.spawn(async move |this, cx| {
1695            let new_checkpoint = new_checkpoint
1696                .await
1697                .context("failed to get new checkpoint")
1698                .log_err();
1699            if let Some(new_checkpoint) = new_checkpoint {
1700                let equal = git_store
1701                    .update(cx, |git, cx| {
1702                        git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1703                    })?
1704                    .await
1705                    .unwrap_or(true);
1706                this.update(cx, |this, cx| {
1707                    let (ix, message) = this.last_user_message().context("no user message")?;
1708                    let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1709                    checkpoint.show = !equal;
1710                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1711                    anyhow::Ok(())
1712                })??;
1713            }
1714
1715            Ok(())
1716        })
1717    }
1718
1719    fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1720        self.entries
1721            .iter_mut()
1722            .enumerate()
1723            .rev()
1724            .find_map(|(ix, entry)| {
1725                if let AgentThreadEntry::UserMessage(message) = entry {
1726                    Some((ix, message))
1727                } else {
1728                    None
1729                }
1730            })
1731    }
1732
1733    fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1734        self.entries.iter().find_map(|entry| {
1735            if let AgentThreadEntry::UserMessage(message) = entry {
1736                if message.id.as_ref() == Some(id) {
1737                    Some(message)
1738                } else {
1739                    None
1740                }
1741            } else {
1742                None
1743            }
1744        })
1745    }
1746
1747    fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1748        self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1749            if let AgentThreadEntry::UserMessage(message) = entry {
1750                if message.id.as_ref() == Some(id) {
1751                    Some((ix, message))
1752                } else {
1753                    None
1754                }
1755            } else {
1756                None
1757            }
1758        })
1759    }
1760
1761    pub fn read_text_file(
1762        &self,
1763        path: PathBuf,
1764        line: Option<u32>,
1765        limit: Option<u32>,
1766        reuse_shared_snapshot: bool,
1767        cx: &mut Context<Self>,
1768    ) -> Task<Result<String>> {
1769        let project = self.project.clone();
1770        let action_log = self.action_log.clone();
1771        cx.spawn(async move |this, cx| {
1772            let load = project.update(cx, |project, cx| {
1773                let path = project
1774                    .project_path_for_absolute_path(&path, cx)
1775                    .context("invalid path")?;
1776                anyhow::Ok(project.open_buffer(path, cx))
1777            });
1778            let buffer = load??.await?;
1779
1780            let snapshot = if reuse_shared_snapshot {
1781                this.read_with(cx, |this, _| {
1782                    this.shared_buffers.get(&buffer.clone()).cloned()
1783                })
1784                .log_err()
1785                .flatten()
1786            } else {
1787                None
1788            };
1789
1790            let snapshot = if let Some(snapshot) = snapshot {
1791                snapshot
1792            } else {
1793                action_log.update(cx, |action_log, cx| {
1794                    action_log.buffer_read(buffer.clone(), cx);
1795                })?;
1796                project.update(cx, |project, cx| {
1797                    let position = buffer
1798                        .read(cx)
1799                        .snapshot()
1800                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1801                    project.set_agent_location(
1802                        Some(AgentLocation {
1803                            buffer: buffer.downgrade(),
1804                            position,
1805                        }),
1806                        cx,
1807                    );
1808                })?;
1809
1810                buffer.update(cx, |buffer, _| buffer.snapshot())?
1811            };
1812
1813            this.update(cx, |this, _| {
1814                let text = snapshot.text();
1815                this.shared_buffers.insert(buffer.clone(), snapshot);
1816                if line.is_none() && limit.is_none() {
1817                    return Ok(text);
1818                }
1819                let limit = limit.unwrap_or(u32::MAX) as usize;
1820                let Some(line) = line else {
1821                    return Ok(text.lines().take(limit).collect::<String>());
1822                };
1823
1824                let count = text.lines().count();
1825                if count < line as usize {
1826                    anyhow::bail!("There are only {} lines", count);
1827                }
1828                Ok(text
1829                    .lines()
1830                    .skip(line as usize + 1)
1831                    .take(limit)
1832                    .collect::<String>())
1833            })?
1834        })
1835    }
1836
1837    pub fn write_text_file(
1838        &self,
1839        path: PathBuf,
1840        content: String,
1841        cx: &mut Context<Self>,
1842    ) -> Task<Result<()>> {
1843        let project = self.project.clone();
1844        let action_log = self.action_log.clone();
1845        cx.spawn(async move |this, cx| {
1846            let load = project.update(cx, |project, cx| {
1847                let path = project
1848                    .project_path_for_absolute_path(&path, cx)
1849                    .context("invalid path")?;
1850                anyhow::Ok(project.open_buffer(path, cx))
1851            });
1852            let buffer = load??.await?;
1853            let snapshot = this.update(cx, |this, cx| {
1854                this.shared_buffers
1855                    .get(&buffer)
1856                    .cloned()
1857                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1858            })?;
1859            let edits = cx
1860                .background_executor()
1861                .spawn(async move {
1862                    let old_text = snapshot.text();
1863                    text_diff(old_text.as_str(), &content)
1864                        .into_iter()
1865                        .map(|(range, replacement)| {
1866                            (
1867                                snapshot.anchor_after(range.start)
1868                                    ..snapshot.anchor_before(range.end),
1869                                replacement,
1870                            )
1871                        })
1872                        .collect::<Vec<_>>()
1873                })
1874                .await;
1875
1876            project.update(cx, |project, cx| {
1877                project.set_agent_location(
1878                    Some(AgentLocation {
1879                        buffer: buffer.downgrade(),
1880                        position: edits
1881                            .last()
1882                            .map(|(range, _)| range.end)
1883                            .unwrap_or(Anchor::MIN),
1884                    }),
1885                    cx,
1886                );
1887            })?;
1888
1889            let format_on_save = cx.update(|cx| {
1890                action_log.update(cx, |action_log, cx| {
1891                    action_log.buffer_read(buffer.clone(), cx);
1892                });
1893
1894                let format_on_save = buffer.update(cx, |buffer, cx| {
1895                    buffer.edit(edits, None, cx);
1896
1897                    let settings = language::language_settings::language_settings(
1898                        buffer.language().map(|l| l.name()),
1899                        buffer.file(),
1900                        cx,
1901                    );
1902
1903                    settings.format_on_save != FormatOnSave::Off
1904                });
1905                action_log.update(cx, |action_log, cx| {
1906                    action_log.buffer_edited(buffer.clone(), cx);
1907                });
1908                format_on_save
1909            })?;
1910
1911            if format_on_save {
1912                let format_task = project.update(cx, |project, cx| {
1913                    project.format(
1914                        HashSet::from_iter([buffer.clone()]),
1915                        LspFormatTarget::Buffers,
1916                        false,
1917                        FormatTrigger::Save,
1918                        cx,
1919                    )
1920                })?;
1921                format_task.await.log_err();
1922
1923                action_log.update(cx, |action_log, cx| {
1924                    action_log.buffer_edited(buffer.clone(), cx);
1925                })?;
1926            }
1927
1928            project
1929                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1930                .await
1931        })
1932    }
1933
1934    pub fn create_terminal(
1935        &self,
1936        mut command: String,
1937        args: Vec<String>,
1938        extra_env: Vec<acp::EnvVariable>,
1939        cwd: Option<PathBuf>,
1940        output_byte_limit: Option<u64>,
1941        cx: &mut Context<Self>,
1942    ) -> Task<Result<Entity<Terminal>>> {
1943        for arg in args {
1944            command.push(' ');
1945            command.push_str(&arg);
1946        }
1947
1948        let shell_command = if cfg!(windows) {
1949            format!("$null | & {{{}}}", command.replace("\"", "'"))
1950        } else if let Some(cwd) = cwd.as_ref().and_then(|cwd| cwd.as_os_str().to_str()) {
1951            // Make sure once we're *inside* the shell, we cd into `cwd`
1952            format!("(cd {cwd}; {}) </dev/null", command)
1953        } else {
1954            format!("({}) </dev/null", command)
1955        };
1956        let args = vec!["-c".into(), shell_command];
1957
1958        let env = match &cwd {
1959            Some(dir) => self.project.update(cx, |project, cx| {
1960                project.directory_environment(dir.as_path().into(), cx)
1961            }),
1962            None => Task::ready(None).shared(),
1963        };
1964
1965        let env = cx.spawn(async move |_, _| {
1966            let mut env = env.await.unwrap_or_default();
1967            if cfg!(unix) {
1968                env.insert("PAGER".into(), "cat".into());
1969            }
1970            for var in extra_env {
1971                env.insert(var.name, var.value);
1972            }
1973            env
1974        });
1975
1976        let project = self.project.clone();
1977        let language_registry = project.read(cx).languages().clone();
1978        let determine_shell = self.determine_shell.clone();
1979
1980        let terminal_id = acp::TerminalId(Uuid::new_v4().to_string().into());
1981        let terminal_task = cx.spawn({
1982            let terminal_id = terminal_id.clone();
1983            async move |_this, cx| {
1984                let program = determine_shell.await;
1985                let env = env.await;
1986                let terminal = project
1987                    .update(cx, |project, cx| {
1988                        project.create_terminal_task(
1989                            task::SpawnInTerminal {
1990                                command: Some(program),
1991                                args,
1992                                cwd: cwd.clone(),
1993                                env,
1994                                ..Default::default()
1995                            },
1996                            cx,
1997                        )
1998                    })?
1999                    .await?;
2000
2001                cx.new(|cx| {
2002                    Terminal::new(
2003                        terminal_id,
2004                        command,
2005                        cwd,
2006                        output_byte_limit.map(|l| l as usize),
2007                        terminal,
2008                        language_registry,
2009                        cx,
2010                    )
2011                })
2012            }
2013        });
2014
2015        cx.spawn(async move |this, cx| {
2016            let terminal = terminal_task.await?;
2017            this.update(cx, |this, _cx| {
2018                this.terminals.insert(terminal_id, terminal.clone());
2019                terminal
2020            })
2021        })
2022    }
2023
2024    pub fn kill_terminal(
2025        &mut self,
2026        terminal_id: acp::TerminalId,
2027        cx: &mut Context<Self>,
2028    ) -> Result<()> {
2029        self.terminals
2030            .get(&terminal_id)
2031            .context("Terminal not found")?
2032            .update(cx, |terminal, cx| {
2033                terminal.kill(cx);
2034            });
2035
2036        Ok(())
2037    }
2038
2039    pub fn release_terminal(
2040        &mut self,
2041        terminal_id: acp::TerminalId,
2042        cx: &mut Context<Self>,
2043    ) -> Result<()> {
2044        self.terminals
2045            .remove(&terminal_id)
2046            .context("Terminal not found")?
2047            .update(cx, |terminal, cx| {
2048                terminal.kill(cx);
2049            });
2050
2051        Ok(())
2052    }
2053
2054    pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2055        self.terminals
2056            .get(&terminal_id)
2057            .context("Terminal not found")
2058            .cloned()
2059    }
2060
2061    pub fn to_markdown(&self, cx: &App) -> String {
2062        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2063    }
2064
2065    pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2066        cx.emit(AcpThreadEvent::LoadError(error));
2067    }
2068}
2069
2070fn markdown_for_raw_output(
2071    raw_output: &serde_json::Value,
2072    language_registry: &Arc<LanguageRegistry>,
2073    cx: &mut App,
2074) -> Option<Entity<Markdown>> {
2075    match raw_output {
2076        serde_json::Value::Null => None,
2077        serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2078            Markdown::new(
2079                value.to_string().into(),
2080                Some(language_registry.clone()),
2081                None,
2082                cx,
2083            )
2084        })),
2085        serde_json::Value::Number(value) => Some(cx.new(|cx| {
2086            Markdown::new(
2087                value.to_string().into(),
2088                Some(language_registry.clone()),
2089                None,
2090                cx,
2091            )
2092        })),
2093        serde_json::Value::String(value) => Some(cx.new(|cx| {
2094            Markdown::new(
2095                value.clone().into(),
2096                Some(language_registry.clone()),
2097                None,
2098                cx,
2099            )
2100        })),
2101        value => Some(cx.new(|cx| {
2102            Markdown::new(
2103                format!("```json\n{}\n```", value).into(),
2104                Some(language_registry.clone()),
2105                None,
2106                cx,
2107            )
2108        })),
2109    }
2110}
2111
2112#[cfg(test)]
2113mod tests {
2114    use super::*;
2115    use anyhow::anyhow;
2116    use futures::{channel::mpsc, future::LocalBoxFuture, select};
2117    use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2118    use indoc::indoc;
2119    use project::{FakeFs, Fs};
2120    use rand::Rng as _;
2121    use serde_json::json;
2122    use settings::SettingsStore;
2123    use smol::stream::StreamExt as _;
2124    use std::{
2125        any::Any,
2126        cell::RefCell,
2127        path::Path,
2128        rc::Rc,
2129        sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2130        time::Duration,
2131    };
2132    use util::path;
2133
2134    fn init_test(cx: &mut TestAppContext) {
2135        env_logger::try_init().ok();
2136        cx.update(|cx| {
2137            let settings_store = SettingsStore::test(cx);
2138            cx.set_global(settings_store);
2139            Project::init_settings(cx);
2140            language::init(cx);
2141        });
2142    }
2143
2144    #[gpui::test]
2145    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2146        init_test(cx);
2147
2148        let fs = FakeFs::new(cx.executor());
2149        let project = Project::test(fs, [], cx).await;
2150        let connection = Rc::new(FakeAgentConnection::new());
2151        let thread = cx
2152            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2153            .await
2154            .unwrap();
2155
2156        // Test creating a new user message
2157        thread.update(cx, |thread, cx| {
2158            thread.push_user_content_block(
2159                None,
2160                acp::ContentBlock::Text(acp::TextContent {
2161                    annotations: None,
2162                    text: "Hello, ".to_string(),
2163                }),
2164                cx,
2165            );
2166        });
2167
2168        thread.update(cx, |thread, cx| {
2169            assert_eq!(thread.entries.len(), 1);
2170            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2171                assert_eq!(user_msg.id, None);
2172                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2173            } else {
2174                panic!("Expected UserMessage");
2175            }
2176        });
2177
2178        // Test appending to existing user message
2179        let message_1_id = UserMessageId::new();
2180        thread.update(cx, |thread, cx| {
2181            thread.push_user_content_block(
2182                Some(message_1_id.clone()),
2183                acp::ContentBlock::Text(acp::TextContent {
2184                    annotations: None,
2185                    text: "world!".to_string(),
2186                }),
2187                cx,
2188            );
2189        });
2190
2191        thread.update(cx, |thread, cx| {
2192            assert_eq!(thread.entries.len(), 1);
2193            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2194                assert_eq!(user_msg.id, Some(message_1_id));
2195                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2196            } else {
2197                panic!("Expected UserMessage");
2198            }
2199        });
2200
2201        // Test creating new user message after assistant message
2202        thread.update(cx, |thread, cx| {
2203            thread.push_assistant_content_block(
2204                acp::ContentBlock::Text(acp::TextContent {
2205                    annotations: None,
2206                    text: "Assistant response".to_string(),
2207                }),
2208                false,
2209                cx,
2210            );
2211        });
2212
2213        let message_2_id = UserMessageId::new();
2214        thread.update(cx, |thread, cx| {
2215            thread.push_user_content_block(
2216                Some(message_2_id.clone()),
2217                acp::ContentBlock::Text(acp::TextContent {
2218                    annotations: None,
2219                    text: "New user message".to_string(),
2220                }),
2221                cx,
2222            );
2223        });
2224
2225        thread.update(cx, |thread, cx| {
2226            assert_eq!(thread.entries.len(), 3);
2227            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2228                assert_eq!(user_msg.id, Some(message_2_id));
2229                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2230            } else {
2231                panic!("Expected UserMessage at index 2");
2232            }
2233        });
2234    }
2235
2236    #[gpui::test]
2237    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2238        init_test(cx);
2239
2240        let fs = FakeFs::new(cx.executor());
2241        let project = Project::test(fs, [], cx).await;
2242        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2243            |_, thread, mut cx| {
2244                async move {
2245                    thread.update(&mut cx, |thread, cx| {
2246                        thread
2247                            .handle_session_update(
2248                                acp::SessionUpdate::AgentThoughtChunk {
2249                                    content: "Thinking ".into(),
2250                                },
2251                                cx,
2252                            )
2253                            .unwrap();
2254                        thread
2255                            .handle_session_update(
2256                                acp::SessionUpdate::AgentThoughtChunk {
2257                                    content: "hard!".into(),
2258                                },
2259                                cx,
2260                            )
2261                            .unwrap();
2262                    })?;
2263                    Ok(acp::PromptResponse {
2264                        stop_reason: acp::StopReason::EndTurn,
2265                    })
2266                }
2267                .boxed_local()
2268            },
2269        ));
2270
2271        let thread = cx
2272            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2273            .await
2274            .unwrap();
2275
2276        thread
2277            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2278            .await
2279            .unwrap();
2280
2281        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2282        assert_eq!(
2283            output,
2284            indoc! {r#"
2285            ## User
2286
2287            Hello from Zed!
2288
2289            ## Assistant
2290
2291            <thinking>
2292            Thinking hard!
2293            </thinking>
2294
2295            "#}
2296        );
2297    }
2298
2299    #[gpui::test]
2300    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2301        init_test(cx);
2302
2303        let fs = FakeFs::new(cx.executor());
2304        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2305            .await;
2306        let project = Project::test(fs.clone(), [], cx).await;
2307        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2308        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2309        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2310            move |_, thread, mut cx| {
2311                let read_file_tx = read_file_tx.clone();
2312                async move {
2313                    let content = thread
2314                        .update(&mut cx, |thread, cx| {
2315                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2316                        })
2317                        .unwrap()
2318                        .await
2319                        .unwrap();
2320                    assert_eq!(content, "one\ntwo\nthree\n");
2321                    read_file_tx.take().unwrap().send(()).unwrap();
2322                    thread
2323                        .update(&mut cx, |thread, cx| {
2324                            thread.write_text_file(
2325                                path!("/tmp/foo").into(),
2326                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
2327                                cx,
2328                            )
2329                        })
2330                        .unwrap()
2331                        .await
2332                        .unwrap();
2333                    Ok(acp::PromptResponse {
2334                        stop_reason: acp::StopReason::EndTurn,
2335                    })
2336                }
2337                .boxed_local()
2338            },
2339        ));
2340
2341        let (worktree, pathbuf) = project
2342            .update(cx, |project, cx| {
2343                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2344            })
2345            .await
2346            .unwrap();
2347        let buffer = project
2348            .update(cx, |project, cx| {
2349                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2350            })
2351            .await
2352            .unwrap();
2353
2354        let thread = cx
2355            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2356            .await
2357            .unwrap();
2358
2359        let request = thread.update(cx, |thread, cx| {
2360            thread.send_raw("Extend the count in /tmp/foo", cx)
2361        });
2362        read_file_rx.await.ok();
2363        buffer.update(cx, |buffer, cx| {
2364            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2365        });
2366        cx.run_until_parked();
2367        assert_eq!(
2368            buffer.read_with(cx, |buffer, _| buffer.text()),
2369            "zero\none\ntwo\nthree\nfour\nfive\n"
2370        );
2371        assert_eq!(
2372            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2373            "zero\none\ntwo\nthree\nfour\nfive\n"
2374        );
2375        request.await.unwrap();
2376    }
2377
2378    #[gpui::test]
2379    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2380        init_test(cx);
2381
2382        let fs = FakeFs::new(cx.executor());
2383        let project = Project::test(fs, [], cx).await;
2384        let id = acp::ToolCallId("test".into());
2385
2386        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2387            let id = id.clone();
2388            move |_, thread, mut cx| {
2389                let id = id.clone();
2390                async move {
2391                    thread
2392                        .update(&mut cx, |thread, cx| {
2393                            thread.handle_session_update(
2394                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2395                                    id: id.clone(),
2396                                    title: "Label".into(),
2397                                    kind: acp::ToolKind::Fetch,
2398                                    status: acp::ToolCallStatus::InProgress,
2399                                    content: vec![],
2400                                    locations: vec![],
2401                                    raw_input: None,
2402                                    raw_output: None,
2403                                }),
2404                                cx,
2405                            )
2406                        })
2407                        .unwrap()
2408                        .unwrap();
2409                    Ok(acp::PromptResponse {
2410                        stop_reason: acp::StopReason::EndTurn,
2411                    })
2412                }
2413                .boxed_local()
2414            }
2415        }));
2416
2417        let thread = cx
2418            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2419            .await
2420            .unwrap();
2421
2422        let request = thread.update(cx, |thread, cx| {
2423            thread.send_raw("Fetch https://example.com", cx)
2424        });
2425
2426        run_until_first_tool_call(&thread, cx).await;
2427
2428        thread.read_with(cx, |thread, _| {
2429            assert!(matches!(
2430                thread.entries[1],
2431                AgentThreadEntry::ToolCall(ToolCall {
2432                    status: ToolCallStatus::InProgress,
2433                    ..
2434                })
2435            ));
2436        });
2437
2438        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2439
2440        thread.read_with(cx, |thread, _| {
2441            assert!(matches!(
2442                &thread.entries[1],
2443                AgentThreadEntry::ToolCall(ToolCall {
2444                    status: ToolCallStatus::Canceled,
2445                    ..
2446                })
2447            ));
2448        });
2449
2450        thread
2451            .update(cx, |thread, cx| {
2452                thread.handle_session_update(
2453                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2454                        id,
2455                        fields: acp::ToolCallUpdateFields {
2456                            status: Some(acp::ToolCallStatus::Completed),
2457                            ..Default::default()
2458                        },
2459                    }),
2460                    cx,
2461                )
2462            })
2463            .unwrap();
2464
2465        request.await.unwrap();
2466
2467        thread.read_with(cx, |thread, _| {
2468            assert!(matches!(
2469                thread.entries[1],
2470                AgentThreadEntry::ToolCall(ToolCall {
2471                    status: ToolCallStatus::Completed,
2472                    ..
2473                })
2474            ));
2475        });
2476    }
2477
2478    #[gpui::test]
2479    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2480        init_test(cx);
2481        let fs = FakeFs::new(cx.background_executor.clone());
2482        fs.insert_tree(path!("/test"), json!({})).await;
2483        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2484
2485        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2486            move |_, thread, mut cx| {
2487                async move {
2488                    thread
2489                        .update(&mut cx, |thread, cx| {
2490                            thread.handle_session_update(
2491                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2492                                    id: acp::ToolCallId("test".into()),
2493                                    title: "Label".into(),
2494                                    kind: acp::ToolKind::Edit,
2495                                    status: acp::ToolCallStatus::Completed,
2496                                    content: vec![acp::ToolCallContent::Diff {
2497                                        diff: acp::Diff {
2498                                            path: "/test/test.txt".into(),
2499                                            old_text: None,
2500                                            new_text: "foo".into(),
2501                                        },
2502                                    }],
2503                                    locations: vec![],
2504                                    raw_input: None,
2505                                    raw_output: None,
2506                                }),
2507                                cx,
2508                            )
2509                        })
2510                        .unwrap()
2511                        .unwrap();
2512                    Ok(acp::PromptResponse {
2513                        stop_reason: acp::StopReason::EndTurn,
2514                    })
2515                }
2516                .boxed_local()
2517            }
2518        }));
2519
2520        let thread = cx
2521            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2522            .await
2523            .unwrap();
2524
2525        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2526            .await
2527            .unwrap();
2528
2529        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2530    }
2531
2532    #[gpui::test(iterations = 10)]
2533    async fn test_checkpoints(cx: &mut TestAppContext) {
2534        init_test(cx);
2535        let fs = FakeFs::new(cx.background_executor.clone());
2536        fs.insert_tree(
2537            path!("/test"),
2538            json!({
2539                ".git": {}
2540            }),
2541        )
2542        .await;
2543        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2544
2545        let simulate_changes = Arc::new(AtomicBool::new(true));
2546        let next_filename = Arc::new(AtomicUsize::new(0));
2547        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2548            let simulate_changes = simulate_changes.clone();
2549            let next_filename = next_filename.clone();
2550            let fs = fs.clone();
2551            move |request, thread, mut cx| {
2552                let fs = fs.clone();
2553                let simulate_changes = simulate_changes.clone();
2554                let next_filename = next_filename.clone();
2555                async move {
2556                    if simulate_changes.load(SeqCst) {
2557                        let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2558                        fs.write(Path::new(&filename), b"").await?;
2559                    }
2560
2561                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2562                        panic!("expected text content block");
2563                    };
2564                    thread.update(&mut cx, |thread, cx| {
2565                        thread
2566                            .handle_session_update(
2567                                acp::SessionUpdate::AgentMessageChunk {
2568                                    content: content.text.to_uppercase().into(),
2569                                },
2570                                cx,
2571                            )
2572                            .unwrap();
2573                    })?;
2574                    Ok(acp::PromptResponse {
2575                        stop_reason: acp::StopReason::EndTurn,
2576                    })
2577                }
2578                .boxed_local()
2579            }
2580        }));
2581        let thread = cx
2582            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2583            .await
2584            .unwrap();
2585
2586        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2587            .await
2588            .unwrap();
2589        thread.read_with(cx, |thread, cx| {
2590            assert_eq!(
2591                thread.to_markdown(cx),
2592                indoc! {"
2593                    ## User (checkpoint)
2594
2595                    Lorem
2596
2597                    ## Assistant
2598
2599                    LOREM
2600
2601                "}
2602            );
2603        });
2604        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2605
2606        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2607            .await
2608            .unwrap();
2609        thread.read_with(cx, |thread, cx| {
2610            assert_eq!(
2611                thread.to_markdown(cx),
2612                indoc! {"
2613                    ## User (checkpoint)
2614
2615                    Lorem
2616
2617                    ## Assistant
2618
2619                    LOREM
2620
2621                    ## User (checkpoint)
2622
2623                    ipsum
2624
2625                    ## Assistant
2626
2627                    IPSUM
2628
2629                "}
2630            );
2631        });
2632        assert_eq!(
2633            fs.files(),
2634            vec![
2635                Path::new(path!("/test/file-0")),
2636                Path::new(path!("/test/file-1"))
2637            ]
2638        );
2639
2640        // Checkpoint isn't stored when there are no changes.
2641        simulate_changes.store(false, SeqCst);
2642        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2643            .await
2644            .unwrap();
2645        thread.read_with(cx, |thread, cx| {
2646            assert_eq!(
2647                thread.to_markdown(cx),
2648                indoc! {"
2649                    ## User (checkpoint)
2650
2651                    Lorem
2652
2653                    ## Assistant
2654
2655                    LOREM
2656
2657                    ## User (checkpoint)
2658
2659                    ipsum
2660
2661                    ## Assistant
2662
2663                    IPSUM
2664
2665                    ## User
2666
2667                    dolor
2668
2669                    ## Assistant
2670
2671                    DOLOR
2672
2673                "}
2674            );
2675        });
2676        assert_eq!(
2677            fs.files(),
2678            vec![
2679                Path::new(path!("/test/file-0")),
2680                Path::new(path!("/test/file-1"))
2681            ]
2682        );
2683
2684        // Rewinding the conversation truncates the history and restores the checkpoint.
2685        thread
2686            .update(cx, |thread, cx| {
2687                let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2688                    panic!("unexpected entries {:?}", thread.entries)
2689                };
2690                thread.rewind(message.id.clone().unwrap(), cx)
2691            })
2692            .await
2693            .unwrap();
2694        thread.read_with(cx, |thread, cx| {
2695            assert_eq!(
2696                thread.to_markdown(cx),
2697                indoc! {"
2698                    ## User (checkpoint)
2699
2700                    Lorem
2701
2702                    ## Assistant
2703
2704                    LOREM
2705
2706                "}
2707            );
2708        });
2709        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2710    }
2711
2712    #[gpui::test]
2713    async fn test_tool_result_refusal(cx: &mut TestAppContext) {
2714        use std::sync::atomic::AtomicUsize;
2715        init_test(cx);
2716
2717        let fs = FakeFs::new(cx.executor());
2718        let project = Project::test(fs, None, cx).await;
2719
2720        // Create a connection that simulates refusal after tool result
2721        let prompt_count = Arc::new(AtomicUsize::new(0));
2722        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2723            let prompt_count = prompt_count.clone();
2724            move |_request, thread, mut cx| {
2725                let count = prompt_count.fetch_add(1, SeqCst);
2726                async move {
2727                    if count == 0 {
2728                        // First prompt: Generate a tool call with result
2729                        thread.update(&mut cx, |thread, cx| {
2730                            thread
2731                                .handle_session_update(
2732                                    acp::SessionUpdate::ToolCall(acp::ToolCall {
2733                                        id: acp::ToolCallId("tool1".into()),
2734                                        title: "Test Tool".into(),
2735                                        kind: acp::ToolKind::Fetch,
2736                                        status: acp::ToolCallStatus::Completed,
2737                                        content: vec![],
2738                                        locations: vec![],
2739                                        raw_input: Some(serde_json::json!({"query": "test"})),
2740                                        raw_output: Some(
2741                                            serde_json::json!({"result": "inappropriate content"}),
2742                                        ),
2743                                    }),
2744                                    cx,
2745                                )
2746                                .unwrap();
2747                        })?;
2748
2749                        // Now return refusal because of the tool result
2750                        Ok(acp::PromptResponse {
2751                            stop_reason: acp::StopReason::Refusal,
2752                        })
2753                    } else {
2754                        Ok(acp::PromptResponse {
2755                            stop_reason: acp::StopReason::EndTurn,
2756                        })
2757                    }
2758                }
2759                .boxed_local()
2760            }
2761        }));
2762
2763        let thread = cx
2764            .update(|cx| connection.new_thread(project, Path::new("/test"), cx))
2765            .await
2766            .unwrap();
2767
2768        // Track if we see a Refusal event
2769        let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
2770        let saw_refusal_event_captured = saw_refusal_event.clone();
2771        thread.update(cx, |_thread, cx| {
2772            cx.subscribe(
2773                &thread,
2774                move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
2775                    if matches!(event, AcpThreadEvent::Refusal) {
2776                        *saw_refusal_event_captured.lock().unwrap() = true;
2777                    }
2778                },
2779            )
2780            .detach();
2781        });
2782
2783        // Send a user message - this will trigger tool call and then refusal
2784        let send_task = thread.update(cx, |thread, cx| {
2785            thread.send(
2786                vec![acp::ContentBlock::Text(acp::TextContent {
2787                    text: "Hello".into(),
2788                    annotations: None,
2789                })],
2790                cx,
2791            )
2792        });
2793        cx.background_executor.spawn(send_task).detach();
2794        cx.run_until_parked();
2795
2796        // Verify that:
2797        // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
2798        // 2. The user message was NOT truncated
2799        assert!(
2800            *saw_refusal_event.lock().unwrap(),
2801            "Refusal event should be emitted for tool result refusals"
2802        );
2803
2804        thread.read_with(cx, |thread, _| {
2805            let entries = thread.entries();
2806            assert!(entries.len() >= 2, "Should have user message and tool call");
2807
2808            // Verify user message is still there
2809            assert!(
2810                matches!(entries[0], AgentThreadEntry::UserMessage(_)),
2811                "User message should not be truncated"
2812            );
2813
2814            // Verify tool call is there with result
2815            if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
2816                assert!(
2817                    tool_call.raw_output.is_some(),
2818                    "Tool call should have output"
2819                );
2820            } else {
2821                panic!("Expected tool call at index 1");
2822            }
2823        });
2824    }
2825
2826    #[gpui::test]
2827    async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
2828        init_test(cx);
2829
2830        let fs = FakeFs::new(cx.executor());
2831        let project = Project::test(fs, None, cx).await;
2832
2833        let refuse_next = Arc::new(AtomicBool::new(false));
2834        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2835            let refuse_next = refuse_next.clone();
2836            move |_request, _thread, _cx| {
2837                if refuse_next.load(SeqCst) {
2838                    async move {
2839                        Ok(acp::PromptResponse {
2840                            stop_reason: acp::StopReason::Refusal,
2841                        })
2842                    }
2843                    .boxed_local()
2844                } else {
2845                    async move {
2846                        Ok(acp::PromptResponse {
2847                            stop_reason: acp::StopReason::EndTurn,
2848                        })
2849                    }
2850                    .boxed_local()
2851                }
2852            }
2853        }));
2854
2855        let thread = cx
2856            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2857            .await
2858            .unwrap();
2859
2860        // Track if we see a Refusal event
2861        let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
2862        let saw_refusal_event_captured = saw_refusal_event.clone();
2863        thread.update(cx, |_thread, cx| {
2864            cx.subscribe(
2865                &thread,
2866                move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
2867                    if matches!(event, AcpThreadEvent::Refusal) {
2868                        *saw_refusal_event_captured.lock().unwrap() = true;
2869                    }
2870                },
2871            )
2872            .detach();
2873        });
2874
2875        // Send a message that will be refused
2876        refuse_next.store(true, SeqCst);
2877        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2878            .await
2879            .unwrap();
2880
2881        // Verify that a Refusal event WAS emitted for user prompt refusal
2882        assert!(
2883            *saw_refusal_event.lock().unwrap(),
2884            "Refusal event should be emitted for user prompt refusals"
2885        );
2886
2887        // Verify the message was truncated (user prompt refusal)
2888        thread.read_with(cx, |thread, cx| {
2889            assert_eq!(thread.to_markdown(cx), "");
2890        });
2891    }
2892
2893    #[gpui::test]
2894    async fn test_refusal(cx: &mut TestAppContext) {
2895        init_test(cx);
2896        let fs = FakeFs::new(cx.background_executor.clone());
2897        fs.insert_tree(path!("/"), json!({})).await;
2898        let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2899
2900        let refuse_next = Arc::new(AtomicBool::new(false));
2901        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2902            let refuse_next = refuse_next.clone();
2903            move |request, thread, mut cx| {
2904                let refuse_next = refuse_next.clone();
2905                async move {
2906                    if refuse_next.load(SeqCst) {
2907                        return Ok(acp::PromptResponse {
2908                            stop_reason: acp::StopReason::Refusal,
2909                        });
2910                    }
2911
2912                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2913                        panic!("expected text content block");
2914                    };
2915                    thread.update(&mut cx, |thread, cx| {
2916                        thread
2917                            .handle_session_update(
2918                                acp::SessionUpdate::AgentMessageChunk {
2919                                    content: content.text.to_uppercase().into(),
2920                                },
2921                                cx,
2922                            )
2923                            .unwrap();
2924                    })?;
2925                    Ok(acp::PromptResponse {
2926                        stop_reason: acp::StopReason::EndTurn,
2927                    })
2928                }
2929                .boxed_local()
2930            }
2931        }));
2932        let thread = cx
2933            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2934            .await
2935            .unwrap();
2936
2937        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2938            .await
2939            .unwrap();
2940        thread.read_with(cx, |thread, cx| {
2941            assert_eq!(
2942                thread.to_markdown(cx),
2943                indoc! {"
2944                    ## User
2945
2946                    hello
2947
2948                    ## Assistant
2949
2950                    HELLO
2951
2952                "}
2953            );
2954        });
2955
2956        // Simulate refusing the second message. The message should be truncated
2957        // when a user prompt is refused.
2958        refuse_next.store(true, SeqCst);
2959        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
2960            .await
2961            .unwrap();
2962        thread.read_with(cx, |thread, cx| {
2963            assert_eq!(
2964                thread.to_markdown(cx),
2965                indoc! {"
2966                    ## User
2967
2968                    hello
2969
2970                    ## Assistant
2971
2972                    HELLO
2973
2974                "}
2975            );
2976        });
2977    }
2978
2979    async fn run_until_first_tool_call(
2980        thread: &Entity<AcpThread>,
2981        cx: &mut TestAppContext,
2982    ) -> usize {
2983        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2984
2985        let subscription = cx.update(|cx| {
2986            cx.subscribe(thread, move |thread, _, cx| {
2987                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2988                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2989                        return tx.try_send(ix).unwrap();
2990                    }
2991                }
2992            })
2993        });
2994
2995        select! {
2996            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2997                panic!("Timeout waiting for tool call")
2998            }
2999            ix = rx.next().fuse() => {
3000                drop(subscription);
3001                ix.unwrap()
3002            }
3003        }
3004    }
3005
3006    #[derive(Clone, Default)]
3007    struct FakeAgentConnection {
3008        auth_methods: Vec<acp::AuthMethod>,
3009        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3010        on_user_message: Option<
3011            Rc<
3012                dyn Fn(
3013                        acp::PromptRequest,
3014                        WeakEntity<AcpThread>,
3015                        AsyncApp,
3016                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3017                    + 'static,
3018            >,
3019        >,
3020    }
3021
3022    impl FakeAgentConnection {
3023        fn new() -> Self {
3024            Self {
3025                auth_methods: Vec::new(),
3026                on_user_message: None,
3027                sessions: Arc::default(),
3028            }
3029        }
3030
3031        #[expect(unused)]
3032        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3033            self.auth_methods = auth_methods;
3034            self
3035        }
3036
3037        fn on_user_message(
3038            mut self,
3039            handler: impl Fn(
3040                acp::PromptRequest,
3041                WeakEntity<AcpThread>,
3042                AsyncApp,
3043            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3044            + 'static,
3045        ) -> Self {
3046            self.on_user_message.replace(Rc::new(handler));
3047            self
3048        }
3049    }
3050
3051    impl AgentConnection for FakeAgentConnection {
3052        fn auth_methods(&self) -> &[acp::AuthMethod] {
3053            &self.auth_methods
3054        }
3055
3056        fn new_thread(
3057            self: Rc<Self>,
3058            project: Entity<Project>,
3059            _cwd: &Path,
3060            cx: &mut App,
3061        ) -> Task<gpui::Result<Entity<AcpThread>>> {
3062            let session_id = acp::SessionId(
3063                rand::thread_rng()
3064                    .sample_iter(&rand::distributions::Alphanumeric)
3065                    .take(7)
3066                    .map(char::from)
3067                    .collect::<String>()
3068                    .into(),
3069            );
3070            let action_log = cx.new(|_| ActionLog::new(project.clone()));
3071            let thread = cx.new(|cx| {
3072                AcpThread::new(
3073                    "Test",
3074                    self.clone(),
3075                    project,
3076                    action_log,
3077                    session_id.clone(),
3078                    watch::Receiver::constant(acp::PromptCapabilities {
3079                        image: true,
3080                        audio: true,
3081                        embedded_context: true,
3082                    }),
3083                    vec![],
3084                    cx,
3085                )
3086            });
3087            self.sessions.lock().insert(session_id, thread.downgrade());
3088            Task::ready(Ok(thread))
3089        }
3090
3091        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3092            if self.auth_methods().iter().any(|m| m.id == method) {
3093                Task::ready(Ok(()))
3094            } else {
3095                Task::ready(Err(anyhow!("Invalid Auth Method")))
3096            }
3097        }
3098
3099        fn prompt(
3100            &self,
3101            _id: Option<UserMessageId>,
3102            params: acp::PromptRequest,
3103            cx: &mut App,
3104        ) -> Task<gpui::Result<acp::PromptResponse>> {
3105            let sessions = self.sessions.lock();
3106            let thread = sessions.get(&params.session_id).unwrap();
3107            if let Some(handler) = &self.on_user_message {
3108                let handler = handler.clone();
3109                let thread = thread.clone();
3110                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3111            } else {
3112                Task::ready(Ok(acp::PromptResponse {
3113                    stop_reason: acp::StopReason::EndTurn,
3114                }))
3115            }
3116        }
3117
3118        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3119            let sessions = self.sessions.lock();
3120            let thread = sessions.get(session_id).unwrap().clone();
3121
3122            cx.spawn(async move |cx| {
3123                thread
3124                    .update(cx, |thread, cx| thread.cancel(cx))
3125                    .unwrap()
3126                    .await
3127            })
3128            .detach();
3129        }
3130
3131        fn truncate(
3132            &self,
3133            session_id: &acp::SessionId,
3134            _cx: &App,
3135        ) -> Option<Rc<dyn AgentSessionTruncate>> {
3136            Some(Rc::new(FakeAgentSessionEditor {
3137                _session_id: session_id.clone(),
3138            }))
3139        }
3140
3141        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3142            self
3143        }
3144    }
3145
3146    struct FakeAgentSessionEditor {
3147        _session_id: acp::SessionId,
3148    }
3149
3150    impl AgentSessionTruncate for FakeAgentSessionEditor {
3151        fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3152            Task::ready(Ok(()))
3153        }
3154    }
3155}