acp_thread.rs

   1mod connection;
   2mod diff;
   3mod mention;
   4mod terminal;
   5
   6use collections::HashSet;
   7pub use connection::*;
   8pub use diff::*;
   9use language::language_settings::FormatOnSave;
  10pub use mention::*;
  11use project::lsp_store::{FormatTrigger, LspFormatTarget};
  12use serde::{Deserialize, Serialize};
  13pub use terminal::*;
  14
  15use action_log::ActionLog;
  16use agent_client_protocol as acp;
  17use anyhow::{Context as _, Result, anyhow};
  18use editor::Bias;
  19use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  20use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  21use itertools::Itertools;
  22use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
  23use markdown::Markdown;
  24use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
  25use std::collections::HashMap;
  26use std::error::Error;
  27use std::fmt::{Formatter, Write};
  28use std::ops::Range;
  29use std::process::ExitStatus;
  30use std::rc::Rc;
  31use std::time::{Duration, Instant};
  32use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
  33use ui::App;
  34use util::ResultExt;
  35
  36#[derive(Debug)]
  37pub struct UserMessage {
  38    pub id: Option<UserMessageId>,
  39    pub content: ContentBlock,
  40    pub chunks: Vec<acp::ContentBlock>,
  41    pub checkpoint: Option<Checkpoint>,
  42}
  43
  44#[derive(Debug)]
  45pub struct Checkpoint {
  46    git_checkpoint: GitStoreCheckpoint,
  47    pub show: bool,
  48}
  49
  50impl UserMessage {
  51    fn to_markdown(&self, cx: &App) -> String {
  52        let mut markdown = String::new();
  53        if self
  54            .checkpoint
  55            .as_ref()
  56            .is_some_and(|checkpoint| checkpoint.show)
  57        {
  58            writeln!(markdown, "## User (checkpoint)").unwrap();
  59        } else {
  60            writeln!(markdown, "## User").unwrap();
  61        }
  62        writeln!(markdown).unwrap();
  63        writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
  64        writeln!(markdown).unwrap();
  65        markdown
  66    }
  67}
  68
  69#[derive(Debug, PartialEq)]
  70pub struct AssistantMessage {
  71    pub chunks: Vec<AssistantMessageChunk>,
  72}
  73
  74impl AssistantMessage {
  75    pub fn to_markdown(&self, cx: &App) -> String {
  76        format!(
  77            "## Assistant\n\n{}\n\n",
  78            self.chunks
  79                .iter()
  80                .map(|chunk| chunk.to_markdown(cx))
  81                .join("\n\n")
  82        )
  83    }
  84}
  85
  86#[derive(Debug, PartialEq)]
  87pub enum AssistantMessageChunk {
  88    Message { block: ContentBlock },
  89    Thought { block: ContentBlock },
  90}
  91
  92impl AssistantMessageChunk {
  93    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
  94        Self::Message {
  95            block: ContentBlock::new(chunk.into(), language_registry, cx),
  96        }
  97    }
  98
  99    fn to_markdown(&self, cx: &App) -> String {
 100        match self {
 101            Self::Message { block } => block.to_markdown(cx).to_string(),
 102            Self::Thought { block } => {
 103                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 104            }
 105        }
 106    }
 107}
 108
 109#[derive(Debug)]
 110pub enum AgentThreadEntry {
 111    UserMessage(UserMessage),
 112    AssistantMessage(AssistantMessage),
 113    ToolCall(ToolCall),
 114}
 115
 116impl AgentThreadEntry {
 117    pub fn to_markdown(&self, cx: &App) -> String {
 118        match self {
 119            Self::UserMessage(message) => message.to_markdown(cx),
 120            Self::AssistantMessage(message) => message.to_markdown(cx),
 121            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 122        }
 123    }
 124
 125    pub fn user_message(&self) -> Option<&UserMessage> {
 126        if let AgentThreadEntry::UserMessage(message) = self {
 127            Some(message)
 128        } else {
 129            None
 130        }
 131    }
 132
 133    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 134        if let AgentThreadEntry::ToolCall(call) = self {
 135            itertools::Either::Left(call.diffs())
 136        } else {
 137            itertools::Either::Right(std::iter::empty())
 138        }
 139    }
 140
 141    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 142        if let AgentThreadEntry::ToolCall(call) = self {
 143            itertools::Either::Left(call.terminals())
 144        } else {
 145            itertools::Either::Right(std::iter::empty())
 146        }
 147    }
 148
 149    pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
 150        if let AgentThreadEntry::ToolCall(ToolCall {
 151            locations,
 152            resolved_locations,
 153            ..
 154        }) = self
 155        {
 156            Some((
 157                locations.get(ix)?.clone(),
 158                resolved_locations.get(ix)?.clone()?,
 159            ))
 160        } else {
 161            None
 162        }
 163    }
 164}
 165
 166#[derive(Debug)]
 167pub struct ToolCall {
 168    pub id: acp::ToolCallId,
 169    pub label: Entity<Markdown>,
 170    pub kind: acp::ToolKind,
 171    pub content: Vec<ToolCallContent>,
 172    pub status: ToolCallStatus,
 173    pub locations: Vec<acp::ToolCallLocation>,
 174    pub resolved_locations: Vec<Option<AgentLocation>>,
 175    pub raw_input: Option<serde_json::Value>,
 176    pub raw_output: Option<serde_json::Value>,
 177}
 178
 179impl ToolCall {
 180    fn from_acp(
 181        tool_call: acp::ToolCall,
 182        status: ToolCallStatus,
 183        language_registry: Arc<LanguageRegistry>,
 184        cx: &mut App,
 185    ) -> Self {
 186        let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
 187            first_line.to_owned() + ""
 188        } else {
 189            tool_call.title
 190        };
 191        Self {
 192            id: tool_call.id,
 193            label: cx
 194                .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
 195            kind: tool_call.kind,
 196            content: tool_call
 197                .content
 198                .into_iter()
 199                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 200                .collect(),
 201            locations: tool_call.locations,
 202            resolved_locations: Vec::default(),
 203            status,
 204            raw_input: tool_call.raw_input,
 205            raw_output: tool_call.raw_output,
 206        }
 207    }
 208
 209    fn update_fields(
 210        &mut self,
 211        fields: acp::ToolCallUpdateFields,
 212        language_registry: Arc<LanguageRegistry>,
 213        cx: &mut App,
 214    ) {
 215        let acp::ToolCallUpdateFields {
 216            kind,
 217            status,
 218            title,
 219            content,
 220            locations,
 221            raw_input,
 222            raw_output,
 223        } = fields;
 224
 225        if let Some(kind) = kind {
 226            self.kind = kind;
 227        }
 228
 229        if let Some(status) = status {
 230            self.status = status.into();
 231        }
 232
 233        if let Some(title) = title {
 234            self.label.update(cx, |label, cx| {
 235                if let Some((first_line, _)) = title.split_once("\n") {
 236                    label.replace(first_line.to_owned() + "", cx)
 237                } else {
 238                    label.replace(title, cx);
 239                }
 240            });
 241        }
 242
 243        if let Some(content) = content {
 244            let new_content_len = content.len();
 245            let mut content = content.into_iter();
 246
 247            // Reuse existing content if we can
 248            for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
 249                old.update_from_acp(new, language_registry.clone(), cx);
 250            }
 251            for new in content {
 252                self.content.push(ToolCallContent::from_acp(
 253                    new,
 254                    language_registry.clone(),
 255                    cx,
 256                ))
 257            }
 258            self.content.truncate(new_content_len);
 259        }
 260
 261        if let Some(locations) = locations {
 262            self.locations = locations;
 263        }
 264
 265        if let Some(raw_input) = raw_input {
 266            self.raw_input = Some(raw_input);
 267        }
 268
 269        if let Some(raw_output) = raw_output {
 270            if self.content.is_empty()
 271                && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
 272            {
 273                self.content
 274                    .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
 275                        markdown,
 276                    }));
 277            }
 278            self.raw_output = Some(raw_output);
 279        }
 280    }
 281
 282    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 283        self.content.iter().filter_map(|content| match content {
 284            ToolCallContent::Diff(diff) => Some(diff),
 285            ToolCallContent::ContentBlock(_) => None,
 286            ToolCallContent::Terminal(_) => None,
 287        })
 288    }
 289
 290    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 291        self.content.iter().filter_map(|content| match content {
 292            ToolCallContent::Terminal(terminal) => Some(terminal),
 293            ToolCallContent::ContentBlock(_) => None,
 294            ToolCallContent::Diff(_) => None,
 295        })
 296    }
 297
 298    fn to_markdown(&self, cx: &App) -> String {
 299        let mut markdown = format!(
 300            "**Tool Call: {}**\nStatus: {}\n\n",
 301            self.label.read(cx).source(),
 302            self.status
 303        );
 304        for content in &self.content {
 305            markdown.push_str(content.to_markdown(cx).as_str());
 306            markdown.push_str("\n\n");
 307        }
 308        markdown
 309    }
 310
 311    async fn resolve_location(
 312        location: acp::ToolCallLocation,
 313        project: WeakEntity<Project>,
 314        cx: &mut AsyncApp,
 315    ) -> Option<AgentLocation> {
 316        let buffer = project
 317            .update(cx, |project, cx| {
 318                project
 319                    .project_path_for_absolute_path(&location.path, cx)
 320                    .map(|path| project.open_buffer(path, cx))
 321            })
 322            .ok()??;
 323        let buffer = buffer.await.log_err()?;
 324        let position = buffer
 325            .update(cx, |buffer, _| {
 326                if let Some(row) = location.line {
 327                    let snapshot = buffer.snapshot();
 328                    let column = snapshot.indent_size_for_line(row).len;
 329                    let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
 330                    snapshot.anchor_before(point)
 331                } else {
 332                    Anchor::MIN
 333                }
 334            })
 335            .ok()?;
 336
 337        Some(AgentLocation {
 338            buffer: buffer.downgrade(),
 339            position,
 340        })
 341    }
 342
 343    fn resolve_locations(
 344        &self,
 345        project: Entity<Project>,
 346        cx: &mut App,
 347    ) -> Task<Vec<Option<AgentLocation>>> {
 348        let locations = self.locations.clone();
 349        project.update(cx, |_, cx| {
 350            cx.spawn(async move |project, cx| {
 351                let mut new_locations = Vec::new();
 352                for location in locations {
 353                    new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
 354                }
 355                new_locations
 356            })
 357        })
 358    }
 359}
 360
 361#[derive(Debug)]
 362pub enum ToolCallStatus {
 363    /// The tool call hasn't started running yet, but we start showing it to
 364    /// the user.
 365    Pending,
 366    /// The tool call is waiting for confirmation from the user.
 367    WaitingForConfirmation {
 368        options: Vec<acp::PermissionOption>,
 369        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 370    },
 371    /// The tool call is currently running.
 372    InProgress,
 373    /// The tool call completed successfully.
 374    Completed,
 375    /// The tool call failed.
 376    Failed,
 377    /// The user rejected the tool call.
 378    Rejected,
 379    /// The user canceled generation so the tool call was canceled.
 380    Canceled,
 381}
 382
 383impl From<acp::ToolCallStatus> for ToolCallStatus {
 384    fn from(status: acp::ToolCallStatus) -> Self {
 385        match status {
 386            acp::ToolCallStatus::Pending => Self::Pending,
 387            acp::ToolCallStatus::InProgress => Self::InProgress,
 388            acp::ToolCallStatus::Completed => Self::Completed,
 389            acp::ToolCallStatus::Failed => Self::Failed,
 390        }
 391    }
 392}
 393
 394impl Display for ToolCallStatus {
 395    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 396        write!(
 397            f,
 398            "{}",
 399            match self {
 400                ToolCallStatus::Pending => "Pending",
 401                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 402                ToolCallStatus::InProgress => "In Progress",
 403                ToolCallStatus::Completed => "Completed",
 404                ToolCallStatus::Failed => "Failed",
 405                ToolCallStatus::Rejected => "Rejected",
 406                ToolCallStatus::Canceled => "Canceled",
 407            }
 408        )
 409    }
 410}
 411
 412#[derive(Debug, PartialEq, Clone)]
 413pub enum ContentBlock {
 414    Empty,
 415    Markdown { markdown: Entity<Markdown> },
 416    ResourceLink { resource_link: acp::ResourceLink },
 417}
 418
 419impl ContentBlock {
 420    pub fn new(
 421        block: acp::ContentBlock,
 422        language_registry: &Arc<LanguageRegistry>,
 423        cx: &mut App,
 424    ) -> Self {
 425        let mut this = Self::Empty;
 426        this.append(block, language_registry, cx);
 427        this
 428    }
 429
 430    pub fn new_combined(
 431        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 432        language_registry: Arc<LanguageRegistry>,
 433        cx: &mut App,
 434    ) -> Self {
 435        let mut this = Self::Empty;
 436        for block in blocks {
 437            this.append(block, &language_registry, cx);
 438        }
 439        this
 440    }
 441
 442    pub fn append(
 443        &mut self,
 444        block: acp::ContentBlock,
 445        language_registry: &Arc<LanguageRegistry>,
 446        cx: &mut App,
 447    ) {
 448        if matches!(self, ContentBlock::Empty)
 449            && let acp::ContentBlock::ResourceLink(resource_link) = block
 450        {
 451            *self = ContentBlock::ResourceLink { resource_link };
 452            return;
 453        }
 454
 455        let new_content = self.block_string_contents(block);
 456
 457        match self {
 458            ContentBlock::Empty => {
 459                *self = Self::create_markdown_block(new_content, language_registry, cx);
 460            }
 461            ContentBlock::Markdown { markdown } => {
 462                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 463            }
 464            ContentBlock::ResourceLink { resource_link } => {
 465                let existing_content = Self::resource_link_md(&resource_link.uri);
 466                let combined = format!("{}\n{}", existing_content, new_content);
 467
 468                *self = Self::create_markdown_block(combined, language_registry, cx);
 469            }
 470        }
 471    }
 472
 473    fn create_markdown_block(
 474        content: String,
 475        language_registry: &Arc<LanguageRegistry>,
 476        cx: &mut App,
 477    ) -> ContentBlock {
 478        ContentBlock::Markdown {
 479            markdown: cx
 480                .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
 481        }
 482    }
 483
 484    fn block_string_contents(&self, block: acp::ContentBlock) -> String {
 485        match block {
 486            acp::ContentBlock::Text(text_content) => text_content.text,
 487            acp::ContentBlock::ResourceLink(resource_link) => {
 488                Self::resource_link_md(&resource_link.uri)
 489            }
 490            acp::ContentBlock::Resource(acp::EmbeddedResource {
 491                resource:
 492                    acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
 493                        uri,
 494                        ..
 495                    }),
 496                ..
 497            }) => Self::resource_link_md(&uri),
 498            acp::ContentBlock::Image(image) => Self::image_md(&image),
 499            acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
 500        }
 501    }
 502
 503    fn resource_link_md(uri: &str) -> String {
 504        if let Some(uri) = MentionUri::parse(uri).log_err() {
 505            uri.as_link().to_string()
 506        } else {
 507            uri.to_string()
 508        }
 509    }
 510
 511    fn image_md(_image: &acp::ImageContent) -> String {
 512        "`Image`".into()
 513    }
 514
 515    pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 516        match self {
 517            ContentBlock::Empty => "",
 518            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 519            ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
 520        }
 521    }
 522
 523    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 524        match self {
 525            ContentBlock::Empty => None,
 526            ContentBlock::Markdown { markdown } => Some(markdown),
 527            ContentBlock::ResourceLink { .. } => None,
 528        }
 529    }
 530
 531    pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
 532        match self {
 533            ContentBlock::ResourceLink { resource_link } => Some(resource_link),
 534            _ => None,
 535        }
 536    }
 537}
 538
 539#[derive(Debug)]
 540pub enum ToolCallContent {
 541    ContentBlock(ContentBlock),
 542    Diff(Entity<Diff>),
 543    Terminal(Entity<Terminal>),
 544}
 545
 546impl ToolCallContent {
 547    pub fn from_acp(
 548        content: acp::ToolCallContent,
 549        language_registry: Arc<LanguageRegistry>,
 550        cx: &mut App,
 551    ) -> Self {
 552        match content {
 553            acp::ToolCallContent::Content { content } => {
 554                Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
 555            }
 556            acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
 557                Diff::finalized(
 558                    diff.path,
 559                    diff.old_text,
 560                    diff.new_text,
 561                    language_registry,
 562                    cx,
 563                )
 564            })),
 565        }
 566    }
 567
 568    pub fn update_from_acp(
 569        &mut self,
 570        new: acp::ToolCallContent,
 571        language_registry: Arc<LanguageRegistry>,
 572        cx: &mut App,
 573    ) {
 574        let needs_update = match (&self, &new) {
 575            (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
 576                old_diff.read(cx).needs_update(
 577                    new_diff.old_text.as_deref().unwrap_or(""),
 578                    &new_diff.new_text,
 579                    cx,
 580                )
 581            }
 582            _ => true,
 583        };
 584
 585        if needs_update {
 586            *self = Self::from_acp(new, language_registry, cx);
 587        }
 588    }
 589
 590    pub fn to_markdown(&self, cx: &App) -> String {
 591        match self {
 592            Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
 593            Self::Diff(diff) => diff.read(cx).to_markdown(cx),
 594            Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
 595        }
 596    }
 597}
 598
 599#[derive(Debug, PartialEq)]
 600pub enum ToolCallUpdate {
 601    UpdateFields(acp::ToolCallUpdate),
 602    UpdateDiff(ToolCallUpdateDiff),
 603    UpdateTerminal(ToolCallUpdateTerminal),
 604}
 605
 606impl ToolCallUpdate {
 607    fn id(&self) -> &acp::ToolCallId {
 608        match self {
 609            Self::UpdateFields(update) => &update.id,
 610            Self::UpdateDiff(diff) => &diff.id,
 611            Self::UpdateTerminal(terminal) => &terminal.id,
 612        }
 613    }
 614}
 615
 616impl From<acp::ToolCallUpdate> for ToolCallUpdate {
 617    fn from(update: acp::ToolCallUpdate) -> Self {
 618        Self::UpdateFields(update)
 619    }
 620}
 621
 622impl From<ToolCallUpdateDiff> for ToolCallUpdate {
 623    fn from(diff: ToolCallUpdateDiff) -> Self {
 624        Self::UpdateDiff(diff)
 625    }
 626}
 627
 628#[derive(Debug, PartialEq)]
 629pub struct ToolCallUpdateDiff {
 630    pub id: acp::ToolCallId,
 631    pub diff: Entity<Diff>,
 632}
 633
 634impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
 635    fn from(terminal: ToolCallUpdateTerminal) -> Self {
 636        Self::UpdateTerminal(terminal)
 637    }
 638}
 639
 640#[derive(Debug, PartialEq)]
 641pub struct ToolCallUpdateTerminal {
 642    pub id: acp::ToolCallId,
 643    pub terminal: Entity<Terminal>,
 644}
 645
 646#[derive(Debug, Default)]
 647pub struct Plan {
 648    pub entries: Vec<PlanEntry>,
 649}
 650
 651#[derive(Debug)]
 652pub struct PlanStats<'a> {
 653    pub in_progress_entry: Option<&'a PlanEntry>,
 654    pub pending: u32,
 655    pub completed: u32,
 656}
 657
 658impl Plan {
 659    pub fn is_empty(&self) -> bool {
 660        self.entries.is_empty()
 661    }
 662
 663    pub fn stats(&self) -> PlanStats<'_> {
 664        let mut stats = PlanStats {
 665            in_progress_entry: None,
 666            pending: 0,
 667            completed: 0,
 668        };
 669
 670        for entry in &self.entries {
 671            match &entry.status {
 672                acp::PlanEntryStatus::Pending => {
 673                    stats.pending += 1;
 674                }
 675                acp::PlanEntryStatus::InProgress => {
 676                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 677                }
 678                acp::PlanEntryStatus::Completed => {
 679                    stats.completed += 1;
 680                }
 681            }
 682        }
 683
 684        stats
 685    }
 686}
 687
 688#[derive(Debug)]
 689pub struct PlanEntry {
 690    pub content: Entity<Markdown>,
 691    pub priority: acp::PlanEntryPriority,
 692    pub status: acp::PlanEntryStatus,
 693}
 694
 695impl PlanEntry {
 696    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 697        Self {
 698            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 699            priority: entry.priority,
 700            status: entry.status,
 701        }
 702    }
 703}
 704
 705#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 706pub struct TokenUsage {
 707    pub max_tokens: u64,
 708    pub used_tokens: u64,
 709}
 710
 711impl TokenUsage {
 712    pub fn ratio(&self) -> TokenUsageRatio {
 713        #[cfg(debug_assertions)]
 714        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 715            .unwrap_or("0.8".to_string())
 716            .parse()
 717            .unwrap();
 718        #[cfg(not(debug_assertions))]
 719        let warning_threshold: f32 = 0.8;
 720
 721        // When the maximum is unknown because there is no selected model,
 722        // avoid showing the token limit warning.
 723        if self.max_tokens == 0 {
 724            TokenUsageRatio::Normal
 725        } else if self.used_tokens >= self.max_tokens {
 726            TokenUsageRatio::Exceeded
 727        } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
 728            TokenUsageRatio::Warning
 729        } else {
 730            TokenUsageRatio::Normal
 731        }
 732    }
 733}
 734
 735#[derive(Debug, Clone, PartialEq, Eq)]
 736pub enum TokenUsageRatio {
 737    Normal,
 738    Warning,
 739    Exceeded,
 740}
 741
 742#[derive(Debug, Clone)]
 743pub struct RetryStatus {
 744    pub last_error: SharedString,
 745    pub attempt: usize,
 746    pub max_attempts: usize,
 747    pub started_at: Instant,
 748    pub duration: Duration,
 749}
 750
 751pub struct AcpThread {
 752    title: SharedString,
 753    entries: Vec<AgentThreadEntry>,
 754    plan: Plan,
 755    project: Entity<Project>,
 756    action_log: Entity<ActionLog>,
 757    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 758    send_task: Option<Task<()>>,
 759    connection: Rc<dyn AgentConnection>,
 760    session_id: acp::SessionId,
 761    token_usage: Option<TokenUsage>,
 762    prompt_capabilities: acp::PromptCapabilities,
 763    _observe_prompt_capabilities: Task<anyhow::Result<()>>,
 764}
 765
 766#[derive(Debug)]
 767pub enum AcpThreadEvent {
 768    NewEntry,
 769    TitleUpdated,
 770    TokenUsageUpdated,
 771    EntryUpdated(usize),
 772    EntriesRemoved(Range<usize>),
 773    ToolAuthorizationRequired,
 774    Retry(RetryStatus),
 775    Stopped,
 776    Error,
 777    LoadError(LoadError),
 778    PromptCapabilitiesUpdated,
 779}
 780
 781impl EventEmitter<AcpThreadEvent> for AcpThread {}
 782
 783#[derive(PartialEq, Eq, Debug)]
 784pub enum ThreadStatus {
 785    Idle,
 786    WaitingForToolConfirmation,
 787    Generating,
 788}
 789
 790#[derive(Debug, Clone)]
 791pub enum LoadError {
 792    NotInstalled,
 793    Unsupported {
 794        command: SharedString,
 795        current_version: SharedString,
 796    },
 797    Exited {
 798        status: ExitStatus,
 799    },
 800    Other(SharedString),
 801}
 802
 803impl Display for LoadError {
 804    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 805        match self {
 806            LoadError::NotInstalled => write!(f, "not installed"),
 807            LoadError::Unsupported {
 808                command: path,
 809                current_version,
 810            } => {
 811                write!(f, "version {current_version} from {path} is not supported")
 812            }
 813            LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
 814            LoadError::Other(msg) => write!(f, "{}", msg),
 815        }
 816    }
 817}
 818
 819impl Error for LoadError {}
 820
 821impl AcpThread {
 822    pub fn new(
 823        title: impl Into<SharedString>,
 824        connection: Rc<dyn AgentConnection>,
 825        project: Entity<Project>,
 826        action_log: Entity<ActionLog>,
 827        session_id: acp::SessionId,
 828        mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
 829        cx: &mut Context<Self>,
 830    ) -> Self {
 831        let prompt_capabilities = *prompt_capabilities_rx.borrow();
 832        let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
 833            loop {
 834                let caps = prompt_capabilities_rx.recv().await?;
 835                this.update(cx, |this, cx| {
 836                    this.prompt_capabilities = caps;
 837                    cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
 838                })?;
 839            }
 840        });
 841
 842        Self {
 843            action_log,
 844            shared_buffers: Default::default(),
 845            entries: Default::default(),
 846            plan: Default::default(),
 847            title: title.into(),
 848            project,
 849            send_task: None,
 850            connection,
 851            session_id,
 852            token_usage: None,
 853            prompt_capabilities,
 854            _observe_prompt_capabilities: task,
 855        }
 856    }
 857
 858    pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
 859        self.prompt_capabilities
 860    }
 861
 862    pub fn connection(&self) -> &Rc<dyn AgentConnection> {
 863        &self.connection
 864    }
 865
 866    pub fn action_log(&self) -> &Entity<ActionLog> {
 867        &self.action_log
 868    }
 869
 870    pub fn project(&self) -> &Entity<Project> {
 871        &self.project
 872    }
 873
 874    pub fn title(&self) -> SharedString {
 875        self.title.clone()
 876    }
 877
 878    pub fn entries(&self) -> &[AgentThreadEntry] {
 879        &self.entries
 880    }
 881
 882    pub fn session_id(&self) -> &acp::SessionId {
 883        &self.session_id
 884    }
 885
 886    pub fn status(&self) -> ThreadStatus {
 887        if self.send_task.is_some() {
 888            if self.waiting_for_tool_confirmation() {
 889                ThreadStatus::WaitingForToolConfirmation
 890            } else {
 891                ThreadStatus::Generating
 892            }
 893        } else {
 894            ThreadStatus::Idle
 895        }
 896    }
 897
 898    pub fn token_usage(&self) -> Option<&TokenUsage> {
 899        self.token_usage.as_ref()
 900    }
 901
 902    pub fn has_pending_edit_tool_calls(&self) -> bool {
 903        for entry in self.entries.iter().rev() {
 904            match entry {
 905                AgentThreadEntry::UserMessage(_) => return false,
 906                AgentThreadEntry::ToolCall(
 907                    call @ ToolCall {
 908                        status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
 909                        ..
 910                    },
 911                ) if call.diffs().next().is_some() => {
 912                    return true;
 913                }
 914                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 915            }
 916        }
 917
 918        false
 919    }
 920
 921    pub fn used_tools_since_last_user_message(&self) -> bool {
 922        for entry in self.entries.iter().rev() {
 923            match entry {
 924                AgentThreadEntry::UserMessage(..) => return false,
 925                AgentThreadEntry::AssistantMessage(..) => continue,
 926                AgentThreadEntry::ToolCall(..) => return true,
 927            }
 928        }
 929
 930        false
 931    }
 932
 933    pub fn handle_session_update(
 934        &mut self,
 935        update: acp::SessionUpdate,
 936        cx: &mut Context<Self>,
 937    ) -> Result<(), acp::Error> {
 938        match update {
 939            acp::SessionUpdate::UserMessageChunk { content } => {
 940                self.push_user_content_block(None, content, cx);
 941            }
 942            acp::SessionUpdate::AgentMessageChunk { content } => {
 943                self.push_assistant_content_block(content, false, cx);
 944            }
 945            acp::SessionUpdate::AgentThoughtChunk { content } => {
 946                self.push_assistant_content_block(content, true, cx);
 947            }
 948            acp::SessionUpdate::ToolCall(tool_call) => {
 949                self.upsert_tool_call(tool_call, cx)?;
 950            }
 951            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 952                self.update_tool_call(tool_call_update, cx)?;
 953            }
 954            acp::SessionUpdate::Plan(plan) => {
 955                self.update_plan(plan, cx);
 956            }
 957        }
 958        Ok(())
 959    }
 960
 961    pub fn push_user_content_block(
 962        &mut self,
 963        message_id: Option<UserMessageId>,
 964        chunk: acp::ContentBlock,
 965        cx: &mut Context<Self>,
 966    ) {
 967        let language_registry = self.project.read(cx).languages().clone();
 968        let entries_len = self.entries.len();
 969
 970        if let Some(last_entry) = self.entries.last_mut()
 971            && let AgentThreadEntry::UserMessage(UserMessage {
 972                id,
 973                content,
 974                chunks,
 975                ..
 976            }) = last_entry
 977        {
 978            *id = message_id.or(id.take());
 979            content.append(chunk.clone(), &language_registry, cx);
 980            chunks.push(chunk);
 981            let idx = entries_len - 1;
 982            cx.emit(AcpThreadEvent::EntryUpdated(idx));
 983        } else {
 984            let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
 985            self.push_entry(
 986                AgentThreadEntry::UserMessage(UserMessage {
 987                    id: message_id,
 988                    content,
 989                    chunks: vec![chunk],
 990                    checkpoint: None,
 991                }),
 992                cx,
 993            );
 994        }
 995    }
 996
 997    pub fn push_assistant_content_block(
 998        &mut self,
 999        chunk: acp::ContentBlock,
1000        is_thought: bool,
1001        cx: &mut Context<Self>,
1002    ) {
1003        let language_registry = self.project.read(cx).languages().clone();
1004        let entries_len = self.entries.len();
1005        if let Some(last_entry) = self.entries.last_mut()
1006            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1007        {
1008            let idx = entries_len - 1;
1009            cx.emit(AcpThreadEvent::EntryUpdated(idx));
1010            match (chunks.last_mut(), is_thought) {
1011                (Some(AssistantMessageChunk::Message { block }), false)
1012                | (Some(AssistantMessageChunk::Thought { block }), true) => {
1013                    block.append(chunk, &language_registry, cx)
1014                }
1015                _ => {
1016                    let block = ContentBlock::new(chunk, &language_registry, cx);
1017                    if is_thought {
1018                        chunks.push(AssistantMessageChunk::Thought { block })
1019                    } else {
1020                        chunks.push(AssistantMessageChunk::Message { block })
1021                    }
1022                }
1023            }
1024        } else {
1025            let block = ContentBlock::new(chunk, &language_registry, cx);
1026            let chunk = if is_thought {
1027                AssistantMessageChunk::Thought { block }
1028            } else {
1029                AssistantMessageChunk::Message { block }
1030            };
1031
1032            self.push_entry(
1033                AgentThreadEntry::AssistantMessage(AssistantMessage {
1034                    chunks: vec![chunk],
1035                }),
1036                cx,
1037            );
1038        }
1039    }
1040
1041    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1042        self.entries.push(entry);
1043        cx.emit(AcpThreadEvent::NewEntry);
1044    }
1045
1046    pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1047        self.connection.set_title(&self.session_id, cx).is_some()
1048    }
1049
1050    pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1051        if title != self.title {
1052            self.title = title.clone();
1053            cx.emit(AcpThreadEvent::TitleUpdated);
1054            if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1055                return set_title.run(title, cx);
1056            }
1057        }
1058        Task::ready(Ok(()))
1059    }
1060
1061    pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1062        self.token_usage = usage;
1063        cx.emit(AcpThreadEvent::TokenUsageUpdated);
1064    }
1065
1066    pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1067        cx.emit(AcpThreadEvent::Retry(status));
1068    }
1069
1070    pub fn update_tool_call(
1071        &mut self,
1072        update: impl Into<ToolCallUpdate>,
1073        cx: &mut Context<Self>,
1074    ) -> Result<()> {
1075        let update = update.into();
1076        let languages = self.project.read(cx).languages().clone();
1077
1078        let (ix, current_call) = self
1079            .tool_call_mut(update.id())
1080            .context("Tool call not found")?;
1081        match update {
1082            ToolCallUpdate::UpdateFields(update) => {
1083                let location_updated = update.fields.locations.is_some();
1084                current_call.update_fields(update.fields, languages, cx);
1085                if location_updated {
1086                    self.resolve_locations(update.id, cx);
1087                }
1088            }
1089            ToolCallUpdate::UpdateDiff(update) => {
1090                current_call.content.clear();
1091                current_call
1092                    .content
1093                    .push(ToolCallContent::Diff(update.diff));
1094            }
1095            ToolCallUpdate::UpdateTerminal(update) => {
1096                current_call.content.clear();
1097                current_call
1098                    .content
1099                    .push(ToolCallContent::Terminal(update.terminal));
1100            }
1101        }
1102
1103        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1104
1105        Ok(())
1106    }
1107
1108    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1109    pub fn upsert_tool_call(
1110        &mut self,
1111        tool_call: acp::ToolCall,
1112        cx: &mut Context<Self>,
1113    ) -> Result<(), acp::Error> {
1114        let status = tool_call.status.into();
1115        self.upsert_tool_call_inner(tool_call.into(), status, cx)
1116    }
1117
1118    /// Fails if id does not match an existing entry.
1119    pub fn upsert_tool_call_inner(
1120        &mut self,
1121        tool_call_update: acp::ToolCallUpdate,
1122        status: ToolCallStatus,
1123        cx: &mut Context<Self>,
1124    ) -> Result<(), acp::Error> {
1125        let language_registry = self.project.read(cx).languages().clone();
1126        let id = tool_call_update.id.clone();
1127
1128        if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1129            current_call.update_fields(tool_call_update.fields, language_registry, cx);
1130            current_call.status = status;
1131
1132            cx.emit(AcpThreadEvent::EntryUpdated(ix));
1133        } else {
1134            let call =
1135                ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1136            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1137        };
1138
1139        self.resolve_locations(id, cx);
1140        Ok(())
1141    }
1142
1143    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1144        // The tool call we are looking for is typically the last one, or very close to the end.
1145        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1146        self.entries
1147            .iter_mut()
1148            .enumerate()
1149            .rev()
1150            .find_map(|(index, tool_call)| {
1151                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1152                    && &tool_call.id == id
1153                {
1154                    Some((index, tool_call))
1155                } else {
1156                    None
1157                }
1158            })
1159    }
1160
1161    pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1162        self.entries
1163            .iter()
1164            .enumerate()
1165            .rev()
1166            .find_map(|(index, tool_call)| {
1167                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1168                    && &tool_call.id == id
1169                {
1170                    Some((index, tool_call))
1171                } else {
1172                    None
1173                }
1174            })
1175    }
1176
1177    pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1178        let project = self.project.clone();
1179        let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1180            return;
1181        };
1182        let task = tool_call.resolve_locations(project, cx);
1183        cx.spawn(async move |this, cx| {
1184            let resolved_locations = task.await;
1185            this.update(cx, |this, cx| {
1186                let project = this.project.clone();
1187                let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1188                    return;
1189                };
1190                if let Some(Some(location)) = resolved_locations.last() {
1191                    project.update(cx, |project, cx| {
1192                        if let Some(agent_location) = project.agent_location() {
1193                            let should_ignore = agent_location.buffer == location.buffer
1194                                && location
1195                                    .buffer
1196                                    .update(cx, |buffer, _| {
1197                                        let snapshot = buffer.snapshot();
1198                                        let old_position =
1199                                            agent_location.position.to_point(&snapshot);
1200                                        let new_position = location.position.to_point(&snapshot);
1201                                        // ignore this so that when we get updates from the edit tool
1202                                        // the position doesn't reset to the startof line
1203                                        old_position.row == new_position.row
1204                                            && old_position.column > new_position.column
1205                                    })
1206                                    .ok()
1207                                    .unwrap_or_default();
1208                            if !should_ignore {
1209                                project.set_agent_location(Some(location.clone()), cx);
1210                            }
1211                        }
1212                    });
1213                }
1214                if tool_call.resolved_locations != resolved_locations {
1215                    tool_call.resolved_locations = resolved_locations;
1216                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1217                }
1218            })
1219        })
1220        .detach();
1221    }
1222
1223    pub fn request_tool_call_authorization(
1224        &mut self,
1225        tool_call: acp::ToolCallUpdate,
1226        options: Vec<acp::PermissionOption>,
1227        cx: &mut Context<Self>,
1228    ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1229        let (tx, rx) = oneshot::channel();
1230
1231        let status = ToolCallStatus::WaitingForConfirmation {
1232            options,
1233            respond_tx: tx,
1234        };
1235
1236        self.upsert_tool_call_inner(tool_call, status, cx)?;
1237        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1238        Ok(rx)
1239    }
1240
1241    pub fn authorize_tool_call(
1242        &mut self,
1243        id: acp::ToolCallId,
1244        option_id: acp::PermissionOptionId,
1245        option_kind: acp::PermissionOptionKind,
1246        cx: &mut Context<Self>,
1247    ) {
1248        let Some((ix, call)) = self.tool_call_mut(&id) else {
1249            return;
1250        };
1251
1252        let new_status = match option_kind {
1253            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1254                ToolCallStatus::Rejected
1255            }
1256            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1257                ToolCallStatus::InProgress
1258            }
1259        };
1260
1261        let curr_status = mem::replace(&mut call.status, new_status);
1262
1263        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1264            respond_tx.send(option_id).log_err();
1265        } else if cfg!(debug_assertions) {
1266            panic!("tried to authorize an already authorized tool call");
1267        }
1268
1269        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1270    }
1271
1272    /// Returns true if the last turn is awaiting tool authorization
1273    pub fn waiting_for_tool_confirmation(&self) -> bool {
1274        for entry in self.entries.iter().rev() {
1275            match &entry {
1276                AgentThreadEntry::ToolCall(call) => match call.status {
1277                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
1278                    ToolCallStatus::Pending
1279                    | ToolCallStatus::InProgress
1280                    | ToolCallStatus::Completed
1281                    | ToolCallStatus::Failed
1282                    | ToolCallStatus::Rejected
1283                    | ToolCallStatus::Canceled => continue,
1284                },
1285                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1286                    // Reached the beginning of the turn
1287                    return false;
1288                }
1289            }
1290        }
1291        false
1292    }
1293
1294    pub fn plan(&self) -> &Plan {
1295        &self.plan
1296    }
1297
1298    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1299        let new_entries_len = request.entries.len();
1300        let mut new_entries = request.entries.into_iter();
1301
1302        // Reuse existing markdown to prevent flickering
1303        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1304            let PlanEntry {
1305                content,
1306                priority,
1307                status,
1308            } = old;
1309            content.update(cx, |old, cx| {
1310                old.replace(new.content, cx);
1311            });
1312            *priority = new.priority;
1313            *status = new.status;
1314        }
1315        for new in new_entries {
1316            self.plan.entries.push(PlanEntry::from_acp(new, cx))
1317        }
1318        self.plan.entries.truncate(new_entries_len);
1319
1320        cx.notify();
1321    }
1322
1323    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1324        self.plan
1325            .entries
1326            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1327        cx.notify();
1328    }
1329
1330    #[cfg(any(test, feature = "test-support"))]
1331    pub fn send_raw(
1332        &mut self,
1333        message: &str,
1334        cx: &mut Context<Self>,
1335    ) -> BoxFuture<'static, Result<()>> {
1336        self.send(
1337            vec![acp::ContentBlock::Text(acp::TextContent {
1338                text: message.to_string(),
1339                annotations: None,
1340            })],
1341            cx,
1342        )
1343    }
1344
1345    pub fn send(
1346        &mut self,
1347        message: Vec<acp::ContentBlock>,
1348        cx: &mut Context<Self>,
1349    ) -> BoxFuture<'static, Result<()>> {
1350        let block = ContentBlock::new_combined(
1351            message.clone(),
1352            self.project.read(cx).languages().clone(),
1353            cx,
1354        );
1355        let request = acp::PromptRequest {
1356            prompt: message.clone(),
1357            session_id: self.session_id.clone(),
1358        };
1359        let git_store = self.project.read(cx).git_store().clone();
1360
1361        let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1362            Some(UserMessageId::new())
1363        } else {
1364            None
1365        };
1366
1367        self.run_turn(cx, async move |this, cx| {
1368            this.update(cx, |this, cx| {
1369                this.push_entry(
1370                    AgentThreadEntry::UserMessage(UserMessage {
1371                        id: message_id.clone(),
1372                        content: block,
1373                        chunks: message,
1374                        checkpoint: None,
1375                    }),
1376                    cx,
1377                );
1378            })
1379            .ok();
1380
1381            let old_checkpoint = git_store
1382                .update(cx, |git, cx| git.checkpoint(cx))?
1383                .await
1384                .context("failed to get old checkpoint")
1385                .log_err();
1386            this.update(cx, |this, cx| {
1387                if let Some((_ix, message)) = this.last_user_message() {
1388                    message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1389                        git_checkpoint,
1390                        show: false,
1391                    });
1392                }
1393                this.connection.prompt(message_id, request, cx)
1394            })?
1395            .await
1396        })
1397    }
1398
1399    pub fn can_resume(&self, cx: &App) -> bool {
1400        self.connection.resume(&self.session_id, cx).is_some()
1401    }
1402
1403    pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1404        self.run_turn(cx, async move |this, cx| {
1405            this.update(cx, |this, cx| {
1406                this.connection
1407                    .resume(&this.session_id, cx)
1408                    .map(|resume| resume.run(cx))
1409            })?
1410            .context("resuming a session is not supported")?
1411            .await
1412        })
1413    }
1414
1415    fn run_turn(
1416        &mut self,
1417        cx: &mut Context<Self>,
1418        f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1419    ) -> BoxFuture<'static, Result<()>> {
1420        self.clear_completed_plan_entries(cx);
1421
1422        let (tx, rx) = oneshot::channel();
1423        let cancel_task = self.cancel(cx);
1424
1425        self.send_task = Some(cx.spawn(async move |this, cx| {
1426            cancel_task.await;
1427            tx.send(f(this, cx).await).ok();
1428        }));
1429
1430        cx.spawn(async move |this, cx| {
1431            let response = rx.await;
1432
1433            this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1434                .await?;
1435
1436            this.update(cx, |this, cx| {
1437                this.project
1438                    .update(cx, |project, cx| project.set_agent_location(None, cx));
1439                match response {
1440                    Ok(Err(e)) => {
1441                        this.send_task.take();
1442                        cx.emit(AcpThreadEvent::Error);
1443                        Err(e)
1444                    }
1445                    result => {
1446                        let canceled = matches!(
1447                            result,
1448                            Ok(Ok(acp::PromptResponse {
1449                                stop_reason: acp::StopReason::Cancelled
1450                            }))
1451                        );
1452
1453                        // We only take the task if the current prompt wasn't canceled.
1454                        //
1455                        // This prompt may have been canceled because another one was sent
1456                        // while it was still generating. In these cases, dropping `send_task`
1457                        // would cause the next generation to be canceled.
1458                        if !canceled {
1459                            this.send_task.take();
1460                        }
1461
1462                        // Truncate entries if the last prompt was refused.
1463                        if let Ok(Ok(acp::PromptResponse {
1464                            stop_reason: acp::StopReason::Refusal,
1465                        })) = result
1466                            && let Some((ix, _)) = this.last_user_message()
1467                        {
1468                            let range = ix..this.entries.len();
1469                            this.entries.truncate(ix);
1470                            cx.emit(AcpThreadEvent::EntriesRemoved(range));
1471                        }
1472
1473                        cx.emit(AcpThreadEvent::Stopped);
1474                        Ok(())
1475                    }
1476                }
1477            })?
1478        })
1479        .boxed()
1480    }
1481
1482    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1483        let Some(send_task) = self.send_task.take() else {
1484            return Task::ready(());
1485        };
1486
1487        for entry in self.entries.iter_mut() {
1488            if let AgentThreadEntry::ToolCall(call) = entry {
1489                let cancel = matches!(
1490                    call.status,
1491                    ToolCallStatus::Pending
1492                        | ToolCallStatus::WaitingForConfirmation { .. }
1493                        | ToolCallStatus::InProgress
1494                );
1495
1496                if cancel {
1497                    call.status = ToolCallStatus::Canceled;
1498                }
1499            }
1500        }
1501
1502        self.connection.cancel(&self.session_id, cx);
1503
1504        // Wait for the send task to complete
1505        cx.foreground_executor().spawn(send_task)
1506    }
1507
1508    /// Rewinds this thread to before the entry at `index`, removing it and all
1509    /// subsequent entries while reverting any changes made from that point.
1510    pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1511        let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1512            return Task::ready(Err(anyhow!("not supported")));
1513        };
1514        let Some(message) = self.user_message(&id) else {
1515            return Task::ready(Err(anyhow!("message not found")));
1516        };
1517
1518        let checkpoint = message
1519            .checkpoint
1520            .as_ref()
1521            .map(|c| c.git_checkpoint.clone());
1522
1523        let git_store = self.project.read(cx).git_store().clone();
1524        cx.spawn(async move |this, cx| {
1525            if let Some(checkpoint) = checkpoint {
1526                git_store
1527                    .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1528                    .await?;
1529            }
1530
1531            cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1532            this.update(cx, |this, cx| {
1533                if let Some((ix, _)) = this.user_message_mut(&id) {
1534                    let range = ix..this.entries.len();
1535                    this.entries.truncate(ix);
1536                    cx.emit(AcpThreadEvent::EntriesRemoved(range));
1537                }
1538            })
1539        })
1540    }
1541
1542    fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1543        let git_store = self.project.read(cx).git_store().clone();
1544
1545        let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1546            if let Some(checkpoint) = message.checkpoint.as_ref() {
1547                checkpoint.git_checkpoint.clone()
1548            } else {
1549                return Task::ready(Ok(()));
1550            }
1551        } else {
1552            return Task::ready(Ok(()));
1553        };
1554
1555        let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1556        cx.spawn(async move |this, cx| {
1557            let new_checkpoint = new_checkpoint
1558                .await
1559                .context("failed to get new checkpoint")
1560                .log_err();
1561            if let Some(new_checkpoint) = new_checkpoint {
1562                let equal = git_store
1563                    .update(cx, |git, cx| {
1564                        git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1565                    })?
1566                    .await
1567                    .unwrap_or(true);
1568                this.update(cx, |this, cx| {
1569                    let (ix, message) = this.last_user_message().context("no user message")?;
1570                    let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1571                    checkpoint.show = !equal;
1572                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1573                    anyhow::Ok(())
1574                })??;
1575            }
1576
1577            Ok(())
1578        })
1579    }
1580
1581    fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1582        self.entries
1583            .iter_mut()
1584            .enumerate()
1585            .rev()
1586            .find_map(|(ix, entry)| {
1587                if let AgentThreadEntry::UserMessage(message) = entry {
1588                    Some((ix, message))
1589                } else {
1590                    None
1591                }
1592            })
1593    }
1594
1595    fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1596        self.entries.iter().find_map(|entry| {
1597            if let AgentThreadEntry::UserMessage(message) = entry {
1598                if message.id.as_ref() == Some(id) {
1599                    Some(message)
1600                } else {
1601                    None
1602                }
1603            } else {
1604                None
1605            }
1606        })
1607    }
1608
1609    fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1610        self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1611            if let AgentThreadEntry::UserMessage(message) = entry {
1612                if message.id.as_ref() == Some(id) {
1613                    Some((ix, message))
1614                } else {
1615                    None
1616                }
1617            } else {
1618                None
1619            }
1620        })
1621    }
1622
1623    pub fn read_text_file(
1624        &self,
1625        path: PathBuf,
1626        line: Option<u32>,
1627        limit: Option<u32>,
1628        reuse_shared_snapshot: bool,
1629        cx: &mut Context<Self>,
1630    ) -> Task<Result<String>> {
1631        let project = self.project.clone();
1632        let action_log = self.action_log.clone();
1633        cx.spawn(async move |this, cx| {
1634            let load = project.update(cx, |project, cx| {
1635                let path = project
1636                    .project_path_for_absolute_path(&path, cx)
1637                    .context("invalid path")?;
1638                anyhow::Ok(project.open_buffer(path, cx))
1639            });
1640            let buffer = load??.await?;
1641
1642            let snapshot = if reuse_shared_snapshot {
1643                this.read_with(cx, |this, _| {
1644                    this.shared_buffers.get(&buffer.clone()).cloned()
1645                })
1646                .log_err()
1647                .flatten()
1648            } else {
1649                None
1650            };
1651
1652            let snapshot = if let Some(snapshot) = snapshot {
1653                snapshot
1654            } else {
1655                action_log.update(cx, |action_log, cx| {
1656                    action_log.buffer_read(buffer.clone(), cx);
1657                })?;
1658                project.update(cx, |project, cx| {
1659                    let position = buffer
1660                        .read(cx)
1661                        .snapshot()
1662                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1663                    project.set_agent_location(
1664                        Some(AgentLocation {
1665                            buffer: buffer.downgrade(),
1666                            position,
1667                        }),
1668                        cx,
1669                    );
1670                })?;
1671
1672                buffer.update(cx, |buffer, _| buffer.snapshot())?
1673            };
1674
1675            this.update(cx, |this, _| {
1676                let text = snapshot.text();
1677                this.shared_buffers.insert(buffer.clone(), snapshot);
1678                if line.is_none() && limit.is_none() {
1679                    return Ok(text);
1680                }
1681                let limit = limit.unwrap_or(u32::MAX) as usize;
1682                let Some(line) = line else {
1683                    return Ok(text.lines().take(limit).collect::<String>());
1684                };
1685
1686                let count = text.lines().count();
1687                if count < line as usize {
1688                    anyhow::bail!("There are only {} lines", count);
1689                }
1690                Ok(text
1691                    .lines()
1692                    .skip(line as usize + 1)
1693                    .take(limit)
1694                    .collect::<String>())
1695            })?
1696        })
1697    }
1698
1699    pub fn write_text_file(
1700        &self,
1701        path: PathBuf,
1702        content: String,
1703        cx: &mut Context<Self>,
1704    ) -> Task<Result<()>> {
1705        let project = self.project.clone();
1706        let action_log = self.action_log.clone();
1707        cx.spawn(async move |this, cx| {
1708            let load = project.update(cx, |project, cx| {
1709                let path = project
1710                    .project_path_for_absolute_path(&path, cx)
1711                    .context("invalid path")?;
1712                anyhow::Ok(project.open_buffer(path, cx))
1713            });
1714            let buffer = load??.await?;
1715            let snapshot = this.update(cx, |this, cx| {
1716                this.shared_buffers
1717                    .get(&buffer)
1718                    .cloned()
1719                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1720            })?;
1721            let edits = cx
1722                .background_executor()
1723                .spawn(async move {
1724                    let old_text = snapshot.text();
1725                    text_diff(old_text.as_str(), &content)
1726                        .into_iter()
1727                        .map(|(range, replacement)| {
1728                            (
1729                                snapshot.anchor_after(range.start)
1730                                    ..snapshot.anchor_before(range.end),
1731                                replacement,
1732                            )
1733                        })
1734                        .collect::<Vec<_>>()
1735                })
1736                .await;
1737
1738            project.update(cx, |project, cx| {
1739                project.set_agent_location(
1740                    Some(AgentLocation {
1741                        buffer: buffer.downgrade(),
1742                        position: edits
1743                            .last()
1744                            .map(|(range, _)| range.end)
1745                            .unwrap_or(Anchor::MIN),
1746                    }),
1747                    cx,
1748                );
1749            })?;
1750
1751            let format_on_save = cx.update(|cx| {
1752                action_log.update(cx, |action_log, cx| {
1753                    action_log.buffer_read(buffer.clone(), cx);
1754                });
1755
1756                let format_on_save = buffer.update(cx, |buffer, cx| {
1757                    buffer.edit(edits, None, cx);
1758
1759                    let settings = language::language_settings::language_settings(
1760                        buffer.language().map(|l| l.name()),
1761                        buffer.file(),
1762                        cx,
1763                    );
1764
1765                    settings.format_on_save != FormatOnSave::Off
1766                });
1767                action_log.update(cx, |action_log, cx| {
1768                    action_log.buffer_edited(buffer.clone(), cx);
1769                });
1770                format_on_save
1771            })?;
1772
1773            if format_on_save {
1774                let format_task = project.update(cx, |project, cx| {
1775                    project.format(
1776                        HashSet::from_iter([buffer.clone()]),
1777                        LspFormatTarget::Buffers,
1778                        false,
1779                        FormatTrigger::Save,
1780                        cx,
1781                    )
1782                })?;
1783                format_task.await.log_err();
1784
1785                action_log.update(cx, |action_log, cx| {
1786                    action_log.buffer_edited(buffer.clone(), cx);
1787                })?;
1788            }
1789
1790            project
1791                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1792                .await
1793        })
1794    }
1795
1796    pub fn to_markdown(&self, cx: &App) -> String {
1797        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1798    }
1799
1800    pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
1801        cx.emit(AcpThreadEvent::LoadError(error));
1802    }
1803}
1804
1805fn markdown_for_raw_output(
1806    raw_output: &serde_json::Value,
1807    language_registry: &Arc<LanguageRegistry>,
1808    cx: &mut App,
1809) -> Option<Entity<Markdown>> {
1810    match raw_output {
1811        serde_json::Value::Null => None,
1812        serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1813            Markdown::new(
1814                value.to_string().into(),
1815                Some(language_registry.clone()),
1816                None,
1817                cx,
1818            )
1819        })),
1820        serde_json::Value::Number(value) => Some(cx.new(|cx| {
1821            Markdown::new(
1822                value.to_string().into(),
1823                Some(language_registry.clone()),
1824                None,
1825                cx,
1826            )
1827        })),
1828        serde_json::Value::String(value) => Some(cx.new(|cx| {
1829            Markdown::new(
1830                value.clone().into(),
1831                Some(language_registry.clone()),
1832                None,
1833                cx,
1834            )
1835        })),
1836        value => Some(cx.new(|cx| {
1837            Markdown::new(
1838                format!("```json\n{}\n```", value).into(),
1839                Some(language_registry.clone()),
1840                None,
1841                cx,
1842            )
1843        })),
1844    }
1845}
1846
1847#[cfg(test)]
1848mod tests {
1849    use super::*;
1850    use anyhow::anyhow;
1851    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1852    use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1853    use indoc::indoc;
1854    use project::{FakeFs, Fs};
1855    use rand::Rng as _;
1856    use serde_json::json;
1857    use settings::SettingsStore;
1858    use smol::stream::StreamExt as _;
1859    use std::{
1860        any::Any,
1861        cell::RefCell,
1862        path::Path,
1863        rc::Rc,
1864        sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1865        time::Duration,
1866    };
1867    use util::path;
1868
1869    fn init_test(cx: &mut TestAppContext) {
1870        env_logger::try_init().ok();
1871        cx.update(|cx| {
1872            let settings_store = SettingsStore::test(cx);
1873            cx.set_global(settings_store);
1874            Project::init_settings(cx);
1875            language::init(cx);
1876        });
1877    }
1878
1879    #[gpui::test]
1880    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1881        init_test(cx);
1882
1883        let fs = FakeFs::new(cx.executor());
1884        let project = Project::test(fs, [], cx).await;
1885        let connection = Rc::new(FakeAgentConnection::new());
1886        let thread = cx
1887            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1888            .await
1889            .unwrap();
1890
1891        // Test creating a new user message
1892        thread.update(cx, |thread, cx| {
1893            thread.push_user_content_block(
1894                None,
1895                acp::ContentBlock::Text(acp::TextContent {
1896                    annotations: None,
1897                    text: "Hello, ".to_string(),
1898                }),
1899                cx,
1900            );
1901        });
1902
1903        thread.update(cx, |thread, cx| {
1904            assert_eq!(thread.entries.len(), 1);
1905            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1906                assert_eq!(user_msg.id, None);
1907                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1908            } else {
1909                panic!("Expected UserMessage");
1910            }
1911        });
1912
1913        // Test appending to existing user message
1914        let message_1_id = UserMessageId::new();
1915        thread.update(cx, |thread, cx| {
1916            thread.push_user_content_block(
1917                Some(message_1_id.clone()),
1918                acp::ContentBlock::Text(acp::TextContent {
1919                    annotations: None,
1920                    text: "world!".to_string(),
1921                }),
1922                cx,
1923            );
1924        });
1925
1926        thread.update(cx, |thread, cx| {
1927            assert_eq!(thread.entries.len(), 1);
1928            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1929                assert_eq!(user_msg.id, Some(message_1_id));
1930                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1931            } else {
1932                panic!("Expected UserMessage");
1933            }
1934        });
1935
1936        // Test creating new user message after assistant message
1937        thread.update(cx, |thread, cx| {
1938            thread.push_assistant_content_block(
1939                acp::ContentBlock::Text(acp::TextContent {
1940                    annotations: None,
1941                    text: "Assistant response".to_string(),
1942                }),
1943                false,
1944                cx,
1945            );
1946        });
1947
1948        let message_2_id = UserMessageId::new();
1949        thread.update(cx, |thread, cx| {
1950            thread.push_user_content_block(
1951                Some(message_2_id.clone()),
1952                acp::ContentBlock::Text(acp::TextContent {
1953                    annotations: None,
1954                    text: "New user message".to_string(),
1955                }),
1956                cx,
1957            );
1958        });
1959
1960        thread.update(cx, |thread, cx| {
1961            assert_eq!(thread.entries.len(), 3);
1962            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1963                assert_eq!(user_msg.id, Some(message_2_id));
1964                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1965            } else {
1966                panic!("Expected UserMessage at index 2");
1967            }
1968        });
1969    }
1970
1971    #[gpui::test]
1972    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1973        init_test(cx);
1974
1975        let fs = FakeFs::new(cx.executor());
1976        let project = Project::test(fs, [], cx).await;
1977        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1978            |_, thread, mut cx| {
1979                async move {
1980                    thread.update(&mut cx, |thread, cx| {
1981                        thread
1982                            .handle_session_update(
1983                                acp::SessionUpdate::AgentThoughtChunk {
1984                                    content: "Thinking ".into(),
1985                                },
1986                                cx,
1987                            )
1988                            .unwrap();
1989                        thread
1990                            .handle_session_update(
1991                                acp::SessionUpdate::AgentThoughtChunk {
1992                                    content: "hard!".into(),
1993                                },
1994                                cx,
1995                            )
1996                            .unwrap();
1997                    })?;
1998                    Ok(acp::PromptResponse {
1999                        stop_reason: acp::StopReason::EndTurn,
2000                    })
2001                }
2002                .boxed_local()
2003            },
2004        ));
2005
2006        let thread = cx
2007            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2008            .await
2009            .unwrap();
2010
2011        thread
2012            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2013            .await
2014            .unwrap();
2015
2016        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2017        assert_eq!(
2018            output,
2019            indoc! {r#"
2020            ## User
2021
2022            Hello from Zed!
2023
2024            ## Assistant
2025
2026            <thinking>
2027            Thinking hard!
2028            </thinking>
2029
2030            "#}
2031        );
2032    }
2033
2034    #[gpui::test]
2035    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2036        init_test(cx);
2037
2038        let fs = FakeFs::new(cx.executor());
2039        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2040            .await;
2041        let project = Project::test(fs.clone(), [], cx).await;
2042        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2043        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2044        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2045            move |_, thread, mut cx| {
2046                let read_file_tx = read_file_tx.clone();
2047                async move {
2048                    let content = thread
2049                        .update(&mut cx, |thread, cx| {
2050                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2051                        })
2052                        .unwrap()
2053                        .await
2054                        .unwrap();
2055                    assert_eq!(content, "one\ntwo\nthree\n");
2056                    read_file_tx.take().unwrap().send(()).unwrap();
2057                    thread
2058                        .update(&mut cx, |thread, cx| {
2059                            thread.write_text_file(
2060                                path!("/tmp/foo").into(),
2061                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
2062                                cx,
2063                            )
2064                        })
2065                        .unwrap()
2066                        .await
2067                        .unwrap();
2068                    Ok(acp::PromptResponse {
2069                        stop_reason: acp::StopReason::EndTurn,
2070                    })
2071                }
2072                .boxed_local()
2073            },
2074        ));
2075
2076        let (worktree, pathbuf) = project
2077            .update(cx, |project, cx| {
2078                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2079            })
2080            .await
2081            .unwrap();
2082        let buffer = project
2083            .update(cx, |project, cx| {
2084                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2085            })
2086            .await
2087            .unwrap();
2088
2089        let thread = cx
2090            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2091            .await
2092            .unwrap();
2093
2094        let request = thread.update(cx, |thread, cx| {
2095            thread.send_raw("Extend the count in /tmp/foo", cx)
2096        });
2097        read_file_rx.await.ok();
2098        buffer.update(cx, |buffer, cx| {
2099            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2100        });
2101        cx.run_until_parked();
2102        assert_eq!(
2103            buffer.read_with(cx, |buffer, _| buffer.text()),
2104            "zero\none\ntwo\nthree\nfour\nfive\n"
2105        );
2106        assert_eq!(
2107            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2108            "zero\none\ntwo\nthree\nfour\nfive\n"
2109        );
2110        request.await.unwrap();
2111    }
2112
2113    #[gpui::test]
2114    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2115        init_test(cx);
2116
2117        let fs = FakeFs::new(cx.executor());
2118        let project = Project::test(fs, [], cx).await;
2119        let id = acp::ToolCallId("test".into());
2120
2121        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2122            let id = id.clone();
2123            move |_, thread, mut cx| {
2124                let id = id.clone();
2125                async move {
2126                    thread
2127                        .update(&mut cx, |thread, cx| {
2128                            thread.handle_session_update(
2129                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2130                                    id: id.clone(),
2131                                    title: "Label".into(),
2132                                    kind: acp::ToolKind::Fetch,
2133                                    status: acp::ToolCallStatus::InProgress,
2134                                    content: vec![],
2135                                    locations: vec![],
2136                                    raw_input: None,
2137                                    raw_output: None,
2138                                }),
2139                                cx,
2140                            )
2141                        })
2142                        .unwrap()
2143                        .unwrap();
2144                    Ok(acp::PromptResponse {
2145                        stop_reason: acp::StopReason::EndTurn,
2146                    })
2147                }
2148                .boxed_local()
2149            }
2150        }));
2151
2152        let thread = cx
2153            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2154            .await
2155            .unwrap();
2156
2157        let request = thread.update(cx, |thread, cx| {
2158            thread.send_raw("Fetch https://example.com", cx)
2159        });
2160
2161        run_until_first_tool_call(&thread, cx).await;
2162
2163        thread.read_with(cx, |thread, _| {
2164            assert!(matches!(
2165                thread.entries[1],
2166                AgentThreadEntry::ToolCall(ToolCall {
2167                    status: ToolCallStatus::InProgress,
2168                    ..
2169                })
2170            ));
2171        });
2172
2173        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2174
2175        thread.read_with(cx, |thread, _| {
2176            assert!(matches!(
2177                &thread.entries[1],
2178                AgentThreadEntry::ToolCall(ToolCall {
2179                    status: ToolCallStatus::Canceled,
2180                    ..
2181                })
2182            ));
2183        });
2184
2185        thread
2186            .update(cx, |thread, cx| {
2187                thread.handle_session_update(
2188                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2189                        id,
2190                        fields: acp::ToolCallUpdateFields {
2191                            status: Some(acp::ToolCallStatus::Completed),
2192                            ..Default::default()
2193                        },
2194                    }),
2195                    cx,
2196                )
2197            })
2198            .unwrap();
2199
2200        request.await.unwrap();
2201
2202        thread.read_with(cx, |thread, _| {
2203            assert!(matches!(
2204                thread.entries[1],
2205                AgentThreadEntry::ToolCall(ToolCall {
2206                    status: ToolCallStatus::Completed,
2207                    ..
2208                })
2209            ));
2210        });
2211    }
2212
2213    #[gpui::test]
2214    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2215        init_test(cx);
2216        let fs = FakeFs::new(cx.background_executor.clone());
2217        fs.insert_tree(path!("/test"), json!({})).await;
2218        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2219
2220        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2221            move |_, thread, mut cx| {
2222                async move {
2223                    thread
2224                        .update(&mut cx, |thread, cx| {
2225                            thread.handle_session_update(
2226                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2227                                    id: acp::ToolCallId("test".into()),
2228                                    title: "Label".into(),
2229                                    kind: acp::ToolKind::Edit,
2230                                    status: acp::ToolCallStatus::Completed,
2231                                    content: vec![acp::ToolCallContent::Diff {
2232                                        diff: acp::Diff {
2233                                            path: "/test/test.txt".into(),
2234                                            old_text: None,
2235                                            new_text: "foo".into(),
2236                                        },
2237                                    }],
2238                                    locations: vec![],
2239                                    raw_input: None,
2240                                    raw_output: None,
2241                                }),
2242                                cx,
2243                            )
2244                        })
2245                        .unwrap()
2246                        .unwrap();
2247                    Ok(acp::PromptResponse {
2248                        stop_reason: acp::StopReason::EndTurn,
2249                    })
2250                }
2251                .boxed_local()
2252            }
2253        }));
2254
2255        let thread = cx
2256            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2257            .await
2258            .unwrap();
2259
2260        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2261            .await
2262            .unwrap();
2263
2264        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2265    }
2266
2267    #[gpui::test(iterations = 10)]
2268    async fn test_checkpoints(cx: &mut TestAppContext) {
2269        init_test(cx);
2270        let fs = FakeFs::new(cx.background_executor.clone());
2271        fs.insert_tree(
2272            path!("/test"),
2273            json!({
2274                ".git": {}
2275            }),
2276        )
2277        .await;
2278        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2279
2280        let simulate_changes = Arc::new(AtomicBool::new(true));
2281        let next_filename = Arc::new(AtomicUsize::new(0));
2282        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2283            let simulate_changes = simulate_changes.clone();
2284            let next_filename = next_filename.clone();
2285            let fs = fs.clone();
2286            move |request, thread, mut cx| {
2287                let fs = fs.clone();
2288                let simulate_changes = simulate_changes.clone();
2289                let next_filename = next_filename.clone();
2290                async move {
2291                    if simulate_changes.load(SeqCst) {
2292                        let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2293                        fs.write(Path::new(&filename), b"").await?;
2294                    }
2295
2296                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2297                        panic!("expected text content block");
2298                    };
2299                    thread.update(&mut cx, |thread, cx| {
2300                        thread
2301                            .handle_session_update(
2302                                acp::SessionUpdate::AgentMessageChunk {
2303                                    content: content.text.to_uppercase().into(),
2304                                },
2305                                cx,
2306                            )
2307                            .unwrap();
2308                    })?;
2309                    Ok(acp::PromptResponse {
2310                        stop_reason: acp::StopReason::EndTurn,
2311                    })
2312                }
2313                .boxed_local()
2314            }
2315        }));
2316        let thread = cx
2317            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2318            .await
2319            .unwrap();
2320
2321        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2322            .await
2323            .unwrap();
2324        thread.read_with(cx, |thread, cx| {
2325            assert_eq!(
2326                thread.to_markdown(cx),
2327                indoc! {"
2328                    ## User (checkpoint)
2329
2330                    Lorem
2331
2332                    ## Assistant
2333
2334                    LOREM
2335
2336                "}
2337            );
2338        });
2339        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2340
2341        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2342            .await
2343            .unwrap();
2344        thread.read_with(cx, |thread, cx| {
2345            assert_eq!(
2346                thread.to_markdown(cx),
2347                indoc! {"
2348                    ## User (checkpoint)
2349
2350                    Lorem
2351
2352                    ## Assistant
2353
2354                    LOREM
2355
2356                    ## User (checkpoint)
2357
2358                    ipsum
2359
2360                    ## Assistant
2361
2362                    IPSUM
2363
2364                "}
2365            );
2366        });
2367        assert_eq!(
2368            fs.files(),
2369            vec![
2370                Path::new(path!("/test/file-0")),
2371                Path::new(path!("/test/file-1"))
2372            ]
2373        );
2374
2375        // Checkpoint isn't stored when there are no changes.
2376        simulate_changes.store(false, SeqCst);
2377        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2378            .await
2379            .unwrap();
2380        thread.read_with(cx, |thread, cx| {
2381            assert_eq!(
2382                thread.to_markdown(cx),
2383                indoc! {"
2384                    ## User (checkpoint)
2385
2386                    Lorem
2387
2388                    ## Assistant
2389
2390                    LOREM
2391
2392                    ## User (checkpoint)
2393
2394                    ipsum
2395
2396                    ## Assistant
2397
2398                    IPSUM
2399
2400                    ## User
2401
2402                    dolor
2403
2404                    ## Assistant
2405
2406                    DOLOR
2407
2408                "}
2409            );
2410        });
2411        assert_eq!(
2412            fs.files(),
2413            vec![
2414                Path::new(path!("/test/file-0")),
2415                Path::new(path!("/test/file-1"))
2416            ]
2417        );
2418
2419        // Rewinding the conversation truncates the history and restores the checkpoint.
2420        thread
2421            .update(cx, |thread, cx| {
2422                let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2423                    panic!("unexpected entries {:?}", thread.entries)
2424                };
2425                thread.rewind(message.id.clone().unwrap(), cx)
2426            })
2427            .await
2428            .unwrap();
2429        thread.read_with(cx, |thread, cx| {
2430            assert_eq!(
2431                thread.to_markdown(cx),
2432                indoc! {"
2433                    ## User (checkpoint)
2434
2435                    Lorem
2436
2437                    ## Assistant
2438
2439                    LOREM
2440
2441                "}
2442            );
2443        });
2444        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2445    }
2446
2447    #[gpui::test]
2448    async fn test_refusal(cx: &mut TestAppContext) {
2449        init_test(cx);
2450        let fs = FakeFs::new(cx.background_executor.clone());
2451        fs.insert_tree(path!("/"), json!({})).await;
2452        let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2453
2454        let refuse_next = Arc::new(AtomicBool::new(false));
2455        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2456            let refuse_next = refuse_next.clone();
2457            move |request, thread, mut cx| {
2458                let refuse_next = refuse_next.clone();
2459                async move {
2460                    if refuse_next.load(SeqCst) {
2461                        return Ok(acp::PromptResponse {
2462                            stop_reason: acp::StopReason::Refusal,
2463                        });
2464                    }
2465
2466                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2467                        panic!("expected text content block");
2468                    };
2469                    thread.update(&mut cx, |thread, cx| {
2470                        thread
2471                            .handle_session_update(
2472                                acp::SessionUpdate::AgentMessageChunk {
2473                                    content: content.text.to_uppercase().into(),
2474                                },
2475                                cx,
2476                            )
2477                            .unwrap();
2478                    })?;
2479                    Ok(acp::PromptResponse {
2480                        stop_reason: acp::StopReason::EndTurn,
2481                    })
2482                }
2483                .boxed_local()
2484            }
2485        }));
2486        let thread = cx
2487            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2488            .await
2489            .unwrap();
2490
2491        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2492            .await
2493            .unwrap();
2494        thread.read_with(cx, |thread, cx| {
2495            assert_eq!(
2496                thread.to_markdown(cx),
2497                indoc! {"
2498                    ## User
2499
2500                    hello
2501
2502                    ## Assistant
2503
2504                    HELLO
2505
2506                "}
2507            );
2508        });
2509
2510        // Simulate refusing the second message, ensuring the conversation gets
2511        // truncated to before sending it.
2512        refuse_next.store(true, SeqCst);
2513        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
2514            .await
2515            .unwrap();
2516        thread.read_with(cx, |thread, cx| {
2517            assert_eq!(
2518                thread.to_markdown(cx),
2519                indoc! {"
2520                    ## User
2521
2522                    hello
2523
2524                    ## Assistant
2525
2526                    HELLO
2527
2528                "}
2529            );
2530        });
2531    }
2532
2533    async fn run_until_first_tool_call(
2534        thread: &Entity<AcpThread>,
2535        cx: &mut TestAppContext,
2536    ) -> usize {
2537        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2538
2539        let subscription = cx.update(|cx| {
2540            cx.subscribe(thread, move |thread, _, cx| {
2541                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2542                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2543                        return tx.try_send(ix).unwrap();
2544                    }
2545                }
2546            })
2547        });
2548
2549        select! {
2550            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2551                panic!("Timeout waiting for tool call")
2552            }
2553            ix = rx.next().fuse() => {
2554                drop(subscription);
2555                ix.unwrap()
2556            }
2557        }
2558    }
2559
2560    #[derive(Clone, Default)]
2561    struct FakeAgentConnection {
2562        auth_methods: Vec<acp::AuthMethod>,
2563        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2564        on_user_message: Option<
2565            Rc<
2566                dyn Fn(
2567                        acp::PromptRequest,
2568                        WeakEntity<AcpThread>,
2569                        AsyncApp,
2570                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2571                    + 'static,
2572            >,
2573        >,
2574    }
2575
2576    impl FakeAgentConnection {
2577        fn new() -> Self {
2578            Self {
2579                auth_methods: Vec::new(),
2580                on_user_message: None,
2581                sessions: Arc::default(),
2582            }
2583        }
2584
2585        #[expect(unused)]
2586        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2587            self.auth_methods = auth_methods;
2588            self
2589        }
2590
2591        fn on_user_message(
2592            mut self,
2593            handler: impl Fn(
2594                acp::PromptRequest,
2595                WeakEntity<AcpThread>,
2596                AsyncApp,
2597            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2598            + 'static,
2599        ) -> Self {
2600            self.on_user_message.replace(Rc::new(handler));
2601            self
2602        }
2603    }
2604
2605    impl AgentConnection for FakeAgentConnection {
2606        fn auth_methods(&self) -> &[acp::AuthMethod] {
2607            &self.auth_methods
2608        }
2609
2610        fn new_thread(
2611            self: Rc<Self>,
2612            project: Entity<Project>,
2613            _cwd: &Path,
2614            cx: &mut App,
2615        ) -> Task<gpui::Result<Entity<AcpThread>>> {
2616            let session_id = acp::SessionId(
2617                rand::thread_rng()
2618                    .sample_iter(&rand::distributions::Alphanumeric)
2619                    .take(7)
2620                    .map(char::from)
2621                    .collect::<String>()
2622                    .into(),
2623            );
2624            let action_log = cx.new(|_| ActionLog::new(project.clone()));
2625            let thread = cx.new(|cx| {
2626                AcpThread::new(
2627                    "Test",
2628                    self.clone(),
2629                    project,
2630                    action_log,
2631                    session_id.clone(),
2632                    watch::Receiver::constant(acp::PromptCapabilities {
2633                        image: true,
2634                        audio: true,
2635                        embedded_context: true,
2636                    }),
2637                    cx,
2638                )
2639            });
2640            self.sessions.lock().insert(session_id, thread.downgrade());
2641            Task::ready(Ok(thread))
2642        }
2643
2644        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2645            if self.auth_methods().iter().any(|m| m.id == method) {
2646                Task::ready(Ok(()))
2647            } else {
2648                Task::ready(Err(anyhow!("Invalid Auth Method")))
2649            }
2650        }
2651
2652        fn prompt(
2653            &self,
2654            _id: Option<UserMessageId>,
2655            params: acp::PromptRequest,
2656            cx: &mut App,
2657        ) -> Task<gpui::Result<acp::PromptResponse>> {
2658            let sessions = self.sessions.lock();
2659            let thread = sessions.get(&params.session_id).unwrap();
2660            if let Some(handler) = &self.on_user_message {
2661                let handler = handler.clone();
2662                let thread = thread.clone();
2663                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2664            } else {
2665                Task::ready(Ok(acp::PromptResponse {
2666                    stop_reason: acp::StopReason::EndTurn,
2667                }))
2668            }
2669        }
2670
2671        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2672            let sessions = self.sessions.lock();
2673            let thread = sessions.get(session_id).unwrap().clone();
2674
2675            cx.spawn(async move |cx| {
2676                thread
2677                    .update(cx, |thread, cx| thread.cancel(cx))
2678                    .unwrap()
2679                    .await
2680            })
2681            .detach();
2682        }
2683
2684        fn truncate(
2685            &self,
2686            session_id: &acp::SessionId,
2687            _cx: &App,
2688        ) -> Option<Rc<dyn AgentSessionTruncate>> {
2689            Some(Rc::new(FakeAgentSessionEditor {
2690                _session_id: session_id.clone(),
2691            }))
2692        }
2693
2694        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2695            self
2696        }
2697    }
2698
2699    struct FakeAgentSessionEditor {
2700        _session_id: acp::SessionId,
2701    }
2702
2703    impl AgentSessionTruncate for FakeAgentSessionEditor {
2704        fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2705            Task::ready(Ok(()))
2706        }
2707    }
2708}