acp_thread.rs

   1mod connection;
   2mod diff;
   3mod mention;
   4mod terminal;
   5
   6use collections::HashSet;
   7pub use connection::*;
   8pub use diff::*;
   9use language::language_settings::FormatOnSave;
  10pub use mention::*;
  11use project::lsp_store::{FormatTrigger, LspFormatTarget};
  12use serde::{Deserialize, Serialize};
  13pub use terminal::*;
  14
  15use action_log::ActionLog;
  16use agent_client_protocol as acp;
  17use anyhow::{Context as _, Result, anyhow};
  18use editor::Bias;
  19use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  20use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  21use itertools::Itertools;
  22use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
  23use markdown::Markdown;
  24use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
  25use std::collections::HashMap;
  26use std::error::Error;
  27use std::fmt::{Formatter, Write};
  28use std::ops::Range;
  29use std::process::ExitStatus;
  30use std::rc::Rc;
  31use std::time::{Duration, Instant};
  32use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
  33use ui::App;
  34use util::ResultExt;
  35
  36#[derive(Debug)]
  37pub struct UserMessage {
  38    pub id: Option<UserMessageId>,
  39    pub content: ContentBlock,
  40    pub chunks: Vec<acp::ContentBlock>,
  41    pub checkpoint: Option<Checkpoint>,
  42}
  43
  44#[derive(Debug)]
  45pub struct Checkpoint {
  46    git_checkpoint: GitStoreCheckpoint,
  47    pub show: bool,
  48}
  49
  50impl UserMessage {
  51    fn to_markdown(&self, cx: &App) -> String {
  52        let mut markdown = String::new();
  53        if self
  54            .checkpoint
  55            .as_ref()
  56            .is_some_and(|checkpoint| checkpoint.show)
  57        {
  58            writeln!(markdown, "## User (checkpoint)").unwrap();
  59        } else {
  60            writeln!(markdown, "## User").unwrap();
  61        }
  62        writeln!(markdown).unwrap();
  63        writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
  64        writeln!(markdown).unwrap();
  65        markdown
  66    }
  67}
  68
  69#[derive(Debug, PartialEq)]
  70pub struct AssistantMessage {
  71    pub chunks: Vec<AssistantMessageChunk>,
  72}
  73
  74impl AssistantMessage {
  75    pub fn to_markdown(&self, cx: &App) -> String {
  76        format!(
  77            "## Assistant\n\n{}\n\n",
  78            self.chunks
  79                .iter()
  80                .map(|chunk| chunk.to_markdown(cx))
  81                .join("\n\n")
  82        )
  83    }
  84}
  85
  86#[derive(Debug, PartialEq)]
  87pub enum AssistantMessageChunk {
  88    Message { block: ContentBlock },
  89    Thought { block: ContentBlock },
  90}
  91
  92impl AssistantMessageChunk {
  93    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
  94        Self::Message {
  95            block: ContentBlock::new(chunk.into(), language_registry, cx),
  96        }
  97    }
  98
  99    fn to_markdown(&self, cx: &App) -> String {
 100        match self {
 101            Self::Message { block } => block.to_markdown(cx).to_string(),
 102            Self::Thought { block } => {
 103                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 104            }
 105        }
 106    }
 107}
 108
 109#[derive(Debug)]
 110pub enum AgentThreadEntry {
 111    UserMessage(UserMessage),
 112    AssistantMessage(AssistantMessage),
 113    ToolCall(ToolCall),
 114}
 115
 116impl AgentThreadEntry {
 117    pub fn to_markdown(&self, cx: &App) -> String {
 118        match self {
 119            Self::UserMessage(message) => message.to_markdown(cx),
 120            Self::AssistantMessage(message) => message.to_markdown(cx),
 121            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 122        }
 123    }
 124
 125    pub fn user_message(&self) -> Option<&UserMessage> {
 126        if let AgentThreadEntry::UserMessage(message) = self {
 127            Some(message)
 128        } else {
 129            None
 130        }
 131    }
 132
 133    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 134        if let AgentThreadEntry::ToolCall(call) = self {
 135            itertools::Either::Left(call.diffs())
 136        } else {
 137            itertools::Either::Right(std::iter::empty())
 138        }
 139    }
 140
 141    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 142        if let AgentThreadEntry::ToolCall(call) = self {
 143            itertools::Either::Left(call.terminals())
 144        } else {
 145            itertools::Either::Right(std::iter::empty())
 146        }
 147    }
 148
 149    pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
 150        if let AgentThreadEntry::ToolCall(ToolCall {
 151            locations,
 152            resolved_locations,
 153            ..
 154        }) = self
 155        {
 156            Some((
 157                locations.get(ix)?.clone(),
 158                resolved_locations.get(ix)?.clone()?,
 159            ))
 160        } else {
 161            None
 162        }
 163    }
 164}
 165
 166#[derive(Debug)]
 167pub struct ToolCall {
 168    pub id: acp::ToolCallId,
 169    pub label: Entity<Markdown>,
 170    pub kind: acp::ToolKind,
 171    pub content: Vec<ToolCallContent>,
 172    pub status: ToolCallStatus,
 173    pub locations: Vec<acp::ToolCallLocation>,
 174    pub resolved_locations: Vec<Option<AgentLocation>>,
 175    pub raw_input: Option<serde_json::Value>,
 176    pub raw_output: Option<serde_json::Value>,
 177}
 178
 179impl ToolCall {
 180    fn from_acp(
 181        tool_call: acp::ToolCall,
 182        status: ToolCallStatus,
 183        language_registry: Arc<LanguageRegistry>,
 184        cx: &mut App,
 185    ) -> Self {
 186        let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
 187            first_line.to_owned() + ""
 188        } else {
 189            tool_call.title
 190        };
 191        Self {
 192            id: tool_call.id,
 193            label: cx
 194                .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
 195            kind: tool_call.kind,
 196            content: tool_call
 197                .content
 198                .into_iter()
 199                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 200                .collect(),
 201            locations: tool_call.locations,
 202            resolved_locations: Vec::default(),
 203            status,
 204            raw_input: tool_call.raw_input,
 205            raw_output: tool_call.raw_output,
 206        }
 207    }
 208
 209    fn update_fields(
 210        &mut self,
 211        fields: acp::ToolCallUpdateFields,
 212        language_registry: Arc<LanguageRegistry>,
 213        cx: &mut App,
 214    ) {
 215        let acp::ToolCallUpdateFields {
 216            kind,
 217            status,
 218            title,
 219            content,
 220            locations,
 221            raw_input,
 222            raw_output,
 223        } = fields;
 224
 225        if let Some(kind) = kind {
 226            self.kind = kind;
 227        }
 228
 229        if let Some(status) = status {
 230            self.status = status.into();
 231        }
 232
 233        if let Some(title) = title {
 234            self.label.update(cx, |label, cx| {
 235                if let Some((first_line, _)) = title.split_once("\n") {
 236                    label.replace(first_line.to_owned() + "", cx)
 237                } else {
 238                    label.replace(title, cx);
 239                }
 240            });
 241        }
 242
 243        if let Some(content) = content {
 244            let new_content_len = content.len();
 245            let mut content = content.into_iter();
 246
 247            // Reuse existing content if we can
 248            for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
 249                old.update_from_acp(new, language_registry.clone(), cx);
 250            }
 251            for new in content {
 252                self.content.push(ToolCallContent::from_acp(
 253                    new,
 254                    language_registry.clone(),
 255                    cx,
 256                ))
 257            }
 258            self.content.truncate(new_content_len);
 259        }
 260
 261        if let Some(locations) = locations {
 262            self.locations = locations;
 263        }
 264
 265        if let Some(raw_input) = raw_input {
 266            self.raw_input = Some(raw_input);
 267        }
 268
 269        if let Some(raw_output) = raw_output {
 270            if self.content.is_empty()
 271                && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
 272            {
 273                self.content
 274                    .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
 275                        markdown,
 276                    }));
 277            }
 278            self.raw_output = Some(raw_output);
 279        }
 280    }
 281
 282    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 283        self.content.iter().filter_map(|content| match content {
 284            ToolCallContent::Diff(diff) => Some(diff),
 285            ToolCallContent::ContentBlock(_) => None,
 286            ToolCallContent::Terminal(_) => None,
 287        })
 288    }
 289
 290    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 291        self.content.iter().filter_map(|content| match content {
 292            ToolCallContent::Terminal(terminal) => Some(terminal),
 293            ToolCallContent::ContentBlock(_) => None,
 294            ToolCallContent::Diff(_) => None,
 295        })
 296    }
 297
 298    fn to_markdown(&self, cx: &App) -> String {
 299        let mut markdown = format!(
 300            "**Tool Call: {}**\nStatus: {}\n\n",
 301            self.label.read(cx).source(),
 302            self.status
 303        );
 304        for content in &self.content {
 305            markdown.push_str(content.to_markdown(cx).as_str());
 306            markdown.push_str("\n\n");
 307        }
 308        markdown
 309    }
 310
 311    async fn resolve_location(
 312        location: acp::ToolCallLocation,
 313        project: WeakEntity<Project>,
 314        cx: &mut AsyncApp,
 315    ) -> Option<AgentLocation> {
 316        let buffer = project
 317            .update(cx, |project, cx| {
 318                project
 319                    .project_path_for_absolute_path(&location.path, cx)
 320                    .map(|path| project.open_buffer(path, cx))
 321            })
 322            .ok()??;
 323        let buffer = buffer.await.log_err()?;
 324        let position = buffer
 325            .update(cx, |buffer, _| {
 326                if let Some(row) = location.line {
 327                    let snapshot = buffer.snapshot();
 328                    let column = snapshot.indent_size_for_line(row).len;
 329                    let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
 330                    snapshot.anchor_before(point)
 331                } else {
 332                    Anchor::MIN
 333                }
 334            })
 335            .ok()?;
 336
 337        Some(AgentLocation {
 338            buffer: buffer.downgrade(),
 339            position,
 340        })
 341    }
 342
 343    fn resolve_locations(
 344        &self,
 345        project: Entity<Project>,
 346        cx: &mut App,
 347    ) -> Task<Vec<Option<AgentLocation>>> {
 348        let locations = self.locations.clone();
 349        project.update(cx, |_, cx| {
 350            cx.spawn(async move |project, cx| {
 351                let mut new_locations = Vec::new();
 352                for location in locations {
 353                    new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
 354                }
 355                new_locations
 356            })
 357        })
 358    }
 359}
 360
 361#[derive(Debug)]
 362pub enum ToolCallStatus {
 363    /// The tool call hasn't started running yet, but we start showing it to
 364    /// the user.
 365    Pending,
 366    /// The tool call is waiting for confirmation from the user.
 367    WaitingForConfirmation {
 368        options: Vec<acp::PermissionOption>,
 369        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 370    },
 371    /// The tool call is currently running.
 372    InProgress,
 373    /// The tool call completed successfully.
 374    Completed,
 375    /// The tool call failed.
 376    Failed,
 377    /// The user rejected the tool call.
 378    Rejected,
 379    /// The user canceled generation so the tool call was canceled.
 380    Canceled,
 381}
 382
 383impl From<acp::ToolCallStatus> for ToolCallStatus {
 384    fn from(status: acp::ToolCallStatus) -> Self {
 385        match status {
 386            acp::ToolCallStatus::Pending => Self::Pending,
 387            acp::ToolCallStatus::InProgress => Self::InProgress,
 388            acp::ToolCallStatus::Completed => Self::Completed,
 389            acp::ToolCallStatus::Failed => Self::Failed,
 390        }
 391    }
 392}
 393
 394impl Display for ToolCallStatus {
 395    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 396        write!(
 397            f,
 398            "{}",
 399            match self {
 400                ToolCallStatus::Pending => "Pending",
 401                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 402                ToolCallStatus::InProgress => "In Progress",
 403                ToolCallStatus::Completed => "Completed",
 404                ToolCallStatus::Failed => "Failed",
 405                ToolCallStatus::Rejected => "Rejected",
 406                ToolCallStatus::Canceled => "Canceled",
 407            }
 408        )
 409    }
 410}
 411
 412#[derive(Debug, PartialEq, Clone)]
 413pub enum ContentBlock {
 414    Empty,
 415    Markdown { markdown: Entity<Markdown> },
 416    ResourceLink { resource_link: acp::ResourceLink },
 417}
 418
 419impl ContentBlock {
 420    pub fn new(
 421        block: acp::ContentBlock,
 422        language_registry: &Arc<LanguageRegistry>,
 423        cx: &mut App,
 424    ) -> Self {
 425        let mut this = Self::Empty;
 426        this.append(block, language_registry, cx);
 427        this
 428    }
 429
 430    pub fn new_combined(
 431        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 432        language_registry: Arc<LanguageRegistry>,
 433        cx: &mut App,
 434    ) -> Self {
 435        let mut this = Self::Empty;
 436        for block in blocks {
 437            this.append(block, &language_registry, cx);
 438        }
 439        this
 440    }
 441
 442    pub fn append(
 443        &mut self,
 444        block: acp::ContentBlock,
 445        language_registry: &Arc<LanguageRegistry>,
 446        cx: &mut App,
 447    ) {
 448        if matches!(self, ContentBlock::Empty)
 449            && let acp::ContentBlock::ResourceLink(resource_link) = block
 450        {
 451            *self = ContentBlock::ResourceLink { resource_link };
 452            return;
 453        }
 454
 455        let new_content = self.block_string_contents(block);
 456
 457        match self {
 458            ContentBlock::Empty => {
 459                *self = Self::create_markdown_block(new_content, language_registry, cx);
 460            }
 461            ContentBlock::Markdown { markdown } => {
 462                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 463            }
 464            ContentBlock::ResourceLink { resource_link } => {
 465                let existing_content = Self::resource_link_md(&resource_link.uri);
 466                let combined = format!("{}\n{}", existing_content, new_content);
 467
 468                *self = Self::create_markdown_block(combined, language_registry, cx);
 469            }
 470        }
 471    }
 472
 473    fn create_markdown_block(
 474        content: String,
 475        language_registry: &Arc<LanguageRegistry>,
 476        cx: &mut App,
 477    ) -> ContentBlock {
 478        ContentBlock::Markdown {
 479            markdown: cx
 480                .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
 481        }
 482    }
 483
 484    fn block_string_contents(&self, block: acp::ContentBlock) -> String {
 485        match block {
 486            acp::ContentBlock::Text(text_content) => text_content.text,
 487            acp::ContentBlock::ResourceLink(resource_link) => {
 488                Self::resource_link_md(&resource_link.uri)
 489            }
 490            acp::ContentBlock::Resource(acp::EmbeddedResource {
 491                resource:
 492                    acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
 493                        uri,
 494                        ..
 495                    }),
 496                ..
 497            }) => Self::resource_link_md(&uri),
 498            acp::ContentBlock::Image(image) => Self::image_md(&image),
 499            acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
 500        }
 501    }
 502
 503    fn resource_link_md(uri: &str) -> String {
 504        if let Some(uri) = MentionUri::parse(uri).log_err() {
 505            uri.as_link().to_string()
 506        } else {
 507            uri.to_string()
 508        }
 509    }
 510
 511    fn image_md(_image: &acp::ImageContent) -> String {
 512        "`Image`".into()
 513    }
 514
 515    pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 516        match self {
 517            ContentBlock::Empty => "",
 518            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 519            ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
 520        }
 521    }
 522
 523    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 524        match self {
 525            ContentBlock::Empty => None,
 526            ContentBlock::Markdown { markdown } => Some(markdown),
 527            ContentBlock::ResourceLink { .. } => None,
 528        }
 529    }
 530
 531    pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
 532        match self {
 533            ContentBlock::ResourceLink { resource_link } => Some(resource_link),
 534            _ => None,
 535        }
 536    }
 537}
 538
 539#[derive(Debug)]
 540pub enum ToolCallContent {
 541    ContentBlock(ContentBlock),
 542    Diff(Entity<Diff>),
 543    Terminal(Entity<Terminal>),
 544}
 545
 546impl ToolCallContent {
 547    pub fn from_acp(
 548        content: acp::ToolCallContent,
 549        language_registry: Arc<LanguageRegistry>,
 550        cx: &mut App,
 551    ) -> Self {
 552        match content {
 553            acp::ToolCallContent::Content { content } => {
 554                Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
 555            }
 556            acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
 557                Diff::finalized(
 558                    diff.path,
 559                    diff.old_text,
 560                    diff.new_text,
 561                    language_registry,
 562                    cx,
 563                )
 564            })),
 565        }
 566    }
 567
 568    pub fn update_from_acp(
 569        &mut self,
 570        new: acp::ToolCallContent,
 571        language_registry: Arc<LanguageRegistry>,
 572        cx: &mut App,
 573    ) {
 574        let needs_update = match (&self, &new) {
 575            (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
 576                old_diff.read(cx).needs_update(
 577                    new_diff.old_text.as_deref().unwrap_or(""),
 578                    &new_diff.new_text,
 579                    cx,
 580                )
 581            }
 582            _ => true,
 583        };
 584
 585        if needs_update {
 586            *self = Self::from_acp(new, language_registry, cx);
 587        }
 588    }
 589
 590    pub fn to_markdown(&self, cx: &App) -> String {
 591        match self {
 592            Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
 593            Self::Diff(diff) => diff.read(cx).to_markdown(cx),
 594            Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
 595        }
 596    }
 597}
 598
 599#[derive(Debug, PartialEq)]
 600pub enum ToolCallUpdate {
 601    UpdateFields(acp::ToolCallUpdate),
 602    UpdateDiff(ToolCallUpdateDiff),
 603    UpdateTerminal(ToolCallUpdateTerminal),
 604}
 605
 606impl ToolCallUpdate {
 607    fn id(&self) -> &acp::ToolCallId {
 608        match self {
 609            Self::UpdateFields(update) => &update.id,
 610            Self::UpdateDiff(diff) => &diff.id,
 611            Self::UpdateTerminal(terminal) => &terminal.id,
 612        }
 613    }
 614}
 615
 616impl From<acp::ToolCallUpdate> for ToolCallUpdate {
 617    fn from(update: acp::ToolCallUpdate) -> Self {
 618        Self::UpdateFields(update)
 619    }
 620}
 621
 622impl From<ToolCallUpdateDiff> for ToolCallUpdate {
 623    fn from(diff: ToolCallUpdateDiff) -> Self {
 624        Self::UpdateDiff(diff)
 625    }
 626}
 627
 628#[derive(Debug, PartialEq)]
 629pub struct ToolCallUpdateDiff {
 630    pub id: acp::ToolCallId,
 631    pub diff: Entity<Diff>,
 632}
 633
 634impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
 635    fn from(terminal: ToolCallUpdateTerminal) -> Self {
 636        Self::UpdateTerminal(terminal)
 637    }
 638}
 639
 640#[derive(Debug, PartialEq)]
 641pub struct ToolCallUpdateTerminal {
 642    pub id: acp::ToolCallId,
 643    pub terminal: Entity<Terminal>,
 644}
 645
 646#[derive(Debug, Default)]
 647pub struct Plan {
 648    pub entries: Vec<PlanEntry>,
 649}
 650
 651#[derive(Debug)]
 652pub struct PlanStats<'a> {
 653    pub in_progress_entry: Option<&'a PlanEntry>,
 654    pub pending: u32,
 655    pub completed: u32,
 656}
 657
 658impl Plan {
 659    pub fn is_empty(&self) -> bool {
 660        self.entries.is_empty()
 661    }
 662
 663    pub fn stats(&self) -> PlanStats<'_> {
 664        let mut stats = PlanStats {
 665            in_progress_entry: None,
 666            pending: 0,
 667            completed: 0,
 668        };
 669
 670        for entry in &self.entries {
 671            match &entry.status {
 672                acp::PlanEntryStatus::Pending => {
 673                    stats.pending += 1;
 674                }
 675                acp::PlanEntryStatus::InProgress => {
 676                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 677                }
 678                acp::PlanEntryStatus::Completed => {
 679                    stats.completed += 1;
 680                }
 681            }
 682        }
 683
 684        stats
 685    }
 686}
 687
 688#[derive(Debug)]
 689pub struct PlanEntry {
 690    pub content: Entity<Markdown>,
 691    pub priority: acp::PlanEntryPriority,
 692    pub status: acp::PlanEntryStatus,
 693}
 694
 695impl PlanEntry {
 696    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 697        Self {
 698            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 699            priority: entry.priority,
 700            status: entry.status,
 701        }
 702    }
 703}
 704
 705#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 706pub struct TokenUsage {
 707    pub max_tokens: u64,
 708    pub used_tokens: u64,
 709}
 710
 711impl TokenUsage {
 712    pub fn ratio(&self) -> TokenUsageRatio {
 713        #[cfg(debug_assertions)]
 714        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 715            .unwrap_or("0.8".to_string())
 716            .parse()
 717            .unwrap();
 718        #[cfg(not(debug_assertions))]
 719        let warning_threshold: f32 = 0.8;
 720
 721        // When the maximum is unknown because there is no selected model,
 722        // avoid showing the token limit warning.
 723        if self.max_tokens == 0 {
 724            TokenUsageRatio::Normal
 725        } else if self.used_tokens >= self.max_tokens {
 726            TokenUsageRatio::Exceeded
 727        } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
 728            TokenUsageRatio::Warning
 729        } else {
 730            TokenUsageRatio::Normal
 731        }
 732    }
 733}
 734
 735#[derive(Debug, Clone, PartialEq, Eq)]
 736pub enum TokenUsageRatio {
 737    Normal,
 738    Warning,
 739    Exceeded,
 740}
 741
 742#[derive(Debug, Clone)]
 743pub struct RetryStatus {
 744    pub last_error: SharedString,
 745    pub attempt: usize,
 746    pub max_attempts: usize,
 747    pub started_at: Instant,
 748    pub duration: Duration,
 749}
 750
 751pub struct AcpThread {
 752    title: SharedString,
 753    entries: Vec<AgentThreadEntry>,
 754    plan: Plan,
 755    project: Entity<Project>,
 756    action_log: Entity<ActionLog>,
 757    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 758    send_task: Option<Task<()>>,
 759    connection: Rc<dyn AgentConnection>,
 760    session_id: acp::SessionId,
 761    token_usage: Option<TokenUsage>,
 762    prompt_capabilities: acp::PromptCapabilities,
 763    _observe_prompt_capabilities: Task<anyhow::Result<()>>,
 764}
 765
 766#[derive(Debug)]
 767pub enum AcpThreadEvent {
 768    NewEntry,
 769    TitleUpdated,
 770    TokenUsageUpdated,
 771    EntryUpdated(usize),
 772    EntriesRemoved(Range<usize>),
 773    ToolAuthorizationRequired,
 774    Retry(RetryStatus),
 775    Stopped,
 776    Error,
 777    LoadError(LoadError),
 778    PromptCapabilitiesUpdated,
 779}
 780
 781impl EventEmitter<AcpThreadEvent> for AcpThread {}
 782
 783#[derive(PartialEq, Eq, Debug)]
 784pub enum ThreadStatus {
 785    Idle,
 786    WaitingForToolConfirmation,
 787    Generating,
 788}
 789
 790#[derive(Debug, Clone)]
 791pub enum LoadError {
 792    Unsupported {
 793        command: SharedString,
 794        current_version: SharedString,
 795        minimum_version: SharedString,
 796    },
 797    FailedToInstall(SharedString),
 798    Exited {
 799        status: ExitStatus,
 800    },
 801    Other(SharedString),
 802}
 803
 804impl Display for LoadError {
 805    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 806        match self {
 807            LoadError::Unsupported {
 808                command: path,
 809                current_version,
 810                minimum_version,
 811            } => {
 812                write!(
 813                    f,
 814                    "version {current_version} from {path} is not supported (need at least {minimum_version})"
 815                )
 816            }
 817            LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
 818            LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
 819            LoadError::Other(msg) => write!(f, "{msg}"),
 820        }
 821    }
 822}
 823
 824impl Error for LoadError {}
 825
 826impl AcpThread {
 827    pub fn new(
 828        title: impl Into<SharedString>,
 829        connection: Rc<dyn AgentConnection>,
 830        project: Entity<Project>,
 831        action_log: Entity<ActionLog>,
 832        session_id: acp::SessionId,
 833        mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
 834        cx: &mut Context<Self>,
 835    ) -> Self {
 836        let prompt_capabilities = *prompt_capabilities_rx.borrow();
 837        let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
 838            loop {
 839                let caps = prompt_capabilities_rx.recv().await?;
 840                this.update(cx, |this, cx| {
 841                    this.prompt_capabilities = caps;
 842                    cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
 843                })?;
 844            }
 845        });
 846
 847        Self {
 848            action_log,
 849            shared_buffers: Default::default(),
 850            entries: Default::default(),
 851            plan: Default::default(),
 852            title: title.into(),
 853            project,
 854            send_task: None,
 855            connection,
 856            session_id,
 857            token_usage: None,
 858            prompt_capabilities,
 859            _observe_prompt_capabilities: task,
 860        }
 861    }
 862
 863    pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
 864        self.prompt_capabilities
 865    }
 866
 867    pub fn connection(&self) -> &Rc<dyn AgentConnection> {
 868        &self.connection
 869    }
 870
 871    /// Returns true if the agent supports custom slash commands.
 872    pub fn supports_custom_commands(&self) -> bool {
 873        self.prompt_capabilities.supports_custom_commands
 874    }
 875
 876    pub fn action_log(&self) -> &Entity<ActionLog> {
 877        &self.action_log
 878    }
 879
 880    pub fn project(&self) -> &Entity<Project> {
 881        &self.project
 882    }
 883
 884    pub fn title(&self) -> SharedString {
 885        self.title.clone()
 886    }
 887
 888    pub fn entries(&self) -> &[AgentThreadEntry] {
 889        &self.entries
 890    }
 891
 892    pub fn session_id(&self) -> &acp::SessionId {
 893        &self.session_id
 894    }
 895
 896    pub fn status(&self) -> ThreadStatus {
 897        if self.send_task.is_some() {
 898            if self.waiting_for_tool_confirmation() {
 899                ThreadStatus::WaitingForToolConfirmation
 900            } else {
 901                ThreadStatus::Generating
 902            }
 903        } else {
 904            ThreadStatus::Idle
 905        }
 906    }
 907
 908    pub fn token_usage(&self) -> Option<&TokenUsage> {
 909        self.token_usage.as_ref()
 910    }
 911
 912    pub fn has_pending_edit_tool_calls(&self) -> bool {
 913        for entry in self.entries.iter().rev() {
 914            match entry {
 915                AgentThreadEntry::UserMessage(_) => return false,
 916                AgentThreadEntry::ToolCall(
 917                    call @ ToolCall {
 918                        status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
 919                        ..
 920                    },
 921                ) if call.diffs().next().is_some() => {
 922                    return true;
 923                }
 924                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 925            }
 926        }
 927
 928        false
 929    }
 930
 931    pub fn used_tools_since_last_user_message(&self) -> bool {
 932        for entry in self.entries.iter().rev() {
 933            match entry {
 934                AgentThreadEntry::UserMessage(..) => return false,
 935                AgentThreadEntry::AssistantMessage(..) => continue,
 936                AgentThreadEntry::ToolCall(..) => return true,
 937            }
 938        }
 939
 940        false
 941    }
 942
 943    pub fn handle_session_update(
 944        &mut self,
 945        update: acp::SessionUpdate,
 946        cx: &mut Context<Self>,
 947    ) -> Result<(), acp::Error> {
 948        match update {
 949            acp::SessionUpdate::UserMessageChunk { content } => {
 950                self.push_user_content_block(None, content, cx);
 951            }
 952            acp::SessionUpdate::AgentMessageChunk { content } => {
 953                self.push_assistant_content_block(content, false, cx);
 954            }
 955            acp::SessionUpdate::AgentThoughtChunk { content } => {
 956                self.push_assistant_content_block(content, true, cx);
 957            }
 958            acp::SessionUpdate::ToolCall(tool_call) => {
 959                self.upsert_tool_call(tool_call, cx)?;
 960            }
 961            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 962                self.update_tool_call(tool_call_update, cx)?;
 963            }
 964            acp::SessionUpdate::Plan(plan) => {
 965                self.update_plan(plan, cx);
 966            }
 967        }
 968        Ok(())
 969    }
 970
 971    pub fn push_user_content_block(
 972        &mut self,
 973        message_id: Option<UserMessageId>,
 974        chunk: acp::ContentBlock,
 975        cx: &mut Context<Self>,
 976    ) {
 977        let language_registry = self.project.read(cx).languages().clone();
 978        let entries_len = self.entries.len();
 979
 980        if let Some(last_entry) = self.entries.last_mut()
 981            && let AgentThreadEntry::UserMessage(UserMessage {
 982                id,
 983                content,
 984                chunks,
 985                ..
 986            }) = last_entry
 987        {
 988            *id = message_id.or(id.take());
 989            content.append(chunk.clone(), &language_registry, cx);
 990            chunks.push(chunk);
 991            let idx = entries_len - 1;
 992            cx.emit(AcpThreadEvent::EntryUpdated(idx));
 993        } else {
 994            let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
 995            self.push_entry(
 996                AgentThreadEntry::UserMessage(UserMessage {
 997                    id: message_id,
 998                    content,
 999                    chunks: vec![chunk],
1000                    checkpoint: None,
1001                }),
1002                cx,
1003            );
1004        }
1005    }
1006
1007    pub fn push_assistant_content_block(
1008        &mut self,
1009        chunk: acp::ContentBlock,
1010        is_thought: bool,
1011        cx: &mut Context<Self>,
1012    ) {
1013        let language_registry = self.project.read(cx).languages().clone();
1014        let entries_len = self.entries.len();
1015        if let Some(last_entry) = self.entries.last_mut()
1016            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1017        {
1018            let idx = entries_len - 1;
1019            cx.emit(AcpThreadEvent::EntryUpdated(idx));
1020            match (chunks.last_mut(), is_thought) {
1021                (Some(AssistantMessageChunk::Message { block }), false)
1022                | (Some(AssistantMessageChunk::Thought { block }), true) => {
1023                    block.append(chunk, &language_registry, cx)
1024                }
1025                _ => {
1026                    let block = ContentBlock::new(chunk, &language_registry, cx);
1027                    if is_thought {
1028                        chunks.push(AssistantMessageChunk::Thought { block })
1029                    } else {
1030                        chunks.push(AssistantMessageChunk::Message { block })
1031                    }
1032                }
1033            }
1034        } else {
1035            let block = ContentBlock::new(chunk, &language_registry, cx);
1036            let chunk = if is_thought {
1037                AssistantMessageChunk::Thought { block }
1038            } else {
1039                AssistantMessageChunk::Message { block }
1040            };
1041
1042            self.push_entry(
1043                AgentThreadEntry::AssistantMessage(AssistantMessage {
1044                    chunks: vec![chunk],
1045                }),
1046                cx,
1047            );
1048        }
1049    }
1050
1051    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1052        self.entries.push(entry);
1053        cx.emit(AcpThreadEvent::NewEntry);
1054    }
1055
1056    pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1057        self.connection.set_title(&self.session_id, cx).is_some()
1058    }
1059
1060    pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1061        if title != self.title {
1062            self.title = title.clone();
1063            cx.emit(AcpThreadEvent::TitleUpdated);
1064            if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1065                return set_title.run(title, cx);
1066            }
1067        }
1068        Task::ready(Ok(()))
1069    }
1070
1071    pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1072        self.token_usage = usage;
1073        cx.emit(AcpThreadEvent::TokenUsageUpdated);
1074    }
1075
1076    pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1077        cx.emit(AcpThreadEvent::Retry(status));
1078    }
1079
1080    pub fn update_tool_call(
1081        &mut self,
1082        update: impl Into<ToolCallUpdate>,
1083        cx: &mut Context<Self>,
1084    ) -> Result<()> {
1085        let update = update.into();
1086        let languages = self.project.read(cx).languages().clone();
1087
1088        let (ix, current_call) = self
1089            .tool_call_mut(update.id())
1090            .context("Tool call not found")?;
1091        match update {
1092            ToolCallUpdate::UpdateFields(update) => {
1093                let location_updated = update.fields.locations.is_some();
1094                current_call.update_fields(update.fields, languages, cx);
1095                if location_updated {
1096                    self.resolve_locations(update.id, cx);
1097                }
1098            }
1099            ToolCallUpdate::UpdateDiff(update) => {
1100                current_call.content.clear();
1101                current_call
1102                    .content
1103                    .push(ToolCallContent::Diff(update.diff));
1104            }
1105            ToolCallUpdate::UpdateTerminal(update) => {
1106                current_call.content.clear();
1107                current_call
1108                    .content
1109                    .push(ToolCallContent::Terminal(update.terminal));
1110            }
1111        }
1112
1113        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1114
1115        Ok(())
1116    }
1117
1118    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1119    pub fn upsert_tool_call(
1120        &mut self,
1121        tool_call: acp::ToolCall,
1122        cx: &mut Context<Self>,
1123    ) -> Result<(), acp::Error> {
1124        let status = tool_call.status.into();
1125        self.upsert_tool_call_inner(tool_call.into(), status, cx)
1126    }
1127
1128    /// Fails if id does not match an existing entry.
1129    pub fn upsert_tool_call_inner(
1130        &mut self,
1131        tool_call_update: acp::ToolCallUpdate,
1132        status: ToolCallStatus,
1133        cx: &mut Context<Self>,
1134    ) -> Result<(), acp::Error> {
1135        let language_registry = self.project.read(cx).languages().clone();
1136        let id = tool_call_update.id.clone();
1137
1138        if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1139            current_call.update_fields(tool_call_update.fields, language_registry, cx);
1140            current_call.status = status;
1141
1142            cx.emit(AcpThreadEvent::EntryUpdated(ix));
1143        } else {
1144            let call =
1145                ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1146            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1147        };
1148
1149        self.resolve_locations(id, cx);
1150        Ok(())
1151    }
1152
1153    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1154        // The tool call we are looking for is typically the last one, or very close to the end.
1155        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1156        self.entries
1157            .iter_mut()
1158            .enumerate()
1159            .rev()
1160            .find_map(|(index, tool_call)| {
1161                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1162                    && &tool_call.id == id
1163                {
1164                    Some((index, tool_call))
1165                } else {
1166                    None
1167                }
1168            })
1169    }
1170
1171    pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1172        self.entries
1173            .iter()
1174            .enumerate()
1175            .rev()
1176            .find_map(|(index, tool_call)| {
1177                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1178                    && &tool_call.id == id
1179                {
1180                    Some((index, tool_call))
1181                } else {
1182                    None
1183                }
1184            })
1185    }
1186
1187    pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1188        let project = self.project.clone();
1189        let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1190            return;
1191        };
1192        let task = tool_call.resolve_locations(project, cx);
1193        cx.spawn(async move |this, cx| {
1194            let resolved_locations = task.await;
1195            this.update(cx, |this, cx| {
1196                let project = this.project.clone();
1197                let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1198                    return;
1199                };
1200                if let Some(Some(location)) = resolved_locations.last() {
1201                    project.update(cx, |project, cx| {
1202                        if let Some(agent_location) = project.agent_location() {
1203                            let should_ignore = agent_location.buffer == location.buffer
1204                                && location
1205                                    .buffer
1206                                    .update(cx, |buffer, _| {
1207                                        let snapshot = buffer.snapshot();
1208                                        let old_position =
1209                                            agent_location.position.to_point(&snapshot);
1210                                        let new_position = location.position.to_point(&snapshot);
1211                                        // ignore this so that when we get updates from the edit tool
1212                                        // the position doesn't reset to the startof line
1213                                        old_position.row == new_position.row
1214                                            && old_position.column > new_position.column
1215                                    })
1216                                    .ok()
1217                                    .unwrap_or_default();
1218                            if !should_ignore {
1219                                project.set_agent_location(Some(location.clone()), cx);
1220                            }
1221                        }
1222                    });
1223                }
1224                if tool_call.resolved_locations != resolved_locations {
1225                    tool_call.resolved_locations = resolved_locations;
1226                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1227                }
1228            })
1229        })
1230        .detach();
1231    }
1232
1233    pub fn request_tool_call_authorization(
1234        &mut self,
1235        tool_call: acp::ToolCallUpdate,
1236        options: Vec<acp::PermissionOption>,
1237        cx: &mut Context<Self>,
1238    ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1239        let (tx, rx) = oneshot::channel();
1240
1241        let status = ToolCallStatus::WaitingForConfirmation {
1242            options,
1243            respond_tx: tx,
1244        };
1245
1246        self.upsert_tool_call_inner(tool_call, status, cx)?;
1247        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1248        Ok(rx)
1249    }
1250
1251    pub fn authorize_tool_call(
1252        &mut self,
1253        id: acp::ToolCallId,
1254        option_id: acp::PermissionOptionId,
1255        option_kind: acp::PermissionOptionKind,
1256        cx: &mut Context<Self>,
1257    ) {
1258        let Some((ix, call)) = self.tool_call_mut(&id) else {
1259            return;
1260        };
1261
1262        let new_status = match option_kind {
1263            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1264                ToolCallStatus::Rejected
1265            }
1266            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1267                ToolCallStatus::InProgress
1268            }
1269        };
1270
1271        let curr_status = mem::replace(&mut call.status, new_status);
1272
1273        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1274            respond_tx.send(option_id).log_err();
1275        } else if cfg!(debug_assertions) {
1276            panic!("tried to authorize an already authorized tool call");
1277        }
1278
1279        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1280    }
1281
1282    /// Returns true if the last turn is awaiting tool authorization
1283    pub fn waiting_for_tool_confirmation(&self) -> bool {
1284        for entry in self.entries.iter().rev() {
1285            match &entry {
1286                AgentThreadEntry::ToolCall(call) => match call.status {
1287                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
1288                    ToolCallStatus::Pending
1289                    | ToolCallStatus::InProgress
1290                    | ToolCallStatus::Completed
1291                    | ToolCallStatus::Failed
1292                    | ToolCallStatus::Rejected
1293                    | ToolCallStatus::Canceled => continue,
1294                },
1295                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1296                    // Reached the beginning of the turn
1297                    return false;
1298                }
1299            }
1300        }
1301        false
1302    }
1303
1304    pub fn plan(&self) -> &Plan {
1305        &self.plan
1306    }
1307
1308    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1309        let new_entries_len = request.entries.len();
1310        let mut new_entries = request.entries.into_iter();
1311
1312        // Reuse existing markdown to prevent flickering
1313        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1314            let PlanEntry {
1315                content,
1316                priority,
1317                status,
1318            } = old;
1319            content.update(cx, |old, cx| {
1320                old.replace(new.content, cx);
1321            });
1322            *priority = new.priority;
1323            *status = new.status;
1324        }
1325        for new in new_entries {
1326            self.plan.entries.push(PlanEntry::from_acp(new, cx))
1327        }
1328        self.plan.entries.truncate(new_entries_len);
1329
1330        cx.notify();
1331    }
1332
1333    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1334        self.plan
1335            .entries
1336            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1337        cx.notify();
1338    }
1339
1340    #[cfg(any(test, feature = "test-support"))]
1341    pub fn send_raw(
1342        &mut self,
1343        message: &str,
1344        cx: &mut Context<Self>,
1345    ) -> BoxFuture<'static, Result<()>> {
1346        self.send(
1347            vec![acp::ContentBlock::Text(acp::TextContent {
1348                text: message.to_string(),
1349                annotations: None,
1350            })],
1351            cx,
1352        )
1353    }
1354
1355    pub fn send(
1356        &mut self,
1357        message: Vec<acp::ContentBlock>,
1358        cx: &mut Context<Self>,
1359    ) -> BoxFuture<'static, Result<()>> {
1360        let block = ContentBlock::new_combined(
1361            message.clone(),
1362            self.project.read(cx).languages().clone(),
1363            cx,
1364        );
1365        let request = acp::PromptRequest {
1366            prompt: message.clone(),
1367            session_id: self.session_id.clone(),
1368        };
1369        let git_store = self.project.read(cx).git_store().clone();
1370
1371        let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1372            Some(UserMessageId::new())
1373        } else {
1374            None
1375        };
1376
1377        self.run_turn(cx, async move |this, cx| {
1378            this.update(cx, |this, cx| {
1379                this.push_entry(
1380                    AgentThreadEntry::UserMessage(UserMessage {
1381                        id: message_id.clone(),
1382                        content: block,
1383                        chunks: message,
1384                        checkpoint: None,
1385                    }),
1386                    cx,
1387                );
1388            })
1389            .ok();
1390
1391            let old_checkpoint = git_store
1392                .update(cx, |git, cx| git.checkpoint(cx))?
1393                .await
1394                .context("failed to get old checkpoint")
1395                .log_err();
1396            this.update(cx, |this, cx| {
1397                if let Some((_ix, message)) = this.last_user_message() {
1398                    message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1399                        git_checkpoint,
1400                        show: false,
1401                    });
1402                }
1403                this.connection.prompt(message_id, request, cx)
1404            })?
1405            .await
1406        })
1407    }
1408
1409    pub fn can_resume(&self, cx: &App) -> bool {
1410        self.connection.resume(&self.session_id, cx).is_some()
1411    }
1412
1413    pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1414        self.run_turn(cx, async move |this, cx| {
1415            this.update(cx, |this, cx| {
1416                this.connection
1417                    .resume(&this.session_id, cx)
1418                    .map(|resume| resume.run(cx))
1419            })?
1420            .context("resuming a session is not supported")?
1421            .await
1422        })
1423    }
1424
1425    fn run_turn(
1426        &mut self,
1427        cx: &mut Context<Self>,
1428        f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1429    ) -> BoxFuture<'static, Result<()>> {
1430        self.clear_completed_plan_entries(cx);
1431
1432        let (tx, rx) = oneshot::channel();
1433        let cancel_task = self.cancel(cx);
1434
1435        self.send_task = Some(cx.spawn(async move |this, cx| {
1436            cancel_task.await;
1437            tx.send(f(this, cx).await).ok();
1438        }));
1439
1440        cx.spawn(async move |this, cx| {
1441            let response = rx.await;
1442
1443            this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1444                .await?;
1445
1446            this.update(cx, |this, cx| {
1447                this.project
1448                    .update(cx, |project, cx| project.set_agent_location(None, cx));
1449                match response {
1450                    Ok(Err(e)) => {
1451                        this.send_task.take();
1452                        cx.emit(AcpThreadEvent::Error);
1453                        Err(e)
1454                    }
1455                    result => {
1456                        let canceled = matches!(
1457                            result,
1458                            Ok(Ok(acp::PromptResponse {
1459                                stop_reason: acp::StopReason::Cancelled
1460                            }))
1461                        );
1462
1463                        // We only take the task if the current prompt wasn't canceled.
1464                        //
1465                        // This prompt may have been canceled because another one was sent
1466                        // while it was still generating. In these cases, dropping `send_task`
1467                        // would cause the next generation to be canceled.
1468                        if !canceled {
1469                            this.send_task.take();
1470                        }
1471
1472                        // Truncate entries if the last prompt was refused.
1473                        if let Ok(Ok(acp::PromptResponse {
1474                            stop_reason: acp::StopReason::Refusal,
1475                        })) = result
1476                            && let Some((ix, _)) = this.last_user_message()
1477                        {
1478                            let range = ix..this.entries.len();
1479                            this.entries.truncate(ix);
1480                            cx.emit(AcpThreadEvent::EntriesRemoved(range));
1481                        }
1482
1483                        cx.emit(AcpThreadEvent::Stopped);
1484                        Ok(())
1485                    }
1486                }
1487            })?
1488        })
1489        .boxed()
1490    }
1491
1492    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1493        let Some(send_task) = self.send_task.take() else {
1494            return Task::ready(());
1495        };
1496
1497        for entry in self.entries.iter_mut() {
1498            if let AgentThreadEntry::ToolCall(call) = entry {
1499                let cancel = matches!(
1500                    call.status,
1501                    ToolCallStatus::Pending
1502                        | ToolCallStatus::WaitingForConfirmation { .. }
1503                        | ToolCallStatus::InProgress
1504                );
1505
1506                if cancel {
1507                    call.status = ToolCallStatus::Canceled;
1508                }
1509            }
1510        }
1511
1512        self.connection.cancel(&self.session_id, cx);
1513
1514        // Wait for the send task to complete
1515        cx.foreground_executor().spawn(send_task)
1516    }
1517
1518    /// Rewinds this thread to before the entry at `index`, removing it and all
1519    /// subsequent entries while reverting any changes made from that point.
1520    pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1521        let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1522            return Task::ready(Err(anyhow!("not supported")));
1523        };
1524        let Some(message) = self.user_message(&id) else {
1525            return Task::ready(Err(anyhow!("message not found")));
1526        };
1527
1528        let checkpoint = message
1529            .checkpoint
1530            .as_ref()
1531            .map(|c| c.git_checkpoint.clone());
1532
1533        let git_store = self.project.read(cx).git_store().clone();
1534        cx.spawn(async move |this, cx| {
1535            if let Some(checkpoint) = checkpoint {
1536                git_store
1537                    .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1538                    .await?;
1539            }
1540
1541            cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1542            this.update(cx, |this, cx| {
1543                if let Some((ix, _)) = this.user_message_mut(&id) {
1544                    let range = ix..this.entries.len();
1545                    this.entries.truncate(ix);
1546                    cx.emit(AcpThreadEvent::EntriesRemoved(range));
1547                }
1548            })
1549        })
1550    }
1551
1552    fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1553        let git_store = self.project.read(cx).git_store().clone();
1554
1555        let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1556            if let Some(checkpoint) = message.checkpoint.as_ref() {
1557                checkpoint.git_checkpoint.clone()
1558            } else {
1559                return Task::ready(Ok(()));
1560            }
1561        } else {
1562            return Task::ready(Ok(()));
1563        };
1564
1565        let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1566        cx.spawn(async move |this, cx| {
1567            let new_checkpoint = new_checkpoint
1568                .await
1569                .context("failed to get new checkpoint")
1570                .log_err();
1571            if let Some(new_checkpoint) = new_checkpoint {
1572                let equal = git_store
1573                    .update(cx, |git, cx| {
1574                        git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1575                    })?
1576                    .await
1577                    .unwrap_or(true);
1578                this.update(cx, |this, cx| {
1579                    let (ix, message) = this.last_user_message().context("no user message")?;
1580                    let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1581                    checkpoint.show = !equal;
1582                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1583                    anyhow::Ok(())
1584                })??;
1585            }
1586
1587            Ok(())
1588        })
1589    }
1590
1591    fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1592        self.entries
1593            .iter_mut()
1594            .enumerate()
1595            .rev()
1596            .find_map(|(ix, entry)| {
1597                if let AgentThreadEntry::UserMessage(message) = entry {
1598                    Some((ix, message))
1599                } else {
1600                    None
1601                }
1602            })
1603    }
1604
1605    fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1606        self.entries.iter().find_map(|entry| {
1607            if let AgentThreadEntry::UserMessage(message) = entry {
1608                if message.id.as_ref() == Some(id) {
1609                    Some(message)
1610                } else {
1611                    None
1612                }
1613            } else {
1614                None
1615            }
1616        })
1617    }
1618
1619    fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1620        self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1621            if let AgentThreadEntry::UserMessage(message) = entry {
1622                if message.id.as_ref() == Some(id) {
1623                    Some((ix, message))
1624                } else {
1625                    None
1626                }
1627            } else {
1628                None
1629            }
1630        })
1631    }
1632
1633    pub fn read_text_file(
1634        &self,
1635        path: PathBuf,
1636        line: Option<u32>,
1637        limit: Option<u32>,
1638        reuse_shared_snapshot: bool,
1639        cx: &mut Context<Self>,
1640    ) -> Task<Result<String>> {
1641        let project = self.project.clone();
1642        let action_log = self.action_log.clone();
1643        cx.spawn(async move |this, cx| {
1644            let load = project.update(cx, |project, cx| {
1645                let path = project
1646                    .project_path_for_absolute_path(&path, cx)
1647                    .context("invalid path")?;
1648                anyhow::Ok(project.open_buffer(path, cx))
1649            });
1650            let buffer = load??.await?;
1651
1652            let snapshot = if reuse_shared_snapshot {
1653                this.read_with(cx, |this, _| {
1654                    this.shared_buffers.get(&buffer.clone()).cloned()
1655                })
1656                .log_err()
1657                .flatten()
1658            } else {
1659                None
1660            };
1661
1662            let snapshot = if let Some(snapshot) = snapshot {
1663                snapshot
1664            } else {
1665                action_log.update(cx, |action_log, cx| {
1666                    action_log.buffer_read(buffer.clone(), cx);
1667                })?;
1668                project.update(cx, |project, cx| {
1669                    let position = buffer
1670                        .read(cx)
1671                        .snapshot()
1672                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1673                    project.set_agent_location(
1674                        Some(AgentLocation {
1675                            buffer: buffer.downgrade(),
1676                            position,
1677                        }),
1678                        cx,
1679                    );
1680                })?;
1681
1682                buffer.update(cx, |buffer, _| buffer.snapshot())?
1683            };
1684
1685            this.update(cx, |this, _| {
1686                let text = snapshot.text();
1687                this.shared_buffers.insert(buffer.clone(), snapshot);
1688                if line.is_none() && limit.is_none() {
1689                    return Ok(text);
1690                }
1691                let limit = limit.unwrap_or(u32::MAX) as usize;
1692                let Some(line) = line else {
1693                    return Ok(text.lines().take(limit).collect::<String>());
1694                };
1695
1696                let count = text.lines().count();
1697                if count < line as usize {
1698                    anyhow::bail!("There are only {} lines", count);
1699                }
1700                Ok(text
1701                    .lines()
1702                    .skip(line as usize + 1)
1703                    .take(limit)
1704                    .collect::<String>())
1705            })?
1706        })
1707    }
1708
1709    pub fn write_text_file(
1710        &self,
1711        path: PathBuf,
1712        content: String,
1713        cx: &mut Context<Self>,
1714    ) -> Task<Result<()>> {
1715        let project = self.project.clone();
1716        let action_log = self.action_log.clone();
1717        cx.spawn(async move |this, cx| {
1718            let load = project.update(cx, |project, cx| {
1719                let path = project
1720                    .project_path_for_absolute_path(&path, cx)
1721                    .context("invalid path")?;
1722                anyhow::Ok(project.open_buffer(path, cx))
1723            });
1724            let buffer = load??.await?;
1725            let snapshot = this.update(cx, |this, cx| {
1726                this.shared_buffers
1727                    .get(&buffer)
1728                    .cloned()
1729                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1730            })?;
1731            let edits = cx
1732                .background_executor()
1733                .spawn(async move {
1734                    let old_text = snapshot.text();
1735                    text_diff(old_text.as_str(), &content)
1736                        .into_iter()
1737                        .map(|(range, replacement)| {
1738                            (
1739                                snapshot.anchor_after(range.start)
1740                                    ..snapshot.anchor_before(range.end),
1741                                replacement,
1742                            )
1743                        })
1744                        .collect::<Vec<_>>()
1745                })
1746                .await;
1747
1748            project.update(cx, |project, cx| {
1749                project.set_agent_location(
1750                    Some(AgentLocation {
1751                        buffer: buffer.downgrade(),
1752                        position: edits
1753                            .last()
1754                            .map(|(range, _)| range.end)
1755                            .unwrap_or(Anchor::MIN),
1756                    }),
1757                    cx,
1758                );
1759            })?;
1760
1761            let format_on_save = cx.update(|cx| {
1762                action_log.update(cx, |action_log, cx| {
1763                    action_log.buffer_read(buffer.clone(), cx);
1764                });
1765
1766                let format_on_save = buffer.update(cx, |buffer, cx| {
1767                    buffer.edit(edits, None, cx);
1768
1769                    let settings = language::language_settings::language_settings(
1770                        buffer.language().map(|l| l.name()),
1771                        buffer.file(),
1772                        cx,
1773                    );
1774
1775                    settings.format_on_save != FormatOnSave::Off
1776                });
1777                action_log.update(cx, |action_log, cx| {
1778                    action_log.buffer_edited(buffer.clone(), cx);
1779                });
1780                format_on_save
1781            })?;
1782
1783            if format_on_save {
1784                let format_task = project.update(cx, |project, cx| {
1785                    project.format(
1786                        HashSet::from_iter([buffer.clone()]),
1787                        LspFormatTarget::Buffers,
1788                        false,
1789                        FormatTrigger::Save,
1790                        cx,
1791                    )
1792                })?;
1793                format_task.await.log_err();
1794
1795                action_log.update(cx, |action_log, cx| {
1796                    action_log.buffer_edited(buffer.clone(), cx);
1797                })?;
1798            }
1799
1800            project
1801                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1802                .await
1803        })
1804    }
1805
1806    pub fn to_markdown(&self, cx: &App) -> String {
1807        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1808    }
1809
1810    pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
1811        cx.emit(AcpThreadEvent::LoadError(error));
1812    }
1813}
1814
1815fn markdown_for_raw_output(
1816    raw_output: &serde_json::Value,
1817    language_registry: &Arc<LanguageRegistry>,
1818    cx: &mut App,
1819) -> Option<Entity<Markdown>> {
1820    match raw_output {
1821        serde_json::Value::Null => None,
1822        serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1823            Markdown::new(
1824                value.to_string().into(),
1825                Some(language_registry.clone()),
1826                None,
1827                cx,
1828            )
1829        })),
1830        serde_json::Value::Number(value) => Some(cx.new(|cx| {
1831            Markdown::new(
1832                value.to_string().into(),
1833                Some(language_registry.clone()),
1834                None,
1835                cx,
1836            )
1837        })),
1838        serde_json::Value::String(value) => Some(cx.new(|cx| {
1839            Markdown::new(
1840                value.clone().into(),
1841                Some(language_registry.clone()),
1842                None,
1843                cx,
1844            )
1845        })),
1846        value => Some(cx.new(|cx| {
1847            Markdown::new(
1848                format!("```json\n{}\n```", value).into(),
1849                Some(language_registry.clone()),
1850                None,
1851                cx,
1852            )
1853        })),
1854    }
1855}
1856
1857#[cfg(test)]
1858mod tests {
1859    use super::*;
1860    use anyhow::anyhow;
1861    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1862    use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1863    use indoc::indoc;
1864    use project::{FakeFs, Fs};
1865    use rand::Rng as _;
1866    use serde_json::json;
1867    use settings::SettingsStore;
1868    use smol::stream::StreamExt as _;
1869    use std::{
1870        any::Any,
1871        cell::RefCell,
1872        path::Path,
1873        rc::Rc,
1874        sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1875        time::Duration,
1876    };
1877    use util::path;
1878
1879    fn init_test(cx: &mut TestAppContext) {
1880        env_logger::try_init().ok();
1881        cx.update(|cx| {
1882            let settings_store = SettingsStore::test(cx);
1883            cx.set_global(settings_store);
1884            Project::init_settings(cx);
1885            language::init(cx);
1886        });
1887    }
1888
1889    #[gpui::test]
1890    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1891        init_test(cx);
1892
1893        let fs = FakeFs::new(cx.executor());
1894        let project = Project::test(fs, [], cx).await;
1895        let connection = Rc::new(FakeAgentConnection::new());
1896        let thread = cx
1897            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1898            .await
1899            .unwrap();
1900
1901        // Test creating a new user message
1902        thread.update(cx, |thread, cx| {
1903            thread.push_user_content_block(
1904                None,
1905                acp::ContentBlock::Text(acp::TextContent {
1906                    annotations: None,
1907                    text: "Hello, ".to_string(),
1908                }),
1909                cx,
1910            );
1911        });
1912
1913        thread.update(cx, |thread, cx| {
1914            assert_eq!(thread.entries.len(), 1);
1915            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1916                assert_eq!(user_msg.id, None);
1917                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1918            } else {
1919                panic!("Expected UserMessage");
1920            }
1921        });
1922
1923        // Test appending to existing user message
1924        let message_1_id = UserMessageId::new();
1925        thread.update(cx, |thread, cx| {
1926            thread.push_user_content_block(
1927                Some(message_1_id.clone()),
1928                acp::ContentBlock::Text(acp::TextContent {
1929                    annotations: None,
1930                    text: "world!".to_string(),
1931                }),
1932                cx,
1933            );
1934        });
1935
1936        thread.update(cx, |thread, cx| {
1937            assert_eq!(thread.entries.len(), 1);
1938            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1939                assert_eq!(user_msg.id, Some(message_1_id));
1940                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1941            } else {
1942                panic!("Expected UserMessage");
1943            }
1944        });
1945
1946        // Test creating new user message after assistant message
1947        thread.update(cx, |thread, cx| {
1948            thread.push_assistant_content_block(
1949                acp::ContentBlock::Text(acp::TextContent {
1950                    annotations: None,
1951                    text: "Assistant response".to_string(),
1952                }),
1953                false,
1954                cx,
1955            );
1956        });
1957
1958        let message_2_id = UserMessageId::new();
1959        thread.update(cx, |thread, cx| {
1960            thread.push_user_content_block(
1961                Some(message_2_id.clone()),
1962                acp::ContentBlock::Text(acp::TextContent {
1963                    annotations: None,
1964                    text: "New user message".to_string(),
1965                }),
1966                cx,
1967            );
1968        });
1969
1970        thread.update(cx, |thread, cx| {
1971            assert_eq!(thread.entries.len(), 3);
1972            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1973                assert_eq!(user_msg.id, Some(message_2_id));
1974                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1975            } else {
1976                panic!("Expected UserMessage at index 2");
1977            }
1978        });
1979    }
1980
1981    #[gpui::test]
1982    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1983        init_test(cx);
1984
1985        let fs = FakeFs::new(cx.executor());
1986        let project = Project::test(fs, [], cx).await;
1987        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1988            |_, thread, mut cx| {
1989                async move {
1990                    thread.update(&mut cx, |thread, cx| {
1991                        thread
1992                            .handle_session_update(
1993                                acp::SessionUpdate::AgentThoughtChunk {
1994                                    content: "Thinking ".into(),
1995                                },
1996                                cx,
1997                            )
1998                            .unwrap();
1999                        thread
2000                            .handle_session_update(
2001                                acp::SessionUpdate::AgentThoughtChunk {
2002                                    content: "hard!".into(),
2003                                },
2004                                cx,
2005                            )
2006                            .unwrap();
2007                    })?;
2008                    Ok(acp::PromptResponse {
2009                        stop_reason: acp::StopReason::EndTurn,
2010                    })
2011                }
2012                .boxed_local()
2013            },
2014        ));
2015
2016        let thread = cx
2017            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2018            .await
2019            .unwrap();
2020
2021        thread
2022            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2023            .await
2024            .unwrap();
2025
2026        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2027        assert_eq!(
2028            output,
2029            indoc! {r#"
2030            ## User
2031
2032            Hello from Zed!
2033
2034            ## Assistant
2035
2036            <thinking>
2037            Thinking hard!
2038            </thinking>
2039
2040            "#}
2041        );
2042    }
2043
2044    #[gpui::test]
2045    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2046        init_test(cx);
2047
2048        let fs = FakeFs::new(cx.executor());
2049        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2050            .await;
2051        let project = Project::test(fs.clone(), [], cx).await;
2052        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2053        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2054        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2055            move |_, thread, mut cx| {
2056                let read_file_tx = read_file_tx.clone();
2057                async move {
2058                    let content = thread
2059                        .update(&mut cx, |thread, cx| {
2060                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2061                        })
2062                        .unwrap()
2063                        .await
2064                        .unwrap();
2065                    assert_eq!(content, "one\ntwo\nthree\n");
2066                    read_file_tx.take().unwrap().send(()).unwrap();
2067                    thread
2068                        .update(&mut cx, |thread, cx| {
2069                            thread.write_text_file(
2070                                path!("/tmp/foo").into(),
2071                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
2072                                cx,
2073                            )
2074                        })
2075                        .unwrap()
2076                        .await
2077                        .unwrap();
2078                    Ok(acp::PromptResponse {
2079                        stop_reason: acp::StopReason::EndTurn,
2080                    })
2081                }
2082                .boxed_local()
2083            },
2084        ));
2085
2086        let (worktree, pathbuf) = project
2087            .update(cx, |project, cx| {
2088                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2089            })
2090            .await
2091            .unwrap();
2092        let buffer = project
2093            .update(cx, |project, cx| {
2094                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2095            })
2096            .await
2097            .unwrap();
2098
2099        let thread = cx
2100            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2101            .await
2102            .unwrap();
2103
2104        let request = thread.update(cx, |thread, cx| {
2105            thread.send_raw("Extend the count in /tmp/foo", cx)
2106        });
2107        read_file_rx.await.ok();
2108        buffer.update(cx, |buffer, cx| {
2109            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2110        });
2111        cx.run_until_parked();
2112        assert_eq!(
2113            buffer.read_with(cx, |buffer, _| buffer.text()),
2114            "zero\none\ntwo\nthree\nfour\nfive\n"
2115        );
2116        assert_eq!(
2117            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2118            "zero\none\ntwo\nthree\nfour\nfive\n"
2119        );
2120        request.await.unwrap();
2121    }
2122
2123    #[gpui::test]
2124    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2125        init_test(cx);
2126
2127        let fs = FakeFs::new(cx.executor());
2128        let project = Project::test(fs, [], cx).await;
2129        let id = acp::ToolCallId("test".into());
2130
2131        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2132            let id = id.clone();
2133            move |_, thread, mut cx| {
2134                let id = id.clone();
2135                async move {
2136                    thread
2137                        .update(&mut cx, |thread, cx| {
2138                            thread.handle_session_update(
2139                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2140                                    id: id.clone(),
2141                                    title: "Label".into(),
2142                                    kind: acp::ToolKind::Fetch,
2143                                    status: acp::ToolCallStatus::InProgress,
2144                                    content: vec![],
2145                                    locations: vec![],
2146                                    raw_input: None,
2147                                    raw_output: None,
2148                                }),
2149                                cx,
2150                            )
2151                        })
2152                        .unwrap()
2153                        .unwrap();
2154                    Ok(acp::PromptResponse {
2155                        stop_reason: acp::StopReason::EndTurn,
2156                    })
2157                }
2158                .boxed_local()
2159            }
2160        }));
2161
2162        let thread = cx
2163            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2164            .await
2165            .unwrap();
2166
2167        let request = thread.update(cx, |thread, cx| {
2168            thread.send_raw("Fetch https://example.com", cx)
2169        });
2170
2171        run_until_first_tool_call(&thread, cx).await;
2172
2173        thread.read_with(cx, |thread, _| {
2174            assert!(matches!(
2175                thread.entries[1],
2176                AgentThreadEntry::ToolCall(ToolCall {
2177                    status: ToolCallStatus::InProgress,
2178                    ..
2179                })
2180            ));
2181        });
2182
2183        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2184
2185        thread.read_with(cx, |thread, _| {
2186            assert!(matches!(
2187                &thread.entries[1],
2188                AgentThreadEntry::ToolCall(ToolCall {
2189                    status: ToolCallStatus::Canceled,
2190                    ..
2191                })
2192            ));
2193        });
2194
2195        thread
2196            .update(cx, |thread, cx| {
2197                thread.handle_session_update(
2198                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2199                        id,
2200                        fields: acp::ToolCallUpdateFields {
2201                            status: Some(acp::ToolCallStatus::Completed),
2202                            ..Default::default()
2203                        },
2204                    }),
2205                    cx,
2206                )
2207            })
2208            .unwrap();
2209
2210        request.await.unwrap();
2211
2212        thread.read_with(cx, |thread, _| {
2213            assert!(matches!(
2214                thread.entries[1],
2215                AgentThreadEntry::ToolCall(ToolCall {
2216                    status: ToolCallStatus::Completed,
2217                    ..
2218                })
2219            ));
2220        });
2221    }
2222
2223    #[gpui::test]
2224    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2225        init_test(cx);
2226        let fs = FakeFs::new(cx.background_executor.clone());
2227        fs.insert_tree(path!("/test"), json!({})).await;
2228        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2229
2230        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2231            move |_, thread, mut cx| {
2232                async move {
2233                    thread
2234                        .update(&mut cx, |thread, cx| {
2235                            thread.handle_session_update(
2236                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2237                                    id: acp::ToolCallId("test".into()),
2238                                    title: "Label".into(),
2239                                    kind: acp::ToolKind::Edit,
2240                                    status: acp::ToolCallStatus::Completed,
2241                                    content: vec![acp::ToolCallContent::Diff {
2242                                        diff: acp::Diff {
2243                                            path: "/test/test.txt".into(),
2244                                            old_text: None,
2245                                            new_text: "foo".into(),
2246                                        },
2247                                    }],
2248                                    locations: vec![],
2249                                    raw_input: None,
2250                                    raw_output: None,
2251                                }),
2252                                cx,
2253                            )
2254                        })
2255                        .unwrap()
2256                        .unwrap();
2257                    Ok(acp::PromptResponse {
2258                        stop_reason: acp::StopReason::EndTurn,
2259                    })
2260                }
2261                .boxed_local()
2262            }
2263        }));
2264
2265        let thread = cx
2266            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2267            .await
2268            .unwrap();
2269
2270        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2271            .await
2272            .unwrap();
2273
2274        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2275    }
2276
2277    #[gpui::test(iterations = 10)]
2278    async fn test_checkpoints(cx: &mut TestAppContext) {
2279        init_test(cx);
2280        let fs = FakeFs::new(cx.background_executor.clone());
2281        fs.insert_tree(
2282            path!("/test"),
2283            json!({
2284                ".git": {}
2285            }),
2286        )
2287        .await;
2288        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2289
2290        let simulate_changes = Arc::new(AtomicBool::new(true));
2291        let next_filename = Arc::new(AtomicUsize::new(0));
2292        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2293            let simulate_changes = simulate_changes.clone();
2294            let next_filename = next_filename.clone();
2295            let fs = fs.clone();
2296            move |request, thread, mut cx| {
2297                let fs = fs.clone();
2298                let simulate_changes = simulate_changes.clone();
2299                let next_filename = next_filename.clone();
2300                async move {
2301                    if simulate_changes.load(SeqCst) {
2302                        let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2303                        fs.write(Path::new(&filename), b"").await?;
2304                    }
2305
2306                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2307                        panic!("expected text content block");
2308                    };
2309                    thread.update(&mut cx, |thread, cx| {
2310                        thread
2311                            .handle_session_update(
2312                                acp::SessionUpdate::AgentMessageChunk {
2313                                    content: content.text.to_uppercase().into(),
2314                                },
2315                                cx,
2316                            )
2317                            .unwrap();
2318                    })?;
2319                    Ok(acp::PromptResponse {
2320                        stop_reason: acp::StopReason::EndTurn,
2321                    })
2322                }
2323                .boxed_local()
2324            }
2325        }));
2326        let thread = cx
2327            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2328            .await
2329            .unwrap();
2330
2331        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2332            .await
2333            .unwrap();
2334        thread.read_with(cx, |thread, cx| {
2335            assert_eq!(
2336                thread.to_markdown(cx),
2337                indoc! {"
2338                    ## User (checkpoint)
2339
2340                    Lorem
2341
2342                    ## Assistant
2343
2344                    LOREM
2345
2346                "}
2347            );
2348        });
2349        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2350
2351        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2352            .await
2353            .unwrap();
2354        thread.read_with(cx, |thread, cx| {
2355            assert_eq!(
2356                thread.to_markdown(cx),
2357                indoc! {"
2358                    ## User (checkpoint)
2359
2360                    Lorem
2361
2362                    ## Assistant
2363
2364                    LOREM
2365
2366                    ## User (checkpoint)
2367
2368                    ipsum
2369
2370                    ## Assistant
2371
2372                    IPSUM
2373
2374                "}
2375            );
2376        });
2377        assert_eq!(
2378            fs.files(),
2379            vec![
2380                Path::new(path!("/test/file-0")),
2381                Path::new(path!("/test/file-1"))
2382            ]
2383        );
2384
2385        // Checkpoint isn't stored when there are no changes.
2386        simulate_changes.store(false, SeqCst);
2387        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2388            .await
2389            .unwrap();
2390        thread.read_with(cx, |thread, cx| {
2391            assert_eq!(
2392                thread.to_markdown(cx),
2393                indoc! {"
2394                    ## User (checkpoint)
2395
2396                    Lorem
2397
2398                    ## Assistant
2399
2400                    LOREM
2401
2402                    ## User (checkpoint)
2403
2404                    ipsum
2405
2406                    ## Assistant
2407
2408                    IPSUM
2409
2410                    ## User
2411
2412                    dolor
2413
2414                    ## Assistant
2415
2416                    DOLOR
2417
2418                "}
2419            );
2420        });
2421        assert_eq!(
2422            fs.files(),
2423            vec![
2424                Path::new(path!("/test/file-0")),
2425                Path::new(path!("/test/file-1"))
2426            ]
2427        );
2428
2429        // Rewinding the conversation truncates the history and restores the checkpoint.
2430        thread
2431            .update(cx, |thread, cx| {
2432                let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2433                    panic!("unexpected entries {:?}", thread.entries)
2434                };
2435                thread.rewind(message.id.clone().unwrap(), cx)
2436            })
2437            .await
2438            .unwrap();
2439        thread.read_with(cx, |thread, cx| {
2440            assert_eq!(
2441                thread.to_markdown(cx),
2442                indoc! {"
2443                    ## User (checkpoint)
2444
2445                    Lorem
2446
2447                    ## Assistant
2448
2449                    LOREM
2450
2451                "}
2452            );
2453        });
2454        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2455    }
2456
2457    #[gpui::test]
2458    async fn test_refusal(cx: &mut TestAppContext) {
2459        init_test(cx);
2460        let fs = FakeFs::new(cx.background_executor.clone());
2461        fs.insert_tree(path!("/"), json!({})).await;
2462        let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2463
2464        let refuse_next = Arc::new(AtomicBool::new(false));
2465        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2466            let refuse_next = refuse_next.clone();
2467            move |request, thread, mut cx| {
2468                let refuse_next = refuse_next.clone();
2469                async move {
2470                    if refuse_next.load(SeqCst) {
2471                        return Ok(acp::PromptResponse {
2472                            stop_reason: acp::StopReason::Refusal,
2473                        });
2474                    }
2475
2476                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2477                        panic!("expected text content block");
2478                    };
2479                    thread.update(&mut cx, |thread, cx| {
2480                        thread
2481                            .handle_session_update(
2482                                acp::SessionUpdate::AgentMessageChunk {
2483                                    content: content.text.to_uppercase().into(),
2484                                },
2485                                cx,
2486                            )
2487                            .unwrap();
2488                    })?;
2489                    Ok(acp::PromptResponse {
2490                        stop_reason: acp::StopReason::EndTurn,
2491                    })
2492                }
2493                .boxed_local()
2494            }
2495        }));
2496        let thread = cx
2497            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2498            .await
2499            .unwrap();
2500
2501        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2502            .await
2503            .unwrap();
2504        thread.read_with(cx, |thread, cx| {
2505            assert_eq!(
2506                thread.to_markdown(cx),
2507                indoc! {"
2508                    ## User
2509
2510                    hello
2511
2512                    ## Assistant
2513
2514                    HELLO
2515
2516                "}
2517            );
2518        });
2519
2520        // Simulate refusing the second message, ensuring the conversation gets
2521        // truncated to before sending it.
2522        refuse_next.store(true, SeqCst);
2523        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
2524            .await
2525            .unwrap();
2526        thread.read_with(cx, |thread, cx| {
2527            assert_eq!(
2528                thread.to_markdown(cx),
2529                indoc! {"
2530                    ## User
2531
2532                    hello
2533
2534                    ## Assistant
2535
2536                    HELLO
2537
2538                "}
2539            );
2540        });
2541    }
2542
2543    async fn run_until_first_tool_call(
2544        thread: &Entity<AcpThread>,
2545        cx: &mut TestAppContext,
2546    ) -> usize {
2547        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2548
2549        let subscription = cx.update(|cx| {
2550            cx.subscribe(thread, move |thread, _, cx| {
2551                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2552                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2553                        return tx.try_send(ix).unwrap();
2554                    }
2555                }
2556            })
2557        });
2558
2559        select! {
2560            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2561                panic!("Timeout waiting for tool call")
2562            }
2563            ix = rx.next().fuse() => {
2564                drop(subscription);
2565                ix.unwrap()
2566            }
2567        }
2568    }
2569
2570    #[derive(Clone, Default)]
2571    struct FakeAgentConnection {
2572        auth_methods: Vec<acp::AuthMethod>,
2573        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2574        on_user_message: Option<
2575            Rc<
2576                dyn Fn(
2577                        acp::PromptRequest,
2578                        WeakEntity<AcpThread>,
2579                        AsyncApp,
2580                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2581                    + 'static,
2582            >,
2583        >,
2584    }
2585
2586    impl FakeAgentConnection {
2587        fn new() -> Self {
2588            Self {
2589                auth_methods: Vec::new(),
2590                on_user_message: None,
2591                sessions: Arc::default(),
2592            }
2593        }
2594
2595        #[expect(unused)]
2596        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2597            self.auth_methods = auth_methods;
2598            self
2599        }
2600
2601        fn on_user_message(
2602            mut self,
2603            handler: impl Fn(
2604                acp::PromptRequest,
2605                WeakEntity<AcpThread>,
2606                AsyncApp,
2607            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2608            + 'static,
2609        ) -> Self {
2610            self.on_user_message.replace(Rc::new(handler));
2611            self
2612        }
2613    }
2614
2615    impl AgentConnection for FakeAgentConnection {
2616        fn auth_methods(&self) -> &[acp::AuthMethod] {
2617            &self.auth_methods
2618        }
2619
2620        fn new_thread(
2621            self: Rc<Self>,
2622            project: Entity<Project>,
2623            _cwd: &Path,
2624            cx: &mut App,
2625        ) -> Task<gpui::Result<Entity<AcpThread>>> {
2626            let session_id = acp::SessionId(
2627                rand::thread_rng()
2628                    .sample_iter(&rand::distributions::Alphanumeric)
2629                    .take(7)
2630                    .map(char::from)
2631                    .collect::<String>()
2632                    .into(),
2633            );
2634            let action_log = cx.new(|_| ActionLog::new(project.clone()));
2635            let thread = cx.new(|cx| {
2636                AcpThread::new(
2637                    "Test",
2638                    self.clone(),
2639                    project,
2640                    action_log,
2641                    session_id.clone(),
2642                    watch::Receiver::constant(acp::PromptCapabilities {
2643                        image: true,
2644                        audio: true,
2645                        embedded_context: true,
2646                        supports_custom_commands: false,
2647                    }),
2648                    cx,
2649                )
2650            });
2651            self.sessions.lock().insert(session_id, thread.downgrade());
2652            Task::ready(Ok(thread))
2653        }
2654
2655        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2656            if self.auth_methods().iter().any(|m| m.id == method) {
2657                Task::ready(Ok(()))
2658            } else {
2659                Task::ready(Err(anyhow!("Invalid Auth Method")))
2660            }
2661        }
2662
2663        fn prompt(
2664            &self,
2665            _id: Option<UserMessageId>,
2666            params: acp::PromptRequest,
2667            cx: &mut App,
2668        ) -> Task<gpui::Result<acp::PromptResponse>> {
2669            let sessions = self.sessions.lock();
2670            let thread = sessions.get(&params.session_id).unwrap();
2671            if let Some(handler) = &self.on_user_message {
2672                let handler = handler.clone();
2673                let thread = thread.clone();
2674                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2675            } else {
2676                Task::ready(Ok(acp::PromptResponse {
2677                    stop_reason: acp::StopReason::EndTurn,
2678                }))
2679            }
2680        }
2681
2682        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2683            let sessions = self.sessions.lock();
2684            let thread = sessions.get(session_id).unwrap().clone();
2685
2686            cx.spawn(async move |cx| {
2687                thread
2688                    .update(cx, |thread, cx| thread.cancel(cx))
2689                    .unwrap()
2690                    .await
2691            })
2692            .detach();
2693        }
2694
2695        fn truncate(
2696            &self,
2697            session_id: &acp::SessionId,
2698            _cx: &App,
2699        ) -> Option<Rc<dyn AgentSessionTruncate>> {
2700            Some(Rc::new(FakeAgentSessionEditor {
2701                _session_id: session_id.clone(),
2702            }))
2703        }
2704
2705        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2706            self
2707        }
2708    }
2709
2710    struct FakeAgentSessionEditor {
2711        _session_id: acp::SessionId,
2712    }
2713
2714    impl AgentSessionTruncate for FakeAgentSessionEditor {
2715        fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2716            Task::ready(Ok(()))
2717        }
2718    }
2719}