acp_thread.rs

   1mod connection;
   2mod diff;
   3mod mention;
   4mod terminal;
   5
   6use collections::HashSet;
   7pub use connection::*;
   8pub use diff::*;
   9use language::language_settings::FormatOnSave;
  10pub use mention::*;
  11use project::lsp_store::{FormatTrigger, LspFormatTarget};
  12use serde::{Deserialize, Serialize};
  13pub use terminal::*;
  14
  15use action_log::ActionLog;
  16use agent_client_protocol as acp;
  17use anyhow::{Context as _, Result, anyhow};
  18use editor::Bias;
  19use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  20use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  21use itertools::Itertools;
  22use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
  23use markdown::Markdown;
  24use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
  25use std::collections::HashMap;
  26use std::error::Error;
  27use std::fmt::{Formatter, Write};
  28use std::ops::Range;
  29use std::process::ExitStatus;
  30use std::rc::Rc;
  31use std::time::{Duration, Instant};
  32use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
  33use ui::App;
  34use util::ResultExt;
  35
  36#[derive(Debug)]
  37pub struct UserMessage {
  38    pub id: Option<UserMessageId>,
  39    pub content: ContentBlock,
  40    pub chunks: Vec<acp::ContentBlock>,
  41    pub checkpoint: Option<Checkpoint>,
  42}
  43
  44#[derive(Debug)]
  45pub struct Checkpoint {
  46    git_checkpoint: GitStoreCheckpoint,
  47    pub show: bool,
  48}
  49
  50impl UserMessage {
  51    fn to_markdown(&self, cx: &App) -> String {
  52        let mut markdown = String::new();
  53        if self
  54            .checkpoint
  55            .as_ref()
  56            .is_some_and(|checkpoint| checkpoint.show)
  57        {
  58            writeln!(markdown, "## User (checkpoint)").unwrap();
  59        } else {
  60            writeln!(markdown, "## User").unwrap();
  61        }
  62        writeln!(markdown).unwrap();
  63        writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
  64        writeln!(markdown).unwrap();
  65        markdown
  66    }
  67}
  68
  69#[derive(Debug, PartialEq)]
  70pub struct AssistantMessage {
  71    pub chunks: Vec<AssistantMessageChunk>,
  72}
  73
  74impl AssistantMessage {
  75    pub fn to_markdown(&self, cx: &App) -> String {
  76        format!(
  77            "## Assistant\n\n{}\n\n",
  78            self.chunks
  79                .iter()
  80                .map(|chunk| chunk.to_markdown(cx))
  81                .join("\n\n")
  82        )
  83    }
  84}
  85
  86#[derive(Debug, PartialEq)]
  87pub enum AssistantMessageChunk {
  88    Message { block: ContentBlock },
  89    Thought { block: ContentBlock },
  90}
  91
  92impl AssistantMessageChunk {
  93    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
  94        Self::Message {
  95            block: ContentBlock::new(chunk.into(), language_registry, cx),
  96        }
  97    }
  98
  99    fn to_markdown(&self, cx: &App) -> String {
 100        match self {
 101            Self::Message { block } => block.to_markdown(cx).to_string(),
 102            Self::Thought { block } => {
 103                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 104            }
 105        }
 106    }
 107}
 108
 109#[derive(Debug)]
 110pub enum AgentThreadEntry {
 111    UserMessage(UserMessage),
 112    AssistantMessage(AssistantMessage),
 113    ToolCall(ToolCall),
 114}
 115
 116impl AgentThreadEntry {
 117    pub fn to_markdown(&self, cx: &App) -> String {
 118        match self {
 119            Self::UserMessage(message) => message.to_markdown(cx),
 120            Self::AssistantMessage(message) => message.to_markdown(cx),
 121            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 122        }
 123    }
 124
 125    pub fn user_message(&self) -> Option<&UserMessage> {
 126        if let AgentThreadEntry::UserMessage(message) = self {
 127            Some(message)
 128        } else {
 129            None
 130        }
 131    }
 132
 133    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 134        if let AgentThreadEntry::ToolCall(call) = self {
 135            itertools::Either::Left(call.diffs())
 136        } else {
 137            itertools::Either::Right(std::iter::empty())
 138        }
 139    }
 140
 141    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 142        if let AgentThreadEntry::ToolCall(call) = self {
 143            itertools::Either::Left(call.terminals())
 144        } else {
 145            itertools::Either::Right(std::iter::empty())
 146        }
 147    }
 148
 149    pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
 150        if let AgentThreadEntry::ToolCall(ToolCall {
 151            locations,
 152            resolved_locations,
 153            ..
 154        }) = self
 155        {
 156            Some((
 157                locations.get(ix)?.clone(),
 158                resolved_locations.get(ix)?.clone()?,
 159            ))
 160        } else {
 161            None
 162        }
 163    }
 164}
 165
 166#[derive(Debug)]
 167pub struct ToolCall {
 168    pub id: acp::ToolCallId,
 169    pub label: Entity<Markdown>,
 170    pub kind: acp::ToolKind,
 171    pub content: Vec<ToolCallContent>,
 172    pub status: ToolCallStatus,
 173    pub locations: Vec<acp::ToolCallLocation>,
 174    pub resolved_locations: Vec<Option<AgentLocation>>,
 175    pub raw_input: Option<serde_json::Value>,
 176    pub raw_output: Option<serde_json::Value>,
 177}
 178
 179impl ToolCall {
 180    fn from_acp(
 181        tool_call: acp::ToolCall,
 182        status: ToolCallStatus,
 183        language_registry: Arc<LanguageRegistry>,
 184        cx: &mut App,
 185    ) -> Self {
 186        Self {
 187            id: tool_call.id,
 188            label: cx.new(|cx| {
 189                Markdown::new(
 190                    tool_call.title.into(),
 191                    Some(language_registry.clone()),
 192                    None,
 193                    cx,
 194                )
 195            }),
 196            kind: tool_call.kind,
 197            content: tool_call
 198                .content
 199                .into_iter()
 200                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 201                .collect(),
 202            locations: tool_call.locations,
 203            resolved_locations: Vec::default(),
 204            status,
 205            raw_input: tool_call.raw_input,
 206            raw_output: tool_call.raw_output,
 207        }
 208    }
 209
 210    fn update_fields(
 211        &mut self,
 212        fields: acp::ToolCallUpdateFields,
 213        language_registry: Arc<LanguageRegistry>,
 214        cx: &mut App,
 215    ) {
 216        let acp::ToolCallUpdateFields {
 217            kind,
 218            status,
 219            title,
 220            content,
 221            locations,
 222            raw_input,
 223            raw_output,
 224        } = fields;
 225
 226        if let Some(kind) = kind {
 227            self.kind = kind;
 228        }
 229
 230        if let Some(status) = status {
 231            self.status = status.into();
 232        }
 233
 234        if let Some(title) = title {
 235            self.label.update(cx, |label, cx| {
 236                label.replace(title, cx);
 237            });
 238        }
 239
 240        if let Some(content) = content {
 241            let new_content_len = content.len();
 242            let mut content = content.into_iter();
 243
 244            // Reuse existing content if we can
 245            for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
 246                old.update_from_acp(new, language_registry.clone(), cx);
 247            }
 248            for new in content {
 249                self.content.push(ToolCallContent::from_acp(
 250                    new,
 251                    language_registry.clone(),
 252                    cx,
 253                ))
 254            }
 255            self.content.truncate(new_content_len);
 256        }
 257
 258        if let Some(locations) = locations {
 259            self.locations = locations;
 260        }
 261
 262        if let Some(raw_input) = raw_input {
 263            self.raw_input = Some(raw_input);
 264        }
 265
 266        if let Some(raw_output) = raw_output {
 267            if self.content.is_empty()
 268                && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
 269            {
 270                self.content
 271                    .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
 272                        markdown,
 273                    }));
 274            }
 275            self.raw_output = Some(raw_output);
 276        }
 277    }
 278
 279    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 280        self.content.iter().filter_map(|content| match content {
 281            ToolCallContent::Diff(diff) => Some(diff),
 282            ToolCallContent::ContentBlock(_) => None,
 283            ToolCallContent::Terminal(_) => None,
 284        })
 285    }
 286
 287    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 288        self.content.iter().filter_map(|content| match content {
 289            ToolCallContent::Terminal(terminal) => Some(terminal),
 290            ToolCallContent::ContentBlock(_) => None,
 291            ToolCallContent::Diff(_) => None,
 292        })
 293    }
 294
 295    fn to_markdown(&self, cx: &App) -> String {
 296        let mut markdown = format!(
 297            "**Tool Call: {}**\nStatus: {}\n\n",
 298            self.label.read(cx).source(),
 299            self.status
 300        );
 301        for content in &self.content {
 302            markdown.push_str(content.to_markdown(cx).as_str());
 303            markdown.push_str("\n\n");
 304        }
 305        markdown
 306    }
 307
 308    async fn resolve_location(
 309        location: acp::ToolCallLocation,
 310        project: WeakEntity<Project>,
 311        cx: &mut AsyncApp,
 312    ) -> Option<AgentLocation> {
 313        let buffer = project
 314            .update(cx, |project, cx| {
 315                project
 316                    .project_path_for_absolute_path(&location.path, cx)
 317                    .map(|path| project.open_buffer(path, cx))
 318            })
 319            .ok()??;
 320        let buffer = buffer.await.log_err()?;
 321        let position = buffer
 322            .update(cx, |buffer, _| {
 323                if let Some(row) = location.line {
 324                    let snapshot = buffer.snapshot();
 325                    let column = snapshot.indent_size_for_line(row).len;
 326                    let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
 327                    snapshot.anchor_before(point)
 328                } else {
 329                    Anchor::MIN
 330                }
 331            })
 332            .ok()?;
 333
 334        Some(AgentLocation {
 335            buffer: buffer.downgrade(),
 336            position,
 337        })
 338    }
 339
 340    fn resolve_locations(
 341        &self,
 342        project: Entity<Project>,
 343        cx: &mut App,
 344    ) -> Task<Vec<Option<AgentLocation>>> {
 345        let locations = self.locations.clone();
 346        project.update(cx, |_, cx| {
 347            cx.spawn(async move |project, cx| {
 348                let mut new_locations = Vec::new();
 349                for location in locations {
 350                    new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
 351                }
 352                new_locations
 353            })
 354        })
 355    }
 356}
 357
 358#[derive(Debug)]
 359pub enum ToolCallStatus {
 360    /// The tool call hasn't started running yet, but we start showing it to
 361    /// the user.
 362    Pending,
 363    /// The tool call is waiting for confirmation from the user.
 364    WaitingForConfirmation {
 365        options: Vec<acp::PermissionOption>,
 366        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 367    },
 368    /// The tool call is currently running.
 369    InProgress,
 370    /// The tool call completed successfully.
 371    Completed,
 372    /// The tool call failed.
 373    Failed,
 374    /// The user rejected the tool call.
 375    Rejected,
 376    /// The user canceled generation so the tool call was canceled.
 377    Canceled,
 378}
 379
 380impl From<acp::ToolCallStatus> for ToolCallStatus {
 381    fn from(status: acp::ToolCallStatus) -> Self {
 382        match status {
 383            acp::ToolCallStatus::Pending => Self::Pending,
 384            acp::ToolCallStatus::InProgress => Self::InProgress,
 385            acp::ToolCallStatus::Completed => Self::Completed,
 386            acp::ToolCallStatus::Failed => Self::Failed,
 387        }
 388    }
 389}
 390
 391impl Display for ToolCallStatus {
 392    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 393        write!(
 394            f,
 395            "{}",
 396            match self {
 397                ToolCallStatus::Pending => "Pending",
 398                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 399                ToolCallStatus::InProgress => "In Progress",
 400                ToolCallStatus::Completed => "Completed",
 401                ToolCallStatus::Failed => "Failed",
 402                ToolCallStatus::Rejected => "Rejected",
 403                ToolCallStatus::Canceled => "Canceled",
 404            }
 405        )
 406    }
 407}
 408
 409#[derive(Debug, PartialEq, Clone)]
 410pub enum ContentBlock {
 411    Empty,
 412    Markdown { markdown: Entity<Markdown> },
 413    ResourceLink { resource_link: acp::ResourceLink },
 414}
 415
 416impl ContentBlock {
 417    pub fn new(
 418        block: acp::ContentBlock,
 419        language_registry: &Arc<LanguageRegistry>,
 420        cx: &mut App,
 421    ) -> Self {
 422        let mut this = Self::Empty;
 423        this.append(block, language_registry, cx);
 424        this
 425    }
 426
 427    pub fn new_combined(
 428        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 429        language_registry: Arc<LanguageRegistry>,
 430        cx: &mut App,
 431    ) -> Self {
 432        let mut this = Self::Empty;
 433        for block in blocks {
 434            this.append(block, &language_registry, cx);
 435        }
 436        this
 437    }
 438
 439    pub fn append(
 440        &mut self,
 441        block: acp::ContentBlock,
 442        language_registry: &Arc<LanguageRegistry>,
 443        cx: &mut App,
 444    ) {
 445        if matches!(self, ContentBlock::Empty)
 446            && let acp::ContentBlock::ResourceLink(resource_link) = block
 447        {
 448            *self = ContentBlock::ResourceLink { resource_link };
 449            return;
 450        }
 451
 452        let new_content = self.block_string_contents(block);
 453
 454        match self {
 455            ContentBlock::Empty => {
 456                *self = Self::create_markdown_block(new_content, language_registry, cx);
 457            }
 458            ContentBlock::Markdown { markdown } => {
 459                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 460            }
 461            ContentBlock::ResourceLink { resource_link } => {
 462                let existing_content = Self::resource_link_md(&resource_link.uri);
 463                let combined = format!("{}\n{}", existing_content, new_content);
 464
 465                *self = Self::create_markdown_block(combined, language_registry, cx);
 466            }
 467        }
 468    }
 469
 470    fn create_markdown_block(
 471        content: String,
 472        language_registry: &Arc<LanguageRegistry>,
 473        cx: &mut App,
 474    ) -> ContentBlock {
 475        ContentBlock::Markdown {
 476            markdown: cx
 477                .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
 478        }
 479    }
 480
 481    fn block_string_contents(&self, block: acp::ContentBlock) -> String {
 482        match block {
 483            acp::ContentBlock::Text(text_content) => text_content.text,
 484            acp::ContentBlock::ResourceLink(resource_link) => {
 485                Self::resource_link_md(&resource_link.uri)
 486            }
 487            acp::ContentBlock::Resource(acp::EmbeddedResource {
 488                resource:
 489                    acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
 490                        uri,
 491                        ..
 492                    }),
 493                ..
 494            }) => Self::resource_link_md(&uri),
 495            acp::ContentBlock::Image(image) => Self::image_md(&image),
 496            acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
 497        }
 498    }
 499
 500    fn resource_link_md(uri: &str) -> String {
 501        if let Some(uri) = MentionUri::parse(uri).log_err() {
 502            uri.as_link().to_string()
 503        } else {
 504            uri.to_string()
 505        }
 506    }
 507
 508    fn image_md(_image: &acp::ImageContent) -> String {
 509        "`Image`".into()
 510    }
 511
 512    pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 513        match self {
 514            ContentBlock::Empty => "",
 515            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 516            ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
 517        }
 518    }
 519
 520    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 521        match self {
 522            ContentBlock::Empty => None,
 523            ContentBlock::Markdown { markdown } => Some(markdown),
 524            ContentBlock::ResourceLink { .. } => None,
 525        }
 526    }
 527
 528    pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
 529        match self {
 530            ContentBlock::ResourceLink { resource_link } => Some(resource_link),
 531            _ => None,
 532        }
 533    }
 534}
 535
 536#[derive(Debug)]
 537pub enum ToolCallContent {
 538    ContentBlock(ContentBlock),
 539    Diff(Entity<Diff>),
 540    Terminal(Entity<Terminal>),
 541}
 542
 543impl ToolCallContent {
 544    pub fn from_acp(
 545        content: acp::ToolCallContent,
 546        language_registry: Arc<LanguageRegistry>,
 547        cx: &mut App,
 548    ) -> Self {
 549        match content {
 550            acp::ToolCallContent::Content { content } => {
 551                Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
 552            }
 553            acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
 554                Diff::finalized(
 555                    diff.path,
 556                    diff.old_text,
 557                    diff.new_text,
 558                    language_registry,
 559                    cx,
 560                )
 561            })),
 562        }
 563    }
 564
 565    pub fn update_from_acp(
 566        &mut self,
 567        new: acp::ToolCallContent,
 568        language_registry: Arc<LanguageRegistry>,
 569        cx: &mut App,
 570    ) {
 571        let needs_update = match (&self, &new) {
 572            (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
 573                old_diff.read(cx).needs_update(
 574                    new_diff.old_text.as_deref().unwrap_or(""),
 575                    &new_diff.new_text,
 576                    cx,
 577                )
 578            }
 579            _ => true,
 580        };
 581
 582        if needs_update {
 583            *self = Self::from_acp(new, language_registry, cx);
 584        }
 585    }
 586
 587    pub fn to_markdown(&self, cx: &App) -> String {
 588        match self {
 589            Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
 590            Self::Diff(diff) => diff.read(cx).to_markdown(cx),
 591            Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
 592        }
 593    }
 594}
 595
 596#[derive(Debug, PartialEq)]
 597pub enum ToolCallUpdate {
 598    UpdateFields(acp::ToolCallUpdate),
 599    UpdateDiff(ToolCallUpdateDiff),
 600    UpdateTerminal(ToolCallUpdateTerminal),
 601}
 602
 603impl ToolCallUpdate {
 604    fn id(&self) -> &acp::ToolCallId {
 605        match self {
 606            Self::UpdateFields(update) => &update.id,
 607            Self::UpdateDiff(diff) => &diff.id,
 608            Self::UpdateTerminal(terminal) => &terminal.id,
 609        }
 610    }
 611}
 612
 613impl From<acp::ToolCallUpdate> for ToolCallUpdate {
 614    fn from(update: acp::ToolCallUpdate) -> Self {
 615        Self::UpdateFields(update)
 616    }
 617}
 618
 619impl From<ToolCallUpdateDiff> for ToolCallUpdate {
 620    fn from(diff: ToolCallUpdateDiff) -> Self {
 621        Self::UpdateDiff(diff)
 622    }
 623}
 624
 625#[derive(Debug, PartialEq)]
 626pub struct ToolCallUpdateDiff {
 627    pub id: acp::ToolCallId,
 628    pub diff: Entity<Diff>,
 629}
 630
 631impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
 632    fn from(terminal: ToolCallUpdateTerminal) -> Self {
 633        Self::UpdateTerminal(terminal)
 634    }
 635}
 636
 637#[derive(Debug, PartialEq)]
 638pub struct ToolCallUpdateTerminal {
 639    pub id: acp::ToolCallId,
 640    pub terminal: Entity<Terminal>,
 641}
 642
 643#[derive(Debug, Default)]
 644pub struct Plan {
 645    pub entries: Vec<PlanEntry>,
 646}
 647
 648#[derive(Debug)]
 649pub struct PlanStats<'a> {
 650    pub in_progress_entry: Option<&'a PlanEntry>,
 651    pub pending: u32,
 652    pub completed: u32,
 653}
 654
 655impl Plan {
 656    pub fn is_empty(&self) -> bool {
 657        self.entries.is_empty()
 658    }
 659
 660    pub fn stats(&self) -> PlanStats<'_> {
 661        let mut stats = PlanStats {
 662            in_progress_entry: None,
 663            pending: 0,
 664            completed: 0,
 665        };
 666
 667        for entry in &self.entries {
 668            match &entry.status {
 669                acp::PlanEntryStatus::Pending => {
 670                    stats.pending += 1;
 671                }
 672                acp::PlanEntryStatus::InProgress => {
 673                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 674                }
 675                acp::PlanEntryStatus::Completed => {
 676                    stats.completed += 1;
 677                }
 678            }
 679        }
 680
 681        stats
 682    }
 683}
 684
 685#[derive(Debug)]
 686pub struct PlanEntry {
 687    pub content: Entity<Markdown>,
 688    pub priority: acp::PlanEntryPriority,
 689    pub status: acp::PlanEntryStatus,
 690}
 691
 692impl PlanEntry {
 693    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 694        Self {
 695            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 696            priority: entry.priority,
 697            status: entry.status,
 698        }
 699    }
 700}
 701
 702#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 703pub struct TokenUsage {
 704    pub max_tokens: u64,
 705    pub used_tokens: u64,
 706}
 707
 708impl TokenUsage {
 709    pub fn ratio(&self) -> TokenUsageRatio {
 710        #[cfg(debug_assertions)]
 711        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 712            .unwrap_or("0.8".to_string())
 713            .parse()
 714            .unwrap();
 715        #[cfg(not(debug_assertions))]
 716        let warning_threshold: f32 = 0.8;
 717
 718        // When the maximum is unknown because there is no selected model,
 719        // avoid showing the token limit warning.
 720        if self.max_tokens == 0 {
 721            TokenUsageRatio::Normal
 722        } else if self.used_tokens >= self.max_tokens {
 723            TokenUsageRatio::Exceeded
 724        } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
 725            TokenUsageRatio::Warning
 726        } else {
 727            TokenUsageRatio::Normal
 728        }
 729    }
 730}
 731
 732#[derive(Debug, Clone, PartialEq, Eq)]
 733pub enum TokenUsageRatio {
 734    Normal,
 735    Warning,
 736    Exceeded,
 737}
 738
 739#[derive(Debug, Clone)]
 740pub struct RetryStatus {
 741    pub last_error: SharedString,
 742    pub attempt: usize,
 743    pub max_attempts: usize,
 744    pub started_at: Instant,
 745    pub duration: Duration,
 746}
 747
 748pub struct AcpThread {
 749    title: SharedString,
 750    entries: Vec<AgentThreadEntry>,
 751    plan: Plan,
 752    project: Entity<Project>,
 753    action_log: Entity<ActionLog>,
 754    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 755    send_task: Option<Task<()>>,
 756    connection: Rc<dyn AgentConnection>,
 757    session_id: acp::SessionId,
 758    token_usage: Option<TokenUsage>,
 759}
 760
 761#[derive(Debug)]
 762pub enum AcpThreadEvent {
 763    NewEntry,
 764    TitleUpdated,
 765    TokenUsageUpdated,
 766    EntryUpdated(usize),
 767    EntriesRemoved(Range<usize>),
 768    ToolAuthorizationRequired,
 769    Retry(RetryStatus),
 770    Stopped,
 771    Error,
 772    LoadError(LoadError),
 773}
 774
 775impl EventEmitter<AcpThreadEvent> for AcpThread {}
 776
 777#[derive(PartialEq, Eq)]
 778pub enum ThreadStatus {
 779    Idle,
 780    WaitingForToolConfirmation,
 781    Generating,
 782}
 783
 784#[derive(Debug, Clone)]
 785pub enum LoadError {
 786    NotInstalled {
 787        error_message: SharedString,
 788        install_message: SharedString,
 789        install_command: String,
 790    },
 791    Unsupported {
 792        error_message: SharedString,
 793        upgrade_message: SharedString,
 794        upgrade_command: String,
 795    },
 796    Exited {
 797        status: ExitStatus,
 798    },
 799    Other(SharedString),
 800}
 801
 802impl Display for LoadError {
 803    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 804        match self {
 805            LoadError::NotInstalled { error_message, .. }
 806            | LoadError::Unsupported { error_message, .. } => {
 807                write!(f, "{error_message}")
 808            }
 809            LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
 810            LoadError::Other(msg) => write!(f, "{}", msg),
 811        }
 812    }
 813}
 814
 815impl Error for LoadError {}
 816
 817impl AcpThread {
 818    pub fn new(
 819        title: impl Into<SharedString>,
 820        connection: Rc<dyn AgentConnection>,
 821        project: Entity<Project>,
 822        action_log: Entity<ActionLog>,
 823        session_id: acp::SessionId,
 824    ) -> Self {
 825        Self {
 826            action_log,
 827            shared_buffers: Default::default(),
 828            entries: Default::default(),
 829            plan: Default::default(),
 830            title: title.into(),
 831            project,
 832            send_task: None,
 833            connection,
 834            session_id,
 835            token_usage: None,
 836        }
 837    }
 838
 839    pub fn connection(&self) -> &Rc<dyn AgentConnection> {
 840        &self.connection
 841    }
 842
 843    pub fn action_log(&self) -> &Entity<ActionLog> {
 844        &self.action_log
 845    }
 846
 847    pub fn project(&self) -> &Entity<Project> {
 848        &self.project
 849    }
 850
 851    pub fn title(&self) -> SharedString {
 852        self.title.clone()
 853    }
 854
 855    pub fn entries(&self) -> &[AgentThreadEntry] {
 856        &self.entries
 857    }
 858
 859    pub fn session_id(&self) -> &acp::SessionId {
 860        &self.session_id
 861    }
 862
 863    pub fn status(&self) -> ThreadStatus {
 864        if self.send_task.is_some() {
 865            if self.waiting_for_tool_confirmation() {
 866                ThreadStatus::WaitingForToolConfirmation
 867            } else {
 868                ThreadStatus::Generating
 869            }
 870        } else {
 871            ThreadStatus::Idle
 872        }
 873    }
 874
 875    pub fn token_usage(&self) -> Option<&TokenUsage> {
 876        self.token_usage.as_ref()
 877    }
 878
 879    pub fn has_pending_edit_tool_calls(&self) -> bool {
 880        for entry in self.entries.iter().rev() {
 881            match entry {
 882                AgentThreadEntry::UserMessage(_) => return false,
 883                AgentThreadEntry::ToolCall(
 884                    call @ ToolCall {
 885                        status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
 886                        ..
 887                    },
 888                ) if call.diffs().next().is_some() => {
 889                    return true;
 890                }
 891                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 892            }
 893        }
 894
 895        false
 896    }
 897
 898    pub fn used_tools_since_last_user_message(&self) -> bool {
 899        for entry in self.entries.iter().rev() {
 900            match entry {
 901                AgentThreadEntry::UserMessage(..) => return false,
 902                AgentThreadEntry::AssistantMessage(..) => continue,
 903                AgentThreadEntry::ToolCall(..) => return true,
 904            }
 905        }
 906
 907        false
 908    }
 909
 910    pub fn handle_session_update(
 911        &mut self,
 912        update: acp::SessionUpdate,
 913        cx: &mut Context<Self>,
 914    ) -> Result<(), acp::Error> {
 915        match update {
 916            acp::SessionUpdate::UserMessageChunk { content } => {
 917                self.push_user_content_block(None, content, cx);
 918            }
 919            acp::SessionUpdate::AgentMessageChunk { content } => {
 920                self.push_assistant_content_block(content, false, cx);
 921            }
 922            acp::SessionUpdate::AgentThoughtChunk { content } => {
 923                self.push_assistant_content_block(content, true, cx);
 924            }
 925            acp::SessionUpdate::ToolCall(tool_call) => {
 926                self.upsert_tool_call(tool_call, cx)?;
 927            }
 928            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 929                self.update_tool_call(tool_call_update, cx)?;
 930            }
 931            acp::SessionUpdate::Plan(plan) => {
 932                self.update_plan(plan, cx);
 933            }
 934        }
 935        Ok(())
 936    }
 937
 938    pub fn push_user_content_block(
 939        &mut self,
 940        message_id: Option<UserMessageId>,
 941        chunk: acp::ContentBlock,
 942        cx: &mut Context<Self>,
 943    ) {
 944        let language_registry = self.project.read(cx).languages().clone();
 945        let entries_len = self.entries.len();
 946
 947        if let Some(last_entry) = self.entries.last_mut()
 948            && let AgentThreadEntry::UserMessage(UserMessage {
 949                id,
 950                content,
 951                chunks,
 952                ..
 953            }) = last_entry
 954        {
 955            *id = message_id.or(id.take());
 956            content.append(chunk.clone(), &language_registry, cx);
 957            chunks.push(chunk);
 958            let idx = entries_len - 1;
 959            cx.emit(AcpThreadEvent::EntryUpdated(idx));
 960        } else {
 961            let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
 962            self.push_entry(
 963                AgentThreadEntry::UserMessage(UserMessage {
 964                    id: message_id,
 965                    content,
 966                    chunks: vec![chunk],
 967                    checkpoint: None,
 968                }),
 969                cx,
 970            );
 971        }
 972    }
 973
 974    pub fn push_assistant_content_block(
 975        &mut self,
 976        chunk: acp::ContentBlock,
 977        is_thought: bool,
 978        cx: &mut Context<Self>,
 979    ) {
 980        let language_registry = self.project.read(cx).languages().clone();
 981        let entries_len = self.entries.len();
 982        if let Some(last_entry) = self.entries.last_mut()
 983            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
 984        {
 985            let idx = entries_len - 1;
 986            cx.emit(AcpThreadEvent::EntryUpdated(idx));
 987            match (chunks.last_mut(), is_thought) {
 988                (Some(AssistantMessageChunk::Message { block }), false)
 989                | (Some(AssistantMessageChunk::Thought { block }), true) => {
 990                    block.append(chunk, &language_registry, cx)
 991                }
 992                _ => {
 993                    let block = ContentBlock::new(chunk, &language_registry, cx);
 994                    if is_thought {
 995                        chunks.push(AssistantMessageChunk::Thought { block })
 996                    } else {
 997                        chunks.push(AssistantMessageChunk::Message { block })
 998                    }
 999                }
1000            }
1001        } else {
1002            let block = ContentBlock::new(chunk, &language_registry, cx);
1003            let chunk = if is_thought {
1004                AssistantMessageChunk::Thought { block }
1005            } else {
1006                AssistantMessageChunk::Message { block }
1007            };
1008
1009            self.push_entry(
1010                AgentThreadEntry::AssistantMessage(AssistantMessage {
1011                    chunks: vec![chunk],
1012                }),
1013                cx,
1014            );
1015        }
1016    }
1017
1018    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1019        self.entries.push(entry);
1020        cx.emit(AcpThreadEvent::NewEntry);
1021    }
1022
1023    pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1024        self.connection.set_title(&self.session_id, cx).is_some()
1025    }
1026
1027    pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1028        if title != self.title {
1029            self.title = title.clone();
1030            cx.emit(AcpThreadEvent::TitleUpdated);
1031            if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1032                return set_title.run(title, cx);
1033            }
1034        }
1035        Task::ready(Ok(()))
1036    }
1037
1038    pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1039        self.token_usage = usage;
1040        cx.emit(AcpThreadEvent::TokenUsageUpdated);
1041    }
1042
1043    pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1044        cx.emit(AcpThreadEvent::Retry(status));
1045    }
1046
1047    pub fn update_tool_call(
1048        &mut self,
1049        update: impl Into<ToolCallUpdate>,
1050        cx: &mut Context<Self>,
1051    ) -> Result<()> {
1052        let update = update.into();
1053        let languages = self.project.read(cx).languages().clone();
1054
1055        let (ix, current_call) = self
1056            .tool_call_mut(update.id())
1057            .context("Tool call not found")?;
1058        match update {
1059            ToolCallUpdate::UpdateFields(update) => {
1060                let location_updated = update.fields.locations.is_some();
1061                current_call.update_fields(update.fields, languages, cx);
1062                if location_updated {
1063                    self.resolve_locations(update.id, cx);
1064                }
1065            }
1066            ToolCallUpdate::UpdateDiff(update) => {
1067                current_call.content.clear();
1068                current_call
1069                    .content
1070                    .push(ToolCallContent::Diff(update.diff));
1071            }
1072            ToolCallUpdate::UpdateTerminal(update) => {
1073                current_call.content.clear();
1074                current_call
1075                    .content
1076                    .push(ToolCallContent::Terminal(update.terminal));
1077            }
1078        }
1079
1080        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1081
1082        Ok(())
1083    }
1084
1085    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1086    pub fn upsert_tool_call(
1087        &mut self,
1088        tool_call: acp::ToolCall,
1089        cx: &mut Context<Self>,
1090    ) -> Result<(), acp::Error> {
1091        let status = tool_call.status.into();
1092        self.upsert_tool_call_inner(tool_call.into(), status, cx)
1093    }
1094
1095    /// Fails if id does not match an existing entry.
1096    pub fn upsert_tool_call_inner(
1097        &mut self,
1098        tool_call_update: acp::ToolCallUpdate,
1099        status: ToolCallStatus,
1100        cx: &mut Context<Self>,
1101    ) -> Result<(), acp::Error> {
1102        let language_registry = self.project.read(cx).languages().clone();
1103        let id = tool_call_update.id.clone();
1104
1105        if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1106            current_call.update_fields(tool_call_update.fields, language_registry, cx);
1107            current_call.status = status;
1108
1109            cx.emit(AcpThreadEvent::EntryUpdated(ix));
1110        } else {
1111            let call =
1112                ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1113            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1114        };
1115
1116        self.resolve_locations(id, cx);
1117        Ok(())
1118    }
1119
1120    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1121        // The tool call we are looking for is typically the last one, or very close to the end.
1122        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1123        self.entries
1124            .iter_mut()
1125            .enumerate()
1126            .rev()
1127            .find_map(|(index, tool_call)| {
1128                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1129                    && &tool_call.id == id
1130                {
1131                    Some((index, tool_call))
1132                } else {
1133                    None
1134                }
1135            })
1136    }
1137
1138    pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1139        self.entries
1140            .iter()
1141            .enumerate()
1142            .rev()
1143            .find_map(|(index, tool_call)| {
1144                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1145                    && &tool_call.id == id
1146                {
1147                    Some((index, tool_call))
1148                } else {
1149                    None
1150                }
1151            })
1152    }
1153
1154    pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1155        let project = self.project.clone();
1156        let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1157            return;
1158        };
1159        let task = tool_call.resolve_locations(project, cx);
1160        cx.spawn(async move |this, cx| {
1161            let resolved_locations = task.await;
1162            this.update(cx, |this, cx| {
1163                let project = this.project.clone();
1164                let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1165                    return;
1166                };
1167                if let Some(Some(location)) = resolved_locations.last() {
1168                    project.update(cx, |project, cx| {
1169                        if let Some(agent_location) = project.agent_location() {
1170                            let should_ignore = agent_location.buffer == location.buffer
1171                                && location
1172                                    .buffer
1173                                    .update(cx, |buffer, _| {
1174                                        let snapshot = buffer.snapshot();
1175                                        let old_position =
1176                                            agent_location.position.to_point(&snapshot);
1177                                        let new_position = location.position.to_point(&snapshot);
1178                                        // ignore this so that when we get updates from the edit tool
1179                                        // the position doesn't reset to the startof line
1180                                        old_position.row == new_position.row
1181                                            && old_position.column > new_position.column
1182                                    })
1183                                    .ok()
1184                                    .unwrap_or_default();
1185                            if !should_ignore {
1186                                project.set_agent_location(Some(location.clone()), cx);
1187                            }
1188                        }
1189                    });
1190                }
1191                if tool_call.resolved_locations != resolved_locations {
1192                    tool_call.resolved_locations = resolved_locations;
1193                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1194                }
1195            })
1196        })
1197        .detach();
1198    }
1199
1200    pub fn request_tool_call_authorization(
1201        &mut self,
1202        tool_call: acp::ToolCallUpdate,
1203        options: Vec<acp::PermissionOption>,
1204        cx: &mut Context<Self>,
1205    ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1206        let (tx, rx) = oneshot::channel();
1207
1208        let status = ToolCallStatus::WaitingForConfirmation {
1209            options,
1210            respond_tx: tx,
1211        };
1212
1213        self.upsert_tool_call_inner(tool_call, status, cx)?;
1214        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1215        Ok(rx)
1216    }
1217
1218    pub fn authorize_tool_call(
1219        &mut self,
1220        id: acp::ToolCallId,
1221        option_id: acp::PermissionOptionId,
1222        option_kind: acp::PermissionOptionKind,
1223        cx: &mut Context<Self>,
1224    ) {
1225        let Some((ix, call)) = self.tool_call_mut(&id) else {
1226            return;
1227        };
1228
1229        let new_status = match option_kind {
1230            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1231                ToolCallStatus::Rejected
1232            }
1233            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1234                ToolCallStatus::InProgress
1235            }
1236        };
1237
1238        let curr_status = mem::replace(&mut call.status, new_status);
1239
1240        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1241            respond_tx.send(option_id).log_err();
1242        } else if cfg!(debug_assertions) {
1243            panic!("tried to authorize an already authorized tool call");
1244        }
1245
1246        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1247    }
1248
1249    /// Returns true if the last turn is awaiting tool authorization
1250    pub fn waiting_for_tool_confirmation(&self) -> bool {
1251        for entry in self.entries.iter().rev() {
1252            match &entry {
1253                AgentThreadEntry::ToolCall(call) => match call.status {
1254                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
1255                    ToolCallStatus::Pending
1256                    | ToolCallStatus::InProgress
1257                    | ToolCallStatus::Completed
1258                    | ToolCallStatus::Failed
1259                    | ToolCallStatus::Rejected
1260                    | ToolCallStatus::Canceled => continue,
1261                },
1262                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1263                    // Reached the beginning of the turn
1264                    return false;
1265                }
1266            }
1267        }
1268        false
1269    }
1270
1271    pub fn plan(&self) -> &Plan {
1272        &self.plan
1273    }
1274
1275    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1276        let new_entries_len = request.entries.len();
1277        let mut new_entries = request.entries.into_iter();
1278
1279        // Reuse existing markdown to prevent flickering
1280        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1281            let PlanEntry {
1282                content,
1283                priority,
1284                status,
1285            } = old;
1286            content.update(cx, |old, cx| {
1287                old.replace(new.content, cx);
1288            });
1289            *priority = new.priority;
1290            *status = new.status;
1291        }
1292        for new in new_entries {
1293            self.plan.entries.push(PlanEntry::from_acp(new, cx))
1294        }
1295        self.plan.entries.truncate(new_entries_len);
1296
1297        cx.notify();
1298    }
1299
1300    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1301        self.plan
1302            .entries
1303            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1304        cx.notify();
1305    }
1306
1307    #[cfg(any(test, feature = "test-support"))]
1308    pub fn send_raw(
1309        &mut self,
1310        message: &str,
1311        cx: &mut Context<Self>,
1312    ) -> BoxFuture<'static, Result<()>> {
1313        self.send(
1314            vec![acp::ContentBlock::Text(acp::TextContent {
1315                text: message.to_string(),
1316                annotations: None,
1317            })],
1318            cx,
1319        )
1320    }
1321
1322    pub fn send(
1323        &mut self,
1324        message: Vec<acp::ContentBlock>,
1325        cx: &mut Context<Self>,
1326    ) -> BoxFuture<'static, Result<()>> {
1327        let block = ContentBlock::new_combined(
1328            message.clone(),
1329            self.project.read(cx).languages().clone(),
1330            cx,
1331        );
1332        let request = acp::PromptRequest {
1333            prompt: message.clone(),
1334            session_id: self.session_id.clone(),
1335        };
1336        let git_store = self.project.read(cx).git_store().clone();
1337
1338        let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1339            Some(UserMessageId::new())
1340        } else {
1341            None
1342        };
1343
1344        self.run_turn(cx, async move |this, cx| {
1345            this.update(cx, |this, cx| {
1346                this.push_entry(
1347                    AgentThreadEntry::UserMessage(UserMessage {
1348                        id: message_id.clone(),
1349                        content: block,
1350                        chunks: message,
1351                        checkpoint: None,
1352                    }),
1353                    cx,
1354                );
1355            })
1356            .ok();
1357
1358            let old_checkpoint = git_store
1359                .update(cx, |git, cx| git.checkpoint(cx))?
1360                .await
1361                .context("failed to get old checkpoint")
1362                .log_err();
1363            this.update(cx, |this, cx| {
1364                if let Some((_ix, message)) = this.last_user_message() {
1365                    message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1366                        git_checkpoint,
1367                        show: false,
1368                    });
1369                }
1370                this.connection.prompt(message_id, request, cx)
1371            })?
1372            .await
1373        })
1374    }
1375
1376    pub fn can_resume(&self, cx: &App) -> bool {
1377        self.connection.resume(&self.session_id, cx).is_some()
1378    }
1379
1380    pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1381        self.run_turn(cx, async move |this, cx| {
1382            this.update(cx, |this, cx| {
1383                this.connection
1384                    .resume(&this.session_id, cx)
1385                    .map(|resume| resume.run(cx))
1386            })?
1387            .context("resuming a session is not supported")?
1388            .await
1389        })
1390    }
1391
1392    fn run_turn(
1393        &mut self,
1394        cx: &mut Context<Self>,
1395        f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1396    ) -> BoxFuture<'static, Result<()>> {
1397        self.clear_completed_plan_entries(cx);
1398
1399        let (tx, rx) = oneshot::channel();
1400        let cancel_task = self.cancel(cx);
1401
1402        self.send_task = Some(cx.spawn(async move |this, cx| {
1403            cancel_task.await;
1404            tx.send(f(this, cx).await).ok();
1405        }));
1406
1407        cx.spawn(async move |this, cx| {
1408            let response = rx.await;
1409
1410            this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1411                .await?;
1412
1413            this.update(cx, |this, cx| {
1414                this.project
1415                    .update(cx, |project, cx| project.set_agent_location(None, cx));
1416                match response {
1417                    Ok(Err(e)) => {
1418                        this.send_task.take();
1419                        cx.emit(AcpThreadEvent::Error);
1420                        Err(e)
1421                    }
1422                    result => {
1423                        let canceled = matches!(
1424                            result,
1425                            Ok(Ok(acp::PromptResponse {
1426                                stop_reason: acp::StopReason::Cancelled
1427                            }))
1428                        );
1429
1430                        // We only take the task if the current prompt wasn't canceled.
1431                        //
1432                        // This prompt may have been canceled because another one was sent
1433                        // while it was still generating. In these cases, dropping `send_task`
1434                        // would cause the next generation to be canceled.
1435                        if !canceled {
1436                            this.send_task.take();
1437                        }
1438
1439                        // Truncate entries if the last prompt was refused.
1440                        if let Ok(Ok(acp::PromptResponse {
1441                            stop_reason: acp::StopReason::Refusal,
1442                        })) = result
1443                            && let Some((ix, _)) = this.last_user_message()
1444                        {
1445                            let range = ix..this.entries.len();
1446                            this.entries.truncate(ix);
1447                            cx.emit(AcpThreadEvent::EntriesRemoved(range));
1448                        }
1449
1450                        cx.emit(AcpThreadEvent::Stopped);
1451                        Ok(())
1452                    }
1453                }
1454            })?
1455        })
1456        .boxed()
1457    }
1458
1459    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1460        let Some(send_task) = self.send_task.take() else {
1461            return Task::ready(());
1462        };
1463
1464        for entry in self.entries.iter_mut() {
1465            if let AgentThreadEntry::ToolCall(call) = entry {
1466                let cancel = matches!(
1467                    call.status,
1468                    ToolCallStatus::Pending
1469                        | ToolCallStatus::WaitingForConfirmation { .. }
1470                        | ToolCallStatus::InProgress
1471                );
1472
1473                if cancel {
1474                    call.status = ToolCallStatus::Canceled;
1475                }
1476            }
1477        }
1478
1479        self.connection.cancel(&self.session_id, cx);
1480
1481        // Wait for the send task to complete
1482        cx.foreground_executor().spawn(send_task)
1483    }
1484
1485    /// Rewinds this thread to before the entry at `index`, removing it and all
1486    /// subsequent entries while reverting any changes made from that point.
1487    pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1488        let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1489            return Task::ready(Err(anyhow!("not supported")));
1490        };
1491        let Some(message) = self.user_message(&id) else {
1492            return Task::ready(Err(anyhow!("message not found")));
1493        };
1494
1495        let checkpoint = message
1496            .checkpoint
1497            .as_ref()
1498            .map(|c| c.git_checkpoint.clone());
1499
1500        let git_store = self.project.read(cx).git_store().clone();
1501        cx.spawn(async move |this, cx| {
1502            if let Some(checkpoint) = checkpoint {
1503                git_store
1504                    .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1505                    .await?;
1506            }
1507
1508            cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1509            this.update(cx, |this, cx| {
1510                if let Some((ix, _)) = this.user_message_mut(&id) {
1511                    let range = ix..this.entries.len();
1512                    this.entries.truncate(ix);
1513                    cx.emit(AcpThreadEvent::EntriesRemoved(range));
1514                }
1515            })
1516        })
1517    }
1518
1519    fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1520        let git_store = self.project.read(cx).git_store().clone();
1521
1522        let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1523            if let Some(checkpoint) = message.checkpoint.as_ref() {
1524                checkpoint.git_checkpoint.clone()
1525            } else {
1526                return Task::ready(Ok(()));
1527            }
1528        } else {
1529            return Task::ready(Ok(()));
1530        };
1531
1532        let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1533        cx.spawn(async move |this, cx| {
1534            let new_checkpoint = new_checkpoint
1535                .await
1536                .context("failed to get new checkpoint")
1537                .log_err();
1538            if let Some(new_checkpoint) = new_checkpoint {
1539                let equal = git_store
1540                    .update(cx, |git, cx| {
1541                        git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1542                    })?
1543                    .await
1544                    .unwrap_or(true);
1545                this.update(cx, |this, cx| {
1546                    let (ix, message) = this.last_user_message().context("no user message")?;
1547                    let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1548                    checkpoint.show = !equal;
1549                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1550                    anyhow::Ok(())
1551                })??;
1552            }
1553
1554            Ok(())
1555        })
1556    }
1557
1558    fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1559        self.entries
1560            .iter_mut()
1561            .enumerate()
1562            .rev()
1563            .find_map(|(ix, entry)| {
1564                if let AgentThreadEntry::UserMessage(message) = entry {
1565                    Some((ix, message))
1566                } else {
1567                    None
1568                }
1569            })
1570    }
1571
1572    fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1573        self.entries.iter().find_map(|entry| {
1574            if let AgentThreadEntry::UserMessage(message) = entry {
1575                if message.id.as_ref() == Some(id) {
1576                    Some(message)
1577                } else {
1578                    None
1579                }
1580            } else {
1581                None
1582            }
1583        })
1584    }
1585
1586    fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1587        self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1588            if let AgentThreadEntry::UserMessage(message) = entry {
1589                if message.id.as_ref() == Some(id) {
1590                    Some((ix, message))
1591                } else {
1592                    None
1593                }
1594            } else {
1595                None
1596            }
1597        })
1598    }
1599
1600    pub fn read_text_file(
1601        &self,
1602        path: PathBuf,
1603        line: Option<u32>,
1604        limit: Option<u32>,
1605        reuse_shared_snapshot: bool,
1606        cx: &mut Context<Self>,
1607    ) -> Task<Result<String>> {
1608        let project = self.project.clone();
1609        let action_log = self.action_log.clone();
1610        cx.spawn(async move |this, cx| {
1611            let load = project.update(cx, |project, cx| {
1612                let path = project
1613                    .project_path_for_absolute_path(&path, cx)
1614                    .context("invalid path")?;
1615                anyhow::Ok(project.open_buffer(path, cx))
1616            });
1617            let buffer = load??.await?;
1618
1619            let snapshot = if reuse_shared_snapshot {
1620                this.read_with(cx, |this, _| {
1621                    this.shared_buffers.get(&buffer.clone()).cloned()
1622                })
1623                .log_err()
1624                .flatten()
1625            } else {
1626                None
1627            };
1628
1629            let snapshot = if let Some(snapshot) = snapshot {
1630                snapshot
1631            } else {
1632                action_log.update(cx, |action_log, cx| {
1633                    action_log.buffer_read(buffer.clone(), cx);
1634                })?;
1635                project.update(cx, |project, cx| {
1636                    let position = buffer
1637                        .read(cx)
1638                        .snapshot()
1639                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1640                    project.set_agent_location(
1641                        Some(AgentLocation {
1642                            buffer: buffer.downgrade(),
1643                            position,
1644                        }),
1645                        cx,
1646                    );
1647                })?;
1648
1649                buffer.update(cx, |buffer, _| buffer.snapshot())?
1650            };
1651
1652            this.update(cx, |this, _| {
1653                let text = snapshot.text();
1654                this.shared_buffers.insert(buffer.clone(), snapshot);
1655                if line.is_none() && limit.is_none() {
1656                    return Ok(text);
1657                }
1658                let limit = limit.unwrap_or(u32::MAX) as usize;
1659                let Some(line) = line else {
1660                    return Ok(text.lines().take(limit).collect::<String>());
1661                };
1662
1663                let count = text.lines().count();
1664                if count < line as usize {
1665                    anyhow::bail!("There are only {} lines", count);
1666                }
1667                Ok(text
1668                    .lines()
1669                    .skip(line as usize + 1)
1670                    .take(limit)
1671                    .collect::<String>())
1672            })?
1673        })
1674    }
1675
1676    pub fn write_text_file(
1677        &self,
1678        path: PathBuf,
1679        content: String,
1680        cx: &mut Context<Self>,
1681    ) -> Task<Result<()>> {
1682        let project = self.project.clone();
1683        let action_log = self.action_log.clone();
1684        cx.spawn(async move |this, cx| {
1685            let load = project.update(cx, |project, cx| {
1686                let path = project
1687                    .project_path_for_absolute_path(&path, cx)
1688                    .context("invalid path")?;
1689                anyhow::Ok(project.open_buffer(path, cx))
1690            });
1691            let buffer = load??.await?;
1692            let snapshot = this.update(cx, |this, cx| {
1693                this.shared_buffers
1694                    .get(&buffer)
1695                    .cloned()
1696                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1697            })?;
1698            let edits = cx
1699                .background_executor()
1700                .spawn(async move {
1701                    let old_text = snapshot.text();
1702                    text_diff(old_text.as_str(), &content)
1703                        .into_iter()
1704                        .map(|(range, replacement)| {
1705                            (
1706                                snapshot.anchor_after(range.start)
1707                                    ..snapshot.anchor_before(range.end),
1708                                replacement,
1709                            )
1710                        })
1711                        .collect::<Vec<_>>()
1712                })
1713                .await;
1714
1715            project.update(cx, |project, cx| {
1716                project.set_agent_location(
1717                    Some(AgentLocation {
1718                        buffer: buffer.downgrade(),
1719                        position: edits
1720                            .last()
1721                            .map(|(range, _)| range.end)
1722                            .unwrap_or(Anchor::MIN),
1723                    }),
1724                    cx,
1725                );
1726            })?;
1727
1728            let format_on_save = cx.update(|cx| {
1729                action_log.update(cx, |action_log, cx| {
1730                    action_log.buffer_read(buffer.clone(), cx);
1731                });
1732
1733                let format_on_save = buffer.update(cx, |buffer, cx| {
1734                    buffer.edit(edits, None, cx);
1735
1736                    let settings = language::language_settings::language_settings(
1737                        buffer.language().map(|l| l.name()),
1738                        buffer.file(),
1739                        cx,
1740                    );
1741
1742                    settings.format_on_save != FormatOnSave::Off
1743                });
1744                action_log.update(cx, |action_log, cx| {
1745                    action_log.buffer_edited(buffer.clone(), cx);
1746                });
1747                format_on_save
1748            })?;
1749
1750            if format_on_save {
1751                let format_task = project.update(cx, |project, cx| {
1752                    project.format(
1753                        HashSet::from_iter([buffer.clone()]),
1754                        LspFormatTarget::Buffers,
1755                        false,
1756                        FormatTrigger::Save,
1757                        cx,
1758                    )
1759                })?;
1760                format_task.await.log_err();
1761
1762                action_log.update(cx, |action_log, cx| {
1763                    action_log.buffer_edited(buffer.clone(), cx);
1764                })?;
1765            }
1766
1767            project
1768                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1769                .await
1770        })
1771    }
1772
1773    pub fn to_markdown(&self, cx: &App) -> String {
1774        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1775    }
1776
1777    pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
1778        cx.emit(AcpThreadEvent::LoadError(error));
1779    }
1780}
1781
1782fn markdown_for_raw_output(
1783    raw_output: &serde_json::Value,
1784    language_registry: &Arc<LanguageRegistry>,
1785    cx: &mut App,
1786) -> Option<Entity<Markdown>> {
1787    match raw_output {
1788        serde_json::Value::Null => None,
1789        serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1790            Markdown::new(
1791                value.to_string().into(),
1792                Some(language_registry.clone()),
1793                None,
1794                cx,
1795            )
1796        })),
1797        serde_json::Value::Number(value) => Some(cx.new(|cx| {
1798            Markdown::new(
1799                value.to_string().into(),
1800                Some(language_registry.clone()),
1801                None,
1802                cx,
1803            )
1804        })),
1805        serde_json::Value::String(value) => Some(cx.new(|cx| {
1806            Markdown::new(
1807                value.clone().into(),
1808                Some(language_registry.clone()),
1809                None,
1810                cx,
1811            )
1812        })),
1813        value => Some(cx.new(|cx| {
1814            Markdown::new(
1815                format!("```json\n{}\n```", value).into(),
1816                Some(language_registry.clone()),
1817                None,
1818                cx,
1819            )
1820        })),
1821    }
1822}
1823
1824#[cfg(test)]
1825mod tests {
1826    use super::*;
1827    use anyhow::anyhow;
1828    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1829    use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1830    use indoc::indoc;
1831    use project::{FakeFs, Fs};
1832    use rand::Rng as _;
1833    use serde_json::json;
1834    use settings::SettingsStore;
1835    use smol::stream::StreamExt as _;
1836    use std::{
1837        any::Any,
1838        cell::RefCell,
1839        path::Path,
1840        rc::Rc,
1841        sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1842        time::Duration,
1843    };
1844    use util::path;
1845
1846    fn init_test(cx: &mut TestAppContext) {
1847        env_logger::try_init().ok();
1848        cx.update(|cx| {
1849            let settings_store = SettingsStore::test(cx);
1850            cx.set_global(settings_store);
1851            Project::init_settings(cx);
1852            language::init(cx);
1853        });
1854    }
1855
1856    #[gpui::test]
1857    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1858        init_test(cx);
1859
1860        let fs = FakeFs::new(cx.executor());
1861        let project = Project::test(fs, [], cx).await;
1862        let connection = Rc::new(FakeAgentConnection::new());
1863        let thread = cx
1864            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1865            .await
1866            .unwrap();
1867
1868        // Test creating a new user message
1869        thread.update(cx, |thread, cx| {
1870            thread.push_user_content_block(
1871                None,
1872                acp::ContentBlock::Text(acp::TextContent {
1873                    annotations: None,
1874                    text: "Hello, ".to_string(),
1875                }),
1876                cx,
1877            );
1878        });
1879
1880        thread.update(cx, |thread, cx| {
1881            assert_eq!(thread.entries.len(), 1);
1882            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1883                assert_eq!(user_msg.id, None);
1884                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1885            } else {
1886                panic!("Expected UserMessage");
1887            }
1888        });
1889
1890        // Test appending to existing user message
1891        let message_1_id = UserMessageId::new();
1892        thread.update(cx, |thread, cx| {
1893            thread.push_user_content_block(
1894                Some(message_1_id.clone()),
1895                acp::ContentBlock::Text(acp::TextContent {
1896                    annotations: None,
1897                    text: "world!".to_string(),
1898                }),
1899                cx,
1900            );
1901        });
1902
1903        thread.update(cx, |thread, cx| {
1904            assert_eq!(thread.entries.len(), 1);
1905            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1906                assert_eq!(user_msg.id, Some(message_1_id));
1907                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1908            } else {
1909                panic!("Expected UserMessage");
1910            }
1911        });
1912
1913        // Test creating new user message after assistant message
1914        thread.update(cx, |thread, cx| {
1915            thread.push_assistant_content_block(
1916                acp::ContentBlock::Text(acp::TextContent {
1917                    annotations: None,
1918                    text: "Assistant response".to_string(),
1919                }),
1920                false,
1921                cx,
1922            );
1923        });
1924
1925        let message_2_id = UserMessageId::new();
1926        thread.update(cx, |thread, cx| {
1927            thread.push_user_content_block(
1928                Some(message_2_id.clone()),
1929                acp::ContentBlock::Text(acp::TextContent {
1930                    annotations: None,
1931                    text: "New user message".to_string(),
1932                }),
1933                cx,
1934            );
1935        });
1936
1937        thread.update(cx, |thread, cx| {
1938            assert_eq!(thread.entries.len(), 3);
1939            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1940                assert_eq!(user_msg.id, Some(message_2_id));
1941                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1942            } else {
1943                panic!("Expected UserMessage at index 2");
1944            }
1945        });
1946    }
1947
1948    #[gpui::test]
1949    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1950        init_test(cx);
1951
1952        let fs = FakeFs::new(cx.executor());
1953        let project = Project::test(fs, [], cx).await;
1954        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1955            |_, thread, mut cx| {
1956                async move {
1957                    thread.update(&mut cx, |thread, cx| {
1958                        thread
1959                            .handle_session_update(
1960                                acp::SessionUpdate::AgentThoughtChunk {
1961                                    content: "Thinking ".into(),
1962                                },
1963                                cx,
1964                            )
1965                            .unwrap();
1966                        thread
1967                            .handle_session_update(
1968                                acp::SessionUpdate::AgentThoughtChunk {
1969                                    content: "hard!".into(),
1970                                },
1971                                cx,
1972                            )
1973                            .unwrap();
1974                    })?;
1975                    Ok(acp::PromptResponse {
1976                        stop_reason: acp::StopReason::EndTurn,
1977                    })
1978                }
1979                .boxed_local()
1980            },
1981        ));
1982
1983        let thread = cx
1984            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1985            .await
1986            .unwrap();
1987
1988        thread
1989            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1990            .await
1991            .unwrap();
1992
1993        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1994        assert_eq!(
1995            output,
1996            indoc! {r#"
1997            ## User
1998
1999            Hello from Zed!
2000
2001            ## Assistant
2002
2003            <thinking>
2004            Thinking hard!
2005            </thinking>
2006
2007            "#}
2008        );
2009    }
2010
2011    #[gpui::test]
2012    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2013        init_test(cx);
2014
2015        let fs = FakeFs::new(cx.executor());
2016        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2017            .await;
2018        let project = Project::test(fs.clone(), [], cx).await;
2019        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2020        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2021        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2022            move |_, thread, mut cx| {
2023                let read_file_tx = read_file_tx.clone();
2024                async move {
2025                    let content = thread
2026                        .update(&mut cx, |thread, cx| {
2027                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2028                        })
2029                        .unwrap()
2030                        .await
2031                        .unwrap();
2032                    assert_eq!(content, "one\ntwo\nthree\n");
2033                    read_file_tx.take().unwrap().send(()).unwrap();
2034                    thread
2035                        .update(&mut cx, |thread, cx| {
2036                            thread.write_text_file(
2037                                path!("/tmp/foo").into(),
2038                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
2039                                cx,
2040                            )
2041                        })
2042                        .unwrap()
2043                        .await
2044                        .unwrap();
2045                    Ok(acp::PromptResponse {
2046                        stop_reason: acp::StopReason::EndTurn,
2047                    })
2048                }
2049                .boxed_local()
2050            },
2051        ));
2052
2053        let (worktree, pathbuf) = project
2054            .update(cx, |project, cx| {
2055                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2056            })
2057            .await
2058            .unwrap();
2059        let buffer = project
2060            .update(cx, |project, cx| {
2061                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2062            })
2063            .await
2064            .unwrap();
2065
2066        let thread = cx
2067            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2068            .await
2069            .unwrap();
2070
2071        let request = thread.update(cx, |thread, cx| {
2072            thread.send_raw("Extend the count in /tmp/foo", cx)
2073        });
2074        read_file_rx.await.ok();
2075        buffer.update(cx, |buffer, cx| {
2076            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2077        });
2078        cx.run_until_parked();
2079        assert_eq!(
2080            buffer.read_with(cx, |buffer, _| buffer.text()),
2081            "zero\none\ntwo\nthree\nfour\nfive\n"
2082        );
2083        assert_eq!(
2084            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2085            "zero\none\ntwo\nthree\nfour\nfive\n"
2086        );
2087        request.await.unwrap();
2088    }
2089
2090    #[gpui::test]
2091    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2092        init_test(cx);
2093
2094        let fs = FakeFs::new(cx.executor());
2095        let project = Project::test(fs, [], cx).await;
2096        let id = acp::ToolCallId("test".into());
2097
2098        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2099            let id = id.clone();
2100            move |_, thread, mut cx| {
2101                let id = id.clone();
2102                async move {
2103                    thread
2104                        .update(&mut cx, |thread, cx| {
2105                            thread.handle_session_update(
2106                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2107                                    id: id.clone(),
2108                                    title: "Label".into(),
2109                                    kind: acp::ToolKind::Fetch,
2110                                    status: acp::ToolCallStatus::InProgress,
2111                                    content: vec![],
2112                                    locations: vec![],
2113                                    raw_input: None,
2114                                    raw_output: None,
2115                                }),
2116                                cx,
2117                            )
2118                        })
2119                        .unwrap()
2120                        .unwrap();
2121                    Ok(acp::PromptResponse {
2122                        stop_reason: acp::StopReason::EndTurn,
2123                    })
2124                }
2125                .boxed_local()
2126            }
2127        }));
2128
2129        let thread = cx
2130            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2131            .await
2132            .unwrap();
2133
2134        let request = thread.update(cx, |thread, cx| {
2135            thread.send_raw("Fetch https://example.com", cx)
2136        });
2137
2138        run_until_first_tool_call(&thread, cx).await;
2139
2140        thread.read_with(cx, |thread, _| {
2141            assert!(matches!(
2142                thread.entries[1],
2143                AgentThreadEntry::ToolCall(ToolCall {
2144                    status: ToolCallStatus::InProgress,
2145                    ..
2146                })
2147            ));
2148        });
2149
2150        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2151
2152        thread.read_with(cx, |thread, _| {
2153            assert!(matches!(
2154                &thread.entries[1],
2155                AgentThreadEntry::ToolCall(ToolCall {
2156                    status: ToolCallStatus::Canceled,
2157                    ..
2158                })
2159            ));
2160        });
2161
2162        thread
2163            .update(cx, |thread, cx| {
2164                thread.handle_session_update(
2165                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2166                        id,
2167                        fields: acp::ToolCallUpdateFields {
2168                            status: Some(acp::ToolCallStatus::Completed),
2169                            ..Default::default()
2170                        },
2171                    }),
2172                    cx,
2173                )
2174            })
2175            .unwrap();
2176
2177        request.await.unwrap();
2178
2179        thread.read_with(cx, |thread, _| {
2180            assert!(matches!(
2181                thread.entries[1],
2182                AgentThreadEntry::ToolCall(ToolCall {
2183                    status: ToolCallStatus::Completed,
2184                    ..
2185                })
2186            ));
2187        });
2188    }
2189
2190    #[gpui::test]
2191    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2192        init_test(cx);
2193        let fs = FakeFs::new(cx.background_executor.clone());
2194        fs.insert_tree(path!("/test"), json!({})).await;
2195        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2196
2197        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2198            move |_, thread, mut cx| {
2199                async move {
2200                    thread
2201                        .update(&mut cx, |thread, cx| {
2202                            thread.handle_session_update(
2203                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2204                                    id: acp::ToolCallId("test".into()),
2205                                    title: "Label".into(),
2206                                    kind: acp::ToolKind::Edit,
2207                                    status: acp::ToolCallStatus::Completed,
2208                                    content: vec![acp::ToolCallContent::Diff {
2209                                        diff: acp::Diff {
2210                                            path: "/test/test.txt".into(),
2211                                            old_text: None,
2212                                            new_text: "foo".into(),
2213                                        },
2214                                    }],
2215                                    locations: vec![],
2216                                    raw_input: None,
2217                                    raw_output: None,
2218                                }),
2219                                cx,
2220                            )
2221                        })
2222                        .unwrap()
2223                        .unwrap();
2224                    Ok(acp::PromptResponse {
2225                        stop_reason: acp::StopReason::EndTurn,
2226                    })
2227                }
2228                .boxed_local()
2229            }
2230        }));
2231
2232        let thread = cx
2233            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2234            .await
2235            .unwrap();
2236
2237        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2238            .await
2239            .unwrap();
2240
2241        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2242    }
2243
2244    #[gpui::test(iterations = 10)]
2245    async fn test_checkpoints(cx: &mut TestAppContext) {
2246        init_test(cx);
2247        let fs = FakeFs::new(cx.background_executor.clone());
2248        fs.insert_tree(
2249            path!("/test"),
2250            json!({
2251                ".git": {}
2252            }),
2253        )
2254        .await;
2255        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2256
2257        let simulate_changes = Arc::new(AtomicBool::new(true));
2258        let next_filename = Arc::new(AtomicUsize::new(0));
2259        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2260            let simulate_changes = simulate_changes.clone();
2261            let next_filename = next_filename.clone();
2262            let fs = fs.clone();
2263            move |request, thread, mut cx| {
2264                let fs = fs.clone();
2265                let simulate_changes = simulate_changes.clone();
2266                let next_filename = next_filename.clone();
2267                async move {
2268                    if simulate_changes.load(SeqCst) {
2269                        let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2270                        fs.write(Path::new(&filename), b"").await?;
2271                    }
2272
2273                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2274                        panic!("expected text content block");
2275                    };
2276                    thread.update(&mut cx, |thread, cx| {
2277                        thread
2278                            .handle_session_update(
2279                                acp::SessionUpdate::AgentMessageChunk {
2280                                    content: content.text.to_uppercase().into(),
2281                                },
2282                                cx,
2283                            )
2284                            .unwrap();
2285                    })?;
2286                    Ok(acp::PromptResponse {
2287                        stop_reason: acp::StopReason::EndTurn,
2288                    })
2289                }
2290                .boxed_local()
2291            }
2292        }));
2293        let thread = cx
2294            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2295            .await
2296            .unwrap();
2297
2298        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2299            .await
2300            .unwrap();
2301        thread.read_with(cx, |thread, cx| {
2302            assert_eq!(
2303                thread.to_markdown(cx),
2304                indoc! {"
2305                    ## User (checkpoint)
2306
2307                    Lorem
2308
2309                    ## Assistant
2310
2311                    LOREM
2312
2313                "}
2314            );
2315        });
2316        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2317
2318        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2319            .await
2320            .unwrap();
2321        thread.read_with(cx, |thread, cx| {
2322            assert_eq!(
2323                thread.to_markdown(cx),
2324                indoc! {"
2325                    ## User (checkpoint)
2326
2327                    Lorem
2328
2329                    ## Assistant
2330
2331                    LOREM
2332
2333                    ## User (checkpoint)
2334
2335                    ipsum
2336
2337                    ## Assistant
2338
2339                    IPSUM
2340
2341                "}
2342            );
2343        });
2344        assert_eq!(
2345            fs.files(),
2346            vec![
2347                Path::new(path!("/test/file-0")),
2348                Path::new(path!("/test/file-1"))
2349            ]
2350        );
2351
2352        // Checkpoint isn't stored when there are no changes.
2353        simulate_changes.store(false, SeqCst);
2354        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2355            .await
2356            .unwrap();
2357        thread.read_with(cx, |thread, cx| {
2358            assert_eq!(
2359                thread.to_markdown(cx),
2360                indoc! {"
2361                    ## User (checkpoint)
2362
2363                    Lorem
2364
2365                    ## Assistant
2366
2367                    LOREM
2368
2369                    ## User (checkpoint)
2370
2371                    ipsum
2372
2373                    ## Assistant
2374
2375                    IPSUM
2376
2377                    ## User
2378
2379                    dolor
2380
2381                    ## Assistant
2382
2383                    DOLOR
2384
2385                "}
2386            );
2387        });
2388        assert_eq!(
2389            fs.files(),
2390            vec![
2391                Path::new(path!("/test/file-0")),
2392                Path::new(path!("/test/file-1"))
2393            ]
2394        );
2395
2396        // Rewinding the conversation truncates the history and restores the checkpoint.
2397        thread
2398            .update(cx, |thread, cx| {
2399                let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2400                    panic!("unexpected entries {:?}", thread.entries)
2401                };
2402                thread.rewind(message.id.clone().unwrap(), cx)
2403            })
2404            .await
2405            .unwrap();
2406        thread.read_with(cx, |thread, cx| {
2407            assert_eq!(
2408                thread.to_markdown(cx),
2409                indoc! {"
2410                    ## User (checkpoint)
2411
2412                    Lorem
2413
2414                    ## Assistant
2415
2416                    LOREM
2417
2418                "}
2419            );
2420        });
2421        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2422    }
2423
2424    #[gpui::test]
2425    async fn test_refusal(cx: &mut TestAppContext) {
2426        init_test(cx);
2427        let fs = FakeFs::new(cx.background_executor.clone());
2428        fs.insert_tree(path!("/"), json!({})).await;
2429        let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2430
2431        let refuse_next = Arc::new(AtomicBool::new(false));
2432        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2433            let refuse_next = refuse_next.clone();
2434            move |request, thread, mut cx| {
2435                let refuse_next = refuse_next.clone();
2436                async move {
2437                    if refuse_next.load(SeqCst) {
2438                        return Ok(acp::PromptResponse {
2439                            stop_reason: acp::StopReason::Refusal,
2440                        });
2441                    }
2442
2443                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2444                        panic!("expected text content block");
2445                    };
2446                    thread.update(&mut cx, |thread, cx| {
2447                        thread
2448                            .handle_session_update(
2449                                acp::SessionUpdate::AgentMessageChunk {
2450                                    content: content.text.to_uppercase().into(),
2451                                },
2452                                cx,
2453                            )
2454                            .unwrap();
2455                    })?;
2456                    Ok(acp::PromptResponse {
2457                        stop_reason: acp::StopReason::EndTurn,
2458                    })
2459                }
2460                .boxed_local()
2461            }
2462        }));
2463        let thread = cx
2464            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2465            .await
2466            .unwrap();
2467
2468        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2469            .await
2470            .unwrap();
2471        thread.read_with(cx, |thread, cx| {
2472            assert_eq!(
2473                thread.to_markdown(cx),
2474                indoc! {"
2475                    ## User
2476
2477                    hello
2478
2479                    ## Assistant
2480
2481                    HELLO
2482
2483                "}
2484            );
2485        });
2486
2487        // Simulate refusing the second message, ensuring the conversation gets
2488        // truncated to before sending it.
2489        refuse_next.store(true, SeqCst);
2490        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
2491            .await
2492            .unwrap();
2493        thread.read_with(cx, |thread, cx| {
2494            assert_eq!(
2495                thread.to_markdown(cx),
2496                indoc! {"
2497                    ## User
2498
2499                    hello
2500
2501                    ## Assistant
2502
2503                    HELLO
2504
2505                "}
2506            );
2507        });
2508    }
2509
2510    async fn run_until_first_tool_call(
2511        thread: &Entity<AcpThread>,
2512        cx: &mut TestAppContext,
2513    ) -> usize {
2514        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2515
2516        let subscription = cx.update(|cx| {
2517            cx.subscribe(thread, move |thread, _, cx| {
2518                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2519                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2520                        return tx.try_send(ix).unwrap();
2521                    }
2522                }
2523            })
2524        });
2525
2526        select! {
2527            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2528                panic!("Timeout waiting for tool call")
2529            }
2530            ix = rx.next().fuse() => {
2531                drop(subscription);
2532                ix.unwrap()
2533            }
2534        }
2535    }
2536
2537    #[derive(Clone, Default)]
2538    struct FakeAgentConnection {
2539        auth_methods: Vec<acp::AuthMethod>,
2540        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2541        on_user_message: Option<
2542            Rc<
2543                dyn Fn(
2544                        acp::PromptRequest,
2545                        WeakEntity<AcpThread>,
2546                        AsyncApp,
2547                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2548                    + 'static,
2549            >,
2550        >,
2551    }
2552
2553    impl FakeAgentConnection {
2554        fn new() -> Self {
2555            Self {
2556                auth_methods: Vec::new(),
2557                on_user_message: None,
2558                sessions: Arc::default(),
2559            }
2560        }
2561
2562        #[expect(unused)]
2563        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2564            self.auth_methods = auth_methods;
2565            self
2566        }
2567
2568        fn on_user_message(
2569            mut self,
2570            handler: impl Fn(
2571                acp::PromptRequest,
2572                WeakEntity<AcpThread>,
2573                AsyncApp,
2574            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2575            + 'static,
2576        ) -> Self {
2577            self.on_user_message.replace(Rc::new(handler));
2578            self
2579        }
2580    }
2581
2582    impl AgentConnection for FakeAgentConnection {
2583        fn auth_methods(&self) -> &[acp::AuthMethod] {
2584            &self.auth_methods
2585        }
2586
2587        fn new_thread(
2588            self: Rc<Self>,
2589            project: Entity<Project>,
2590            _cwd: &Path,
2591            cx: &mut App,
2592        ) -> Task<gpui::Result<Entity<AcpThread>>> {
2593            let session_id = acp::SessionId(
2594                rand::thread_rng()
2595                    .sample_iter(&rand::distributions::Alphanumeric)
2596                    .take(7)
2597                    .map(char::from)
2598                    .collect::<String>()
2599                    .into(),
2600            );
2601            let action_log = cx.new(|_| ActionLog::new(project.clone()));
2602            let thread = cx.new(|_cx| {
2603                AcpThread::new(
2604                    "Test",
2605                    self.clone(),
2606                    project,
2607                    action_log,
2608                    session_id.clone(),
2609                )
2610            });
2611            self.sessions.lock().insert(session_id, thread.downgrade());
2612            Task::ready(Ok(thread))
2613        }
2614
2615        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2616            if self.auth_methods().iter().any(|m| m.id == method) {
2617                Task::ready(Ok(()))
2618            } else {
2619                Task::ready(Err(anyhow!("Invalid Auth Method")))
2620            }
2621        }
2622
2623        fn prompt(
2624            &self,
2625            _id: Option<UserMessageId>,
2626            params: acp::PromptRequest,
2627            cx: &mut App,
2628        ) -> Task<gpui::Result<acp::PromptResponse>> {
2629            let sessions = self.sessions.lock();
2630            let thread = sessions.get(&params.session_id).unwrap();
2631            if let Some(handler) = &self.on_user_message {
2632                let handler = handler.clone();
2633                let thread = thread.clone();
2634                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2635            } else {
2636                Task::ready(Ok(acp::PromptResponse {
2637                    stop_reason: acp::StopReason::EndTurn,
2638                }))
2639            }
2640        }
2641
2642        fn prompt_capabilities(&self) -> acp::PromptCapabilities {
2643            acp::PromptCapabilities {
2644                image: true,
2645                audio: true,
2646                embedded_context: true,
2647            }
2648        }
2649
2650        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2651            let sessions = self.sessions.lock();
2652            let thread = sessions.get(session_id).unwrap().clone();
2653
2654            cx.spawn(async move |cx| {
2655                thread
2656                    .update(cx, |thread, cx| thread.cancel(cx))
2657                    .unwrap()
2658                    .await
2659            })
2660            .detach();
2661        }
2662
2663        fn truncate(
2664            &self,
2665            session_id: &acp::SessionId,
2666            _cx: &App,
2667        ) -> Option<Rc<dyn AgentSessionTruncate>> {
2668            Some(Rc::new(FakeAgentSessionEditor {
2669                _session_id: session_id.clone(),
2670            }))
2671        }
2672
2673        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2674            self
2675        }
2676    }
2677
2678    struct FakeAgentSessionEditor {
2679        _session_id: acp::SessionId,
2680    }
2681
2682    impl AgentSessionTruncate for FakeAgentSessionEditor {
2683        fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2684            Task::ready(Ok(()))
2685        }
2686    }
2687}