acp_thread.rs

   1mod connection;
   2mod diff;
   3mod mention;
   4mod terminal;
   5
   6use collections::HashSet;
   7pub use connection::*;
   8pub use diff::*;
   9use language::language_settings::FormatOnSave;
  10pub use mention::*;
  11use project::lsp_store::{FormatTrigger, LspFormatTarget};
  12use serde::{Deserialize, Serialize};
  13pub use terminal::*;
  14
  15use action_log::ActionLog;
  16use agent_client_protocol as acp;
  17use anyhow::{Context as _, Result, anyhow};
  18use editor::Bias;
  19use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  20use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  21use itertools::Itertools;
  22use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
  23use markdown::Markdown;
  24use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
  25use std::collections::HashMap;
  26use std::error::Error;
  27use std::fmt::{Formatter, Write};
  28use std::ops::Range;
  29use std::process::ExitStatus;
  30use std::rc::Rc;
  31use std::time::{Duration, Instant};
  32use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
  33use ui::App;
  34use util::ResultExt;
  35
  36#[derive(Debug)]
  37pub struct UserMessage {
  38    pub id: Option<UserMessageId>,
  39    pub content: ContentBlock,
  40    pub chunks: Vec<acp::ContentBlock>,
  41    pub checkpoint: Option<Checkpoint>,
  42}
  43
  44#[derive(Debug)]
  45pub struct Checkpoint {
  46    git_checkpoint: GitStoreCheckpoint,
  47    pub show: bool,
  48}
  49
  50impl UserMessage {
  51    fn to_markdown(&self, cx: &App) -> String {
  52        let mut markdown = String::new();
  53        if self
  54            .checkpoint
  55            .as_ref()
  56            .is_some_and(|checkpoint| checkpoint.show)
  57        {
  58            writeln!(markdown, "## User (checkpoint)").unwrap();
  59        } else {
  60            writeln!(markdown, "## User").unwrap();
  61        }
  62        writeln!(markdown).unwrap();
  63        writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
  64        writeln!(markdown).unwrap();
  65        markdown
  66    }
  67}
  68
  69#[derive(Debug, PartialEq)]
  70pub struct AssistantMessage {
  71    pub chunks: Vec<AssistantMessageChunk>,
  72}
  73
  74impl AssistantMessage {
  75    pub fn to_markdown(&self, cx: &App) -> String {
  76        format!(
  77            "## Assistant\n\n{}\n\n",
  78            self.chunks
  79                .iter()
  80                .map(|chunk| chunk.to_markdown(cx))
  81                .join("\n\n")
  82        )
  83    }
  84}
  85
  86#[derive(Debug, PartialEq)]
  87pub enum AssistantMessageChunk {
  88    Message { block: ContentBlock },
  89    Thought { block: ContentBlock },
  90}
  91
  92impl AssistantMessageChunk {
  93    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
  94        Self::Message {
  95            block: ContentBlock::new(chunk.into(), language_registry, cx),
  96        }
  97    }
  98
  99    fn to_markdown(&self, cx: &App) -> String {
 100        match self {
 101            Self::Message { block } => block.to_markdown(cx).to_string(),
 102            Self::Thought { block } => {
 103                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 104            }
 105        }
 106    }
 107}
 108
 109#[derive(Debug)]
 110pub enum AgentThreadEntry {
 111    UserMessage(UserMessage),
 112    AssistantMessage(AssistantMessage),
 113    ToolCall(ToolCall),
 114}
 115
 116impl AgentThreadEntry {
 117    pub fn to_markdown(&self, cx: &App) -> String {
 118        match self {
 119            Self::UserMessage(message) => message.to_markdown(cx),
 120            Self::AssistantMessage(message) => message.to_markdown(cx),
 121            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 122        }
 123    }
 124
 125    pub fn user_message(&self) -> Option<&UserMessage> {
 126        if let AgentThreadEntry::UserMessage(message) = self {
 127            Some(message)
 128        } else {
 129            None
 130        }
 131    }
 132
 133    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 134        if let AgentThreadEntry::ToolCall(call) = self {
 135            itertools::Either::Left(call.diffs())
 136        } else {
 137            itertools::Either::Right(std::iter::empty())
 138        }
 139    }
 140
 141    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 142        if let AgentThreadEntry::ToolCall(call) = self {
 143            itertools::Either::Left(call.terminals())
 144        } else {
 145            itertools::Either::Right(std::iter::empty())
 146        }
 147    }
 148
 149    pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
 150        if let AgentThreadEntry::ToolCall(ToolCall {
 151            locations,
 152            resolved_locations,
 153            ..
 154        }) = self
 155        {
 156            Some((
 157                locations.get(ix)?.clone(),
 158                resolved_locations.get(ix)?.clone()?,
 159            ))
 160        } else {
 161            None
 162        }
 163    }
 164}
 165
 166#[derive(Debug)]
 167pub struct ToolCall {
 168    pub id: acp::ToolCallId,
 169    pub label: Entity<Markdown>,
 170    pub kind: acp::ToolKind,
 171    pub content: Vec<ToolCallContent>,
 172    pub status: ToolCallStatus,
 173    pub locations: Vec<acp::ToolCallLocation>,
 174    pub resolved_locations: Vec<Option<AgentLocation>>,
 175    pub raw_input: Option<serde_json::Value>,
 176    pub raw_output: Option<serde_json::Value>,
 177}
 178
 179impl ToolCall {
 180    fn from_acp(
 181        tool_call: acp::ToolCall,
 182        status: ToolCallStatus,
 183        language_registry: Arc<LanguageRegistry>,
 184        cx: &mut App,
 185    ) -> Self {
 186        Self {
 187            id: tool_call.id,
 188            label: cx.new(|cx| {
 189                Markdown::new(
 190                    tool_call.title.into(),
 191                    Some(language_registry.clone()),
 192                    None,
 193                    cx,
 194                )
 195            }),
 196            kind: tool_call.kind,
 197            content: tool_call
 198                .content
 199                .into_iter()
 200                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 201                .collect(),
 202            locations: tool_call.locations,
 203            resolved_locations: Vec::default(),
 204            status,
 205            raw_input: tool_call.raw_input,
 206            raw_output: tool_call.raw_output,
 207        }
 208    }
 209
 210    fn update_fields(
 211        &mut self,
 212        fields: acp::ToolCallUpdateFields,
 213        language_registry: Arc<LanguageRegistry>,
 214        cx: &mut App,
 215    ) {
 216        let acp::ToolCallUpdateFields {
 217            kind,
 218            status,
 219            title,
 220            content,
 221            locations,
 222            raw_input,
 223            raw_output,
 224        } = fields;
 225
 226        if let Some(kind) = kind {
 227            self.kind = kind;
 228        }
 229
 230        if let Some(status) = status {
 231            self.status = status.into();
 232        }
 233
 234        if let Some(title) = title {
 235            self.label.update(cx, |label, cx| {
 236                label.replace(title, cx);
 237            });
 238        }
 239
 240        if let Some(content) = content {
 241            let new_content_len = content.len();
 242            let mut content = content.into_iter();
 243
 244            // Reuse existing content if we can
 245            for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
 246                old.update_from_acp(new, language_registry.clone(), cx);
 247            }
 248            for new in content {
 249                self.content.push(ToolCallContent::from_acp(
 250                    new,
 251                    language_registry.clone(),
 252                    cx,
 253                ))
 254            }
 255            self.content.truncate(new_content_len);
 256        }
 257
 258        if let Some(locations) = locations {
 259            self.locations = locations;
 260        }
 261
 262        if let Some(raw_input) = raw_input {
 263            self.raw_input = Some(raw_input);
 264        }
 265
 266        if let Some(raw_output) = raw_output {
 267            if self.content.is_empty()
 268                && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
 269            {
 270                self.content
 271                    .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
 272                        markdown,
 273                    }));
 274            }
 275            self.raw_output = Some(raw_output);
 276        }
 277    }
 278
 279    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 280        self.content.iter().filter_map(|content| match content {
 281            ToolCallContent::Diff(diff) => Some(diff),
 282            ToolCallContent::ContentBlock(_) => None,
 283            ToolCallContent::Terminal(_) => None,
 284        })
 285    }
 286
 287    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 288        self.content.iter().filter_map(|content| match content {
 289            ToolCallContent::Terminal(terminal) => Some(terminal),
 290            ToolCallContent::ContentBlock(_) => None,
 291            ToolCallContent::Diff(_) => None,
 292        })
 293    }
 294
 295    fn to_markdown(&self, cx: &App) -> String {
 296        let mut markdown = format!(
 297            "**Tool Call: {}**\nStatus: {}\n\n",
 298            self.label.read(cx).source(),
 299            self.status
 300        );
 301        for content in &self.content {
 302            markdown.push_str(content.to_markdown(cx).as_str());
 303            markdown.push_str("\n\n");
 304        }
 305        markdown
 306    }
 307
 308    async fn resolve_location(
 309        location: acp::ToolCallLocation,
 310        project: WeakEntity<Project>,
 311        cx: &mut AsyncApp,
 312    ) -> Option<AgentLocation> {
 313        let buffer = project
 314            .update(cx, |project, cx| {
 315                project
 316                    .project_path_for_absolute_path(&location.path, cx)
 317                    .map(|path| project.open_buffer(path, cx))
 318            })
 319            .ok()??;
 320        let buffer = buffer.await.log_err()?;
 321        let position = buffer
 322            .update(cx, |buffer, _| {
 323                if let Some(row) = location.line {
 324                    let snapshot = buffer.snapshot();
 325                    let column = snapshot.indent_size_for_line(row).len;
 326                    let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
 327                    snapshot.anchor_before(point)
 328                } else {
 329                    Anchor::MIN
 330                }
 331            })
 332            .ok()?;
 333
 334        Some(AgentLocation {
 335            buffer: buffer.downgrade(),
 336            position,
 337        })
 338    }
 339
 340    fn resolve_locations(
 341        &self,
 342        project: Entity<Project>,
 343        cx: &mut App,
 344    ) -> Task<Vec<Option<AgentLocation>>> {
 345        let locations = self.locations.clone();
 346        project.update(cx, |_, cx| {
 347            cx.spawn(async move |project, cx| {
 348                let mut new_locations = Vec::new();
 349                for location in locations {
 350                    new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
 351                }
 352                new_locations
 353            })
 354        })
 355    }
 356}
 357
 358#[derive(Debug)]
 359pub enum ToolCallStatus {
 360    /// The tool call hasn't started running yet, but we start showing it to
 361    /// the user.
 362    Pending,
 363    /// The tool call is waiting for confirmation from the user.
 364    WaitingForConfirmation {
 365        options: Vec<acp::PermissionOption>,
 366        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 367    },
 368    /// The tool call is currently running.
 369    InProgress,
 370    /// The tool call completed successfully.
 371    Completed,
 372    /// The tool call failed.
 373    Failed,
 374    /// The user rejected the tool call.
 375    Rejected,
 376    /// The user canceled generation so the tool call was canceled.
 377    Canceled,
 378}
 379
 380impl From<acp::ToolCallStatus> for ToolCallStatus {
 381    fn from(status: acp::ToolCallStatus) -> Self {
 382        match status {
 383            acp::ToolCallStatus::Pending => Self::Pending,
 384            acp::ToolCallStatus::InProgress => Self::InProgress,
 385            acp::ToolCallStatus::Completed => Self::Completed,
 386            acp::ToolCallStatus::Failed => Self::Failed,
 387        }
 388    }
 389}
 390
 391impl Display for ToolCallStatus {
 392    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 393        write!(
 394            f,
 395            "{}",
 396            match self {
 397                ToolCallStatus::Pending => "Pending",
 398                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 399                ToolCallStatus::InProgress => "In Progress",
 400                ToolCallStatus::Completed => "Completed",
 401                ToolCallStatus::Failed => "Failed",
 402                ToolCallStatus::Rejected => "Rejected",
 403                ToolCallStatus::Canceled => "Canceled",
 404            }
 405        )
 406    }
 407}
 408
 409#[derive(Debug, PartialEq, Clone)]
 410pub enum ContentBlock {
 411    Empty,
 412    Markdown { markdown: Entity<Markdown> },
 413    ResourceLink { resource_link: acp::ResourceLink },
 414}
 415
 416impl ContentBlock {
 417    pub fn new(
 418        block: acp::ContentBlock,
 419        language_registry: &Arc<LanguageRegistry>,
 420        cx: &mut App,
 421    ) -> Self {
 422        let mut this = Self::Empty;
 423        this.append(block, language_registry, cx);
 424        this
 425    }
 426
 427    pub fn new_combined(
 428        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 429        language_registry: Arc<LanguageRegistry>,
 430        cx: &mut App,
 431    ) -> Self {
 432        let mut this = Self::Empty;
 433        for block in blocks {
 434            this.append(block, &language_registry, cx);
 435        }
 436        this
 437    }
 438
 439    pub fn append(
 440        &mut self,
 441        block: acp::ContentBlock,
 442        language_registry: &Arc<LanguageRegistry>,
 443        cx: &mut App,
 444    ) {
 445        if matches!(self, ContentBlock::Empty)
 446            && let acp::ContentBlock::ResourceLink(resource_link) = block
 447        {
 448            *self = ContentBlock::ResourceLink { resource_link };
 449            return;
 450        }
 451
 452        let new_content = self.block_string_contents(block);
 453
 454        match self {
 455            ContentBlock::Empty => {
 456                *self = Self::create_markdown_block(new_content, language_registry, cx);
 457            }
 458            ContentBlock::Markdown { markdown } => {
 459                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 460            }
 461            ContentBlock::ResourceLink { resource_link } => {
 462                let existing_content = Self::resource_link_md(&resource_link.uri);
 463                let combined = format!("{}\n{}", existing_content, new_content);
 464
 465                *self = Self::create_markdown_block(combined, language_registry, cx);
 466            }
 467        }
 468    }
 469
 470    fn create_markdown_block(
 471        content: String,
 472        language_registry: &Arc<LanguageRegistry>,
 473        cx: &mut App,
 474    ) -> ContentBlock {
 475        ContentBlock::Markdown {
 476            markdown: cx
 477                .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
 478        }
 479    }
 480
 481    fn block_string_contents(&self, block: acp::ContentBlock) -> String {
 482        match block {
 483            acp::ContentBlock::Text(text_content) => text_content.text,
 484            acp::ContentBlock::ResourceLink(resource_link) => {
 485                Self::resource_link_md(&resource_link.uri)
 486            }
 487            acp::ContentBlock::Resource(acp::EmbeddedResource {
 488                resource:
 489                    acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
 490                        uri,
 491                        ..
 492                    }),
 493                ..
 494            }) => Self::resource_link_md(&uri),
 495            acp::ContentBlock::Image(image) => Self::image_md(&image),
 496            acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
 497        }
 498    }
 499
 500    fn resource_link_md(uri: &str) -> String {
 501        if let Some(uri) = MentionUri::parse(uri).log_err() {
 502            uri.as_link().to_string()
 503        } else {
 504            uri.to_string()
 505        }
 506    }
 507
 508    fn image_md(_image: &acp::ImageContent) -> String {
 509        "`Image`".into()
 510    }
 511
 512    pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 513        match self {
 514            ContentBlock::Empty => "",
 515            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 516            ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
 517        }
 518    }
 519
 520    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 521        match self {
 522            ContentBlock::Empty => None,
 523            ContentBlock::Markdown { markdown } => Some(markdown),
 524            ContentBlock::ResourceLink { .. } => None,
 525        }
 526    }
 527
 528    pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
 529        match self {
 530            ContentBlock::ResourceLink { resource_link } => Some(resource_link),
 531            _ => None,
 532        }
 533    }
 534}
 535
 536#[derive(Debug)]
 537pub enum ToolCallContent {
 538    ContentBlock(ContentBlock),
 539    Diff(Entity<Diff>),
 540    Terminal(Entity<Terminal>),
 541}
 542
 543impl ToolCallContent {
 544    pub fn from_acp(
 545        content: acp::ToolCallContent,
 546        language_registry: Arc<LanguageRegistry>,
 547        cx: &mut App,
 548    ) -> Self {
 549        match content {
 550            acp::ToolCallContent::Content { content } => {
 551                Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
 552            }
 553            acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
 554                Diff::finalized(
 555                    diff.path,
 556                    diff.old_text,
 557                    diff.new_text,
 558                    language_registry,
 559                    cx,
 560                )
 561            })),
 562        }
 563    }
 564
 565    pub fn update_from_acp(
 566        &mut self,
 567        new: acp::ToolCallContent,
 568        language_registry: Arc<LanguageRegistry>,
 569        cx: &mut App,
 570    ) {
 571        let needs_update = match (&self, &new) {
 572            (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
 573                old_diff.read(cx).needs_update(
 574                    new_diff.old_text.as_deref().unwrap_or(""),
 575                    &new_diff.new_text,
 576                    cx,
 577                )
 578            }
 579            _ => true,
 580        };
 581
 582        if needs_update {
 583            *self = Self::from_acp(new, language_registry, cx);
 584        }
 585    }
 586
 587    pub fn to_markdown(&self, cx: &App) -> String {
 588        match self {
 589            Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
 590            Self::Diff(diff) => diff.read(cx).to_markdown(cx),
 591            Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
 592        }
 593    }
 594}
 595
 596#[derive(Debug, PartialEq)]
 597pub enum ToolCallUpdate {
 598    UpdateFields(acp::ToolCallUpdate),
 599    UpdateDiff(ToolCallUpdateDiff),
 600    UpdateTerminal(ToolCallUpdateTerminal),
 601}
 602
 603impl ToolCallUpdate {
 604    fn id(&self) -> &acp::ToolCallId {
 605        match self {
 606            Self::UpdateFields(update) => &update.id,
 607            Self::UpdateDiff(diff) => &diff.id,
 608            Self::UpdateTerminal(terminal) => &terminal.id,
 609        }
 610    }
 611}
 612
 613impl From<acp::ToolCallUpdate> for ToolCallUpdate {
 614    fn from(update: acp::ToolCallUpdate) -> Self {
 615        Self::UpdateFields(update)
 616    }
 617}
 618
 619impl From<ToolCallUpdateDiff> for ToolCallUpdate {
 620    fn from(diff: ToolCallUpdateDiff) -> Self {
 621        Self::UpdateDiff(diff)
 622    }
 623}
 624
 625#[derive(Debug, PartialEq)]
 626pub struct ToolCallUpdateDiff {
 627    pub id: acp::ToolCallId,
 628    pub diff: Entity<Diff>,
 629}
 630
 631impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
 632    fn from(terminal: ToolCallUpdateTerminal) -> Self {
 633        Self::UpdateTerminal(terminal)
 634    }
 635}
 636
 637#[derive(Debug, PartialEq)]
 638pub struct ToolCallUpdateTerminal {
 639    pub id: acp::ToolCallId,
 640    pub terminal: Entity<Terminal>,
 641}
 642
 643#[derive(Debug, Default)]
 644pub struct Plan {
 645    pub entries: Vec<PlanEntry>,
 646}
 647
 648#[derive(Debug)]
 649pub struct PlanStats<'a> {
 650    pub in_progress_entry: Option<&'a PlanEntry>,
 651    pub pending: u32,
 652    pub completed: u32,
 653}
 654
 655impl Plan {
 656    pub fn is_empty(&self) -> bool {
 657        self.entries.is_empty()
 658    }
 659
 660    pub fn stats(&self) -> PlanStats<'_> {
 661        let mut stats = PlanStats {
 662            in_progress_entry: None,
 663            pending: 0,
 664            completed: 0,
 665        };
 666
 667        for entry in &self.entries {
 668            match &entry.status {
 669                acp::PlanEntryStatus::Pending => {
 670                    stats.pending += 1;
 671                }
 672                acp::PlanEntryStatus::InProgress => {
 673                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 674                }
 675                acp::PlanEntryStatus::Completed => {
 676                    stats.completed += 1;
 677                }
 678            }
 679        }
 680
 681        stats
 682    }
 683}
 684
 685#[derive(Debug)]
 686pub struct PlanEntry {
 687    pub content: Entity<Markdown>,
 688    pub priority: acp::PlanEntryPriority,
 689    pub status: acp::PlanEntryStatus,
 690}
 691
 692impl PlanEntry {
 693    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 694        Self {
 695            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 696            priority: entry.priority,
 697            status: entry.status,
 698        }
 699    }
 700}
 701
 702#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 703pub struct TokenUsage {
 704    pub max_tokens: u64,
 705    pub used_tokens: u64,
 706}
 707
 708impl TokenUsage {
 709    pub fn ratio(&self) -> TokenUsageRatio {
 710        #[cfg(debug_assertions)]
 711        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 712            .unwrap_or("0.8".to_string())
 713            .parse()
 714            .unwrap();
 715        #[cfg(not(debug_assertions))]
 716        let warning_threshold: f32 = 0.8;
 717
 718        // When the maximum is unknown because there is no selected model,
 719        // avoid showing the token limit warning.
 720        if self.max_tokens == 0 {
 721            TokenUsageRatio::Normal
 722        } else if self.used_tokens >= self.max_tokens {
 723            TokenUsageRatio::Exceeded
 724        } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
 725            TokenUsageRatio::Warning
 726        } else {
 727            TokenUsageRatio::Normal
 728        }
 729    }
 730}
 731
 732#[derive(Debug, Clone, PartialEq, Eq)]
 733pub enum TokenUsageRatio {
 734    Normal,
 735    Warning,
 736    Exceeded,
 737}
 738
 739#[derive(Debug, Clone)]
 740pub struct RetryStatus {
 741    pub last_error: SharedString,
 742    pub attempt: usize,
 743    pub max_attempts: usize,
 744    pub started_at: Instant,
 745    pub duration: Duration,
 746}
 747
 748pub struct AcpThread {
 749    title: SharedString,
 750    entries: Vec<AgentThreadEntry>,
 751    plan: Plan,
 752    project: Entity<Project>,
 753    action_log: Entity<ActionLog>,
 754    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 755    send_task: Option<Task<()>>,
 756    connection: Rc<dyn AgentConnection>,
 757    session_id: acp::SessionId,
 758    token_usage: Option<TokenUsage>,
 759    prompt_capabilities: acp::PromptCapabilities,
 760    _observe_prompt_capabilities: Task<anyhow::Result<()>>,
 761}
 762
 763#[derive(Debug)]
 764pub enum AcpThreadEvent {
 765    NewEntry,
 766    TitleUpdated,
 767    TokenUsageUpdated,
 768    EntryUpdated(usize),
 769    EntriesRemoved(Range<usize>),
 770    ToolAuthorizationRequired,
 771    Retry(RetryStatus),
 772    Stopped,
 773    Error,
 774    LoadError(LoadError),
 775    PromptCapabilitiesUpdated,
 776}
 777
 778impl EventEmitter<AcpThreadEvent> for AcpThread {}
 779
 780#[derive(PartialEq, Eq, Debug)]
 781pub enum ThreadStatus {
 782    Idle,
 783    WaitingForToolConfirmation,
 784    Generating,
 785}
 786
 787#[derive(Debug, Clone)]
 788pub enum LoadError {
 789    NotInstalled {
 790        error_message: SharedString,
 791        install_message: SharedString,
 792        install_command: String,
 793    },
 794    Unsupported {
 795        error_message: SharedString,
 796        upgrade_message: SharedString,
 797        upgrade_command: String,
 798    },
 799    Exited {
 800        status: ExitStatus,
 801    },
 802    Other(SharedString),
 803}
 804
 805impl Display for LoadError {
 806    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 807        match self {
 808            LoadError::NotInstalled { error_message, .. }
 809            | LoadError::Unsupported { error_message, .. } => {
 810                write!(f, "{error_message}")
 811            }
 812            LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
 813            LoadError::Other(msg) => write!(f, "{}", msg),
 814        }
 815    }
 816}
 817
 818impl Error for LoadError {}
 819
 820impl AcpThread {
 821    pub fn new(
 822        title: impl Into<SharedString>,
 823        connection: Rc<dyn AgentConnection>,
 824        project: Entity<Project>,
 825        action_log: Entity<ActionLog>,
 826        session_id: acp::SessionId,
 827        mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
 828        cx: &mut Context<Self>,
 829    ) -> Self {
 830        let prompt_capabilities = *prompt_capabilities_rx.borrow();
 831        let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
 832            loop {
 833                let caps = prompt_capabilities_rx.recv().await?;
 834                this.update(cx, |this, cx| {
 835                    this.prompt_capabilities = caps;
 836                    cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
 837                })?;
 838            }
 839        });
 840
 841        Self {
 842            action_log,
 843            shared_buffers: Default::default(),
 844            entries: Default::default(),
 845            plan: Default::default(),
 846            title: title.into(),
 847            project,
 848            send_task: None,
 849            connection,
 850            session_id,
 851            token_usage: None,
 852            prompt_capabilities,
 853            _observe_prompt_capabilities: task,
 854        }
 855    }
 856
 857    pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
 858        self.prompt_capabilities
 859    }
 860
 861    pub fn connection(&self) -> &Rc<dyn AgentConnection> {
 862        &self.connection
 863    }
 864
 865    pub fn action_log(&self) -> &Entity<ActionLog> {
 866        &self.action_log
 867    }
 868
 869    pub fn project(&self) -> &Entity<Project> {
 870        &self.project
 871    }
 872
 873    pub fn title(&self) -> SharedString {
 874        self.title.clone()
 875    }
 876
 877    pub fn entries(&self) -> &[AgentThreadEntry] {
 878        &self.entries
 879    }
 880
 881    pub fn session_id(&self) -> &acp::SessionId {
 882        &self.session_id
 883    }
 884
 885    pub fn status(&self) -> ThreadStatus {
 886        if self.send_task.is_some() {
 887            if self.waiting_for_tool_confirmation() {
 888                ThreadStatus::WaitingForToolConfirmation
 889            } else {
 890                ThreadStatus::Generating
 891            }
 892        } else {
 893            ThreadStatus::Idle
 894        }
 895    }
 896
 897    pub fn token_usage(&self) -> Option<&TokenUsage> {
 898        self.token_usage.as_ref()
 899    }
 900
 901    pub fn has_pending_edit_tool_calls(&self) -> bool {
 902        for entry in self.entries.iter().rev() {
 903            match entry {
 904                AgentThreadEntry::UserMessage(_) => return false,
 905                AgentThreadEntry::ToolCall(
 906                    call @ ToolCall {
 907                        status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
 908                        ..
 909                    },
 910                ) if call.diffs().next().is_some() => {
 911                    return true;
 912                }
 913                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 914            }
 915        }
 916
 917        false
 918    }
 919
 920    pub fn used_tools_since_last_user_message(&self) -> bool {
 921        for entry in self.entries.iter().rev() {
 922            match entry {
 923                AgentThreadEntry::UserMessage(..) => return false,
 924                AgentThreadEntry::AssistantMessage(..) => continue,
 925                AgentThreadEntry::ToolCall(..) => return true,
 926            }
 927        }
 928
 929        false
 930    }
 931
 932    pub fn handle_session_update(
 933        &mut self,
 934        update: acp::SessionUpdate,
 935        cx: &mut Context<Self>,
 936    ) -> Result<(), acp::Error> {
 937        match update {
 938            acp::SessionUpdate::UserMessageChunk { content } => {
 939                self.push_user_content_block(None, content, cx);
 940            }
 941            acp::SessionUpdate::AgentMessageChunk { content } => {
 942                self.push_assistant_content_block(content, false, cx);
 943            }
 944            acp::SessionUpdate::AgentThoughtChunk { content } => {
 945                self.push_assistant_content_block(content, true, cx);
 946            }
 947            acp::SessionUpdate::ToolCall(tool_call) => {
 948                self.upsert_tool_call(tool_call, cx)?;
 949            }
 950            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 951                self.update_tool_call(tool_call_update, cx)?;
 952            }
 953            acp::SessionUpdate::Plan(plan) => {
 954                self.update_plan(plan, cx);
 955            }
 956        }
 957        Ok(())
 958    }
 959
 960    pub fn push_user_content_block(
 961        &mut self,
 962        message_id: Option<UserMessageId>,
 963        chunk: acp::ContentBlock,
 964        cx: &mut Context<Self>,
 965    ) {
 966        let language_registry = self.project.read(cx).languages().clone();
 967        let entries_len = self.entries.len();
 968
 969        if let Some(last_entry) = self.entries.last_mut()
 970            && let AgentThreadEntry::UserMessage(UserMessage {
 971                id,
 972                content,
 973                chunks,
 974                ..
 975            }) = last_entry
 976        {
 977            *id = message_id.or(id.take());
 978            content.append(chunk.clone(), &language_registry, cx);
 979            chunks.push(chunk);
 980            let idx = entries_len - 1;
 981            cx.emit(AcpThreadEvent::EntryUpdated(idx));
 982        } else {
 983            let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
 984            self.push_entry(
 985                AgentThreadEntry::UserMessage(UserMessage {
 986                    id: message_id,
 987                    content,
 988                    chunks: vec![chunk],
 989                    checkpoint: None,
 990                }),
 991                cx,
 992            );
 993        }
 994    }
 995
 996    pub fn push_assistant_content_block(
 997        &mut self,
 998        chunk: acp::ContentBlock,
 999        is_thought: bool,
1000        cx: &mut Context<Self>,
1001    ) {
1002        let language_registry = self.project.read(cx).languages().clone();
1003        let entries_len = self.entries.len();
1004        if let Some(last_entry) = self.entries.last_mut()
1005            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1006        {
1007            let idx = entries_len - 1;
1008            cx.emit(AcpThreadEvent::EntryUpdated(idx));
1009            match (chunks.last_mut(), is_thought) {
1010                (Some(AssistantMessageChunk::Message { block }), false)
1011                | (Some(AssistantMessageChunk::Thought { block }), true) => {
1012                    block.append(chunk, &language_registry, cx)
1013                }
1014                _ => {
1015                    let block = ContentBlock::new(chunk, &language_registry, cx);
1016                    if is_thought {
1017                        chunks.push(AssistantMessageChunk::Thought { block })
1018                    } else {
1019                        chunks.push(AssistantMessageChunk::Message { block })
1020                    }
1021                }
1022            }
1023        } else {
1024            let block = ContentBlock::new(chunk, &language_registry, cx);
1025            let chunk = if is_thought {
1026                AssistantMessageChunk::Thought { block }
1027            } else {
1028                AssistantMessageChunk::Message { block }
1029            };
1030
1031            self.push_entry(
1032                AgentThreadEntry::AssistantMessage(AssistantMessage {
1033                    chunks: vec![chunk],
1034                }),
1035                cx,
1036            );
1037        }
1038    }
1039
1040    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1041        self.entries.push(entry);
1042        cx.emit(AcpThreadEvent::NewEntry);
1043    }
1044
1045    pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1046        self.connection.set_title(&self.session_id, cx).is_some()
1047    }
1048
1049    pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1050        if title != self.title {
1051            self.title = title.clone();
1052            cx.emit(AcpThreadEvent::TitleUpdated);
1053            if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1054                return set_title.run(title, cx);
1055            }
1056        }
1057        Task::ready(Ok(()))
1058    }
1059
1060    pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1061        self.token_usage = usage;
1062        cx.emit(AcpThreadEvent::TokenUsageUpdated);
1063    }
1064
1065    pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1066        cx.emit(AcpThreadEvent::Retry(status));
1067    }
1068
1069    pub fn update_tool_call(
1070        &mut self,
1071        update: impl Into<ToolCallUpdate>,
1072        cx: &mut Context<Self>,
1073    ) -> Result<()> {
1074        let update = update.into();
1075        let languages = self.project.read(cx).languages().clone();
1076
1077        let (ix, current_call) = self
1078            .tool_call_mut(update.id())
1079            .context("Tool call not found")?;
1080        match update {
1081            ToolCallUpdate::UpdateFields(update) => {
1082                let location_updated = update.fields.locations.is_some();
1083                current_call.update_fields(update.fields, languages, cx);
1084                if location_updated {
1085                    self.resolve_locations(update.id, cx);
1086                }
1087            }
1088            ToolCallUpdate::UpdateDiff(update) => {
1089                current_call.content.clear();
1090                current_call
1091                    .content
1092                    .push(ToolCallContent::Diff(update.diff));
1093            }
1094            ToolCallUpdate::UpdateTerminal(update) => {
1095                current_call.content.clear();
1096                current_call
1097                    .content
1098                    .push(ToolCallContent::Terminal(update.terminal));
1099            }
1100        }
1101
1102        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1103
1104        Ok(())
1105    }
1106
1107    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1108    pub fn upsert_tool_call(
1109        &mut self,
1110        tool_call: acp::ToolCall,
1111        cx: &mut Context<Self>,
1112    ) -> Result<(), acp::Error> {
1113        let status = tool_call.status.into();
1114        self.upsert_tool_call_inner(tool_call.into(), status, cx)
1115    }
1116
1117    /// Fails if id does not match an existing entry.
1118    pub fn upsert_tool_call_inner(
1119        &mut self,
1120        tool_call_update: acp::ToolCallUpdate,
1121        status: ToolCallStatus,
1122        cx: &mut Context<Self>,
1123    ) -> Result<(), acp::Error> {
1124        let language_registry = self.project.read(cx).languages().clone();
1125        let id = tool_call_update.id.clone();
1126
1127        if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1128            current_call.update_fields(tool_call_update.fields, language_registry, cx);
1129            current_call.status = status;
1130
1131            cx.emit(AcpThreadEvent::EntryUpdated(ix));
1132        } else {
1133            let call =
1134                ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1135            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1136        };
1137
1138        self.resolve_locations(id, cx);
1139        Ok(())
1140    }
1141
1142    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1143        // The tool call we are looking for is typically the last one, or very close to the end.
1144        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1145        self.entries
1146            .iter_mut()
1147            .enumerate()
1148            .rev()
1149            .find_map(|(index, tool_call)| {
1150                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1151                    && &tool_call.id == id
1152                {
1153                    Some((index, tool_call))
1154                } else {
1155                    None
1156                }
1157            })
1158    }
1159
1160    pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1161        self.entries
1162            .iter()
1163            .enumerate()
1164            .rev()
1165            .find_map(|(index, tool_call)| {
1166                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1167                    && &tool_call.id == id
1168                {
1169                    Some((index, tool_call))
1170                } else {
1171                    None
1172                }
1173            })
1174    }
1175
1176    pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1177        let project = self.project.clone();
1178        let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1179            return;
1180        };
1181        let task = tool_call.resolve_locations(project, cx);
1182        cx.spawn(async move |this, cx| {
1183            let resolved_locations = task.await;
1184            this.update(cx, |this, cx| {
1185                let project = this.project.clone();
1186                let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1187                    return;
1188                };
1189                if let Some(Some(location)) = resolved_locations.last() {
1190                    project.update(cx, |project, cx| {
1191                        if let Some(agent_location) = project.agent_location() {
1192                            let should_ignore = agent_location.buffer == location.buffer
1193                                && location
1194                                    .buffer
1195                                    .update(cx, |buffer, _| {
1196                                        let snapshot = buffer.snapshot();
1197                                        let old_position =
1198                                            agent_location.position.to_point(&snapshot);
1199                                        let new_position = location.position.to_point(&snapshot);
1200                                        // ignore this so that when we get updates from the edit tool
1201                                        // the position doesn't reset to the startof line
1202                                        old_position.row == new_position.row
1203                                            && old_position.column > new_position.column
1204                                    })
1205                                    .ok()
1206                                    .unwrap_or_default();
1207                            if !should_ignore {
1208                                project.set_agent_location(Some(location.clone()), cx);
1209                            }
1210                        }
1211                    });
1212                }
1213                if tool_call.resolved_locations != resolved_locations {
1214                    tool_call.resolved_locations = resolved_locations;
1215                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1216                }
1217            })
1218        })
1219        .detach();
1220    }
1221
1222    pub fn request_tool_call_authorization(
1223        &mut self,
1224        tool_call: acp::ToolCallUpdate,
1225        options: Vec<acp::PermissionOption>,
1226        cx: &mut Context<Self>,
1227    ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1228        let (tx, rx) = oneshot::channel();
1229
1230        let status = ToolCallStatus::WaitingForConfirmation {
1231            options,
1232            respond_tx: tx,
1233        };
1234
1235        self.upsert_tool_call_inner(tool_call, status, cx)?;
1236        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1237        Ok(rx)
1238    }
1239
1240    pub fn authorize_tool_call(
1241        &mut self,
1242        id: acp::ToolCallId,
1243        option_id: acp::PermissionOptionId,
1244        option_kind: acp::PermissionOptionKind,
1245        cx: &mut Context<Self>,
1246    ) {
1247        let Some((ix, call)) = self.tool_call_mut(&id) else {
1248            return;
1249        };
1250
1251        let new_status = match option_kind {
1252            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1253                ToolCallStatus::Rejected
1254            }
1255            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1256                ToolCallStatus::InProgress
1257            }
1258        };
1259
1260        let curr_status = mem::replace(&mut call.status, new_status);
1261
1262        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1263            respond_tx.send(option_id).log_err();
1264        } else if cfg!(debug_assertions) {
1265            panic!("tried to authorize an already authorized tool call");
1266        }
1267
1268        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1269    }
1270
1271    /// Returns true if the last turn is awaiting tool authorization
1272    pub fn waiting_for_tool_confirmation(&self) -> bool {
1273        for entry in self.entries.iter().rev() {
1274            match &entry {
1275                AgentThreadEntry::ToolCall(call) => match call.status {
1276                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
1277                    ToolCallStatus::Pending
1278                    | ToolCallStatus::InProgress
1279                    | ToolCallStatus::Completed
1280                    | ToolCallStatus::Failed
1281                    | ToolCallStatus::Rejected
1282                    | ToolCallStatus::Canceled => continue,
1283                },
1284                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1285                    // Reached the beginning of the turn
1286                    return false;
1287                }
1288            }
1289        }
1290        false
1291    }
1292
1293    pub fn plan(&self) -> &Plan {
1294        &self.plan
1295    }
1296
1297    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1298        let new_entries_len = request.entries.len();
1299        let mut new_entries = request.entries.into_iter();
1300
1301        // Reuse existing markdown to prevent flickering
1302        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1303            let PlanEntry {
1304                content,
1305                priority,
1306                status,
1307            } = old;
1308            content.update(cx, |old, cx| {
1309                old.replace(new.content, cx);
1310            });
1311            *priority = new.priority;
1312            *status = new.status;
1313        }
1314        for new in new_entries {
1315            self.plan.entries.push(PlanEntry::from_acp(new, cx))
1316        }
1317        self.plan.entries.truncate(new_entries_len);
1318
1319        cx.notify();
1320    }
1321
1322    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1323        self.plan
1324            .entries
1325            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1326        cx.notify();
1327    }
1328
1329    #[cfg(any(test, feature = "test-support"))]
1330    pub fn send_raw(
1331        &mut self,
1332        message: &str,
1333        cx: &mut Context<Self>,
1334    ) -> BoxFuture<'static, Result<()>> {
1335        self.send(
1336            vec![acp::ContentBlock::Text(acp::TextContent {
1337                text: message.to_string(),
1338                annotations: None,
1339            })],
1340            cx,
1341        )
1342    }
1343
1344    pub fn send(
1345        &mut self,
1346        message: Vec<acp::ContentBlock>,
1347        cx: &mut Context<Self>,
1348    ) -> BoxFuture<'static, Result<()>> {
1349        let block = ContentBlock::new_combined(
1350            message.clone(),
1351            self.project.read(cx).languages().clone(),
1352            cx,
1353        );
1354        let request = acp::PromptRequest {
1355            prompt: message.clone(),
1356            session_id: self.session_id.clone(),
1357        };
1358        let git_store = self.project.read(cx).git_store().clone();
1359
1360        let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1361            Some(UserMessageId::new())
1362        } else {
1363            None
1364        };
1365
1366        self.run_turn(cx, async move |this, cx| {
1367            this.update(cx, |this, cx| {
1368                this.push_entry(
1369                    AgentThreadEntry::UserMessage(UserMessage {
1370                        id: message_id.clone(),
1371                        content: block,
1372                        chunks: message,
1373                        checkpoint: None,
1374                    }),
1375                    cx,
1376                );
1377            })
1378            .ok();
1379
1380            let old_checkpoint = git_store
1381                .update(cx, |git, cx| git.checkpoint(cx))?
1382                .await
1383                .context("failed to get old checkpoint")
1384                .log_err();
1385            this.update(cx, |this, cx| {
1386                if let Some((_ix, message)) = this.last_user_message() {
1387                    message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1388                        git_checkpoint,
1389                        show: false,
1390                    });
1391                }
1392                this.connection.prompt(message_id, request, cx)
1393            })?
1394            .await
1395        })
1396    }
1397
1398    pub fn can_resume(&self, cx: &App) -> bool {
1399        self.connection.resume(&self.session_id, cx).is_some()
1400    }
1401
1402    pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1403        self.run_turn(cx, async move |this, cx| {
1404            this.update(cx, |this, cx| {
1405                this.connection
1406                    .resume(&this.session_id, cx)
1407                    .map(|resume| resume.run(cx))
1408            })?
1409            .context("resuming a session is not supported")?
1410            .await
1411        })
1412    }
1413
1414    fn run_turn(
1415        &mut self,
1416        cx: &mut Context<Self>,
1417        f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1418    ) -> BoxFuture<'static, Result<()>> {
1419        self.clear_completed_plan_entries(cx);
1420
1421        let (tx, rx) = oneshot::channel();
1422        let cancel_task = self.cancel(cx);
1423
1424        self.send_task = Some(cx.spawn(async move |this, cx| {
1425            cancel_task.await;
1426            tx.send(f(this, cx).await).ok();
1427        }));
1428
1429        cx.spawn(async move |this, cx| {
1430            let response = rx.await;
1431
1432            this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1433                .await?;
1434
1435            this.update(cx, |this, cx| {
1436                this.project
1437                    .update(cx, |project, cx| project.set_agent_location(None, cx));
1438                match response {
1439                    Ok(Err(e)) => {
1440                        this.send_task.take();
1441                        cx.emit(AcpThreadEvent::Error);
1442                        Err(e)
1443                    }
1444                    result => {
1445                        let canceled = matches!(
1446                            result,
1447                            Ok(Ok(acp::PromptResponse {
1448                                stop_reason: acp::StopReason::Cancelled
1449                            }))
1450                        );
1451
1452                        // We only take the task if the current prompt wasn't canceled.
1453                        //
1454                        // This prompt may have been canceled because another one was sent
1455                        // while it was still generating. In these cases, dropping `send_task`
1456                        // would cause the next generation to be canceled.
1457                        if !canceled {
1458                            this.send_task.take();
1459                        }
1460
1461                        // Truncate entries if the last prompt was refused.
1462                        if let Ok(Ok(acp::PromptResponse {
1463                            stop_reason: acp::StopReason::Refusal,
1464                        })) = result
1465                            && let Some((ix, _)) = this.last_user_message()
1466                        {
1467                            let range = ix..this.entries.len();
1468                            this.entries.truncate(ix);
1469                            cx.emit(AcpThreadEvent::EntriesRemoved(range));
1470                        }
1471
1472                        cx.emit(AcpThreadEvent::Stopped);
1473                        Ok(())
1474                    }
1475                }
1476            })?
1477        })
1478        .boxed()
1479    }
1480
1481    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1482        let Some(send_task) = self.send_task.take() else {
1483            return Task::ready(());
1484        };
1485
1486        for entry in self.entries.iter_mut() {
1487            if let AgentThreadEntry::ToolCall(call) = entry {
1488                let cancel = matches!(
1489                    call.status,
1490                    ToolCallStatus::Pending
1491                        | ToolCallStatus::WaitingForConfirmation { .. }
1492                        | ToolCallStatus::InProgress
1493                );
1494
1495                if cancel {
1496                    call.status = ToolCallStatus::Canceled;
1497                }
1498            }
1499        }
1500
1501        self.connection.cancel(&self.session_id, cx);
1502
1503        // Wait for the send task to complete
1504        cx.foreground_executor().spawn(send_task)
1505    }
1506
1507    /// Rewinds this thread to before the entry at `index`, removing it and all
1508    /// subsequent entries while reverting any changes made from that point.
1509    pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1510        let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1511            return Task::ready(Err(anyhow!("not supported")));
1512        };
1513        let Some(message) = self.user_message(&id) else {
1514            return Task::ready(Err(anyhow!("message not found")));
1515        };
1516
1517        let checkpoint = message
1518            .checkpoint
1519            .as_ref()
1520            .map(|c| c.git_checkpoint.clone());
1521
1522        let git_store = self.project.read(cx).git_store().clone();
1523        cx.spawn(async move |this, cx| {
1524            if let Some(checkpoint) = checkpoint {
1525                git_store
1526                    .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1527                    .await?;
1528            }
1529
1530            cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1531            this.update(cx, |this, cx| {
1532                if let Some((ix, _)) = this.user_message_mut(&id) {
1533                    let range = ix..this.entries.len();
1534                    this.entries.truncate(ix);
1535                    cx.emit(AcpThreadEvent::EntriesRemoved(range));
1536                }
1537            })
1538        })
1539    }
1540
1541    fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1542        let git_store = self.project.read(cx).git_store().clone();
1543
1544        let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1545            if let Some(checkpoint) = message.checkpoint.as_ref() {
1546                checkpoint.git_checkpoint.clone()
1547            } else {
1548                return Task::ready(Ok(()));
1549            }
1550        } else {
1551            return Task::ready(Ok(()));
1552        };
1553
1554        let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1555        cx.spawn(async move |this, cx| {
1556            let new_checkpoint = new_checkpoint
1557                .await
1558                .context("failed to get new checkpoint")
1559                .log_err();
1560            if let Some(new_checkpoint) = new_checkpoint {
1561                let equal = git_store
1562                    .update(cx, |git, cx| {
1563                        git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1564                    })?
1565                    .await
1566                    .unwrap_or(true);
1567                this.update(cx, |this, cx| {
1568                    let (ix, message) = this.last_user_message().context("no user message")?;
1569                    let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1570                    checkpoint.show = !equal;
1571                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1572                    anyhow::Ok(())
1573                })??;
1574            }
1575
1576            Ok(())
1577        })
1578    }
1579
1580    fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1581        self.entries
1582            .iter_mut()
1583            .enumerate()
1584            .rev()
1585            .find_map(|(ix, entry)| {
1586                if let AgentThreadEntry::UserMessage(message) = entry {
1587                    Some((ix, message))
1588                } else {
1589                    None
1590                }
1591            })
1592    }
1593
1594    fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1595        self.entries.iter().find_map(|entry| {
1596            if let AgentThreadEntry::UserMessage(message) = entry {
1597                if message.id.as_ref() == Some(id) {
1598                    Some(message)
1599                } else {
1600                    None
1601                }
1602            } else {
1603                None
1604            }
1605        })
1606    }
1607
1608    fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1609        self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1610            if let AgentThreadEntry::UserMessage(message) = entry {
1611                if message.id.as_ref() == Some(id) {
1612                    Some((ix, message))
1613                } else {
1614                    None
1615                }
1616            } else {
1617                None
1618            }
1619        })
1620    }
1621
1622    pub fn read_text_file(
1623        &self,
1624        path: PathBuf,
1625        line: Option<u32>,
1626        limit: Option<u32>,
1627        reuse_shared_snapshot: bool,
1628        cx: &mut Context<Self>,
1629    ) -> Task<Result<String>> {
1630        let project = self.project.clone();
1631        let action_log = self.action_log.clone();
1632        cx.spawn(async move |this, cx| {
1633            let load = project.update(cx, |project, cx| {
1634                let path = project
1635                    .project_path_for_absolute_path(&path, cx)
1636                    .context("invalid path")?;
1637                anyhow::Ok(project.open_buffer(path, cx))
1638            });
1639            let buffer = load??.await?;
1640
1641            let snapshot = if reuse_shared_snapshot {
1642                this.read_with(cx, |this, _| {
1643                    this.shared_buffers.get(&buffer.clone()).cloned()
1644                })
1645                .log_err()
1646                .flatten()
1647            } else {
1648                None
1649            };
1650
1651            let snapshot = if let Some(snapshot) = snapshot {
1652                snapshot
1653            } else {
1654                action_log.update(cx, |action_log, cx| {
1655                    action_log.buffer_read(buffer.clone(), cx);
1656                })?;
1657                project.update(cx, |project, cx| {
1658                    let position = buffer
1659                        .read(cx)
1660                        .snapshot()
1661                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1662                    project.set_agent_location(
1663                        Some(AgentLocation {
1664                            buffer: buffer.downgrade(),
1665                            position,
1666                        }),
1667                        cx,
1668                    );
1669                })?;
1670
1671                buffer.update(cx, |buffer, _| buffer.snapshot())?
1672            };
1673
1674            this.update(cx, |this, _| {
1675                let text = snapshot.text();
1676                this.shared_buffers.insert(buffer.clone(), snapshot);
1677                if line.is_none() && limit.is_none() {
1678                    return Ok(text);
1679                }
1680                let limit = limit.unwrap_or(u32::MAX) as usize;
1681                let Some(line) = line else {
1682                    return Ok(text.lines().take(limit).collect::<String>());
1683                };
1684
1685                let count = text.lines().count();
1686                if count < line as usize {
1687                    anyhow::bail!("There are only {} lines", count);
1688                }
1689                Ok(text
1690                    .lines()
1691                    .skip(line as usize + 1)
1692                    .take(limit)
1693                    .collect::<String>())
1694            })?
1695        })
1696    }
1697
1698    pub fn write_text_file(
1699        &self,
1700        path: PathBuf,
1701        content: String,
1702        cx: &mut Context<Self>,
1703    ) -> Task<Result<()>> {
1704        let project = self.project.clone();
1705        let action_log = self.action_log.clone();
1706        cx.spawn(async move |this, cx| {
1707            let load = project.update(cx, |project, cx| {
1708                let path = project
1709                    .project_path_for_absolute_path(&path, cx)
1710                    .context("invalid path")?;
1711                anyhow::Ok(project.open_buffer(path, cx))
1712            });
1713            let buffer = load??.await?;
1714            let snapshot = this.update(cx, |this, cx| {
1715                this.shared_buffers
1716                    .get(&buffer)
1717                    .cloned()
1718                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1719            })?;
1720            let edits = cx
1721                .background_executor()
1722                .spawn(async move {
1723                    let old_text = snapshot.text();
1724                    text_diff(old_text.as_str(), &content)
1725                        .into_iter()
1726                        .map(|(range, replacement)| {
1727                            (
1728                                snapshot.anchor_after(range.start)
1729                                    ..snapshot.anchor_before(range.end),
1730                                replacement,
1731                            )
1732                        })
1733                        .collect::<Vec<_>>()
1734                })
1735                .await;
1736
1737            project.update(cx, |project, cx| {
1738                project.set_agent_location(
1739                    Some(AgentLocation {
1740                        buffer: buffer.downgrade(),
1741                        position: edits
1742                            .last()
1743                            .map(|(range, _)| range.end)
1744                            .unwrap_or(Anchor::MIN),
1745                    }),
1746                    cx,
1747                );
1748            })?;
1749
1750            let format_on_save = cx.update(|cx| {
1751                action_log.update(cx, |action_log, cx| {
1752                    action_log.buffer_read(buffer.clone(), cx);
1753                });
1754
1755                let format_on_save = buffer.update(cx, |buffer, cx| {
1756                    buffer.edit(edits, None, cx);
1757
1758                    let settings = language::language_settings::language_settings(
1759                        buffer.language().map(|l| l.name()),
1760                        buffer.file(),
1761                        cx,
1762                    );
1763
1764                    settings.format_on_save != FormatOnSave::Off
1765                });
1766                action_log.update(cx, |action_log, cx| {
1767                    action_log.buffer_edited(buffer.clone(), cx);
1768                });
1769                format_on_save
1770            })?;
1771
1772            if format_on_save {
1773                let format_task = project.update(cx, |project, cx| {
1774                    project.format(
1775                        HashSet::from_iter([buffer.clone()]),
1776                        LspFormatTarget::Buffers,
1777                        false,
1778                        FormatTrigger::Save,
1779                        cx,
1780                    )
1781                })?;
1782                format_task.await.log_err();
1783
1784                action_log.update(cx, |action_log, cx| {
1785                    action_log.buffer_edited(buffer.clone(), cx);
1786                })?;
1787            }
1788
1789            project
1790                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1791                .await
1792        })
1793    }
1794
1795    pub fn to_markdown(&self, cx: &App) -> String {
1796        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1797    }
1798
1799    pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
1800        cx.emit(AcpThreadEvent::LoadError(error));
1801    }
1802}
1803
1804fn markdown_for_raw_output(
1805    raw_output: &serde_json::Value,
1806    language_registry: &Arc<LanguageRegistry>,
1807    cx: &mut App,
1808) -> Option<Entity<Markdown>> {
1809    match raw_output {
1810        serde_json::Value::Null => None,
1811        serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1812            Markdown::new(
1813                value.to_string().into(),
1814                Some(language_registry.clone()),
1815                None,
1816                cx,
1817            )
1818        })),
1819        serde_json::Value::Number(value) => Some(cx.new(|cx| {
1820            Markdown::new(
1821                value.to_string().into(),
1822                Some(language_registry.clone()),
1823                None,
1824                cx,
1825            )
1826        })),
1827        serde_json::Value::String(value) => Some(cx.new(|cx| {
1828            Markdown::new(
1829                value.clone().into(),
1830                Some(language_registry.clone()),
1831                None,
1832                cx,
1833            )
1834        })),
1835        value => Some(cx.new(|cx| {
1836            Markdown::new(
1837                format!("```json\n{}\n```", value).into(),
1838                Some(language_registry.clone()),
1839                None,
1840                cx,
1841            )
1842        })),
1843    }
1844}
1845
1846#[cfg(test)]
1847mod tests {
1848    use super::*;
1849    use anyhow::anyhow;
1850    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1851    use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1852    use indoc::indoc;
1853    use project::{FakeFs, Fs};
1854    use rand::Rng as _;
1855    use serde_json::json;
1856    use settings::SettingsStore;
1857    use smol::stream::StreamExt as _;
1858    use std::{
1859        any::Any,
1860        cell::RefCell,
1861        path::Path,
1862        rc::Rc,
1863        sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1864        time::Duration,
1865    };
1866    use util::path;
1867
1868    fn init_test(cx: &mut TestAppContext) {
1869        env_logger::try_init().ok();
1870        cx.update(|cx| {
1871            let settings_store = SettingsStore::test(cx);
1872            cx.set_global(settings_store);
1873            Project::init_settings(cx);
1874            language::init(cx);
1875        });
1876    }
1877
1878    #[gpui::test]
1879    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1880        init_test(cx);
1881
1882        let fs = FakeFs::new(cx.executor());
1883        let project = Project::test(fs, [], cx).await;
1884        let connection = Rc::new(FakeAgentConnection::new());
1885        let thread = cx
1886            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1887            .await
1888            .unwrap();
1889
1890        // Test creating a new user message
1891        thread.update(cx, |thread, cx| {
1892            thread.push_user_content_block(
1893                None,
1894                acp::ContentBlock::Text(acp::TextContent {
1895                    annotations: None,
1896                    text: "Hello, ".to_string(),
1897                }),
1898                cx,
1899            );
1900        });
1901
1902        thread.update(cx, |thread, cx| {
1903            assert_eq!(thread.entries.len(), 1);
1904            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1905                assert_eq!(user_msg.id, None);
1906                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1907            } else {
1908                panic!("Expected UserMessage");
1909            }
1910        });
1911
1912        // Test appending to existing user message
1913        let message_1_id = UserMessageId::new();
1914        thread.update(cx, |thread, cx| {
1915            thread.push_user_content_block(
1916                Some(message_1_id.clone()),
1917                acp::ContentBlock::Text(acp::TextContent {
1918                    annotations: None,
1919                    text: "world!".to_string(),
1920                }),
1921                cx,
1922            );
1923        });
1924
1925        thread.update(cx, |thread, cx| {
1926            assert_eq!(thread.entries.len(), 1);
1927            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1928                assert_eq!(user_msg.id, Some(message_1_id));
1929                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1930            } else {
1931                panic!("Expected UserMessage");
1932            }
1933        });
1934
1935        // Test creating new user message after assistant message
1936        thread.update(cx, |thread, cx| {
1937            thread.push_assistant_content_block(
1938                acp::ContentBlock::Text(acp::TextContent {
1939                    annotations: None,
1940                    text: "Assistant response".to_string(),
1941                }),
1942                false,
1943                cx,
1944            );
1945        });
1946
1947        let message_2_id = UserMessageId::new();
1948        thread.update(cx, |thread, cx| {
1949            thread.push_user_content_block(
1950                Some(message_2_id.clone()),
1951                acp::ContentBlock::Text(acp::TextContent {
1952                    annotations: None,
1953                    text: "New user message".to_string(),
1954                }),
1955                cx,
1956            );
1957        });
1958
1959        thread.update(cx, |thread, cx| {
1960            assert_eq!(thread.entries.len(), 3);
1961            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1962                assert_eq!(user_msg.id, Some(message_2_id));
1963                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1964            } else {
1965                panic!("Expected UserMessage at index 2");
1966            }
1967        });
1968    }
1969
1970    #[gpui::test]
1971    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1972        init_test(cx);
1973
1974        let fs = FakeFs::new(cx.executor());
1975        let project = Project::test(fs, [], cx).await;
1976        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1977            |_, thread, mut cx| {
1978                async move {
1979                    thread.update(&mut cx, |thread, cx| {
1980                        thread
1981                            .handle_session_update(
1982                                acp::SessionUpdate::AgentThoughtChunk {
1983                                    content: "Thinking ".into(),
1984                                },
1985                                cx,
1986                            )
1987                            .unwrap();
1988                        thread
1989                            .handle_session_update(
1990                                acp::SessionUpdate::AgentThoughtChunk {
1991                                    content: "hard!".into(),
1992                                },
1993                                cx,
1994                            )
1995                            .unwrap();
1996                    })?;
1997                    Ok(acp::PromptResponse {
1998                        stop_reason: acp::StopReason::EndTurn,
1999                    })
2000                }
2001                .boxed_local()
2002            },
2003        ));
2004
2005        let thread = cx
2006            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2007            .await
2008            .unwrap();
2009
2010        thread
2011            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2012            .await
2013            .unwrap();
2014
2015        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2016        assert_eq!(
2017            output,
2018            indoc! {r#"
2019            ## User
2020
2021            Hello from Zed!
2022
2023            ## Assistant
2024
2025            <thinking>
2026            Thinking hard!
2027            </thinking>
2028
2029            "#}
2030        );
2031    }
2032
2033    #[gpui::test]
2034    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2035        init_test(cx);
2036
2037        let fs = FakeFs::new(cx.executor());
2038        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2039            .await;
2040        let project = Project::test(fs.clone(), [], cx).await;
2041        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2042        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2043        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2044            move |_, thread, mut cx| {
2045                let read_file_tx = read_file_tx.clone();
2046                async move {
2047                    let content = thread
2048                        .update(&mut cx, |thread, cx| {
2049                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2050                        })
2051                        .unwrap()
2052                        .await
2053                        .unwrap();
2054                    assert_eq!(content, "one\ntwo\nthree\n");
2055                    read_file_tx.take().unwrap().send(()).unwrap();
2056                    thread
2057                        .update(&mut cx, |thread, cx| {
2058                            thread.write_text_file(
2059                                path!("/tmp/foo").into(),
2060                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
2061                                cx,
2062                            )
2063                        })
2064                        .unwrap()
2065                        .await
2066                        .unwrap();
2067                    Ok(acp::PromptResponse {
2068                        stop_reason: acp::StopReason::EndTurn,
2069                    })
2070                }
2071                .boxed_local()
2072            },
2073        ));
2074
2075        let (worktree, pathbuf) = project
2076            .update(cx, |project, cx| {
2077                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2078            })
2079            .await
2080            .unwrap();
2081        let buffer = project
2082            .update(cx, |project, cx| {
2083                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2084            })
2085            .await
2086            .unwrap();
2087
2088        let thread = cx
2089            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2090            .await
2091            .unwrap();
2092
2093        let request = thread.update(cx, |thread, cx| {
2094            thread.send_raw("Extend the count in /tmp/foo", cx)
2095        });
2096        read_file_rx.await.ok();
2097        buffer.update(cx, |buffer, cx| {
2098            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2099        });
2100        cx.run_until_parked();
2101        assert_eq!(
2102            buffer.read_with(cx, |buffer, _| buffer.text()),
2103            "zero\none\ntwo\nthree\nfour\nfive\n"
2104        );
2105        assert_eq!(
2106            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2107            "zero\none\ntwo\nthree\nfour\nfive\n"
2108        );
2109        request.await.unwrap();
2110    }
2111
2112    #[gpui::test]
2113    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2114        init_test(cx);
2115
2116        let fs = FakeFs::new(cx.executor());
2117        let project = Project::test(fs, [], cx).await;
2118        let id = acp::ToolCallId("test".into());
2119
2120        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2121            let id = id.clone();
2122            move |_, thread, mut cx| {
2123                let id = id.clone();
2124                async move {
2125                    thread
2126                        .update(&mut cx, |thread, cx| {
2127                            thread.handle_session_update(
2128                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2129                                    id: id.clone(),
2130                                    title: "Label".into(),
2131                                    kind: acp::ToolKind::Fetch,
2132                                    status: acp::ToolCallStatus::InProgress,
2133                                    content: vec![],
2134                                    locations: vec![],
2135                                    raw_input: None,
2136                                    raw_output: None,
2137                                }),
2138                                cx,
2139                            )
2140                        })
2141                        .unwrap()
2142                        .unwrap();
2143                    Ok(acp::PromptResponse {
2144                        stop_reason: acp::StopReason::EndTurn,
2145                    })
2146                }
2147                .boxed_local()
2148            }
2149        }));
2150
2151        let thread = cx
2152            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2153            .await
2154            .unwrap();
2155
2156        let request = thread.update(cx, |thread, cx| {
2157            thread.send_raw("Fetch https://example.com", cx)
2158        });
2159
2160        run_until_first_tool_call(&thread, cx).await;
2161
2162        thread.read_with(cx, |thread, _| {
2163            assert!(matches!(
2164                thread.entries[1],
2165                AgentThreadEntry::ToolCall(ToolCall {
2166                    status: ToolCallStatus::InProgress,
2167                    ..
2168                })
2169            ));
2170        });
2171
2172        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2173
2174        thread.read_with(cx, |thread, _| {
2175            assert!(matches!(
2176                &thread.entries[1],
2177                AgentThreadEntry::ToolCall(ToolCall {
2178                    status: ToolCallStatus::Canceled,
2179                    ..
2180                })
2181            ));
2182        });
2183
2184        thread
2185            .update(cx, |thread, cx| {
2186                thread.handle_session_update(
2187                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2188                        id,
2189                        fields: acp::ToolCallUpdateFields {
2190                            status: Some(acp::ToolCallStatus::Completed),
2191                            ..Default::default()
2192                        },
2193                    }),
2194                    cx,
2195                )
2196            })
2197            .unwrap();
2198
2199        request.await.unwrap();
2200
2201        thread.read_with(cx, |thread, _| {
2202            assert!(matches!(
2203                thread.entries[1],
2204                AgentThreadEntry::ToolCall(ToolCall {
2205                    status: ToolCallStatus::Completed,
2206                    ..
2207                })
2208            ));
2209        });
2210    }
2211
2212    #[gpui::test]
2213    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2214        init_test(cx);
2215        let fs = FakeFs::new(cx.background_executor.clone());
2216        fs.insert_tree(path!("/test"), json!({})).await;
2217        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2218
2219        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2220            move |_, thread, mut cx| {
2221                async move {
2222                    thread
2223                        .update(&mut cx, |thread, cx| {
2224                            thread.handle_session_update(
2225                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2226                                    id: acp::ToolCallId("test".into()),
2227                                    title: "Label".into(),
2228                                    kind: acp::ToolKind::Edit,
2229                                    status: acp::ToolCallStatus::Completed,
2230                                    content: vec![acp::ToolCallContent::Diff {
2231                                        diff: acp::Diff {
2232                                            path: "/test/test.txt".into(),
2233                                            old_text: None,
2234                                            new_text: "foo".into(),
2235                                        },
2236                                    }],
2237                                    locations: vec![],
2238                                    raw_input: None,
2239                                    raw_output: None,
2240                                }),
2241                                cx,
2242                            )
2243                        })
2244                        .unwrap()
2245                        .unwrap();
2246                    Ok(acp::PromptResponse {
2247                        stop_reason: acp::StopReason::EndTurn,
2248                    })
2249                }
2250                .boxed_local()
2251            }
2252        }));
2253
2254        let thread = cx
2255            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2256            .await
2257            .unwrap();
2258
2259        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2260            .await
2261            .unwrap();
2262
2263        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2264    }
2265
2266    #[gpui::test(iterations = 10)]
2267    async fn test_checkpoints(cx: &mut TestAppContext) {
2268        init_test(cx);
2269        let fs = FakeFs::new(cx.background_executor.clone());
2270        fs.insert_tree(
2271            path!("/test"),
2272            json!({
2273                ".git": {}
2274            }),
2275        )
2276        .await;
2277        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2278
2279        let simulate_changes = Arc::new(AtomicBool::new(true));
2280        let next_filename = Arc::new(AtomicUsize::new(0));
2281        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2282            let simulate_changes = simulate_changes.clone();
2283            let next_filename = next_filename.clone();
2284            let fs = fs.clone();
2285            move |request, thread, mut cx| {
2286                let fs = fs.clone();
2287                let simulate_changes = simulate_changes.clone();
2288                let next_filename = next_filename.clone();
2289                async move {
2290                    if simulate_changes.load(SeqCst) {
2291                        let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2292                        fs.write(Path::new(&filename), b"").await?;
2293                    }
2294
2295                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2296                        panic!("expected text content block");
2297                    };
2298                    thread.update(&mut cx, |thread, cx| {
2299                        thread
2300                            .handle_session_update(
2301                                acp::SessionUpdate::AgentMessageChunk {
2302                                    content: content.text.to_uppercase().into(),
2303                                },
2304                                cx,
2305                            )
2306                            .unwrap();
2307                    })?;
2308                    Ok(acp::PromptResponse {
2309                        stop_reason: acp::StopReason::EndTurn,
2310                    })
2311                }
2312                .boxed_local()
2313            }
2314        }));
2315        let thread = cx
2316            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2317            .await
2318            .unwrap();
2319
2320        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2321            .await
2322            .unwrap();
2323        thread.read_with(cx, |thread, cx| {
2324            assert_eq!(
2325                thread.to_markdown(cx),
2326                indoc! {"
2327                    ## User (checkpoint)
2328
2329                    Lorem
2330
2331                    ## Assistant
2332
2333                    LOREM
2334
2335                "}
2336            );
2337        });
2338        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2339
2340        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2341            .await
2342            .unwrap();
2343        thread.read_with(cx, |thread, cx| {
2344            assert_eq!(
2345                thread.to_markdown(cx),
2346                indoc! {"
2347                    ## User (checkpoint)
2348
2349                    Lorem
2350
2351                    ## Assistant
2352
2353                    LOREM
2354
2355                    ## User (checkpoint)
2356
2357                    ipsum
2358
2359                    ## Assistant
2360
2361                    IPSUM
2362
2363                "}
2364            );
2365        });
2366        assert_eq!(
2367            fs.files(),
2368            vec![
2369                Path::new(path!("/test/file-0")),
2370                Path::new(path!("/test/file-1"))
2371            ]
2372        );
2373
2374        // Checkpoint isn't stored when there are no changes.
2375        simulate_changes.store(false, SeqCst);
2376        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2377            .await
2378            .unwrap();
2379        thread.read_with(cx, |thread, cx| {
2380            assert_eq!(
2381                thread.to_markdown(cx),
2382                indoc! {"
2383                    ## User (checkpoint)
2384
2385                    Lorem
2386
2387                    ## Assistant
2388
2389                    LOREM
2390
2391                    ## User (checkpoint)
2392
2393                    ipsum
2394
2395                    ## Assistant
2396
2397                    IPSUM
2398
2399                    ## User
2400
2401                    dolor
2402
2403                    ## Assistant
2404
2405                    DOLOR
2406
2407                "}
2408            );
2409        });
2410        assert_eq!(
2411            fs.files(),
2412            vec![
2413                Path::new(path!("/test/file-0")),
2414                Path::new(path!("/test/file-1"))
2415            ]
2416        );
2417
2418        // Rewinding the conversation truncates the history and restores the checkpoint.
2419        thread
2420            .update(cx, |thread, cx| {
2421                let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2422                    panic!("unexpected entries {:?}", thread.entries)
2423                };
2424                thread.rewind(message.id.clone().unwrap(), cx)
2425            })
2426            .await
2427            .unwrap();
2428        thread.read_with(cx, |thread, cx| {
2429            assert_eq!(
2430                thread.to_markdown(cx),
2431                indoc! {"
2432                    ## User (checkpoint)
2433
2434                    Lorem
2435
2436                    ## Assistant
2437
2438                    LOREM
2439
2440                "}
2441            );
2442        });
2443        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2444    }
2445
2446    #[gpui::test]
2447    async fn test_refusal(cx: &mut TestAppContext) {
2448        init_test(cx);
2449        let fs = FakeFs::new(cx.background_executor.clone());
2450        fs.insert_tree(path!("/"), json!({})).await;
2451        let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2452
2453        let refuse_next = Arc::new(AtomicBool::new(false));
2454        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2455            let refuse_next = refuse_next.clone();
2456            move |request, thread, mut cx| {
2457                let refuse_next = refuse_next.clone();
2458                async move {
2459                    if refuse_next.load(SeqCst) {
2460                        return Ok(acp::PromptResponse {
2461                            stop_reason: acp::StopReason::Refusal,
2462                        });
2463                    }
2464
2465                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2466                        panic!("expected text content block");
2467                    };
2468                    thread.update(&mut cx, |thread, cx| {
2469                        thread
2470                            .handle_session_update(
2471                                acp::SessionUpdate::AgentMessageChunk {
2472                                    content: content.text.to_uppercase().into(),
2473                                },
2474                                cx,
2475                            )
2476                            .unwrap();
2477                    })?;
2478                    Ok(acp::PromptResponse {
2479                        stop_reason: acp::StopReason::EndTurn,
2480                    })
2481                }
2482                .boxed_local()
2483            }
2484        }));
2485        let thread = cx
2486            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2487            .await
2488            .unwrap();
2489
2490        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2491            .await
2492            .unwrap();
2493        thread.read_with(cx, |thread, cx| {
2494            assert_eq!(
2495                thread.to_markdown(cx),
2496                indoc! {"
2497                    ## User
2498
2499                    hello
2500
2501                    ## Assistant
2502
2503                    HELLO
2504
2505                "}
2506            );
2507        });
2508
2509        // Simulate refusing the second message, ensuring the conversation gets
2510        // truncated to before sending it.
2511        refuse_next.store(true, SeqCst);
2512        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
2513            .await
2514            .unwrap();
2515        thread.read_with(cx, |thread, cx| {
2516            assert_eq!(
2517                thread.to_markdown(cx),
2518                indoc! {"
2519                    ## User
2520
2521                    hello
2522
2523                    ## Assistant
2524
2525                    HELLO
2526
2527                "}
2528            );
2529        });
2530    }
2531
2532    async fn run_until_first_tool_call(
2533        thread: &Entity<AcpThread>,
2534        cx: &mut TestAppContext,
2535    ) -> usize {
2536        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2537
2538        let subscription = cx.update(|cx| {
2539            cx.subscribe(thread, move |thread, _, cx| {
2540                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2541                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2542                        return tx.try_send(ix).unwrap();
2543                    }
2544                }
2545            })
2546        });
2547
2548        select! {
2549            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2550                panic!("Timeout waiting for tool call")
2551            }
2552            ix = rx.next().fuse() => {
2553                drop(subscription);
2554                ix.unwrap()
2555            }
2556        }
2557    }
2558
2559    #[derive(Clone, Default)]
2560    struct FakeAgentConnection {
2561        auth_methods: Vec<acp::AuthMethod>,
2562        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2563        on_user_message: Option<
2564            Rc<
2565                dyn Fn(
2566                        acp::PromptRequest,
2567                        WeakEntity<AcpThread>,
2568                        AsyncApp,
2569                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2570                    + 'static,
2571            >,
2572        >,
2573    }
2574
2575    impl FakeAgentConnection {
2576        fn new() -> Self {
2577            Self {
2578                auth_methods: Vec::new(),
2579                on_user_message: None,
2580                sessions: Arc::default(),
2581            }
2582        }
2583
2584        #[expect(unused)]
2585        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2586            self.auth_methods = auth_methods;
2587            self
2588        }
2589
2590        fn on_user_message(
2591            mut self,
2592            handler: impl Fn(
2593                acp::PromptRequest,
2594                WeakEntity<AcpThread>,
2595                AsyncApp,
2596            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2597            + 'static,
2598        ) -> Self {
2599            self.on_user_message.replace(Rc::new(handler));
2600            self
2601        }
2602    }
2603
2604    impl AgentConnection for FakeAgentConnection {
2605        fn auth_methods(&self) -> &[acp::AuthMethod] {
2606            &self.auth_methods
2607        }
2608
2609        fn new_thread(
2610            self: Rc<Self>,
2611            project: Entity<Project>,
2612            _cwd: &Path,
2613            cx: &mut App,
2614        ) -> Task<gpui::Result<Entity<AcpThread>>> {
2615            let session_id = acp::SessionId(
2616                rand::thread_rng()
2617                    .sample_iter(&rand::distributions::Alphanumeric)
2618                    .take(7)
2619                    .map(char::from)
2620                    .collect::<String>()
2621                    .into(),
2622            );
2623            let action_log = cx.new(|_| ActionLog::new(project.clone()));
2624            let thread = cx.new(|cx| {
2625                AcpThread::new(
2626                    "Test",
2627                    self.clone(),
2628                    project,
2629                    action_log,
2630                    session_id.clone(),
2631                    watch::Receiver::constant(acp::PromptCapabilities {
2632                        image: true,
2633                        audio: true,
2634                        embedded_context: true,
2635                    }),
2636                    cx,
2637                )
2638            });
2639            self.sessions.lock().insert(session_id, thread.downgrade());
2640            Task::ready(Ok(thread))
2641        }
2642
2643        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2644            if self.auth_methods().iter().any(|m| m.id == method) {
2645                Task::ready(Ok(()))
2646            } else {
2647                Task::ready(Err(anyhow!("Invalid Auth Method")))
2648            }
2649        }
2650
2651        fn prompt(
2652            &self,
2653            _id: Option<UserMessageId>,
2654            params: acp::PromptRequest,
2655            cx: &mut App,
2656        ) -> Task<gpui::Result<acp::PromptResponse>> {
2657            let sessions = self.sessions.lock();
2658            let thread = sessions.get(&params.session_id).unwrap();
2659            if let Some(handler) = &self.on_user_message {
2660                let handler = handler.clone();
2661                let thread = thread.clone();
2662                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2663            } else {
2664                Task::ready(Ok(acp::PromptResponse {
2665                    stop_reason: acp::StopReason::EndTurn,
2666                }))
2667            }
2668        }
2669
2670        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2671            let sessions = self.sessions.lock();
2672            let thread = sessions.get(session_id).unwrap().clone();
2673
2674            cx.spawn(async move |cx| {
2675                thread
2676                    .update(cx, |thread, cx| thread.cancel(cx))
2677                    .unwrap()
2678                    .await
2679            })
2680            .detach();
2681        }
2682
2683        fn truncate(
2684            &self,
2685            session_id: &acp::SessionId,
2686            _cx: &App,
2687        ) -> Option<Rc<dyn AgentSessionTruncate>> {
2688            Some(Rc::new(FakeAgentSessionEditor {
2689                _session_id: session_id.clone(),
2690            }))
2691        }
2692
2693        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2694            self
2695        }
2696    }
2697
2698    struct FakeAgentSessionEditor {
2699        _session_id: acp::SessionId,
2700    }
2701
2702    impl AgentSessionTruncate for FakeAgentSessionEditor {
2703        fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2704            Task::ready(Ok(()))
2705        }
2706    }
2707}