acp_thread.rs

   1mod connection;
   2mod diff;
   3mod mention;
   4mod terminal;
   5
   6use agent_settings::AgentSettings;
   7use collections::HashSet;
   8pub use connection::*;
   9pub use diff::*;
  10use language::language_settings::FormatOnSave;
  11pub use mention::*;
  12use project::lsp_store::{FormatTrigger, LspFormatTarget};
  13use serde::{Deserialize, Serialize};
  14use settings::Settings as _;
  15use task::{Shell, ShellBuilder};
  16pub use terminal::*;
  17
  18use action_log::ActionLog;
  19use agent_client_protocol::{self as acp};
  20use anyhow::{Context as _, Result, anyhow};
  21use editor::Bias;
  22use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  23use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  24use itertools::Itertools;
  25use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
  26use markdown::Markdown;
  27use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
  28use std::collections::HashMap;
  29use std::error::Error;
  30use std::fmt::{Formatter, Write};
  31use std::ops::Range;
  32use std::process::ExitStatus;
  33use std::rc::Rc;
  34use std::time::{Duration, Instant};
  35use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
  36use ui::App;
  37use util::{ResultExt, get_default_system_shell};
  38use uuid::Uuid;
  39
  40#[derive(Debug)]
  41pub struct UserMessage {
  42    pub id: Option<UserMessageId>,
  43    pub content: ContentBlock,
  44    pub chunks: Vec<acp::ContentBlock>,
  45    pub checkpoint: Option<Checkpoint>,
  46}
  47
  48#[derive(Debug)]
  49pub struct Checkpoint {
  50    git_checkpoint: GitStoreCheckpoint,
  51    pub show: bool,
  52}
  53
  54impl UserMessage {
  55    fn to_markdown(&self, cx: &App) -> String {
  56        let mut markdown = String::new();
  57        if self
  58            .checkpoint
  59            .as_ref()
  60            .is_some_and(|checkpoint| checkpoint.show)
  61        {
  62            writeln!(markdown, "## User (checkpoint)").unwrap();
  63        } else {
  64            writeln!(markdown, "## User").unwrap();
  65        }
  66        writeln!(markdown).unwrap();
  67        writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
  68        writeln!(markdown).unwrap();
  69        markdown
  70    }
  71}
  72
  73#[derive(Debug, PartialEq)]
  74pub struct AssistantMessage {
  75    pub chunks: Vec<AssistantMessageChunk>,
  76}
  77
  78impl AssistantMessage {
  79    pub fn to_markdown(&self, cx: &App) -> String {
  80        format!(
  81            "## Assistant\n\n{}\n\n",
  82            self.chunks
  83                .iter()
  84                .map(|chunk| chunk.to_markdown(cx))
  85                .join("\n\n")
  86        )
  87    }
  88}
  89
  90#[derive(Debug, PartialEq)]
  91pub enum AssistantMessageChunk {
  92    Message { block: ContentBlock },
  93    Thought { block: ContentBlock },
  94}
  95
  96impl AssistantMessageChunk {
  97    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
  98        Self::Message {
  99            block: ContentBlock::new(chunk.into(), language_registry, cx),
 100        }
 101    }
 102
 103    fn to_markdown(&self, cx: &App) -> String {
 104        match self {
 105            Self::Message { block } => block.to_markdown(cx).to_string(),
 106            Self::Thought { block } => {
 107                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 108            }
 109        }
 110    }
 111}
 112
 113#[derive(Debug)]
 114pub enum AgentThreadEntry {
 115    UserMessage(UserMessage),
 116    AssistantMessage(AssistantMessage),
 117    ToolCall(ToolCall),
 118}
 119
 120impl AgentThreadEntry {
 121    pub fn to_markdown(&self, cx: &App) -> String {
 122        match self {
 123            Self::UserMessage(message) => message.to_markdown(cx),
 124            Self::AssistantMessage(message) => message.to_markdown(cx),
 125            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 126        }
 127    }
 128
 129    pub fn user_message(&self) -> Option<&UserMessage> {
 130        if let AgentThreadEntry::UserMessage(message) = self {
 131            Some(message)
 132        } else {
 133            None
 134        }
 135    }
 136
 137    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 138        if let AgentThreadEntry::ToolCall(call) = self {
 139            itertools::Either::Left(call.diffs())
 140        } else {
 141            itertools::Either::Right(std::iter::empty())
 142        }
 143    }
 144
 145    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 146        if let AgentThreadEntry::ToolCall(call) = self {
 147            itertools::Either::Left(call.terminals())
 148        } else {
 149            itertools::Either::Right(std::iter::empty())
 150        }
 151    }
 152
 153    pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
 154        if let AgentThreadEntry::ToolCall(ToolCall {
 155            locations,
 156            resolved_locations,
 157            ..
 158        }) = self
 159        {
 160            Some((
 161                locations.get(ix)?.clone(),
 162                resolved_locations.get(ix)?.clone()?,
 163            ))
 164        } else {
 165            None
 166        }
 167    }
 168}
 169
 170#[derive(Debug)]
 171pub struct ToolCall {
 172    pub id: acp::ToolCallId,
 173    pub label: Entity<Markdown>,
 174    pub kind: acp::ToolKind,
 175    pub content: Vec<ToolCallContent>,
 176    pub status: ToolCallStatus,
 177    pub locations: Vec<acp::ToolCallLocation>,
 178    pub resolved_locations: Vec<Option<AgentLocation>>,
 179    pub raw_input: Option<serde_json::Value>,
 180    pub raw_output: Option<serde_json::Value>,
 181}
 182
 183impl ToolCall {
 184    fn from_acp(
 185        tool_call: acp::ToolCall,
 186        status: ToolCallStatus,
 187        language_registry: Arc<LanguageRegistry>,
 188        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 189        cx: &mut App,
 190    ) -> Result<Self> {
 191        let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
 192            first_line.to_owned() + ""
 193        } else {
 194            tool_call.title
 195        };
 196        let mut content = Vec::with_capacity(tool_call.content.len());
 197        for item in tool_call.content {
 198            content.push(ToolCallContent::from_acp(
 199                item,
 200                language_registry.clone(),
 201                terminals,
 202                cx,
 203            )?);
 204        }
 205
 206        let result = Self {
 207            id: tool_call.id,
 208            label: cx
 209                .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
 210            kind: tool_call.kind,
 211            content,
 212            locations: tool_call.locations,
 213            resolved_locations: Vec::default(),
 214            status,
 215            raw_input: tool_call.raw_input,
 216            raw_output: tool_call.raw_output,
 217        };
 218        Ok(result)
 219    }
 220
 221    fn update_fields(
 222        &mut self,
 223        fields: acp::ToolCallUpdateFields,
 224        language_registry: Arc<LanguageRegistry>,
 225        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 226        cx: &mut App,
 227    ) -> Result<()> {
 228        let acp::ToolCallUpdateFields {
 229            kind,
 230            status,
 231            title,
 232            content,
 233            locations,
 234            raw_input,
 235            raw_output,
 236        } = fields;
 237
 238        if let Some(kind) = kind {
 239            self.kind = kind;
 240        }
 241
 242        if let Some(status) = status {
 243            self.status = status.into();
 244        }
 245
 246        if let Some(title) = title {
 247            self.label.update(cx, |label, cx| {
 248                if let Some((first_line, _)) = title.split_once("\n") {
 249                    label.replace(first_line.to_owned() + "", cx)
 250                } else {
 251                    label.replace(title, cx);
 252                }
 253            });
 254        }
 255
 256        if let Some(content) = content {
 257            let new_content_len = content.len();
 258            let mut content = content.into_iter();
 259
 260            // Reuse existing content if we can
 261            for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
 262                old.update_from_acp(new, language_registry.clone(), terminals, cx)?;
 263            }
 264            for new in content {
 265                self.content.push(ToolCallContent::from_acp(
 266                    new,
 267                    language_registry.clone(),
 268                    terminals,
 269                    cx,
 270                )?)
 271            }
 272            self.content.truncate(new_content_len);
 273        }
 274
 275        if let Some(locations) = locations {
 276            self.locations = locations;
 277        }
 278
 279        if let Some(raw_input) = raw_input {
 280            self.raw_input = Some(raw_input);
 281        }
 282
 283        if let Some(raw_output) = raw_output {
 284            if self.content.is_empty()
 285                && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
 286            {
 287                self.content
 288                    .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
 289                        markdown,
 290                    }));
 291            }
 292            self.raw_output = Some(raw_output);
 293        }
 294        Ok(())
 295    }
 296
 297    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 298        self.content.iter().filter_map(|content| match content {
 299            ToolCallContent::Diff(diff) => Some(diff),
 300            ToolCallContent::ContentBlock(_) => None,
 301            ToolCallContent::Terminal(_) => None,
 302        })
 303    }
 304
 305    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 306        self.content.iter().filter_map(|content| match content {
 307            ToolCallContent::Terminal(terminal) => Some(terminal),
 308            ToolCallContent::ContentBlock(_) => None,
 309            ToolCallContent::Diff(_) => None,
 310        })
 311    }
 312
 313    fn to_markdown(&self, cx: &App) -> String {
 314        let mut markdown = format!(
 315            "**Tool Call: {}**\nStatus: {}\n\n",
 316            self.label.read(cx).source(),
 317            self.status
 318        );
 319        for content in &self.content {
 320            markdown.push_str(content.to_markdown(cx).as_str());
 321            markdown.push_str("\n\n");
 322        }
 323        markdown
 324    }
 325
 326    async fn resolve_location(
 327        location: acp::ToolCallLocation,
 328        project: WeakEntity<Project>,
 329        cx: &mut AsyncApp,
 330    ) -> Option<AgentLocation> {
 331        let buffer = project
 332            .update(cx, |project, cx| {
 333                project
 334                    .project_path_for_absolute_path(&location.path, cx)
 335                    .map(|path| project.open_buffer(path, cx))
 336            })
 337            .ok()??;
 338        let buffer = buffer.await.log_err()?;
 339        let position = buffer
 340            .update(cx, |buffer, _| {
 341                if let Some(row) = location.line {
 342                    let snapshot = buffer.snapshot();
 343                    let column = snapshot.indent_size_for_line(row).len;
 344                    let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
 345                    snapshot.anchor_before(point)
 346                } else {
 347                    Anchor::MIN
 348                }
 349            })
 350            .ok()?;
 351
 352        Some(AgentLocation {
 353            buffer: buffer.downgrade(),
 354            position,
 355        })
 356    }
 357
 358    fn resolve_locations(
 359        &self,
 360        project: Entity<Project>,
 361        cx: &mut App,
 362    ) -> Task<Vec<Option<AgentLocation>>> {
 363        let locations = self.locations.clone();
 364        project.update(cx, |_, cx| {
 365            cx.spawn(async move |project, cx| {
 366                let mut new_locations = Vec::new();
 367                for location in locations {
 368                    new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
 369                }
 370                new_locations
 371            })
 372        })
 373    }
 374}
 375
 376#[derive(Debug)]
 377pub enum ToolCallStatus {
 378    /// The tool call hasn't started running yet, but we start showing it to
 379    /// the user.
 380    Pending,
 381    /// The tool call is waiting for confirmation from the user.
 382    WaitingForConfirmation {
 383        options: Vec<acp::PermissionOption>,
 384        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 385    },
 386    /// The tool call is currently running.
 387    InProgress,
 388    /// The tool call completed successfully.
 389    Completed,
 390    /// The tool call failed.
 391    Failed,
 392    /// The user rejected the tool call.
 393    Rejected,
 394    /// The user canceled generation so the tool call was canceled.
 395    Canceled,
 396}
 397
 398impl From<acp::ToolCallStatus> for ToolCallStatus {
 399    fn from(status: acp::ToolCallStatus) -> Self {
 400        match status {
 401            acp::ToolCallStatus::Pending => Self::Pending,
 402            acp::ToolCallStatus::InProgress => Self::InProgress,
 403            acp::ToolCallStatus::Completed => Self::Completed,
 404            acp::ToolCallStatus::Failed => Self::Failed,
 405        }
 406    }
 407}
 408
 409impl Display for ToolCallStatus {
 410    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 411        write!(
 412            f,
 413            "{}",
 414            match self {
 415                ToolCallStatus::Pending => "Pending",
 416                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 417                ToolCallStatus::InProgress => "In Progress",
 418                ToolCallStatus::Completed => "Completed",
 419                ToolCallStatus::Failed => "Failed",
 420                ToolCallStatus::Rejected => "Rejected",
 421                ToolCallStatus::Canceled => "Canceled",
 422            }
 423        )
 424    }
 425}
 426
 427#[derive(Debug, PartialEq, Clone)]
 428pub enum ContentBlock {
 429    Empty,
 430    Markdown { markdown: Entity<Markdown> },
 431    ResourceLink { resource_link: acp::ResourceLink },
 432}
 433
 434impl ContentBlock {
 435    pub fn new(
 436        block: acp::ContentBlock,
 437        language_registry: &Arc<LanguageRegistry>,
 438        cx: &mut App,
 439    ) -> Self {
 440        let mut this = Self::Empty;
 441        this.append(block, language_registry, cx);
 442        this
 443    }
 444
 445    pub fn new_combined(
 446        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 447        language_registry: Arc<LanguageRegistry>,
 448        cx: &mut App,
 449    ) -> Self {
 450        let mut this = Self::Empty;
 451        for block in blocks {
 452            this.append(block, &language_registry, cx);
 453        }
 454        this
 455    }
 456
 457    pub fn append(
 458        &mut self,
 459        block: acp::ContentBlock,
 460        language_registry: &Arc<LanguageRegistry>,
 461        cx: &mut App,
 462    ) {
 463        if matches!(self, ContentBlock::Empty)
 464            && let acp::ContentBlock::ResourceLink(resource_link) = block
 465        {
 466            *self = ContentBlock::ResourceLink { resource_link };
 467            return;
 468        }
 469
 470        let new_content = self.block_string_contents(block);
 471
 472        match self {
 473            ContentBlock::Empty => {
 474                *self = Self::create_markdown_block(new_content, language_registry, cx);
 475            }
 476            ContentBlock::Markdown { markdown } => {
 477                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 478            }
 479            ContentBlock::ResourceLink { resource_link } => {
 480                let existing_content = Self::resource_link_md(&resource_link.uri);
 481                let combined = format!("{}\n{}", existing_content, new_content);
 482
 483                *self = Self::create_markdown_block(combined, language_registry, cx);
 484            }
 485        }
 486    }
 487
 488    fn create_markdown_block(
 489        content: String,
 490        language_registry: &Arc<LanguageRegistry>,
 491        cx: &mut App,
 492    ) -> ContentBlock {
 493        ContentBlock::Markdown {
 494            markdown: cx
 495                .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
 496        }
 497    }
 498
 499    fn block_string_contents(&self, block: acp::ContentBlock) -> String {
 500        match block {
 501            acp::ContentBlock::Text(text_content) => text_content.text,
 502            acp::ContentBlock::ResourceLink(resource_link) => {
 503                Self::resource_link_md(&resource_link.uri)
 504            }
 505            acp::ContentBlock::Resource(acp::EmbeddedResource {
 506                resource:
 507                    acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
 508                        uri,
 509                        ..
 510                    }),
 511                ..
 512            }) => Self::resource_link_md(&uri),
 513            acp::ContentBlock::Image(image) => Self::image_md(&image),
 514            acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
 515        }
 516    }
 517
 518    fn resource_link_md(uri: &str) -> String {
 519        if let Some(uri) = MentionUri::parse(uri).log_err() {
 520            uri.as_link().to_string()
 521        } else {
 522            uri.to_string()
 523        }
 524    }
 525
 526    fn image_md(_image: &acp::ImageContent) -> String {
 527        "`Image`".into()
 528    }
 529
 530    pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 531        match self {
 532            ContentBlock::Empty => "",
 533            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 534            ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
 535        }
 536    }
 537
 538    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 539        match self {
 540            ContentBlock::Empty => None,
 541            ContentBlock::Markdown { markdown } => Some(markdown),
 542            ContentBlock::ResourceLink { .. } => None,
 543        }
 544    }
 545
 546    pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
 547        match self {
 548            ContentBlock::ResourceLink { resource_link } => Some(resource_link),
 549            _ => None,
 550        }
 551    }
 552}
 553
 554#[derive(Debug)]
 555pub enum ToolCallContent {
 556    ContentBlock(ContentBlock),
 557    Diff(Entity<Diff>),
 558    Terminal(Entity<Terminal>),
 559}
 560
 561impl ToolCallContent {
 562    pub fn from_acp(
 563        content: acp::ToolCallContent,
 564        language_registry: Arc<LanguageRegistry>,
 565        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 566        cx: &mut App,
 567    ) -> Result<Self> {
 568        match content {
 569            acp::ToolCallContent::Content { content } => Ok(Self::ContentBlock(ContentBlock::new(
 570                content,
 571                &language_registry,
 572                cx,
 573            ))),
 574            acp::ToolCallContent::Diff { diff } => Ok(Self::Diff(cx.new(|cx| {
 575                Diff::finalized(
 576                    diff.path,
 577                    diff.old_text,
 578                    diff.new_text,
 579                    language_registry,
 580                    cx,
 581                )
 582            }))),
 583            acp::ToolCallContent::Terminal { terminal_id } => terminals
 584                .get(&terminal_id)
 585                .cloned()
 586                .map(Self::Terminal)
 587                .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
 588        }
 589    }
 590
 591    pub fn update_from_acp(
 592        &mut self,
 593        new: acp::ToolCallContent,
 594        language_registry: Arc<LanguageRegistry>,
 595        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 596        cx: &mut App,
 597    ) -> Result<()> {
 598        let needs_update = match (&self, &new) {
 599            (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
 600                old_diff.read(cx).needs_update(
 601                    new_diff.old_text.as_deref().unwrap_or(""),
 602                    &new_diff.new_text,
 603                    cx,
 604                )
 605            }
 606            _ => true,
 607        };
 608
 609        if needs_update {
 610            *self = Self::from_acp(new, language_registry, terminals, cx)?;
 611        }
 612        Ok(())
 613    }
 614
 615    pub fn to_markdown(&self, cx: &App) -> String {
 616        match self {
 617            Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
 618            Self::Diff(diff) => diff.read(cx).to_markdown(cx),
 619            Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
 620        }
 621    }
 622}
 623
 624#[derive(Debug, PartialEq)]
 625pub enum ToolCallUpdate {
 626    UpdateFields(acp::ToolCallUpdate),
 627    UpdateDiff(ToolCallUpdateDiff),
 628    UpdateTerminal(ToolCallUpdateTerminal),
 629}
 630
 631impl ToolCallUpdate {
 632    fn id(&self) -> &acp::ToolCallId {
 633        match self {
 634            Self::UpdateFields(update) => &update.id,
 635            Self::UpdateDiff(diff) => &diff.id,
 636            Self::UpdateTerminal(terminal) => &terminal.id,
 637        }
 638    }
 639}
 640
 641impl From<acp::ToolCallUpdate> for ToolCallUpdate {
 642    fn from(update: acp::ToolCallUpdate) -> Self {
 643        Self::UpdateFields(update)
 644    }
 645}
 646
 647impl From<ToolCallUpdateDiff> for ToolCallUpdate {
 648    fn from(diff: ToolCallUpdateDiff) -> Self {
 649        Self::UpdateDiff(diff)
 650    }
 651}
 652
 653#[derive(Debug, PartialEq)]
 654pub struct ToolCallUpdateDiff {
 655    pub id: acp::ToolCallId,
 656    pub diff: Entity<Diff>,
 657}
 658
 659impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
 660    fn from(terminal: ToolCallUpdateTerminal) -> Self {
 661        Self::UpdateTerminal(terminal)
 662    }
 663}
 664
 665#[derive(Debug, PartialEq)]
 666pub struct ToolCallUpdateTerminal {
 667    pub id: acp::ToolCallId,
 668    pub terminal: Entity<Terminal>,
 669}
 670
 671#[derive(Debug, Default)]
 672pub struct Plan {
 673    pub entries: Vec<PlanEntry>,
 674}
 675
 676#[derive(Debug)]
 677pub struct PlanStats<'a> {
 678    pub in_progress_entry: Option<&'a PlanEntry>,
 679    pub pending: u32,
 680    pub completed: u32,
 681}
 682
 683impl Plan {
 684    pub fn is_empty(&self) -> bool {
 685        self.entries.is_empty()
 686    }
 687
 688    pub fn stats(&self) -> PlanStats<'_> {
 689        let mut stats = PlanStats {
 690            in_progress_entry: None,
 691            pending: 0,
 692            completed: 0,
 693        };
 694
 695        for entry in &self.entries {
 696            match &entry.status {
 697                acp::PlanEntryStatus::Pending => {
 698                    stats.pending += 1;
 699                }
 700                acp::PlanEntryStatus::InProgress => {
 701                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 702                }
 703                acp::PlanEntryStatus::Completed => {
 704                    stats.completed += 1;
 705                }
 706            }
 707        }
 708
 709        stats
 710    }
 711}
 712
 713#[derive(Debug)]
 714pub struct PlanEntry {
 715    pub content: Entity<Markdown>,
 716    pub priority: acp::PlanEntryPriority,
 717    pub status: acp::PlanEntryStatus,
 718}
 719
 720impl PlanEntry {
 721    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 722        Self {
 723            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 724            priority: entry.priority,
 725            status: entry.status,
 726        }
 727    }
 728}
 729
 730#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 731pub struct TokenUsage {
 732    pub max_tokens: u64,
 733    pub used_tokens: u64,
 734}
 735
 736impl TokenUsage {
 737    pub fn ratio(&self) -> TokenUsageRatio {
 738        #[cfg(debug_assertions)]
 739        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 740            .unwrap_or("0.8".to_string())
 741            .parse()
 742            .unwrap();
 743        #[cfg(not(debug_assertions))]
 744        let warning_threshold: f32 = 0.8;
 745
 746        // When the maximum is unknown because there is no selected model,
 747        // avoid showing the token limit warning.
 748        if self.max_tokens == 0 {
 749            TokenUsageRatio::Normal
 750        } else if self.used_tokens >= self.max_tokens {
 751            TokenUsageRatio::Exceeded
 752        } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
 753            TokenUsageRatio::Warning
 754        } else {
 755            TokenUsageRatio::Normal
 756        }
 757    }
 758}
 759
 760#[derive(Debug, Clone, PartialEq, Eq)]
 761pub enum TokenUsageRatio {
 762    Normal,
 763    Warning,
 764    Exceeded,
 765}
 766
 767#[derive(Debug, Clone)]
 768pub struct RetryStatus {
 769    pub last_error: SharedString,
 770    pub attempt: usize,
 771    pub max_attempts: usize,
 772    pub started_at: Instant,
 773    pub duration: Duration,
 774}
 775
 776pub struct AcpThread {
 777    title: SharedString,
 778    entries: Vec<AgentThreadEntry>,
 779    plan: Plan,
 780    project: Entity<Project>,
 781    action_log: Entity<ActionLog>,
 782    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 783    send_task: Option<Task<()>>,
 784    connection: Rc<dyn AgentConnection>,
 785    session_id: acp::SessionId,
 786    token_usage: Option<TokenUsage>,
 787    prompt_capabilities: acp::PromptCapabilities,
 788    _observe_prompt_capabilities: Task<anyhow::Result<()>>,
 789    terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
 790}
 791
 792#[derive(Debug)]
 793pub enum AcpThreadEvent {
 794    NewEntry,
 795    TitleUpdated,
 796    TokenUsageUpdated,
 797    EntryUpdated(usize),
 798    EntriesRemoved(Range<usize>),
 799    ToolAuthorizationRequired,
 800    Retry(RetryStatus),
 801    Stopped,
 802    Error,
 803    LoadError(LoadError),
 804    PromptCapabilitiesUpdated,
 805    Refusal,
 806    AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
 807    ModeUpdated(acp::SessionModeId),
 808}
 809
 810impl EventEmitter<AcpThreadEvent> for AcpThread {}
 811
 812#[derive(PartialEq, Eq, Debug)]
 813pub enum ThreadStatus {
 814    Idle,
 815    Generating,
 816}
 817
 818#[derive(Debug, Clone)]
 819pub enum LoadError {
 820    Unsupported {
 821        command: SharedString,
 822        current_version: SharedString,
 823        minimum_version: SharedString,
 824    },
 825    FailedToInstall(SharedString),
 826    Exited {
 827        status: ExitStatus,
 828    },
 829    Other(SharedString),
 830}
 831
 832impl Display for LoadError {
 833    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 834        match self {
 835            LoadError::Unsupported {
 836                command: path,
 837                current_version,
 838                minimum_version,
 839            } => {
 840                write!(
 841                    f,
 842                    "version {current_version} from {path} is not supported (need at least {minimum_version})"
 843                )
 844            }
 845            LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
 846            LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
 847            LoadError::Other(msg) => write!(f, "{msg}"),
 848        }
 849    }
 850}
 851
 852impl Error for LoadError {}
 853
 854impl AcpThread {
 855    pub fn new(
 856        title: impl Into<SharedString>,
 857        connection: Rc<dyn AgentConnection>,
 858        project: Entity<Project>,
 859        action_log: Entity<ActionLog>,
 860        session_id: acp::SessionId,
 861        mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
 862        cx: &mut Context<Self>,
 863    ) -> Self {
 864        let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
 865        let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
 866            loop {
 867                let caps = prompt_capabilities_rx.recv().await?;
 868                this.update(cx, |this, cx| {
 869                    this.prompt_capabilities = caps;
 870                    cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
 871                })?;
 872            }
 873        });
 874
 875        Self {
 876            action_log,
 877            shared_buffers: Default::default(),
 878            entries: Default::default(),
 879            plan: Default::default(),
 880            title: title.into(),
 881            project,
 882            send_task: None,
 883            connection,
 884            session_id,
 885            token_usage: None,
 886            prompt_capabilities,
 887            _observe_prompt_capabilities: task,
 888            terminals: HashMap::default(),
 889        }
 890    }
 891
 892    pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
 893        self.prompt_capabilities.clone()
 894    }
 895
 896    pub fn connection(&self) -> &Rc<dyn AgentConnection> {
 897        &self.connection
 898    }
 899
 900    pub fn action_log(&self) -> &Entity<ActionLog> {
 901        &self.action_log
 902    }
 903
 904    pub fn project(&self) -> &Entity<Project> {
 905        &self.project
 906    }
 907
 908    pub fn title(&self) -> SharedString {
 909        self.title.clone()
 910    }
 911
 912    pub fn entries(&self) -> &[AgentThreadEntry] {
 913        &self.entries
 914    }
 915
 916    pub fn session_id(&self) -> &acp::SessionId {
 917        &self.session_id
 918    }
 919
 920    pub fn status(&self) -> ThreadStatus {
 921        if self.send_task.is_some() {
 922            ThreadStatus::Generating
 923        } else {
 924            ThreadStatus::Idle
 925        }
 926    }
 927
 928    pub fn token_usage(&self) -> Option<&TokenUsage> {
 929        self.token_usage.as_ref()
 930    }
 931
 932    pub fn has_pending_edit_tool_calls(&self) -> bool {
 933        for entry in self.entries.iter().rev() {
 934            match entry {
 935                AgentThreadEntry::UserMessage(_) => return false,
 936                AgentThreadEntry::ToolCall(
 937                    call @ ToolCall {
 938                        status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
 939                        ..
 940                    },
 941                ) if call.diffs().next().is_some() => {
 942                    return true;
 943                }
 944                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 945            }
 946        }
 947
 948        false
 949    }
 950
 951    pub fn used_tools_since_last_user_message(&self) -> bool {
 952        for entry in self.entries.iter().rev() {
 953            match entry {
 954                AgentThreadEntry::UserMessage(..) => return false,
 955                AgentThreadEntry::AssistantMessage(..) => continue,
 956                AgentThreadEntry::ToolCall(..) => return true,
 957            }
 958        }
 959
 960        false
 961    }
 962
 963    pub fn handle_session_update(
 964        &mut self,
 965        update: acp::SessionUpdate,
 966        cx: &mut Context<Self>,
 967    ) -> Result<(), acp::Error> {
 968        match update {
 969            acp::SessionUpdate::UserMessageChunk { content } => {
 970                self.push_user_content_block(None, content, cx);
 971            }
 972            acp::SessionUpdate::AgentMessageChunk { content } => {
 973                self.push_assistant_content_block(content, false, cx);
 974            }
 975            acp::SessionUpdate::AgentThoughtChunk { content } => {
 976                self.push_assistant_content_block(content, true, cx);
 977            }
 978            acp::SessionUpdate::ToolCall(tool_call) => {
 979                self.upsert_tool_call(tool_call, cx)?;
 980            }
 981            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 982                self.update_tool_call(tool_call_update, cx)?;
 983            }
 984            acp::SessionUpdate::Plan(plan) => {
 985                self.update_plan(plan, cx);
 986            }
 987            acp::SessionUpdate::AvailableCommandsUpdate { available_commands } => {
 988                cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands))
 989            }
 990            acp::SessionUpdate::CurrentModeUpdate { current_mode_id } => {
 991                cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id))
 992            }
 993        }
 994        Ok(())
 995    }
 996
 997    pub fn push_user_content_block(
 998        &mut self,
 999        message_id: Option<UserMessageId>,
1000        chunk: acp::ContentBlock,
1001        cx: &mut Context<Self>,
1002    ) {
1003        let language_registry = self.project.read(cx).languages().clone();
1004        let entries_len = self.entries.len();
1005
1006        if let Some(last_entry) = self.entries.last_mut()
1007            && let AgentThreadEntry::UserMessage(UserMessage {
1008                id,
1009                content,
1010                chunks,
1011                ..
1012            }) = last_entry
1013        {
1014            *id = message_id.or(id.take());
1015            content.append(chunk.clone(), &language_registry, cx);
1016            chunks.push(chunk);
1017            let idx = entries_len - 1;
1018            cx.emit(AcpThreadEvent::EntryUpdated(idx));
1019        } else {
1020            let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
1021            self.push_entry(
1022                AgentThreadEntry::UserMessage(UserMessage {
1023                    id: message_id,
1024                    content,
1025                    chunks: vec![chunk],
1026                    checkpoint: None,
1027                }),
1028                cx,
1029            );
1030        }
1031    }
1032
1033    pub fn push_assistant_content_block(
1034        &mut self,
1035        chunk: acp::ContentBlock,
1036        is_thought: bool,
1037        cx: &mut Context<Self>,
1038    ) {
1039        let language_registry = self.project.read(cx).languages().clone();
1040        let entries_len = self.entries.len();
1041        if let Some(last_entry) = self.entries.last_mut()
1042            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1043        {
1044            let idx = entries_len - 1;
1045            cx.emit(AcpThreadEvent::EntryUpdated(idx));
1046            match (chunks.last_mut(), is_thought) {
1047                (Some(AssistantMessageChunk::Message { block }), false)
1048                | (Some(AssistantMessageChunk::Thought { block }), true) => {
1049                    block.append(chunk, &language_registry, cx)
1050                }
1051                _ => {
1052                    let block = ContentBlock::new(chunk, &language_registry, cx);
1053                    if is_thought {
1054                        chunks.push(AssistantMessageChunk::Thought { block })
1055                    } else {
1056                        chunks.push(AssistantMessageChunk::Message { block })
1057                    }
1058                }
1059            }
1060        } else {
1061            let block = ContentBlock::new(chunk, &language_registry, cx);
1062            let chunk = if is_thought {
1063                AssistantMessageChunk::Thought { block }
1064            } else {
1065                AssistantMessageChunk::Message { block }
1066            };
1067
1068            self.push_entry(
1069                AgentThreadEntry::AssistantMessage(AssistantMessage {
1070                    chunks: vec![chunk],
1071                }),
1072                cx,
1073            );
1074        }
1075    }
1076
1077    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1078        self.entries.push(entry);
1079        cx.emit(AcpThreadEvent::NewEntry);
1080    }
1081
1082    pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1083        self.connection.set_title(&self.session_id, cx).is_some()
1084    }
1085
1086    pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1087        if title != self.title {
1088            self.title = title.clone();
1089            cx.emit(AcpThreadEvent::TitleUpdated);
1090            if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1091                return set_title.run(title, cx);
1092            }
1093        }
1094        Task::ready(Ok(()))
1095    }
1096
1097    pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1098        self.token_usage = usage;
1099        cx.emit(AcpThreadEvent::TokenUsageUpdated);
1100    }
1101
1102    pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1103        cx.emit(AcpThreadEvent::Retry(status));
1104    }
1105
1106    pub fn update_tool_call(
1107        &mut self,
1108        update: impl Into<ToolCallUpdate>,
1109        cx: &mut Context<Self>,
1110    ) -> Result<()> {
1111        let update = update.into();
1112        let languages = self.project.read(cx).languages().clone();
1113
1114        let ix = match self.index_for_tool_call(update.id()) {
1115            Some(ix) => ix,
1116            None => {
1117                // Tool call not found - create a failed tool call entry
1118                let failed_tool_call = ToolCall {
1119                    id: update.id().clone(),
1120                    label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1121                    kind: acp::ToolKind::Fetch,
1122                    content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1123                        acp::ContentBlock::Text(acp::TextContent {
1124                            text: "Tool call not found".to_string(),
1125                            annotations: None,
1126                            meta: None,
1127                        }),
1128                        &languages,
1129                        cx,
1130                    ))],
1131                    status: ToolCallStatus::Failed,
1132                    locations: Vec::new(),
1133                    resolved_locations: Vec::new(),
1134                    raw_input: None,
1135                    raw_output: None,
1136                };
1137                self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1138                return Ok(());
1139            }
1140        };
1141        let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1142            unreachable!()
1143        };
1144
1145        match update {
1146            ToolCallUpdate::UpdateFields(update) => {
1147                let location_updated = update.fields.locations.is_some();
1148                call.update_fields(update.fields, languages, &self.terminals, cx)?;
1149                if location_updated {
1150                    self.resolve_locations(update.id, cx);
1151                }
1152            }
1153            ToolCallUpdate::UpdateDiff(update) => {
1154                call.content.clear();
1155                call.content.push(ToolCallContent::Diff(update.diff));
1156            }
1157            ToolCallUpdate::UpdateTerminal(update) => {
1158                call.content.clear();
1159                call.content
1160                    .push(ToolCallContent::Terminal(update.terminal));
1161            }
1162        }
1163
1164        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1165
1166        Ok(())
1167    }
1168
1169    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1170    pub fn upsert_tool_call(
1171        &mut self,
1172        tool_call: acp::ToolCall,
1173        cx: &mut Context<Self>,
1174    ) -> Result<(), acp::Error> {
1175        let status = tool_call.status.into();
1176        self.upsert_tool_call_inner(tool_call.into(), status, cx)
1177    }
1178
1179    /// Fails if id does not match an existing entry.
1180    pub fn upsert_tool_call_inner(
1181        &mut self,
1182        update: acp::ToolCallUpdate,
1183        status: ToolCallStatus,
1184        cx: &mut Context<Self>,
1185    ) -> Result<(), acp::Error> {
1186        let language_registry = self.project.read(cx).languages().clone();
1187        let id = update.id.clone();
1188
1189        if let Some(ix) = self.index_for_tool_call(&id) {
1190            let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1191                unreachable!()
1192            };
1193
1194            call.update_fields(update.fields, language_registry, &self.terminals, cx)?;
1195            call.status = status;
1196
1197            cx.emit(AcpThreadEvent::EntryUpdated(ix));
1198        } else {
1199            let call = ToolCall::from_acp(
1200                update.try_into()?,
1201                status,
1202                language_registry,
1203                &self.terminals,
1204                cx,
1205            )?;
1206            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1207        };
1208
1209        self.resolve_locations(id, cx);
1210        Ok(())
1211    }
1212
1213    fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1214        self.entries
1215            .iter()
1216            .enumerate()
1217            .rev()
1218            .find_map(|(index, entry)| {
1219                if let AgentThreadEntry::ToolCall(tool_call) = entry
1220                    && &tool_call.id == id
1221                {
1222                    Some(index)
1223                } else {
1224                    None
1225                }
1226            })
1227    }
1228
1229    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1230        // The tool call we are looking for is typically the last one, or very close to the end.
1231        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1232        self.entries
1233            .iter_mut()
1234            .enumerate()
1235            .rev()
1236            .find_map(|(index, tool_call)| {
1237                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1238                    && &tool_call.id == id
1239                {
1240                    Some((index, tool_call))
1241                } else {
1242                    None
1243                }
1244            })
1245    }
1246
1247    pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1248        self.entries
1249            .iter()
1250            .enumerate()
1251            .rev()
1252            .find_map(|(index, tool_call)| {
1253                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1254                    && &tool_call.id == id
1255                {
1256                    Some((index, tool_call))
1257                } else {
1258                    None
1259                }
1260            })
1261    }
1262
1263    pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1264        let project = self.project.clone();
1265        let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1266            return;
1267        };
1268        let task = tool_call.resolve_locations(project, cx);
1269        cx.spawn(async move |this, cx| {
1270            let resolved_locations = task.await;
1271            this.update(cx, |this, cx| {
1272                let project = this.project.clone();
1273                let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1274                    return;
1275                };
1276                if let Some(Some(location)) = resolved_locations.last() {
1277                    project.update(cx, |project, cx| {
1278                        if let Some(agent_location) = project.agent_location() {
1279                            let should_ignore = agent_location.buffer == location.buffer
1280                                && location
1281                                    .buffer
1282                                    .update(cx, |buffer, _| {
1283                                        let snapshot = buffer.snapshot();
1284                                        let old_position =
1285                                            agent_location.position.to_point(&snapshot);
1286                                        let new_position = location.position.to_point(&snapshot);
1287                                        // ignore this so that when we get updates from the edit tool
1288                                        // the position doesn't reset to the startof line
1289                                        old_position.row == new_position.row
1290                                            && old_position.column > new_position.column
1291                                    })
1292                                    .ok()
1293                                    .unwrap_or_default();
1294                            if !should_ignore {
1295                                project.set_agent_location(Some(location.clone()), cx);
1296                            }
1297                        }
1298                    });
1299                }
1300                if tool_call.resolved_locations != resolved_locations {
1301                    tool_call.resolved_locations = resolved_locations;
1302                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1303                }
1304            })
1305        })
1306        .detach();
1307    }
1308
1309    pub fn request_tool_call_authorization(
1310        &mut self,
1311        tool_call: acp::ToolCallUpdate,
1312        options: Vec<acp::PermissionOption>,
1313        respect_always_allow_setting: bool,
1314        cx: &mut Context<Self>,
1315    ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1316        let (tx, rx) = oneshot::channel();
1317
1318        if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1319            // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1320            // some tools would (incorrectly) continue to auto-accept.
1321            if let Some(allow_once_option) = options.iter().find_map(|option| {
1322                if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1323                    Some(option.id.clone())
1324                } else {
1325                    None
1326                }
1327            }) {
1328                self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1329                return Ok(async {
1330                    acp::RequestPermissionOutcome::Selected {
1331                        option_id: allow_once_option,
1332                    }
1333                }
1334                .boxed());
1335            }
1336        }
1337
1338        let status = ToolCallStatus::WaitingForConfirmation {
1339            options,
1340            respond_tx: tx,
1341        };
1342
1343        self.upsert_tool_call_inner(tool_call, status, cx)?;
1344        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1345
1346        let fut = async {
1347            match rx.await {
1348                Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
1349                Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1350            }
1351        }
1352        .boxed();
1353
1354        Ok(fut)
1355    }
1356
1357    pub fn authorize_tool_call(
1358        &mut self,
1359        id: acp::ToolCallId,
1360        option_id: acp::PermissionOptionId,
1361        option_kind: acp::PermissionOptionKind,
1362        cx: &mut Context<Self>,
1363    ) {
1364        let Some((ix, call)) = self.tool_call_mut(&id) else {
1365            return;
1366        };
1367
1368        let new_status = match option_kind {
1369            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1370                ToolCallStatus::Rejected
1371            }
1372            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1373                ToolCallStatus::InProgress
1374            }
1375        };
1376
1377        let curr_status = mem::replace(&mut call.status, new_status);
1378
1379        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1380            respond_tx.send(option_id).log_err();
1381        } else if cfg!(debug_assertions) {
1382            panic!("tried to authorize an already authorized tool call");
1383        }
1384
1385        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1386    }
1387
1388    pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1389        let mut first_tool_call = None;
1390
1391        for entry in self.entries.iter().rev() {
1392            match &entry {
1393                AgentThreadEntry::ToolCall(call) => {
1394                    if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1395                        first_tool_call = Some(call);
1396                    } else {
1397                        continue;
1398                    }
1399                }
1400                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1401                    // Reached the beginning of the turn.
1402                    // If we had pending permission requests in the previous turn, they have been cancelled.
1403                    break;
1404                }
1405            }
1406        }
1407
1408        first_tool_call
1409    }
1410
1411    pub fn plan(&self) -> &Plan {
1412        &self.plan
1413    }
1414
1415    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1416        let new_entries_len = request.entries.len();
1417        let mut new_entries = request.entries.into_iter();
1418
1419        // Reuse existing markdown to prevent flickering
1420        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1421            let PlanEntry {
1422                content,
1423                priority,
1424                status,
1425            } = old;
1426            content.update(cx, |old, cx| {
1427                old.replace(new.content, cx);
1428            });
1429            *priority = new.priority;
1430            *status = new.status;
1431        }
1432        for new in new_entries {
1433            self.plan.entries.push(PlanEntry::from_acp(new, cx))
1434        }
1435        self.plan.entries.truncate(new_entries_len);
1436
1437        cx.notify();
1438    }
1439
1440    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1441        self.plan
1442            .entries
1443            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1444        cx.notify();
1445    }
1446
1447    #[cfg(any(test, feature = "test-support"))]
1448    pub fn send_raw(
1449        &mut self,
1450        message: &str,
1451        cx: &mut Context<Self>,
1452    ) -> BoxFuture<'static, Result<()>> {
1453        self.send(
1454            vec![acp::ContentBlock::Text(acp::TextContent {
1455                text: message.to_string(),
1456                annotations: None,
1457                meta: None,
1458            })],
1459            cx,
1460        )
1461    }
1462
1463    pub fn send(
1464        &mut self,
1465        message: Vec<acp::ContentBlock>,
1466        cx: &mut Context<Self>,
1467    ) -> BoxFuture<'static, Result<()>> {
1468        let block = ContentBlock::new_combined(
1469            message.clone(),
1470            self.project.read(cx).languages().clone(),
1471            cx,
1472        );
1473        let request = acp::PromptRequest {
1474            prompt: message.clone(),
1475            session_id: self.session_id.clone(),
1476            meta: None,
1477        };
1478        let git_store = self.project.read(cx).git_store().clone();
1479
1480        let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1481            Some(UserMessageId::new())
1482        } else {
1483            None
1484        };
1485
1486        self.run_turn(cx, async move |this, cx| {
1487            this.update(cx, |this, cx| {
1488                this.push_entry(
1489                    AgentThreadEntry::UserMessage(UserMessage {
1490                        id: message_id.clone(),
1491                        content: block,
1492                        chunks: message,
1493                        checkpoint: None,
1494                    }),
1495                    cx,
1496                );
1497            })
1498            .ok();
1499
1500            let old_checkpoint = git_store
1501                .update(cx, |git, cx| git.checkpoint(cx))?
1502                .await
1503                .context("failed to get old checkpoint")
1504                .log_err();
1505            this.update(cx, |this, cx| {
1506                if let Some((_ix, message)) = this.last_user_message() {
1507                    message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1508                        git_checkpoint,
1509                        show: false,
1510                    });
1511                }
1512                this.connection.prompt(message_id, request, cx)
1513            })?
1514            .await
1515        })
1516    }
1517
1518    pub fn can_resume(&self, cx: &App) -> bool {
1519        self.connection.resume(&self.session_id, cx).is_some()
1520    }
1521
1522    pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1523        self.run_turn(cx, async move |this, cx| {
1524            this.update(cx, |this, cx| {
1525                this.connection
1526                    .resume(&this.session_id, cx)
1527                    .map(|resume| resume.run(cx))
1528            })?
1529            .context("resuming a session is not supported")?
1530            .await
1531        })
1532    }
1533
1534    fn run_turn(
1535        &mut self,
1536        cx: &mut Context<Self>,
1537        f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1538    ) -> BoxFuture<'static, Result<()>> {
1539        self.clear_completed_plan_entries(cx);
1540
1541        let (tx, rx) = oneshot::channel();
1542        let cancel_task = self.cancel(cx);
1543
1544        self.send_task = Some(cx.spawn(async move |this, cx| {
1545            cancel_task.await;
1546            tx.send(f(this, cx).await).ok();
1547        }));
1548
1549        cx.spawn(async move |this, cx| {
1550            let response = rx.await;
1551
1552            this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1553                .await?;
1554
1555            this.update(cx, |this, cx| {
1556                this.project
1557                    .update(cx, |project, cx| project.set_agent_location(None, cx));
1558                match response {
1559                    Ok(Err(e)) => {
1560                        this.send_task.take();
1561                        cx.emit(AcpThreadEvent::Error);
1562                        Err(e)
1563                    }
1564                    result => {
1565                        let canceled = matches!(
1566                            result,
1567                            Ok(Ok(acp::PromptResponse {
1568                                stop_reason: acp::StopReason::Cancelled,
1569                                meta: None,
1570                            }))
1571                        );
1572
1573                        // We only take the task if the current prompt wasn't canceled.
1574                        //
1575                        // This prompt may have been canceled because another one was sent
1576                        // while it was still generating. In these cases, dropping `send_task`
1577                        // would cause the next generation to be canceled.
1578                        if !canceled {
1579                            this.send_task.take();
1580                        }
1581
1582                        // Handle refusal - distinguish between user prompt and tool call refusals
1583                        if let Ok(Ok(acp::PromptResponse {
1584                            stop_reason: acp::StopReason::Refusal,
1585                            meta: _,
1586                        })) = result
1587                        {
1588                            if let Some((user_msg_ix, _)) = this.last_user_message() {
1589                                // Check if there's a completed tool call with results after the last user message
1590                                // This indicates the refusal is in response to tool output, not the user's prompt
1591                                let has_completed_tool_call_after_user_msg =
1592                                    this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1593                                        if let AgentThreadEntry::ToolCall(tool_call) = entry {
1594                                            // Check if the tool call has completed and has output
1595                                            matches!(tool_call.status, ToolCallStatus::Completed)
1596                                                && tool_call.raw_output.is_some()
1597                                        } else {
1598                                            false
1599                                        }
1600                                    });
1601
1602                                if has_completed_tool_call_after_user_msg {
1603                                    // Refusal is due to tool output - don't truncate, just notify
1604                                    // The model refused based on what the tool returned
1605                                    cx.emit(AcpThreadEvent::Refusal);
1606                                } else {
1607                                    // User prompt was refused - truncate back to before the user message
1608                                    let range = user_msg_ix..this.entries.len();
1609                                    if range.start < range.end {
1610                                        this.entries.truncate(user_msg_ix);
1611                                        cx.emit(AcpThreadEvent::EntriesRemoved(range));
1612                                    }
1613                                    cx.emit(AcpThreadEvent::Refusal);
1614                                }
1615                            } else {
1616                                // No user message found, treat as general refusal
1617                                cx.emit(AcpThreadEvent::Refusal);
1618                            }
1619                        }
1620
1621                        cx.emit(AcpThreadEvent::Stopped);
1622                        Ok(())
1623                    }
1624                }
1625            })?
1626        })
1627        .boxed()
1628    }
1629
1630    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1631        let Some(send_task) = self.send_task.take() else {
1632            return Task::ready(());
1633        };
1634
1635        for entry in self.entries.iter_mut() {
1636            if let AgentThreadEntry::ToolCall(call) = entry {
1637                let cancel = matches!(
1638                    call.status,
1639                    ToolCallStatus::Pending
1640                        | ToolCallStatus::WaitingForConfirmation { .. }
1641                        | ToolCallStatus::InProgress
1642                );
1643
1644                if cancel {
1645                    call.status = ToolCallStatus::Canceled;
1646                }
1647            }
1648        }
1649
1650        self.connection.cancel(&self.session_id, cx);
1651
1652        // Wait for the send task to complete
1653        cx.foreground_executor().spawn(send_task)
1654    }
1655
1656    /// Restores the git working tree to the state at the given checkpoint (if one exists)
1657    pub fn restore_checkpoint(
1658        &mut self,
1659        id: UserMessageId,
1660        cx: &mut Context<Self>,
1661    ) -> Task<Result<()>> {
1662        let Some((_, message)) = self.user_message_mut(&id) else {
1663            return Task::ready(Err(anyhow!("message not found")));
1664        };
1665
1666        let checkpoint = message
1667            .checkpoint
1668            .as_ref()
1669            .map(|c| c.git_checkpoint.clone());
1670        let rewind = self.rewind(id.clone(), cx);
1671        let git_store = self.project.read(cx).git_store().clone();
1672
1673        cx.spawn(async move |_, cx| {
1674            rewind.await?;
1675            if let Some(checkpoint) = checkpoint {
1676                git_store
1677                    .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1678                    .await?;
1679            }
1680
1681            Ok(())
1682        })
1683    }
1684
1685    /// Rewinds this thread to before the entry at `index`, removing it and all
1686    /// subsequent entries while rejecting any action_log changes made from that point.
1687    /// Unlike `restore_checkpoint`, this method does not restore from git.
1688    pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1689        let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1690            return Task::ready(Err(anyhow!("not supported")));
1691        };
1692
1693        cx.spawn(async move |this, cx| {
1694            cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1695            this.update(cx, |this, cx| {
1696                if let Some((ix, _)) = this.user_message_mut(&id) {
1697                    let range = ix..this.entries.len();
1698                    this.entries.truncate(ix);
1699                    cx.emit(AcpThreadEvent::EntriesRemoved(range));
1700                }
1701                this.action_log()
1702                    .update(cx, |action_log, cx| action_log.reject_all_edits(cx))
1703            })?
1704            .await;
1705            Ok(())
1706        })
1707    }
1708
1709    fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1710        let git_store = self.project.read(cx).git_store().clone();
1711
1712        let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1713            if let Some(checkpoint) = message.checkpoint.as_ref() {
1714                checkpoint.git_checkpoint.clone()
1715            } else {
1716                return Task::ready(Ok(()));
1717            }
1718        } else {
1719            return Task::ready(Ok(()));
1720        };
1721
1722        let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1723        cx.spawn(async move |this, cx| {
1724            let new_checkpoint = new_checkpoint
1725                .await
1726                .context("failed to get new checkpoint")
1727                .log_err();
1728            if let Some(new_checkpoint) = new_checkpoint {
1729                let equal = git_store
1730                    .update(cx, |git, cx| {
1731                        git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1732                    })?
1733                    .await
1734                    .unwrap_or(true);
1735                this.update(cx, |this, cx| {
1736                    let (ix, message) = this.last_user_message().context("no user message")?;
1737                    let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1738                    checkpoint.show = !equal;
1739                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1740                    anyhow::Ok(())
1741                })??;
1742            }
1743
1744            Ok(())
1745        })
1746    }
1747
1748    fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1749        self.entries
1750            .iter_mut()
1751            .enumerate()
1752            .rev()
1753            .find_map(|(ix, entry)| {
1754                if let AgentThreadEntry::UserMessage(message) = entry {
1755                    Some((ix, message))
1756                } else {
1757                    None
1758                }
1759            })
1760    }
1761
1762    fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1763        self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1764            if let AgentThreadEntry::UserMessage(message) = entry {
1765                if message.id.as_ref() == Some(id) {
1766                    Some((ix, message))
1767                } else {
1768                    None
1769                }
1770            } else {
1771                None
1772            }
1773        })
1774    }
1775
1776    pub fn read_text_file(
1777        &self,
1778        path: PathBuf,
1779        line: Option<u32>,
1780        limit: Option<u32>,
1781        reuse_shared_snapshot: bool,
1782        cx: &mut Context<Self>,
1783    ) -> Task<Result<String>> {
1784        // Args are 1-based, move to 0-based
1785        let line = line.unwrap_or_default().saturating_sub(1);
1786        let limit = limit.unwrap_or(u32::MAX);
1787        let project = self.project.clone();
1788        let action_log = self.action_log.clone();
1789        cx.spawn(async move |this, cx| {
1790            let load = project.update(cx, |project, cx| {
1791                let path = project
1792                    .project_path_for_absolute_path(&path, cx)
1793                    .context("invalid path")?;
1794                anyhow::Ok(project.open_buffer(path, cx))
1795            });
1796            let buffer = load??.await?;
1797
1798            let snapshot = if reuse_shared_snapshot {
1799                this.read_with(cx, |this, _| {
1800                    this.shared_buffers.get(&buffer.clone()).cloned()
1801                })
1802                .log_err()
1803                .flatten()
1804            } else {
1805                None
1806            };
1807
1808            let snapshot = if let Some(snapshot) = snapshot {
1809                snapshot
1810            } else {
1811                action_log.update(cx, |action_log, cx| {
1812                    action_log.buffer_read(buffer.clone(), cx);
1813                })?;
1814
1815                let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
1816                this.update(cx, |this, _| {
1817                    this.shared_buffers.insert(buffer.clone(), snapshot.clone());
1818                })?;
1819                snapshot
1820            };
1821
1822            let max_point = snapshot.max_point();
1823            let start_position = Point::new(line, 0);
1824
1825            if start_position > max_point {
1826                anyhow::bail!(
1827                    "Attempting to read beyond the end of the file, line {}:{}",
1828                    max_point.row + 1,
1829                    max_point.column
1830                );
1831            }
1832
1833            let start = snapshot.anchor_before(start_position);
1834            let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
1835
1836            project.update(cx, |project, cx| {
1837                project.set_agent_location(
1838                    Some(AgentLocation {
1839                        buffer: buffer.downgrade(),
1840                        position: start,
1841                    }),
1842                    cx,
1843                );
1844            })?;
1845
1846            Ok(snapshot.text_for_range(start..end).collect::<String>())
1847        })
1848    }
1849
1850    pub fn write_text_file(
1851        &self,
1852        path: PathBuf,
1853        content: String,
1854        cx: &mut Context<Self>,
1855    ) -> Task<Result<()>> {
1856        let project = self.project.clone();
1857        let action_log = self.action_log.clone();
1858        cx.spawn(async move |this, cx| {
1859            let load = project.update(cx, |project, cx| {
1860                let path = project
1861                    .project_path_for_absolute_path(&path, cx)
1862                    .context("invalid path")?;
1863                anyhow::Ok(project.open_buffer(path, cx))
1864            });
1865            let buffer = load??.await?;
1866            let snapshot = this.update(cx, |this, cx| {
1867                this.shared_buffers
1868                    .get(&buffer)
1869                    .cloned()
1870                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1871            })?;
1872            let edits = cx
1873                .background_executor()
1874                .spawn(async move {
1875                    let old_text = snapshot.text();
1876                    text_diff(old_text.as_str(), &content)
1877                        .into_iter()
1878                        .map(|(range, replacement)| {
1879                            (
1880                                snapshot.anchor_after(range.start)
1881                                    ..snapshot.anchor_before(range.end),
1882                                replacement,
1883                            )
1884                        })
1885                        .collect::<Vec<_>>()
1886                })
1887                .await;
1888
1889            project.update(cx, |project, cx| {
1890                project.set_agent_location(
1891                    Some(AgentLocation {
1892                        buffer: buffer.downgrade(),
1893                        position: edits
1894                            .last()
1895                            .map(|(range, _)| range.end)
1896                            .unwrap_or(Anchor::MIN),
1897                    }),
1898                    cx,
1899                );
1900            })?;
1901
1902            let format_on_save = cx.update(|cx| {
1903                action_log.update(cx, |action_log, cx| {
1904                    action_log.buffer_read(buffer.clone(), cx);
1905                });
1906
1907                let format_on_save = buffer.update(cx, |buffer, cx| {
1908                    buffer.edit(edits, None, cx);
1909
1910                    let settings = language::language_settings::language_settings(
1911                        buffer.language().map(|l| l.name()),
1912                        buffer.file(),
1913                        cx,
1914                    );
1915
1916                    settings.format_on_save != FormatOnSave::Off
1917                });
1918                action_log.update(cx, |action_log, cx| {
1919                    action_log.buffer_edited(buffer.clone(), cx);
1920                });
1921                format_on_save
1922            })?;
1923
1924            if format_on_save {
1925                let format_task = project.update(cx, |project, cx| {
1926                    project.format(
1927                        HashSet::from_iter([buffer.clone()]),
1928                        LspFormatTarget::Buffers,
1929                        false,
1930                        FormatTrigger::Save,
1931                        cx,
1932                    )
1933                })?;
1934                format_task.await.log_err();
1935
1936                action_log.update(cx, |action_log, cx| {
1937                    action_log.buffer_edited(buffer.clone(), cx);
1938                })?;
1939            }
1940
1941            project
1942                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1943                .await
1944        })
1945    }
1946
1947    pub fn create_terminal(
1948        &self,
1949        command: String,
1950        args: Vec<String>,
1951        extra_env: Vec<acp::EnvVariable>,
1952        cwd: Option<PathBuf>,
1953        output_byte_limit: Option<u64>,
1954        cx: &mut Context<Self>,
1955    ) -> Task<Result<Entity<Terminal>>> {
1956        let env = match &cwd {
1957            Some(dir) => self.project.update(cx, |project, cx| {
1958                project.directory_environment(dir.as_path().into(), cx)
1959            }),
1960            None => Task::ready(None).shared(),
1961        };
1962
1963        let env = cx.spawn(async move |_, _| {
1964            let mut env = env.await.unwrap_or_default();
1965            if cfg!(unix) {
1966                env.insert("PAGER".into(), "cat".into());
1967            }
1968            for var in extra_env {
1969                env.insert(var.name, var.value);
1970            }
1971            env
1972        });
1973
1974        let project = self.project.clone();
1975        let language_registry = project.read(cx).languages().clone();
1976
1977        let terminal_id = acp::TerminalId(Uuid::new_v4().to_string().into());
1978        let terminal_task = cx.spawn({
1979            let terminal_id = terminal_id.clone();
1980            async move |_this, cx| {
1981                let env = env.await;
1982                let (command, args) = ShellBuilder::new(
1983                    project
1984                        .update(cx, |project, cx| {
1985                            project
1986                                .remote_client()
1987                                .and_then(|r| r.read(cx).default_system_shell())
1988                        })?
1989                        .as_deref(),
1990                    &Shell::Program(get_default_system_shell()),
1991                )
1992                .redirect_stdin_to_dev_null()
1993                .build(Some(command), &args);
1994                let terminal = project
1995                    .update(cx, |project, cx| {
1996                        project.create_terminal_task(
1997                            task::SpawnInTerminal {
1998                                command: Some(command.clone()),
1999                                args: args.clone(),
2000                                cwd: cwd.clone(),
2001                                env,
2002                                ..Default::default()
2003                            },
2004                            cx,
2005                        )
2006                    })?
2007                    .await?;
2008
2009                cx.new(|cx| {
2010                    Terminal::new(
2011                        terminal_id,
2012                        &format!("{} {}", command, args.join(" ")),
2013                        cwd,
2014                        output_byte_limit.map(|l| l as usize),
2015                        terminal,
2016                        language_registry,
2017                        cx,
2018                    )
2019                })
2020            }
2021        });
2022
2023        cx.spawn(async move |this, cx| {
2024            let terminal = terminal_task.await?;
2025            this.update(cx, |this, _cx| {
2026                this.terminals.insert(terminal_id, terminal.clone());
2027                terminal
2028            })
2029        })
2030    }
2031
2032    pub fn kill_terminal(
2033        &mut self,
2034        terminal_id: acp::TerminalId,
2035        cx: &mut Context<Self>,
2036    ) -> Result<()> {
2037        self.terminals
2038            .get(&terminal_id)
2039            .context("Terminal not found")?
2040            .update(cx, |terminal, cx| {
2041                terminal.kill(cx);
2042            });
2043
2044        Ok(())
2045    }
2046
2047    pub fn release_terminal(
2048        &mut self,
2049        terminal_id: acp::TerminalId,
2050        cx: &mut Context<Self>,
2051    ) -> Result<()> {
2052        self.terminals
2053            .remove(&terminal_id)
2054            .context("Terminal not found")?
2055            .update(cx, |terminal, cx| {
2056                terminal.kill(cx);
2057            });
2058
2059        Ok(())
2060    }
2061
2062    pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2063        self.terminals
2064            .get(&terminal_id)
2065            .context("Terminal not found")
2066            .cloned()
2067    }
2068
2069    pub fn to_markdown(&self, cx: &App) -> String {
2070        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2071    }
2072
2073    pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2074        cx.emit(AcpThreadEvent::LoadError(error));
2075    }
2076}
2077
2078fn markdown_for_raw_output(
2079    raw_output: &serde_json::Value,
2080    language_registry: &Arc<LanguageRegistry>,
2081    cx: &mut App,
2082) -> Option<Entity<Markdown>> {
2083    match raw_output {
2084        serde_json::Value::Null => None,
2085        serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2086            Markdown::new(
2087                value.to_string().into(),
2088                Some(language_registry.clone()),
2089                None,
2090                cx,
2091            )
2092        })),
2093        serde_json::Value::Number(value) => Some(cx.new(|cx| {
2094            Markdown::new(
2095                value.to_string().into(),
2096                Some(language_registry.clone()),
2097                None,
2098                cx,
2099            )
2100        })),
2101        serde_json::Value::String(value) => Some(cx.new(|cx| {
2102            Markdown::new(
2103                value.clone().into(),
2104                Some(language_registry.clone()),
2105                None,
2106                cx,
2107            )
2108        })),
2109        value => Some(cx.new(|cx| {
2110            Markdown::new(
2111                format!("```json\n{}\n```", value).into(),
2112                Some(language_registry.clone()),
2113                None,
2114                cx,
2115            )
2116        })),
2117    }
2118}
2119
2120#[cfg(test)]
2121mod tests {
2122    use super::*;
2123    use anyhow::anyhow;
2124    use futures::{channel::mpsc, future::LocalBoxFuture, select};
2125    use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2126    use indoc::indoc;
2127    use project::{FakeFs, Fs};
2128    use rand::{distr, prelude::*};
2129    use serde_json::json;
2130    use settings::SettingsStore;
2131    use smol::stream::StreamExt as _;
2132    use std::{
2133        any::Any,
2134        cell::RefCell,
2135        path::Path,
2136        rc::Rc,
2137        sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2138        time::Duration,
2139    };
2140    use util::path;
2141
2142    fn init_test(cx: &mut TestAppContext) {
2143        env_logger::try_init().ok();
2144        cx.update(|cx| {
2145            let settings_store = SettingsStore::test(cx);
2146            cx.set_global(settings_store);
2147            Project::init_settings(cx);
2148            language::init(cx);
2149        });
2150    }
2151
2152    #[gpui::test]
2153    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2154        init_test(cx);
2155
2156        let fs = FakeFs::new(cx.executor());
2157        let project = Project::test(fs, [], cx).await;
2158        let connection = Rc::new(FakeAgentConnection::new());
2159        let thread = cx
2160            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2161            .await
2162            .unwrap();
2163
2164        // Test creating a new user message
2165        thread.update(cx, |thread, cx| {
2166            thread.push_user_content_block(
2167                None,
2168                acp::ContentBlock::Text(acp::TextContent {
2169                    annotations: None,
2170                    text: "Hello, ".to_string(),
2171                    meta: None,
2172                }),
2173                cx,
2174            );
2175        });
2176
2177        thread.update(cx, |thread, cx| {
2178            assert_eq!(thread.entries.len(), 1);
2179            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2180                assert_eq!(user_msg.id, None);
2181                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2182            } else {
2183                panic!("Expected UserMessage");
2184            }
2185        });
2186
2187        // Test appending to existing user message
2188        let message_1_id = UserMessageId::new();
2189        thread.update(cx, |thread, cx| {
2190            thread.push_user_content_block(
2191                Some(message_1_id.clone()),
2192                acp::ContentBlock::Text(acp::TextContent {
2193                    annotations: None,
2194                    text: "world!".to_string(),
2195                    meta: None,
2196                }),
2197                cx,
2198            );
2199        });
2200
2201        thread.update(cx, |thread, cx| {
2202            assert_eq!(thread.entries.len(), 1);
2203            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2204                assert_eq!(user_msg.id, Some(message_1_id));
2205                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2206            } else {
2207                panic!("Expected UserMessage");
2208            }
2209        });
2210
2211        // Test creating new user message after assistant message
2212        thread.update(cx, |thread, cx| {
2213            thread.push_assistant_content_block(
2214                acp::ContentBlock::Text(acp::TextContent {
2215                    annotations: None,
2216                    text: "Assistant response".to_string(),
2217                    meta: None,
2218                }),
2219                false,
2220                cx,
2221            );
2222        });
2223
2224        let message_2_id = UserMessageId::new();
2225        thread.update(cx, |thread, cx| {
2226            thread.push_user_content_block(
2227                Some(message_2_id.clone()),
2228                acp::ContentBlock::Text(acp::TextContent {
2229                    annotations: None,
2230                    text: "New user message".to_string(),
2231                    meta: None,
2232                }),
2233                cx,
2234            );
2235        });
2236
2237        thread.update(cx, |thread, cx| {
2238            assert_eq!(thread.entries.len(), 3);
2239            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2240                assert_eq!(user_msg.id, Some(message_2_id));
2241                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2242            } else {
2243                panic!("Expected UserMessage at index 2");
2244            }
2245        });
2246    }
2247
2248    #[gpui::test]
2249    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2250        init_test(cx);
2251
2252        let fs = FakeFs::new(cx.executor());
2253        let project = Project::test(fs, [], cx).await;
2254        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2255            |_, thread, mut cx| {
2256                async move {
2257                    thread.update(&mut cx, |thread, cx| {
2258                        thread
2259                            .handle_session_update(
2260                                acp::SessionUpdate::AgentThoughtChunk {
2261                                    content: "Thinking ".into(),
2262                                },
2263                                cx,
2264                            )
2265                            .unwrap();
2266                        thread
2267                            .handle_session_update(
2268                                acp::SessionUpdate::AgentThoughtChunk {
2269                                    content: "hard!".into(),
2270                                },
2271                                cx,
2272                            )
2273                            .unwrap();
2274                    })?;
2275                    Ok(acp::PromptResponse {
2276                        stop_reason: acp::StopReason::EndTurn,
2277                        meta: None,
2278                    })
2279                }
2280                .boxed_local()
2281            },
2282        ));
2283
2284        let thread = cx
2285            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2286            .await
2287            .unwrap();
2288
2289        thread
2290            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2291            .await
2292            .unwrap();
2293
2294        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2295        assert_eq!(
2296            output,
2297            indoc! {r#"
2298            ## User
2299
2300            Hello from Zed!
2301
2302            ## Assistant
2303
2304            <thinking>
2305            Thinking hard!
2306            </thinking>
2307
2308            "#}
2309        );
2310    }
2311
2312    #[gpui::test]
2313    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2314        init_test(cx);
2315
2316        let fs = FakeFs::new(cx.executor());
2317        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2318            .await;
2319        let project = Project::test(fs.clone(), [], cx).await;
2320        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2321        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2322        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2323            move |_, thread, mut cx| {
2324                let read_file_tx = read_file_tx.clone();
2325                async move {
2326                    let content = thread
2327                        .update(&mut cx, |thread, cx| {
2328                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2329                        })
2330                        .unwrap()
2331                        .await
2332                        .unwrap();
2333                    assert_eq!(content, "one\ntwo\nthree\n");
2334                    read_file_tx.take().unwrap().send(()).unwrap();
2335                    thread
2336                        .update(&mut cx, |thread, cx| {
2337                            thread.write_text_file(
2338                                path!("/tmp/foo").into(),
2339                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
2340                                cx,
2341                            )
2342                        })
2343                        .unwrap()
2344                        .await
2345                        .unwrap();
2346                    Ok(acp::PromptResponse {
2347                        stop_reason: acp::StopReason::EndTurn,
2348                        meta: None,
2349                    })
2350                }
2351                .boxed_local()
2352            },
2353        ));
2354
2355        let (worktree, pathbuf) = project
2356            .update(cx, |project, cx| {
2357                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2358            })
2359            .await
2360            .unwrap();
2361        let buffer = project
2362            .update(cx, |project, cx| {
2363                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2364            })
2365            .await
2366            .unwrap();
2367
2368        let thread = cx
2369            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2370            .await
2371            .unwrap();
2372
2373        let request = thread.update(cx, |thread, cx| {
2374            thread.send_raw("Extend the count in /tmp/foo", cx)
2375        });
2376        read_file_rx.await.ok();
2377        buffer.update(cx, |buffer, cx| {
2378            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2379        });
2380        cx.run_until_parked();
2381        assert_eq!(
2382            buffer.read_with(cx, |buffer, _| buffer.text()),
2383            "zero\none\ntwo\nthree\nfour\nfive\n"
2384        );
2385        assert_eq!(
2386            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2387            "zero\none\ntwo\nthree\nfour\nfive\n"
2388        );
2389        request.await.unwrap();
2390    }
2391
2392    #[gpui::test]
2393    async fn test_reading_from_line(cx: &mut TestAppContext) {
2394        init_test(cx);
2395
2396        let fs = FakeFs::new(cx.executor());
2397        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
2398            .await;
2399        let project = Project::test(fs.clone(), [], cx).await;
2400        project
2401            .update(cx, |project, cx| {
2402                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2403            })
2404            .await
2405            .unwrap();
2406
2407        let connection = Rc::new(FakeAgentConnection::new());
2408
2409        let thread = cx
2410            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2411            .await
2412            .unwrap();
2413
2414        // Whole file
2415        let content = thread
2416            .update(cx, |thread, cx| {
2417                thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2418            })
2419            .await
2420            .unwrap();
2421
2422        assert_eq!(content, "one\ntwo\nthree\nfour\n");
2423
2424        // Only start line
2425        let content = thread
2426            .update(cx, |thread, cx| {
2427                thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
2428            })
2429            .await
2430            .unwrap();
2431
2432        assert_eq!(content, "three\nfour\n");
2433
2434        // Only limit
2435        let content = thread
2436            .update(cx, |thread, cx| {
2437                thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2438            })
2439            .await
2440            .unwrap();
2441
2442        assert_eq!(content, "one\ntwo\n");
2443
2444        // Range
2445        let content = thread
2446            .update(cx, |thread, cx| {
2447                thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
2448            })
2449            .await
2450            .unwrap();
2451
2452        assert_eq!(content, "two\nthree\n");
2453
2454        // Invalid
2455        let err = thread
2456            .update(cx, |thread, cx| {
2457                thread.read_text_file(path!("/tmp/foo").into(), Some(6), Some(2), false, cx)
2458            })
2459            .await
2460            .unwrap_err();
2461
2462        assert_eq!(
2463            err.to_string(),
2464            "Attempting to read beyond the end of the file, line 5:0"
2465        );
2466    }
2467
2468    #[gpui::test]
2469    async fn test_reading_empty_file(cx: &mut TestAppContext) {
2470        init_test(cx);
2471
2472        let fs = FakeFs::new(cx.executor());
2473        fs.insert_tree(path!("/tmp"), json!({"foo": ""})).await;
2474        let project = Project::test(fs.clone(), [], cx).await;
2475        project
2476            .update(cx, |project, cx| {
2477                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2478            })
2479            .await
2480            .unwrap();
2481
2482        let connection = Rc::new(FakeAgentConnection::new());
2483
2484        let thread = cx
2485            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2486            .await
2487            .unwrap();
2488
2489        // Whole file
2490        let content = thread
2491            .update(cx, |thread, cx| {
2492                thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2493            })
2494            .await
2495            .unwrap();
2496
2497        assert_eq!(content, "");
2498
2499        // Only start line
2500        let content = thread
2501            .update(cx, |thread, cx| {
2502                thread.read_text_file(path!("/tmp/foo").into(), Some(1), None, false, cx)
2503            })
2504            .await
2505            .unwrap();
2506
2507        assert_eq!(content, "");
2508
2509        // Only limit
2510        let content = thread
2511            .update(cx, |thread, cx| {
2512                thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2513            })
2514            .await
2515            .unwrap();
2516
2517        assert_eq!(content, "");
2518
2519        // Range
2520        let content = thread
2521            .update(cx, |thread, cx| {
2522                thread.read_text_file(path!("/tmp/foo").into(), Some(1), Some(1), false, cx)
2523            })
2524            .await
2525            .unwrap();
2526
2527        assert_eq!(content, "");
2528
2529        // Invalid
2530        let err = thread
2531            .update(cx, |thread, cx| {
2532                thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
2533            })
2534            .await
2535            .unwrap_err();
2536
2537        assert_eq!(
2538            err.to_string(),
2539            "Attempting to read beyond the end of the file, line 1:0"
2540        );
2541    }
2542
2543    #[gpui::test]
2544    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2545        init_test(cx);
2546
2547        let fs = FakeFs::new(cx.executor());
2548        let project = Project::test(fs, [], cx).await;
2549        let id = acp::ToolCallId("test".into());
2550
2551        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2552            let id = id.clone();
2553            move |_, thread, mut cx| {
2554                let id = id.clone();
2555                async move {
2556                    thread
2557                        .update(&mut cx, |thread, cx| {
2558                            thread.handle_session_update(
2559                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2560                                    id: id.clone(),
2561                                    title: "Label".into(),
2562                                    kind: acp::ToolKind::Fetch,
2563                                    status: acp::ToolCallStatus::InProgress,
2564                                    content: vec![],
2565                                    locations: vec![],
2566                                    raw_input: None,
2567                                    raw_output: None,
2568                                    meta: None,
2569                                }),
2570                                cx,
2571                            )
2572                        })
2573                        .unwrap()
2574                        .unwrap();
2575                    Ok(acp::PromptResponse {
2576                        stop_reason: acp::StopReason::EndTurn,
2577                        meta: None,
2578                    })
2579                }
2580                .boxed_local()
2581            }
2582        }));
2583
2584        let thread = cx
2585            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2586            .await
2587            .unwrap();
2588
2589        let request = thread.update(cx, |thread, cx| {
2590            thread.send_raw("Fetch https://example.com", cx)
2591        });
2592
2593        run_until_first_tool_call(&thread, cx).await;
2594
2595        thread.read_with(cx, |thread, _| {
2596            assert!(matches!(
2597                thread.entries[1],
2598                AgentThreadEntry::ToolCall(ToolCall {
2599                    status: ToolCallStatus::InProgress,
2600                    ..
2601                })
2602            ));
2603        });
2604
2605        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2606
2607        thread.read_with(cx, |thread, _| {
2608            assert!(matches!(
2609                &thread.entries[1],
2610                AgentThreadEntry::ToolCall(ToolCall {
2611                    status: ToolCallStatus::Canceled,
2612                    ..
2613                })
2614            ));
2615        });
2616
2617        thread
2618            .update(cx, |thread, cx| {
2619                thread.handle_session_update(
2620                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2621                        id,
2622                        fields: acp::ToolCallUpdateFields {
2623                            status: Some(acp::ToolCallStatus::Completed),
2624                            ..Default::default()
2625                        },
2626                        meta: None,
2627                    }),
2628                    cx,
2629                )
2630            })
2631            .unwrap();
2632
2633        request.await.unwrap();
2634
2635        thread.read_with(cx, |thread, _| {
2636            assert!(matches!(
2637                thread.entries[1],
2638                AgentThreadEntry::ToolCall(ToolCall {
2639                    status: ToolCallStatus::Completed,
2640                    ..
2641                })
2642            ));
2643        });
2644    }
2645
2646    #[gpui::test]
2647    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2648        init_test(cx);
2649        let fs = FakeFs::new(cx.background_executor.clone());
2650        fs.insert_tree(path!("/test"), json!({})).await;
2651        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2652
2653        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2654            move |_, thread, mut cx| {
2655                async move {
2656                    thread
2657                        .update(&mut cx, |thread, cx| {
2658                            thread.handle_session_update(
2659                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2660                                    id: acp::ToolCallId("test".into()),
2661                                    title: "Label".into(),
2662                                    kind: acp::ToolKind::Edit,
2663                                    status: acp::ToolCallStatus::Completed,
2664                                    content: vec![acp::ToolCallContent::Diff {
2665                                        diff: acp::Diff {
2666                                            path: "/test/test.txt".into(),
2667                                            old_text: None,
2668                                            new_text: "foo".into(),
2669                                            meta: None,
2670                                        },
2671                                    }],
2672                                    locations: vec![],
2673                                    raw_input: None,
2674                                    raw_output: None,
2675                                    meta: None,
2676                                }),
2677                                cx,
2678                            )
2679                        })
2680                        .unwrap()
2681                        .unwrap();
2682                    Ok(acp::PromptResponse {
2683                        stop_reason: acp::StopReason::EndTurn,
2684                        meta: None,
2685                    })
2686                }
2687                .boxed_local()
2688            }
2689        }));
2690
2691        let thread = cx
2692            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2693            .await
2694            .unwrap();
2695
2696        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2697            .await
2698            .unwrap();
2699
2700        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2701    }
2702
2703    #[gpui::test(iterations = 10)]
2704    async fn test_checkpoints(cx: &mut TestAppContext) {
2705        init_test(cx);
2706        let fs = FakeFs::new(cx.background_executor.clone());
2707        fs.insert_tree(
2708            path!("/test"),
2709            json!({
2710                ".git": {}
2711            }),
2712        )
2713        .await;
2714        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2715
2716        let simulate_changes = Arc::new(AtomicBool::new(true));
2717        let next_filename = Arc::new(AtomicUsize::new(0));
2718        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2719            let simulate_changes = simulate_changes.clone();
2720            let next_filename = next_filename.clone();
2721            let fs = fs.clone();
2722            move |request, thread, mut cx| {
2723                let fs = fs.clone();
2724                let simulate_changes = simulate_changes.clone();
2725                let next_filename = next_filename.clone();
2726                async move {
2727                    if simulate_changes.load(SeqCst) {
2728                        let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2729                        fs.write(Path::new(&filename), b"").await?;
2730                    }
2731
2732                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2733                        panic!("expected text content block");
2734                    };
2735                    thread.update(&mut cx, |thread, cx| {
2736                        thread
2737                            .handle_session_update(
2738                                acp::SessionUpdate::AgentMessageChunk {
2739                                    content: content.text.to_uppercase().into(),
2740                                },
2741                                cx,
2742                            )
2743                            .unwrap();
2744                    })?;
2745                    Ok(acp::PromptResponse {
2746                        stop_reason: acp::StopReason::EndTurn,
2747                        meta: None,
2748                    })
2749                }
2750                .boxed_local()
2751            }
2752        }));
2753        let thread = cx
2754            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2755            .await
2756            .unwrap();
2757
2758        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2759            .await
2760            .unwrap();
2761        thread.read_with(cx, |thread, cx| {
2762            assert_eq!(
2763                thread.to_markdown(cx),
2764                indoc! {"
2765                    ## User (checkpoint)
2766
2767                    Lorem
2768
2769                    ## Assistant
2770
2771                    LOREM
2772
2773                "}
2774            );
2775        });
2776        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2777
2778        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2779            .await
2780            .unwrap();
2781        thread.read_with(cx, |thread, cx| {
2782            assert_eq!(
2783                thread.to_markdown(cx),
2784                indoc! {"
2785                    ## User (checkpoint)
2786
2787                    Lorem
2788
2789                    ## Assistant
2790
2791                    LOREM
2792
2793                    ## User (checkpoint)
2794
2795                    ipsum
2796
2797                    ## Assistant
2798
2799                    IPSUM
2800
2801                "}
2802            );
2803        });
2804        assert_eq!(
2805            fs.files(),
2806            vec![
2807                Path::new(path!("/test/file-0")),
2808                Path::new(path!("/test/file-1"))
2809            ]
2810        );
2811
2812        // Checkpoint isn't stored when there are no changes.
2813        simulate_changes.store(false, SeqCst);
2814        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2815            .await
2816            .unwrap();
2817        thread.read_with(cx, |thread, cx| {
2818            assert_eq!(
2819                thread.to_markdown(cx),
2820                indoc! {"
2821                    ## User (checkpoint)
2822
2823                    Lorem
2824
2825                    ## Assistant
2826
2827                    LOREM
2828
2829                    ## User (checkpoint)
2830
2831                    ipsum
2832
2833                    ## Assistant
2834
2835                    IPSUM
2836
2837                    ## User
2838
2839                    dolor
2840
2841                    ## Assistant
2842
2843                    DOLOR
2844
2845                "}
2846            );
2847        });
2848        assert_eq!(
2849            fs.files(),
2850            vec![
2851                Path::new(path!("/test/file-0")),
2852                Path::new(path!("/test/file-1"))
2853            ]
2854        );
2855
2856        // Rewinding the conversation truncates the history and restores the checkpoint.
2857        thread
2858            .update(cx, |thread, cx| {
2859                let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2860                    panic!("unexpected entries {:?}", thread.entries)
2861                };
2862                thread.restore_checkpoint(message.id.clone().unwrap(), cx)
2863            })
2864            .await
2865            .unwrap();
2866        thread.read_with(cx, |thread, cx| {
2867            assert_eq!(
2868                thread.to_markdown(cx),
2869                indoc! {"
2870                    ## User (checkpoint)
2871
2872                    Lorem
2873
2874                    ## Assistant
2875
2876                    LOREM
2877
2878                "}
2879            );
2880        });
2881        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2882    }
2883
2884    #[gpui::test]
2885    async fn test_tool_result_refusal(cx: &mut TestAppContext) {
2886        use std::sync::atomic::AtomicUsize;
2887        init_test(cx);
2888
2889        let fs = FakeFs::new(cx.executor());
2890        let project = Project::test(fs, None, cx).await;
2891
2892        // Create a connection that simulates refusal after tool result
2893        let prompt_count = Arc::new(AtomicUsize::new(0));
2894        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2895            let prompt_count = prompt_count.clone();
2896            move |_request, thread, mut cx| {
2897                let count = prompt_count.fetch_add(1, SeqCst);
2898                async move {
2899                    if count == 0 {
2900                        // First prompt: Generate a tool call with result
2901                        thread.update(&mut cx, |thread, cx| {
2902                            thread
2903                                .handle_session_update(
2904                                    acp::SessionUpdate::ToolCall(acp::ToolCall {
2905                                        id: acp::ToolCallId("tool1".into()),
2906                                        title: "Test Tool".into(),
2907                                        kind: acp::ToolKind::Fetch,
2908                                        status: acp::ToolCallStatus::Completed,
2909                                        content: vec![],
2910                                        locations: vec![],
2911                                        raw_input: Some(serde_json::json!({"query": "test"})),
2912                                        raw_output: Some(
2913                                            serde_json::json!({"result": "inappropriate content"}),
2914                                        ),
2915                                        meta: None,
2916                                    }),
2917                                    cx,
2918                                )
2919                                .unwrap();
2920                        })?;
2921
2922                        // Now return refusal because of the tool result
2923                        Ok(acp::PromptResponse {
2924                            stop_reason: acp::StopReason::Refusal,
2925                            meta: None,
2926                        })
2927                    } else {
2928                        Ok(acp::PromptResponse {
2929                            stop_reason: acp::StopReason::EndTurn,
2930                            meta: None,
2931                        })
2932                    }
2933                }
2934                .boxed_local()
2935            }
2936        }));
2937
2938        let thread = cx
2939            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2940            .await
2941            .unwrap();
2942
2943        // Track if we see a Refusal event
2944        let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
2945        let saw_refusal_event_captured = saw_refusal_event.clone();
2946        thread.update(cx, |_thread, cx| {
2947            cx.subscribe(
2948                &thread,
2949                move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
2950                    if matches!(event, AcpThreadEvent::Refusal) {
2951                        *saw_refusal_event_captured.lock().unwrap() = true;
2952                    }
2953                },
2954            )
2955            .detach();
2956        });
2957
2958        // Send a user message - this will trigger tool call and then refusal
2959        let send_task = thread.update(cx, |thread, cx| {
2960            thread.send(
2961                vec![acp::ContentBlock::Text(acp::TextContent {
2962                    text: "Hello".into(),
2963                    annotations: None,
2964                    meta: None,
2965                })],
2966                cx,
2967            )
2968        });
2969        cx.background_executor.spawn(send_task).detach();
2970        cx.run_until_parked();
2971
2972        // Verify that:
2973        // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
2974        // 2. The user message was NOT truncated
2975        assert!(
2976            *saw_refusal_event.lock().unwrap(),
2977            "Refusal event should be emitted for tool result refusals"
2978        );
2979
2980        thread.read_with(cx, |thread, _| {
2981            let entries = thread.entries();
2982            assert!(entries.len() >= 2, "Should have user message and tool call");
2983
2984            // Verify user message is still there
2985            assert!(
2986                matches!(entries[0], AgentThreadEntry::UserMessage(_)),
2987                "User message should not be truncated"
2988            );
2989
2990            // Verify tool call is there with result
2991            if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
2992                assert!(
2993                    tool_call.raw_output.is_some(),
2994                    "Tool call should have output"
2995                );
2996            } else {
2997                panic!("Expected tool call at index 1");
2998            }
2999        });
3000    }
3001
3002    #[gpui::test]
3003    async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
3004        init_test(cx);
3005
3006        let fs = FakeFs::new(cx.executor());
3007        let project = Project::test(fs, None, cx).await;
3008
3009        let refuse_next = Arc::new(AtomicBool::new(false));
3010        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3011            let refuse_next = refuse_next.clone();
3012            move |_request, _thread, _cx| {
3013                if refuse_next.load(SeqCst) {
3014                    async move {
3015                        Ok(acp::PromptResponse {
3016                            stop_reason: acp::StopReason::Refusal,
3017                            meta: None,
3018                        })
3019                    }
3020                    .boxed_local()
3021                } else {
3022                    async move {
3023                        Ok(acp::PromptResponse {
3024                            stop_reason: acp::StopReason::EndTurn,
3025                            meta: None,
3026                        })
3027                    }
3028                    .boxed_local()
3029                }
3030            }
3031        }));
3032
3033        let thread = cx
3034            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3035            .await
3036            .unwrap();
3037
3038        // Track if we see a Refusal event
3039        let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3040        let saw_refusal_event_captured = saw_refusal_event.clone();
3041        thread.update(cx, |_thread, cx| {
3042            cx.subscribe(
3043                &thread,
3044                move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3045                    if matches!(event, AcpThreadEvent::Refusal) {
3046                        *saw_refusal_event_captured.lock().unwrap() = true;
3047                    }
3048                },
3049            )
3050            .detach();
3051        });
3052
3053        // Send a message that will be refused
3054        refuse_next.store(true, SeqCst);
3055        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3056            .await
3057            .unwrap();
3058
3059        // Verify that a Refusal event WAS emitted for user prompt refusal
3060        assert!(
3061            *saw_refusal_event.lock().unwrap(),
3062            "Refusal event should be emitted for user prompt refusals"
3063        );
3064
3065        // Verify the message was truncated (user prompt refusal)
3066        thread.read_with(cx, |thread, cx| {
3067            assert_eq!(thread.to_markdown(cx), "");
3068        });
3069    }
3070
3071    #[gpui::test]
3072    async fn test_refusal(cx: &mut TestAppContext) {
3073        init_test(cx);
3074        let fs = FakeFs::new(cx.background_executor.clone());
3075        fs.insert_tree(path!("/"), json!({})).await;
3076        let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3077
3078        let refuse_next = Arc::new(AtomicBool::new(false));
3079        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3080            let refuse_next = refuse_next.clone();
3081            move |request, thread, mut cx| {
3082                let refuse_next = refuse_next.clone();
3083                async move {
3084                    if refuse_next.load(SeqCst) {
3085                        return Ok(acp::PromptResponse {
3086                            stop_reason: acp::StopReason::Refusal,
3087                            meta: None,
3088                        });
3089                    }
3090
3091                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3092                        panic!("expected text content block");
3093                    };
3094                    thread.update(&mut cx, |thread, cx| {
3095                        thread
3096                            .handle_session_update(
3097                                acp::SessionUpdate::AgentMessageChunk {
3098                                    content: content.text.to_uppercase().into(),
3099                                },
3100                                cx,
3101                            )
3102                            .unwrap();
3103                    })?;
3104                    Ok(acp::PromptResponse {
3105                        stop_reason: acp::StopReason::EndTurn,
3106                        meta: None,
3107                    })
3108                }
3109                .boxed_local()
3110            }
3111        }));
3112        let thread = cx
3113            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3114            .await
3115            .unwrap();
3116
3117        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3118            .await
3119            .unwrap();
3120        thread.read_with(cx, |thread, cx| {
3121            assert_eq!(
3122                thread.to_markdown(cx),
3123                indoc! {"
3124                    ## User
3125
3126                    hello
3127
3128                    ## Assistant
3129
3130                    HELLO
3131
3132                "}
3133            );
3134        });
3135
3136        // Simulate refusing the second message. The message should be truncated
3137        // when a user prompt is refused.
3138        refuse_next.store(true, SeqCst);
3139        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3140            .await
3141            .unwrap();
3142        thread.read_with(cx, |thread, cx| {
3143            assert_eq!(
3144                thread.to_markdown(cx),
3145                indoc! {"
3146                    ## User
3147
3148                    hello
3149
3150                    ## Assistant
3151
3152                    HELLO
3153
3154                "}
3155            );
3156        });
3157    }
3158
3159    async fn run_until_first_tool_call(
3160        thread: &Entity<AcpThread>,
3161        cx: &mut TestAppContext,
3162    ) -> usize {
3163        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3164
3165        let subscription = cx.update(|cx| {
3166            cx.subscribe(thread, move |thread, _, cx| {
3167                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3168                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3169                        return tx.try_send(ix).unwrap();
3170                    }
3171                }
3172            })
3173        });
3174
3175        select! {
3176            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
3177                panic!("Timeout waiting for tool call")
3178            }
3179            ix = rx.next().fuse() => {
3180                drop(subscription);
3181                ix.unwrap()
3182            }
3183        }
3184    }
3185
3186    #[derive(Clone, Default)]
3187    struct FakeAgentConnection {
3188        auth_methods: Vec<acp::AuthMethod>,
3189        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3190        on_user_message: Option<
3191            Rc<
3192                dyn Fn(
3193                        acp::PromptRequest,
3194                        WeakEntity<AcpThread>,
3195                        AsyncApp,
3196                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3197                    + 'static,
3198            >,
3199        >,
3200    }
3201
3202    impl FakeAgentConnection {
3203        fn new() -> Self {
3204            Self {
3205                auth_methods: Vec::new(),
3206                on_user_message: None,
3207                sessions: Arc::default(),
3208            }
3209        }
3210
3211        #[expect(unused)]
3212        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3213            self.auth_methods = auth_methods;
3214            self
3215        }
3216
3217        fn on_user_message(
3218            mut self,
3219            handler: impl Fn(
3220                acp::PromptRequest,
3221                WeakEntity<AcpThread>,
3222                AsyncApp,
3223            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3224            + 'static,
3225        ) -> Self {
3226            self.on_user_message.replace(Rc::new(handler));
3227            self
3228        }
3229    }
3230
3231    impl AgentConnection for FakeAgentConnection {
3232        fn auth_methods(&self) -> &[acp::AuthMethod] {
3233            &self.auth_methods
3234        }
3235
3236        fn new_thread(
3237            self: Rc<Self>,
3238            project: Entity<Project>,
3239            _cwd: &Path,
3240            cx: &mut App,
3241        ) -> Task<gpui::Result<Entity<AcpThread>>> {
3242            let session_id = acp::SessionId(
3243                rand::rng()
3244                    .sample_iter(&distr::Alphanumeric)
3245                    .take(7)
3246                    .map(char::from)
3247                    .collect::<String>()
3248                    .into(),
3249            );
3250            let action_log = cx.new(|_| ActionLog::new(project.clone()));
3251            let thread = cx.new(|cx| {
3252                AcpThread::new(
3253                    "Test",
3254                    self.clone(),
3255                    project,
3256                    action_log,
3257                    session_id.clone(),
3258                    watch::Receiver::constant(acp::PromptCapabilities {
3259                        image: true,
3260                        audio: true,
3261                        embedded_context: true,
3262                        meta: None,
3263                    }),
3264                    cx,
3265                )
3266            });
3267            self.sessions.lock().insert(session_id, thread.downgrade());
3268            Task::ready(Ok(thread))
3269        }
3270
3271        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3272            if self.auth_methods().iter().any(|m| m.id == method) {
3273                Task::ready(Ok(()))
3274            } else {
3275                Task::ready(Err(anyhow!("Invalid Auth Method")))
3276            }
3277        }
3278
3279        fn prompt(
3280            &self,
3281            _id: Option<UserMessageId>,
3282            params: acp::PromptRequest,
3283            cx: &mut App,
3284        ) -> Task<gpui::Result<acp::PromptResponse>> {
3285            let sessions = self.sessions.lock();
3286            let thread = sessions.get(&params.session_id).unwrap();
3287            if let Some(handler) = &self.on_user_message {
3288                let handler = handler.clone();
3289                let thread = thread.clone();
3290                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3291            } else {
3292                Task::ready(Ok(acp::PromptResponse {
3293                    stop_reason: acp::StopReason::EndTurn,
3294                    meta: None,
3295                }))
3296            }
3297        }
3298
3299        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3300            let sessions = self.sessions.lock();
3301            let thread = sessions.get(session_id).unwrap().clone();
3302
3303            cx.spawn(async move |cx| {
3304                thread
3305                    .update(cx, |thread, cx| thread.cancel(cx))
3306                    .unwrap()
3307                    .await
3308            })
3309            .detach();
3310        }
3311
3312        fn truncate(
3313            &self,
3314            session_id: &acp::SessionId,
3315            _cx: &App,
3316        ) -> Option<Rc<dyn AgentSessionTruncate>> {
3317            Some(Rc::new(FakeAgentSessionEditor {
3318                _session_id: session_id.clone(),
3319            }))
3320        }
3321
3322        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3323            self
3324        }
3325    }
3326
3327    struct FakeAgentSessionEditor {
3328        _session_id: acp::SessionId,
3329    }
3330
3331    impl AgentSessionTruncate for FakeAgentSessionEditor {
3332        fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3333            Task::ready(Ok(()))
3334        }
3335    }
3336
3337    #[gpui::test]
3338    async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3339        init_test(cx);
3340
3341        let fs = FakeFs::new(cx.executor());
3342        let project = Project::test(fs, [], cx).await;
3343        let connection = Rc::new(FakeAgentConnection::new());
3344        let thread = cx
3345            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3346            .await
3347            .unwrap();
3348
3349        // Try to update a tool call that doesn't exist
3350        let nonexistent_id = acp::ToolCallId("nonexistent-tool-call".into());
3351        thread.update(cx, |thread, cx| {
3352            let result = thread.handle_session_update(
3353                acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
3354                    id: nonexistent_id.clone(),
3355                    fields: acp::ToolCallUpdateFields {
3356                        status: Some(acp::ToolCallStatus::Completed),
3357                        ..Default::default()
3358                    },
3359                    meta: None,
3360                }),
3361                cx,
3362            );
3363
3364            // The update should succeed (not return an error)
3365            assert!(result.is_ok());
3366
3367            // There should now be exactly one entry in the thread
3368            assert_eq!(thread.entries.len(), 1);
3369
3370            // The entry should be a failed tool call
3371            if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
3372                assert_eq!(tool_call.id, nonexistent_id);
3373                assert!(matches!(tool_call.status, ToolCallStatus::Failed));
3374                assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
3375
3376                // Check that the content contains the error message
3377                assert_eq!(tool_call.content.len(), 1);
3378                if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
3379                    match content_block {
3380                        ContentBlock::Markdown { markdown } => {
3381                            let markdown_text = markdown.read(cx).source();
3382                            assert!(markdown_text.contains("Tool call not found"));
3383                        }
3384                        ContentBlock::Empty => panic!("Expected markdown content, got empty"),
3385                        ContentBlock::ResourceLink { .. } => {
3386                            panic!("Expected markdown content, got resource link")
3387                        }
3388                    }
3389                } else {
3390                    panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
3391                }
3392            } else {
3393                panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
3394            }
3395        });
3396    }
3397}