acp_thread.rs

   1mod connection;
   2mod diff;
   3mod mention;
   4mod terminal;
   5
   6use collections::HashSet;
   7pub use connection::*;
   8pub use diff::*;
   9use language::language_settings::FormatOnSave;
  10pub use mention::*;
  11use project::lsp_store::{FormatTrigger, LspFormatTarget};
  12use serde::{Deserialize, Serialize};
  13pub use terminal::*;
  14
  15use action_log::ActionLog;
  16use agent_client_protocol as acp;
  17use anyhow::{Context as _, Result, anyhow};
  18use editor::Bias;
  19use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  20use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  21use itertools::Itertools;
  22use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
  23use markdown::Markdown;
  24use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
  25use std::collections::HashMap;
  26use std::error::Error;
  27use std::fmt::{Formatter, Write};
  28use std::ops::Range;
  29use std::process::ExitStatus;
  30use std::rc::Rc;
  31use std::time::{Duration, Instant};
  32use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
  33use ui::App;
  34use util::ResultExt;
  35
  36#[derive(Debug)]
  37pub struct UserMessage {
  38    pub id: Option<UserMessageId>,
  39    pub content: ContentBlock,
  40    pub chunks: Vec<acp::ContentBlock>,
  41    pub checkpoint: Option<Checkpoint>,
  42}
  43
  44#[derive(Debug)]
  45pub struct Checkpoint {
  46    git_checkpoint: GitStoreCheckpoint,
  47    pub show: bool,
  48}
  49
  50impl UserMessage {
  51    fn to_markdown(&self, cx: &App) -> String {
  52        let mut markdown = String::new();
  53        if self
  54            .checkpoint
  55            .as_ref()
  56            .is_some_and(|checkpoint| checkpoint.show)
  57        {
  58            writeln!(markdown, "## User (checkpoint)").unwrap();
  59        } else {
  60            writeln!(markdown, "## User").unwrap();
  61        }
  62        writeln!(markdown).unwrap();
  63        writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
  64        writeln!(markdown).unwrap();
  65        markdown
  66    }
  67}
  68
  69#[derive(Debug, PartialEq)]
  70pub struct AssistantMessage {
  71    pub chunks: Vec<AssistantMessageChunk>,
  72}
  73
  74impl AssistantMessage {
  75    pub fn to_markdown(&self, cx: &App) -> String {
  76        format!(
  77            "## Assistant\n\n{}\n\n",
  78            self.chunks
  79                .iter()
  80                .map(|chunk| chunk.to_markdown(cx))
  81                .join("\n\n")
  82        )
  83    }
  84}
  85
  86#[derive(Debug, PartialEq)]
  87pub enum AssistantMessageChunk {
  88    Message { block: ContentBlock },
  89    Thought { block: ContentBlock },
  90}
  91
  92impl AssistantMessageChunk {
  93    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
  94        Self::Message {
  95            block: ContentBlock::new(chunk.into(), language_registry, cx),
  96        }
  97    }
  98
  99    fn to_markdown(&self, cx: &App) -> String {
 100        match self {
 101            Self::Message { block } => block.to_markdown(cx).to_string(),
 102            Self::Thought { block } => {
 103                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 104            }
 105        }
 106    }
 107}
 108
 109#[derive(Debug)]
 110pub enum AgentThreadEntry {
 111    UserMessage(UserMessage),
 112    AssistantMessage(AssistantMessage),
 113    ToolCall(ToolCall),
 114}
 115
 116impl AgentThreadEntry {
 117    pub fn to_markdown(&self, cx: &App) -> String {
 118        match self {
 119            Self::UserMessage(message) => message.to_markdown(cx),
 120            Self::AssistantMessage(message) => message.to_markdown(cx),
 121            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 122        }
 123    }
 124
 125    pub fn user_message(&self) -> Option<&UserMessage> {
 126        if let AgentThreadEntry::UserMessage(message) = self {
 127            Some(message)
 128        } else {
 129            None
 130        }
 131    }
 132
 133    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 134        if let AgentThreadEntry::ToolCall(call) = self {
 135            itertools::Either::Left(call.diffs())
 136        } else {
 137            itertools::Either::Right(std::iter::empty())
 138        }
 139    }
 140
 141    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 142        if let AgentThreadEntry::ToolCall(call) = self {
 143            itertools::Either::Left(call.terminals())
 144        } else {
 145            itertools::Either::Right(std::iter::empty())
 146        }
 147    }
 148
 149    pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
 150        if let AgentThreadEntry::ToolCall(ToolCall {
 151            locations,
 152            resolved_locations,
 153            ..
 154        }) = self
 155        {
 156            Some((
 157                locations.get(ix)?.clone(),
 158                resolved_locations.get(ix)?.clone()?,
 159            ))
 160        } else {
 161            None
 162        }
 163    }
 164}
 165
 166#[derive(Debug)]
 167pub struct ToolCall {
 168    pub id: acp::ToolCallId,
 169    pub label: Entity<Markdown>,
 170    pub kind: acp::ToolKind,
 171    pub content: Vec<ToolCallContent>,
 172    pub status: ToolCallStatus,
 173    pub locations: Vec<acp::ToolCallLocation>,
 174    pub resolved_locations: Vec<Option<AgentLocation>>,
 175    pub raw_input: Option<serde_json::Value>,
 176    pub raw_output: Option<serde_json::Value>,
 177}
 178
 179impl ToolCall {
 180    fn from_acp(
 181        tool_call: acp::ToolCall,
 182        status: ToolCallStatus,
 183        language_registry: Arc<LanguageRegistry>,
 184        cx: &mut App,
 185    ) -> Self {
 186        let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
 187            first_line.to_owned() + ""
 188        } else {
 189            tool_call.title
 190        };
 191        Self {
 192            id: tool_call.id,
 193            label: cx
 194                .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
 195            kind: tool_call.kind,
 196            content: tool_call
 197                .content
 198                .into_iter()
 199                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 200                .collect(),
 201            locations: tool_call.locations,
 202            resolved_locations: Vec::default(),
 203            status,
 204            raw_input: tool_call.raw_input,
 205            raw_output: tool_call.raw_output,
 206        }
 207    }
 208
 209    fn update_fields(
 210        &mut self,
 211        fields: acp::ToolCallUpdateFields,
 212        language_registry: Arc<LanguageRegistry>,
 213        cx: &mut App,
 214    ) {
 215        let acp::ToolCallUpdateFields {
 216            kind,
 217            status,
 218            title,
 219            content,
 220            locations,
 221            raw_input,
 222            raw_output,
 223        } = fields;
 224
 225        if let Some(kind) = kind {
 226            self.kind = kind;
 227        }
 228
 229        if let Some(status) = status {
 230            self.status = status.into();
 231        }
 232
 233        if let Some(title) = title {
 234            self.label.update(cx, |label, cx| {
 235                if let Some((first_line, _)) = title.split_once("\n") {
 236                    label.replace(first_line.to_owned() + "", cx)
 237                } else {
 238                    label.replace(title, cx);
 239                }
 240            });
 241        }
 242
 243        if let Some(content) = content {
 244            let new_content_len = content.len();
 245            let mut content = content.into_iter();
 246
 247            // Reuse existing content if we can
 248            for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
 249                old.update_from_acp(new, language_registry.clone(), cx);
 250            }
 251            for new in content {
 252                self.content.push(ToolCallContent::from_acp(
 253                    new,
 254                    language_registry.clone(),
 255                    cx,
 256                ))
 257            }
 258            self.content.truncate(new_content_len);
 259        }
 260
 261        if let Some(locations) = locations {
 262            self.locations = locations;
 263        }
 264
 265        if let Some(raw_input) = raw_input {
 266            self.raw_input = Some(raw_input);
 267        }
 268
 269        if let Some(raw_output) = raw_output {
 270            if self.content.is_empty()
 271                && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
 272            {
 273                self.content
 274                    .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
 275                        markdown,
 276                    }));
 277            }
 278            self.raw_output = Some(raw_output);
 279        }
 280    }
 281
 282    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 283        self.content.iter().filter_map(|content| match content {
 284            ToolCallContent::Diff(diff) => Some(diff),
 285            ToolCallContent::ContentBlock(_) => None,
 286            ToolCallContent::Terminal(_) => None,
 287        })
 288    }
 289
 290    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 291        self.content.iter().filter_map(|content| match content {
 292            ToolCallContent::Terminal(terminal) => Some(terminal),
 293            ToolCallContent::ContentBlock(_) => None,
 294            ToolCallContent::Diff(_) => None,
 295        })
 296    }
 297
 298    fn to_markdown(&self, cx: &App) -> String {
 299        let mut markdown = format!(
 300            "**Tool Call: {}**\nStatus: {}\n\n",
 301            self.label.read(cx).source(),
 302            self.status
 303        );
 304        for content in &self.content {
 305            markdown.push_str(content.to_markdown(cx).as_str());
 306            markdown.push_str("\n\n");
 307        }
 308        markdown
 309    }
 310
 311    async fn resolve_location(
 312        location: acp::ToolCallLocation,
 313        project: WeakEntity<Project>,
 314        cx: &mut AsyncApp,
 315    ) -> Option<AgentLocation> {
 316        let buffer = project
 317            .update(cx, |project, cx| {
 318                project
 319                    .project_path_for_absolute_path(&location.path, cx)
 320                    .map(|path| project.open_buffer(path, cx))
 321            })
 322            .ok()??;
 323        let buffer = buffer.await.log_err()?;
 324        let position = buffer
 325            .update(cx, |buffer, _| {
 326                if let Some(row) = location.line {
 327                    let snapshot = buffer.snapshot();
 328                    let column = snapshot.indent_size_for_line(row).len;
 329                    let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
 330                    snapshot.anchor_before(point)
 331                } else {
 332                    Anchor::MIN
 333                }
 334            })
 335            .ok()?;
 336
 337        Some(AgentLocation {
 338            buffer: buffer.downgrade(),
 339            position,
 340        })
 341    }
 342
 343    fn resolve_locations(
 344        &self,
 345        project: Entity<Project>,
 346        cx: &mut App,
 347    ) -> Task<Vec<Option<AgentLocation>>> {
 348        let locations = self.locations.clone();
 349        project.update(cx, |_, cx| {
 350            cx.spawn(async move |project, cx| {
 351                let mut new_locations = Vec::new();
 352                for location in locations {
 353                    new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
 354                }
 355                new_locations
 356            })
 357        })
 358    }
 359}
 360
 361#[derive(Debug)]
 362pub enum ToolCallStatus {
 363    /// The tool call hasn't started running yet, but we start showing it to
 364    /// the user.
 365    Pending,
 366    /// The tool call is waiting for confirmation from the user.
 367    WaitingForConfirmation {
 368        options: Vec<acp::PermissionOption>,
 369        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 370    },
 371    /// The tool call is currently running.
 372    InProgress,
 373    /// The tool call completed successfully.
 374    Completed,
 375    /// The tool call failed.
 376    Failed,
 377    /// The user rejected the tool call.
 378    Rejected,
 379    /// The user canceled generation so the tool call was canceled.
 380    Canceled,
 381}
 382
 383impl From<acp::ToolCallStatus> for ToolCallStatus {
 384    fn from(status: acp::ToolCallStatus) -> Self {
 385        match status {
 386            acp::ToolCallStatus::Pending => Self::Pending,
 387            acp::ToolCallStatus::InProgress => Self::InProgress,
 388            acp::ToolCallStatus::Completed => Self::Completed,
 389            acp::ToolCallStatus::Failed => Self::Failed,
 390        }
 391    }
 392}
 393
 394impl Display for ToolCallStatus {
 395    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 396        write!(
 397            f,
 398            "{}",
 399            match self {
 400                ToolCallStatus::Pending => "Pending",
 401                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 402                ToolCallStatus::InProgress => "In Progress",
 403                ToolCallStatus::Completed => "Completed",
 404                ToolCallStatus::Failed => "Failed",
 405                ToolCallStatus::Rejected => "Rejected",
 406                ToolCallStatus::Canceled => "Canceled",
 407            }
 408        )
 409    }
 410}
 411
 412#[derive(Debug, PartialEq, Clone)]
 413pub enum ContentBlock {
 414    Empty,
 415    Markdown { markdown: Entity<Markdown> },
 416    ResourceLink { resource_link: acp::ResourceLink },
 417}
 418
 419impl ContentBlock {
 420    pub fn new(
 421        block: acp::ContentBlock,
 422        language_registry: &Arc<LanguageRegistry>,
 423        cx: &mut App,
 424    ) -> Self {
 425        let mut this = Self::Empty;
 426        this.append(block, language_registry, cx);
 427        this
 428    }
 429
 430    pub fn new_combined(
 431        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 432        language_registry: Arc<LanguageRegistry>,
 433        cx: &mut App,
 434    ) -> Self {
 435        let mut this = Self::Empty;
 436        for block in blocks {
 437            this.append(block, &language_registry, cx);
 438        }
 439        this
 440    }
 441
 442    pub fn append(
 443        &mut self,
 444        block: acp::ContentBlock,
 445        language_registry: &Arc<LanguageRegistry>,
 446        cx: &mut App,
 447    ) {
 448        if matches!(self, ContentBlock::Empty)
 449            && let acp::ContentBlock::ResourceLink(resource_link) = block
 450        {
 451            *self = ContentBlock::ResourceLink { resource_link };
 452            return;
 453        }
 454
 455        let new_content = self.block_string_contents(block);
 456
 457        match self {
 458            ContentBlock::Empty => {
 459                *self = Self::create_markdown_block(new_content, language_registry, cx);
 460            }
 461            ContentBlock::Markdown { markdown } => {
 462                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 463            }
 464            ContentBlock::ResourceLink { resource_link } => {
 465                let existing_content = Self::resource_link_md(&resource_link.uri);
 466                let combined = format!("{}\n{}", existing_content, new_content);
 467
 468                *self = Self::create_markdown_block(combined, language_registry, cx);
 469            }
 470        }
 471    }
 472
 473    fn create_markdown_block(
 474        content: String,
 475        language_registry: &Arc<LanguageRegistry>,
 476        cx: &mut App,
 477    ) -> ContentBlock {
 478        ContentBlock::Markdown {
 479            markdown: cx
 480                .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
 481        }
 482    }
 483
 484    fn block_string_contents(&self, block: acp::ContentBlock) -> String {
 485        match block {
 486            acp::ContentBlock::Text(text_content) => text_content.text,
 487            acp::ContentBlock::ResourceLink(resource_link) => {
 488                Self::resource_link_md(&resource_link.uri)
 489            }
 490            acp::ContentBlock::Resource(acp::EmbeddedResource {
 491                resource:
 492                    acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
 493                        uri,
 494                        ..
 495                    }),
 496                ..
 497            }) => Self::resource_link_md(&uri),
 498            acp::ContentBlock::Image(image) => Self::image_md(&image),
 499            acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
 500        }
 501    }
 502
 503    fn resource_link_md(uri: &str) -> String {
 504        if let Some(uri) = MentionUri::parse(uri).log_err() {
 505            uri.as_link().to_string()
 506        } else {
 507            uri.to_string()
 508        }
 509    }
 510
 511    fn image_md(_image: &acp::ImageContent) -> String {
 512        "`Image`".into()
 513    }
 514
 515    pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 516        match self {
 517            ContentBlock::Empty => "",
 518            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 519            ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
 520        }
 521    }
 522
 523    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 524        match self {
 525            ContentBlock::Empty => None,
 526            ContentBlock::Markdown { markdown } => Some(markdown),
 527            ContentBlock::ResourceLink { .. } => None,
 528        }
 529    }
 530
 531    pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
 532        match self {
 533            ContentBlock::ResourceLink { resource_link } => Some(resource_link),
 534            _ => None,
 535        }
 536    }
 537}
 538
 539#[derive(Debug)]
 540pub enum ToolCallContent {
 541    ContentBlock(ContentBlock),
 542    Diff(Entity<Diff>),
 543    Terminal(Entity<Terminal>),
 544}
 545
 546impl ToolCallContent {
 547    pub fn from_acp(
 548        content: acp::ToolCallContent,
 549        language_registry: Arc<LanguageRegistry>,
 550        cx: &mut App,
 551    ) -> Self {
 552        match content {
 553            acp::ToolCallContent::Content { content } => {
 554                Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
 555            }
 556            acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
 557                Diff::finalized(
 558                    diff.path,
 559                    diff.old_text,
 560                    diff.new_text,
 561                    language_registry,
 562                    cx,
 563                )
 564            })),
 565        }
 566    }
 567
 568    pub fn update_from_acp(
 569        &mut self,
 570        new: acp::ToolCallContent,
 571        language_registry: Arc<LanguageRegistry>,
 572        cx: &mut App,
 573    ) {
 574        let needs_update = match (&self, &new) {
 575            (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
 576                old_diff.read(cx).needs_update(
 577                    new_diff.old_text.as_deref().unwrap_or(""),
 578                    &new_diff.new_text,
 579                    cx,
 580                )
 581            }
 582            _ => true,
 583        };
 584
 585        if needs_update {
 586            *self = Self::from_acp(new, language_registry, cx);
 587        }
 588    }
 589
 590    pub fn to_markdown(&self, cx: &App) -> String {
 591        match self {
 592            Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
 593            Self::Diff(diff) => diff.read(cx).to_markdown(cx),
 594            Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
 595        }
 596    }
 597}
 598
 599#[derive(Debug, PartialEq)]
 600pub enum ToolCallUpdate {
 601    UpdateFields(acp::ToolCallUpdate),
 602    UpdateDiff(ToolCallUpdateDiff),
 603    UpdateTerminal(ToolCallUpdateTerminal),
 604}
 605
 606impl ToolCallUpdate {
 607    fn id(&self) -> &acp::ToolCallId {
 608        match self {
 609            Self::UpdateFields(update) => &update.id,
 610            Self::UpdateDiff(diff) => &diff.id,
 611            Self::UpdateTerminal(terminal) => &terminal.id,
 612        }
 613    }
 614}
 615
 616impl From<acp::ToolCallUpdate> for ToolCallUpdate {
 617    fn from(update: acp::ToolCallUpdate) -> Self {
 618        Self::UpdateFields(update)
 619    }
 620}
 621
 622impl From<ToolCallUpdateDiff> for ToolCallUpdate {
 623    fn from(diff: ToolCallUpdateDiff) -> Self {
 624        Self::UpdateDiff(diff)
 625    }
 626}
 627
 628#[derive(Debug, PartialEq)]
 629pub struct ToolCallUpdateDiff {
 630    pub id: acp::ToolCallId,
 631    pub diff: Entity<Diff>,
 632}
 633
 634impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
 635    fn from(terminal: ToolCallUpdateTerminal) -> Self {
 636        Self::UpdateTerminal(terminal)
 637    }
 638}
 639
 640#[derive(Debug, PartialEq)]
 641pub struct ToolCallUpdateTerminal {
 642    pub id: acp::ToolCallId,
 643    pub terminal: Entity<Terminal>,
 644}
 645
 646#[derive(Debug, Default)]
 647pub struct Plan {
 648    pub entries: Vec<PlanEntry>,
 649}
 650
 651#[derive(Debug)]
 652pub struct PlanStats<'a> {
 653    pub in_progress_entry: Option<&'a PlanEntry>,
 654    pub pending: u32,
 655    pub completed: u32,
 656}
 657
 658impl Plan {
 659    pub fn is_empty(&self) -> bool {
 660        self.entries.is_empty()
 661    }
 662
 663    pub fn stats(&self) -> PlanStats<'_> {
 664        let mut stats = PlanStats {
 665            in_progress_entry: None,
 666            pending: 0,
 667            completed: 0,
 668        };
 669
 670        for entry in &self.entries {
 671            match &entry.status {
 672                acp::PlanEntryStatus::Pending => {
 673                    stats.pending += 1;
 674                }
 675                acp::PlanEntryStatus::InProgress => {
 676                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 677                }
 678                acp::PlanEntryStatus::Completed => {
 679                    stats.completed += 1;
 680                }
 681            }
 682        }
 683
 684        stats
 685    }
 686}
 687
 688#[derive(Debug)]
 689pub struct PlanEntry {
 690    pub content: Entity<Markdown>,
 691    pub priority: acp::PlanEntryPriority,
 692    pub status: acp::PlanEntryStatus,
 693}
 694
 695impl PlanEntry {
 696    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 697        Self {
 698            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 699            priority: entry.priority,
 700            status: entry.status,
 701        }
 702    }
 703}
 704
 705#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 706pub struct TokenUsage {
 707    pub max_tokens: u64,
 708    pub used_tokens: u64,
 709}
 710
 711impl TokenUsage {
 712    pub fn ratio(&self) -> TokenUsageRatio {
 713        #[cfg(debug_assertions)]
 714        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 715            .unwrap_or("0.8".to_string())
 716            .parse()
 717            .unwrap();
 718        #[cfg(not(debug_assertions))]
 719        let warning_threshold: f32 = 0.8;
 720
 721        // When the maximum is unknown because there is no selected model,
 722        // avoid showing the token limit warning.
 723        if self.max_tokens == 0 {
 724            TokenUsageRatio::Normal
 725        } else if self.used_tokens >= self.max_tokens {
 726            TokenUsageRatio::Exceeded
 727        } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
 728            TokenUsageRatio::Warning
 729        } else {
 730            TokenUsageRatio::Normal
 731        }
 732    }
 733}
 734
 735#[derive(Debug, Clone, PartialEq, Eq)]
 736pub enum TokenUsageRatio {
 737    Normal,
 738    Warning,
 739    Exceeded,
 740}
 741
 742#[derive(Debug, Clone)]
 743pub struct RetryStatus {
 744    pub last_error: SharedString,
 745    pub attempt: usize,
 746    pub max_attempts: usize,
 747    pub started_at: Instant,
 748    pub duration: Duration,
 749}
 750
 751pub struct AcpThread {
 752    title: SharedString,
 753    entries: Vec<AgentThreadEntry>,
 754    plan: Plan,
 755    project: Entity<Project>,
 756    action_log: Entity<ActionLog>,
 757    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 758    send_task: Option<Task<()>>,
 759    connection: Rc<dyn AgentConnection>,
 760    session_id: acp::SessionId,
 761    token_usage: Option<TokenUsage>,
 762    prompt_capabilities: acp::PromptCapabilities,
 763    _observe_prompt_capabilities: Task<anyhow::Result<()>>,
 764}
 765
 766#[derive(Debug)]
 767pub enum AcpThreadEvent {
 768    NewEntry,
 769    TitleUpdated,
 770    TokenUsageUpdated,
 771    EntryUpdated(usize),
 772    EntriesRemoved(Range<usize>),
 773    ToolAuthorizationRequired,
 774    Retry(RetryStatus),
 775    Stopped,
 776    Error,
 777    LoadError(LoadError),
 778    PromptCapabilitiesUpdated,
 779}
 780
 781impl EventEmitter<AcpThreadEvent> for AcpThread {}
 782
 783#[derive(PartialEq, Eq, Debug)]
 784pub enum ThreadStatus {
 785    Idle,
 786    WaitingForToolConfirmation,
 787    Generating,
 788}
 789
 790#[derive(Debug, Clone)]
 791pub enum LoadError {
 792    Unsupported {
 793        command: SharedString,
 794        current_version: SharedString,
 795        minimum_version: SharedString,
 796    },
 797    FailedToInstall(SharedString),
 798    Exited {
 799        status: ExitStatus,
 800    },
 801    Other(SharedString),
 802}
 803
 804impl Display for LoadError {
 805    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 806        match self {
 807            LoadError::Unsupported {
 808                command: path,
 809                current_version,
 810                minimum_version,
 811            } => {
 812                write!(
 813                    f,
 814                    "version {current_version} from {path} is not supported (need at least {minimum_version})"
 815                )
 816            }
 817            LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
 818            LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
 819            LoadError::Other(msg) => write!(f, "{msg}"),
 820        }
 821    }
 822}
 823
 824impl Error for LoadError {}
 825
 826impl AcpThread {
 827    pub fn new(
 828        title: impl Into<SharedString>,
 829        connection: Rc<dyn AgentConnection>,
 830        project: Entity<Project>,
 831        action_log: Entity<ActionLog>,
 832        session_id: acp::SessionId,
 833        mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
 834        cx: &mut Context<Self>,
 835    ) -> Self {
 836        let prompt_capabilities = *prompt_capabilities_rx.borrow();
 837        let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
 838            loop {
 839                let caps = prompt_capabilities_rx.recv().await?;
 840                this.update(cx, |this, cx| {
 841                    this.prompt_capabilities = caps;
 842                    cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
 843                })?;
 844            }
 845        });
 846
 847        Self {
 848            action_log,
 849            shared_buffers: Default::default(),
 850            entries: Default::default(),
 851            plan: Default::default(),
 852            title: title.into(),
 853            project,
 854            send_task: None,
 855            connection,
 856            session_id,
 857            token_usage: None,
 858            prompt_capabilities,
 859            _observe_prompt_capabilities: task,
 860        }
 861    }
 862
 863    pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
 864        self.prompt_capabilities
 865    }
 866
 867    pub fn connection(&self) -> &Rc<dyn AgentConnection> {
 868        &self.connection
 869    }
 870
 871    pub fn action_log(&self) -> &Entity<ActionLog> {
 872        &self.action_log
 873    }
 874
 875    pub fn project(&self) -> &Entity<Project> {
 876        &self.project
 877    }
 878
 879    pub fn title(&self) -> SharedString {
 880        self.title.clone()
 881    }
 882
 883    pub fn entries(&self) -> &[AgentThreadEntry] {
 884        &self.entries
 885    }
 886
 887    pub fn session_id(&self) -> &acp::SessionId {
 888        &self.session_id
 889    }
 890
 891    pub fn status(&self) -> ThreadStatus {
 892        if self.send_task.is_some() {
 893            if self.waiting_for_tool_confirmation() {
 894                ThreadStatus::WaitingForToolConfirmation
 895            } else {
 896                ThreadStatus::Generating
 897            }
 898        } else {
 899            ThreadStatus::Idle
 900        }
 901    }
 902
 903    pub fn token_usage(&self) -> Option<&TokenUsage> {
 904        self.token_usage.as_ref()
 905    }
 906
 907    pub fn has_pending_edit_tool_calls(&self) -> bool {
 908        for entry in self.entries.iter().rev() {
 909            match entry {
 910                AgentThreadEntry::UserMessage(_) => return false,
 911                AgentThreadEntry::ToolCall(
 912                    call @ ToolCall {
 913                        status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
 914                        ..
 915                    },
 916                ) if call.diffs().next().is_some() => {
 917                    return true;
 918                }
 919                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 920            }
 921        }
 922
 923        false
 924    }
 925
 926    pub fn used_tools_since_last_user_message(&self) -> bool {
 927        for entry in self.entries.iter().rev() {
 928            match entry {
 929                AgentThreadEntry::UserMessage(..) => return false,
 930                AgentThreadEntry::AssistantMessage(..) => continue,
 931                AgentThreadEntry::ToolCall(..) => return true,
 932            }
 933        }
 934
 935        false
 936    }
 937
 938    pub fn handle_session_update(
 939        &mut self,
 940        update: acp::SessionUpdate,
 941        cx: &mut Context<Self>,
 942    ) -> Result<(), acp::Error> {
 943        match update {
 944            acp::SessionUpdate::UserMessageChunk { content } => {
 945                self.push_user_content_block(None, content, cx);
 946            }
 947            acp::SessionUpdate::AgentMessageChunk { content } => {
 948                self.push_assistant_content_block(content, false, cx);
 949            }
 950            acp::SessionUpdate::AgentThoughtChunk { content } => {
 951                self.push_assistant_content_block(content, true, cx);
 952            }
 953            acp::SessionUpdate::ToolCall(tool_call) => {
 954                self.upsert_tool_call(tool_call, cx)?;
 955            }
 956            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 957                self.update_tool_call(tool_call_update, cx)?;
 958            }
 959            acp::SessionUpdate::Plan(plan) => {
 960                self.update_plan(plan, cx);
 961            }
 962        }
 963        Ok(())
 964    }
 965
 966    pub fn push_user_content_block(
 967        &mut self,
 968        message_id: Option<UserMessageId>,
 969        chunk: acp::ContentBlock,
 970        cx: &mut Context<Self>,
 971    ) {
 972        let language_registry = self.project.read(cx).languages().clone();
 973        let entries_len = self.entries.len();
 974
 975        if let Some(last_entry) = self.entries.last_mut()
 976            && let AgentThreadEntry::UserMessage(UserMessage {
 977                id,
 978                content,
 979                chunks,
 980                ..
 981            }) = last_entry
 982        {
 983            *id = message_id.or(id.take());
 984            content.append(chunk.clone(), &language_registry, cx);
 985            chunks.push(chunk);
 986            let idx = entries_len - 1;
 987            cx.emit(AcpThreadEvent::EntryUpdated(idx));
 988        } else {
 989            let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
 990            self.push_entry(
 991                AgentThreadEntry::UserMessage(UserMessage {
 992                    id: message_id,
 993                    content,
 994                    chunks: vec![chunk],
 995                    checkpoint: None,
 996                }),
 997                cx,
 998            );
 999        }
1000    }
1001
1002    pub fn push_assistant_content_block(
1003        &mut self,
1004        chunk: acp::ContentBlock,
1005        is_thought: bool,
1006        cx: &mut Context<Self>,
1007    ) {
1008        let language_registry = self.project.read(cx).languages().clone();
1009        let entries_len = self.entries.len();
1010        if let Some(last_entry) = self.entries.last_mut()
1011            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1012        {
1013            let idx = entries_len - 1;
1014            cx.emit(AcpThreadEvent::EntryUpdated(idx));
1015            match (chunks.last_mut(), is_thought) {
1016                (Some(AssistantMessageChunk::Message { block }), false)
1017                | (Some(AssistantMessageChunk::Thought { block }), true) => {
1018                    block.append(chunk, &language_registry, cx)
1019                }
1020                _ => {
1021                    let block = ContentBlock::new(chunk, &language_registry, cx);
1022                    if is_thought {
1023                        chunks.push(AssistantMessageChunk::Thought { block })
1024                    } else {
1025                        chunks.push(AssistantMessageChunk::Message { block })
1026                    }
1027                }
1028            }
1029        } else {
1030            let block = ContentBlock::new(chunk, &language_registry, cx);
1031            let chunk = if is_thought {
1032                AssistantMessageChunk::Thought { block }
1033            } else {
1034                AssistantMessageChunk::Message { block }
1035            };
1036
1037            self.push_entry(
1038                AgentThreadEntry::AssistantMessage(AssistantMessage {
1039                    chunks: vec![chunk],
1040                }),
1041                cx,
1042            );
1043        }
1044    }
1045
1046    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1047        self.entries.push(entry);
1048        cx.emit(AcpThreadEvent::NewEntry);
1049    }
1050
1051    pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1052        self.connection.set_title(&self.session_id, cx).is_some()
1053    }
1054
1055    pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1056        if title != self.title {
1057            self.title = title.clone();
1058            cx.emit(AcpThreadEvent::TitleUpdated);
1059            if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1060                return set_title.run(title, cx);
1061            }
1062        }
1063        Task::ready(Ok(()))
1064    }
1065
1066    pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1067        self.token_usage = usage;
1068        cx.emit(AcpThreadEvent::TokenUsageUpdated);
1069    }
1070
1071    pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1072        cx.emit(AcpThreadEvent::Retry(status));
1073    }
1074
1075    pub fn update_tool_call(
1076        &mut self,
1077        update: impl Into<ToolCallUpdate>,
1078        cx: &mut Context<Self>,
1079    ) -> Result<()> {
1080        let update = update.into();
1081        let languages = self.project.read(cx).languages().clone();
1082
1083        let (ix, current_call) = self
1084            .tool_call_mut(update.id())
1085            .context("Tool call not found")?;
1086        match update {
1087            ToolCallUpdate::UpdateFields(update) => {
1088                let location_updated = update.fields.locations.is_some();
1089                current_call.update_fields(update.fields, languages, cx);
1090                if location_updated {
1091                    self.resolve_locations(update.id, cx);
1092                }
1093            }
1094            ToolCallUpdate::UpdateDiff(update) => {
1095                current_call.content.clear();
1096                current_call
1097                    .content
1098                    .push(ToolCallContent::Diff(update.diff));
1099            }
1100            ToolCallUpdate::UpdateTerminal(update) => {
1101                current_call.content.clear();
1102                current_call
1103                    .content
1104                    .push(ToolCallContent::Terminal(update.terminal));
1105            }
1106        }
1107
1108        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1109
1110        Ok(())
1111    }
1112
1113    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1114    pub fn upsert_tool_call(
1115        &mut self,
1116        tool_call: acp::ToolCall,
1117        cx: &mut Context<Self>,
1118    ) -> Result<(), acp::Error> {
1119        let status = tool_call.status.into();
1120        self.upsert_tool_call_inner(tool_call.into(), status, cx)
1121    }
1122
1123    /// Fails if id does not match an existing entry.
1124    pub fn upsert_tool_call_inner(
1125        &mut self,
1126        tool_call_update: acp::ToolCallUpdate,
1127        status: ToolCallStatus,
1128        cx: &mut Context<Self>,
1129    ) -> Result<(), acp::Error> {
1130        let language_registry = self.project.read(cx).languages().clone();
1131        let id = tool_call_update.id.clone();
1132
1133        if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1134            current_call.update_fields(tool_call_update.fields, language_registry, cx);
1135            current_call.status = status;
1136
1137            cx.emit(AcpThreadEvent::EntryUpdated(ix));
1138        } else {
1139            let call =
1140                ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1141            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1142        };
1143
1144        self.resolve_locations(id, cx);
1145        Ok(())
1146    }
1147
1148    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1149        // The tool call we are looking for is typically the last one, or very close to the end.
1150        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1151        self.entries
1152            .iter_mut()
1153            .enumerate()
1154            .rev()
1155            .find_map(|(index, tool_call)| {
1156                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1157                    && &tool_call.id == id
1158                {
1159                    Some((index, tool_call))
1160                } else {
1161                    None
1162                }
1163            })
1164    }
1165
1166    pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1167        self.entries
1168            .iter()
1169            .enumerate()
1170            .rev()
1171            .find_map(|(index, tool_call)| {
1172                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1173                    && &tool_call.id == id
1174                {
1175                    Some((index, tool_call))
1176                } else {
1177                    None
1178                }
1179            })
1180    }
1181
1182    pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1183        let project = self.project.clone();
1184        let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1185            return;
1186        };
1187        let task = tool_call.resolve_locations(project, cx);
1188        cx.spawn(async move |this, cx| {
1189            let resolved_locations = task.await;
1190            this.update(cx, |this, cx| {
1191                let project = this.project.clone();
1192                let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1193                    return;
1194                };
1195                if let Some(Some(location)) = resolved_locations.last() {
1196                    project.update(cx, |project, cx| {
1197                        if let Some(agent_location) = project.agent_location() {
1198                            let should_ignore = agent_location.buffer == location.buffer
1199                                && location
1200                                    .buffer
1201                                    .update(cx, |buffer, _| {
1202                                        let snapshot = buffer.snapshot();
1203                                        let old_position =
1204                                            agent_location.position.to_point(&snapshot);
1205                                        let new_position = location.position.to_point(&snapshot);
1206                                        // ignore this so that when we get updates from the edit tool
1207                                        // the position doesn't reset to the startof line
1208                                        old_position.row == new_position.row
1209                                            && old_position.column > new_position.column
1210                                    })
1211                                    .ok()
1212                                    .unwrap_or_default();
1213                            if !should_ignore {
1214                                project.set_agent_location(Some(location.clone()), cx);
1215                            }
1216                        }
1217                    });
1218                }
1219                if tool_call.resolved_locations != resolved_locations {
1220                    tool_call.resolved_locations = resolved_locations;
1221                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1222                }
1223            })
1224        })
1225        .detach();
1226    }
1227
1228    pub fn request_tool_call_authorization(
1229        &mut self,
1230        tool_call: acp::ToolCallUpdate,
1231        options: Vec<acp::PermissionOption>,
1232        cx: &mut Context<Self>,
1233    ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1234        let (tx, rx) = oneshot::channel();
1235
1236        let status = ToolCallStatus::WaitingForConfirmation {
1237            options,
1238            respond_tx: tx,
1239        };
1240
1241        self.upsert_tool_call_inner(tool_call, status, cx)?;
1242        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1243        Ok(rx)
1244    }
1245
1246    pub fn authorize_tool_call(
1247        &mut self,
1248        id: acp::ToolCallId,
1249        option_id: acp::PermissionOptionId,
1250        option_kind: acp::PermissionOptionKind,
1251        cx: &mut Context<Self>,
1252    ) {
1253        let Some((ix, call)) = self.tool_call_mut(&id) else {
1254            return;
1255        };
1256
1257        let new_status = match option_kind {
1258            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1259                ToolCallStatus::Rejected
1260            }
1261            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1262                ToolCallStatus::InProgress
1263            }
1264        };
1265
1266        let curr_status = mem::replace(&mut call.status, new_status);
1267
1268        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1269            respond_tx.send(option_id).log_err();
1270        } else if cfg!(debug_assertions) {
1271            panic!("tried to authorize an already authorized tool call");
1272        }
1273
1274        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1275    }
1276
1277    /// Returns true if the last turn is awaiting tool authorization
1278    pub fn waiting_for_tool_confirmation(&self) -> bool {
1279        for entry in self.entries.iter().rev() {
1280            match &entry {
1281                AgentThreadEntry::ToolCall(call) => match call.status {
1282                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
1283                    ToolCallStatus::Pending
1284                    | ToolCallStatus::InProgress
1285                    | ToolCallStatus::Completed
1286                    | ToolCallStatus::Failed
1287                    | ToolCallStatus::Rejected
1288                    | ToolCallStatus::Canceled => continue,
1289                },
1290                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1291                    // Reached the beginning of the turn
1292                    return false;
1293                }
1294            }
1295        }
1296        false
1297    }
1298
1299    pub fn plan(&self) -> &Plan {
1300        &self.plan
1301    }
1302
1303    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1304        let new_entries_len = request.entries.len();
1305        let mut new_entries = request.entries.into_iter();
1306
1307        // Reuse existing markdown to prevent flickering
1308        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1309            let PlanEntry {
1310                content,
1311                priority,
1312                status,
1313            } = old;
1314            content.update(cx, |old, cx| {
1315                old.replace(new.content, cx);
1316            });
1317            *priority = new.priority;
1318            *status = new.status;
1319        }
1320        for new in new_entries {
1321            self.plan.entries.push(PlanEntry::from_acp(new, cx))
1322        }
1323        self.plan.entries.truncate(new_entries_len);
1324
1325        cx.notify();
1326    }
1327
1328    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1329        self.plan
1330            .entries
1331            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1332        cx.notify();
1333    }
1334
1335    #[cfg(any(test, feature = "test-support"))]
1336    pub fn send_raw(
1337        &mut self,
1338        message: &str,
1339        cx: &mut Context<Self>,
1340    ) -> BoxFuture<'static, Result<()>> {
1341        self.send(
1342            vec![acp::ContentBlock::Text(acp::TextContent {
1343                text: message.to_string(),
1344                annotations: None,
1345            })],
1346            cx,
1347        )
1348    }
1349
1350    pub fn send(
1351        &mut self,
1352        message: Vec<acp::ContentBlock>,
1353        cx: &mut Context<Self>,
1354    ) -> BoxFuture<'static, Result<()>> {
1355        let block = ContentBlock::new_combined(
1356            message.clone(),
1357            self.project.read(cx).languages().clone(),
1358            cx,
1359        );
1360        let request = acp::PromptRequest {
1361            prompt: message.clone(),
1362            session_id: self.session_id.clone(),
1363        };
1364        let git_store = self.project.read(cx).git_store().clone();
1365
1366        let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1367            Some(UserMessageId::new())
1368        } else {
1369            None
1370        };
1371
1372        self.run_turn(cx, async move |this, cx| {
1373            this.update(cx, |this, cx| {
1374                this.push_entry(
1375                    AgentThreadEntry::UserMessage(UserMessage {
1376                        id: message_id.clone(),
1377                        content: block,
1378                        chunks: message,
1379                        checkpoint: None,
1380                    }),
1381                    cx,
1382                );
1383            })
1384            .ok();
1385
1386            let old_checkpoint = git_store
1387                .update(cx, |git, cx| git.checkpoint(cx))?
1388                .await
1389                .context("failed to get old checkpoint")
1390                .log_err();
1391            this.update(cx, |this, cx| {
1392                if let Some((_ix, message)) = this.last_user_message() {
1393                    message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1394                        git_checkpoint,
1395                        show: false,
1396                    });
1397                }
1398                this.connection.prompt(message_id, request, cx)
1399            })?
1400            .await
1401        })
1402    }
1403
1404    pub fn can_resume(&self, cx: &App) -> bool {
1405        self.connection.resume(&self.session_id, cx).is_some()
1406    }
1407
1408    pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1409        self.run_turn(cx, async move |this, cx| {
1410            this.update(cx, |this, cx| {
1411                this.connection
1412                    .resume(&this.session_id, cx)
1413                    .map(|resume| resume.run(cx))
1414            })?
1415            .context("resuming a session is not supported")?
1416            .await
1417        })
1418    }
1419
1420    fn run_turn(
1421        &mut self,
1422        cx: &mut Context<Self>,
1423        f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1424    ) -> BoxFuture<'static, Result<()>> {
1425        self.clear_completed_plan_entries(cx);
1426
1427        let (tx, rx) = oneshot::channel();
1428        let cancel_task = self.cancel(cx);
1429
1430        self.send_task = Some(cx.spawn(async move |this, cx| {
1431            cancel_task.await;
1432            tx.send(f(this, cx).await).ok();
1433        }));
1434
1435        cx.spawn(async move |this, cx| {
1436            let response = rx.await;
1437
1438            this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1439                .await?;
1440
1441            this.update(cx, |this, cx| {
1442                this.project
1443                    .update(cx, |project, cx| project.set_agent_location(None, cx));
1444                match response {
1445                    Ok(Err(e)) => {
1446                        this.send_task.take();
1447                        cx.emit(AcpThreadEvent::Error);
1448                        Err(e)
1449                    }
1450                    result => {
1451                        let canceled = matches!(
1452                            result,
1453                            Ok(Ok(acp::PromptResponse {
1454                                stop_reason: acp::StopReason::Cancelled
1455                            }))
1456                        );
1457
1458                        // We only take the task if the current prompt wasn't canceled.
1459                        //
1460                        // This prompt may have been canceled because another one was sent
1461                        // while it was still generating. In these cases, dropping `send_task`
1462                        // would cause the next generation to be canceled.
1463                        if !canceled {
1464                            this.send_task.take();
1465                        }
1466
1467                        // Truncate entries if the last prompt was refused.
1468                        if let Ok(Ok(acp::PromptResponse {
1469                            stop_reason: acp::StopReason::Refusal,
1470                        })) = result
1471                            && let Some((ix, _)) = this.last_user_message()
1472                        {
1473                            let range = ix..this.entries.len();
1474                            this.entries.truncate(ix);
1475                            cx.emit(AcpThreadEvent::EntriesRemoved(range));
1476                        }
1477
1478                        cx.emit(AcpThreadEvent::Stopped);
1479                        Ok(())
1480                    }
1481                }
1482            })?
1483        })
1484        .boxed()
1485    }
1486
1487    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1488        let Some(send_task) = self.send_task.take() else {
1489            return Task::ready(());
1490        };
1491
1492        for entry in self.entries.iter_mut() {
1493            if let AgentThreadEntry::ToolCall(call) = entry {
1494                let cancel = matches!(
1495                    call.status,
1496                    ToolCallStatus::Pending
1497                        | ToolCallStatus::WaitingForConfirmation { .. }
1498                        | ToolCallStatus::InProgress
1499                );
1500
1501                if cancel {
1502                    call.status = ToolCallStatus::Canceled;
1503                }
1504            }
1505        }
1506
1507        self.connection.cancel(&self.session_id, cx);
1508
1509        // Wait for the send task to complete
1510        cx.foreground_executor().spawn(send_task)
1511    }
1512
1513    /// Rewinds this thread to before the entry at `index`, removing it and all
1514    /// subsequent entries while reverting any changes made from that point.
1515    pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1516        let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1517            return Task::ready(Err(anyhow!("not supported")));
1518        };
1519        let Some(message) = self.user_message(&id) else {
1520            return Task::ready(Err(anyhow!("message not found")));
1521        };
1522
1523        let checkpoint = message
1524            .checkpoint
1525            .as_ref()
1526            .map(|c| c.git_checkpoint.clone());
1527
1528        let git_store = self.project.read(cx).git_store().clone();
1529        cx.spawn(async move |this, cx| {
1530            if let Some(checkpoint) = checkpoint {
1531                git_store
1532                    .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1533                    .await?;
1534            }
1535
1536            cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1537            this.update(cx, |this, cx| {
1538                if let Some((ix, _)) = this.user_message_mut(&id) {
1539                    let range = ix..this.entries.len();
1540                    this.entries.truncate(ix);
1541                    cx.emit(AcpThreadEvent::EntriesRemoved(range));
1542                }
1543            })
1544        })
1545    }
1546
1547    fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1548        let git_store = self.project.read(cx).git_store().clone();
1549
1550        let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1551            if let Some(checkpoint) = message.checkpoint.as_ref() {
1552                checkpoint.git_checkpoint.clone()
1553            } else {
1554                return Task::ready(Ok(()));
1555            }
1556        } else {
1557            return Task::ready(Ok(()));
1558        };
1559
1560        let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1561        cx.spawn(async move |this, cx| {
1562            let new_checkpoint = new_checkpoint
1563                .await
1564                .context("failed to get new checkpoint")
1565                .log_err();
1566            if let Some(new_checkpoint) = new_checkpoint {
1567                let equal = git_store
1568                    .update(cx, |git, cx| {
1569                        git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1570                    })?
1571                    .await
1572                    .unwrap_or(true);
1573                this.update(cx, |this, cx| {
1574                    let (ix, message) = this.last_user_message().context("no user message")?;
1575                    let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1576                    checkpoint.show = !equal;
1577                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1578                    anyhow::Ok(())
1579                })??;
1580            }
1581
1582            Ok(())
1583        })
1584    }
1585
1586    fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1587        self.entries
1588            .iter_mut()
1589            .enumerate()
1590            .rev()
1591            .find_map(|(ix, entry)| {
1592                if let AgentThreadEntry::UserMessage(message) = entry {
1593                    Some((ix, message))
1594                } else {
1595                    None
1596                }
1597            })
1598    }
1599
1600    fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1601        self.entries.iter().find_map(|entry| {
1602            if let AgentThreadEntry::UserMessage(message) = entry {
1603                if message.id.as_ref() == Some(id) {
1604                    Some(message)
1605                } else {
1606                    None
1607                }
1608            } else {
1609                None
1610            }
1611        })
1612    }
1613
1614    fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1615        self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1616            if let AgentThreadEntry::UserMessage(message) = entry {
1617                if message.id.as_ref() == Some(id) {
1618                    Some((ix, message))
1619                } else {
1620                    None
1621                }
1622            } else {
1623                None
1624            }
1625        })
1626    }
1627
1628    pub fn read_text_file(
1629        &self,
1630        path: PathBuf,
1631        line: Option<u32>,
1632        limit: Option<u32>,
1633        reuse_shared_snapshot: bool,
1634        cx: &mut Context<Self>,
1635    ) -> Task<Result<String>> {
1636        let project = self.project.clone();
1637        let action_log = self.action_log.clone();
1638        cx.spawn(async move |this, cx| {
1639            let load = project.update(cx, |project, cx| {
1640                let path = project
1641                    .project_path_for_absolute_path(&path, cx)
1642                    .context("invalid path")?;
1643                anyhow::Ok(project.open_buffer(path, cx))
1644            });
1645            let buffer = load??.await?;
1646
1647            let snapshot = if reuse_shared_snapshot {
1648                this.read_with(cx, |this, _| {
1649                    this.shared_buffers.get(&buffer.clone()).cloned()
1650                })
1651                .log_err()
1652                .flatten()
1653            } else {
1654                None
1655            };
1656
1657            let snapshot = if let Some(snapshot) = snapshot {
1658                snapshot
1659            } else {
1660                action_log.update(cx, |action_log, cx| {
1661                    action_log.buffer_read(buffer.clone(), cx);
1662                })?;
1663                project.update(cx, |project, cx| {
1664                    let position = buffer
1665                        .read(cx)
1666                        .snapshot()
1667                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1668                    project.set_agent_location(
1669                        Some(AgentLocation {
1670                            buffer: buffer.downgrade(),
1671                            position,
1672                        }),
1673                        cx,
1674                    );
1675                })?;
1676
1677                buffer.update(cx, |buffer, _| buffer.snapshot())?
1678            };
1679
1680            this.update(cx, |this, _| {
1681                let text = snapshot.text();
1682                this.shared_buffers.insert(buffer.clone(), snapshot);
1683                if line.is_none() && limit.is_none() {
1684                    return Ok(text);
1685                }
1686                let limit = limit.unwrap_or(u32::MAX) as usize;
1687                let Some(line) = line else {
1688                    return Ok(text.lines().take(limit).collect::<String>());
1689                };
1690
1691                let count = text.lines().count();
1692                if count < line as usize {
1693                    anyhow::bail!("There are only {} lines", count);
1694                }
1695                Ok(text
1696                    .lines()
1697                    .skip(line as usize + 1)
1698                    .take(limit)
1699                    .collect::<String>())
1700            })?
1701        })
1702    }
1703
1704    pub fn write_text_file(
1705        &self,
1706        path: PathBuf,
1707        content: String,
1708        cx: &mut Context<Self>,
1709    ) -> Task<Result<()>> {
1710        let project = self.project.clone();
1711        let action_log = self.action_log.clone();
1712        cx.spawn(async move |this, cx| {
1713            let load = project.update(cx, |project, cx| {
1714                let path = project
1715                    .project_path_for_absolute_path(&path, cx)
1716                    .context("invalid path")?;
1717                anyhow::Ok(project.open_buffer(path, cx))
1718            });
1719            let buffer = load??.await?;
1720            let snapshot = this.update(cx, |this, cx| {
1721                this.shared_buffers
1722                    .get(&buffer)
1723                    .cloned()
1724                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1725            })?;
1726            let edits = cx
1727                .background_executor()
1728                .spawn(async move {
1729                    let old_text = snapshot.text();
1730                    text_diff(old_text.as_str(), &content)
1731                        .into_iter()
1732                        .map(|(range, replacement)| {
1733                            (
1734                                snapshot.anchor_after(range.start)
1735                                    ..snapshot.anchor_before(range.end),
1736                                replacement,
1737                            )
1738                        })
1739                        .collect::<Vec<_>>()
1740                })
1741                .await;
1742
1743            project.update(cx, |project, cx| {
1744                project.set_agent_location(
1745                    Some(AgentLocation {
1746                        buffer: buffer.downgrade(),
1747                        position: edits
1748                            .last()
1749                            .map(|(range, _)| range.end)
1750                            .unwrap_or(Anchor::MIN),
1751                    }),
1752                    cx,
1753                );
1754            })?;
1755
1756            let format_on_save = cx.update(|cx| {
1757                action_log.update(cx, |action_log, cx| {
1758                    action_log.buffer_read(buffer.clone(), cx);
1759                });
1760
1761                let format_on_save = buffer.update(cx, |buffer, cx| {
1762                    buffer.edit(edits, None, cx);
1763
1764                    let settings = language::language_settings::language_settings(
1765                        buffer.language().map(|l| l.name()),
1766                        buffer.file(),
1767                        cx,
1768                    );
1769
1770                    settings.format_on_save != FormatOnSave::Off
1771                });
1772                action_log.update(cx, |action_log, cx| {
1773                    action_log.buffer_edited(buffer.clone(), cx);
1774                });
1775                format_on_save
1776            })?;
1777
1778            if format_on_save {
1779                let format_task = project.update(cx, |project, cx| {
1780                    project.format(
1781                        HashSet::from_iter([buffer.clone()]),
1782                        LspFormatTarget::Buffers,
1783                        false,
1784                        FormatTrigger::Save,
1785                        cx,
1786                    )
1787                })?;
1788                format_task.await.log_err();
1789
1790                action_log.update(cx, |action_log, cx| {
1791                    action_log.buffer_edited(buffer.clone(), cx);
1792                })?;
1793            }
1794
1795            project
1796                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1797                .await
1798        })
1799    }
1800
1801    pub fn to_markdown(&self, cx: &App) -> String {
1802        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1803    }
1804
1805    pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
1806        cx.emit(AcpThreadEvent::LoadError(error));
1807    }
1808}
1809
1810fn markdown_for_raw_output(
1811    raw_output: &serde_json::Value,
1812    language_registry: &Arc<LanguageRegistry>,
1813    cx: &mut App,
1814) -> Option<Entity<Markdown>> {
1815    match raw_output {
1816        serde_json::Value::Null => None,
1817        serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1818            Markdown::new(
1819                value.to_string().into(),
1820                Some(language_registry.clone()),
1821                None,
1822                cx,
1823            )
1824        })),
1825        serde_json::Value::Number(value) => Some(cx.new(|cx| {
1826            Markdown::new(
1827                value.to_string().into(),
1828                Some(language_registry.clone()),
1829                None,
1830                cx,
1831            )
1832        })),
1833        serde_json::Value::String(value) => Some(cx.new(|cx| {
1834            Markdown::new(
1835                value.clone().into(),
1836                Some(language_registry.clone()),
1837                None,
1838                cx,
1839            )
1840        })),
1841        value => Some(cx.new(|cx| {
1842            Markdown::new(
1843                format!("```json\n{}\n```", value).into(),
1844                Some(language_registry.clone()),
1845                None,
1846                cx,
1847            )
1848        })),
1849    }
1850}
1851
1852#[cfg(test)]
1853mod tests {
1854    use super::*;
1855    use anyhow::anyhow;
1856    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1857    use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1858    use indoc::indoc;
1859    use project::{FakeFs, Fs};
1860    use rand::Rng as _;
1861    use serde_json::json;
1862    use settings::SettingsStore;
1863    use smol::stream::StreamExt as _;
1864    use std::{
1865        any::Any,
1866        cell::RefCell,
1867        path::Path,
1868        rc::Rc,
1869        sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1870        time::Duration,
1871    };
1872    use util::path;
1873
1874    fn init_test(cx: &mut TestAppContext) {
1875        env_logger::try_init().ok();
1876        cx.update(|cx| {
1877            let settings_store = SettingsStore::test(cx);
1878            cx.set_global(settings_store);
1879            Project::init_settings(cx);
1880            language::init(cx);
1881        });
1882    }
1883
1884    #[gpui::test]
1885    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1886        init_test(cx);
1887
1888        let fs = FakeFs::new(cx.executor());
1889        let project = Project::test(fs, [], cx).await;
1890        let connection = Rc::new(FakeAgentConnection::new());
1891        let thread = cx
1892            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1893            .await
1894            .unwrap();
1895
1896        // Test creating a new user message
1897        thread.update(cx, |thread, cx| {
1898            thread.push_user_content_block(
1899                None,
1900                acp::ContentBlock::Text(acp::TextContent {
1901                    annotations: None,
1902                    text: "Hello, ".to_string(),
1903                }),
1904                cx,
1905            );
1906        });
1907
1908        thread.update(cx, |thread, cx| {
1909            assert_eq!(thread.entries.len(), 1);
1910            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1911                assert_eq!(user_msg.id, None);
1912                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1913            } else {
1914                panic!("Expected UserMessage");
1915            }
1916        });
1917
1918        // Test appending to existing user message
1919        let message_1_id = UserMessageId::new();
1920        thread.update(cx, |thread, cx| {
1921            thread.push_user_content_block(
1922                Some(message_1_id.clone()),
1923                acp::ContentBlock::Text(acp::TextContent {
1924                    annotations: None,
1925                    text: "world!".to_string(),
1926                }),
1927                cx,
1928            );
1929        });
1930
1931        thread.update(cx, |thread, cx| {
1932            assert_eq!(thread.entries.len(), 1);
1933            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1934                assert_eq!(user_msg.id, Some(message_1_id));
1935                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1936            } else {
1937                panic!("Expected UserMessage");
1938            }
1939        });
1940
1941        // Test creating new user message after assistant message
1942        thread.update(cx, |thread, cx| {
1943            thread.push_assistant_content_block(
1944                acp::ContentBlock::Text(acp::TextContent {
1945                    annotations: None,
1946                    text: "Assistant response".to_string(),
1947                }),
1948                false,
1949                cx,
1950            );
1951        });
1952
1953        let message_2_id = UserMessageId::new();
1954        thread.update(cx, |thread, cx| {
1955            thread.push_user_content_block(
1956                Some(message_2_id.clone()),
1957                acp::ContentBlock::Text(acp::TextContent {
1958                    annotations: None,
1959                    text: "New user message".to_string(),
1960                }),
1961                cx,
1962            );
1963        });
1964
1965        thread.update(cx, |thread, cx| {
1966            assert_eq!(thread.entries.len(), 3);
1967            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1968                assert_eq!(user_msg.id, Some(message_2_id));
1969                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1970            } else {
1971                panic!("Expected UserMessage at index 2");
1972            }
1973        });
1974    }
1975
1976    #[gpui::test]
1977    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1978        init_test(cx);
1979
1980        let fs = FakeFs::new(cx.executor());
1981        let project = Project::test(fs, [], cx).await;
1982        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1983            |_, thread, mut cx| {
1984                async move {
1985                    thread.update(&mut cx, |thread, cx| {
1986                        thread
1987                            .handle_session_update(
1988                                acp::SessionUpdate::AgentThoughtChunk {
1989                                    content: "Thinking ".into(),
1990                                },
1991                                cx,
1992                            )
1993                            .unwrap();
1994                        thread
1995                            .handle_session_update(
1996                                acp::SessionUpdate::AgentThoughtChunk {
1997                                    content: "hard!".into(),
1998                                },
1999                                cx,
2000                            )
2001                            .unwrap();
2002                    })?;
2003                    Ok(acp::PromptResponse {
2004                        stop_reason: acp::StopReason::EndTurn,
2005                    })
2006                }
2007                .boxed_local()
2008            },
2009        ));
2010
2011        let thread = cx
2012            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2013            .await
2014            .unwrap();
2015
2016        thread
2017            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2018            .await
2019            .unwrap();
2020
2021        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2022        assert_eq!(
2023            output,
2024            indoc! {r#"
2025            ## User
2026
2027            Hello from Zed!
2028
2029            ## Assistant
2030
2031            <thinking>
2032            Thinking hard!
2033            </thinking>
2034
2035            "#}
2036        );
2037    }
2038
2039    #[gpui::test]
2040    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2041        init_test(cx);
2042
2043        let fs = FakeFs::new(cx.executor());
2044        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2045            .await;
2046        let project = Project::test(fs.clone(), [], cx).await;
2047        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2048        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2049        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2050            move |_, thread, mut cx| {
2051                let read_file_tx = read_file_tx.clone();
2052                async move {
2053                    let content = thread
2054                        .update(&mut cx, |thread, cx| {
2055                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2056                        })
2057                        .unwrap()
2058                        .await
2059                        .unwrap();
2060                    assert_eq!(content, "one\ntwo\nthree\n");
2061                    read_file_tx.take().unwrap().send(()).unwrap();
2062                    thread
2063                        .update(&mut cx, |thread, cx| {
2064                            thread.write_text_file(
2065                                path!("/tmp/foo").into(),
2066                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
2067                                cx,
2068                            )
2069                        })
2070                        .unwrap()
2071                        .await
2072                        .unwrap();
2073                    Ok(acp::PromptResponse {
2074                        stop_reason: acp::StopReason::EndTurn,
2075                    })
2076                }
2077                .boxed_local()
2078            },
2079        ));
2080
2081        let (worktree, pathbuf) = project
2082            .update(cx, |project, cx| {
2083                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2084            })
2085            .await
2086            .unwrap();
2087        let buffer = project
2088            .update(cx, |project, cx| {
2089                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2090            })
2091            .await
2092            .unwrap();
2093
2094        let thread = cx
2095            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2096            .await
2097            .unwrap();
2098
2099        let request = thread.update(cx, |thread, cx| {
2100            thread.send_raw("Extend the count in /tmp/foo", cx)
2101        });
2102        read_file_rx.await.ok();
2103        buffer.update(cx, |buffer, cx| {
2104            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2105        });
2106        cx.run_until_parked();
2107        assert_eq!(
2108            buffer.read_with(cx, |buffer, _| buffer.text()),
2109            "zero\none\ntwo\nthree\nfour\nfive\n"
2110        );
2111        assert_eq!(
2112            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2113            "zero\none\ntwo\nthree\nfour\nfive\n"
2114        );
2115        request.await.unwrap();
2116    }
2117
2118    #[gpui::test]
2119    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2120        init_test(cx);
2121
2122        let fs = FakeFs::new(cx.executor());
2123        let project = Project::test(fs, [], cx).await;
2124        let id = acp::ToolCallId("test".into());
2125
2126        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2127            let id = id.clone();
2128            move |_, thread, mut cx| {
2129                let id = id.clone();
2130                async move {
2131                    thread
2132                        .update(&mut cx, |thread, cx| {
2133                            thread.handle_session_update(
2134                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2135                                    id: id.clone(),
2136                                    title: "Label".into(),
2137                                    kind: acp::ToolKind::Fetch,
2138                                    status: acp::ToolCallStatus::InProgress,
2139                                    content: vec![],
2140                                    locations: vec![],
2141                                    raw_input: None,
2142                                    raw_output: None,
2143                                }),
2144                                cx,
2145                            )
2146                        })
2147                        .unwrap()
2148                        .unwrap();
2149                    Ok(acp::PromptResponse {
2150                        stop_reason: acp::StopReason::EndTurn,
2151                    })
2152                }
2153                .boxed_local()
2154            }
2155        }));
2156
2157        let thread = cx
2158            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2159            .await
2160            .unwrap();
2161
2162        let request = thread.update(cx, |thread, cx| {
2163            thread.send_raw("Fetch https://example.com", cx)
2164        });
2165
2166        run_until_first_tool_call(&thread, cx).await;
2167
2168        thread.read_with(cx, |thread, _| {
2169            assert!(matches!(
2170                thread.entries[1],
2171                AgentThreadEntry::ToolCall(ToolCall {
2172                    status: ToolCallStatus::InProgress,
2173                    ..
2174                })
2175            ));
2176        });
2177
2178        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2179
2180        thread.read_with(cx, |thread, _| {
2181            assert!(matches!(
2182                &thread.entries[1],
2183                AgentThreadEntry::ToolCall(ToolCall {
2184                    status: ToolCallStatus::Canceled,
2185                    ..
2186                })
2187            ));
2188        });
2189
2190        thread
2191            .update(cx, |thread, cx| {
2192                thread.handle_session_update(
2193                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2194                        id,
2195                        fields: acp::ToolCallUpdateFields {
2196                            status: Some(acp::ToolCallStatus::Completed),
2197                            ..Default::default()
2198                        },
2199                    }),
2200                    cx,
2201                )
2202            })
2203            .unwrap();
2204
2205        request.await.unwrap();
2206
2207        thread.read_with(cx, |thread, _| {
2208            assert!(matches!(
2209                thread.entries[1],
2210                AgentThreadEntry::ToolCall(ToolCall {
2211                    status: ToolCallStatus::Completed,
2212                    ..
2213                })
2214            ));
2215        });
2216    }
2217
2218    #[gpui::test]
2219    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2220        init_test(cx);
2221        let fs = FakeFs::new(cx.background_executor.clone());
2222        fs.insert_tree(path!("/test"), json!({})).await;
2223        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2224
2225        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2226            move |_, thread, mut cx| {
2227                async move {
2228                    thread
2229                        .update(&mut cx, |thread, cx| {
2230                            thread.handle_session_update(
2231                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2232                                    id: acp::ToolCallId("test".into()),
2233                                    title: "Label".into(),
2234                                    kind: acp::ToolKind::Edit,
2235                                    status: acp::ToolCallStatus::Completed,
2236                                    content: vec![acp::ToolCallContent::Diff {
2237                                        diff: acp::Diff {
2238                                            path: "/test/test.txt".into(),
2239                                            old_text: None,
2240                                            new_text: "foo".into(),
2241                                        },
2242                                    }],
2243                                    locations: vec![],
2244                                    raw_input: None,
2245                                    raw_output: None,
2246                                }),
2247                                cx,
2248                            )
2249                        })
2250                        .unwrap()
2251                        .unwrap();
2252                    Ok(acp::PromptResponse {
2253                        stop_reason: acp::StopReason::EndTurn,
2254                    })
2255                }
2256                .boxed_local()
2257            }
2258        }));
2259
2260        let thread = cx
2261            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2262            .await
2263            .unwrap();
2264
2265        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2266            .await
2267            .unwrap();
2268
2269        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2270    }
2271
2272    #[gpui::test(iterations = 10)]
2273    async fn test_checkpoints(cx: &mut TestAppContext) {
2274        init_test(cx);
2275        let fs = FakeFs::new(cx.background_executor.clone());
2276        fs.insert_tree(
2277            path!("/test"),
2278            json!({
2279                ".git": {}
2280            }),
2281        )
2282        .await;
2283        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2284
2285        let simulate_changes = Arc::new(AtomicBool::new(true));
2286        let next_filename = Arc::new(AtomicUsize::new(0));
2287        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2288            let simulate_changes = simulate_changes.clone();
2289            let next_filename = next_filename.clone();
2290            let fs = fs.clone();
2291            move |request, thread, mut cx| {
2292                let fs = fs.clone();
2293                let simulate_changes = simulate_changes.clone();
2294                let next_filename = next_filename.clone();
2295                async move {
2296                    if simulate_changes.load(SeqCst) {
2297                        let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2298                        fs.write(Path::new(&filename), b"").await?;
2299                    }
2300
2301                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2302                        panic!("expected text content block");
2303                    };
2304                    thread.update(&mut cx, |thread, cx| {
2305                        thread
2306                            .handle_session_update(
2307                                acp::SessionUpdate::AgentMessageChunk {
2308                                    content: content.text.to_uppercase().into(),
2309                                },
2310                                cx,
2311                            )
2312                            .unwrap();
2313                    })?;
2314                    Ok(acp::PromptResponse {
2315                        stop_reason: acp::StopReason::EndTurn,
2316                    })
2317                }
2318                .boxed_local()
2319            }
2320        }));
2321        let thread = cx
2322            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2323            .await
2324            .unwrap();
2325
2326        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2327            .await
2328            .unwrap();
2329        thread.read_with(cx, |thread, cx| {
2330            assert_eq!(
2331                thread.to_markdown(cx),
2332                indoc! {"
2333                    ## User (checkpoint)
2334
2335                    Lorem
2336
2337                    ## Assistant
2338
2339                    LOREM
2340
2341                "}
2342            );
2343        });
2344        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2345
2346        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2347            .await
2348            .unwrap();
2349        thread.read_with(cx, |thread, cx| {
2350            assert_eq!(
2351                thread.to_markdown(cx),
2352                indoc! {"
2353                    ## User (checkpoint)
2354
2355                    Lorem
2356
2357                    ## Assistant
2358
2359                    LOREM
2360
2361                    ## User (checkpoint)
2362
2363                    ipsum
2364
2365                    ## Assistant
2366
2367                    IPSUM
2368
2369                "}
2370            );
2371        });
2372        assert_eq!(
2373            fs.files(),
2374            vec![
2375                Path::new(path!("/test/file-0")),
2376                Path::new(path!("/test/file-1"))
2377            ]
2378        );
2379
2380        // Checkpoint isn't stored when there are no changes.
2381        simulate_changes.store(false, SeqCst);
2382        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2383            .await
2384            .unwrap();
2385        thread.read_with(cx, |thread, cx| {
2386            assert_eq!(
2387                thread.to_markdown(cx),
2388                indoc! {"
2389                    ## User (checkpoint)
2390
2391                    Lorem
2392
2393                    ## Assistant
2394
2395                    LOREM
2396
2397                    ## User (checkpoint)
2398
2399                    ipsum
2400
2401                    ## Assistant
2402
2403                    IPSUM
2404
2405                    ## User
2406
2407                    dolor
2408
2409                    ## Assistant
2410
2411                    DOLOR
2412
2413                "}
2414            );
2415        });
2416        assert_eq!(
2417            fs.files(),
2418            vec![
2419                Path::new(path!("/test/file-0")),
2420                Path::new(path!("/test/file-1"))
2421            ]
2422        );
2423
2424        // Rewinding the conversation truncates the history and restores the checkpoint.
2425        thread
2426            .update(cx, |thread, cx| {
2427                let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2428                    panic!("unexpected entries {:?}", thread.entries)
2429                };
2430                thread.rewind(message.id.clone().unwrap(), cx)
2431            })
2432            .await
2433            .unwrap();
2434        thread.read_with(cx, |thread, cx| {
2435            assert_eq!(
2436                thread.to_markdown(cx),
2437                indoc! {"
2438                    ## User (checkpoint)
2439
2440                    Lorem
2441
2442                    ## Assistant
2443
2444                    LOREM
2445
2446                "}
2447            );
2448        });
2449        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2450    }
2451
2452    #[gpui::test]
2453    async fn test_refusal(cx: &mut TestAppContext) {
2454        init_test(cx);
2455        let fs = FakeFs::new(cx.background_executor.clone());
2456        fs.insert_tree(path!("/"), json!({})).await;
2457        let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2458
2459        let refuse_next = Arc::new(AtomicBool::new(false));
2460        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2461            let refuse_next = refuse_next.clone();
2462            move |request, thread, mut cx| {
2463                let refuse_next = refuse_next.clone();
2464                async move {
2465                    if refuse_next.load(SeqCst) {
2466                        return Ok(acp::PromptResponse {
2467                            stop_reason: acp::StopReason::Refusal,
2468                        });
2469                    }
2470
2471                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2472                        panic!("expected text content block");
2473                    };
2474                    thread.update(&mut cx, |thread, cx| {
2475                        thread
2476                            .handle_session_update(
2477                                acp::SessionUpdate::AgentMessageChunk {
2478                                    content: content.text.to_uppercase().into(),
2479                                },
2480                                cx,
2481                            )
2482                            .unwrap();
2483                    })?;
2484                    Ok(acp::PromptResponse {
2485                        stop_reason: acp::StopReason::EndTurn,
2486                    })
2487                }
2488                .boxed_local()
2489            }
2490        }));
2491        let thread = cx
2492            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2493            .await
2494            .unwrap();
2495
2496        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2497            .await
2498            .unwrap();
2499        thread.read_with(cx, |thread, cx| {
2500            assert_eq!(
2501                thread.to_markdown(cx),
2502                indoc! {"
2503                    ## User
2504
2505                    hello
2506
2507                    ## Assistant
2508
2509                    HELLO
2510
2511                "}
2512            );
2513        });
2514
2515        // Simulate refusing the second message, ensuring the conversation gets
2516        // truncated to before sending it.
2517        refuse_next.store(true, SeqCst);
2518        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
2519            .await
2520            .unwrap();
2521        thread.read_with(cx, |thread, cx| {
2522            assert_eq!(
2523                thread.to_markdown(cx),
2524                indoc! {"
2525                    ## User
2526
2527                    hello
2528
2529                    ## Assistant
2530
2531                    HELLO
2532
2533                "}
2534            );
2535        });
2536    }
2537
2538    async fn run_until_first_tool_call(
2539        thread: &Entity<AcpThread>,
2540        cx: &mut TestAppContext,
2541    ) -> usize {
2542        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2543
2544        let subscription = cx.update(|cx| {
2545            cx.subscribe(thread, move |thread, _, cx| {
2546                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2547                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2548                        return tx.try_send(ix).unwrap();
2549                    }
2550                }
2551            })
2552        });
2553
2554        select! {
2555            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2556                panic!("Timeout waiting for tool call")
2557            }
2558            ix = rx.next().fuse() => {
2559                drop(subscription);
2560                ix.unwrap()
2561            }
2562        }
2563    }
2564
2565    #[derive(Clone, Default)]
2566    struct FakeAgentConnection {
2567        auth_methods: Vec<acp::AuthMethod>,
2568        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2569        on_user_message: Option<
2570            Rc<
2571                dyn Fn(
2572                        acp::PromptRequest,
2573                        WeakEntity<AcpThread>,
2574                        AsyncApp,
2575                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2576                    + 'static,
2577            >,
2578        >,
2579    }
2580
2581    impl FakeAgentConnection {
2582        fn new() -> Self {
2583            Self {
2584                auth_methods: Vec::new(),
2585                on_user_message: None,
2586                sessions: Arc::default(),
2587            }
2588        }
2589
2590        #[expect(unused)]
2591        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2592            self.auth_methods = auth_methods;
2593            self
2594        }
2595
2596        fn on_user_message(
2597            mut self,
2598            handler: impl Fn(
2599                acp::PromptRequest,
2600                WeakEntity<AcpThread>,
2601                AsyncApp,
2602            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2603            + 'static,
2604        ) -> Self {
2605            self.on_user_message.replace(Rc::new(handler));
2606            self
2607        }
2608    }
2609
2610    impl AgentConnection for FakeAgentConnection {
2611        fn auth_methods(&self) -> &[acp::AuthMethod] {
2612            &self.auth_methods
2613        }
2614
2615        fn new_thread(
2616            self: Rc<Self>,
2617            project: Entity<Project>,
2618            _cwd: &Path,
2619            cx: &mut App,
2620        ) -> Task<gpui::Result<Entity<AcpThread>>> {
2621            let session_id = acp::SessionId(
2622                rand::thread_rng()
2623                    .sample_iter(&rand::distributions::Alphanumeric)
2624                    .take(7)
2625                    .map(char::from)
2626                    .collect::<String>()
2627                    .into(),
2628            );
2629            let action_log = cx.new(|_| ActionLog::new(project.clone()));
2630            let thread = cx.new(|cx| {
2631                AcpThread::new(
2632                    "Test",
2633                    self.clone(),
2634                    project,
2635                    action_log,
2636                    session_id.clone(),
2637                    watch::Receiver::constant(acp::PromptCapabilities {
2638                        image: true,
2639                        audio: true,
2640                        embedded_context: true,
2641                    }),
2642                    cx,
2643                )
2644            });
2645            self.sessions.lock().insert(session_id, thread.downgrade());
2646            Task::ready(Ok(thread))
2647        }
2648
2649        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2650            if self.auth_methods().iter().any(|m| m.id == method) {
2651                Task::ready(Ok(()))
2652            } else {
2653                Task::ready(Err(anyhow!("Invalid Auth Method")))
2654            }
2655        }
2656
2657        fn prompt(
2658            &self,
2659            _id: Option<UserMessageId>,
2660            params: acp::PromptRequest,
2661            cx: &mut App,
2662        ) -> Task<gpui::Result<acp::PromptResponse>> {
2663            let sessions = self.sessions.lock();
2664            let thread = sessions.get(&params.session_id).unwrap();
2665            if let Some(handler) = &self.on_user_message {
2666                let handler = handler.clone();
2667                let thread = thread.clone();
2668                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2669            } else {
2670                Task::ready(Ok(acp::PromptResponse {
2671                    stop_reason: acp::StopReason::EndTurn,
2672                }))
2673            }
2674        }
2675
2676        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2677            let sessions = self.sessions.lock();
2678            let thread = sessions.get(session_id).unwrap().clone();
2679
2680            cx.spawn(async move |cx| {
2681                thread
2682                    .update(cx, |thread, cx| thread.cancel(cx))
2683                    .unwrap()
2684                    .await
2685            })
2686            .detach();
2687        }
2688
2689        fn truncate(
2690            &self,
2691            session_id: &acp::SessionId,
2692            _cx: &App,
2693        ) -> Option<Rc<dyn AgentSessionTruncate>> {
2694            Some(Rc::new(FakeAgentSessionEditor {
2695                _session_id: session_id.clone(),
2696            }))
2697        }
2698
2699        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2700            self
2701        }
2702    }
2703
2704    struct FakeAgentSessionEditor {
2705        _session_id: acp::SessionId,
2706    }
2707
2708    impl AgentSessionTruncate for FakeAgentSessionEditor {
2709        fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2710            Task::ready(Ok(()))
2711        }
2712    }
2713}