acp_thread.rs

   1mod connection;
   2mod diff;
   3mod mention;
   4mod terminal;
   5
   6use collections::HashSet;
   7pub use connection::*;
   8pub use diff::*;
   9use language::language_settings::FormatOnSave;
  10pub use mention::*;
  11use project::lsp_store::{FormatTrigger, LspFormatTarget};
  12use serde::{Deserialize, Serialize};
  13pub use terminal::*;
  14
  15use action_log::ActionLog;
  16use agent_client_protocol as acp;
  17use anyhow::{Context as _, Result, anyhow};
  18use editor::Bias;
  19use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  20use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  21use itertools::Itertools;
  22use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
  23use markdown::Markdown;
  24use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
  25use std::collections::HashMap;
  26use std::error::Error;
  27use std::fmt::{Formatter, Write};
  28use std::ops::Range;
  29use std::process::ExitStatus;
  30use std::rc::Rc;
  31use std::time::{Duration, Instant};
  32use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
  33use ui::App;
  34use util::ResultExt;
  35
  36#[derive(Debug)]
  37pub struct UserMessage {
  38    pub id: Option<UserMessageId>,
  39    pub content: ContentBlock,
  40    pub chunks: Vec<acp::ContentBlock>,
  41    pub checkpoint: Option<Checkpoint>,
  42}
  43
  44#[derive(Debug)]
  45pub struct Checkpoint {
  46    git_checkpoint: GitStoreCheckpoint,
  47    pub show: bool,
  48}
  49
  50impl UserMessage {
  51    fn to_markdown(&self, cx: &App) -> String {
  52        let mut markdown = String::new();
  53        if self
  54            .checkpoint
  55            .as_ref()
  56            .is_some_and(|checkpoint| checkpoint.show)
  57        {
  58            writeln!(markdown, "## User (checkpoint)").unwrap();
  59        } else {
  60            writeln!(markdown, "## User").unwrap();
  61        }
  62        writeln!(markdown).unwrap();
  63        writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
  64        writeln!(markdown).unwrap();
  65        markdown
  66    }
  67}
  68
  69#[derive(Debug, PartialEq)]
  70pub struct AssistantMessage {
  71    pub chunks: Vec<AssistantMessageChunk>,
  72}
  73
  74impl AssistantMessage {
  75    pub fn to_markdown(&self, cx: &App) -> String {
  76        format!(
  77            "## Assistant\n\n{}\n\n",
  78            self.chunks
  79                .iter()
  80                .map(|chunk| chunk.to_markdown(cx))
  81                .join("\n\n")
  82        )
  83    }
  84}
  85
  86#[derive(Debug, PartialEq)]
  87pub enum AssistantMessageChunk {
  88    Message { block: ContentBlock },
  89    Thought { block: ContentBlock },
  90}
  91
  92impl AssistantMessageChunk {
  93    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
  94        Self::Message {
  95            block: ContentBlock::new(chunk.into(), language_registry, cx),
  96        }
  97    }
  98
  99    fn to_markdown(&self, cx: &App) -> String {
 100        match self {
 101            Self::Message { block } => block.to_markdown(cx).to_string(),
 102            Self::Thought { block } => {
 103                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 104            }
 105        }
 106    }
 107}
 108
 109#[derive(Debug)]
 110pub enum AgentThreadEntry {
 111    UserMessage(UserMessage),
 112    AssistantMessage(AssistantMessage),
 113    ToolCall(ToolCall),
 114}
 115
 116impl AgentThreadEntry {
 117    pub fn to_markdown(&self, cx: &App) -> String {
 118        match self {
 119            Self::UserMessage(message) => message.to_markdown(cx),
 120            Self::AssistantMessage(message) => message.to_markdown(cx),
 121            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 122        }
 123    }
 124
 125    pub fn user_message(&self) -> Option<&UserMessage> {
 126        if let AgentThreadEntry::UserMessage(message) = self {
 127            Some(message)
 128        } else {
 129            None
 130        }
 131    }
 132
 133    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 134        if let AgentThreadEntry::ToolCall(call) = self {
 135            itertools::Either::Left(call.diffs())
 136        } else {
 137            itertools::Either::Right(std::iter::empty())
 138        }
 139    }
 140
 141    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 142        if let AgentThreadEntry::ToolCall(call) = self {
 143            itertools::Either::Left(call.terminals())
 144        } else {
 145            itertools::Either::Right(std::iter::empty())
 146        }
 147    }
 148
 149    pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
 150        if let AgentThreadEntry::ToolCall(ToolCall {
 151            locations,
 152            resolved_locations,
 153            ..
 154        }) = self
 155        {
 156            Some((
 157                locations.get(ix)?.clone(),
 158                resolved_locations.get(ix)?.clone()?,
 159            ))
 160        } else {
 161            None
 162        }
 163    }
 164}
 165
 166#[derive(Debug)]
 167pub struct ToolCall {
 168    pub id: acp::ToolCallId,
 169    pub label: Entity<Markdown>,
 170    pub kind: acp::ToolKind,
 171    pub content: Vec<ToolCallContent>,
 172    pub status: ToolCallStatus,
 173    pub locations: Vec<acp::ToolCallLocation>,
 174    pub resolved_locations: Vec<Option<AgentLocation>>,
 175    pub raw_input: Option<serde_json::Value>,
 176    pub raw_output: Option<serde_json::Value>,
 177}
 178
 179impl ToolCall {
 180    fn from_acp(
 181        tool_call: acp::ToolCall,
 182        status: ToolCallStatus,
 183        language_registry: Arc<LanguageRegistry>,
 184        cx: &mut App,
 185    ) -> Self {
 186        let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
 187            first_line.to_owned() + ""
 188        } else {
 189            tool_call.title
 190        };
 191        Self {
 192            id: tool_call.id,
 193            label: cx
 194                .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
 195            kind: tool_call.kind,
 196            content: tool_call
 197                .content
 198                .into_iter()
 199                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 200                .collect(),
 201            locations: tool_call.locations,
 202            resolved_locations: Vec::default(),
 203            status,
 204            raw_input: tool_call.raw_input,
 205            raw_output: tool_call.raw_output,
 206        }
 207    }
 208
 209    fn update_fields(
 210        &mut self,
 211        fields: acp::ToolCallUpdateFields,
 212        language_registry: Arc<LanguageRegistry>,
 213        cx: &mut App,
 214    ) {
 215        let acp::ToolCallUpdateFields {
 216            kind,
 217            status,
 218            title,
 219            content,
 220            locations,
 221            raw_input,
 222            raw_output,
 223        } = fields;
 224
 225        if let Some(kind) = kind {
 226            self.kind = kind;
 227        }
 228
 229        if let Some(status) = status {
 230            self.status = status.into();
 231        }
 232
 233        if let Some(title) = title {
 234            self.label.update(cx, |label, cx| {
 235                if let Some((first_line, _)) = title.split_once("\n") {
 236                    label.replace(first_line.to_owned() + "", cx)
 237                } else {
 238                    label.replace(title, cx);
 239                }
 240            });
 241        }
 242
 243        if let Some(content) = content {
 244            let new_content_len = content.len();
 245            let mut content = content.into_iter();
 246
 247            // Reuse existing content if we can
 248            for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
 249                old.update_from_acp(new, language_registry.clone(), cx);
 250            }
 251            for new in content {
 252                self.content.push(ToolCallContent::from_acp(
 253                    new,
 254                    language_registry.clone(),
 255                    cx,
 256                ))
 257            }
 258            self.content.truncate(new_content_len);
 259        }
 260
 261        if let Some(locations) = locations {
 262            self.locations = locations;
 263        }
 264
 265        if let Some(raw_input) = raw_input {
 266            self.raw_input = Some(raw_input);
 267        }
 268
 269        if let Some(raw_output) = raw_output {
 270            if self.content.is_empty()
 271                && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
 272            {
 273                self.content
 274                    .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
 275                        markdown,
 276                    }));
 277            }
 278            self.raw_output = Some(raw_output);
 279        }
 280    }
 281
 282    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 283        self.content.iter().filter_map(|content| match content {
 284            ToolCallContent::Diff(diff) => Some(diff),
 285            ToolCallContent::ContentBlock(_) => None,
 286            ToolCallContent::Terminal(_) => None,
 287        })
 288    }
 289
 290    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 291        self.content.iter().filter_map(|content| match content {
 292            ToolCallContent::Terminal(terminal) => Some(terminal),
 293            ToolCallContent::ContentBlock(_) => None,
 294            ToolCallContent::Diff(_) => None,
 295        })
 296    }
 297
 298    fn to_markdown(&self, cx: &App) -> String {
 299        let mut markdown = format!(
 300            "**Tool Call: {}**\nStatus: {}\n\n",
 301            self.label.read(cx).source(),
 302            self.status
 303        );
 304        for content in &self.content {
 305            markdown.push_str(content.to_markdown(cx).as_str());
 306            markdown.push_str("\n\n");
 307        }
 308        markdown
 309    }
 310
 311    async fn resolve_location(
 312        location: acp::ToolCallLocation,
 313        project: WeakEntity<Project>,
 314        cx: &mut AsyncApp,
 315    ) -> Option<AgentLocation> {
 316        let buffer = project
 317            .update(cx, |project, cx| {
 318                project
 319                    .project_path_for_absolute_path(&location.path, cx)
 320                    .map(|path| project.open_buffer(path, cx))
 321            })
 322            .ok()??;
 323        let buffer = buffer.await.log_err()?;
 324        let position = buffer
 325            .update(cx, |buffer, _| {
 326                if let Some(row) = location.line {
 327                    let snapshot = buffer.snapshot();
 328                    let column = snapshot.indent_size_for_line(row).len;
 329                    let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
 330                    snapshot.anchor_before(point)
 331                } else {
 332                    Anchor::MIN
 333                }
 334            })
 335            .ok()?;
 336
 337        Some(AgentLocation {
 338            buffer: buffer.downgrade(),
 339            position,
 340        })
 341    }
 342
 343    fn resolve_locations(
 344        &self,
 345        project: Entity<Project>,
 346        cx: &mut App,
 347    ) -> Task<Vec<Option<AgentLocation>>> {
 348        let locations = self.locations.clone();
 349        project.update(cx, |_, cx| {
 350            cx.spawn(async move |project, cx| {
 351                let mut new_locations = Vec::new();
 352                for location in locations {
 353                    new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
 354                }
 355                new_locations
 356            })
 357        })
 358    }
 359}
 360
 361#[derive(Debug)]
 362pub enum ToolCallStatus {
 363    /// The tool call hasn't started running yet, but we start showing it to
 364    /// the user.
 365    Pending,
 366    /// The tool call is waiting for confirmation from the user.
 367    WaitingForConfirmation {
 368        options: Vec<acp::PermissionOption>,
 369        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 370    },
 371    /// The tool call is currently running.
 372    InProgress,
 373    /// The tool call completed successfully.
 374    Completed,
 375    /// The tool call failed.
 376    Failed,
 377    /// The user rejected the tool call.
 378    Rejected,
 379    /// The user canceled generation so the tool call was canceled.
 380    Canceled,
 381}
 382
 383impl From<acp::ToolCallStatus> for ToolCallStatus {
 384    fn from(status: acp::ToolCallStatus) -> Self {
 385        match status {
 386            acp::ToolCallStatus::Pending => Self::Pending,
 387            acp::ToolCallStatus::InProgress => Self::InProgress,
 388            acp::ToolCallStatus::Completed => Self::Completed,
 389            acp::ToolCallStatus::Failed => Self::Failed,
 390        }
 391    }
 392}
 393
 394impl Display for ToolCallStatus {
 395    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 396        write!(
 397            f,
 398            "{}",
 399            match self {
 400                ToolCallStatus::Pending => "Pending",
 401                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 402                ToolCallStatus::InProgress => "In Progress",
 403                ToolCallStatus::Completed => "Completed",
 404                ToolCallStatus::Failed => "Failed",
 405                ToolCallStatus::Rejected => "Rejected",
 406                ToolCallStatus::Canceled => "Canceled",
 407            }
 408        )
 409    }
 410}
 411
 412#[derive(Debug, PartialEq, Clone)]
 413pub enum ContentBlock {
 414    Empty,
 415    Markdown { markdown: Entity<Markdown> },
 416    ResourceLink { resource_link: acp::ResourceLink },
 417}
 418
 419impl ContentBlock {
 420    pub fn new(
 421        block: acp::ContentBlock,
 422        language_registry: &Arc<LanguageRegistry>,
 423        cx: &mut App,
 424    ) -> Self {
 425        let mut this = Self::Empty;
 426        this.append(block, language_registry, cx);
 427        this
 428    }
 429
 430    pub fn new_combined(
 431        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 432        language_registry: Arc<LanguageRegistry>,
 433        cx: &mut App,
 434    ) -> Self {
 435        let mut this = Self::Empty;
 436        for block in blocks {
 437            this.append(block, &language_registry, cx);
 438        }
 439        this
 440    }
 441
 442    pub fn append(
 443        &mut self,
 444        block: acp::ContentBlock,
 445        language_registry: &Arc<LanguageRegistry>,
 446        cx: &mut App,
 447    ) {
 448        if matches!(self, ContentBlock::Empty)
 449            && let acp::ContentBlock::ResourceLink(resource_link) = block
 450        {
 451            *self = ContentBlock::ResourceLink { resource_link };
 452            return;
 453        }
 454
 455        let new_content = self.block_string_contents(block);
 456
 457        match self {
 458            ContentBlock::Empty => {
 459                *self = Self::create_markdown_block(new_content, language_registry, cx);
 460            }
 461            ContentBlock::Markdown { markdown } => {
 462                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 463            }
 464            ContentBlock::ResourceLink { resource_link } => {
 465                let existing_content = Self::resource_link_md(&resource_link.uri);
 466                let combined = format!("{}\n{}", existing_content, new_content);
 467
 468                *self = Self::create_markdown_block(combined, language_registry, cx);
 469            }
 470        }
 471    }
 472
 473    fn create_markdown_block(
 474        content: String,
 475        language_registry: &Arc<LanguageRegistry>,
 476        cx: &mut App,
 477    ) -> ContentBlock {
 478        ContentBlock::Markdown {
 479            markdown: cx
 480                .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
 481        }
 482    }
 483
 484    fn block_string_contents(&self, block: acp::ContentBlock) -> String {
 485        match block {
 486            acp::ContentBlock::Text(text_content) => text_content.text,
 487            acp::ContentBlock::ResourceLink(resource_link) => {
 488                Self::resource_link_md(&resource_link.uri)
 489            }
 490            acp::ContentBlock::Resource(acp::EmbeddedResource {
 491                resource:
 492                    acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
 493                        uri,
 494                        ..
 495                    }),
 496                ..
 497            }) => Self::resource_link_md(&uri),
 498            acp::ContentBlock::Image(image) => Self::image_md(&image),
 499            acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
 500        }
 501    }
 502
 503    fn resource_link_md(uri: &str) -> String {
 504        if let Some(uri) = MentionUri::parse(uri).log_err() {
 505            uri.as_link().to_string()
 506        } else {
 507            uri.to_string()
 508        }
 509    }
 510
 511    fn image_md(_image: &acp::ImageContent) -> String {
 512        "`Image`".into()
 513    }
 514
 515    pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 516        match self {
 517            ContentBlock::Empty => "",
 518            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 519            ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
 520        }
 521    }
 522
 523    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 524        match self {
 525            ContentBlock::Empty => None,
 526            ContentBlock::Markdown { markdown } => Some(markdown),
 527            ContentBlock::ResourceLink { .. } => None,
 528        }
 529    }
 530
 531    pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
 532        match self {
 533            ContentBlock::ResourceLink { resource_link } => Some(resource_link),
 534            _ => None,
 535        }
 536    }
 537}
 538
 539#[derive(Debug)]
 540pub enum ToolCallContent {
 541    ContentBlock(ContentBlock),
 542    Diff(Entity<Diff>),
 543    Terminal(Entity<Terminal>),
 544}
 545
 546impl ToolCallContent {
 547    pub fn from_acp(
 548        content: acp::ToolCallContent,
 549        language_registry: Arc<LanguageRegistry>,
 550        cx: &mut App,
 551    ) -> Self {
 552        match content {
 553            acp::ToolCallContent::Content { content } => {
 554                Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
 555            }
 556            acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
 557                Diff::finalized(
 558                    diff.path,
 559                    diff.old_text,
 560                    diff.new_text,
 561                    language_registry,
 562                    cx,
 563                )
 564            })),
 565        }
 566    }
 567
 568    pub fn update_from_acp(
 569        &mut self,
 570        new: acp::ToolCallContent,
 571        language_registry: Arc<LanguageRegistry>,
 572        cx: &mut App,
 573    ) {
 574        let needs_update = match (&self, &new) {
 575            (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
 576                old_diff.read(cx).needs_update(
 577                    new_diff.old_text.as_deref().unwrap_or(""),
 578                    &new_diff.new_text,
 579                    cx,
 580                )
 581            }
 582            _ => true,
 583        };
 584
 585        if needs_update {
 586            *self = Self::from_acp(new, language_registry, cx);
 587        }
 588    }
 589
 590    pub fn to_markdown(&self, cx: &App) -> String {
 591        match self {
 592            Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
 593            Self::Diff(diff) => diff.read(cx).to_markdown(cx),
 594            Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
 595        }
 596    }
 597}
 598
 599#[derive(Debug, PartialEq)]
 600pub enum ToolCallUpdate {
 601    UpdateFields(acp::ToolCallUpdate),
 602    UpdateDiff(ToolCallUpdateDiff),
 603    UpdateTerminal(ToolCallUpdateTerminal),
 604}
 605
 606impl ToolCallUpdate {
 607    fn id(&self) -> &acp::ToolCallId {
 608        match self {
 609            Self::UpdateFields(update) => &update.id,
 610            Self::UpdateDiff(diff) => &diff.id,
 611            Self::UpdateTerminal(terminal) => &terminal.id,
 612        }
 613    }
 614}
 615
 616impl From<acp::ToolCallUpdate> for ToolCallUpdate {
 617    fn from(update: acp::ToolCallUpdate) -> Self {
 618        Self::UpdateFields(update)
 619    }
 620}
 621
 622impl From<ToolCallUpdateDiff> for ToolCallUpdate {
 623    fn from(diff: ToolCallUpdateDiff) -> Self {
 624        Self::UpdateDiff(diff)
 625    }
 626}
 627
 628#[derive(Debug, PartialEq)]
 629pub struct ToolCallUpdateDiff {
 630    pub id: acp::ToolCallId,
 631    pub diff: Entity<Diff>,
 632}
 633
 634impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
 635    fn from(terminal: ToolCallUpdateTerminal) -> Self {
 636        Self::UpdateTerminal(terminal)
 637    }
 638}
 639
 640#[derive(Debug, PartialEq)]
 641pub struct ToolCallUpdateTerminal {
 642    pub id: acp::ToolCallId,
 643    pub terminal: Entity<Terminal>,
 644}
 645
 646#[derive(Debug, Default)]
 647pub struct Plan {
 648    pub entries: Vec<PlanEntry>,
 649}
 650
 651#[derive(Debug)]
 652pub struct PlanStats<'a> {
 653    pub in_progress_entry: Option<&'a PlanEntry>,
 654    pub pending: u32,
 655    pub completed: u32,
 656}
 657
 658impl Plan {
 659    pub fn is_empty(&self) -> bool {
 660        self.entries.is_empty()
 661    }
 662
 663    pub fn stats(&self) -> PlanStats<'_> {
 664        let mut stats = PlanStats {
 665            in_progress_entry: None,
 666            pending: 0,
 667            completed: 0,
 668        };
 669
 670        for entry in &self.entries {
 671            match &entry.status {
 672                acp::PlanEntryStatus::Pending => {
 673                    stats.pending += 1;
 674                }
 675                acp::PlanEntryStatus::InProgress => {
 676                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 677                }
 678                acp::PlanEntryStatus::Completed => {
 679                    stats.completed += 1;
 680                }
 681            }
 682        }
 683
 684        stats
 685    }
 686}
 687
 688#[derive(Debug)]
 689pub struct PlanEntry {
 690    pub content: Entity<Markdown>,
 691    pub priority: acp::PlanEntryPriority,
 692    pub status: acp::PlanEntryStatus,
 693}
 694
 695impl PlanEntry {
 696    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 697        Self {
 698            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 699            priority: entry.priority,
 700            status: entry.status,
 701        }
 702    }
 703}
 704
 705#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 706pub struct TokenUsage {
 707    pub max_tokens: u64,
 708    pub used_tokens: u64,
 709}
 710
 711impl TokenUsage {
 712    pub fn ratio(&self) -> TokenUsageRatio {
 713        #[cfg(debug_assertions)]
 714        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 715            .unwrap_or("0.8".to_string())
 716            .parse()
 717            .unwrap();
 718        #[cfg(not(debug_assertions))]
 719        let warning_threshold: f32 = 0.8;
 720
 721        // When the maximum is unknown because there is no selected model,
 722        // avoid showing the token limit warning.
 723        if self.max_tokens == 0 {
 724            TokenUsageRatio::Normal
 725        } else if self.used_tokens >= self.max_tokens {
 726            TokenUsageRatio::Exceeded
 727        } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
 728            TokenUsageRatio::Warning
 729        } else {
 730            TokenUsageRatio::Normal
 731        }
 732    }
 733}
 734
 735#[derive(Debug, Clone, PartialEq, Eq)]
 736pub enum TokenUsageRatio {
 737    Normal,
 738    Warning,
 739    Exceeded,
 740}
 741
 742#[derive(Debug, Clone)]
 743pub struct RetryStatus {
 744    pub last_error: SharedString,
 745    pub attempt: usize,
 746    pub max_attempts: usize,
 747    pub started_at: Instant,
 748    pub duration: Duration,
 749}
 750
 751pub struct AcpThread {
 752    title: SharedString,
 753    entries: Vec<AgentThreadEntry>,
 754    plan: Plan,
 755    project: Entity<Project>,
 756    action_log: Entity<ActionLog>,
 757    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 758    send_task: Option<Task<()>>,
 759    connection: Rc<dyn AgentConnection>,
 760    session_id: acp::SessionId,
 761    token_usage: Option<TokenUsage>,
 762    prompt_capabilities: acp::PromptCapabilities,
 763    _observe_prompt_capabilities: Task<anyhow::Result<()>>,
 764}
 765
 766#[derive(Debug)]
 767pub enum AcpThreadEvent {
 768    NewEntry,
 769    TitleUpdated,
 770    TokenUsageUpdated,
 771    EntryUpdated(usize),
 772    EntriesRemoved(Range<usize>),
 773    ToolAuthorizationRequired,
 774    Retry(RetryStatus),
 775    Stopped,
 776    Error,
 777    LoadError(LoadError),
 778    PromptCapabilitiesUpdated,
 779}
 780
 781impl EventEmitter<AcpThreadEvent> for AcpThread {}
 782
 783#[derive(PartialEq, Eq, Debug)]
 784pub enum ThreadStatus {
 785    Idle,
 786    WaitingForToolConfirmation,
 787    Generating,
 788}
 789
 790#[derive(Debug, Clone)]
 791pub enum LoadError {
 792    NotInstalled {
 793        error_message: SharedString,
 794        install_message: SharedString,
 795        install_command: String,
 796    },
 797    Unsupported {
 798        error_message: SharedString,
 799        upgrade_message: SharedString,
 800        upgrade_command: String,
 801    },
 802    Exited {
 803        status: ExitStatus,
 804    },
 805    Other(SharedString),
 806}
 807
 808impl Display for LoadError {
 809    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 810        match self {
 811            LoadError::NotInstalled { error_message, .. }
 812            | LoadError::Unsupported { error_message, .. } => {
 813                write!(f, "{error_message}")
 814            }
 815            LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
 816            LoadError::Other(msg) => write!(f, "{}", msg),
 817        }
 818    }
 819}
 820
 821impl Error for LoadError {}
 822
 823impl AcpThread {
 824    pub fn new(
 825        title: impl Into<SharedString>,
 826        connection: Rc<dyn AgentConnection>,
 827        project: Entity<Project>,
 828        action_log: Entity<ActionLog>,
 829        session_id: acp::SessionId,
 830        mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
 831        cx: &mut Context<Self>,
 832    ) -> Self {
 833        let prompt_capabilities = *prompt_capabilities_rx.borrow();
 834        let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
 835            loop {
 836                let caps = prompt_capabilities_rx.recv().await?;
 837                this.update(cx, |this, cx| {
 838                    this.prompt_capabilities = caps;
 839                    cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
 840                })?;
 841            }
 842        });
 843
 844        Self {
 845            action_log,
 846            shared_buffers: Default::default(),
 847            entries: Default::default(),
 848            plan: Default::default(),
 849            title: title.into(),
 850            project,
 851            send_task: None,
 852            connection,
 853            session_id,
 854            token_usage: None,
 855            prompt_capabilities,
 856            _observe_prompt_capabilities: task,
 857        }
 858    }
 859
 860    pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
 861        self.prompt_capabilities
 862    }
 863
 864    pub fn connection(&self) -> &Rc<dyn AgentConnection> {
 865        &self.connection
 866    }
 867
 868    pub fn action_log(&self) -> &Entity<ActionLog> {
 869        &self.action_log
 870    }
 871
 872    pub fn project(&self) -> &Entity<Project> {
 873        &self.project
 874    }
 875
 876    pub fn title(&self) -> SharedString {
 877        self.title.clone()
 878    }
 879
 880    pub fn entries(&self) -> &[AgentThreadEntry] {
 881        &self.entries
 882    }
 883
 884    pub fn session_id(&self) -> &acp::SessionId {
 885        &self.session_id
 886    }
 887
 888    pub fn status(&self) -> ThreadStatus {
 889        if self.send_task.is_some() {
 890            if self.waiting_for_tool_confirmation() {
 891                ThreadStatus::WaitingForToolConfirmation
 892            } else {
 893                ThreadStatus::Generating
 894            }
 895        } else {
 896            ThreadStatus::Idle
 897        }
 898    }
 899
 900    pub fn token_usage(&self) -> Option<&TokenUsage> {
 901        self.token_usage.as_ref()
 902    }
 903
 904    pub fn has_pending_edit_tool_calls(&self) -> bool {
 905        for entry in self.entries.iter().rev() {
 906            match entry {
 907                AgentThreadEntry::UserMessage(_) => return false,
 908                AgentThreadEntry::ToolCall(
 909                    call @ ToolCall {
 910                        status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
 911                        ..
 912                    },
 913                ) if call.diffs().next().is_some() => {
 914                    return true;
 915                }
 916                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 917            }
 918        }
 919
 920        false
 921    }
 922
 923    pub fn used_tools_since_last_user_message(&self) -> bool {
 924        for entry in self.entries.iter().rev() {
 925            match entry {
 926                AgentThreadEntry::UserMessage(..) => return false,
 927                AgentThreadEntry::AssistantMessage(..) => continue,
 928                AgentThreadEntry::ToolCall(..) => return true,
 929            }
 930        }
 931
 932        false
 933    }
 934
 935    pub fn handle_session_update(
 936        &mut self,
 937        update: acp::SessionUpdate,
 938        cx: &mut Context<Self>,
 939    ) -> Result<(), acp::Error> {
 940        match update {
 941            acp::SessionUpdate::UserMessageChunk { content } => {
 942                self.push_user_content_block(None, content, cx);
 943            }
 944            acp::SessionUpdate::AgentMessageChunk { content } => {
 945                self.push_assistant_content_block(content, false, cx);
 946            }
 947            acp::SessionUpdate::AgentThoughtChunk { content } => {
 948                self.push_assistant_content_block(content, true, cx);
 949            }
 950            acp::SessionUpdate::ToolCall(tool_call) => {
 951                self.upsert_tool_call(tool_call, cx)?;
 952            }
 953            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 954                self.update_tool_call(tool_call_update, cx)?;
 955            }
 956            acp::SessionUpdate::Plan(plan) => {
 957                self.update_plan(plan, cx);
 958            }
 959        }
 960        Ok(())
 961    }
 962
 963    pub fn push_user_content_block(
 964        &mut self,
 965        message_id: Option<UserMessageId>,
 966        chunk: acp::ContentBlock,
 967        cx: &mut Context<Self>,
 968    ) {
 969        let language_registry = self.project.read(cx).languages().clone();
 970        let entries_len = self.entries.len();
 971
 972        if let Some(last_entry) = self.entries.last_mut()
 973            && let AgentThreadEntry::UserMessage(UserMessage {
 974                id,
 975                content,
 976                chunks,
 977                ..
 978            }) = last_entry
 979        {
 980            *id = message_id.or(id.take());
 981            content.append(chunk.clone(), &language_registry, cx);
 982            chunks.push(chunk);
 983            let idx = entries_len - 1;
 984            cx.emit(AcpThreadEvent::EntryUpdated(idx));
 985        } else {
 986            let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
 987            self.push_entry(
 988                AgentThreadEntry::UserMessage(UserMessage {
 989                    id: message_id,
 990                    content,
 991                    chunks: vec![chunk],
 992                    checkpoint: None,
 993                }),
 994                cx,
 995            );
 996        }
 997    }
 998
 999    pub fn push_assistant_content_block(
1000        &mut self,
1001        chunk: acp::ContentBlock,
1002        is_thought: bool,
1003        cx: &mut Context<Self>,
1004    ) {
1005        let language_registry = self.project.read(cx).languages().clone();
1006        let entries_len = self.entries.len();
1007        if let Some(last_entry) = self.entries.last_mut()
1008            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1009        {
1010            let idx = entries_len - 1;
1011            cx.emit(AcpThreadEvent::EntryUpdated(idx));
1012            match (chunks.last_mut(), is_thought) {
1013                (Some(AssistantMessageChunk::Message { block }), false)
1014                | (Some(AssistantMessageChunk::Thought { block }), true) => {
1015                    block.append(chunk, &language_registry, cx)
1016                }
1017                _ => {
1018                    let block = ContentBlock::new(chunk, &language_registry, cx);
1019                    if is_thought {
1020                        chunks.push(AssistantMessageChunk::Thought { block })
1021                    } else {
1022                        chunks.push(AssistantMessageChunk::Message { block })
1023                    }
1024                }
1025            }
1026        } else {
1027            let block = ContentBlock::new(chunk, &language_registry, cx);
1028            let chunk = if is_thought {
1029                AssistantMessageChunk::Thought { block }
1030            } else {
1031                AssistantMessageChunk::Message { block }
1032            };
1033
1034            self.push_entry(
1035                AgentThreadEntry::AssistantMessage(AssistantMessage {
1036                    chunks: vec![chunk],
1037                }),
1038                cx,
1039            );
1040        }
1041    }
1042
1043    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1044        self.entries.push(entry);
1045        cx.emit(AcpThreadEvent::NewEntry);
1046    }
1047
1048    pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1049        self.connection.set_title(&self.session_id, cx).is_some()
1050    }
1051
1052    pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1053        if title != self.title {
1054            self.title = title.clone();
1055            cx.emit(AcpThreadEvent::TitleUpdated);
1056            if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1057                return set_title.run(title, cx);
1058            }
1059        }
1060        Task::ready(Ok(()))
1061    }
1062
1063    pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1064        self.token_usage = usage;
1065        cx.emit(AcpThreadEvent::TokenUsageUpdated);
1066    }
1067
1068    pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1069        cx.emit(AcpThreadEvent::Retry(status));
1070    }
1071
1072    pub fn update_tool_call(
1073        &mut self,
1074        update: impl Into<ToolCallUpdate>,
1075        cx: &mut Context<Self>,
1076    ) -> Result<()> {
1077        let update = update.into();
1078        let languages = self.project.read(cx).languages().clone();
1079
1080        let (ix, current_call) = self
1081            .tool_call_mut(update.id())
1082            .context("Tool call not found")?;
1083        match update {
1084            ToolCallUpdate::UpdateFields(update) => {
1085                let location_updated = update.fields.locations.is_some();
1086                current_call.update_fields(update.fields, languages, cx);
1087                if location_updated {
1088                    self.resolve_locations(update.id, cx);
1089                }
1090            }
1091            ToolCallUpdate::UpdateDiff(update) => {
1092                current_call.content.clear();
1093                current_call
1094                    .content
1095                    .push(ToolCallContent::Diff(update.diff));
1096            }
1097            ToolCallUpdate::UpdateTerminal(update) => {
1098                current_call.content.clear();
1099                current_call
1100                    .content
1101                    .push(ToolCallContent::Terminal(update.terminal));
1102            }
1103        }
1104
1105        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1106
1107        Ok(())
1108    }
1109
1110    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1111    pub fn upsert_tool_call(
1112        &mut self,
1113        tool_call: acp::ToolCall,
1114        cx: &mut Context<Self>,
1115    ) -> Result<(), acp::Error> {
1116        let status = tool_call.status.into();
1117        self.upsert_tool_call_inner(tool_call.into(), status, cx)
1118    }
1119
1120    /// Fails if id does not match an existing entry.
1121    pub fn upsert_tool_call_inner(
1122        &mut self,
1123        tool_call_update: acp::ToolCallUpdate,
1124        status: ToolCallStatus,
1125        cx: &mut Context<Self>,
1126    ) -> Result<(), acp::Error> {
1127        let language_registry = self.project.read(cx).languages().clone();
1128        let id = tool_call_update.id.clone();
1129
1130        if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1131            current_call.update_fields(tool_call_update.fields, language_registry, cx);
1132            current_call.status = status;
1133
1134            cx.emit(AcpThreadEvent::EntryUpdated(ix));
1135        } else {
1136            let call =
1137                ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1138            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1139        };
1140
1141        self.resolve_locations(id, cx);
1142        Ok(())
1143    }
1144
1145    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1146        // The tool call we are looking for is typically the last one, or very close to the end.
1147        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1148        self.entries
1149            .iter_mut()
1150            .enumerate()
1151            .rev()
1152            .find_map(|(index, tool_call)| {
1153                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1154                    && &tool_call.id == id
1155                {
1156                    Some((index, tool_call))
1157                } else {
1158                    None
1159                }
1160            })
1161    }
1162
1163    pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1164        self.entries
1165            .iter()
1166            .enumerate()
1167            .rev()
1168            .find_map(|(index, tool_call)| {
1169                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1170                    && &tool_call.id == id
1171                {
1172                    Some((index, tool_call))
1173                } else {
1174                    None
1175                }
1176            })
1177    }
1178
1179    pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1180        let project = self.project.clone();
1181        let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1182            return;
1183        };
1184        let task = tool_call.resolve_locations(project, cx);
1185        cx.spawn(async move |this, cx| {
1186            let resolved_locations = task.await;
1187            this.update(cx, |this, cx| {
1188                let project = this.project.clone();
1189                let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1190                    return;
1191                };
1192                if let Some(Some(location)) = resolved_locations.last() {
1193                    project.update(cx, |project, cx| {
1194                        if let Some(agent_location) = project.agent_location() {
1195                            let should_ignore = agent_location.buffer == location.buffer
1196                                && location
1197                                    .buffer
1198                                    .update(cx, |buffer, _| {
1199                                        let snapshot = buffer.snapshot();
1200                                        let old_position =
1201                                            agent_location.position.to_point(&snapshot);
1202                                        let new_position = location.position.to_point(&snapshot);
1203                                        // ignore this so that when we get updates from the edit tool
1204                                        // the position doesn't reset to the startof line
1205                                        old_position.row == new_position.row
1206                                            && old_position.column > new_position.column
1207                                    })
1208                                    .ok()
1209                                    .unwrap_or_default();
1210                            if !should_ignore {
1211                                project.set_agent_location(Some(location.clone()), cx);
1212                            }
1213                        }
1214                    });
1215                }
1216                if tool_call.resolved_locations != resolved_locations {
1217                    tool_call.resolved_locations = resolved_locations;
1218                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1219                }
1220            })
1221        })
1222        .detach();
1223    }
1224
1225    pub fn request_tool_call_authorization(
1226        &mut self,
1227        tool_call: acp::ToolCallUpdate,
1228        options: Vec<acp::PermissionOption>,
1229        cx: &mut Context<Self>,
1230    ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1231        let (tx, rx) = oneshot::channel();
1232
1233        let status = ToolCallStatus::WaitingForConfirmation {
1234            options,
1235            respond_tx: tx,
1236        };
1237
1238        self.upsert_tool_call_inner(tool_call, status, cx)?;
1239        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1240        Ok(rx)
1241    }
1242
1243    pub fn authorize_tool_call(
1244        &mut self,
1245        id: acp::ToolCallId,
1246        option_id: acp::PermissionOptionId,
1247        option_kind: acp::PermissionOptionKind,
1248        cx: &mut Context<Self>,
1249    ) {
1250        let Some((ix, call)) = self.tool_call_mut(&id) else {
1251            return;
1252        };
1253
1254        let new_status = match option_kind {
1255            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1256                ToolCallStatus::Rejected
1257            }
1258            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1259                ToolCallStatus::InProgress
1260            }
1261        };
1262
1263        let curr_status = mem::replace(&mut call.status, new_status);
1264
1265        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1266            respond_tx.send(option_id).log_err();
1267        } else if cfg!(debug_assertions) {
1268            panic!("tried to authorize an already authorized tool call");
1269        }
1270
1271        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1272    }
1273
1274    /// Returns true if the last turn is awaiting tool authorization
1275    pub fn waiting_for_tool_confirmation(&self) -> bool {
1276        for entry in self.entries.iter().rev() {
1277            match &entry {
1278                AgentThreadEntry::ToolCall(call) => match call.status {
1279                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
1280                    ToolCallStatus::Pending
1281                    | ToolCallStatus::InProgress
1282                    | ToolCallStatus::Completed
1283                    | ToolCallStatus::Failed
1284                    | ToolCallStatus::Rejected
1285                    | ToolCallStatus::Canceled => continue,
1286                },
1287                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1288                    // Reached the beginning of the turn
1289                    return false;
1290                }
1291            }
1292        }
1293        false
1294    }
1295
1296    pub fn plan(&self) -> &Plan {
1297        &self.plan
1298    }
1299
1300    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1301        let new_entries_len = request.entries.len();
1302        let mut new_entries = request.entries.into_iter();
1303
1304        // Reuse existing markdown to prevent flickering
1305        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1306            let PlanEntry {
1307                content,
1308                priority,
1309                status,
1310            } = old;
1311            content.update(cx, |old, cx| {
1312                old.replace(new.content, cx);
1313            });
1314            *priority = new.priority;
1315            *status = new.status;
1316        }
1317        for new in new_entries {
1318            self.plan.entries.push(PlanEntry::from_acp(new, cx))
1319        }
1320        self.plan.entries.truncate(new_entries_len);
1321
1322        cx.notify();
1323    }
1324
1325    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1326        self.plan
1327            .entries
1328            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1329        cx.notify();
1330    }
1331
1332    #[cfg(any(test, feature = "test-support"))]
1333    pub fn send_raw(
1334        &mut self,
1335        message: &str,
1336        cx: &mut Context<Self>,
1337    ) -> BoxFuture<'static, Result<()>> {
1338        self.send(
1339            vec![acp::ContentBlock::Text(acp::TextContent {
1340                text: message.to_string(),
1341                annotations: None,
1342            })],
1343            cx,
1344        )
1345    }
1346
1347    pub fn send(
1348        &mut self,
1349        message: Vec<acp::ContentBlock>,
1350        cx: &mut Context<Self>,
1351    ) -> BoxFuture<'static, Result<()>> {
1352        let block = ContentBlock::new_combined(
1353            message.clone(),
1354            self.project.read(cx).languages().clone(),
1355            cx,
1356        );
1357        let request = acp::PromptRequest {
1358            prompt: message.clone(),
1359            session_id: self.session_id.clone(),
1360        };
1361        let git_store = self.project.read(cx).git_store().clone();
1362
1363        let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1364            Some(UserMessageId::new())
1365        } else {
1366            None
1367        };
1368
1369        self.run_turn(cx, async move |this, cx| {
1370            this.update(cx, |this, cx| {
1371                this.push_entry(
1372                    AgentThreadEntry::UserMessage(UserMessage {
1373                        id: message_id.clone(),
1374                        content: block,
1375                        chunks: message,
1376                        checkpoint: None,
1377                    }),
1378                    cx,
1379                );
1380            })
1381            .ok();
1382
1383            let old_checkpoint = git_store
1384                .update(cx, |git, cx| git.checkpoint(cx))?
1385                .await
1386                .context("failed to get old checkpoint")
1387                .log_err();
1388            this.update(cx, |this, cx| {
1389                if let Some((_ix, message)) = this.last_user_message() {
1390                    message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1391                        git_checkpoint,
1392                        show: false,
1393                    });
1394                }
1395                this.connection.prompt(message_id, request, cx)
1396            })?
1397            .await
1398        })
1399    }
1400
1401    pub fn can_resume(&self, cx: &App) -> bool {
1402        self.connection.resume(&self.session_id, cx).is_some()
1403    }
1404
1405    pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1406        self.run_turn(cx, async move |this, cx| {
1407            this.update(cx, |this, cx| {
1408                this.connection
1409                    .resume(&this.session_id, cx)
1410                    .map(|resume| resume.run(cx))
1411            })?
1412            .context("resuming a session is not supported")?
1413            .await
1414        })
1415    }
1416
1417    fn run_turn(
1418        &mut self,
1419        cx: &mut Context<Self>,
1420        f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1421    ) -> BoxFuture<'static, Result<()>> {
1422        self.clear_completed_plan_entries(cx);
1423
1424        let (tx, rx) = oneshot::channel();
1425        let cancel_task = self.cancel(cx);
1426
1427        self.send_task = Some(cx.spawn(async move |this, cx| {
1428            cancel_task.await;
1429            tx.send(f(this, cx).await).ok();
1430        }));
1431
1432        cx.spawn(async move |this, cx| {
1433            let response = rx.await;
1434
1435            this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1436                .await?;
1437
1438            this.update(cx, |this, cx| {
1439                this.project
1440                    .update(cx, |project, cx| project.set_agent_location(None, cx));
1441                match response {
1442                    Ok(Err(e)) => {
1443                        this.send_task.take();
1444                        cx.emit(AcpThreadEvent::Error);
1445                        Err(e)
1446                    }
1447                    result => {
1448                        let canceled = matches!(
1449                            result,
1450                            Ok(Ok(acp::PromptResponse {
1451                                stop_reason: acp::StopReason::Cancelled
1452                            }))
1453                        );
1454
1455                        // We only take the task if the current prompt wasn't canceled.
1456                        //
1457                        // This prompt may have been canceled because another one was sent
1458                        // while it was still generating. In these cases, dropping `send_task`
1459                        // would cause the next generation to be canceled.
1460                        if !canceled {
1461                            this.send_task.take();
1462                        }
1463
1464                        // Truncate entries if the last prompt was refused.
1465                        if let Ok(Ok(acp::PromptResponse {
1466                            stop_reason: acp::StopReason::Refusal,
1467                        })) = result
1468                            && let Some((ix, _)) = this.last_user_message()
1469                        {
1470                            let range = ix..this.entries.len();
1471                            this.entries.truncate(ix);
1472                            cx.emit(AcpThreadEvent::EntriesRemoved(range));
1473                        }
1474
1475                        cx.emit(AcpThreadEvent::Stopped);
1476                        Ok(())
1477                    }
1478                }
1479            })?
1480        })
1481        .boxed()
1482    }
1483
1484    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1485        let Some(send_task) = self.send_task.take() else {
1486            return Task::ready(());
1487        };
1488
1489        for entry in self.entries.iter_mut() {
1490            if let AgentThreadEntry::ToolCall(call) = entry {
1491                let cancel = matches!(
1492                    call.status,
1493                    ToolCallStatus::Pending
1494                        | ToolCallStatus::WaitingForConfirmation { .. }
1495                        | ToolCallStatus::InProgress
1496                );
1497
1498                if cancel {
1499                    call.status = ToolCallStatus::Canceled;
1500                }
1501            }
1502        }
1503
1504        self.connection.cancel(&self.session_id, cx);
1505
1506        // Wait for the send task to complete
1507        cx.foreground_executor().spawn(send_task)
1508    }
1509
1510    /// Rewinds this thread to before the entry at `index`, removing it and all
1511    /// subsequent entries while reverting any changes made from that point.
1512    pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1513        let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1514            return Task::ready(Err(anyhow!("not supported")));
1515        };
1516        let Some(message) = self.user_message(&id) else {
1517            return Task::ready(Err(anyhow!("message not found")));
1518        };
1519
1520        let checkpoint = message
1521            .checkpoint
1522            .as_ref()
1523            .map(|c| c.git_checkpoint.clone());
1524
1525        let git_store = self.project.read(cx).git_store().clone();
1526        cx.spawn(async move |this, cx| {
1527            if let Some(checkpoint) = checkpoint {
1528                git_store
1529                    .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1530                    .await?;
1531            }
1532
1533            cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1534            this.update(cx, |this, cx| {
1535                if let Some((ix, _)) = this.user_message_mut(&id) {
1536                    let range = ix..this.entries.len();
1537                    this.entries.truncate(ix);
1538                    cx.emit(AcpThreadEvent::EntriesRemoved(range));
1539                }
1540            })
1541        })
1542    }
1543
1544    fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1545        let git_store = self.project.read(cx).git_store().clone();
1546
1547        let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1548            if let Some(checkpoint) = message.checkpoint.as_ref() {
1549                checkpoint.git_checkpoint.clone()
1550            } else {
1551                return Task::ready(Ok(()));
1552            }
1553        } else {
1554            return Task::ready(Ok(()));
1555        };
1556
1557        let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1558        cx.spawn(async move |this, cx| {
1559            let new_checkpoint = new_checkpoint
1560                .await
1561                .context("failed to get new checkpoint")
1562                .log_err();
1563            if let Some(new_checkpoint) = new_checkpoint {
1564                let equal = git_store
1565                    .update(cx, |git, cx| {
1566                        git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1567                    })?
1568                    .await
1569                    .unwrap_or(true);
1570                this.update(cx, |this, cx| {
1571                    let (ix, message) = this.last_user_message().context("no user message")?;
1572                    let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1573                    checkpoint.show = !equal;
1574                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1575                    anyhow::Ok(())
1576                })??;
1577            }
1578
1579            Ok(())
1580        })
1581    }
1582
1583    fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1584        self.entries
1585            .iter_mut()
1586            .enumerate()
1587            .rev()
1588            .find_map(|(ix, entry)| {
1589                if let AgentThreadEntry::UserMessage(message) = entry {
1590                    Some((ix, message))
1591                } else {
1592                    None
1593                }
1594            })
1595    }
1596
1597    fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1598        self.entries.iter().find_map(|entry| {
1599            if let AgentThreadEntry::UserMessage(message) = entry {
1600                if message.id.as_ref() == Some(id) {
1601                    Some(message)
1602                } else {
1603                    None
1604                }
1605            } else {
1606                None
1607            }
1608        })
1609    }
1610
1611    fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1612        self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1613            if let AgentThreadEntry::UserMessage(message) = entry {
1614                if message.id.as_ref() == Some(id) {
1615                    Some((ix, message))
1616                } else {
1617                    None
1618                }
1619            } else {
1620                None
1621            }
1622        })
1623    }
1624
1625    pub fn read_text_file(
1626        &self,
1627        path: PathBuf,
1628        line: Option<u32>,
1629        limit: Option<u32>,
1630        reuse_shared_snapshot: bool,
1631        cx: &mut Context<Self>,
1632    ) -> Task<Result<String>> {
1633        let project = self.project.clone();
1634        let action_log = self.action_log.clone();
1635        cx.spawn(async move |this, cx| {
1636            let load = project.update(cx, |project, cx| {
1637                let path = project
1638                    .project_path_for_absolute_path(&path, cx)
1639                    .context("invalid path")?;
1640                anyhow::Ok(project.open_buffer(path, cx))
1641            });
1642            let buffer = load??.await?;
1643
1644            let snapshot = if reuse_shared_snapshot {
1645                this.read_with(cx, |this, _| {
1646                    this.shared_buffers.get(&buffer.clone()).cloned()
1647                })
1648                .log_err()
1649                .flatten()
1650            } else {
1651                None
1652            };
1653
1654            let snapshot = if let Some(snapshot) = snapshot {
1655                snapshot
1656            } else {
1657                action_log.update(cx, |action_log, cx| {
1658                    action_log.buffer_read(buffer.clone(), cx);
1659                })?;
1660                project.update(cx, |project, cx| {
1661                    let position = buffer
1662                        .read(cx)
1663                        .snapshot()
1664                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1665                    project.set_agent_location(
1666                        Some(AgentLocation {
1667                            buffer: buffer.downgrade(),
1668                            position,
1669                        }),
1670                        cx,
1671                    );
1672                })?;
1673
1674                buffer.update(cx, |buffer, _| buffer.snapshot())?
1675            };
1676
1677            this.update(cx, |this, _| {
1678                let text = snapshot.text();
1679                this.shared_buffers.insert(buffer.clone(), snapshot);
1680                if line.is_none() && limit.is_none() {
1681                    return Ok(text);
1682                }
1683                let limit = limit.unwrap_or(u32::MAX) as usize;
1684                let Some(line) = line else {
1685                    return Ok(text.lines().take(limit).collect::<String>());
1686                };
1687
1688                let count = text.lines().count();
1689                if count < line as usize {
1690                    anyhow::bail!("There are only {} lines", count);
1691                }
1692                Ok(text
1693                    .lines()
1694                    .skip(line as usize + 1)
1695                    .take(limit)
1696                    .collect::<String>())
1697            })?
1698        })
1699    }
1700
1701    pub fn write_text_file(
1702        &self,
1703        path: PathBuf,
1704        content: String,
1705        cx: &mut Context<Self>,
1706    ) -> Task<Result<()>> {
1707        let project = self.project.clone();
1708        let action_log = self.action_log.clone();
1709        cx.spawn(async move |this, cx| {
1710            let load = project.update(cx, |project, cx| {
1711                let path = project
1712                    .project_path_for_absolute_path(&path, cx)
1713                    .context("invalid path")?;
1714                anyhow::Ok(project.open_buffer(path, cx))
1715            });
1716            let buffer = load??.await?;
1717            let snapshot = this.update(cx, |this, cx| {
1718                this.shared_buffers
1719                    .get(&buffer)
1720                    .cloned()
1721                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1722            })?;
1723            let edits = cx
1724                .background_executor()
1725                .spawn(async move {
1726                    let old_text = snapshot.text();
1727                    text_diff(old_text.as_str(), &content)
1728                        .into_iter()
1729                        .map(|(range, replacement)| {
1730                            (
1731                                snapshot.anchor_after(range.start)
1732                                    ..snapshot.anchor_before(range.end),
1733                                replacement,
1734                            )
1735                        })
1736                        .collect::<Vec<_>>()
1737                })
1738                .await;
1739
1740            project.update(cx, |project, cx| {
1741                project.set_agent_location(
1742                    Some(AgentLocation {
1743                        buffer: buffer.downgrade(),
1744                        position: edits
1745                            .last()
1746                            .map(|(range, _)| range.end)
1747                            .unwrap_or(Anchor::MIN),
1748                    }),
1749                    cx,
1750                );
1751            })?;
1752
1753            let format_on_save = cx.update(|cx| {
1754                action_log.update(cx, |action_log, cx| {
1755                    action_log.buffer_read(buffer.clone(), cx);
1756                });
1757
1758                let format_on_save = buffer.update(cx, |buffer, cx| {
1759                    buffer.edit(edits, None, cx);
1760
1761                    let settings = language::language_settings::language_settings(
1762                        buffer.language().map(|l| l.name()),
1763                        buffer.file(),
1764                        cx,
1765                    );
1766
1767                    settings.format_on_save != FormatOnSave::Off
1768                });
1769                action_log.update(cx, |action_log, cx| {
1770                    action_log.buffer_edited(buffer.clone(), cx);
1771                });
1772                format_on_save
1773            })?;
1774
1775            if format_on_save {
1776                let format_task = project.update(cx, |project, cx| {
1777                    project.format(
1778                        HashSet::from_iter([buffer.clone()]),
1779                        LspFormatTarget::Buffers,
1780                        false,
1781                        FormatTrigger::Save,
1782                        cx,
1783                    )
1784                })?;
1785                format_task.await.log_err();
1786
1787                action_log.update(cx, |action_log, cx| {
1788                    action_log.buffer_edited(buffer.clone(), cx);
1789                })?;
1790            }
1791
1792            project
1793                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1794                .await
1795        })
1796    }
1797
1798    pub fn to_markdown(&self, cx: &App) -> String {
1799        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1800    }
1801
1802    pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
1803        cx.emit(AcpThreadEvent::LoadError(error));
1804    }
1805}
1806
1807fn markdown_for_raw_output(
1808    raw_output: &serde_json::Value,
1809    language_registry: &Arc<LanguageRegistry>,
1810    cx: &mut App,
1811) -> Option<Entity<Markdown>> {
1812    match raw_output {
1813        serde_json::Value::Null => None,
1814        serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1815            Markdown::new(
1816                value.to_string().into(),
1817                Some(language_registry.clone()),
1818                None,
1819                cx,
1820            )
1821        })),
1822        serde_json::Value::Number(value) => Some(cx.new(|cx| {
1823            Markdown::new(
1824                value.to_string().into(),
1825                Some(language_registry.clone()),
1826                None,
1827                cx,
1828            )
1829        })),
1830        serde_json::Value::String(value) => Some(cx.new(|cx| {
1831            Markdown::new(
1832                value.clone().into(),
1833                Some(language_registry.clone()),
1834                None,
1835                cx,
1836            )
1837        })),
1838        value => Some(cx.new(|cx| {
1839            Markdown::new(
1840                format!("```json\n{}\n```", value).into(),
1841                Some(language_registry.clone()),
1842                None,
1843                cx,
1844            )
1845        })),
1846    }
1847}
1848
1849#[cfg(test)]
1850mod tests {
1851    use super::*;
1852    use anyhow::anyhow;
1853    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1854    use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1855    use indoc::indoc;
1856    use project::{FakeFs, Fs};
1857    use rand::Rng as _;
1858    use serde_json::json;
1859    use settings::SettingsStore;
1860    use smol::stream::StreamExt as _;
1861    use std::{
1862        any::Any,
1863        cell::RefCell,
1864        path::Path,
1865        rc::Rc,
1866        sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1867        time::Duration,
1868    };
1869    use util::path;
1870
1871    fn init_test(cx: &mut TestAppContext) {
1872        env_logger::try_init().ok();
1873        cx.update(|cx| {
1874            let settings_store = SettingsStore::test(cx);
1875            cx.set_global(settings_store);
1876            Project::init_settings(cx);
1877            language::init(cx);
1878        });
1879    }
1880
1881    #[gpui::test]
1882    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1883        init_test(cx);
1884
1885        let fs = FakeFs::new(cx.executor());
1886        let project = Project::test(fs, [], cx).await;
1887        let connection = Rc::new(FakeAgentConnection::new());
1888        let thread = cx
1889            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1890            .await
1891            .unwrap();
1892
1893        // Test creating a new user message
1894        thread.update(cx, |thread, cx| {
1895            thread.push_user_content_block(
1896                None,
1897                acp::ContentBlock::Text(acp::TextContent {
1898                    annotations: None,
1899                    text: "Hello, ".to_string(),
1900                }),
1901                cx,
1902            );
1903        });
1904
1905        thread.update(cx, |thread, cx| {
1906            assert_eq!(thread.entries.len(), 1);
1907            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1908                assert_eq!(user_msg.id, None);
1909                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1910            } else {
1911                panic!("Expected UserMessage");
1912            }
1913        });
1914
1915        // Test appending to existing user message
1916        let message_1_id = UserMessageId::new();
1917        thread.update(cx, |thread, cx| {
1918            thread.push_user_content_block(
1919                Some(message_1_id.clone()),
1920                acp::ContentBlock::Text(acp::TextContent {
1921                    annotations: None,
1922                    text: "world!".to_string(),
1923                }),
1924                cx,
1925            );
1926        });
1927
1928        thread.update(cx, |thread, cx| {
1929            assert_eq!(thread.entries.len(), 1);
1930            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1931                assert_eq!(user_msg.id, Some(message_1_id));
1932                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1933            } else {
1934                panic!("Expected UserMessage");
1935            }
1936        });
1937
1938        // Test creating new user message after assistant message
1939        thread.update(cx, |thread, cx| {
1940            thread.push_assistant_content_block(
1941                acp::ContentBlock::Text(acp::TextContent {
1942                    annotations: None,
1943                    text: "Assistant response".to_string(),
1944                }),
1945                false,
1946                cx,
1947            );
1948        });
1949
1950        let message_2_id = UserMessageId::new();
1951        thread.update(cx, |thread, cx| {
1952            thread.push_user_content_block(
1953                Some(message_2_id.clone()),
1954                acp::ContentBlock::Text(acp::TextContent {
1955                    annotations: None,
1956                    text: "New user message".to_string(),
1957                }),
1958                cx,
1959            );
1960        });
1961
1962        thread.update(cx, |thread, cx| {
1963            assert_eq!(thread.entries.len(), 3);
1964            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1965                assert_eq!(user_msg.id, Some(message_2_id));
1966                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1967            } else {
1968                panic!("Expected UserMessage at index 2");
1969            }
1970        });
1971    }
1972
1973    #[gpui::test]
1974    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1975        init_test(cx);
1976
1977        let fs = FakeFs::new(cx.executor());
1978        let project = Project::test(fs, [], cx).await;
1979        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1980            |_, thread, mut cx| {
1981                async move {
1982                    thread.update(&mut cx, |thread, cx| {
1983                        thread
1984                            .handle_session_update(
1985                                acp::SessionUpdate::AgentThoughtChunk {
1986                                    content: "Thinking ".into(),
1987                                },
1988                                cx,
1989                            )
1990                            .unwrap();
1991                        thread
1992                            .handle_session_update(
1993                                acp::SessionUpdate::AgentThoughtChunk {
1994                                    content: "hard!".into(),
1995                                },
1996                                cx,
1997                            )
1998                            .unwrap();
1999                    })?;
2000                    Ok(acp::PromptResponse {
2001                        stop_reason: acp::StopReason::EndTurn,
2002                    })
2003                }
2004                .boxed_local()
2005            },
2006        ));
2007
2008        let thread = cx
2009            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2010            .await
2011            .unwrap();
2012
2013        thread
2014            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2015            .await
2016            .unwrap();
2017
2018        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2019        assert_eq!(
2020            output,
2021            indoc! {r#"
2022            ## User
2023
2024            Hello from Zed!
2025
2026            ## Assistant
2027
2028            <thinking>
2029            Thinking hard!
2030            </thinking>
2031
2032            "#}
2033        );
2034    }
2035
2036    #[gpui::test]
2037    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2038        init_test(cx);
2039
2040        let fs = FakeFs::new(cx.executor());
2041        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2042            .await;
2043        let project = Project::test(fs.clone(), [], cx).await;
2044        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2045        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2046        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2047            move |_, thread, mut cx| {
2048                let read_file_tx = read_file_tx.clone();
2049                async move {
2050                    let content = thread
2051                        .update(&mut cx, |thread, cx| {
2052                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2053                        })
2054                        .unwrap()
2055                        .await
2056                        .unwrap();
2057                    assert_eq!(content, "one\ntwo\nthree\n");
2058                    read_file_tx.take().unwrap().send(()).unwrap();
2059                    thread
2060                        .update(&mut cx, |thread, cx| {
2061                            thread.write_text_file(
2062                                path!("/tmp/foo").into(),
2063                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
2064                                cx,
2065                            )
2066                        })
2067                        .unwrap()
2068                        .await
2069                        .unwrap();
2070                    Ok(acp::PromptResponse {
2071                        stop_reason: acp::StopReason::EndTurn,
2072                    })
2073                }
2074                .boxed_local()
2075            },
2076        ));
2077
2078        let (worktree, pathbuf) = project
2079            .update(cx, |project, cx| {
2080                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2081            })
2082            .await
2083            .unwrap();
2084        let buffer = project
2085            .update(cx, |project, cx| {
2086                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2087            })
2088            .await
2089            .unwrap();
2090
2091        let thread = cx
2092            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2093            .await
2094            .unwrap();
2095
2096        let request = thread.update(cx, |thread, cx| {
2097            thread.send_raw("Extend the count in /tmp/foo", cx)
2098        });
2099        read_file_rx.await.ok();
2100        buffer.update(cx, |buffer, cx| {
2101            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2102        });
2103        cx.run_until_parked();
2104        assert_eq!(
2105            buffer.read_with(cx, |buffer, _| buffer.text()),
2106            "zero\none\ntwo\nthree\nfour\nfive\n"
2107        );
2108        assert_eq!(
2109            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2110            "zero\none\ntwo\nthree\nfour\nfive\n"
2111        );
2112        request.await.unwrap();
2113    }
2114
2115    #[gpui::test]
2116    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2117        init_test(cx);
2118
2119        let fs = FakeFs::new(cx.executor());
2120        let project = Project::test(fs, [], cx).await;
2121        let id = acp::ToolCallId("test".into());
2122
2123        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2124            let id = id.clone();
2125            move |_, thread, mut cx| {
2126                let id = id.clone();
2127                async move {
2128                    thread
2129                        .update(&mut cx, |thread, cx| {
2130                            thread.handle_session_update(
2131                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2132                                    id: id.clone(),
2133                                    title: "Label".into(),
2134                                    kind: acp::ToolKind::Fetch,
2135                                    status: acp::ToolCallStatus::InProgress,
2136                                    content: vec![],
2137                                    locations: vec![],
2138                                    raw_input: None,
2139                                    raw_output: None,
2140                                }),
2141                                cx,
2142                            )
2143                        })
2144                        .unwrap()
2145                        .unwrap();
2146                    Ok(acp::PromptResponse {
2147                        stop_reason: acp::StopReason::EndTurn,
2148                    })
2149                }
2150                .boxed_local()
2151            }
2152        }));
2153
2154        let thread = cx
2155            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2156            .await
2157            .unwrap();
2158
2159        let request = thread.update(cx, |thread, cx| {
2160            thread.send_raw("Fetch https://example.com", cx)
2161        });
2162
2163        run_until_first_tool_call(&thread, cx).await;
2164
2165        thread.read_with(cx, |thread, _| {
2166            assert!(matches!(
2167                thread.entries[1],
2168                AgentThreadEntry::ToolCall(ToolCall {
2169                    status: ToolCallStatus::InProgress,
2170                    ..
2171                })
2172            ));
2173        });
2174
2175        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2176
2177        thread.read_with(cx, |thread, _| {
2178            assert!(matches!(
2179                &thread.entries[1],
2180                AgentThreadEntry::ToolCall(ToolCall {
2181                    status: ToolCallStatus::Canceled,
2182                    ..
2183                })
2184            ));
2185        });
2186
2187        thread
2188            .update(cx, |thread, cx| {
2189                thread.handle_session_update(
2190                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2191                        id,
2192                        fields: acp::ToolCallUpdateFields {
2193                            status: Some(acp::ToolCallStatus::Completed),
2194                            ..Default::default()
2195                        },
2196                    }),
2197                    cx,
2198                )
2199            })
2200            .unwrap();
2201
2202        request.await.unwrap();
2203
2204        thread.read_with(cx, |thread, _| {
2205            assert!(matches!(
2206                thread.entries[1],
2207                AgentThreadEntry::ToolCall(ToolCall {
2208                    status: ToolCallStatus::Completed,
2209                    ..
2210                })
2211            ));
2212        });
2213    }
2214
2215    #[gpui::test]
2216    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2217        init_test(cx);
2218        let fs = FakeFs::new(cx.background_executor.clone());
2219        fs.insert_tree(path!("/test"), json!({})).await;
2220        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2221
2222        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2223            move |_, thread, mut cx| {
2224                async move {
2225                    thread
2226                        .update(&mut cx, |thread, cx| {
2227                            thread.handle_session_update(
2228                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2229                                    id: acp::ToolCallId("test".into()),
2230                                    title: "Label".into(),
2231                                    kind: acp::ToolKind::Edit,
2232                                    status: acp::ToolCallStatus::Completed,
2233                                    content: vec![acp::ToolCallContent::Diff {
2234                                        diff: acp::Diff {
2235                                            path: "/test/test.txt".into(),
2236                                            old_text: None,
2237                                            new_text: "foo".into(),
2238                                        },
2239                                    }],
2240                                    locations: vec![],
2241                                    raw_input: None,
2242                                    raw_output: None,
2243                                }),
2244                                cx,
2245                            )
2246                        })
2247                        .unwrap()
2248                        .unwrap();
2249                    Ok(acp::PromptResponse {
2250                        stop_reason: acp::StopReason::EndTurn,
2251                    })
2252                }
2253                .boxed_local()
2254            }
2255        }));
2256
2257        let thread = cx
2258            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2259            .await
2260            .unwrap();
2261
2262        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2263            .await
2264            .unwrap();
2265
2266        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2267    }
2268
2269    #[gpui::test(iterations = 10)]
2270    async fn test_checkpoints(cx: &mut TestAppContext) {
2271        init_test(cx);
2272        let fs = FakeFs::new(cx.background_executor.clone());
2273        fs.insert_tree(
2274            path!("/test"),
2275            json!({
2276                ".git": {}
2277            }),
2278        )
2279        .await;
2280        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2281
2282        let simulate_changes = Arc::new(AtomicBool::new(true));
2283        let next_filename = Arc::new(AtomicUsize::new(0));
2284        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2285            let simulate_changes = simulate_changes.clone();
2286            let next_filename = next_filename.clone();
2287            let fs = fs.clone();
2288            move |request, thread, mut cx| {
2289                let fs = fs.clone();
2290                let simulate_changes = simulate_changes.clone();
2291                let next_filename = next_filename.clone();
2292                async move {
2293                    if simulate_changes.load(SeqCst) {
2294                        let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2295                        fs.write(Path::new(&filename), b"").await?;
2296                    }
2297
2298                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2299                        panic!("expected text content block");
2300                    };
2301                    thread.update(&mut cx, |thread, cx| {
2302                        thread
2303                            .handle_session_update(
2304                                acp::SessionUpdate::AgentMessageChunk {
2305                                    content: content.text.to_uppercase().into(),
2306                                },
2307                                cx,
2308                            )
2309                            .unwrap();
2310                    })?;
2311                    Ok(acp::PromptResponse {
2312                        stop_reason: acp::StopReason::EndTurn,
2313                    })
2314                }
2315                .boxed_local()
2316            }
2317        }));
2318        let thread = cx
2319            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2320            .await
2321            .unwrap();
2322
2323        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2324            .await
2325            .unwrap();
2326        thread.read_with(cx, |thread, cx| {
2327            assert_eq!(
2328                thread.to_markdown(cx),
2329                indoc! {"
2330                    ## User (checkpoint)
2331
2332                    Lorem
2333
2334                    ## Assistant
2335
2336                    LOREM
2337
2338                "}
2339            );
2340        });
2341        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2342
2343        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2344            .await
2345            .unwrap();
2346        thread.read_with(cx, |thread, cx| {
2347            assert_eq!(
2348                thread.to_markdown(cx),
2349                indoc! {"
2350                    ## User (checkpoint)
2351
2352                    Lorem
2353
2354                    ## Assistant
2355
2356                    LOREM
2357
2358                    ## User (checkpoint)
2359
2360                    ipsum
2361
2362                    ## Assistant
2363
2364                    IPSUM
2365
2366                "}
2367            );
2368        });
2369        assert_eq!(
2370            fs.files(),
2371            vec![
2372                Path::new(path!("/test/file-0")),
2373                Path::new(path!("/test/file-1"))
2374            ]
2375        );
2376
2377        // Checkpoint isn't stored when there are no changes.
2378        simulate_changes.store(false, SeqCst);
2379        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2380            .await
2381            .unwrap();
2382        thread.read_with(cx, |thread, cx| {
2383            assert_eq!(
2384                thread.to_markdown(cx),
2385                indoc! {"
2386                    ## User (checkpoint)
2387
2388                    Lorem
2389
2390                    ## Assistant
2391
2392                    LOREM
2393
2394                    ## User (checkpoint)
2395
2396                    ipsum
2397
2398                    ## Assistant
2399
2400                    IPSUM
2401
2402                    ## User
2403
2404                    dolor
2405
2406                    ## Assistant
2407
2408                    DOLOR
2409
2410                "}
2411            );
2412        });
2413        assert_eq!(
2414            fs.files(),
2415            vec![
2416                Path::new(path!("/test/file-0")),
2417                Path::new(path!("/test/file-1"))
2418            ]
2419        );
2420
2421        // Rewinding the conversation truncates the history and restores the checkpoint.
2422        thread
2423            .update(cx, |thread, cx| {
2424                let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2425                    panic!("unexpected entries {:?}", thread.entries)
2426                };
2427                thread.rewind(message.id.clone().unwrap(), cx)
2428            })
2429            .await
2430            .unwrap();
2431        thread.read_with(cx, |thread, cx| {
2432            assert_eq!(
2433                thread.to_markdown(cx),
2434                indoc! {"
2435                    ## User (checkpoint)
2436
2437                    Lorem
2438
2439                    ## Assistant
2440
2441                    LOREM
2442
2443                "}
2444            );
2445        });
2446        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2447    }
2448
2449    #[gpui::test]
2450    async fn test_refusal(cx: &mut TestAppContext) {
2451        init_test(cx);
2452        let fs = FakeFs::new(cx.background_executor.clone());
2453        fs.insert_tree(path!("/"), json!({})).await;
2454        let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2455
2456        let refuse_next = Arc::new(AtomicBool::new(false));
2457        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2458            let refuse_next = refuse_next.clone();
2459            move |request, thread, mut cx| {
2460                let refuse_next = refuse_next.clone();
2461                async move {
2462                    if refuse_next.load(SeqCst) {
2463                        return Ok(acp::PromptResponse {
2464                            stop_reason: acp::StopReason::Refusal,
2465                        });
2466                    }
2467
2468                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2469                        panic!("expected text content block");
2470                    };
2471                    thread.update(&mut cx, |thread, cx| {
2472                        thread
2473                            .handle_session_update(
2474                                acp::SessionUpdate::AgentMessageChunk {
2475                                    content: content.text.to_uppercase().into(),
2476                                },
2477                                cx,
2478                            )
2479                            .unwrap();
2480                    })?;
2481                    Ok(acp::PromptResponse {
2482                        stop_reason: acp::StopReason::EndTurn,
2483                    })
2484                }
2485                .boxed_local()
2486            }
2487        }));
2488        let thread = cx
2489            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2490            .await
2491            .unwrap();
2492
2493        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2494            .await
2495            .unwrap();
2496        thread.read_with(cx, |thread, cx| {
2497            assert_eq!(
2498                thread.to_markdown(cx),
2499                indoc! {"
2500                    ## User
2501
2502                    hello
2503
2504                    ## Assistant
2505
2506                    HELLO
2507
2508                "}
2509            );
2510        });
2511
2512        // Simulate refusing the second message, ensuring the conversation gets
2513        // truncated to before sending it.
2514        refuse_next.store(true, SeqCst);
2515        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
2516            .await
2517            .unwrap();
2518        thread.read_with(cx, |thread, cx| {
2519            assert_eq!(
2520                thread.to_markdown(cx),
2521                indoc! {"
2522                    ## User
2523
2524                    hello
2525
2526                    ## Assistant
2527
2528                    HELLO
2529
2530                "}
2531            );
2532        });
2533    }
2534
2535    async fn run_until_first_tool_call(
2536        thread: &Entity<AcpThread>,
2537        cx: &mut TestAppContext,
2538    ) -> usize {
2539        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2540
2541        let subscription = cx.update(|cx| {
2542            cx.subscribe(thread, move |thread, _, cx| {
2543                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2544                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2545                        return tx.try_send(ix).unwrap();
2546                    }
2547                }
2548            })
2549        });
2550
2551        select! {
2552            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2553                panic!("Timeout waiting for tool call")
2554            }
2555            ix = rx.next().fuse() => {
2556                drop(subscription);
2557                ix.unwrap()
2558            }
2559        }
2560    }
2561
2562    #[derive(Clone, Default)]
2563    struct FakeAgentConnection {
2564        auth_methods: Vec<acp::AuthMethod>,
2565        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2566        on_user_message: Option<
2567            Rc<
2568                dyn Fn(
2569                        acp::PromptRequest,
2570                        WeakEntity<AcpThread>,
2571                        AsyncApp,
2572                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2573                    + 'static,
2574            >,
2575        >,
2576    }
2577
2578    impl FakeAgentConnection {
2579        fn new() -> Self {
2580            Self {
2581                auth_methods: Vec::new(),
2582                on_user_message: None,
2583                sessions: Arc::default(),
2584            }
2585        }
2586
2587        #[expect(unused)]
2588        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2589            self.auth_methods = auth_methods;
2590            self
2591        }
2592
2593        fn on_user_message(
2594            mut self,
2595            handler: impl Fn(
2596                acp::PromptRequest,
2597                WeakEntity<AcpThread>,
2598                AsyncApp,
2599            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2600            + 'static,
2601        ) -> Self {
2602            self.on_user_message.replace(Rc::new(handler));
2603            self
2604        }
2605    }
2606
2607    impl AgentConnection for FakeAgentConnection {
2608        fn auth_methods(&self) -> &[acp::AuthMethod] {
2609            &self.auth_methods
2610        }
2611
2612        fn new_thread(
2613            self: Rc<Self>,
2614            project: Entity<Project>,
2615            _cwd: &Path,
2616            cx: &mut App,
2617        ) -> Task<gpui::Result<Entity<AcpThread>>> {
2618            let session_id = acp::SessionId(
2619                rand::thread_rng()
2620                    .sample_iter(&rand::distributions::Alphanumeric)
2621                    .take(7)
2622                    .map(char::from)
2623                    .collect::<String>()
2624                    .into(),
2625            );
2626            let action_log = cx.new(|_| ActionLog::new(project.clone()));
2627            let thread = cx.new(|cx| {
2628                AcpThread::new(
2629                    "Test",
2630                    self.clone(),
2631                    project,
2632                    action_log,
2633                    session_id.clone(),
2634                    watch::Receiver::constant(acp::PromptCapabilities {
2635                        image: true,
2636                        audio: true,
2637                        embedded_context: true,
2638                    }),
2639                    cx,
2640                )
2641            });
2642            self.sessions.lock().insert(session_id, thread.downgrade());
2643            Task::ready(Ok(thread))
2644        }
2645
2646        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2647            if self.auth_methods().iter().any(|m| m.id == method) {
2648                Task::ready(Ok(()))
2649            } else {
2650                Task::ready(Err(anyhow!("Invalid Auth Method")))
2651            }
2652        }
2653
2654        fn prompt(
2655            &self,
2656            _id: Option<UserMessageId>,
2657            params: acp::PromptRequest,
2658            cx: &mut App,
2659        ) -> Task<gpui::Result<acp::PromptResponse>> {
2660            let sessions = self.sessions.lock();
2661            let thread = sessions.get(&params.session_id).unwrap();
2662            if let Some(handler) = &self.on_user_message {
2663                let handler = handler.clone();
2664                let thread = thread.clone();
2665                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2666            } else {
2667                Task::ready(Ok(acp::PromptResponse {
2668                    stop_reason: acp::StopReason::EndTurn,
2669                }))
2670            }
2671        }
2672
2673        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2674            let sessions = self.sessions.lock();
2675            let thread = sessions.get(session_id).unwrap().clone();
2676
2677            cx.spawn(async move |cx| {
2678                thread
2679                    .update(cx, |thread, cx| thread.cancel(cx))
2680                    .unwrap()
2681                    .await
2682            })
2683            .detach();
2684        }
2685
2686        fn truncate(
2687            &self,
2688            session_id: &acp::SessionId,
2689            _cx: &App,
2690        ) -> Option<Rc<dyn AgentSessionTruncate>> {
2691            Some(Rc::new(FakeAgentSessionEditor {
2692                _session_id: session_id.clone(),
2693            }))
2694        }
2695
2696        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2697            self
2698        }
2699    }
2700
2701    struct FakeAgentSessionEditor {
2702        _session_id: acp::SessionId,
2703    }
2704
2705    impl AgentSessionTruncate for FakeAgentSessionEditor {
2706        fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2707            Task::ready(Ok(()))
2708        }
2709    }
2710}