acp_thread.rs

   1mod connection;
   2mod diff;
   3mod mention;
   4mod terminal;
   5
   6use ::terminal::terminal_settings::TerminalSettings;
   7use agent_settings::AgentSettings;
   8use collections::HashSet;
   9pub use connection::*;
  10pub use diff::*;
  11use language::language_settings::FormatOnSave;
  12pub use mention::*;
  13use project::lsp_store::{FormatTrigger, LspFormatTarget};
  14use serde::{Deserialize, Serialize};
  15use settings::Settings as _;
  16use task::{Shell, ShellBuilder};
  17pub use terminal::*;
  18
  19use action_log::ActionLog;
  20use agent_client_protocol::{self as acp};
  21use anyhow::{Context as _, Result, anyhow};
  22use editor::Bias;
  23use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  24use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  25use itertools::Itertools;
  26use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
  27use markdown::Markdown;
  28use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
  29use std::collections::HashMap;
  30use std::error::Error;
  31use std::fmt::{Formatter, Write};
  32use std::ops::Range;
  33use std::process::ExitStatus;
  34use std::rc::Rc;
  35use std::time::{Duration, Instant};
  36use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
  37use ui::App;
  38use util::{ResultExt, get_default_system_shell};
  39use uuid::Uuid;
  40
  41#[derive(Debug)]
  42pub struct UserMessage {
  43    pub id: Option<UserMessageId>,
  44    pub content: ContentBlock,
  45    pub chunks: Vec<acp::ContentBlock>,
  46    pub checkpoint: Option<Checkpoint>,
  47}
  48
  49#[derive(Debug)]
  50pub struct Checkpoint {
  51    git_checkpoint: GitStoreCheckpoint,
  52    pub show: bool,
  53}
  54
  55impl UserMessage {
  56    fn to_markdown(&self, cx: &App) -> String {
  57        let mut markdown = String::new();
  58        if self
  59            .checkpoint
  60            .as_ref()
  61            .is_some_and(|checkpoint| checkpoint.show)
  62        {
  63            writeln!(markdown, "## User (checkpoint)").unwrap();
  64        } else {
  65            writeln!(markdown, "## User").unwrap();
  66        }
  67        writeln!(markdown).unwrap();
  68        writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
  69        writeln!(markdown).unwrap();
  70        markdown
  71    }
  72}
  73
  74#[derive(Debug, PartialEq)]
  75pub struct AssistantMessage {
  76    pub chunks: Vec<AssistantMessageChunk>,
  77}
  78
  79impl AssistantMessage {
  80    pub fn to_markdown(&self, cx: &App) -> String {
  81        format!(
  82            "## Assistant\n\n{}\n\n",
  83            self.chunks
  84                .iter()
  85                .map(|chunk| chunk.to_markdown(cx))
  86                .join("\n\n")
  87        )
  88    }
  89}
  90
  91#[derive(Debug, PartialEq)]
  92pub enum AssistantMessageChunk {
  93    Message { block: ContentBlock },
  94    Thought { block: ContentBlock },
  95}
  96
  97impl AssistantMessageChunk {
  98    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
  99        Self::Message {
 100            block: ContentBlock::new(chunk.into(), language_registry, cx),
 101        }
 102    }
 103
 104    fn to_markdown(&self, cx: &App) -> String {
 105        match self {
 106            Self::Message { block } => block.to_markdown(cx).to_string(),
 107            Self::Thought { block } => {
 108                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 109            }
 110        }
 111    }
 112}
 113
 114#[derive(Debug)]
 115pub enum AgentThreadEntry {
 116    UserMessage(UserMessage),
 117    AssistantMessage(AssistantMessage),
 118    ToolCall(ToolCall),
 119}
 120
 121impl AgentThreadEntry {
 122    pub fn to_markdown(&self, cx: &App) -> String {
 123        match self {
 124            Self::UserMessage(message) => message.to_markdown(cx),
 125            Self::AssistantMessage(message) => message.to_markdown(cx),
 126            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 127        }
 128    }
 129
 130    pub fn user_message(&self) -> Option<&UserMessage> {
 131        if let AgentThreadEntry::UserMessage(message) = self {
 132            Some(message)
 133        } else {
 134            None
 135        }
 136    }
 137
 138    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 139        if let AgentThreadEntry::ToolCall(call) = self {
 140            itertools::Either::Left(call.diffs())
 141        } else {
 142            itertools::Either::Right(std::iter::empty())
 143        }
 144    }
 145
 146    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 147        if let AgentThreadEntry::ToolCall(call) = self {
 148            itertools::Either::Left(call.terminals())
 149        } else {
 150            itertools::Either::Right(std::iter::empty())
 151        }
 152    }
 153
 154    pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
 155        if let AgentThreadEntry::ToolCall(ToolCall {
 156            locations,
 157            resolved_locations,
 158            ..
 159        }) = self
 160        {
 161            Some((
 162                locations.get(ix)?.clone(),
 163                resolved_locations.get(ix)?.clone()?,
 164            ))
 165        } else {
 166            None
 167        }
 168    }
 169}
 170
 171#[derive(Debug)]
 172pub struct ToolCall {
 173    pub id: acp::ToolCallId,
 174    pub label: Entity<Markdown>,
 175    pub kind: acp::ToolKind,
 176    pub content: Vec<ToolCallContent>,
 177    pub status: ToolCallStatus,
 178    pub locations: Vec<acp::ToolCallLocation>,
 179    pub resolved_locations: Vec<Option<AgentLocation>>,
 180    pub raw_input: Option<serde_json::Value>,
 181    pub raw_output: Option<serde_json::Value>,
 182}
 183
 184impl ToolCall {
 185    fn from_acp(
 186        tool_call: acp::ToolCall,
 187        status: ToolCallStatus,
 188        language_registry: Arc<LanguageRegistry>,
 189        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 190        cx: &mut App,
 191    ) -> Result<Self> {
 192        let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
 193            first_line.to_owned() + ""
 194        } else {
 195            tool_call.title
 196        };
 197        let mut content = Vec::with_capacity(tool_call.content.len());
 198        for item in tool_call.content {
 199            content.push(ToolCallContent::from_acp(
 200                item,
 201                language_registry.clone(),
 202                terminals,
 203                cx,
 204            )?);
 205        }
 206
 207        let result = Self {
 208            id: tool_call.id,
 209            label: cx
 210                .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
 211            kind: tool_call.kind,
 212            content,
 213            locations: tool_call.locations,
 214            resolved_locations: Vec::default(),
 215            status,
 216            raw_input: tool_call.raw_input,
 217            raw_output: tool_call.raw_output,
 218        };
 219        Ok(result)
 220    }
 221
 222    fn update_fields(
 223        &mut self,
 224        fields: acp::ToolCallUpdateFields,
 225        language_registry: Arc<LanguageRegistry>,
 226        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 227        cx: &mut App,
 228    ) -> Result<()> {
 229        let acp::ToolCallUpdateFields {
 230            kind,
 231            status,
 232            title,
 233            content,
 234            locations,
 235            raw_input,
 236            raw_output,
 237        } = fields;
 238
 239        if let Some(kind) = kind {
 240            self.kind = kind;
 241        }
 242
 243        if let Some(status) = status {
 244            self.status = status.into();
 245        }
 246
 247        if let Some(title) = title {
 248            self.label.update(cx, |label, cx| {
 249                if let Some((first_line, _)) = title.split_once("\n") {
 250                    label.replace(first_line.to_owned() + "", cx)
 251                } else {
 252                    label.replace(title, cx);
 253                }
 254            });
 255        }
 256
 257        if let Some(content) = content {
 258            let new_content_len = content.len();
 259            let mut content = content.into_iter();
 260
 261            // Reuse existing content if we can
 262            for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
 263                old.update_from_acp(new, language_registry.clone(), terminals, cx)?;
 264            }
 265            for new in content {
 266                self.content.push(ToolCallContent::from_acp(
 267                    new,
 268                    language_registry.clone(),
 269                    terminals,
 270                    cx,
 271                )?)
 272            }
 273            self.content.truncate(new_content_len);
 274        }
 275
 276        if let Some(locations) = locations {
 277            self.locations = locations;
 278        }
 279
 280        if let Some(raw_input) = raw_input {
 281            self.raw_input = Some(raw_input);
 282        }
 283
 284        if let Some(raw_output) = raw_output {
 285            if self.content.is_empty()
 286                && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
 287            {
 288                self.content
 289                    .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
 290                        markdown,
 291                    }));
 292            }
 293            self.raw_output = Some(raw_output);
 294        }
 295        Ok(())
 296    }
 297
 298    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 299        self.content.iter().filter_map(|content| match content {
 300            ToolCallContent::Diff(diff) => Some(diff),
 301            ToolCallContent::ContentBlock(_) => None,
 302            ToolCallContent::Terminal(_) => None,
 303        })
 304    }
 305
 306    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 307        self.content.iter().filter_map(|content| match content {
 308            ToolCallContent::Terminal(terminal) => Some(terminal),
 309            ToolCallContent::ContentBlock(_) => None,
 310            ToolCallContent::Diff(_) => None,
 311        })
 312    }
 313
 314    fn to_markdown(&self, cx: &App) -> String {
 315        let mut markdown = format!(
 316            "**Tool Call: {}**\nStatus: {}\n\n",
 317            self.label.read(cx).source(),
 318            self.status
 319        );
 320        for content in &self.content {
 321            markdown.push_str(content.to_markdown(cx).as_str());
 322            markdown.push_str("\n\n");
 323        }
 324        markdown
 325    }
 326
 327    async fn resolve_location(
 328        location: acp::ToolCallLocation,
 329        project: WeakEntity<Project>,
 330        cx: &mut AsyncApp,
 331    ) -> Option<AgentLocation> {
 332        let buffer = project
 333            .update(cx, |project, cx| {
 334                project
 335                    .project_path_for_absolute_path(&location.path, cx)
 336                    .map(|path| project.open_buffer(path, cx))
 337            })
 338            .ok()??;
 339        let buffer = buffer.await.log_err()?;
 340        let position = buffer
 341            .update(cx, |buffer, _| {
 342                if let Some(row) = location.line {
 343                    let snapshot = buffer.snapshot();
 344                    let column = snapshot.indent_size_for_line(row).len;
 345                    let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
 346                    snapshot.anchor_before(point)
 347                } else {
 348                    Anchor::MIN
 349                }
 350            })
 351            .ok()?;
 352
 353        Some(AgentLocation {
 354            buffer: buffer.downgrade(),
 355            position,
 356        })
 357    }
 358
 359    fn resolve_locations(
 360        &self,
 361        project: Entity<Project>,
 362        cx: &mut App,
 363    ) -> Task<Vec<Option<AgentLocation>>> {
 364        let locations = self.locations.clone();
 365        project.update(cx, |_, cx| {
 366            cx.spawn(async move |project, cx| {
 367                let mut new_locations = Vec::new();
 368                for location in locations {
 369                    new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
 370                }
 371                new_locations
 372            })
 373        })
 374    }
 375}
 376
 377#[derive(Debug)]
 378pub enum ToolCallStatus {
 379    /// The tool call hasn't started running yet, but we start showing it to
 380    /// the user.
 381    Pending,
 382    /// The tool call is waiting for confirmation from the user.
 383    WaitingForConfirmation {
 384        options: Vec<acp::PermissionOption>,
 385        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 386    },
 387    /// The tool call is currently running.
 388    InProgress,
 389    /// The tool call completed successfully.
 390    Completed,
 391    /// The tool call failed.
 392    Failed,
 393    /// The user rejected the tool call.
 394    Rejected,
 395    /// The user canceled generation so the tool call was canceled.
 396    Canceled,
 397}
 398
 399impl From<acp::ToolCallStatus> for ToolCallStatus {
 400    fn from(status: acp::ToolCallStatus) -> Self {
 401        match status {
 402            acp::ToolCallStatus::Pending => Self::Pending,
 403            acp::ToolCallStatus::InProgress => Self::InProgress,
 404            acp::ToolCallStatus::Completed => Self::Completed,
 405            acp::ToolCallStatus::Failed => Self::Failed,
 406        }
 407    }
 408}
 409
 410impl Display for ToolCallStatus {
 411    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 412        write!(
 413            f,
 414            "{}",
 415            match self {
 416                ToolCallStatus::Pending => "Pending",
 417                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 418                ToolCallStatus::InProgress => "In Progress",
 419                ToolCallStatus::Completed => "Completed",
 420                ToolCallStatus::Failed => "Failed",
 421                ToolCallStatus::Rejected => "Rejected",
 422                ToolCallStatus::Canceled => "Canceled",
 423            }
 424        )
 425    }
 426}
 427
 428#[derive(Debug, PartialEq, Clone)]
 429pub enum ContentBlock {
 430    Empty,
 431    Markdown { markdown: Entity<Markdown> },
 432    ResourceLink { resource_link: acp::ResourceLink },
 433}
 434
 435impl ContentBlock {
 436    pub fn new(
 437        block: acp::ContentBlock,
 438        language_registry: &Arc<LanguageRegistry>,
 439        cx: &mut App,
 440    ) -> Self {
 441        let mut this = Self::Empty;
 442        this.append(block, language_registry, cx);
 443        this
 444    }
 445
 446    pub fn new_combined(
 447        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 448        language_registry: Arc<LanguageRegistry>,
 449        cx: &mut App,
 450    ) -> Self {
 451        let mut this = Self::Empty;
 452        for block in blocks {
 453            this.append(block, &language_registry, cx);
 454        }
 455        this
 456    }
 457
 458    pub fn append(
 459        &mut self,
 460        block: acp::ContentBlock,
 461        language_registry: &Arc<LanguageRegistry>,
 462        cx: &mut App,
 463    ) {
 464        if matches!(self, ContentBlock::Empty)
 465            && let acp::ContentBlock::ResourceLink(resource_link) = block
 466        {
 467            *self = ContentBlock::ResourceLink { resource_link };
 468            return;
 469        }
 470
 471        let new_content = self.block_string_contents(block);
 472
 473        match self {
 474            ContentBlock::Empty => {
 475                *self = Self::create_markdown_block(new_content, language_registry, cx);
 476            }
 477            ContentBlock::Markdown { markdown } => {
 478                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 479            }
 480            ContentBlock::ResourceLink { resource_link } => {
 481                let existing_content = Self::resource_link_md(&resource_link.uri);
 482                let combined = format!("{}\n{}", existing_content, new_content);
 483
 484                *self = Self::create_markdown_block(combined, language_registry, cx);
 485            }
 486        }
 487    }
 488
 489    fn create_markdown_block(
 490        content: String,
 491        language_registry: &Arc<LanguageRegistry>,
 492        cx: &mut App,
 493    ) -> ContentBlock {
 494        ContentBlock::Markdown {
 495            markdown: cx
 496                .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
 497        }
 498    }
 499
 500    fn block_string_contents(&self, block: acp::ContentBlock) -> String {
 501        match block {
 502            acp::ContentBlock::Text(text_content) => text_content.text,
 503            acp::ContentBlock::ResourceLink(resource_link) => {
 504                Self::resource_link_md(&resource_link.uri)
 505            }
 506            acp::ContentBlock::Resource(acp::EmbeddedResource {
 507                resource:
 508                    acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
 509                        uri,
 510                        ..
 511                    }),
 512                ..
 513            }) => Self::resource_link_md(&uri),
 514            acp::ContentBlock::Image(image) => Self::image_md(&image),
 515            acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
 516        }
 517    }
 518
 519    fn resource_link_md(uri: &str) -> String {
 520        if let Some(uri) = MentionUri::parse(uri).log_err() {
 521            uri.as_link().to_string()
 522        } else {
 523            uri.to_string()
 524        }
 525    }
 526
 527    fn image_md(_image: &acp::ImageContent) -> String {
 528        "`Image`".into()
 529    }
 530
 531    pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 532        match self {
 533            ContentBlock::Empty => "",
 534            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 535            ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
 536        }
 537    }
 538
 539    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 540        match self {
 541            ContentBlock::Empty => None,
 542            ContentBlock::Markdown { markdown } => Some(markdown),
 543            ContentBlock::ResourceLink { .. } => None,
 544        }
 545    }
 546
 547    pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
 548        match self {
 549            ContentBlock::ResourceLink { resource_link } => Some(resource_link),
 550            _ => None,
 551        }
 552    }
 553}
 554
 555#[derive(Debug)]
 556pub enum ToolCallContent {
 557    ContentBlock(ContentBlock),
 558    Diff(Entity<Diff>),
 559    Terminal(Entity<Terminal>),
 560}
 561
 562impl ToolCallContent {
 563    pub fn from_acp(
 564        content: acp::ToolCallContent,
 565        language_registry: Arc<LanguageRegistry>,
 566        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 567        cx: &mut App,
 568    ) -> Result<Self> {
 569        match content {
 570            acp::ToolCallContent::Content { content } => Ok(Self::ContentBlock(ContentBlock::new(
 571                content,
 572                &language_registry,
 573                cx,
 574            ))),
 575            acp::ToolCallContent::Diff { diff } => Ok(Self::Diff(cx.new(|cx| {
 576                Diff::finalized(
 577                    diff.path.to_string_lossy().into_owned(),
 578                    diff.old_text,
 579                    diff.new_text,
 580                    language_registry,
 581                    cx,
 582                )
 583            }))),
 584            acp::ToolCallContent::Terminal { terminal_id } => terminals
 585                .get(&terminal_id)
 586                .cloned()
 587                .map(Self::Terminal)
 588                .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
 589        }
 590    }
 591
 592    pub fn update_from_acp(
 593        &mut self,
 594        new: acp::ToolCallContent,
 595        language_registry: Arc<LanguageRegistry>,
 596        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 597        cx: &mut App,
 598    ) -> Result<()> {
 599        let needs_update = match (&self, &new) {
 600            (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
 601                old_diff.read(cx).needs_update(
 602                    new_diff.old_text.as_deref().unwrap_or(""),
 603                    &new_diff.new_text,
 604                    cx,
 605                )
 606            }
 607            _ => true,
 608        };
 609
 610        if needs_update {
 611            *self = Self::from_acp(new, language_registry, terminals, cx)?;
 612        }
 613        Ok(())
 614    }
 615
 616    pub fn to_markdown(&self, cx: &App) -> String {
 617        match self {
 618            Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
 619            Self::Diff(diff) => diff.read(cx).to_markdown(cx),
 620            Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
 621        }
 622    }
 623}
 624
 625#[derive(Debug, PartialEq)]
 626pub enum ToolCallUpdate {
 627    UpdateFields(acp::ToolCallUpdate),
 628    UpdateDiff(ToolCallUpdateDiff),
 629    UpdateTerminal(ToolCallUpdateTerminal),
 630}
 631
 632impl ToolCallUpdate {
 633    fn id(&self) -> &acp::ToolCallId {
 634        match self {
 635            Self::UpdateFields(update) => &update.id,
 636            Self::UpdateDiff(diff) => &diff.id,
 637            Self::UpdateTerminal(terminal) => &terminal.id,
 638        }
 639    }
 640}
 641
 642impl From<acp::ToolCallUpdate> for ToolCallUpdate {
 643    fn from(update: acp::ToolCallUpdate) -> Self {
 644        Self::UpdateFields(update)
 645    }
 646}
 647
 648impl From<ToolCallUpdateDiff> for ToolCallUpdate {
 649    fn from(diff: ToolCallUpdateDiff) -> Self {
 650        Self::UpdateDiff(diff)
 651    }
 652}
 653
 654#[derive(Debug, PartialEq)]
 655pub struct ToolCallUpdateDiff {
 656    pub id: acp::ToolCallId,
 657    pub diff: Entity<Diff>,
 658}
 659
 660impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
 661    fn from(terminal: ToolCallUpdateTerminal) -> Self {
 662        Self::UpdateTerminal(terminal)
 663    }
 664}
 665
 666#[derive(Debug, PartialEq)]
 667pub struct ToolCallUpdateTerminal {
 668    pub id: acp::ToolCallId,
 669    pub terminal: Entity<Terminal>,
 670}
 671
 672#[derive(Debug, Default)]
 673pub struct Plan {
 674    pub entries: Vec<PlanEntry>,
 675}
 676
 677#[derive(Debug)]
 678pub struct PlanStats<'a> {
 679    pub in_progress_entry: Option<&'a PlanEntry>,
 680    pub pending: u32,
 681    pub completed: u32,
 682}
 683
 684impl Plan {
 685    pub fn is_empty(&self) -> bool {
 686        self.entries.is_empty()
 687    }
 688
 689    pub fn stats(&self) -> PlanStats<'_> {
 690        let mut stats = PlanStats {
 691            in_progress_entry: None,
 692            pending: 0,
 693            completed: 0,
 694        };
 695
 696        for entry in &self.entries {
 697            match &entry.status {
 698                acp::PlanEntryStatus::Pending => {
 699                    stats.pending += 1;
 700                }
 701                acp::PlanEntryStatus::InProgress => {
 702                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 703                }
 704                acp::PlanEntryStatus::Completed => {
 705                    stats.completed += 1;
 706                }
 707            }
 708        }
 709
 710        stats
 711    }
 712}
 713
 714#[derive(Debug)]
 715pub struct PlanEntry {
 716    pub content: Entity<Markdown>,
 717    pub priority: acp::PlanEntryPriority,
 718    pub status: acp::PlanEntryStatus,
 719}
 720
 721impl PlanEntry {
 722    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 723        Self {
 724            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 725            priority: entry.priority,
 726            status: entry.status,
 727        }
 728    }
 729}
 730
 731#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 732pub struct TokenUsage {
 733    pub max_tokens: u64,
 734    pub used_tokens: u64,
 735}
 736
 737impl TokenUsage {
 738    pub fn ratio(&self) -> TokenUsageRatio {
 739        #[cfg(debug_assertions)]
 740        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 741            .unwrap_or("0.8".to_string())
 742            .parse()
 743            .unwrap();
 744        #[cfg(not(debug_assertions))]
 745        let warning_threshold: f32 = 0.8;
 746
 747        // When the maximum is unknown because there is no selected model,
 748        // avoid showing the token limit warning.
 749        if self.max_tokens == 0 {
 750            TokenUsageRatio::Normal
 751        } else if self.used_tokens >= self.max_tokens {
 752            TokenUsageRatio::Exceeded
 753        } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
 754            TokenUsageRatio::Warning
 755        } else {
 756            TokenUsageRatio::Normal
 757        }
 758    }
 759}
 760
 761#[derive(Debug, Clone, PartialEq, Eq)]
 762pub enum TokenUsageRatio {
 763    Normal,
 764    Warning,
 765    Exceeded,
 766}
 767
 768#[derive(Debug, Clone)]
 769pub struct RetryStatus {
 770    pub last_error: SharedString,
 771    pub attempt: usize,
 772    pub max_attempts: usize,
 773    pub started_at: Instant,
 774    pub duration: Duration,
 775}
 776
 777pub struct AcpThread {
 778    title: SharedString,
 779    entries: Vec<AgentThreadEntry>,
 780    plan: Plan,
 781    project: Entity<Project>,
 782    action_log: Entity<ActionLog>,
 783    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 784    send_task: Option<Task<()>>,
 785    connection: Rc<dyn AgentConnection>,
 786    session_id: acp::SessionId,
 787    token_usage: Option<TokenUsage>,
 788    prompt_capabilities: acp::PromptCapabilities,
 789    _observe_prompt_capabilities: Task<anyhow::Result<()>>,
 790    terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
 791}
 792
 793#[derive(Debug)]
 794pub enum AcpThreadEvent {
 795    NewEntry,
 796    TitleUpdated,
 797    TokenUsageUpdated,
 798    EntryUpdated(usize),
 799    EntriesRemoved(Range<usize>),
 800    ToolAuthorizationRequired,
 801    Retry(RetryStatus),
 802    Stopped,
 803    Error,
 804    LoadError(LoadError),
 805    PromptCapabilitiesUpdated,
 806    Refusal,
 807    AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
 808    ModeUpdated(acp::SessionModeId),
 809}
 810
 811impl EventEmitter<AcpThreadEvent> for AcpThread {}
 812
 813#[derive(PartialEq, Eq, Debug)]
 814pub enum ThreadStatus {
 815    Idle,
 816    Generating,
 817}
 818
 819#[derive(Debug, Clone)]
 820pub enum LoadError {
 821    Unsupported {
 822        command: SharedString,
 823        current_version: SharedString,
 824        minimum_version: SharedString,
 825    },
 826    FailedToInstall(SharedString),
 827    Exited {
 828        status: ExitStatus,
 829    },
 830    Other(SharedString),
 831}
 832
 833impl Display for LoadError {
 834    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 835        match self {
 836            LoadError::Unsupported {
 837                command: path,
 838                current_version,
 839                minimum_version,
 840            } => {
 841                write!(
 842                    f,
 843                    "version {current_version} from {path} is not supported (need at least {minimum_version})"
 844                )
 845            }
 846            LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
 847            LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
 848            LoadError::Other(msg) => write!(f, "{msg}"),
 849        }
 850    }
 851}
 852
 853impl Error for LoadError {}
 854
 855impl AcpThread {
 856    pub fn new(
 857        title: impl Into<SharedString>,
 858        connection: Rc<dyn AgentConnection>,
 859        project: Entity<Project>,
 860        action_log: Entity<ActionLog>,
 861        session_id: acp::SessionId,
 862        mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
 863        cx: &mut Context<Self>,
 864    ) -> Self {
 865        let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
 866        let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
 867            loop {
 868                let caps = prompt_capabilities_rx.recv().await?;
 869                this.update(cx, |this, cx| {
 870                    this.prompt_capabilities = caps;
 871                    cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
 872                })?;
 873            }
 874        });
 875
 876        Self {
 877            action_log,
 878            shared_buffers: Default::default(),
 879            entries: Default::default(),
 880            plan: Default::default(),
 881            title: title.into(),
 882            project,
 883            send_task: None,
 884            connection,
 885            session_id,
 886            token_usage: None,
 887            prompt_capabilities,
 888            _observe_prompt_capabilities: task,
 889            terminals: HashMap::default(),
 890        }
 891    }
 892
 893    pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
 894        self.prompt_capabilities.clone()
 895    }
 896
 897    pub fn connection(&self) -> &Rc<dyn AgentConnection> {
 898        &self.connection
 899    }
 900
 901    pub fn action_log(&self) -> &Entity<ActionLog> {
 902        &self.action_log
 903    }
 904
 905    pub fn project(&self) -> &Entity<Project> {
 906        &self.project
 907    }
 908
 909    pub fn title(&self) -> SharedString {
 910        self.title.clone()
 911    }
 912
 913    pub fn entries(&self) -> &[AgentThreadEntry] {
 914        &self.entries
 915    }
 916
 917    pub fn session_id(&self) -> &acp::SessionId {
 918        &self.session_id
 919    }
 920
 921    pub fn status(&self) -> ThreadStatus {
 922        if self.send_task.is_some() {
 923            ThreadStatus::Generating
 924        } else {
 925            ThreadStatus::Idle
 926        }
 927    }
 928
 929    pub fn token_usage(&self) -> Option<&TokenUsage> {
 930        self.token_usage.as_ref()
 931    }
 932
 933    pub fn has_pending_edit_tool_calls(&self) -> bool {
 934        for entry in self.entries.iter().rev() {
 935            match entry {
 936                AgentThreadEntry::UserMessage(_) => return false,
 937                AgentThreadEntry::ToolCall(
 938                    call @ ToolCall {
 939                        status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
 940                        ..
 941                    },
 942                ) if call.diffs().next().is_some() => {
 943                    return true;
 944                }
 945                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 946            }
 947        }
 948
 949        false
 950    }
 951
 952    pub fn used_tools_since_last_user_message(&self) -> bool {
 953        for entry in self.entries.iter().rev() {
 954            match entry {
 955                AgentThreadEntry::UserMessage(..) => return false,
 956                AgentThreadEntry::AssistantMessage(..) => continue,
 957                AgentThreadEntry::ToolCall(..) => return true,
 958            }
 959        }
 960
 961        false
 962    }
 963
 964    pub fn handle_session_update(
 965        &mut self,
 966        update: acp::SessionUpdate,
 967        cx: &mut Context<Self>,
 968    ) -> Result<(), acp::Error> {
 969        match update {
 970            acp::SessionUpdate::UserMessageChunk { content } => {
 971                self.push_user_content_block(None, content, cx);
 972            }
 973            acp::SessionUpdate::AgentMessageChunk { content } => {
 974                self.push_assistant_content_block(content, false, cx);
 975            }
 976            acp::SessionUpdate::AgentThoughtChunk { content } => {
 977                self.push_assistant_content_block(content, true, cx);
 978            }
 979            acp::SessionUpdate::ToolCall(tool_call) => {
 980                self.upsert_tool_call(tool_call, cx)?;
 981            }
 982            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 983                self.update_tool_call(tool_call_update, cx)?;
 984            }
 985            acp::SessionUpdate::Plan(plan) => {
 986                self.update_plan(plan, cx);
 987            }
 988            acp::SessionUpdate::AvailableCommandsUpdate { available_commands } => {
 989                cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands))
 990            }
 991            acp::SessionUpdate::CurrentModeUpdate { current_mode_id } => {
 992                cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id))
 993            }
 994        }
 995        Ok(())
 996    }
 997
 998    pub fn push_user_content_block(
 999        &mut self,
1000        message_id: Option<UserMessageId>,
1001        chunk: acp::ContentBlock,
1002        cx: &mut Context<Self>,
1003    ) {
1004        let language_registry = self.project.read(cx).languages().clone();
1005        let entries_len = self.entries.len();
1006
1007        if let Some(last_entry) = self.entries.last_mut()
1008            && let AgentThreadEntry::UserMessage(UserMessage {
1009                id,
1010                content,
1011                chunks,
1012                ..
1013            }) = last_entry
1014        {
1015            *id = message_id.or(id.take());
1016            content.append(chunk.clone(), &language_registry, cx);
1017            chunks.push(chunk);
1018            let idx = entries_len - 1;
1019            cx.emit(AcpThreadEvent::EntryUpdated(idx));
1020        } else {
1021            let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
1022            self.push_entry(
1023                AgentThreadEntry::UserMessage(UserMessage {
1024                    id: message_id,
1025                    content,
1026                    chunks: vec![chunk],
1027                    checkpoint: None,
1028                }),
1029                cx,
1030            );
1031        }
1032    }
1033
1034    pub fn push_assistant_content_block(
1035        &mut self,
1036        chunk: acp::ContentBlock,
1037        is_thought: bool,
1038        cx: &mut Context<Self>,
1039    ) {
1040        let language_registry = self.project.read(cx).languages().clone();
1041        let entries_len = self.entries.len();
1042        if let Some(last_entry) = self.entries.last_mut()
1043            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1044        {
1045            let idx = entries_len - 1;
1046            cx.emit(AcpThreadEvent::EntryUpdated(idx));
1047            match (chunks.last_mut(), is_thought) {
1048                (Some(AssistantMessageChunk::Message { block }), false)
1049                | (Some(AssistantMessageChunk::Thought { block }), true) => {
1050                    block.append(chunk, &language_registry, cx)
1051                }
1052                _ => {
1053                    let block = ContentBlock::new(chunk, &language_registry, cx);
1054                    if is_thought {
1055                        chunks.push(AssistantMessageChunk::Thought { block })
1056                    } else {
1057                        chunks.push(AssistantMessageChunk::Message { block })
1058                    }
1059                }
1060            }
1061        } else {
1062            let block = ContentBlock::new(chunk, &language_registry, cx);
1063            let chunk = if is_thought {
1064                AssistantMessageChunk::Thought { block }
1065            } else {
1066                AssistantMessageChunk::Message { block }
1067            };
1068
1069            self.push_entry(
1070                AgentThreadEntry::AssistantMessage(AssistantMessage {
1071                    chunks: vec![chunk],
1072                }),
1073                cx,
1074            );
1075        }
1076    }
1077
1078    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1079        self.entries.push(entry);
1080        cx.emit(AcpThreadEvent::NewEntry);
1081    }
1082
1083    pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1084        self.connection.set_title(&self.session_id, cx).is_some()
1085    }
1086
1087    pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1088        if title != self.title {
1089            self.title = title.clone();
1090            cx.emit(AcpThreadEvent::TitleUpdated);
1091            if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1092                return set_title.run(title, cx);
1093            }
1094        }
1095        Task::ready(Ok(()))
1096    }
1097
1098    pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1099        self.token_usage = usage;
1100        cx.emit(AcpThreadEvent::TokenUsageUpdated);
1101    }
1102
1103    pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1104        cx.emit(AcpThreadEvent::Retry(status));
1105    }
1106
1107    pub fn update_tool_call(
1108        &mut self,
1109        update: impl Into<ToolCallUpdate>,
1110        cx: &mut Context<Self>,
1111    ) -> Result<()> {
1112        let update = update.into();
1113        let languages = self.project.read(cx).languages().clone();
1114
1115        let ix = match self.index_for_tool_call(update.id()) {
1116            Some(ix) => ix,
1117            None => {
1118                // Tool call not found - create a failed tool call entry
1119                let failed_tool_call = ToolCall {
1120                    id: update.id().clone(),
1121                    label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1122                    kind: acp::ToolKind::Fetch,
1123                    content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1124                        acp::ContentBlock::Text(acp::TextContent {
1125                            text: "Tool call not found".to_string(),
1126                            annotations: None,
1127                            meta: None,
1128                        }),
1129                        &languages,
1130                        cx,
1131                    ))],
1132                    status: ToolCallStatus::Failed,
1133                    locations: Vec::new(),
1134                    resolved_locations: Vec::new(),
1135                    raw_input: None,
1136                    raw_output: None,
1137                };
1138                self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1139                return Ok(());
1140            }
1141        };
1142        let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1143            unreachable!()
1144        };
1145
1146        match update {
1147            ToolCallUpdate::UpdateFields(update) => {
1148                let location_updated = update.fields.locations.is_some();
1149                call.update_fields(update.fields, languages, &self.terminals, cx)?;
1150                if location_updated {
1151                    self.resolve_locations(update.id, cx);
1152                }
1153            }
1154            ToolCallUpdate::UpdateDiff(update) => {
1155                call.content.clear();
1156                call.content.push(ToolCallContent::Diff(update.diff));
1157            }
1158            ToolCallUpdate::UpdateTerminal(update) => {
1159                call.content.clear();
1160                call.content
1161                    .push(ToolCallContent::Terminal(update.terminal));
1162            }
1163        }
1164
1165        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1166
1167        Ok(())
1168    }
1169
1170    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1171    pub fn upsert_tool_call(
1172        &mut self,
1173        tool_call: acp::ToolCall,
1174        cx: &mut Context<Self>,
1175    ) -> Result<(), acp::Error> {
1176        let status = tool_call.status.into();
1177        self.upsert_tool_call_inner(tool_call.into(), status, cx)
1178    }
1179
1180    /// Fails if id does not match an existing entry.
1181    pub fn upsert_tool_call_inner(
1182        &mut self,
1183        update: acp::ToolCallUpdate,
1184        status: ToolCallStatus,
1185        cx: &mut Context<Self>,
1186    ) -> Result<(), acp::Error> {
1187        let language_registry = self.project.read(cx).languages().clone();
1188        let id = update.id.clone();
1189
1190        if let Some(ix) = self.index_for_tool_call(&id) {
1191            let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1192                unreachable!()
1193            };
1194
1195            call.update_fields(update.fields, language_registry, &self.terminals, cx)?;
1196            call.status = status;
1197
1198            cx.emit(AcpThreadEvent::EntryUpdated(ix));
1199        } else {
1200            let call = ToolCall::from_acp(
1201                update.try_into()?,
1202                status,
1203                language_registry,
1204                &self.terminals,
1205                cx,
1206            )?;
1207            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1208        };
1209
1210        self.resolve_locations(id, cx);
1211        Ok(())
1212    }
1213
1214    fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1215        self.entries
1216            .iter()
1217            .enumerate()
1218            .rev()
1219            .find_map(|(index, entry)| {
1220                if let AgentThreadEntry::ToolCall(tool_call) = entry
1221                    && &tool_call.id == id
1222                {
1223                    Some(index)
1224                } else {
1225                    None
1226                }
1227            })
1228    }
1229
1230    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1231        // The tool call we are looking for is typically the last one, or very close to the end.
1232        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1233        self.entries
1234            .iter_mut()
1235            .enumerate()
1236            .rev()
1237            .find_map(|(index, tool_call)| {
1238                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1239                    && &tool_call.id == id
1240                {
1241                    Some((index, tool_call))
1242                } else {
1243                    None
1244                }
1245            })
1246    }
1247
1248    pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1249        self.entries
1250            .iter()
1251            .enumerate()
1252            .rev()
1253            .find_map(|(index, tool_call)| {
1254                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1255                    && &tool_call.id == id
1256                {
1257                    Some((index, tool_call))
1258                } else {
1259                    None
1260                }
1261            })
1262    }
1263
1264    pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1265        let project = self.project.clone();
1266        let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1267            return;
1268        };
1269        let task = tool_call.resolve_locations(project, cx);
1270        cx.spawn(async move |this, cx| {
1271            let resolved_locations = task.await;
1272            this.update(cx, |this, cx| {
1273                let project = this.project.clone();
1274                let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1275                    return;
1276                };
1277                if let Some(Some(location)) = resolved_locations.last() {
1278                    project.update(cx, |project, cx| {
1279                        if let Some(agent_location) = project.agent_location() {
1280                            let should_ignore = agent_location.buffer == location.buffer
1281                                && location
1282                                    .buffer
1283                                    .update(cx, |buffer, _| {
1284                                        let snapshot = buffer.snapshot();
1285                                        let old_position =
1286                                            agent_location.position.to_point(&snapshot);
1287                                        let new_position = location.position.to_point(&snapshot);
1288                                        // ignore this so that when we get updates from the edit tool
1289                                        // the position doesn't reset to the startof line
1290                                        old_position.row == new_position.row
1291                                            && old_position.column > new_position.column
1292                                    })
1293                                    .ok()
1294                                    .unwrap_or_default();
1295                            if !should_ignore {
1296                                project.set_agent_location(Some(location.clone()), cx);
1297                            }
1298                        }
1299                    });
1300                }
1301                if tool_call.resolved_locations != resolved_locations {
1302                    tool_call.resolved_locations = resolved_locations;
1303                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1304                }
1305            })
1306        })
1307        .detach();
1308    }
1309
1310    pub fn request_tool_call_authorization(
1311        &mut self,
1312        tool_call: acp::ToolCallUpdate,
1313        options: Vec<acp::PermissionOption>,
1314        respect_always_allow_setting: bool,
1315        cx: &mut Context<Self>,
1316    ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1317        let (tx, rx) = oneshot::channel();
1318
1319        if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1320            // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1321            // some tools would (incorrectly) continue to auto-accept.
1322            if let Some(allow_once_option) = options.iter().find_map(|option| {
1323                if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1324                    Some(option.id.clone())
1325                } else {
1326                    None
1327                }
1328            }) {
1329                self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1330                return Ok(async {
1331                    acp::RequestPermissionOutcome::Selected {
1332                        option_id: allow_once_option,
1333                    }
1334                }
1335                .boxed());
1336            }
1337        }
1338
1339        let status = ToolCallStatus::WaitingForConfirmation {
1340            options,
1341            respond_tx: tx,
1342        };
1343
1344        self.upsert_tool_call_inner(tool_call, status, cx)?;
1345        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1346
1347        let fut = async {
1348            match rx.await {
1349                Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
1350                Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1351            }
1352        }
1353        .boxed();
1354
1355        Ok(fut)
1356    }
1357
1358    pub fn authorize_tool_call(
1359        &mut self,
1360        id: acp::ToolCallId,
1361        option_id: acp::PermissionOptionId,
1362        option_kind: acp::PermissionOptionKind,
1363        cx: &mut Context<Self>,
1364    ) {
1365        let Some((ix, call)) = self.tool_call_mut(&id) else {
1366            return;
1367        };
1368
1369        let new_status = match option_kind {
1370            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1371                ToolCallStatus::Rejected
1372            }
1373            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1374                ToolCallStatus::InProgress
1375            }
1376        };
1377
1378        let curr_status = mem::replace(&mut call.status, new_status);
1379
1380        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1381            respond_tx.send(option_id).log_err();
1382        } else if cfg!(debug_assertions) {
1383            panic!("tried to authorize an already authorized tool call");
1384        }
1385
1386        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1387    }
1388
1389    pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1390        let mut first_tool_call = None;
1391
1392        for entry in self.entries.iter().rev() {
1393            match &entry {
1394                AgentThreadEntry::ToolCall(call) => {
1395                    if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1396                        first_tool_call = Some(call);
1397                    } else {
1398                        continue;
1399                    }
1400                }
1401                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1402                    // Reached the beginning of the turn.
1403                    // If we had pending permission requests in the previous turn, they have been cancelled.
1404                    break;
1405                }
1406            }
1407        }
1408
1409        first_tool_call
1410    }
1411
1412    pub fn plan(&self) -> &Plan {
1413        &self.plan
1414    }
1415
1416    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1417        let new_entries_len = request.entries.len();
1418        let mut new_entries = request.entries.into_iter();
1419
1420        // Reuse existing markdown to prevent flickering
1421        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1422            let PlanEntry {
1423                content,
1424                priority,
1425                status,
1426            } = old;
1427            content.update(cx, |old, cx| {
1428                old.replace(new.content, cx);
1429            });
1430            *priority = new.priority;
1431            *status = new.status;
1432        }
1433        for new in new_entries {
1434            self.plan.entries.push(PlanEntry::from_acp(new, cx))
1435        }
1436        self.plan.entries.truncate(new_entries_len);
1437
1438        cx.notify();
1439    }
1440
1441    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1442        self.plan
1443            .entries
1444            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1445        cx.notify();
1446    }
1447
1448    #[cfg(any(test, feature = "test-support"))]
1449    pub fn send_raw(
1450        &mut self,
1451        message: &str,
1452        cx: &mut Context<Self>,
1453    ) -> BoxFuture<'static, Result<()>> {
1454        self.send(
1455            vec![acp::ContentBlock::Text(acp::TextContent {
1456                text: message.to_string(),
1457                annotations: None,
1458                meta: None,
1459            })],
1460            cx,
1461        )
1462    }
1463
1464    pub fn send(
1465        &mut self,
1466        message: Vec<acp::ContentBlock>,
1467        cx: &mut Context<Self>,
1468    ) -> BoxFuture<'static, Result<()>> {
1469        let block = ContentBlock::new_combined(
1470            message.clone(),
1471            self.project.read(cx).languages().clone(),
1472            cx,
1473        );
1474        let request = acp::PromptRequest {
1475            prompt: message.clone(),
1476            session_id: self.session_id.clone(),
1477            meta: None,
1478        };
1479        let git_store = self.project.read(cx).git_store().clone();
1480
1481        let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1482            Some(UserMessageId::new())
1483        } else {
1484            None
1485        };
1486
1487        self.run_turn(cx, async move |this, cx| {
1488            this.update(cx, |this, cx| {
1489                this.push_entry(
1490                    AgentThreadEntry::UserMessage(UserMessage {
1491                        id: message_id.clone(),
1492                        content: block,
1493                        chunks: message,
1494                        checkpoint: None,
1495                    }),
1496                    cx,
1497                );
1498            })
1499            .ok();
1500
1501            let old_checkpoint = git_store
1502                .update(cx, |git, cx| git.checkpoint(cx))?
1503                .await
1504                .context("failed to get old checkpoint")
1505                .log_err();
1506            this.update(cx, |this, cx| {
1507                if let Some((_ix, message)) = this.last_user_message() {
1508                    message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1509                        git_checkpoint,
1510                        show: false,
1511                    });
1512                }
1513                this.connection.prompt(message_id, request, cx)
1514            })?
1515            .await
1516        })
1517    }
1518
1519    pub fn can_resume(&self, cx: &App) -> bool {
1520        self.connection.resume(&self.session_id, cx).is_some()
1521    }
1522
1523    pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1524        self.run_turn(cx, async move |this, cx| {
1525            this.update(cx, |this, cx| {
1526                this.connection
1527                    .resume(&this.session_id, cx)
1528                    .map(|resume| resume.run(cx))
1529            })?
1530            .context("resuming a session is not supported")?
1531            .await
1532        })
1533    }
1534
1535    fn run_turn(
1536        &mut self,
1537        cx: &mut Context<Self>,
1538        f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1539    ) -> BoxFuture<'static, Result<()>> {
1540        self.clear_completed_plan_entries(cx);
1541
1542        let (tx, rx) = oneshot::channel();
1543        let cancel_task = self.cancel(cx);
1544
1545        self.send_task = Some(cx.spawn(async move |this, cx| {
1546            cancel_task.await;
1547            tx.send(f(this, cx).await).ok();
1548        }));
1549
1550        cx.spawn(async move |this, cx| {
1551            let response = rx.await;
1552
1553            this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1554                .await?;
1555
1556            this.update(cx, |this, cx| {
1557                this.project
1558                    .update(cx, |project, cx| project.set_agent_location(None, cx));
1559                match response {
1560                    Ok(Err(e)) => {
1561                        this.send_task.take();
1562                        cx.emit(AcpThreadEvent::Error);
1563                        Err(e)
1564                    }
1565                    result => {
1566                        let canceled = matches!(
1567                            result,
1568                            Ok(Ok(acp::PromptResponse {
1569                                stop_reason: acp::StopReason::Cancelled,
1570                                meta: None,
1571                            }))
1572                        );
1573
1574                        // We only take the task if the current prompt wasn't canceled.
1575                        //
1576                        // This prompt may have been canceled because another one was sent
1577                        // while it was still generating. In these cases, dropping `send_task`
1578                        // would cause the next generation to be canceled.
1579                        if !canceled {
1580                            this.send_task.take();
1581                        }
1582
1583                        // Handle refusal - distinguish between user prompt and tool call refusals
1584                        if let Ok(Ok(acp::PromptResponse {
1585                            stop_reason: acp::StopReason::Refusal,
1586                            meta: _,
1587                        })) = result
1588                        {
1589                            if let Some((user_msg_ix, _)) = this.last_user_message() {
1590                                // Check if there's a completed tool call with results after the last user message
1591                                // This indicates the refusal is in response to tool output, not the user's prompt
1592                                let has_completed_tool_call_after_user_msg =
1593                                    this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1594                                        if let AgentThreadEntry::ToolCall(tool_call) = entry {
1595                                            // Check if the tool call has completed and has output
1596                                            matches!(tool_call.status, ToolCallStatus::Completed)
1597                                                && tool_call.raw_output.is_some()
1598                                        } else {
1599                                            false
1600                                        }
1601                                    });
1602
1603                                if has_completed_tool_call_after_user_msg {
1604                                    // Refusal is due to tool output - don't truncate, just notify
1605                                    // The model refused based on what the tool returned
1606                                    cx.emit(AcpThreadEvent::Refusal);
1607                                } else {
1608                                    // User prompt was refused - truncate back to before the user message
1609                                    let range = user_msg_ix..this.entries.len();
1610                                    if range.start < range.end {
1611                                        this.entries.truncate(user_msg_ix);
1612                                        cx.emit(AcpThreadEvent::EntriesRemoved(range));
1613                                    }
1614                                    cx.emit(AcpThreadEvent::Refusal);
1615                                }
1616                            } else {
1617                                // No user message found, treat as general refusal
1618                                cx.emit(AcpThreadEvent::Refusal);
1619                            }
1620                        }
1621
1622                        cx.emit(AcpThreadEvent::Stopped);
1623                        Ok(())
1624                    }
1625                }
1626            })?
1627        })
1628        .boxed()
1629    }
1630
1631    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1632        let Some(send_task) = self.send_task.take() else {
1633            return Task::ready(());
1634        };
1635
1636        for entry in self.entries.iter_mut() {
1637            if let AgentThreadEntry::ToolCall(call) = entry {
1638                let cancel = matches!(
1639                    call.status,
1640                    ToolCallStatus::Pending
1641                        | ToolCallStatus::WaitingForConfirmation { .. }
1642                        | ToolCallStatus::InProgress
1643                );
1644
1645                if cancel {
1646                    call.status = ToolCallStatus::Canceled;
1647                }
1648            }
1649        }
1650
1651        self.connection.cancel(&self.session_id, cx);
1652
1653        // Wait for the send task to complete
1654        cx.foreground_executor().spawn(send_task)
1655    }
1656
1657    /// Restores the git working tree to the state at the given checkpoint (if one exists)
1658    pub fn restore_checkpoint(
1659        &mut self,
1660        id: UserMessageId,
1661        cx: &mut Context<Self>,
1662    ) -> Task<Result<()>> {
1663        let Some((_, message)) = self.user_message_mut(&id) else {
1664            return Task::ready(Err(anyhow!("message not found")));
1665        };
1666
1667        let checkpoint = message
1668            .checkpoint
1669            .as_ref()
1670            .map(|c| c.git_checkpoint.clone());
1671        let rewind = self.rewind(id.clone(), cx);
1672        let git_store = self.project.read(cx).git_store().clone();
1673
1674        cx.spawn(async move |_, cx| {
1675            rewind.await?;
1676            if let Some(checkpoint) = checkpoint {
1677                git_store
1678                    .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1679                    .await?;
1680            }
1681
1682            Ok(())
1683        })
1684    }
1685
1686    /// Rewinds this thread to before the entry at `index`, removing it and all
1687    /// subsequent entries while rejecting any action_log changes made from that point.
1688    /// Unlike `restore_checkpoint`, this method does not restore from git.
1689    pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1690        let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1691            return Task::ready(Err(anyhow!("not supported")));
1692        };
1693
1694        cx.spawn(async move |this, cx| {
1695            cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1696            this.update(cx, |this, cx| {
1697                if let Some((ix, _)) = this.user_message_mut(&id) {
1698                    let range = ix..this.entries.len();
1699                    this.entries.truncate(ix);
1700                    cx.emit(AcpThreadEvent::EntriesRemoved(range));
1701                }
1702                this.action_log()
1703                    .update(cx, |action_log, cx| action_log.reject_all_edits(cx))
1704            })?
1705            .await;
1706            Ok(())
1707        })
1708    }
1709
1710    fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1711        let git_store = self.project.read(cx).git_store().clone();
1712
1713        let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1714            if let Some(checkpoint) = message.checkpoint.as_ref() {
1715                checkpoint.git_checkpoint.clone()
1716            } else {
1717                return Task::ready(Ok(()));
1718            }
1719        } else {
1720            return Task::ready(Ok(()));
1721        };
1722
1723        let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1724        cx.spawn(async move |this, cx| {
1725            let new_checkpoint = new_checkpoint
1726                .await
1727                .context("failed to get new checkpoint")
1728                .log_err();
1729            if let Some(new_checkpoint) = new_checkpoint {
1730                let equal = git_store
1731                    .update(cx, |git, cx| {
1732                        git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1733                    })?
1734                    .await
1735                    .unwrap_or(true);
1736                this.update(cx, |this, cx| {
1737                    let (ix, message) = this.last_user_message().context("no user message")?;
1738                    let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1739                    checkpoint.show = !equal;
1740                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1741                    anyhow::Ok(())
1742                })??;
1743            }
1744
1745            Ok(())
1746        })
1747    }
1748
1749    fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1750        self.entries
1751            .iter_mut()
1752            .enumerate()
1753            .rev()
1754            .find_map(|(ix, entry)| {
1755                if let AgentThreadEntry::UserMessage(message) = entry {
1756                    Some((ix, message))
1757                } else {
1758                    None
1759                }
1760            })
1761    }
1762
1763    fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1764        self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1765            if let AgentThreadEntry::UserMessage(message) = entry {
1766                if message.id.as_ref() == Some(id) {
1767                    Some((ix, message))
1768                } else {
1769                    None
1770                }
1771            } else {
1772                None
1773            }
1774        })
1775    }
1776
1777    pub fn read_text_file(
1778        &self,
1779        path: PathBuf,
1780        line: Option<u32>,
1781        limit: Option<u32>,
1782        reuse_shared_snapshot: bool,
1783        cx: &mut Context<Self>,
1784    ) -> Task<Result<String, acp::Error>> {
1785        // Args are 1-based, move to 0-based
1786        let line = line.unwrap_or_default().saturating_sub(1);
1787        let limit = limit.unwrap_or(u32::MAX);
1788        let project = self.project.clone();
1789        let action_log = self.action_log.clone();
1790        cx.spawn(async move |this, cx| {
1791            let load = project
1792                .update(cx, |project, cx| {
1793                    let path = project
1794                        .project_path_for_absolute_path(&path, cx)
1795                        .ok_or_else(|| {
1796                            acp::Error::resource_not_found(Some(path.display().to_string()))
1797                        })?;
1798                    Ok(project.open_buffer(path, cx))
1799                })
1800                .map_err(|e| acp::Error::internal_error().with_data(e.to_string()))
1801                .flatten()?;
1802
1803            let buffer = load.await?;
1804
1805            let snapshot = if reuse_shared_snapshot {
1806                this.read_with(cx, |this, _| {
1807                    this.shared_buffers.get(&buffer.clone()).cloned()
1808                })
1809                .log_err()
1810                .flatten()
1811            } else {
1812                None
1813            };
1814
1815            let snapshot = if let Some(snapshot) = snapshot {
1816                snapshot
1817            } else {
1818                action_log.update(cx, |action_log, cx| {
1819                    action_log.buffer_read(buffer.clone(), cx);
1820                })?;
1821
1822                let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
1823                this.update(cx, |this, _| {
1824                    this.shared_buffers.insert(buffer.clone(), snapshot.clone());
1825                })?;
1826                snapshot
1827            };
1828
1829            let max_point = snapshot.max_point();
1830            let start_position = Point::new(line, 0);
1831
1832            if start_position > max_point {
1833                return Err(acp::Error::invalid_params().with_data(format!(
1834                    "Attempting to read beyond the end of the file, line {}:{}",
1835                    max_point.row + 1,
1836                    max_point.column
1837                )));
1838            }
1839
1840            let start = snapshot.anchor_before(start_position);
1841            let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
1842
1843            project.update(cx, |project, cx| {
1844                project.set_agent_location(
1845                    Some(AgentLocation {
1846                        buffer: buffer.downgrade(),
1847                        position: start,
1848                    }),
1849                    cx,
1850                );
1851            })?;
1852
1853            Ok(snapshot.text_for_range(start..end).collect::<String>())
1854        })
1855    }
1856
1857    pub fn write_text_file(
1858        &self,
1859        path: PathBuf,
1860        content: String,
1861        cx: &mut Context<Self>,
1862    ) -> Task<Result<()>> {
1863        let project = self.project.clone();
1864        let action_log = self.action_log.clone();
1865        cx.spawn(async move |this, cx| {
1866            let load = project.update(cx, |project, cx| {
1867                let path = project
1868                    .project_path_for_absolute_path(&path, cx)
1869                    .context("invalid path")?;
1870                anyhow::Ok(project.open_buffer(path, cx))
1871            });
1872            let buffer = load??.await?;
1873            let snapshot = this.update(cx, |this, cx| {
1874                this.shared_buffers
1875                    .get(&buffer)
1876                    .cloned()
1877                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1878            })?;
1879            let edits = cx
1880                .background_executor()
1881                .spawn(async move {
1882                    let old_text = snapshot.text();
1883                    text_diff(old_text.as_str(), &content)
1884                        .into_iter()
1885                        .map(|(range, replacement)| {
1886                            (
1887                                snapshot.anchor_after(range.start)
1888                                    ..snapshot.anchor_before(range.end),
1889                                replacement,
1890                            )
1891                        })
1892                        .collect::<Vec<_>>()
1893                })
1894                .await;
1895
1896            project.update(cx, |project, cx| {
1897                project.set_agent_location(
1898                    Some(AgentLocation {
1899                        buffer: buffer.downgrade(),
1900                        position: edits
1901                            .last()
1902                            .map(|(range, _)| range.end)
1903                            .unwrap_or(Anchor::MIN),
1904                    }),
1905                    cx,
1906                );
1907            })?;
1908
1909            let format_on_save = cx.update(|cx| {
1910                action_log.update(cx, |action_log, cx| {
1911                    action_log.buffer_read(buffer.clone(), cx);
1912                });
1913
1914                let format_on_save = buffer.update(cx, |buffer, cx| {
1915                    buffer.edit(edits, None, cx);
1916
1917                    let settings = language::language_settings::language_settings(
1918                        buffer.language().map(|l| l.name()),
1919                        buffer.file(),
1920                        cx,
1921                    );
1922
1923                    settings.format_on_save != FormatOnSave::Off
1924                });
1925                action_log.update(cx, |action_log, cx| {
1926                    action_log.buffer_edited(buffer.clone(), cx);
1927                });
1928                format_on_save
1929            })?;
1930
1931            if format_on_save {
1932                let format_task = project.update(cx, |project, cx| {
1933                    project.format(
1934                        HashSet::from_iter([buffer.clone()]),
1935                        LspFormatTarget::Buffers,
1936                        false,
1937                        FormatTrigger::Save,
1938                        cx,
1939                    )
1940                })?;
1941                format_task.await.log_err();
1942
1943                action_log.update(cx, |action_log, cx| {
1944                    action_log.buffer_edited(buffer.clone(), cx);
1945                })?;
1946            }
1947
1948            project
1949                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1950                .await
1951        })
1952    }
1953
1954    pub fn create_terminal(
1955        &self,
1956        command: String,
1957        args: Vec<String>,
1958        extra_env: Vec<acp::EnvVariable>,
1959        cwd: Option<PathBuf>,
1960        output_byte_limit: Option<u64>,
1961        cx: &mut Context<Self>,
1962    ) -> Task<Result<Entity<Terminal>>> {
1963        let env = match &cwd {
1964            Some(dir) => self.project.update(cx, |project, cx| {
1965                let shell = TerminalSettings::get_global(cx).shell.clone();
1966                project.directory_environment(&shell, dir.as_path().into(), cx)
1967            }),
1968            None => Task::ready(None).shared(),
1969        };
1970        let env = cx.spawn(async move |_, _| {
1971            let mut env = env.await.unwrap_or_default();
1972            // Disables paging for `git` and hopefully other commands
1973            env.insert("PAGER".into(), "".into());
1974            for var in extra_env {
1975                env.insert(var.name, var.value);
1976            }
1977            env
1978        });
1979
1980        let project = self.project.clone();
1981        let language_registry = project.read(cx).languages().clone();
1982
1983        let terminal_id = acp::TerminalId(Uuid::new_v4().to_string().into());
1984        let terminal_task = cx.spawn({
1985            let terminal_id = terminal_id.clone();
1986            async move |_this, cx| {
1987                let env = env.await;
1988                let (task_command, task_args) = ShellBuilder::new(
1989                    project
1990                        .update(cx, |project, cx| {
1991                            project
1992                                .remote_client()
1993                                .and_then(|r| r.read(cx).default_system_shell())
1994                        })?
1995                        .as_deref(),
1996                    &Shell::Program(get_default_system_shell()),
1997                )
1998                .redirect_stdin_to_dev_null()
1999                .build(Some(command.clone()), &args);
2000                let terminal = project
2001                    .update(cx, |project, cx| {
2002                        project.create_terminal_task(
2003                            task::SpawnInTerminal {
2004                                command: Some(task_command),
2005                                args: task_args,
2006                                cwd: cwd.clone(),
2007                                env,
2008                                ..Default::default()
2009                            },
2010                            cx,
2011                        )
2012                    })?
2013                    .await?;
2014
2015                cx.new(|cx| {
2016                    Terminal::new(
2017                        terminal_id,
2018                        &format!("{} {}", command, args.join(" ")),
2019                        cwd,
2020                        output_byte_limit.map(|l| l as usize),
2021                        terminal,
2022                        language_registry,
2023                        cx,
2024                    )
2025                })
2026            }
2027        });
2028
2029        cx.spawn(async move |this, cx| {
2030            let terminal = terminal_task.await?;
2031            this.update(cx, |this, _cx| {
2032                this.terminals.insert(terminal_id, terminal.clone());
2033                terminal
2034            })
2035        })
2036    }
2037
2038    pub fn kill_terminal(
2039        &mut self,
2040        terminal_id: acp::TerminalId,
2041        cx: &mut Context<Self>,
2042    ) -> Result<()> {
2043        self.terminals
2044            .get(&terminal_id)
2045            .context("Terminal not found")?
2046            .update(cx, |terminal, cx| {
2047                terminal.kill(cx);
2048            });
2049
2050        Ok(())
2051    }
2052
2053    pub fn release_terminal(
2054        &mut self,
2055        terminal_id: acp::TerminalId,
2056        cx: &mut Context<Self>,
2057    ) -> Result<()> {
2058        self.terminals
2059            .remove(&terminal_id)
2060            .context("Terminal not found")?
2061            .update(cx, |terminal, cx| {
2062                terminal.kill(cx);
2063            });
2064
2065        Ok(())
2066    }
2067
2068    pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2069        self.terminals
2070            .get(&terminal_id)
2071            .context("Terminal not found")
2072            .cloned()
2073    }
2074
2075    pub fn to_markdown(&self, cx: &App) -> String {
2076        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2077    }
2078
2079    pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2080        cx.emit(AcpThreadEvent::LoadError(error));
2081    }
2082}
2083
2084fn markdown_for_raw_output(
2085    raw_output: &serde_json::Value,
2086    language_registry: &Arc<LanguageRegistry>,
2087    cx: &mut App,
2088) -> Option<Entity<Markdown>> {
2089    match raw_output {
2090        serde_json::Value::Null => None,
2091        serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2092            Markdown::new(
2093                value.to_string().into(),
2094                Some(language_registry.clone()),
2095                None,
2096                cx,
2097            )
2098        })),
2099        serde_json::Value::Number(value) => Some(cx.new(|cx| {
2100            Markdown::new(
2101                value.to_string().into(),
2102                Some(language_registry.clone()),
2103                None,
2104                cx,
2105            )
2106        })),
2107        serde_json::Value::String(value) => Some(cx.new(|cx| {
2108            Markdown::new(
2109                value.clone().into(),
2110                Some(language_registry.clone()),
2111                None,
2112                cx,
2113            )
2114        })),
2115        value => Some(cx.new(|cx| {
2116            Markdown::new(
2117                format!("```json\n{}\n```", value).into(),
2118                Some(language_registry.clone()),
2119                None,
2120                cx,
2121            )
2122        })),
2123    }
2124}
2125
2126#[cfg(test)]
2127mod tests {
2128    use super::*;
2129    use anyhow::anyhow;
2130    use futures::{channel::mpsc, future::LocalBoxFuture, select};
2131    use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2132    use indoc::indoc;
2133    use project::{FakeFs, Fs};
2134    use rand::{distr, prelude::*};
2135    use serde_json::json;
2136    use settings::SettingsStore;
2137    use smol::stream::StreamExt as _;
2138    use std::{
2139        any::Any,
2140        cell::RefCell,
2141        path::Path,
2142        rc::Rc,
2143        sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2144        time::Duration,
2145    };
2146    use util::path;
2147
2148    fn init_test(cx: &mut TestAppContext) {
2149        env_logger::try_init().ok();
2150        cx.update(|cx| {
2151            let settings_store = SettingsStore::test(cx);
2152            cx.set_global(settings_store);
2153            Project::init_settings(cx);
2154            language::init(cx);
2155        });
2156    }
2157
2158    #[gpui::test]
2159    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2160        init_test(cx);
2161
2162        let fs = FakeFs::new(cx.executor());
2163        let project = Project::test(fs, [], cx).await;
2164        let connection = Rc::new(FakeAgentConnection::new());
2165        let thread = cx
2166            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2167            .await
2168            .unwrap();
2169
2170        // Test creating a new user message
2171        thread.update(cx, |thread, cx| {
2172            thread.push_user_content_block(
2173                None,
2174                acp::ContentBlock::Text(acp::TextContent {
2175                    annotations: None,
2176                    text: "Hello, ".to_string(),
2177                    meta: None,
2178                }),
2179                cx,
2180            );
2181        });
2182
2183        thread.update(cx, |thread, cx| {
2184            assert_eq!(thread.entries.len(), 1);
2185            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2186                assert_eq!(user_msg.id, None);
2187                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2188            } else {
2189                panic!("Expected UserMessage");
2190            }
2191        });
2192
2193        // Test appending to existing user message
2194        let message_1_id = UserMessageId::new();
2195        thread.update(cx, |thread, cx| {
2196            thread.push_user_content_block(
2197                Some(message_1_id.clone()),
2198                acp::ContentBlock::Text(acp::TextContent {
2199                    annotations: None,
2200                    text: "world!".to_string(),
2201                    meta: None,
2202                }),
2203                cx,
2204            );
2205        });
2206
2207        thread.update(cx, |thread, cx| {
2208            assert_eq!(thread.entries.len(), 1);
2209            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2210                assert_eq!(user_msg.id, Some(message_1_id));
2211                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2212            } else {
2213                panic!("Expected UserMessage");
2214            }
2215        });
2216
2217        // Test creating new user message after assistant message
2218        thread.update(cx, |thread, cx| {
2219            thread.push_assistant_content_block(
2220                acp::ContentBlock::Text(acp::TextContent {
2221                    annotations: None,
2222                    text: "Assistant response".to_string(),
2223                    meta: None,
2224                }),
2225                false,
2226                cx,
2227            );
2228        });
2229
2230        let message_2_id = UserMessageId::new();
2231        thread.update(cx, |thread, cx| {
2232            thread.push_user_content_block(
2233                Some(message_2_id.clone()),
2234                acp::ContentBlock::Text(acp::TextContent {
2235                    annotations: None,
2236                    text: "New user message".to_string(),
2237                    meta: None,
2238                }),
2239                cx,
2240            );
2241        });
2242
2243        thread.update(cx, |thread, cx| {
2244            assert_eq!(thread.entries.len(), 3);
2245            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2246                assert_eq!(user_msg.id, Some(message_2_id));
2247                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2248            } else {
2249                panic!("Expected UserMessage at index 2");
2250            }
2251        });
2252    }
2253
2254    #[gpui::test]
2255    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2256        init_test(cx);
2257
2258        let fs = FakeFs::new(cx.executor());
2259        let project = Project::test(fs, [], cx).await;
2260        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2261            |_, thread, mut cx| {
2262                async move {
2263                    thread.update(&mut cx, |thread, cx| {
2264                        thread
2265                            .handle_session_update(
2266                                acp::SessionUpdate::AgentThoughtChunk {
2267                                    content: "Thinking ".into(),
2268                                },
2269                                cx,
2270                            )
2271                            .unwrap();
2272                        thread
2273                            .handle_session_update(
2274                                acp::SessionUpdate::AgentThoughtChunk {
2275                                    content: "hard!".into(),
2276                                },
2277                                cx,
2278                            )
2279                            .unwrap();
2280                    })?;
2281                    Ok(acp::PromptResponse {
2282                        stop_reason: acp::StopReason::EndTurn,
2283                        meta: None,
2284                    })
2285                }
2286                .boxed_local()
2287            },
2288        ));
2289
2290        let thread = cx
2291            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2292            .await
2293            .unwrap();
2294
2295        thread
2296            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2297            .await
2298            .unwrap();
2299
2300        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2301        assert_eq!(
2302            output,
2303            indoc! {r#"
2304            ## User
2305
2306            Hello from Zed!
2307
2308            ## Assistant
2309
2310            <thinking>
2311            Thinking hard!
2312            </thinking>
2313
2314            "#}
2315        );
2316    }
2317
2318    #[gpui::test]
2319    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2320        init_test(cx);
2321
2322        let fs = FakeFs::new(cx.executor());
2323        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2324            .await;
2325        let project = Project::test(fs.clone(), [], cx).await;
2326        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2327        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2328        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2329            move |_, thread, mut cx| {
2330                let read_file_tx = read_file_tx.clone();
2331                async move {
2332                    let content = thread
2333                        .update(&mut cx, |thread, cx| {
2334                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2335                        })
2336                        .unwrap()
2337                        .await
2338                        .unwrap();
2339                    assert_eq!(content, "one\ntwo\nthree\n");
2340                    read_file_tx.take().unwrap().send(()).unwrap();
2341                    thread
2342                        .update(&mut cx, |thread, cx| {
2343                            thread.write_text_file(
2344                                path!("/tmp/foo").into(),
2345                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
2346                                cx,
2347                            )
2348                        })
2349                        .unwrap()
2350                        .await
2351                        .unwrap();
2352                    Ok(acp::PromptResponse {
2353                        stop_reason: acp::StopReason::EndTurn,
2354                        meta: None,
2355                    })
2356                }
2357                .boxed_local()
2358            },
2359        ));
2360
2361        let (worktree, pathbuf) = project
2362            .update(cx, |project, cx| {
2363                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2364            })
2365            .await
2366            .unwrap();
2367        let buffer = project
2368            .update(cx, |project, cx| {
2369                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2370            })
2371            .await
2372            .unwrap();
2373
2374        let thread = cx
2375            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2376            .await
2377            .unwrap();
2378
2379        let request = thread.update(cx, |thread, cx| {
2380            thread.send_raw("Extend the count in /tmp/foo", cx)
2381        });
2382        read_file_rx.await.ok();
2383        buffer.update(cx, |buffer, cx| {
2384            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2385        });
2386        cx.run_until_parked();
2387        assert_eq!(
2388            buffer.read_with(cx, |buffer, _| buffer.text()),
2389            "zero\none\ntwo\nthree\nfour\nfive\n"
2390        );
2391        assert_eq!(
2392            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2393            "zero\none\ntwo\nthree\nfour\nfive\n"
2394        );
2395        request.await.unwrap();
2396    }
2397
2398    #[gpui::test]
2399    async fn test_reading_from_line(cx: &mut TestAppContext) {
2400        init_test(cx);
2401
2402        let fs = FakeFs::new(cx.executor());
2403        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
2404            .await;
2405        let project = Project::test(fs.clone(), [], cx).await;
2406        project
2407            .update(cx, |project, cx| {
2408                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2409            })
2410            .await
2411            .unwrap();
2412
2413        let connection = Rc::new(FakeAgentConnection::new());
2414
2415        let thread = cx
2416            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2417            .await
2418            .unwrap();
2419
2420        // Whole file
2421        let content = thread
2422            .update(cx, |thread, cx| {
2423                thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2424            })
2425            .await
2426            .unwrap();
2427
2428        assert_eq!(content, "one\ntwo\nthree\nfour\n");
2429
2430        // Only start line
2431        let content = thread
2432            .update(cx, |thread, cx| {
2433                thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
2434            })
2435            .await
2436            .unwrap();
2437
2438        assert_eq!(content, "three\nfour\n");
2439
2440        // Only limit
2441        let content = thread
2442            .update(cx, |thread, cx| {
2443                thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2444            })
2445            .await
2446            .unwrap();
2447
2448        assert_eq!(content, "one\ntwo\n");
2449
2450        // Range
2451        let content = thread
2452            .update(cx, |thread, cx| {
2453                thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
2454            })
2455            .await
2456            .unwrap();
2457
2458        assert_eq!(content, "two\nthree\n");
2459
2460        // Invalid
2461        let err = thread
2462            .update(cx, |thread, cx| {
2463                thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
2464            })
2465            .await
2466            .unwrap_err();
2467
2468        assert_eq!(
2469            err.to_string(),
2470            "Invalid params: \"Attempting to read beyond the end of the file, line 5:0\""
2471        );
2472    }
2473
2474    #[gpui::test]
2475    async fn test_reading_empty_file(cx: &mut TestAppContext) {
2476        init_test(cx);
2477
2478        let fs = FakeFs::new(cx.executor());
2479        fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
2480        let project = Project::test(fs.clone(), [], cx).await;
2481        project
2482            .update(cx, |project, cx| {
2483                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2484            })
2485            .await
2486            .unwrap();
2487
2488        let connection = Rc::new(FakeAgentConnection::new());
2489
2490        let thread = cx
2491            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2492            .await
2493            .unwrap();
2494
2495        // Whole file
2496        let content = thread
2497            .update(cx, |thread, cx| {
2498                thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2499            })
2500            .await
2501            .unwrap();
2502
2503        assert_eq!(content, "");
2504
2505        // Only start line
2506        let content = thread
2507            .update(cx, |thread, cx| {
2508                thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
2509            })
2510            .await
2511            .unwrap();
2512
2513        assert_eq!(content, "");
2514
2515        // Only limit
2516        let content = thread
2517            .update(cx, |thread, cx| {
2518                thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2519            })
2520            .await
2521            .unwrap();
2522
2523        assert_eq!(content, "");
2524
2525        // Range
2526        let content = thread
2527            .update(cx, |thread, cx| {
2528                thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
2529            })
2530            .await
2531            .unwrap();
2532
2533        assert_eq!(content, "");
2534
2535        // Invalid
2536        let err = thread
2537            .update(cx, |thread, cx| {
2538                thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
2539            })
2540            .await
2541            .unwrap_err();
2542
2543        assert_eq!(
2544            err.to_string(),
2545            "Invalid params: \"Attempting to read beyond the end of the file, line 1:0\""
2546        );
2547    }
2548    #[gpui::test]
2549    async fn test_reading_non_existing_file(cx: &mut TestAppContext) {
2550        init_test(cx);
2551
2552        let fs = FakeFs::new(cx.executor());
2553        fs.insert_tree(path!("/tmp"), json!({})).await;
2554        let project = Project::test(fs.clone(), [], cx).await;
2555        project
2556            .update(cx, |project, cx| {
2557                project.find_or_create_worktree(path!("/tmp"), true, cx)
2558            })
2559            .await
2560            .unwrap();
2561
2562        let connection = Rc::new(FakeAgentConnection::new());
2563
2564        let thread = cx
2565            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2566            .await
2567            .unwrap();
2568
2569        // Out of project file
2570        let err = thread
2571            .update(cx, |thread, cx| {
2572                thread.read_text_file(path!("/foo").into(), None, None, false, cx)
2573            })
2574            .await
2575            .unwrap_err();
2576
2577        assert_eq!(err.code, acp::ErrorCode::RESOURCE_NOT_FOUND.code);
2578    }
2579
2580    #[gpui::test]
2581    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2582        init_test(cx);
2583
2584        let fs = FakeFs::new(cx.executor());
2585        let project = Project::test(fs, [], cx).await;
2586        let id = acp::ToolCallId("test".into());
2587
2588        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2589            let id = id.clone();
2590            move |_, thread, mut cx| {
2591                let id = id.clone();
2592                async move {
2593                    thread
2594                        .update(&mut cx, |thread, cx| {
2595                            thread.handle_session_update(
2596                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2597                                    id: id.clone(),
2598                                    title: "Label".into(),
2599                                    kind: acp::ToolKind::Fetch,
2600                                    status: acp::ToolCallStatus::InProgress,
2601                                    content: vec![],
2602                                    locations: vec![],
2603                                    raw_input: None,
2604                                    raw_output: None,
2605                                    meta: None,
2606                                }),
2607                                cx,
2608                            )
2609                        })
2610                        .unwrap()
2611                        .unwrap();
2612                    Ok(acp::PromptResponse {
2613                        stop_reason: acp::StopReason::EndTurn,
2614                        meta: None,
2615                    })
2616                }
2617                .boxed_local()
2618            }
2619        }));
2620
2621        let thread = cx
2622            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2623            .await
2624            .unwrap();
2625
2626        let request = thread.update(cx, |thread, cx| {
2627            thread.send_raw("Fetch https://example.com", cx)
2628        });
2629
2630        run_until_first_tool_call(&thread, cx).await;
2631
2632        thread.read_with(cx, |thread, _| {
2633            assert!(matches!(
2634                thread.entries[1],
2635                AgentThreadEntry::ToolCall(ToolCall {
2636                    status: ToolCallStatus::InProgress,
2637                    ..
2638                })
2639            ));
2640        });
2641
2642        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2643
2644        thread.read_with(cx, |thread, _| {
2645            assert!(matches!(
2646                &thread.entries[1],
2647                AgentThreadEntry::ToolCall(ToolCall {
2648                    status: ToolCallStatus::Canceled,
2649                    ..
2650                })
2651            ));
2652        });
2653
2654        thread
2655            .update(cx, |thread, cx| {
2656                thread.handle_session_update(
2657                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2658                        id,
2659                        fields: acp::ToolCallUpdateFields {
2660                            status: Some(acp::ToolCallStatus::Completed),
2661                            ..Default::default()
2662                        },
2663                        meta: None,
2664                    }),
2665                    cx,
2666                )
2667            })
2668            .unwrap();
2669
2670        request.await.unwrap();
2671
2672        thread.read_with(cx, |thread, _| {
2673            assert!(matches!(
2674                thread.entries[1],
2675                AgentThreadEntry::ToolCall(ToolCall {
2676                    status: ToolCallStatus::Completed,
2677                    ..
2678                })
2679            ));
2680        });
2681    }
2682
2683    #[gpui::test]
2684    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2685        init_test(cx);
2686        let fs = FakeFs::new(cx.background_executor.clone());
2687        fs.insert_tree(path!("/test"), json!({})).await;
2688        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2689
2690        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2691            move |_, thread, mut cx| {
2692                async move {
2693                    thread
2694                        .update(&mut cx, |thread, cx| {
2695                            thread.handle_session_update(
2696                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2697                                    id: acp::ToolCallId("test".into()),
2698                                    title: "Label".into(),
2699                                    kind: acp::ToolKind::Edit,
2700                                    status: acp::ToolCallStatus::Completed,
2701                                    content: vec![acp::ToolCallContent::Diff {
2702                                        diff: acp::Diff {
2703                                            path: "/test/test.txt".into(),
2704                                            old_text: None,
2705                                            new_text: "foo".into(),
2706                                            meta: None,
2707                                        },
2708                                    }],
2709                                    locations: vec![],
2710                                    raw_input: None,
2711                                    raw_output: None,
2712                                    meta: None,
2713                                }),
2714                                cx,
2715                            )
2716                        })
2717                        .unwrap()
2718                        .unwrap();
2719                    Ok(acp::PromptResponse {
2720                        stop_reason: acp::StopReason::EndTurn,
2721                        meta: None,
2722                    })
2723                }
2724                .boxed_local()
2725            }
2726        }));
2727
2728        let thread = cx
2729            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2730            .await
2731            .unwrap();
2732
2733        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2734            .await
2735            .unwrap();
2736
2737        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2738    }
2739
2740    #[gpui::test(iterations = 10)]
2741    async fn test_checkpoints(cx: &mut TestAppContext) {
2742        init_test(cx);
2743        let fs = FakeFs::new(cx.background_executor.clone());
2744        fs.insert_tree(
2745            path!("/test"),
2746            json!({
2747                ".git": {}
2748            }),
2749        )
2750        .await;
2751        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2752
2753        let simulate_changes = Arc::new(AtomicBool::new(true));
2754        let next_filename = Arc::new(AtomicUsize::new(0));
2755        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2756            let simulate_changes = simulate_changes.clone();
2757            let next_filename = next_filename.clone();
2758            let fs = fs.clone();
2759            move |request, thread, mut cx| {
2760                let fs = fs.clone();
2761                let simulate_changes = simulate_changes.clone();
2762                let next_filename = next_filename.clone();
2763                async move {
2764                    if simulate_changes.load(SeqCst) {
2765                        let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2766                        fs.write(Path::new(&filename), b"").await?;
2767                    }
2768
2769                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2770                        panic!("expected text content block");
2771                    };
2772                    thread.update(&mut cx, |thread, cx| {
2773                        thread
2774                            .handle_session_update(
2775                                acp::SessionUpdate::AgentMessageChunk {
2776                                    content: content.text.to_uppercase().into(),
2777                                },
2778                                cx,
2779                            )
2780                            .unwrap();
2781                    })?;
2782                    Ok(acp::PromptResponse {
2783                        stop_reason: acp::StopReason::EndTurn,
2784                        meta: None,
2785                    })
2786                }
2787                .boxed_local()
2788            }
2789        }));
2790        let thread = cx
2791            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2792            .await
2793            .unwrap();
2794
2795        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2796            .await
2797            .unwrap();
2798        thread.read_with(cx, |thread, cx| {
2799            assert_eq!(
2800                thread.to_markdown(cx),
2801                indoc! {"
2802                    ## User (checkpoint)
2803
2804                    Lorem
2805
2806                    ## Assistant
2807
2808                    LOREM
2809
2810                "}
2811            );
2812        });
2813        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2814
2815        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2816            .await
2817            .unwrap();
2818        thread.read_with(cx, |thread, cx| {
2819            assert_eq!(
2820                thread.to_markdown(cx),
2821                indoc! {"
2822                    ## User (checkpoint)
2823
2824                    Lorem
2825
2826                    ## Assistant
2827
2828                    LOREM
2829
2830                    ## User (checkpoint)
2831
2832                    ipsum
2833
2834                    ## Assistant
2835
2836                    IPSUM
2837
2838                "}
2839            );
2840        });
2841        assert_eq!(
2842            fs.files(),
2843            vec![
2844                Path::new(path!("/test/file-0")),
2845                Path::new(path!("/test/file-1"))
2846            ]
2847        );
2848
2849        // Checkpoint isn't stored when there are no changes.
2850        simulate_changes.store(false, SeqCst);
2851        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2852            .await
2853            .unwrap();
2854        thread.read_with(cx, |thread, cx| {
2855            assert_eq!(
2856                thread.to_markdown(cx),
2857                indoc! {"
2858                    ## User (checkpoint)
2859
2860                    Lorem
2861
2862                    ## Assistant
2863
2864                    LOREM
2865
2866                    ## User (checkpoint)
2867
2868                    ipsum
2869
2870                    ## Assistant
2871
2872                    IPSUM
2873
2874                    ## User
2875
2876                    dolor
2877
2878                    ## Assistant
2879
2880                    DOLOR
2881
2882                "}
2883            );
2884        });
2885        assert_eq!(
2886            fs.files(),
2887            vec![
2888                Path::new(path!("/test/file-0")),
2889                Path::new(path!("/test/file-1"))
2890            ]
2891        );
2892
2893        // Rewinding the conversation truncates the history and restores the checkpoint.
2894        thread
2895            .update(cx, |thread, cx| {
2896                let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2897                    panic!("unexpected entries {:?}", thread.entries)
2898                };
2899                thread.restore_checkpoint(message.id.clone().unwrap(), cx)
2900            })
2901            .await
2902            .unwrap();
2903        thread.read_with(cx, |thread, cx| {
2904            assert_eq!(
2905                thread.to_markdown(cx),
2906                indoc! {"
2907                    ## User (checkpoint)
2908
2909                    Lorem
2910
2911                    ## Assistant
2912
2913                    LOREM
2914
2915                "}
2916            );
2917        });
2918        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2919    }
2920
2921    #[gpui::test]
2922    async fn test_tool_result_refusal(cx: &mut TestAppContext) {
2923        use std::sync::atomic::AtomicUsize;
2924        init_test(cx);
2925
2926        let fs = FakeFs::new(cx.executor());
2927        let project = Project::test(fs, None, cx).await;
2928
2929        // Create a connection that simulates refusal after tool result
2930        let prompt_count = Arc::new(AtomicUsize::new(0));
2931        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2932            let prompt_count = prompt_count.clone();
2933            move |_request, thread, mut cx| {
2934                let count = prompt_count.fetch_add(1, SeqCst);
2935                async move {
2936                    if count == 0 {
2937                        // First prompt: Generate a tool call with result
2938                        thread.update(&mut cx, |thread, cx| {
2939                            thread
2940                                .handle_session_update(
2941                                    acp::SessionUpdate::ToolCall(acp::ToolCall {
2942                                        id: acp::ToolCallId("tool1".into()),
2943                                        title: "Test Tool".into(),
2944                                        kind: acp::ToolKind::Fetch,
2945                                        status: acp::ToolCallStatus::Completed,
2946                                        content: vec![],
2947                                        locations: vec![],
2948                                        raw_input: Some(serde_json::json!({"query": "test"})),
2949                                        raw_output: Some(
2950                                            serde_json::json!({"result": "inappropriate content"}),
2951                                        ),
2952                                        meta: None,
2953                                    }),
2954                                    cx,
2955                                )
2956                                .unwrap();
2957                        })?;
2958
2959                        // Now return refusal because of the tool result
2960                        Ok(acp::PromptResponse {
2961                            stop_reason: acp::StopReason::Refusal,
2962                            meta: None,
2963                        })
2964                    } else {
2965                        Ok(acp::PromptResponse {
2966                            stop_reason: acp::StopReason::EndTurn,
2967                            meta: None,
2968                        })
2969                    }
2970                }
2971                .boxed_local()
2972            }
2973        }));
2974
2975        let thread = cx
2976            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2977            .await
2978            .unwrap();
2979
2980        // Track if we see a Refusal event
2981        let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
2982        let saw_refusal_event_captured = saw_refusal_event.clone();
2983        thread.update(cx, |_thread, cx| {
2984            cx.subscribe(
2985                &thread,
2986                move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
2987                    if matches!(event, AcpThreadEvent::Refusal) {
2988                        *saw_refusal_event_captured.lock().unwrap() = true;
2989                    }
2990                },
2991            )
2992            .detach();
2993        });
2994
2995        // Send a user message - this will trigger tool call and then refusal
2996        let send_task = thread.update(cx, |thread, cx| {
2997            thread.send(
2998                vec![acp::ContentBlock::Text(acp::TextContent {
2999                    text: "Hello".into(),
3000                    annotations: None,
3001                    meta: None,
3002                })],
3003                cx,
3004            )
3005        });
3006        cx.background_executor.spawn(send_task).detach();
3007        cx.run_until_parked();
3008
3009        // Verify that:
3010        // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
3011        // 2. The user message was NOT truncated
3012        assert!(
3013            *saw_refusal_event.lock().unwrap(),
3014            "Refusal event should be emitted for tool result refusals"
3015        );
3016
3017        thread.read_with(cx, |thread, _| {
3018            let entries = thread.entries();
3019            assert!(entries.len() >= 2, "Should have user message and tool call");
3020
3021            // Verify user message is still there
3022            assert!(
3023                matches!(entries[0], AgentThreadEntry::UserMessage(_)),
3024                "User message should not be truncated"
3025            );
3026
3027            // Verify tool call is there with result
3028            if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
3029                assert!(
3030                    tool_call.raw_output.is_some(),
3031                    "Tool call should have output"
3032                );
3033            } else {
3034                panic!("Expected tool call at index 1");
3035            }
3036        });
3037    }
3038
3039    #[gpui::test]
3040    async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3041        init_test(cx);
3042
3043        let fs = FakeFs::new(cx.executor());
3044        let project = Project::test(fs, None, cx).await;
3045
3046        let refuse_next = Arc::new(AtomicBool::new(false));
3047        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3048            let refuse_next = refuse_next.clone();
3049            move |_request, _thread, _cx| {
3050                if refuse_next.load(SeqCst) {
3051                    async move {
3052                        Ok(acp::PromptResponse {
3053                            stop_reason: acp::StopReason::Refusal,
3054                            meta: None,
3055                        })
3056                    }
3057                    .boxed_local()
3058                } else {
3059                    async move {
3060                        Ok(acp::PromptResponse {
3061                            stop_reason: acp::StopReason::EndTurn,
3062                            meta: None,
3063                        })
3064                    }
3065                    .boxed_local()
3066                }
3067            }
3068        }));
3069
3070        let thread = cx
3071            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3072            .await
3073            .unwrap();
3074
3075        // Track if we see a Refusal event
3076        let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3077        let saw_refusal_event_captured = saw_refusal_event.clone();
3078        thread.update(cx, |_thread, cx| {
3079            cx.subscribe(
3080                &thread,
3081                move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3082                    if matches!(event, AcpThreadEvent::Refusal) {
3083                        *saw_refusal_event_captured.lock().unwrap() = true;
3084                    }
3085                },
3086            )
3087            .detach();
3088        });
3089
3090        // Send a message that will be refused
3091        refuse_next.store(true, SeqCst);
3092        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3093            .await
3094            .unwrap();
3095
3096        // Verify that a Refusal event WAS emitted for user prompt refusal
3097        assert!(
3098            *saw_refusal_event.lock().unwrap(),
3099            "Refusal event should be emitted for user prompt refusals"
3100        );
3101
3102        // Verify the message was truncated (user prompt refusal)
3103        thread.read_with(cx, |thread, cx| {
3104            assert_eq!(thread.to_markdown(cx), "");
3105        });
3106    }
3107
3108    #[gpui::test]
3109    async fn test_refusal(cx: &mut TestAppContext) {
3110        init_test(cx);
3111        let fs = FakeFs::new(cx.background_executor.clone());
3112        fs.insert_tree(path!("/"), json!({})).await;
3113        let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3114
3115        let refuse_next = Arc::new(AtomicBool::new(false));
3116        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3117            let refuse_next = refuse_next.clone();
3118            move |request, thread, mut cx| {
3119                let refuse_next = refuse_next.clone();
3120                async move {
3121                    if refuse_next.load(SeqCst) {
3122                        return Ok(acp::PromptResponse {
3123                            stop_reason: acp::StopReason::Refusal,
3124                            meta: None,
3125                        });
3126                    }
3127
3128                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3129                        panic!("expected text content block");
3130                    };
3131                    thread.update(&mut cx, |thread, cx| {
3132                        thread
3133                            .handle_session_update(
3134                                acp::SessionUpdate::AgentMessageChunk {
3135                                    content: content.text.to_uppercase().into(),
3136                                },
3137                                cx,
3138                            )
3139                            .unwrap();
3140                    })?;
3141                    Ok(acp::PromptResponse {
3142                        stop_reason: acp::StopReason::EndTurn,
3143                        meta: None,
3144                    })
3145                }
3146                .boxed_local()
3147            }
3148        }));
3149        let thread = cx
3150            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3151            .await
3152            .unwrap();
3153
3154        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3155            .await
3156            .unwrap();
3157        thread.read_with(cx, |thread, cx| {
3158            assert_eq!(
3159                thread.to_markdown(cx),
3160                indoc! {"
3161                    ## User
3162
3163                    hello
3164
3165                    ## Assistant
3166
3167                    HELLO
3168
3169                "}
3170            );
3171        });
3172
3173        // Simulate refusing the second message. The message should be truncated
3174        // when a user prompt is refused.
3175        refuse_next.store(true, SeqCst);
3176        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3177            .await
3178            .unwrap();
3179        thread.read_with(cx, |thread, cx| {
3180            assert_eq!(
3181                thread.to_markdown(cx),
3182                indoc! {"
3183                    ## User
3184
3185                    hello
3186
3187                    ## Assistant
3188
3189                    HELLO
3190
3191                "}
3192            );
3193        });
3194    }
3195
3196    async fn run_until_first_tool_call(
3197        thread: &Entity<AcpThread>,
3198        cx: &mut TestAppContext,
3199    ) -> usize {
3200        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3201
3202        let subscription = cx.update(|cx| {
3203            cx.subscribe(thread, move |thread, _, cx| {
3204                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3205                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3206                        return tx.try_send(ix).unwrap();
3207                    }
3208                }
3209            })
3210        });
3211
3212        select! {
3213            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
3214                panic!("Timeout waiting for tool call")
3215            }
3216            ix = rx.next().fuse() => {
3217                drop(subscription);
3218                ix.unwrap()
3219            }
3220        }
3221    }
3222
3223    #[derive(Clone, Default)]
3224    struct FakeAgentConnection {
3225        auth_methods: Vec<acp::AuthMethod>,
3226        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3227        on_user_message: Option<
3228            Rc<
3229                dyn Fn(
3230                        acp::PromptRequest,
3231                        WeakEntity<AcpThread>,
3232                        AsyncApp,
3233                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3234                    + 'static,
3235            >,
3236        >,
3237    }
3238
3239    impl FakeAgentConnection {
3240        fn new() -> Self {
3241            Self {
3242                auth_methods: Vec::new(),
3243                on_user_message: None,
3244                sessions: Arc::default(),
3245            }
3246        }
3247
3248        #[expect(unused)]
3249        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3250            self.auth_methods = auth_methods;
3251            self
3252        }
3253
3254        fn on_user_message(
3255            mut self,
3256            handler: impl Fn(
3257                acp::PromptRequest,
3258                WeakEntity<AcpThread>,
3259                AsyncApp,
3260            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3261            + 'static,
3262        ) -> Self {
3263            self.on_user_message.replace(Rc::new(handler));
3264            self
3265        }
3266    }
3267
3268    impl AgentConnection for FakeAgentConnection {
3269        fn auth_methods(&self) -> &[acp::AuthMethod] {
3270            &self.auth_methods
3271        }
3272
3273        fn new_thread(
3274            self: Rc<Self>,
3275            project: Entity<Project>,
3276            _cwd: &Path,
3277            cx: &mut App,
3278        ) -> Task<gpui::Result<Entity<AcpThread>>> {
3279            let session_id = acp::SessionId(
3280                rand::rng()
3281                    .sample_iter(&distr::Alphanumeric)
3282                    .take(7)
3283                    .map(char::from)
3284                    .collect::<String>()
3285                    .into(),
3286            );
3287            let action_log = cx.new(|_| ActionLog::new(project.clone()));
3288            let thread = cx.new(|cx| {
3289                AcpThread::new(
3290                    "Test",
3291                    self.clone(),
3292                    project,
3293                    action_log,
3294                    session_id.clone(),
3295                    watch::Receiver::constant(acp::PromptCapabilities {
3296                        image: true,
3297                        audio: true,
3298                        embedded_context: true,
3299                        meta: None,
3300                    }),
3301                    cx,
3302                )
3303            });
3304            self.sessions.lock().insert(session_id, thread.downgrade());
3305            Task::ready(Ok(thread))
3306        }
3307
3308        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3309            if self.auth_methods().iter().any(|m| m.id == method) {
3310                Task::ready(Ok(()))
3311            } else {
3312                Task::ready(Err(anyhow!("Invalid Auth Method")))
3313            }
3314        }
3315
3316        fn prompt(
3317            &self,
3318            _id: Option<UserMessageId>,
3319            params: acp::PromptRequest,
3320            cx: &mut App,
3321        ) -> Task<gpui::Result<acp::PromptResponse>> {
3322            let sessions = self.sessions.lock();
3323            let thread = sessions.get(&params.session_id).unwrap();
3324            if let Some(handler) = &self.on_user_message {
3325                let handler = handler.clone();
3326                let thread = thread.clone();
3327                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3328            } else {
3329                Task::ready(Ok(acp::PromptResponse {
3330                    stop_reason: acp::StopReason::EndTurn,
3331                    meta: None,
3332                }))
3333            }
3334        }
3335
3336        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3337            let sessions = self.sessions.lock();
3338            let thread = sessions.get(session_id).unwrap().clone();
3339
3340            cx.spawn(async move |cx| {
3341                thread
3342                    .update(cx, |thread, cx| thread.cancel(cx))
3343                    .unwrap()
3344                    .await
3345            })
3346            .detach();
3347        }
3348
3349        fn truncate(
3350            &self,
3351            session_id: &acp::SessionId,
3352            _cx: &App,
3353        ) -> Option<Rc<dyn AgentSessionTruncate>> {
3354            Some(Rc::new(FakeAgentSessionEditor {
3355                _session_id: session_id.clone(),
3356            }))
3357        }
3358
3359        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3360            self
3361        }
3362    }
3363
3364    struct FakeAgentSessionEditor {
3365        _session_id: acp::SessionId,
3366    }
3367
3368    impl AgentSessionTruncate for FakeAgentSessionEditor {
3369        fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3370            Task::ready(Ok(()))
3371        }
3372    }
3373
3374    #[gpui::test]
3375    async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3376        init_test(cx);
3377
3378        let fs = FakeFs::new(cx.executor());
3379        let project = Project::test(fs, [], cx).await;
3380        let connection = Rc::new(FakeAgentConnection::new());
3381        let thread = cx
3382            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3383            .await
3384            .unwrap();
3385
3386        // Try to update a tool call that doesn't exist
3387        let nonexistent_id = acp::ToolCallId("nonexistent-tool-call".into());
3388        thread.update(cx, |thread, cx| {
3389            let result = thread.handle_session_update(
3390                acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
3391                    id: nonexistent_id.clone(),
3392                    fields: acp::ToolCallUpdateFields {
3393                        status: Some(acp::ToolCallStatus::Completed),
3394                        ..Default::default()
3395                    },
3396                    meta: None,
3397                }),
3398                cx,
3399            );
3400
3401            // The update should succeed (not return an error)
3402            assert!(result.is_ok());
3403
3404            // There should now be exactly one entry in the thread
3405            assert_eq!(thread.entries.len(), 1);
3406
3407            // The entry should be a failed tool call
3408            if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
3409                assert_eq!(tool_call.id, nonexistent_id);
3410                assert!(matches!(tool_call.status, ToolCallStatus::Failed));
3411                assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
3412
3413                // Check that the content contains the error message
3414                assert_eq!(tool_call.content.len(), 1);
3415                if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
3416                    match content_block {
3417                        ContentBlock::Markdown { markdown } => {
3418                            let markdown_text = markdown.read(cx).source();
3419                            assert!(markdown_text.contains("Tool call not found"));
3420                        }
3421                        ContentBlock::Empty => panic!("Expected markdown content, got empty"),
3422                        ContentBlock::ResourceLink { .. } => {
3423                            panic!("Expected markdown content, got resource link")
3424                        }
3425                    }
3426                } else {
3427                    panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
3428                }
3429            } else {
3430                panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
3431            }
3432        });
3433    }
3434}