acp_thread.rs

   1mod connection;
   2mod diff;
   3mod mention;
   4mod terminal;
   5
   6use ::terminal::TerminalBuilder;
   7use ::terminal::terminal_settings::{AlternateScroll, CursorShape};
   8use agent_settings::AgentSettings;
   9use collections::HashSet;
  10pub use connection::*;
  11pub use diff::*;
  12use language::language_settings::FormatOnSave;
  13pub use mention::*;
  14use project::lsp_store::{FormatTrigger, LspFormatTarget};
  15use serde::{Deserialize, Serialize};
  16use settings::Settings as _;
  17use task::{Shell, ShellBuilder};
  18pub use terminal::*;
  19
  20use action_log::ActionLog;
  21use agent_client_protocol::{self as acp};
  22use anyhow::{Context as _, Result, anyhow};
  23use editor::Bias;
  24use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  25use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  26use itertools::Itertools;
  27use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
  28use markdown::Markdown;
  29use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
  30use std::collections::HashMap;
  31use std::error::Error;
  32use std::fmt::{Formatter, Write};
  33use std::ops::Range;
  34use std::process::ExitStatus;
  35use std::rc::Rc;
  36use std::time::{Duration, Instant};
  37use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
  38use ui::App;
  39use util::{ResultExt, get_default_system_shell};
  40use uuid::Uuid;
  41
  42#[derive(Debug)]
  43pub struct UserMessage {
  44    pub id: Option<UserMessageId>,
  45    pub content: ContentBlock,
  46    pub chunks: Vec<acp::ContentBlock>,
  47    pub checkpoint: Option<Checkpoint>,
  48}
  49
  50#[derive(Debug)]
  51pub struct Checkpoint {
  52    git_checkpoint: GitStoreCheckpoint,
  53    pub show: bool,
  54}
  55
  56impl UserMessage {
  57    fn to_markdown(&self, cx: &App) -> String {
  58        let mut markdown = String::new();
  59        if self
  60            .checkpoint
  61            .as_ref()
  62            .is_some_and(|checkpoint| checkpoint.show)
  63        {
  64            writeln!(markdown, "## User (checkpoint)").unwrap();
  65        } else {
  66            writeln!(markdown, "## User").unwrap();
  67        }
  68        writeln!(markdown).unwrap();
  69        writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
  70        writeln!(markdown).unwrap();
  71        markdown
  72    }
  73}
  74
  75#[derive(Debug, PartialEq)]
  76pub struct AssistantMessage {
  77    pub chunks: Vec<AssistantMessageChunk>,
  78}
  79
  80impl AssistantMessage {
  81    pub fn to_markdown(&self, cx: &App) -> String {
  82        format!(
  83            "## Assistant\n\n{}\n\n",
  84            self.chunks
  85                .iter()
  86                .map(|chunk| chunk.to_markdown(cx))
  87                .join("\n\n")
  88        )
  89    }
  90}
  91
  92#[derive(Debug, PartialEq)]
  93pub enum AssistantMessageChunk {
  94    Message { block: ContentBlock },
  95    Thought { block: ContentBlock },
  96}
  97
  98impl AssistantMessageChunk {
  99    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
 100        Self::Message {
 101            block: ContentBlock::new(chunk.into(), language_registry, cx),
 102        }
 103    }
 104
 105    fn to_markdown(&self, cx: &App) -> String {
 106        match self {
 107            Self::Message { block } => block.to_markdown(cx).to_string(),
 108            Self::Thought { block } => {
 109                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 110            }
 111        }
 112    }
 113}
 114
 115#[derive(Debug)]
 116pub enum AgentThreadEntry {
 117    UserMessage(UserMessage),
 118    AssistantMessage(AssistantMessage),
 119    ToolCall(ToolCall),
 120}
 121
 122impl AgentThreadEntry {
 123    pub fn to_markdown(&self, cx: &App) -> String {
 124        match self {
 125            Self::UserMessage(message) => message.to_markdown(cx),
 126            Self::AssistantMessage(message) => message.to_markdown(cx),
 127            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 128        }
 129    }
 130
 131    pub fn user_message(&self) -> Option<&UserMessage> {
 132        if let AgentThreadEntry::UserMessage(message) = self {
 133            Some(message)
 134        } else {
 135            None
 136        }
 137    }
 138
 139    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 140        if let AgentThreadEntry::ToolCall(call) = self {
 141            itertools::Either::Left(call.diffs())
 142        } else {
 143            itertools::Either::Right(std::iter::empty())
 144        }
 145    }
 146
 147    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 148        if let AgentThreadEntry::ToolCall(call) = self {
 149            itertools::Either::Left(call.terminals())
 150        } else {
 151            itertools::Either::Right(std::iter::empty())
 152        }
 153    }
 154
 155    pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
 156        if let AgentThreadEntry::ToolCall(ToolCall {
 157            locations,
 158            resolved_locations,
 159            ..
 160        }) = self
 161        {
 162            Some((
 163                locations.get(ix)?.clone(),
 164                resolved_locations.get(ix)?.clone()?,
 165            ))
 166        } else {
 167            None
 168        }
 169    }
 170}
 171
 172#[derive(Debug)]
 173pub struct ToolCall {
 174    pub id: acp::ToolCallId,
 175    pub label: Entity<Markdown>,
 176    pub kind: acp::ToolKind,
 177    pub content: Vec<ToolCallContent>,
 178    pub status: ToolCallStatus,
 179    pub locations: Vec<acp::ToolCallLocation>,
 180    pub resolved_locations: Vec<Option<AgentLocation>>,
 181    pub raw_input: Option<serde_json::Value>,
 182    pub raw_output: Option<serde_json::Value>,
 183}
 184
 185impl ToolCall {
 186    fn from_acp(
 187        tool_call: acp::ToolCall,
 188        status: ToolCallStatus,
 189        language_registry: Arc<LanguageRegistry>,
 190        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 191        cx: &mut App,
 192    ) -> Result<Self> {
 193        let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
 194            first_line.to_owned() + ""
 195        } else {
 196            tool_call.title
 197        };
 198        let mut content = Vec::with_capacity(tool_call.content.len());
 199        for item in tool_call.content {
 200            content.push(ToolCallContent::from_acp(
 201                item,
 202                language_registry.clone(),
 203                terminals,
 204                cx,
 205            )?);
 206        }
 207
 208        let result = Self {
 209            id: tool_call.id,
 210            label: cx
 211                .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
 212            kind: tool_call.kind,
 213            content,
 214            locations: tool_call.locations,
 215            resolved_locations: Vec::default(),
 216            status,
 217            raw_input: tool_call.raw_input,
 218            raw_output: tool_call.raw_output,
 219        };
 220        Ok(result)
 221    }
 222
 223    fn update_fields(
 224        &mut self,
 225        fields: acp::ToolCallUpdateFields,
 226        language_registry: Arc<LanguageRegistry>,
 227        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 228        cx: &mut App,
 229    ) -> Result<()> {
 230        let acp::ToolCallUpdateFields {
 231            kind,
 232            status,
 233            title,
 234            content,
 235            locations,
 236            raw_input,
 237            raw_output,
 238        } = fields;
 239
 240        if let Some(kind) = kind {
 241            self.kind = kind;
 242        }
 243
 244        if let Some(status) = status {
 245            self.status = status.into();
 246        }
 247
 248        if let Some(title) = title {
 249            self.label.update(cx, |label, cx| {
 250                if let Some((first_line, _)) = title.split_once("\n") {
 251                    label.replace(first_line.to_owned() + "", cx)
 252                } else {
 253                    label.replace(title, cx);
 254                }
 255            });
 256        }
 257
 258        if let Some(content) = content {
 259            let new_content_len = content.len();
 260            let mut content = content.into_iter();
 261
 262            // Reuse existing content if we can
 263            for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
 264                old.update_from_acp(new, language_registry.clone(), terminals, cx)?;
 265            }
 266            for new in content {
 267                self.content.push(ToolCallContent::from_acp(
 268                    new,
 269                    language_registry.clone(),
 270                    terminals,
 271                    cx,
 272                )?)
 273            }
 274            self.content.truncate(new_content_len);
 275        }
 276
 277        if let Some(locations) = locations {
 278            self.locations = locations;
 279        }
 280
 281        if let Some(raw_input) = raw_input {
 282            self.raw_input = Some(raw_input);
 283        }
 284
 285        if let Some(raw_output) = raw_output {
 286            if self.content.is_empty()
 287                && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
 288            {
 289                self.content
 290                    .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
 291                        markdown,
 292                    }));
 293            }
 294            self.raw_output = Some(raw_output);
 295        }
 296        Ok(())
 297    }
 298
 299    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 300        self.content.iter().filter_map(|content| match content {
 301            ToolCallContent::Diff(diff) => Some(diff),
 302            ToolCallContent::ContentBlock(_) => None,
 303            ToolCallContent::Terminal(_) => None,
 304        })
 305    }
 306
 307    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 308        self.content.iter().filter_map(|content| match content {
 309            ToolCallContent::Terminal(terminal) => Some(terminal),
 310            ToolCallContent::ContentBlock(_) => None,
 311            ToolCallContent::Diff(_) => None,
 312        })
 313    }
 314
 315    fn to_markdown(&self, cx: &App) -> String {
 316        let mut markdown = format!(
 317            "**Tool Call: {}**\nStatus: {}\n\n",
 318            self.label.read(cx).source(),
 319            self.status
 320        );
 321        for content in &self.content {
 322            markdown.push_str(content.to_markdown(cx).as_str());
 323            markdown.push_str("\n\n");
 324        }
 325        markdown
 326    }
 327
 328    async fn resolve_location(
 329        location: acp::ToolCallLocation,
 330        project: WeakEntity<Project>,
 331        cx: &mut AsyncApp,
 332    ) -> Option<AgentLocation> {
 333        let buffer = project
 334            .update(cx, |project, cx| {
 335                project
 336                    .project_path_for_absolute_path(&location.path, cx)
 337                    .map(|path| project.open_buffer(path, cx))
 338            })
 339            .ok()??;
 340        let buffer = buffer.await.log_err()?;
 341        let position = buffer
 342            .update(cx, |buffer, _| {
 343                if let Some(row) = location.line {
 344                    let snapshot = buffer.snapshot();
 345                    let column = snapshot.indent_size_for_line(row).len;
 346                    let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
 347                    snapshot.anchor_before(point)
 348                } else {
 349                    Anchor::MIN
 350                }
 351            })
 352            .ok()?;
 353
 354        Some(AgentLocation {
 355            buffer: buffer.downgrade(),
 356            position,
 357        })
 358    }
 359
 360    fn resolve_locations(
 361        &self,
 362        project: Entity<Project>,
 363        cx: &mut App,
 364    ) -> Task<Vec<Option<AgentLocation>>> {
 365        let locations = self.locations.clone();
 366        project.update(cx, |_, cx| {
 367            cx.spawn(async move |project, cx| {
 368                let mut new_locations = Vec::new();
 369                for location in locations {
 370                    new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
 371                }
 372                new_locations
 373            })
 374        })
 375    }
 376}
 377
 378#[derive(Debug)]
 379pub enum ToolCallStatus {
 380    /// The tool call hasn't started running yet, but we start showing it to
 381    /// the user.
 382    Pending,
 383    /// The tool call is waiting for confirmation from the user.
 384    WaitingForConfirmation {
 385        options: Vec<acp::PermissionOption>,
 386        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 387    },
 388    /// The tool call is currently running.
 389    InProgress,
 390    /// The tool call completed successfully.
 391    Completed,
 392    /// The tool call failed.
 393    Failed,
 394    /// The user rejected the tool call.
 395    Rejected,
 396    /// The user canceled generation so the tool call was canceled.
 397    Canceled,
 398}
 399
 400impl From<acp::ToolCallStatus> for ToolCallStatus {
 401    fn from(status: acp::ToolCallStatus) -> Self {
 402        match status {
 403            acp::ToolCallStatus::Pending => Self::Pending,
 404            acp::ToolCallStatus::InProgress => Self::InProgress,
 405            acp::ToolCallStatus::Completed => Self::Completed,
 406            acp::ToolCallStatus::Failed => Self::Failed,
 407        }
 408    }
 409}
 410
 411impl Display for ToolCallStatus {
 412    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 413        write!(
 414            f,
 415            "{}",
 416            match self {
 417                ToolCallStatus::Pending => "Pending",
 418                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 419                ToolCallStatus::InProgress => "In Progress",
 420                ToolCallStatus::Completed => "Completed",
 421                ToolCallStatus::Failed => "Failed",
 422                ToolCallStatus::Rejected => "Rejected",
 423                ToolCallStatus::Canceled => "Canceled",
 424            }
 425        )
 426    }
 427}
 428
 429#[derive(Debug, PartialEq, Clone)]
 430pub enum ContentBlock {
 431    Empty,
 432    Markdown { markdown: Entity<Markdown> },
 433    ResourceLink { resource_link: acp::ResourceLink },
 434}
 435
 436impl ContentBlock {
 437    pub fn new(
 438        block: acp::ContentBlock,
 439        language_registry: &Arc<LanguageRegistry>,
 440        cx: &mut App,
 441    ) -> Self {
 442        let mut this = Self::Empty;
 443        this.append(block, language_registry, cx);
 444        this
 445    }
 446
 447    pub fn new_combined(
 448        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 449        language_registry: Arc<LanguageRegistry>,
 450        cx: &mut App,
 451    ) -> Self {
 452        let mut this = Self::Empty;
 453        for block in blocks {
 454            this.append(block, &language_registry, cx);
 455        }
 456        this
 457    }
 458
 459    pub fn append(
 460        &mut self,
 461        block: acp::ContentBlock,
 462        language_registry: &Arc<LanguageRegistry>,
 463        cx: &mut App,
 464    ) {
 465        if matches!(self, ContentBlock::Empty)
 466            && let acp::ContentBlock::ResourceLink(resource_link) = block
 467        {
 468            *self = ContentBlock::ResourceLink { resource_link };
 469            return;
 470        }
 471
 472        let new_content = self.block_string_contents(block);
 473
 474        match self {
 475            ContentBlock::Empty => {
 476                *self = Self::create_markdown_block(new_content, language_registry, cx);
 477            }
 478            ContentBlock::Markdown { markdown } => {
 479                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 480            }
 481            ContentBlock::ResourceLink { resource_link } => {
 482                let existing_content = Self::resource_link_md(&resource_link.uri);
 483                let combined = format!("{}\n{}", existing_content, new_content);
 484
 485                *self = Self::create_markdown_block(combined, language_registry, cx);
 486            }
 487        }
 488    }
 489
 490    fn create_markdown_block(
 491        content: String,
 492        language_registry: &Arc<LanguageRegistry>,
 493        cx: &mut App,
 494    ) -> ContentBlock {
 495        ContentBlock::Markdown {
 496            markdown: cx
 497                .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
 498        }
 499    }
 500
 501    fn block_string_contents(&self, block: acp::ContentBlock) -> String {
 502        match block {
 503            acp::ContentBlock::Text(text_content) => text_content.text,
 504            acp::ContentBlock::ResourceLink(resource_link) => {
 505                Self::resource_link_md(&resource_link.uri)
 506            }
 507            acp::ContentBlock::Resource(acp::EmbeddedResource {
 508                resource:
 509                    acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
 510                        uri,
 511                        ..
 512                    }),
 513                ..
 514            }) => Self::resource_link_md(&uri),
 515            acp::ContentBlock::Image(image) => Self::image_md(&image),
 516            acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
 517        }
 518    }
 519
 520    fn resource_link_md(uri: &str) -> String {
 521        if let Some(uri) = MentionUri::parse(uri).log_err() {
 522            uri.as_link().to_string()
 523        } else {
 524            uri.to_string()
 525        }
 526    }
 527
 528    fn image_md(_image: &acp::ImageContent) -> String {
 529        "`Image`".into()
 530    }
 531
 532    pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 533        match self {
 534            ContentBlock::Empty => "",
 535            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 536            ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
 537        }
 538    }
 539
 540    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 541        match self {
 542            ContentBlock::Empty => None,
 543            ContentBlock::Markdown { markdown } => Some(markdown),
 544            ContentBlock::ResourceLink { .. } => None,
 545        }
 546    }
 547
 548    pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
 549        match self {
 550            ContentBlock::ResourceLink { resource_link } => Some(resource_link),
 551            _ => None,
 552        }
 553    }
 554}
 555
 556#[derive(Debug)]
 557pub enum ToolCallContent {
 558    ContentBlock(ContentBlock),
 559    Diff(Entity<Diff>),
 560    Terminal(Entity<Terminal>),
 561}
 562
 563impl ToolCallContent {
 564    pub fn from_acp(
 565        content: acp::ToolCallContent,
 566        language_registry: Arc<LanguageRegistry>,
 567        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 568        cx: &mut App,
 569    ) -> Result<Self> {
 570        match content {
 571            acp::ToolCallContent::Content { content } => Ok(Self::ContentBlock(ContentBlock::new(
 572                content,
 573                &language_registry,
 574                cx,
 575            ))),
 576            acp::ToolCallContent::Diff { diff } => Ok(Self::Diff(cx.new(|cx| {
 577                Diff::finalized(
 578                    diff.path,
 579                    diff.old_text,
 580                    diff.new_text,
 581                    language_registry,
 582                    cx,
 583                )
 584            }))),
 585            acp::ToolCallContent::Terminal { terminal_id } => terminals
 586                .get(&terminal_id)
 587                .cloned()
 588                .map(Self::Terminal)
 589                .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
 590        }
 591    }
 592
 593    pub fn update_from_acp(
 594        &mut self,
 595        new: acp::ToolCallContent,
 596        language_registry: Arc<LanguageRegistry>,
 597        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 598        cx: &mut App,
 599    ) -> Result<()> {
 600        let needs_update = match (&self, &new) {
 601            (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
 602                old_diff.read(cx).needs_update(
 603                    new_diff.old_text.as_deref().unwrap_or(""),
 604                    &new_diff.new_text,
 605                    cx,
 606                )
 607            }
 608            _ => true,
 609        };
 610
 611        if needs_update {
 612            *self = Self::from_acp(new, language_registry, terminals, cx)?;
 613        }
 614        Ok(())
 615    }
 616
 617    pub fn to_markdown(&self, cx: &App) -> String {
 618        match self {
 619            Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
 620            Self::Diff(diff) => diff.read(cx).to_markdown(cx),
 621            Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
 622        }
 623    }
 624}
 625
 626#[derive(Debug, PartialEq)]
 627pub enum ToolCallUpdate {
 628    UpdateFields(acp::ToolCallUpdate),
 629    UpdateDiff(ToolCallUpdateDiff),
 630    UpdateTerminal(ToolCallUpdateTerminal),
 631}
 632
 633impl ToolCallUpdate {
 634    fn id(&self) -> &acp::ToolCallId {
 635        match self {
 636            Self::UpdateFields(update) => &update.id,
 637            Self::UpdateDiff(diff) => &diff.id,
 638            Self::UpdateTerminal(terminal) => &terminal.id,
 639        }
 640    }
 641}
 642
 643impl From<acp::ToolCallUpdate> for ToolCallUpdate {
 644    fn from(update: acp::ToolCallUpdate) -> Self {
 645        Self::UpdateFields(update)
 646    }
 647}
 648
 649impl From<ToolCallUpdateDiff> for ToolCallUpdate {
 650    fn from(diff: ToolCallUpdateDiff) -> Self {
 651        Self::UpdateDiff(diff)
 652    }
 653}
 654
 655#[derive(Debug, PartialEq)]
 656pub struct ToolCallUpdateDiff {
 657    pub id: acp::ToolCallId,
 658    pub diff: Entity<Diff>,
 659}
 660
 661impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
 662    fn from(terminal: ToolCallUpdateTerminal) -> Self {
 663        Self::UpdateTerminal(terminal)
 664    }
 665}
 666
 667#[derive(Debug, PartialEq)]
 668pub struct ToolCallUpdateTerminal {
 669    pub id: acp::ToolCallId,
 670    pub terminal: Entity<Terminal>,
 671}
 672
 673#[derive(Debug, Default)]
 674pub struct Plan {
 675    pub entries: Vec<PlanEntry>,
 676}
 677
 678#[derive(Debug)]
 679pub struct PlanStats<'a> {
 680    pub in_progress_entry: Option<&'a PlanEntry>,
 681    pub pending: u32,
 682    pub completed: u32,
 683}
 684
 685impl Plan {
 686    pub fn is_empty(&self) -> bool {
 687        self.entries.is_empty()
 688    }
 689
 690    pub fn stats(&self) -> PlanStats<'_> {
 691        let mut stats = PlanStats {
 692            in_progress_entry: None,
 693            pending: 0,
 694            completed: 0,
 695        };
 696
 697        for entry in &self.entries {
 698            match &entry.status {
 699                acp::PlanEntryStatus::Pending => {
 700                    stats.pending += 1;
 701                }
 702                acp::PlanEntryStatus::InProgress => {
 703                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 704                }
 705                acp::PlanEntryStatus::Completed => {
 706                    stats.completed += 1;
 707                }
 708            }
 709        }
 710
 711        stats
 712    }
 713}
 714
 715#[derive(Debug)]
 716pub struct PlanEntry {
 717    pub content: Entity<Markdown>,
 718    pub priority: acp::PlanEntryPriority,
 719    pub status: acp::PlanEntryStatus,
 720}
 721
 722impl PlanEntry {
 723    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 724        Self {
 725            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 726            priority: entry.priority,
 727            status: entry.status,
 728        }
 729    }
 730}
 731
 732#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 733pub struct TokenUsage {
 734    pub max_tokens: u64,
 735    pub used_tokens: u64,
 736}
 737
 738impl TokenUsage {
 739    pub fn ratio(&self) -> TokenUsageRatio {
 740        #[cfg(debug_assertions)]
 741        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 742            .unwrap_or("0.8".to_string())
 743            .parse()
 744            .unwrap();
 745        #[cfg(not(debug_assertions))]
 746        let warning_threshold: f32 = 0.8;
 747
 748        // When the maximum is unknown because there is no selected model,
 749        // avoid showing the token limit warning.
 750        if self.max_tokens == 0 {
 751            TokenUsageRatio::Normal
 752        } else if self.used_tokens >= self.max_tokens {
 753            TokenUsageRatio::Exceeded
 754        } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
 755            TokenUsageRatio::Warning
 756        } else {
 757            TokenUsageRatio::Normal
 758        }
 759    }
 760}
 761
 762#[derive(Debug, Clone, PartialEq, Eq)]
 763pub enum TokenUsageRatio {
 764    Normal,
 765    Warning,
 766    Exceeded,
 767}
 768
 769#[derive(Debug, Clone)]
 770pub struct RetryStatus {
 771    pub last_error: SharedString,
 772    pub attempt: usize,
 773    pub max_attempts: usize,
 774    pub started_at: Instant,
 775    pub duration: Duration,
 776}
 777
 778pub struct AcpThread {
 779    title: SharedString,
 780    entries: Vec<AgentThreadEntry>,
 781    plan: Plan,
 782    project: Entity<Project>,
 783    action_log: Entity<ActionLog>,
 784    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 785    send_task: Option<Task<()>>,
 786    connection: Rc<dyn AgentConnection>,
 787    session_id: acp::SessionId,
 788    token_usage: Option<TokenUsage>,
 789    prompt_capabilities: acp::PromptCapabilities,
 790    _observe_prompt_capabilities: Task<anyhow::Result<()>>,
 791    terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
 792}
 793
 794#[derive(Debug)]
 795pub enum AcpThreadEvent {
 796    NewEntry,
 797    TitleUpdated,
 798    TokenUsageUpdated,
 799    EntryUpdated(usize),
 800    EntriesRemoved(Range<usize>),
 801    ToolAuthorizationRequired,
 802    Retry(RetryStatus),
 803    Stopped,
 804    Error,
 805    LoadError(LoadError),
 806    PromptCapabilitiesUpdated,
 807    Refusal,
 808    AvailableCommandsUpdated(Vec<acp::AvailableCommand>),
 809    ModeUpdated(acp::SessionModeId),
 810}
 811
 812impl EventEmitter<AcpThreadEvent> for AcpThread {}
 813
 814#[derive(PartialEq, Eq, Debug)]
 815pub enum ThreadStatus {
 816    Idle,
 817    Generating,
 818}
 819
 820#[derive(Debug, Clone)]
 821pub enum LoadError {
 822    Unsupported {
 823        command: SharedString,
 824        current_version: SharedString,
 825        minimum_version: SharedString,
 826    },
 827    FailedToInstall(SharedString),
 828    Exited {
 829        status: ExitStatus,
 830    },
 831    Other(SharedString),
 832}
 833
 834impl Display for LoadError {
 835    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 836        match self {
 837            LoadError::Unsupported {
 838                command: path,
 839                current_version,
 840                minimum_version,
 841            } => {
 842                write!(
 843                    f,
 844                    "version {current_version} from {path} is not supported (need at least {minimum_version})"
 845                )
 846            }
 847            LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
 848            LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
 849            LoadError::Other(msg) => write!(f, "{msg}"),
 850        }
 851    }
 852}
 853
 854impl Error for LoadError {}
 855
 856impl AcpThread {
 857    pub fn new(
 858        title: impl Into<SharedString>,
 859        connection: Rc<dyn AgentConnection>,
 860        project: Entity<Project>,
 861        action_log: Entity<ActionLog>,
 862        session_id: acp::SessionId,
 863        mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
 864        cx: &mut Context<Self>,
 865    ) -> Self {
 866        let prompt_capabilities = prompt_capabilities_rx.borrow().clone();
 867        let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
 868            loop {
 869                let caps = prompt_capabilities_rx.recv().await?;
 870                this.update(cx, |this, cx| {
 871                    this.prompt_capabilities = caps;
 872                    cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
 873                })?;
 874            }
 875        });
 876
 877        Self {
 878            action_log,
 879            shared_buffers: Default::default(),
 880            entries: Default::default(),
 881            plan: Default::default(),
 882            title: title.into(),
 883            project,
 884            send_task: None,
 885            connection,
 886            session_id,
 887            token_usage: None,
 888            prompt_capabilities,
 889            _observe_prompt_capabilities: task,
 890            terminals: HashMap::default(),
 891        }
 892    }
 893
 894    pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
 895        self.prompt_capabilities.clone()
 896    }
 897
 898    pub fn connection(&self) -> &Rc<dyn AgentConnection> {
 899        &self.connection
 900    }
 901
 902    pub fn action_log(&self) -> &Entity<ActionLog> {
 903        &self.action_log
 904    }
 905
 906    pub fn project(&self) -> &Entity<Project> {
 907        &self.project
 908    }
 909
 910    pub fn title(&self) -> SharedString {
 911        self.title.clone()
 912    }
 913
 914    pub fn entries(&self) -> &[AgentThreadEntry] {
 915        &self.entries
 916    }
 917
 918    pub fn session_id(&self) -> &acp::SessionId {
 919        &self.session_id
 920    }
 921
 922    pub fn status(&self) -> ThreadStatus {
 923        if self.send_task.is_some() {
 924            ThreadStatus::Generating
 925        } else {
 926            ThreadStatus::Idle
 927        }
 928    }
 929
 930    pub fn token_usage(&self) -> Option<&TokenUsage> {
 931        self.token_usage.as_ref()
 932    }
 933
 934    pub fn has_pending_edit_tool_calls(&self) -> bool {
 935        for entry in self.entries.iter().rev() {
 936            match entry {
 937                AgentThreadEntry::UserMessage(_) => return false,
 938                AgentThreadEntry::ToolCall(
 939                    call @ ToolCall {
 940                        status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
 941                        ..
 942                    },
 943                ) if call.diffs().next().is_some() => {
 944                    return true;
 945                }
 946                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 947            }
 948        }
 949
 950        false
 951    }
 952
 953    pub fn used_tools_since_last_user_message(&self) -> bool {
 954        for entry in self.entries.iter().rev() {
 955            match entry {
 956                AgentThreadEntry::UserMessage(..) => return false,
 957                AgentThreadEntry::AssistantMessage(..) => continue,
 958                AgentThreadEntry::ToolCall(..) => return true,
 959            }
 960        }
 961
 962        false
 963    }
 964
 965    pub fn handle_session_update(
 966        &mut self,
 967        update: acp::SessionUpdate,
 968        cx: &mut Context<Self>,
 969    ) -> Result<(), acp::Error> {
 970        match update {
 971            acp::SessionUpdate::UserMessageChunk { content } => {
 972                self.push_user_content_block(None, content, cx);
 973            }
 974            acp::SessionUpdate::AgentMessageChunk { content } => {
 975                self.push_assistant_content_block(content, false, cx);
 976            }
 977            acp::SessionUpdate::AgentThoughtChunk { content } => {
 978                self.push_assistant_content_block(content, true, cx);
 979            }
 980            acp::SessionUpdate::ToolCall(tool_call) => {
 981                self.upsert_tool_call(tool_call, cx)?;
 982            }
 983            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 984                self.update_tool_call(tool_call_update, cx)?;
 985            }
 986            acp::SessionUpdate::Plan(plan) => {
 987                self.update_plan(plan, cx);
 988            }
 989            acp::SessionUpdate::AvailableCommandsUpdate { available_commands } => {
 990                cx.emit(AcpThreadEvent::AvailableCommandsUpdated(available_commands))
 991            }
 992            acp::SessionUpdate::CurrentModeUpdate { current_mode_id } => {
 993                cx.emit(AcpThreadEvent::ModeUpdated(current_mode_id))
 994            }
 995        }
 996        Ok(())
 997    }
 998
 999    pub fn push_user_content_block(
1000        &mut self,
1001        message_id: Option<UserMessageId>,
1002        chunk: acp::ContentBlock,
1003        cx: &mut Context<Self>,
1004    ) {
1005        let language_registry = self.project.read(cx).languages().clone();
1006        let entries_len = self.entries.len();
1007
1008        if let Some(last_entry) = self.entries.last_mut()
1009            && let AgentThreadEntry::UserMessage(UserMessage {
1010                id,
1011                content,
1012                chunks,
1013                ..
1014            }) = last_entry
1015        {
1016            *id = message_id.or(id.take());
1017            content.append(chunk.clone(), &language_registry, cx);
1018            chunks.push(chunk);
1019            let idx = entries_len - 1;
1020            cx.emit(AcpThreadEvent::EntryUpdated(idx));
1021        } else {
1022            let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
1023            self.push_entry(
1024                AgentThreadEntry::UserMessage(UserMessage {
1025                    id: message_id,
1026                    content,
1027                    chunks: vec![chunk],
1028                    checkpoint: None,
1029                }),
1030                cx,
1031            );
1032        }
1033    }
1034
1035    pub fn push_assistant_content_block(
1036        &mut self,
1037        chunk: acp::ContentBlock,
1038        is_thought: bool,
1039        cx: &mut Context<Self>,
1040    ) {
1041        let language_registry = self.project.read(cx).languages().clone();
1042        let entries_len = self.entries.len();
1043        if let Some(last_entry) = self.entries.last_mut()
1044            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1045        {
1046            let idx = entries_len - 1;
1047            cx.emit(AcpThreadEvent::EntryUpdated(idx));
1048            match (chunks.last_mut(), is_thought) {
1049                (Some(AssistantMessageChunk::Message { block }), false)
1050                | (Some(AssistantMessageChunk::Thought { block }), true) => {
1051                    block.append(chunk, &language_registry, cx)
1052                }
1053                _ => {
1054                    let block = ContentBlock::new(chunk, &language_registry, cx);
1055                    if is_thought {
1056                        chunks.push(AssistantMessageChunk::Thought { block })
1057                    } else {
1058                        chunks.push(AssistantMessageChunk::Message { block })
1059                    }
1060                }
1061            }
1062        } else {
1063            let block = ContentBlock::new(chunk, &language_registry, cx);
1064            let chunk = if is_thought {
1065                AssistantMessageChunk::Thought { block }
1066            } else {
1067                AssistantMessageChunk::Message { block }
1068            };
1069
1070            self.push_entry(
1071                AgentThreadEntry::AssistantMessage(AssistantMessage {
1072                    chunks: vec![chunk],
1073                }),
1074                cx,
1075            );
1076        }
1077    }
1078
1079    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1080        self.entries.push(entry);
1081        cx.emit(AcpThreadEvent::NewEntry);
1082    }
1083
1084    pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1085        self.connection.set_title(&self.session_id, cx).is_some()
1086    }
1087
1088    pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1089        if title != self.title {
1090            self.title = title.clone();
1091            cx.emit(AcpThreadEvent::TitleUpdated);
1092            if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1093                return set_title.run(title, cx);
1094            }
1095        }
1096        Task::ready(Ok(()))
1097    }
1098
1099    pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1100        self.token_usage = usage;
1101        cx.emit(AcpThreadEvent::TokenUsageUpdated);
1102    }
1103
1104    pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1105        cx.emit(AcpThreadEvent::Retry(status));
1106    }
1107
1108    pub fn update_tool_call(
1109        &mut self,
1110        update: impl Into<ToolCallUpdate>,
1111        cx: &mut Context<Self>,
1112    ) -> Result<()> {
1113        let update = update.into();
1114        let languages = self.project.read(cx).languages().clone();
1115
1116        let ix = match self.index_for_tool_call(update.id()) {
1117            Some(ix) => ix,
1118            None => {
1119                // Tool call not found - create a failed tool call entry
1120                let failed_tool_call = ToolCall {
1121                    id: update.id().clone(),
1122                    label: cx.new(|cx| Markdown::new("Tool call not found".into(), None, None, cx)),
1123                    kind: acp::ToolKind::Fetch,
1124                    content: vec![ToolCallContent::ContentBlock(ContentBlock::new(
1125                        acp::ContentBlock::Text(acp::TextContent {
1126                            text: "Tool call not found".to_string(),
1127                            annotations: None,
1128                            meta: None,
1129                        }),
1130                        &languages,
1131                        cx,
1132                    ))],
1133                    status: ToolCallStatus::Failed,
1134                    locations: Vec::new(),
1135                    resolved_locations: Vec::new(),
1136                    raw_input: None,
1137                    raw_output: None,
1138                };
1139                self.push_entry(AgentThreadEntry::ToolCall(failed_tool_call), cx);
1140                return Ok(());
1141            }
1142        };
1143        let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1144            unreachable!()
1145        };
1146
1147        match update {
1148            ToolCallUpdate::UpdateFields(update) => {
1149                // Check if there's terminal output in the meta field
1150                let terminal_output_result = update
1151                    .meta
1152                    .as_ref()
1153                    .and_then(|meta| meta.get("terminal_output"))
1154                    .and_then(|terminal_output| {
1155                        match (
1156                            terminal_output.get("terminal_id").and_then(|v| v.as_str()),
1157                            terminal_output.get("data").and_then(|v| v.as_str()),
1158                        ) {
1159                            (Some(terminal_id_str), Some(data_str)) => {
1160                                let data = data_str.as_bytes().to_vec();
1161                                let terminal_id = acp::TerminalId(terminal_id_str.into());
1162                                Some((terminal_id, data))
1163                            }
1164                            _ => None,
1165                        }
1166                    });
1167
1168                let location_updated = update.fields.locations.is_some();
1169                call.update_fields(update.fields, languages, &self.terminals, cx)?;
1170
1171                if let Some((terminal_id, data)) = terminal_output_result {
1172                    // Silently ignore errors - terminal output streaming is best-effort
1173                    let _ = self.write_terminal_output(terminal_id, &data, cx);
1174                }
1175                if location_updated {
1176                    self.resolve_locations(update.id, cx);
1177                }
1178            }
1179            ToolCallUpdate::UpdateDiff(update) => {
1180                call.content.clear();
1181                call.content.push(ToolCallContent::Diff(update.diff));
1182            }
1183            ToolCallUpdate::UpdateTerminal(update) => {
1184                call.content.clear();
1185                call.content
1186                    .push(ToolCallContent::Terminal(update.terminal));
1187            }
1188        }
1189
1190        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1191
1192        Ok(())
1193    }
1194
1195    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1196    pub fn upsert_tool_call(
1197        &mut self,
1198        tool_call: acp::ToolCall,
1199        cx: &mut Context<Self>,
1200    ) -> Result<(), acp::Error> {
1201        let status = tool_call.status.into();
1202        self.upsert_tool_call_inner(tool_call.into(), status, cx)
1203    }
1204
1205    /// Fails if id does not match an existing entry.
1206    pub fn upsert_tool_call_inner(
1207        &mut self,
1208        update: acp::ToolCallUpdate,
1209        status: ToolCallStatus,
1210        cx: &mut Context<Self>,
1211    ) -> Result<(), acp::Error> {
1212        let language_registry = self.project.read(cx).languages().clone();
1213        let id = update.id.clone();
1214
1215        if let Some(ix) = self.index_for_tool_call(&id) {
1216            let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1217                unreachable!()
1218            };
1219
1220            call.update_fields(update.fields, language_registry, &self.terminals, cx)?;
1221            call.status = status;
1222
1223            cx.emit(AcpThreadEvent::EntryUpdated(ix));
1224        } else {
1225            let call = ToolCall::from_acp(
1226                update.try_into()?,
1227                status,
1228                language_registry,
1229                &self.terminals,
1230                cx,
1231            )?;
1232            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1233        };
1234
1235        self.resolve_locations(id, cx);
1236        Ok(())
1237    }
1238
1239    fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1240        self.entries
1241            .iter()
1242            .enumerate()
1243            .rev()
1244            .find_map(|(index, entry)| {
1245                if let AgentThreadEntry::ToolCall(tool_call) = entry
1246                    && &tool_call.id == id
1247                {
1248                    Some(index)
1249                } else {
1250                    None
1251                }
1252            })
1253    }
1254
1255    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1256        // The tool call we are looking for is typically the last one, or very close to the end.
1257        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1258        self.entries
1259            .iter_mut()
1260            .enumerate()
1261            .rev()
1262            .find_map(|(index, tool_call)| {
1263                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1264                    && &tool_call.id == id
1265                {
1266                    Some((index, tool_call))
1267                } else {
1268                    None
1269                }
1270            })
1271    }
1272
1273    pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1274        self.entries
1275            .iter()
1276            .enumerate()
1277            .rev()
1278            .find_map(|(index, tool_call)| {
1279                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1280                    && &tool_call.id == id
1281                {
1282                    Some((index, tool_call))
1283                } else {
1284                    None
1285                }
1286            })
1287    }
1288
1289    pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1290        let project = self.project.clone();
1291        let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1292            return;
1293        };
1294        let task = tool_call.resolve_locations(project, cx);
1295        cx.spawn(async move |this, cx| {
1296            let resolved_locations = task.await;
1297            this.update(cx, |this, cx| {
1298                let project = this.project.clone();
1299                let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1300                    return;
1301                };
1302                if let Some(Some(location)) = resolved_locations.last() {
1303                    project.update(cx, |project, cx| {
1304                        if let Some(agent_location) = project.agent_location() {
1305                            let should_ignore = agent_location.buffer == location.buffer
1306                                && location
1307                                    .buffer
1308                                    .update(cx, |buffer, _| {
1309                                        let snapshot = buffer.snapshot();
1310                                        let old_position =
1311                                            agent_location.position.to_point(&snapshot);
1312                                        let new_position = location.position.to_point(&snapshot);
1313                                        // ignore this so that when we get updates from the edit tool
1314                                        // the position doesn't reset to the startof line
1315                                        old_position.row == new_position.row
1316                                            && old_position.column > new_position.column
1317                                    })
1318                                    .ok()
1319                                    .unwrap_or_default();
1320                            if !should_ignore {
1321                                project.set_agent_location(Some(location.clone()), cx);
1322                            }
1323                        }
1324                    });
1325                }
1326                if tool_call.resolved_locations != resolved_locations {
1327                    tool_call.resolved_locations = resolved_locations;
1328                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1329                }
1330            })
1331        })
1332        .detach();
1333    }
1334
1335    pub fn request_tool_call_authorization(
1336        &mut self,
1337        tool_call: acp::ToolCallUpdate,
1338        options: Vec<acp::PermissionOption>,
1339        respect_always_allow_setting: bool,
1340        cx: &mut Context<Self>,
1341    ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1342        let (tx, rx) = oneshot::channel();
1343
1344        if respect_always_allow_setting && AgentSettings::get_global(cx).always_allow_tool_actions {
1345            // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1346            // some tools would (incorrectly) continue to auto-accept.
1347            if let Some(allow_once_option) = options.iter().find_map(|option| {
1348                if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1349                    Some(option.id.clone())
1350                } else {
1351                    None
1352                }
1353            }) {
1354                self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1355                return Ok(async {
1356                    acp::RequestPermissionOutcome::Selected {
1357                        option_id: allow_once_option,
1358                    }
1359                }
1360                .boxed());
1361            }
1362        }
1363
1364        let status = ToolCallStatus::WaitingForConfirmation {
1365            options,
1366            respond_tx: tx,
1367        };
1368
1369        self.upsert_tool_call_inner(tool_call, status, cx)?;
1370        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1371
1372        let fut = async {
1373            match rx.await {
1374                Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
1375                Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1376            }
1377        }
1378        .boxed();
1379
1380        Ok(fut)
1381    }
1382
1383    pub fn authorize_tool_call(
1384        &mut self,
1385        id: acp::ToolCallId,
1386        option_id: acp::PermissionOptionId,
1387        option_kind: acp::PermissionOptionKind,
1388        cx: &mut Context<Self>,
1389    ) {
1390        let Some((ix, call)) = self.tool_call_mut(&id) else {
1391            return;
1392        };
1393
1394        let new_status = match option_kind {
1395            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1396                ToolCallStatus::Rejected
1397            }
1398            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1399                ToolCallStatus::InProgress
1400            }
1401        };
1402
1403        let curr_status = mem::replace(&mut call.status, new_status);
1404
1405        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1406            respond_tx.send(option_id).log_err();
1407        } else if cfg!(debug_assertions) {
1408            panic!("tried to authorize an already authorized tool call");
1409        }
1410
1411        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1412    }
1413
1414    pub fn first_tool_awaiting_confirmation(&self) -> Option<&ToolCall> {
1415        let mut first_tool_call = None;
1416
1417        for entry in self.entries.iter().rev() {
1418            match &entry {
1419                AgentThreadEntry::ToolCall(call) => {
1420                    if let ToolCallStatus::WaitingForConfirmation { .. } = call.status {
1421                        first_tool_call = Some(call);
1422                    } else {
1423                        continue;
1424                    }
1425                }
1426                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1427                    // Reached the beginning of the turn.
1428                    // If we had pending permission requests in the previous turn, they have been cancelled.
1429                    break;
1430                }
1431            }
1432        }
1433
1434        first_tool_call
1435    }
1436
1437    pub fn plan(&self) -> &Plan {
1438        &self.plan
1439    }
1440
1441    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1442        let new_entries_len = request.entries.len();
1443        let mut new_entries = request.entries.into_iter();
1444
1445        // Reuse existing markdown to prevent flickering
1446        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1447            let PlanEntry {
1448                content,
1449                priority,
1450                status,
1451            } = old;
1452            content.update(cx, |old, cx| {
1453                old.replace(new.content, cx);
1454            });
1455            *priority = new.priority;
1456            *status = new.status;
1457        }
1458        for new in new_entries {
1459            self.plan.entries.push(PlanEntry::from_acp(new, cx))
1460        }
1461        self.plan.entries.truncate(new_entries_len);
1462
1463        cx.notify();
1464    }
1465
1466    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1467        self.plan
1468            .entries
1469            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1470        cx.notify();
1471    }
1472
1473    #[cfg(any(test, feature = "test-support"))]
1474    pub fn send_raw(
1475        &mut self,
1476        message: &str,
1477        cx: &mut Context<Self>,
1478    ) -> BoxFuture<'static, Result<()>> {
1479        self.send(
1480            vec![acp::ContentBlock::Text(acp::TextContent {
1481                text: message.to_string(),
1482                annotations: None,
1483                meta: None,
1484            })],
1485            cx,
1486        )
1487    }
1488
1489    pub fn send(
1490        &mut self,
1491        message: Vec<acp::ContentBlock>,
1492        cx: &mut Context<Self>,
1493    ) -> BoxFuture<'static, Result<()>> {
1494        let block = ContentBlock::new_combined(
1495            message.clone(),
1496            self.project.read(cx).languages().clone(),
1497            cx,
1498        );
1499        let request = acp::PromptRequest {
1500            prompt: message.clone(),
1501            session_id: self.session_id.clone(),
1502            meta: None,
1503        };
1504        let git_store = self.project.read(cx).git_store().clone();
1505
1506        let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1507            Some(UserMessageId::new())
1508        } else {
1509            None
1510        };
1511
1512        self.run_turn(cx, async move |this, cx| {
1513            this.update(cx, |this, cx| {
1514                this.push_entry(
1515                    AgentThreadEntry::UserMessage(UserMessage {
1516                        id: message_id.clone(),
1517                        content: block,
1518                        chunks: message,
1519                        checkpoint: None,
1520                    }),
1521                    cx,
1522                );
1523            })
1524            .ok();
1525
1526            let old_checkpoint = git_store
1527                .update(cx, |git, cx| git.checkpoint(cx))?
1528                .await
1529                .context("failed to get old checkpoint")
1530                .log_err();
1531            this.update(cx, |this, cx| {
1532                if let Some((_ix, message)) = this.last_user_message() {
1533                    message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1534                        git_checkpoint,
1535                        show: false,
1536                    });
1537                }
1538                this.connection.prompt(message_id, request, cx)
1539            })?
1540            .await
1541        })
1542    }
1543
1544    pub fn can_resume(&self, cx: &App) -> bool {
1545        self.connection.resume(&self.session_id, cx).is_some()
1546    }
1547
1548    pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1549        self.run_turn(cx, async move |this, cx| {
1550            this.update(cx, |this, cx| {
1551                this.connection
1552                    .resume(&this.session_id, cx)
1553                    .map(|resume| resume.run(cx))
1554            })?
1555            .context("resuming a session is not supported")?
1556            .await
1557        })
1558    }
1559
1560    fn run_turn(
1561        &mut self,
1562        cx: &mut Context<Self>,
1563        f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1564    ) -> BoxFuture<'static, Result<()>> {
1565        self.clear_completed_plan_entries(cx);
1566
1567        let (tx, rx) = oneshot::channel();
1568        let cancel_task = self.cancel(cx);
1569
1570        self.send_task = Some(cx.spawn(async move |this, cx| {
1571            cancel_task.await;
1572            tx.send(f(this, cx).await).ok();
1573        }));
1574
1575        cx.spawn(async move |this, cx| {
1576            let response = rx.await;
1577
1578            this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1579                .await?;
1580
1581            this.update(cx, |this, cx| {
1582                this.project
1583                    .update(cx, |project, cx| project.set_agent_location(None, cx));
1584                match response {
1585                    Ok(Err(e)) => {
1586                        this.send_task.take();
1587                        cx.emit(AcpThreadEvent::Error);
1588                        Err(e)
1589                    }
1590                    result => {
1591                        let canceled = matches!(
1592                            result,
1593                            Ok(Ok(acp::PromptResponse {
1594                                stop_reason: acp::StopReason::Cancelled,
1595                                meta: None,
1596                            }))
1597                        );
1598
1599                        // We only take the task if the current prompt wasn't canceled.
1600                        //
1601                        // This prompt may have been canceled because another one was sent
1602                        // while it was still generating. In these cases, dropping `send_task`
1603                        // would cause the next generation to be canceled.
1604                        if !canceled {
1605                            this.send_task.take();
1606                        }
1607
1608                        // Handle refusal - distinguish between user prompt and tool call refusals
1609                        if let Ok(Ok(acp::PromptResponse {
1610                            stop_reason: acp::StopReason::Refusal,
1611                            meta: _,
1612                        })) = result
1613                        {
1614                            if let Some((user_msg_ix, _)) = this.last_user_message() {
1615                                // Check if there's a completed tool call with results after the last user message
1616                                // This indicates the refusal is in response to tool output, not the user's prompt
1617                                let has_completed_tool_call_after_user_msg =
1618                                    this.entries.iter().skip(user_msg_ix + 1).any(|entry| {
1619                                        if let AgentThreadEntry::ToolCall(tool_call) = entry {
1620                                            // Check if the tool call has completed and has output
1621                                            matches!(tool_call.status, ToolCallStatus::Completed)
1622                                                && tool_call.raw_output.is_some()
1623                                        } else {
1624                                            false
1625                                        }
1626                                    });
1627
1628                                if has_completed_tool_call_after_user_msg {
1629                                    // Refusal is due to tool output - don't truncate, just notify
1630                                    // The model refused based on what the tool returned
1631                                    cx.emit(AcpThreadEvent::Refusal);
1632                                } else {
1633                                    // User prompt was refused - truncate back to before the user message
1634                                    let range = user_msg_ix..this.entries.len();
1635                                    if range.start < range.end {
1636                                        this.entries.truncate(user_msg_ix);
1637                                        cx.emit(AcpThreadEvent::EntriesRemoved(range));
1638                                    }
1639                                    cx.emit(AcpThreadEvent::Refusal);
1640                                }
1641                            } else {
1642                                // No user message found, treat as general refusal
1643                                cx.emit(AcpThreadEvent::Refusal);
1644                            }
1645                        }
1646
1647                        cx.emit(AcpThreadEvent::Stopped);
1648                        Ok(())
1649                    }
1650                }
1651            })?
1652        })
1653        .boxed()
1654    }
1655
1656    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1657        let Some(send_task) = self.send_task.take() else {
1658            return Task::ready(());
1659        };
1660
1661        for entry in self.entries.iter_mut() {
1662            if let AgentThreadEntry::ToolCall(call) = entry {
1663                let cancel = matches!(
1664                    call.status,
1665                    ToolCallStatus::Pending
1666                        | ToolCallStatus::WaitingForConfirmation { .. }
1667                        | ToolCallStatus::InProgress
1668                );
1669
1670                if cancel {
1671                    call.status = ToolCallStatus::Canceled;
1672                }
1673            }
1674        }
1675
1676        self.connection.cancel(&self.session_id, cx);
1677
1678        // Wait for the send task to complete
1679        cx.foreground_executor().spawn(send_task)
1680    }
1681
1682    /// Restores the git working tree to the state at the given checkpoint (if one exists)
1683    pub fn restore_checkpoint(
1684        &mut self,
1685        id: UserMessageId,
1686        cx: &mut Context<Self>,
1687    ) -> Task<Result<()>> {
1688        let Some((_, message)) = self.user_message_mut(&id) else {
1689            return Task::ready(Err(anyhow!("message not found")));
1690        };
1691
1692        let checkpoint = message
1693            .checkpoint
1694            .as_ref()
1695            .map(|c| c.git_checkpoint.clone());
1696        let rewind = self.rewind(id.clone(), cx);
1697        let git_store = self.project.read(cx).git_store().clone();
1698
1699        cx.spawn(async move |_, cx| {
1700            rewind.await?;
1701            if let Some(checkpoint) = checkpoint {
1702                git_store
1703                    .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1704                    .await?;
1705            }
1706
1707            Ok(())
1708        })
1709    }
1710
1711    /// Rewinds this thread to before the entry at `index`, removing it and all
1712    /// subsequent entries while rejecting any action_log changes made from that point.
1713    /// Unlike `restore_checkpoint`, this method does not restore from git.
1714    pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1715        let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1716            return Task::ready(Err(anyhow!("not supported")));
1717        };
1718
1719        cx.spawn(async move |this, cx| {
1720            cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1721            this.update(cx, |this, cx| {
1722                if let Some((ix, _)) = this.user_message_mut(&id) {
1723                    let range = ix..this.entries.len();
1724                    this.entries.truncate(ix);
1725                    cx.emit(AcpThreadEvent::EntriesRemoved(range));
1726                }
1727                this.action_log()
1728                    .update(cx, |action_log, cx| action_log.reject_all_edits(cx))
1729            })?
1730            .await;
1731            Ok(())
1732        })
1733    }
1734
1735    fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1736        let git_store = self.project.read(cx).git_store().clone();
1737
1738        let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1739            if let Some(checkpoint) = message.checkpoint.as_ref() {
1740                checkpoint.git_checkpoint.clone()
1741            } else {
1742                return Task::ready(Ok(()));
1743            }
1744        } else {
1745            return Task::ready(Ok(()));
1746        };
1747
1748        let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1749        cx.spawn(async move |this, cx| {
1750            let new_checkpoint = new_checkpoint
1751                .await
1752                .context("failed to get new checkpoint")
1753                .log_err();
1754            if let Some(new_checkpoint) = new_checkpoint {
1755                let equal = git_store
1756                    .update(cx, |git, cx| {
1757                        git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1758                    })?
1759                    .await
1760                    .unwrap_or(true);
1761                this.update(cx, |this, cx| {
1762                    let (ix, message) = this.last_user_message().context("no user message")?;
1763                    let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1764                    checkpoint.show = !equal;
1765                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1766                    anyhow::Ok(())
1767                })??;
1768            }
1769
1770            Ok(())
1771        })
1772    }
1773
1774    fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1775        self.entries
1776            .iter_mut()
1777            .enumerate()
1778            .rev()
1779            .find_map(|(ix, entry)| {
1780                if let AgentThreadEntry::UserMessage(message) = entry {
1781                    Some((ix, message))
1782                } else {
1783                    None
1784                }
1785            })
1786    }
1787
1788    fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1789        self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1790            if let AgentThreadEntry::UserMessage(message) = entry {
1791                if message.id.as_ref() == Some(id) {
1792                    Some((ix, message))
1793                } else {
1794                    None
1795                }
1796            } else {
1797                None
1798            }
1799        })
1800    }
1801
1802    pub fn read_text_file(
1803        &self,
1804        path: PathBuf,
1805        line: Option<u32>,
1806        limit: Option<u32>,
1807        reuse_shared_snapshot: bool,
1808        cx: &mut Context<Self>,
1809    ) -> Task<Result<String>> {
1810        // Args are 1-based, move to 0-based
1811        let line = line.unwrap_or_default().saturating_sub(1);
1812        let limit = limit.unwrap_or(u32::MAX);
1813        let project = self.project.clone();
1814        let action_log = self.action_log.clone();
1815        cx.spawn(async move |this, cx| {
1816            let load = project.update(cx, |project, cx| {
1817                let path = project
1818                    .project_path_for_absolute_path(&path, cx)
1819                    .context("invalid path")?;
1820                anyhow::Ok(project.open_buffer(path, cx))
1821            });
1822            let buffer = load??.await?;
1823
1824            let snapshot = if reuse_shared_snapshot {
1825                this.read_with(cx, |this, _| {
1826                    this.shared_buffers.get(&buffer.clone()).cloned()
1827                })
1828                .log_err()
1829                .flatten()
1830            } else {
1831                None
1832            };
1833
1834            let snapshot = if let Some(snapshot) = snapshot {
1835                snapshot
1836            } else {
1837                action_log.update(cx, |action_log, cx| {
1838                    action_log.buffer_read(buffer.clone(), cx);
1839                })?;
1840
1841                let snapshot = buffer.update(cx, |buffer, _| buffer.snapshot())?;
1842                this.update(cx, |this, _| {
1843                    this.shared_buffers.insert(buffer.clone(), snapshot.clone());
1844                })?;
1845                snapshot
1846            };
1847
1848            let max_point = snapshot.max_point();
1849            if line >= max_point.row {
1850                anyhow::bail!(
1851                    "Attempting to read beyond the end of the file, line {}:{}",
1852                    max_point.row + 1,
1853                    max_point.column
1854                );
1855            }
1856
1857            let start = snapshot.anchor_before(Point::new(line, 0));
1858            let end = snapshot.anchor_before(Point::new(line.saturating_add(limit), 0));
1859
1860            project.update(cx, |project, cx| {
1861                project.set_agent_location(
1862                    Some(AgentLocation {
1863                        buffer: buffer.downgrade(),
1864                        position: start,
1865                    }),
1866                    cx,
1867                );
1868            })?;
1869
1870            Ok(snapshot.text_for_range(start..end).collect::<String>())
1871        })
1872    }
1873
1874    pub fn write_text_file(
1875        &self,
1876        path: PathBuf,
1877        content: String,
1878        cx: &mut Context<Self>,
1879    ) -> Task<Result<()>> {
1880        let project = self.project.clone();
1881        let action_log = self.action_log.clone();
1882        cx.spawn(async move |this, cx| {
1883            let load = project.update(cx, |project, cx| {
1884                let path = project
1885                    .project_path_for_absolute_path(&path, cx)
1886                    .context("invalid path")?;
1887                anyhow::Ok(project.open_buffer(path, cx))
1888            });
1889            let buffer = load??.await?;
1890            let snapshot = this.update(cx, |this, cx| {
1891                this.shared_buffers
1892                    .get(&buffer)
1893                    .cloned()
1894                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1895            })?;
1896            let edits = cx
1897                .background_executor()
1898                .spawn(async move {
1899                    let old_text = snapshot.text();
1900                    text_diff(old_text.as_str(), &content)
1901                        .into_iter()
1902                        .map(|(range, replacement)| {
1903                            (
1904                                snapshot.anchor_after(range.start)
1905                                    ..snapshot.anchor_before(range.end),
1906                                replacement,
1907                            )
1908                        })
1909                        .collect::<Vec<_>>()
1910                })
1911                .await;
1912
1913            project.update(cx, |project, cx| {
1914                project.set_agent_location(
1915                    Some(AgentLocation {
1916                        buffer: buffer.downgrade(),
1917                        position: edits
1918                            .last()
1919                            .map(|(range, _)| range.end)
1920                            .unwrap_or(Anchor::MIN),
1921                    }),
1922                    cx,
1923                );
1924            })?;
1925
1926            let format_on_save = cx.update(|cx| {
1927                action_log.update(cx, |action_log, cx| {
1928                    action_log.buffer_read(buffer.clone(), cx);
1929                });
1930
1931                let format_on_save = buffer.update(cx, |buffer, cx| {
1932                    buffer.edit(edits, None, cx);
1933
1934                    let settings = language::language_settings::language_settings(
1935                        buffer.language().map(|l| l.name()),
1936                        buffer.file(),
1937                        cx,
1938                    );
1939
1940                    settings.format_on_save != FormatOnSave::Off
1941                });
1942                action_log.update(cx, |action_log, cx| {
1943                    action_log.buffer_edited(buffer.clone(), cx);
1944                });
1945                format_on_save
1946            })?;
1947
1948            if format_on_save {
1949                let format_task = project.update(cx, |project, cx| {
1950                    project.format(
1951                        HashSet::from_iter([buffer.clone()]),
1952                        LspFormatTarget::Buffers,
1953                        false,
1954                        FormatTrigger::Save,
1955                        cx,
1956                    )
1957                })?;
1958                format_task.await.log_err();
1959
1960                action_log.update(cx, |action_log, cx| {
1961                    action_log.buffer_edited(buffer.clone(), cx);
1962                })?;
1963            }
1964
1965            project
1966                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1967                .await
1968        })
1969    }
1970
1971    pub fn create_terminal(
1972        &self,
1973        command: String,
1974        args: Vec<String>,
1975        extra_env: Vec<acp::EnvVariable>,
1976        cwd: Option<PathBuf>,
1977        output_byte_limit: Option<u64>,
1978        is_display_only: bool,
1979        cx: &mut Context<Self>,
1980    ) -> Task<Result<Entity<Terminal>>> {
1981        let env = match &cwd {
1982            Some(dir) => self.project.update(cx, |project, cx| {
1983                project.directory_environment(dir.as_path().into(), cx)
1984            }),
1985            None => Task::ready(None).shared(),
1986        };
1987
1988        let env = cx.spawn(async move |_, _| {
1989            let mut env = env.await.unwrap_or_default();
1990            if cfg!(unix) {
1991                env.insert("PAGER".into(), "cat".into());
1992            }
1993            for var in extra_env {
1994                env.insert(var.name, var.value);
1995            }
1996            env
1997        });
1998
1999        let project = self.project.clone();
2000        let language_registry = project.read(cx).languages().clone();
2001
2002        let terminal_id = acp::TerminalId(Uuid::new_v4().to_string().into());
2003        let terminal_task = cx.spawn({
2004            let terminal_id = terminal_id.clone();
2005            async move |_this, cx| {
2006                let env = env.await;
2007                let (command, args) = ShellBuilder::new(
2008                    project
2009                        .update(cx, |project, cx| {
2010                            project
2011                                .remote_client()
2012                                .and_then(|r| r.read(cx).default_system_shell())
2013                        })?
2014                        .as_deref(),
2015                    &Shell::Program(get_default_system_shell()),
2016                )
2017                .redirect_stdin_to_dev_null()
2018                .build(Some(command), &args);
2019
2020                let terminal = if is_display_only {
2021                    cx.update(|cx| {
2022                        TerminalBuilder::new_display_only(
2023                            Some(format!("Display: {}", command).into()),
2024                            CursorShape::Block,
2025                            AlternateScroll::On,
2026                            Some(10_000),
2027                            cx,
2028                        )
2029                    })??
2030                } else {
2031                    project
2032                        .update(cx, |project, cx| {
2033                            project.create_terminal_task(
2034                                task::SpawnInTerminal {
2035                                    command: Some(command.clone()),
2036                                    args: args.clone(),
2037                                    cwd: cwd.clone(),
2038                                    env,
2039                                    ..Default::default()
2040                                },
2041                                cx,
2042                            )
2043                        })?
2044                        .await?
2045                };
2046
2047                if is_display_only {
2048                    // For display-only terminals, we need special handling
2049                    cx.new(|cx| {
2050                        Terminal::new_display_only(
2051                            terminal_id,
2052                            &format!("{} {}", command, args.join(" ")),
2053                            cwd,
2054                            output_byte_limit.map(|l| l as usize),
2055                            terminal,
2056                            cx,
2057                        )
2058                    })
2059                } else {
2060                    cx.new(|cx| {
2061                        Terminal::new(
2062                            terminal_id,
2063                            &format!("{} {}", command, args.join(" ")),
2064                            cwd,
2065                            output_byte_limit.map(|l| l as usize),
2066                            terminal,
2067                            language_registry,
2068                            cx,
2069                        )
2070                    })
2071                }
2072            }
2073        });
2074
2075        cx.spawn(async move |this, cx| {
2076            let terminal = terminal_task.await?;
2077            this.update(cx, |this, _cx| {
2078                this.terminals.insert(terminal_id, terminal.clone());
2079                terminal
2080            })
2081        })
2082    }
2083
2084    pub fn kill_terminal(
2085        &mut self,
2086        terminal_id: acp::TerminalId,
2087        cx: &mut Context<Self>,
2088    ) -> Result<()> {
2089        self.terminals
2090            .get(&terminal_id)
2091            .context("Terminal not found")?
2092            .update(cx, |terminal, cx| {
2093                terminal.kill(cx);
2094            });
2095
2096        Ok(())
2097    }
2098
2099    pub fn release_terminal(
2100        &mut self,
2101        terminal_id: acp::TerminalId,
2102        cx: &mut Context<Self>,
2103    ) -> Result<()> {
2104        self.terminals
2105            .remove(&terminal_id)
2106            .context("Terminal not found")?
2107            .update(cx, |terminal, cx| {
2108                terminal.kill(cx);
2109            });
2110
2111        Ok(())
2112    }
2113
2114    pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2115        self.terminals
2116            .get(&terminal_id)
2117            .cloned()
2118            .context("Terminal not found")
2119    }
2120
2121    pub fn write_terminal_output(
2122        &mut self,
2123        terminal_id: acp::TerminalId,
2124        output: &[u8],
2125        cx: &mut Context<Self>,
2126    ) -> Result<()> {
2127        let terminal = self
2128            .terminals
2129            .get(&terminal_id)
2130            .context("Terminal not found")?;
2131
2132        terminal.update(cx, |terminal, cx| {
2133            terminal.write_output(output, cx);
2134        });
2135
2136        Ok(())
2137    }
2138
2139    pub fn to_markdown(&self, cx: &App) -> String {
2140        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2141    }
2142
2143    pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2144        cx.emit(AcpThreadEvent::LoadError(error));
2145    }
2146}
2147
2148fn markdown_for_raw_output(
2149    raw_output: &serde_json::Value,
2150    language_registry: &Arc<LanguageRegistry>,
2151    cx: &mut App,
2152) -> Option<Entity<Markdown>> {
2153    match raw_output {
2154        serde_json::Value::Null => None,
2155        serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2156            Markdown::new(
2157                value.to_string().into(),
2158                Some(language_registry.clone()),
2159                None,
2160                cx,
2161            )
2162        })),
2163        serde_json::Value::Number(value) => Some(cx.new(|cx| {
2164            Markdown::new(
2165                value.to_string().into(),
2166                Some(language_registry.clone()),
2167                None,
2168                cx,
2169            )
2170        })),
2171        serde_json::Value::String(value) => Some(cx.new(|cx| {
2172            Markdown::new(
2173                value.clone().into(),
2174                Some(language_registry.clone()),
2175                None,
2176                cx,
2177            )
2178        })),
2179        value => Some(cx.new(|cx| {
2180            Markdown::new(
2181                format!("```json\n{}\n```", value).into(),
2182                Some(language_registry.clone()),
2183                None,
2184                cx,
2185            )
2186        })),
2187    }
2188}
2189
2190#[cfg(test)]
2191mod tests {
2192    use super::*;
2193    use anyhow::anyhow;
2194    use futures::{channel::mpsc, future::LocalBoxFuture, select};
2195    use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2196    use indoc::indoc;
2197    use project::{FakeFs, Fs};
2198    use rand::{distr, prelude::*};
2199    use serde_json::json;
2200    use settings::SettingsStore;
2201    use smol::stream::StreamExt as _;
2202    use std::{
2203        any::Any,
2204        cell::RefCell,
2205        path::Path,
2206        rc::Rc,
2207        sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2208        time::Duration,
2209    };
2210    use util::path;
2211
2212    fn init_test(cx: &mut TestAppContext) {
2213        env_logger::try_init().ok();
2214        cx.update(|cx| {
2215            let settings_store = SettingsStore::test(cx);
2216            cx.set_global(settings_store);
2217            Project::init_settings(cx);
2218            language::init(cx);
2219        });
2220    }
2221
2222    #[gpui::test]
2223    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2224        init_test(cx);
2225
2226        let fs = FakeFs::new(cx.executor());
2227        let project = Project::test(fs, [], cx).await;
2228        let connection = Rc::new(FakeAgentConnection::new());
2229        let thread = cx
2230            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2231            .await
2232            .unwrap();
2233
2234        // Test creating a new user message
2235        thread.update(cx, |thread, cx| {
2236            thread.push_user_content_block(
2237                None,
2238                acp::ContentBlock::Text(acp::TextContent {
2239                    annotations: None,
2240                    text: "Hello, ".to_string(),
2241                    meta: None,
2242                }),
2243                cx,
2244            );
2245        });
2246
2247        thread.update(cx, |thread, cx| {
2248            assert_eq!(thread.entries.len(), 1);
2249            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2250                assert_eq!(user_msg.id, None);
2251                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2252            } else {
2253                panic!("Expected UserMessage");
2254            }
2255        });
2256
2257        // Test appending to existing user message
2258        let message_1_id = UserMessageId::new();
2259        thread.update(cx, |thread, cx| {
2260            thread.push_user_content_block(
2261                Some(message_1_id.clone()),
2262                acp::ContentBlock::Text(acp::TextContent {
2263                    annotations: None,
2264                    text: "world!".to_string(),
2265                    meta: None,
2266                }),
2267                cx,
2268            );
2269        });
2270
2271        thread.update(cx, |thread, cx| {
2272            assert_eq!(thread.entries.len(), 1);
2273            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2274                assert_eq!(user_msg.id, Some(message_1_id));
2275                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2276            } else {
2277                panic!("Expected UserMessage");
2278            }
2279        });
2280
2281        // Test creating new user message after assistant message
2282        thread.update(cx, |thread, cx| {
2283            thread.push_assistant_content_block(
2284                acp::ContentBlock::Text(acp::TextContent {
2285                    annotations: None,
2286                    text: "Assistant response".to_string(),
2287                    meta: None,
2288                }),
2289                false,
2290                cx,
2291            );
2292        });
2293
2294        let message_2_id = UserMessageId::new();
2295        thread.update(cx, |thread, cx| {
2296            thread.push_user_content_block(
2297                Some(message_2_id.clone()),
2298                acp::ContentBlock::Text(acp::TextContent {
2299                    annotations: None,
2300                    text: "New user message".to_string(),
2301                    meta: None,
2302                }),
2303                cx,
2304            );
2305        });
2306
2307        thread.update(cx, |thread, cx| {
2308            assert_eq!(thread.entries.len(), 3);
2309            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2310                assert_eq!(user_msg.id, Some(message_2_id));
2311                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2312            } else {
2313                panic!("Expected UserMessage at index 2");
2314            }
2315        });
2316    }
2317
2318    #[gpui::test]
2319    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2320        init_test(cx);
2321
2322        let fs = FakeFs::new(cx.executor());
2323        let project = Project::test(fs, [], cx).await;
2324        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2325            |_, thread, mut cx| {
2326                async move {
2327                    thread.update(&mut cx, |thread, cx| {
2328                        thread
2329                            .handle_session_update(
2330                                acp::SessionUpdate::AgentThoughtChunk {
2331                                    content: "Thinking ".into(),
2332                                },
2333                                cx,
2334                            )
2335                            .unwrap();
2336                        thread
2337                            .handle_session_update(
2338                                acp::SessionUpdate::AgentThoughtChunk {
2339                                    content: "hard!".into(),
2340                                },
2341                                cx,
2342                            )
2343                            .unwrap();
2344                    })?;
2345                    Ok(acp::PromptResponse {
2346                        stop_reason: acp::StopReason::EndTurn,
2347                        meta: None,
2348                    })
2349                }
2350                .boxed_local()
2351            },
2352        ));
2353
2354        let thread = cx
2355            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2356            .await
2357            .unwrap();
2358
2359        thread
2360            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2361            .await
2362            .unwrap();
2363
2364        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2365        assert_eq!(
2366            output,
2367            indoc! {r#"
2368            ## User
2369
2370            Hello from Zed!
2371
2372            ## Assistant
2373
2374            <thinking>
2375            Thinking hard!
2376            </thinking>
2377
2378            "#}
2379        );
2380    }
2381
2382    #[gpui::test]
2383    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2384        init_test(cx);
2385
2386        let fs = FakeFs::new(cx.executor());
2387        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2388            .await;
2389        let project = Project::test(fs.clone(), [], cx).await;
2390        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2391        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2392        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2393            move |_, thread, mut cx| {
2394                let read_file_tx = read_file_tx.clone();
2395                async move {
2396                    let content = thread
2397                        .update(&mut cx, |thread, cx| {
2398                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2399                        })
2400                        .unwrap()
2401                        .await
2402                        .unwrap();
2403                    assert_eq!(content, "one\ntwo\nthree\n");
2404                    read_file_tx.take().unwrap().send(()).unwrap();
2405                    thread
2406                        .update(&mut cx, |thread, cx| {
2407                            thread.write_text_file(
2408                                path!("/tmp/foo").into(),
2409                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
2410                                cx,
2411                            )
2412                        })
2413                        .unwrap()
2414                        .await
2415                        .unwrap();
2416                    Ok(acp::PromptResponse {
2417                        stop_reason: acp::StopReason::EndTurn,
2418                        meta: None,
2419                    })
2420                }
2421                .boxed_local()
2422            },
2423        ));
2424
2425        let (worktree, pathbuf) = project
2426            .update(cx, |project, cx| {
2427                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2428            })
2429            .await
2430            .unwrap();
2431        let buffer = project
2432            .update(cx, |project, cx| {
2433                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2434            })
2435            .await
2436            .unwrap();
2437
2438        let thread = cx
2439            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2440            .await
2441            .unwrap();
2442
2443        let request = thread.update(cx, |thread, cx| {
2444            thread.send_raw("Extend the count in /tmp/foo", cx)
2445        });
2446        read_file_rx.await.ok();
2447        buffer.update(cx, |buffer, cx| {
2448            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2449        });
2450        cx.run_until_parked();
2451        assert_eq!(
2452            buffer.read_with(cx, |buffer, _| buffer.text()),
2453            "zero\none\ntwo\nthree\nfour\nfive\n"
2454        );
2455        assert_eq!(
2456            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2457            "zero\none\ntwo\nthree\nfour\nfive\n"
2458        );
2459        request.await.unwrap();
2460    }
2461
2462    #[gpui::test]
2463    async fn test_reading_from_line(cx: &mut TestAppContext) {
2464        init_test(cx);
2465
2466        let fs = FakeFs::new(cx.executor());
2467        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\nfour\n"}))
2468            .await;
2469        let project = Project::test(fs.clone(), [], cx).await;
2470        project
2471            .update(cx, |project, cx| {
2472                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2473            })
2474            .await
2475            .unwrap();
2476
2477        let connection = Rc::new(FakeAgentConnection::new());
2478
2479        let thread = cx
2480            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2481            .await
2482            .unwrap();
2483
2484        // Whole file
2485        let content = thread
2486            .update(cx, |thread, cx| {
2487                thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2488            })
2489            .await
2490            .unwrap();
2491
2492        assert_eq!(content, "one\ntwo\nthree\nfour\n");
2493
2494        // Only start line
2495        let content = thread
2496            .update(cx, |thread, cx| {
2497                thread.read_text_file(path!("/tmp/foo").into(), Some(3), None, false, cx)
2498            })
2499            .await
2500            .unwrap();
2501
2502        assert_eq!(content, "three\nfour\n");
2503
2504        // Only limit
2505        let content = thread
2506            .update(cx, |thread, cx| {
2507                thread.read_text_file(path!("/tmp/foo").into(), None, Some(2), false, cx)
2508            })
2509            .await
2510            .unwrap();
2511
2512        assert_eq!(content, "one\ntwo\n");
2513
2514        // Range
2515        let content = thread
2516            .update(cx, |thread, cx| {
2517                thread.read_text_file(path!("/tmp/foo").into(), Some(2), Some(2), false, cx)
2518            })
2519            .await
2520            .unwrap();
2521
2522        assert_eq!(content, "two\nthree\n");
2523
2524        // Invalid
2525        let err = thread
2526            .update(cx, |thread, cx| {
2527                thread.read_text_file(path!("/tmp/foo").into(), Some(5), Some(2), false, cx)
2528            })
2529            .await
2530            .unwrap_err();
2531
2532        assert_eq!(
2533            err.to_string(),
2534            "Attempting to read beyond the end of the file, line 5:0"
2535        );
2536    }
2537
2538    #[gpui::test]
2539    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2540        init_test(cx);
2541
2542        let fs = FakeFs::new(cx.executor());
2543        let project = Project::test(fs, [], cx).await;
2544        let id = acp::ToolCallId("test".into());
2545
2546        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2547            let id = id.clone();
2548            move |_, thread, mut cx| {
2549                let id = id.clone();
2550                async move {
2551                    thread
2552                        .update(&mut cx, |thread, cx| {
2553                            thread.handle_session_update(
2554                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2555                                    id: id.clone(),
2556                                    title: "Label".into(),
2557                                    kind: acp::ToolKind::Fetch,
2558                                    status: acp::ToolCallStatus::InProgress,
2559                                    content: vec![],
2560                                    locations: vec![],
2561                                    raw_input: None,
2562                                    raw_output: None,
2563                                    meta: None,
2564                                }),
2565                                cx,
2566                            )
2567                        })
2568                        .unwrap()
2569                        .unwrap();
2570                    Ok(acp::PromptResponse {
2571                        stop_reason: acp::StopReason::EndTurn,
2572                        meta: None,
2573                    })
2574                }
2575                .boxed_local()
2576            }
2577        }));
2578
2579        let thread = cx
2580            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2581            .await
2582            .unwrap();
2583
2584        let request = thread.update(cx, |thread, cx| {
2585            thread.send_raw("Fetch https://example.com", cx)
2586        });
2587
2588        run_until_first_tool_call(&thread, cx).await;
2589
2590        thread.read_with(cx, |thread, _| {
2591            assert!(matches!(
2592                thread.entries[1],
2593                AgentThreadEntry::ToolCall(ToolCall {
2594                    status: ToolCallStatus::InProgress,
2595                    ..
2596                })
2597            ));
2598        });
2599
2600        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2601
2602        thread.read_with(cx, |thread, _| {
2603            assert!(matches!(
2604                &thread.entries[1],
2605                AgentThreadEntry::ToolCall(ToolCall {
2606                    status: ToolCallStatus::Canceled,
2607                    ..
2608                })
2609            ));
2610        });
2611
2612        thread
2613            .update(cx, |thread, cx| {
2614                thread.handle_session_update(
2615                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2616                        id,
2617                        fields: acp::ToolCallUpdateFields {
2618                            status: Some(acp::ToolCallStatus::Completed),
2619                            ..Default::default()
2620                        },
2621                        meta: None,
2622                    }),
2623                    cx,
2624                )
2625            })
2626            .unwrap();
2627
2628        request.await.unwrap();
2629
2630        thread.read_with(cx, |thread, _| {
2631            assert!(matches!(
2632                thread.entries[1],
2633                AgentThreadEntry::ToolCall(ToolCall {
2634                    status: ToolCallStatus::Completed,
2635                    ..
2636                })
2637            ));
2638        });
2639    }
2640
2641    #[gpui::test]
2642    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2643        init_test(cx);
2644        let fs = FakeFs::new(cx.background_executor.clone());
2645        fs.insert_tree(path!("/test"), json!({})).await;
2646        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2647
2648        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2649            move |_, thread, mut cx| {
2650                async move {
2651                    thread
2652                        .update(&mut cx, |thread, cx| {
2653                            thread.handle_session_update(
2654                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2655                                    id: acp::ToolCallId("test".into()),
2656                                    title: "Label".into(),
2657                                    kind: acp::ToolKind::Edit,
2658                                    status: acp::ToolCallStatus::Completed,
2659                                    content: vec![acp::ToolCallContent::Diff {
2660                                        diff: acp::Diff {
2661                                            path: "/test/test.txt".into(),
2662                                            old_text: None,
2663                                            new_text: "foo".into(),
2664                                            meta: None,
2665                                        },
2666                                    }],
2667                                    locations: vec![],
2668                                    raw_input: None,
2669                                    raw_output: None,
2670                                    meta: None,
2671                                }),
2672                                cx,
2673                            )
2674                        })
2675                        .unwrap()
2676                        .unwrap();
2677                    Ok(acp::PromptResponse {
2678                        stop_reason: acp::StopReason::EndTurn,
2679                        meta: None,
2680                    })
2681                }
2682                .boxed_local()
2683            }
2684        }));
2685
2686        let thread = cx
2687            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2688            .await
2689            .unwrap();
2690
2691        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2692            .await
2693            .unwrap();
2694
2695        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2696    }
2697
2698    #[gpui::test(iterations = 10)]
2699    async fn test_checkpoints(cx: &mut TestAppContext) {
2700        init_test(cx);
2701        let fs = FakeFs::new(cx.background_executor.clone());
2702        fs.insert_tree(
2703            path!("/test"),
2704            json!({
2705                ".git": {}
2706            }),
2707        )
2708        .await;
2709        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2710
2711        let simulate_changes = Arc::new(AtomicBool::new(true));
2712        let next_filename = Arc::new(AtomicUsize::new(0));
2713        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2714            let simulate_changes = simulate_changes.clone();
2715            let next_filename = next_filename.clone();
2716            let fs = fs.clone();
2717            move |request, thread, mut cx| {
2718                let fs = fs.clone();
2719                let simulate_changes = simulate_changes.clone();
2720                let next_filename = next_filename.clone();
2721                async move {
2722                    if simulate_changes.load(SeqCst) {
2723                        let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2724                        fs.write(Path::new(&filename), b"").await?;
2725                    }
2726
2727                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2728                        panic!("expected text content block");
2729                    };
2730                    thread.update(&mut cx, |thread, cx| {
2731                        thread
2732                            .handle_session_update(
2733                                acp::SessionUpdate::AgentMessageChunk {
2734                                    content: content.text.to_uppercase().into(),
2735                                },
2736                                cx,
2737                            )
2738                            .unwrap();
2739                    })?;
2740                    Ok(acp::PromptResponse {
2741                        stop_reason: acp::StopReason::EndTurn,
2742                        meta: None,
2743                    })
2744                }
2745                .boxed_local()
2746            }
2747        }));
2748        let thread = cx
2749            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2750            .await
2751            .unwrap();
2752
2753        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2754            .await
2755            .unwrap();
2756        thread.read_with(cx, |thread, cx| {
2757            assert_eq!(
2758                thread.to_markdown(cx),
2759                indoc! {"
2760                    ## User (checkpoint)
2761
2762                    Lorem
2763
2764                    ## Assistant
2765
2766                    LOREM
2767
2768                "}
2769            );
2770        });
2771        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2772
2773        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2774            .await
2775            .unwrap();
2776        thread.read_with(cx, |thread, cx| {
2777            assert_eq!(
2778                thread.to_markdown(cx),
2779                indoc! {"
2780                    ## User (checkpoint)
2781
2782                    Lorem
2783
2784                    ## Assistant
2785
2786                    LOREM
2787
2788                    ## User (checkpoint)
2789
2790                    ipsum
2791
2792                    ## Assistant
2793
2794                    IPSUM
2795
2796                "}
2797            );
2798        });
2799        assert_eq!(
2800            fs.files(),
2801            vec![
2802                Path::new(path!("/test/file-0")),
2803                Path::new(path!("/test/file-1"))
2804            ]
2805        );
2806
2807        // Checkpoint isn't stored when there are no changes.
2808        simulate_changes.store(false, SeqCst);
2809        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2810            .await
2811            .unwrap();
2812        thread.read_with(cx, |thread, cx| {
2813            assert_eq!(
2814                thread.to_markdown(cx),
2815                indoc! {"
2816                    ## User (checkpoint)
2817
2818                    Lorem
2819
2820                    ## Assistant
2821
2822                    LOREM
2823
2824                    ## User (checkpoint)
2825
2826                    ipsum
2827
2828                    ## Assistant
2829
2830                    IPSUM
2831
2832                    ## User
2833
2834                    dolor
2835
2836                    ## Assistant
2837
2838                    DOLOR
2839
2840                "}
2841            );
2842        });
2843        assert_eq!(
2844            fs.files(),
2845            vec![
2846                Path::new(path!("/test/file-0")),
2847                Path::new(path!("/test/file-1"))
2848            ]
2849        );
2850
2851        // Rewinding the conversation truncates the history and restores the checkpoint.
2852        thread
2853            .update(cx, |thread, cx| {
2854                let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2855                    panic!("unexpected entries {:?}", thread.entries)
2856                };
2857                thread.restore_checkpoint(message.id.clone().unwrap(), cx)
2858            })
2859            .await
2860            .unwrap();
2861        thread.read_with(cx, |thread, cx| {
2862            assert_eq!(
2863                thread.to_markdown(cx),
2864                indoc! {"
2865                    ## User (checkpoint)
2866
2867                    Lorem
2868
2869                    ## Assistant
2870
2871                    LOREM
2872
2873                "}
2874            );
2875        });
2876        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2877    }
2878
2879    #[gpui::test]
2880    async fn test_tool_result_refusal(cx: &mut TestAppContext) {
2881        use std::sync::atomic::AtomicUsize;
2882        init_test(cx);
2883
2884        let fs = FakeFs::new(cx.executor());
2885        let project = Project::test(fs, None, cx).await;
2886
2887        // Create a connection that simulates refusal after tool result
2888        let prompt_count = Arc::new(AtomicUsize::new(0));
2889        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2890            let prompt_count = prompt_count.clone();
2891            move |_request, thread, mut cx| {
2892                let count = prompt_count.fetch_add(1, SeqCst);
2893                async move {
2894                    if count == 0 {
2895                        // First prompt: Generate a tool call with result
2896                        thread.update(&mut cx, |thread, cx| {
2897                            thread
2898                                .handle_session_update(
2899                                    acp::SessionUpdate::ToolCall(acp::ToolCall {
2900                                        id: acp::ToolCallId("tool1".into()),
2901                                        title: "Test Tool".into(),
2902                                        kind: acp::ToolKind::Fetch,
2903                                        status: acp::ToolCallStatus::Completed,
2904                                        content: vec![],
2905                                        locations: vec![],
2906                                        raw_input: Some(serde_json::json!({"query": "test"})),
2907                                        raw_output: Some(
2908                                            serde_json::json!({"result": "inappropriate content"}),
2909                                        ),
2910                                        meta: None,
2911                                    }),
2912                                    cx,
2913                                )
2914                                .unwrap();
2915                        })?;
2916
2917                        // Now return refusal because of the tool result
2918                        Ok(acp::PromptResponse {
2919                            stop_reason: acp::StopReason::Refusal,
2920                            meta: None,
2921                        })
2922                    } else {
2923                        Ok(acp::PromptResponse {
2924                            stop_reason: acp::StopReason::EndTurn,
2925                            meta: None,
2926                        })
2927                    }
2928                }
2929                .boxed_local()
2930            }
2931        }));
2932
2933        let thread = cx
2934            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2935            .await
2936            .unwrap();
2937
2938        // Track if we see a Refusal event
2939        let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
2940        let saw_refusal_event_captured = saw_refusal_event.clone();
2941        thread.update(cx, |_thread, cx| {
2942            cx.subscribe(
2943                &thread,
2944                move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
2945                    if matches!(event, AcpThreadEvent::Refusal) {
2946                        *saw_refusal_event_captured.lock().unwrap() = true;
2947                    }
2948                },
2949            )
2950            .detach();
2951        });
2952
2953        // Send a user message - this will trigger tool call and then refusal
2954        let send_task = thread.update(cx, |thread, cx| {
2955            thread.send(
2956                vec![acp::ContentBlock::Text(acp::TextContent {
2957                    text: "Hello".into(),
2958                    annotations: None,
2959                    meta: None,
2960                })],
2961                cx,
2962            )
2963        });
2964        cx.background_executor.spawn(send_task).detach();
2965        cx.run_until_parked();
2966
2967        // Verify that:
2968        // 1. A Refusal event WAS emitted (because it's a tool result refusal, not user prompt)
2969        // 2. The user message was NOT truncated
2970        assert!(
2971            *saw_refusal_event.lock().unwrap(),
2972            "Refusal event should be emitted for tool result refusals"
2973        );
2974
2975        thread.read_with(cx, |thread, _| {
2976            let entries = thread.entries();
2977            assert!(entries.len() >= 2, "Should have user message and tool call");
2978
2979            // Verify user message is still there
2980            assert!(
2981                matches!(entries[0], AgentThreadEntry::UserMessage(_)),
2982                "User message should not be truncated"
2983            );
2984
2985            // Verify tool call is there with result
2986            if let AgentThreadEntry::ToolCall(tool_call) = &entries[1] {
2987                assert!(
2988                    tool_call.raw_output.is_some(),
2989                    "Tool call should have output"
2990                );
2991            } else {
2992                panic!("Expected tool call at index 1");
2993            }
2994        });
2995    }
2996
2997    #[gpui::test]
2998    async fn test_user_prompt_refusal_emits_event(cx: &mut TestAppContext) {
2999        init_test(cx);
3000
3001        let fs = FakeFs::new(cx.executor());
3002        let project = Project::test(fs, None, cx).await;
3003
3004        let refuse_next = Arc::new(AtomicBool::new(false));
3005        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3006            let refuse_next = refuse_next.clone();
3007            move |_request, _thread, _cx| {
3008                if refuse_next.load(SeqCst) {
3009                    async move {
3010                        Ok(acp::PromptResponse {
3011                            stop_reason: acp::StopReason::Refusal,
3012                            meta: None,
3013                        })
3014                    }
3015                    .boxed_local()
3016                } else {
3017                    async move {
3018                        Ok(acp::PromptResponse {
3019                            stop_reason: acp::StopReason::EndTurn,
3020                            meta: None,
3021                        })
3022                    }
3023                    .boxed_local()
3024                }
3025            }
3026        }));
3027
3028        let thread = cx
3029            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3030            .await
3031            .unwrap();
3032
3033        // Track if we see a Refusal event
3034        let saw_refusal_event = Arc::new(std::sync::Mutex::new(false));
3035        let saw_refusal_event_captured = saw_refusal_event.clone();
3036        thread.update(cx, |_thread, cx| {
3037            cx.subscribe(
3038                &thread,
3039                move |_thread, _event_thread, event: &AcpThreadEvent, _cx| {
3040                    if matches!(event, AcpThreadEvent::Refusal) {
3041                        *saw_refusal_event_captured.lock().unwrap() = true;
3042                    }
3043                },
3044            )
3045            .detach();
3046        });
3047
3048        // Send a message that will be refused
3049        refuse_next.store(true, SeqCst);
3050        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3051            .await
3052            .unwrap();
3053
3054        // Verify that a Refusal event WAS emitted for user prompt refusal
3055        assert!(
3056            *saw_refusal_event.lock().unwrap(),
3057            "Refusal event should be emitted for user prompt refusals"
3058        );
3059
3060        // Verify the message was truncated (user prompt refusal)
3061        thread.read_with(cx, |thread, cx| {
3062            assert_eq!(thread.to_markdown(cx), "");
3063        });
3064    }
3065
3066    #[gpui::test]
3067    async fn test_refusal(cx: &mut TestAppContext) {
3068        init_test(cx);
3069        let fs = FakeFs::new(cx.background_executor.clone());
3070        fs.insert_tree(path!("/"), json!({})).await;
3071        let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
3072
3073        let refuse_next = Arc::new(AtomicBool::new(false));
3074        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
3075            let refuse_next = refuse_next.clone();
3076            move |request, thread, mut cx| {
3077                let refuse_next = refuse_next.clone();
3078                async move {
3079                    if refuse_next.load(SeqCst) {
3080                        return Ok(acp::PromptResponse {
3081                            stop_reason: acp::StopReason::Refusal,
3082                            meta: None,
3083                        });
3084                    }
3085
3086                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
3087                        panic!("expected text content block");
3088                    };
3089                    thread.update(&mut cx, |thread, cx| {
3090                        thread
3091                            .handle_session_update(
3092                                acp::SessionUpdate::AgentMessageChunk {
3093                                    content: content.text.to_uppercase().into(),
3094                                },
3095                                cx,
3096                            )
3097                            .unwrap();
3098                    })?;
3099                    Ok(acp::PromptResponse {
3100                        stop_reason: acp::StopReason::EndTurn,
3101                        meta: None,
3102                    })
3103                }
3104                .boxed_local()
3105            }
3106        }));
3107        let thread = cx
3108            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3109            .await
3110            .unwrap();
3111
3112        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
3113            .await
3114            .unwrap();
3115        thread.read_with(cx, |thread, cx| {
3116            assert_eq!(
3117                thread.to_markdown(cx),
3118                indoc! {"
3119                    ## User
3120
3121                    hello
3122
3123                    ## Assistant
3124
3125                    HELLO
3126
3127                "}
3128            );
3129        });
3130
3131        // Simulate refusing the second message. The message should be truncated
3132        // when a user prompt is refused.
3133        refuse_next.store(true, SeqCst);
3134        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
3135            .await
3136            .unwrap();
3137        thread.read_with(cx, |thread, cx| {
3138            assert_eq!(
3139                thread.to_markdown(cx),
3140                indoc! {"
3141                    ## User
3142
3143                    hello
3144
3145                    ## Assistant
3146
3147                    HELLO
3148
3149                "}
3150            );
3151        });
3152    }
3153
3154    async fn run_until_first_tool_call(
3155        thread: &Entity<AcpThread>,
3156        cx: &mut TestAppContext,
3157    ) -> usize {
3158        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
3159
3160        let subscription = cx.update(|cx| {
3161            cx.subscribe(thread, move |thread, _, cx| {
3162                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
3163                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
3164                        return tx.try_send(ix).unwrap();
3165                    }
3166                }
3167            })
3168        });
3169
3170        select! {
3171            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
3172                panic!("Timeout waiting for tool call")
3173            }
3174            ix = rx.next().fuse() => {
3175                drop(subscription);
3176                ix.unwrap()
3177            }
3178        }
3179    }
3180
3181    #[derive(Clone, Default)]
3182    struct FakeAgentConnection {
3183        auth_methods: Vec<acp::AuthMethod>,
3184        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
3185        on_user_message: Option<
3186            Rc<
3187                dyn Fn(
3188                        acp::PromptRequest,
3189                        WeakEntity<AcpThread>,
3190                        AsyncApp,
3191                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3192                    + 'static,
3193            >,
3194        >,
3195    }
3196
3197    impl FakeAgentConnection {
3198        fn new() -> Self {
3199            Self {
3200                auth_methods: Vec::new(),
3201                on_user_message: None,
3202                sessions: Arc::default(),
3203            }
3204        }
3205
3206        #[expect(unused)]
3207        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
3208            self.auth_methods = auth_methods;
3209            self
3210        }
3211
3212        fn on_user_message(
3213            mut self,
3214            handler: impl Fn(
3215                acp::PromptRequest,
3216                WeakEntity<AcpThread>,
3217                AsyncApp,
3218            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
3219            + 'static,
3220        ) -> Self {
3221            self.on_user_message.replace(Rc::new(handler));
3222            self
3223        }
3224    }
3225
3226    impl AgentConnection for FakeAgentConnection {
3227        fn auth_methods(&self) -> &[acp::AuthMethod] {
3228            &self.auth_methods
3229        }
3230
3231        fn new_thread(
3232            self: Rc<Self>,
3233            project: Entity<Project>,
3234            _cwd: &Path,
3235            cx: &mut App,
3236        ) -> Task<gpui::Result<Entity<AcpThread>>> {
3237            let session_id = acp::SessionId(
3238                rand::rng()
3239                    .sample_iter(&distr::Alphanumeric)
3240                    .take(7)
3241                    .map(char::from)
3242                    .collect::<String>()
3243                    .into(),
3244            );
3245            let action_log = cx.new(|_| ActionLog::new(project.clone()));
3246            let thread = cx.new(|cx| {
3247                AcpThread::new(
3248                    "Test",
3249                    self.clone(),
3250                    project,
3251                    action_log,
3252                    session_id.clone(),
3253                    watch::Receiver::constant(acp::PromptCapabilities {
3254                        image: true,
3255                        audio: true,
3256                        embedded_context: true,
3257                        meta: None,
3258                    }),
3259                    cx,
3260                )
3261            });
3262            self.sessions.lock().insert(session_id, thread.downgrade());
3263            Task::ready(Ok(thread))
3264        }
3265
3266        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
3267            if self.auth_methods().iter().any(|m| m.id == method) {
3268                Task::ready(Ok(()))
3269            } else {
3270                Task::ready(Err(anyhow!("Invalid Auth Method")))
3271            }
3272        }
3273
3274        fn prompt(
3275            &self,
3276            _id: Option<UserMessageId>,
3277            params: acp::PromptRequest,
3278            cx: &mut App,
3279        ) -> Task<gpui::Result<acp::PromptResponse>> {
3280            let sessions = self.sessions.lock();
3281            let thread = sessions.get(&params.session_id).unwrap();
3282            if let Some(handler) = &self.on_user_message {
3283                let handler = handler.clone();
3284                let thread = thread.clone();
3285                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
3286            } else {
3287                Task::ready(Ok(acp::PromptResponse {
3288                    stop_reason: acp::StopReason::EndTurn,
3289                    meta: None,
3290                }))
3291            }
3292        }
3293
3294        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
3295            let sessions = self.sessions.lock();
3296            let thread = sessions.get(session_id).unwrap().clone();
3297
3298            cx.spawn(async move |cx| {
3299                thread
3300                    .update(cx, |thread, cx| thread.cancel(cx))
3301                    .unwrap()
3302                    .await
3303            })
3304            .detach();
3305        }
3306
3307        fn truncate(
3308            &self,
3309            session_id: &acp::SessionId,
3310            _cx: &App,
3311        ) -> Option<Rc<dyn AgentSessionTruncate>> {
3312            Some(Rc::new(FakeAgentSessionEditor {
3313                _session_id: session_id.clone(),
3314            }))
3315        }
3316
3317        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
3318            self
3319        }
3320    }
3321
3322    struct FakeAgentSessionEditor {
3323        _session_id: acp::SessionId,
3324    }
3325
3326    impl AgentSessionTruncate for FakeAgentSessionEditor {
3327        fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
3328            Task::ready(Ok(()))
3329        }
3330    }
3331
3332    #[gpui::test]
3333    async fn test_tool_call_not_found_creates_failed_entry(cx: &mut TestAppContext) {
3334        init_test(cx);
3335
3336        let fs = FakeFs::new(cx.executor());
3337        let project = Project::test(fs, [], cx).await;
3338        let connection = Rc::new(FakeAgentConnection::new());
3339        let thread = cx
3340            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
3341            .await
3342            .unwrap();
3343
3344        // Try to update a tool call that doesn't exist
3345        let nonexistent_id = acp::ToolCallId("nonexistent-tool-call".into());
3346        thread.update(cx, |thread, cx| {
3347            let result = thread.handle_session_update(
3348                acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
3349                    id: nonexistent_id.clone(),
3350                    fields: acp::ToolCallUpdateFields {
3351                        status: Some(acp::ToolCallStatus::Completed),
3352                        ..Default::default()
3353                    },
3354                    meta: None,
3355                }),
3356                cx,
3357            );
3358
3359            // The update should succeed (not return an error)
3360            assert!(result.is_ok());
3361
3362            // There should now be exactly one entry in the thread
3363            assert_eq!(thread.entries.len(), 1);
3364
3365            // The entry should be a failed tool call
3366            if let AgentThreadEntry::ToolCall(tool_call) = &thread.entries[0] {
3367                assert_eq!(tool_call.id, nonexistent_id);
3368                assert!(matches!(tool_call.status, ToolCallStatus::Failed));
3369                assert_eq!(tool_call.kind, acp::ToolKind::Fetch);
3370
3371                // Check that the content contains the error message
3372                assert_eq!(tool_call.content.len(), 1);
3373                if let ToolCallContent::ContentBlock(content_block) = &tool_call.content[0] {
3374                    match content_block {
3375                        ContentBlock::Markdown { markdown } => {
3376                            let markdown_text = markdown.read(cx).source();
3377                            assert!(markdown_text.contains("Tool call not found"));
3378                        }
3379                        ContentBlock::Empty => panic!("Expected markdown content, got empty"),
3380                        ContentBlock::ResourceLink { .. } => {
3381                            panic!("Expected markdown content, got resource link")
3382                        }
3383                    }
3384                } else {
3385                    panic!("Expected ContentBlock, got: {:?}", tool_call.content[0]);
3386                }
3387            } else {
3388                panic!("Expected ToolCall entry, got: {:?}", thread.entries[0]);
3389            }
3390        });
3391    }
3392}