acp_thread.rs

   1mod connection;
   2mod diff;
   3mod mention;
   4mod terminal;
   5
   6use collections::HashSet;
   7pub use connection::*;
   8pub use diff::*;
   9use language::language_settings::FormatOnSave;
  10pub use mention::*;
  11use project::lsp_store::{FormatTrigger, LspFormatTarget};
  12use serde::{Deserialize, Serialize};
  13pub use terminal::*;
  14
  15use action_log::ActionLog;
  16use agent_client_protocol as acp;
  17use anyhow::{Context as _, Result, anyhow};
  18use editor::Bias;
  19use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  20use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  21use itertools::Itertools;
  22use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
  23use markdown::Markdown;
  24use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
  25use std::collections::HashMap;
  26use std::error::Error;
  27use std::fmt::{Formatter, Write};
  28use std::ops::Range;
  29use std::process::ExitStatus;
  30use std::rc::Rc;
  31use std::time::{Duration, Instant};
  32use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
  33use ui::App;
  34use util::ResultExt;
  35
  36#[derive(Debug)]
  37pub struct UserMessage {
  38    pub id: Option<UserMessageId>,
  39    pub content: ContentBlock,
  40    pub chunks: Vec<acp::ContentBlock>,
  41    pub checkpoint: Option<Checkpoint>,
  42}
  43
  44#[derive(Debug)]
  45pub struct Checkpoint {
  46    git_checkpoint: GitStoreCheckpoint,
  47    pub show: bool,
  48}
  49
  50impl UserMessage {
  51    fn to_markdown(&self, cx: &App) -> String {
  52        let mut markdown = String::new();
  53        if self
  54            .checkpoint
  55            .as_ref()
  56            .is_some_and(|checkpoint| checkpoint.show)
  57        {
  58            writeln!(markdown, "## User (checkpoint)").unwrap();
  59        } else {
  60            writeln!(markdown, "## User").unwrap();
  61        }
  62        writeln!(markdown).unwrap();
  63        writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
  64        writeln!(markdown).unwrap();
  65        markdown
  66    }
  67}
  68
  69#[derive(Debug, PartialEq)]
  70pub struct AssistantMessage {
  71    pub chunks: Vec<AssistantMessageChunk>,
  72}
  73
  74impl AssistantMessage {
  75    pub fn to_markdown(&self, cx: &App) -> String {
  76        format!(
  77            "## Assistant\n\n{}\n\n",
  78            self.chunks
  79                .iter()
  80                .map(|chunk| chunk.to_markdown(cx))
  81                .join("\n\n")
  82        )
  83    }
  84}
  85
  86#[derive(Debug, PartialEq)]
  87pub enum AssistantMessageChunk {
  88    Message { block: ContentBlock },
  89    Thought { block: ContentBlock },
  90}
  91
  92impl AssistantMessageChunk {
  93    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
  94        Self::Message {
  95            block: ContentBlock::new(chunk.into(), language_registry, cx),
  96        }
  97    }
  98
  99    fn to_markdown(&self, cx: &App) -> String {
 100        match self {
 101            Self::Message { block } => block.to_markdown(cx).to_string(),
 102            Self::Thought { block } => {
 103                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 104            }
 105        }
 106    }
 107}
 108
 109#[derive(Debug)]
 110pub enum AgentThreadEntry {
 111    UserMessage(UserMessage),
 112    AssistantMessage(AssistantMessage),
 113    ToolCall(ToolCall),
 114}
 115
 116impl AgentThreadEntry {
 117    pub fn to_markdown(&self, cx: &App) -> String {
 118        match self {
 119            Self::UserMessage(message) => message.to_markdown(cx),
 120            Self::AssistantMessage(message) => message.to_markdown(cx),
 121            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 122        }
 123    }
 124
 125    pub fn user_message(&self) -> Option<&UserMessage> {
 126        if let AgentThreadEntry::UserMessage(message) = self {
 127            Some(message)
 128        } else {
 129            None
 130        }
 131    }
 132
 133    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 134        if let AgentThreadEntry::ToolCall(call) = self {
 135            itertools::Either::Left(call.diffs())
 136        } else {
 137            itertools::Either::Right(std::iter::empty())
 138        }
 139    }
 140
 141    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 142        if let AgentThreadEntry::ToolCall(call) = self {
 143            itertools::Either::Left(call.terminals())
 144        } else {
 145            itertools::Either::Right(std::iter::empty())
 146        }
 147    }
 148
 149    pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
 150        if let AgentThreadEntry::ToolCall(ToolCall {
 151            locations,
 152            resolved_locations,
 153            ..
 154        }) = self
 155        {
 156            Some((
 157                locations.get(ix)?.clone(),
 158                resolved_locations.get(ix)?.clone()?,
 159            ))
 160        } else {
 161            None
 162        }
 163    }
 164}
 165
 166#[derive(Debug)]
 167pub struct ToolCall {
 168    pub id: acp::ToolCallId,
 169    pub label: Entity<Markdown>,
 170    pub kind: acp::ToolKind,
 171    pub content: Vec<ToolCallContent>,
 172    pub status: ToolCallStatus,
 173    pub locations: Vec<acp::ToolCallLocation>,
 174    pub resolved_locations: Vec<Option<AgentLocation>>,
 175    pub raw_input: Option<serde_json::Value>,
 176    pub raw_output: Option<serde_json::Value>,
 177}
 178
 179impl ToolCall {
 180    fn from_acp(
 181        tool_call: acp::ToolCall,
 182        status: ToolCallStatus,
 183        language_registry: Arc<LanguageRegistry>,
 184        cx: &mut App,
 185    ) -> Self {
 186        Self {
 187            id: tool_call.id,
 188            label: cx.new(|cx| {
 189                Markdown::new(
 190                    tool_call.title.into(),
 191                    Some(language_registry.clone()),
 192                    None,
 193                    cx,
 194                )
 195            }),
 196            kind: tool_call.kind,
 197            content: tool_call
 198                .content
 199                .into_iter()
 200                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 201                .collect(),
 202            locations: tool_call.locations,
 203            resolved_locations: Vec::default(),
 204            status,
 205            raw_input: tool_call.raw_input,
 206            raw_output: tool_call.raw_output,
 207        }
 208    }
 209
 210    fn update_fields(
 211        &mut self,
 212        fields: acp::ToolCallUpdateFields,
 213        language_registry: Arc<LanguageRegistry>,
 214        cx: &mut App,
 215    ) {
 216        let acp::ToolCallUpdateFields {
 217            kind,
 218            status,
 219            title,
 220            content,
 221            locations,
 222            raw_input,
 223            raw_output,
 224        } = fields;
 225
 226        if let Some(kind) = kind {
 227            self.kind = kind;
 228        }
 229
 230        if let Some(status) = status {
 231            self.status = status.into();
 232        }
 233
 234        if let Some(title) = title {
 235            self.label.update(cx, |label, cx| {
 236                label.replace(title, cx);
 237            });
 238        }
 239
 240        if let Some(content) = content {
 241            self.content = content
 242                .into_iter()
 243                .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
 244                .collect();
 245        }
 246
 247        if let Some(locations) = locations {
 248            self.locations = locations;
 249        }
 250
 251        if let Some(raw_input) = raw_input {
 252            self.raw_input = Some(raw_input);
 253        }
 254
 255        if let Some(raw_output) = raw_output {
 256            if self.content.is_empty()
 257                && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
 258            {
 259                self.content
 260                    .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
 261                        markdown,
 262                    }));
 263            }
 264            self.raw_output = Some(raw_output);
 265        }
 266    }
 267
 268    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 269        self.content.iter().filter_map(|content| match content {
 270            ToolCallContent::Diff(diff) => Some(diff),
 271            ToolCallContent::ContentBlock(_) => None,
 272            ToolCallContent::Terminal(_) => None,
 273        })
 274    }
 275
 276    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 277        self.content.iter().filter_map(|content| match content {
 278            ToolCallContent::Terminal(terminal) => Some(terminal),
 279            ToolCallContent::ContentBlock(_) => None,
 280            ToolCallContent::Diff(_) => None,
 281        })
 282    }
 283
 284    fn to_markdown(&self, cx: &App) -> String {
 285        let mut markdown = format!(
 286            "**Tool Call: {}**\nStatus: {}\n\n",
 287            self.label.read(cx).source(),
 288            self.status
 289        );
 290        for content in &self.content {
 291            markdown.push_str(content.to_markdown(cx).as_str());
 292            markdown.push_str("\n\n");
 293        }
 294        markdown
 295    }
 296
 297    async fn resolve_location(
 298        location: acp::ToolCallLocation,
 299        project: WeakEntity<Project>,
 300        cx: &mut AsyncApp,
 301    ) -> Option<AgentLocation> {
 302        let buffer = project
 303            .update(cx, |project, cx| {
 304                project
 305                    .project_path_for_absolute_path(&location.path, cx)
 306                    .map(|path| project.open_buffer(path, cx))
 307            })
 308            .ok()??;
 309        let buffer = buffer.await.log_err()?;
 310        let position = buffer
 311            .update(cx, |buffer, _| {
 312                if let Some(row) = location.line {
 313                    let snapshot = buffer.snapshot();
 314                    let column = snapshot.indent_size_for_line(row).len;
 315                    let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
 316                    snapshot.anchor_before(point)
 317                } else {
 318                    Anchor::MIN
 319                }
 320            })
 321            .ok()?;
 322
 323        Some(AgentLocation {
 324            buffer: buffer.downgrade(),
 325            position,
 326        })
 327    }
 328
 329    fn resolve_locations(
 330        &self,
 331        project: Entity<Project>,
 332        cx: &mut App,
 333    ) -> Task<Vec<Option<AgentLocation>>> {
 334        let locations = self.locations.clone();
 335        project.update(cx, |_, cx| {
 336            cx.spawn(async move |project, cx| {
 337                let mut new_locations = Vec::new();
 338                for location in locations {
 339                    new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
 340                }
 341                new_locations
 342            })
 343        })
 344    }
 345}
 346
 347#[derive(Debug)]
 348pub enum ToolCallStatus {
 349    /// The tool call hasn't started running yet, but we start showing it to
 350    /// the user.
 351    Pending,
 352    /// The tool call is waiting for confirmation from the user.
 353    WaitingForConfirmation {
 354        options: Vec<acp::PermissionOption>,
 355        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 356    },
 357    /// The tool call is currently running.
 358    InProgress,
 359    /// The tool call completed successfully.
 360    Completed,
 361    /// The tool call failed.
 362    Failed,
 363    /// The user rejected the tool call.
 364    Rejected,
 365    /// The user canceled generation so the tool call was canceled.
 366    Canceled,
 367}
 368
 369impl From<acp::ToolCallStatus> for ToolCallStatus {
 370    fn from(status: acp::ToolCallStatus) -> Self {
 371        match status {
 372            acp::ToolCallStatus::Pending => Self::Pending,
 373            acp::ToolCallStatus::InProgress => Self::InProgress,
 374            acp::ToolCallStatus::Completed => Self::Completed,
 375            acp::ToolCallStatus::Failed => Self::Failed,
 376        }
 377    }
 378}
 379
 380impl Display for ToolCallStatus {
 381    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 382        write!(
 383            f,
 384            "{}",
 385            match self {
 386                ToolCallStatus::Pending => "Pending",
 387                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 388                ToolCallStatus::InProgress => "In Progress",
 389                ToolCallStatus::Completed => "Completed",
 390                ToolCallStatus::Failed => "Failed",
 391                ToolCallStatus::Rejected => "Rejected",
 392                ToolCallStatus::Canceled => "Canceled",
 393            }
 394        )
 395    }
 396}
 397
 398#[derive(Debug, PartialEq, Clone)]
 399pub enum ContentBlock {
 400    Empty,
 401    Markdown { markdown: Entity<Markdown> },
 402    ResourceLink { resource_link: acp::ResourceLink },
 403}
 404
 405impl ContentBlock {
 406    pub fn new(
 407        block: acp::ContentBlock,
 408        language_registry: &Arc<LanguageRegistry>,
 409        cx: &mut App,
 410    ) -> Self {
 411        let mut this = Self::Empty;
 412        this.append(block, language_registry, cx);
 413        this
 414    }
 415
 416    pub fn new_combined(
 417        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 418        language_registry: Arc<LanguageRegistry>,
 419        cx: &mut App,
 420    ) -> Self {
 421        let mut this = Self::Empty;
 422        for block in blocks {
 423            this.append(block, &language_registry, cx);
 424        }
 425        this
 426    }
 427
 428    pub fn append(
 429        &mut self,
 430        block: acp::ContentBlock,
 431        language_registry: &Arc<LanguageRegistry>,
 432        cx: &mut App,
 433    ) {
 434        if matches!(self, ContentBlock::Empty)
 435            && let acp::ContentBlock::ResourceLink(resource_link) = block
 436        {
 437            *self = ContentBlock::ResourceLink { resource_link };
 438            return;
 439        }
 440
 441        let new_content = self.block_string_contents(block);
 442
 443        match self {
 444            ContentBlock::Empty => {
 445                *self = Self::create_markdown_block(new_content, language_registry, cx);
 446            }
 447            ContentBlock::Markdown { markdown } => {
 448                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 449            }
 450            ContentBlock::ResourceLink { resource_link } => {
 451                let existing_content = Self::resource_link_md(&resource_link.uri);
 452                let combined = format!("{}\n{}", existing_content, new_content);
 453
 454                *self = Self::create_markdown_block(combined, language_registry, cx);
 455            }
 456        }
 457    }
 458
 459    fn create_markdown_block(
 460        content: String,
 461        language_registry: &Arc<LanguageRegistry>,
 462        cx: &mut App,
 463    ) -> ContentBlock {
 464        ContentBlock::Markdown {
 465            markdown: cx
 466                .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
 467        }
 468    }
 469
 470    fn block_string_contents(&self, block: acp::ContentBlock) -> String {
 471        match block {
 472            acp::ContentBlock::Text(text_content) => text_content.text,
 473            acp::ContentBlock::ResourceLink(resource_link) => {
 474                Self::resource_link_md(&resource_link.uri)
 475            }
 476            acp::ContentBlock::Resource(acp::EmbeddedResource {
 477                resource:
 478                    acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
 479                        uri,
 480                        ..
 481                    }),
 482                ..
 483            }) => Self::resource_link_md(&uri),
 484            acp::ContentBlock::Image(image) => Self::image_md(&image),
 485            acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
 486        }
 487    }
 488
 489    fn resource_link_md(uri: &str) -> String {
 490        if let Some(uri) = MentionUri::parse(uri).log_err() {
 491            uri.as_link().to_string()
 492        } else {
 493            uri.to_string()
 494        }
 495    }
 496
 497    fn image_md(_image: &acp::ImageContent) -> String {
 498        "`Image`".into()
 499    }
 500
 501    fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 502        match self {
 503            ContentBlock::Empty => "",
 504            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 505            ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
 506        }
 507    }
 508
 509    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 510        match self {
 511            ContentBlock::Empty => None,
 512            ContentBlock::Markdown { markdown } => Some(markdown),
 513            ContentBlock::ResourceLink { .. } => None,
 514        }
 515    }
 516
 517    pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
 518        match self {
 519            ContentBlock::ResourceLink { resource_link } => Some(resource_link),
 520            _ => None,
 521        }
 522    }
 523}
 524
 525#[derive(Debug)]
 526pub enum ToolCallContent {
 527    ContentBlock(ContentBlock),
 528    Diff(Entity<Diff>),
 529    Terminal(Entity<Terminal>),
 530}
 531
 532impl ToolCallContent {
 533    pub fn from_acp(
 534        content: acp::ToolCallContent,
 535        language_registry: Arc<LanguageRegistry>,
 536        cx: &mut App,
 537    ) -> Self {
 538        match content {
 539            acp::ToolCallContent::Content { content } => {
 540                Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
 541            }
 542            acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
 543                Diff::finalized(
 544                    diff.path,
 545                    diff.old_text,
 546                    diff.new_text,
 547                    language_registry,
 548                    cx,
 549                )
 550            })),
 551        }
 552    }
 553
 554    pub fn to_markdown(&self, cx: &App) -> String {
 555        match self {
 556            Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
 557            Self::Diff(diff) => diff.read(cx).to_markdown(cx),
 558            Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
 559        }
 560    }
 561}
 562
 563#[derive(Debug, PartialEq)]
 564pub enum ToolCallUpdate {
 565    UpdateFields(acp::ToolCallUpdate),
 566    UpdateDiff(ToolCallUpdateDiff),
 567    UpdateTerminal(ToolCallUpdateTerminal),
 568}
 569
 570impl ToolCallUpdate {
 571    fn id(&self) -> &acp::ToolCallId {
 572        match self {
 573            Self::UpdateFields(update) => &update.id,
 574            Self::UpdateDiff(diff) => &diff.id,
 575            Self::UpdateTerminal(terminal) => &terminal.id,
 576        }
 577    }
 578}
 579
 580impl From<acp::ToolCallUpdate> for ToolCallUpdate {
 581    fn from(update: acp::ToolCallUpdate) -> Self {
 582        Self::UpdateFields(update)
 583    }
 584}
 585
 586impl From<ToolCallUpdateDiff> for ToolCallUpdate {
 587    fn from(diff: ToolCallUpdateDiff) -> Self {
 588        Self::UpdateDiff(diff)
 589    }
 590}
 591
 592#[derive(Debug, PartialEq)]
 593pub struct ToolCallUpdateDiff {
 594    pub id: acp::ToolCallId,
 595    pub diff: Entity<Diff>,
 596}
 597
 598impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
 599    fn from(terminal: ToolCallUpdateTerminal) -> Self {
 600        Self::UpdateTerminal(terminal)
 601    }
 602}
 603
 604#[derive(Debug, PartialEq)]
 605pub struct ToolCallUpdateTerminal {
 606    pub id: acp::ToolCallId,
 607    pub terminal: Entity<Terminal>,
 608}
 609
 610#[derive(Debug, Default)]
 611pub struct Plan {
 612    pub entries: Vec<PlanEntry>,
 613}
 614
 615#[derive(Debug)]
 616pub struct PlanStats<'a> {
 617    pub in_progress_entry: Option<&'a PlanEntry>,
 618    pub pending: u32,
 619    pub completed: u32,
 620}
 621
 622impl Plan {
 623    pub fn is_empty(&self) -> bool {
 624        self.entries.is_empty()
 625    }
 626
 627    pub fn stats(&self) -> PlanStats<'_> {
 628        let mut stats = PlanStats {
 629            in_progress_entry: None,
 630            pending: 0,
 631            completed: 0,
 632        };
 633
 634        for entry in &self.entries {
 635            match &entry.status {
 636                acp::PlanEntryStatus::Pending => {
 637                    stats.pending += 1;
 638                }
 639                acp::PlanEntryStatus::InProgress => {
 640                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 641                }
 642                acp::PlanEntryStatus::Completed => {
 643                    stats.completed += 1;
 644                }
 645            }
 646        }
 647
 648        stats
 649    }
 650}
 651
 652#[derive(Debug)]
 653pub struct PlanEntry {
 654    pub content: Entity<Markdown>,
 655    pub priority: acp::PlanEntryPriority,
 656    pub status: acp::PlanEntryStatus,
 657}
 658
 659impl PlanEntry {
 660    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 661        Self {
 662            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 663            priority: entry.priority,
 664            status: entry.status,
 665        }
 666    }
 667}
 668
 669#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 670pub struct TokenUsage {
 671    pub max_tokens: u64,
 672    pub used_tokens: u64,
 673}
 674
 675impl TokenUsage {
 676    pub fn ratio(&self) -> TokenUsageRatio {
 677        #[cfg(debug_assertions)]
 678        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 679            .unwrap_or("0.8".to_string())
 680            .parse()
 681            .unwrap();
 682        #[cfg(not(debug_assertions))]
 683        let warning_threshold: f32 = 0.8;
 684
 685        // When the maximum is unknown because there is no selected model,
 686        // avoid showing the token limit warning.
 687        if self.max_tokens == 0 {
 688            TokenUsageRatio::Normal
 689        } else if self.used_tokens >= self.max_tokens {
 690            TokenUsageRatio::Exceeded
 691        } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
 692            TokenUsageRatio::Warning
 693        } else {
 694            TokenUsageRatio::Normal
 695        }
 696    }
 697}
 698
 699#[derive(Debug, Clone, PartialEq, Eq)]
 700pub enum TokenUsageRatio {
 701    Normal,
 702    Warning,
 703    Exceeded,
 704}
 705
 706#[derive(Debug, Clone)]
 707pub struct RetryStatus {
 708    pub last_error: SharedString,
 709    pub attempt: usize,
 710    pub max_attempts: usize,
 711    pub started_at: Instant,
 712    pub duration: Duration,
 713}
 714
 715pub struct AcpThread {
 716    title: SharedString,
 717    entries: Vec<AgentThreadEntry>,
 718    plan: Plan,
 719    project: Entity<Project>,
 720    action_log: Entity<ActionLog>,
 721    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 722    send_task: Option<Task<()>>,
 723    connection: Rc<dyn AgentConnection>,
 724    session_id: acp::SessionId,
 725    token_usage: Option<TokenUsage>,
 726}
 727
 728#[derive(Debug)]
 729pub enum AcpThreadEvent {
 730    NewEntry,
 731    TitleUpdated,
 732    TokenUsageUpdated,
 733    EntryUpdated(usize),
 734    EntriesRemoved(Range<usize>),
 735    ToolAuthorizationRequired,
 736    Retry(RetryStatus),
 737    Stopped,
 738    Error,
 739    LoadError(LoadError),
 740}
 741
 742impl EventEmitter<AcpThreadEvent> for AcpThread {}
 743
 744#[derive(PartialEq, Eq)]
 745pub enum ThreadStatus {
 746    Idle,
 747    WaitingForToolConfirmation,
 748    Generating,
 749}
 750
 751#[derive(Debug, Clone)]
 752pub enum LoadError {
 753    NotInstalled {
 754        error_message: SharedString,
 755        install_message: SharedString,
 756        install_command: String,
 757    },
 758    Unsupported {
 759        error_message: SharedString,
 760        upgrade_message: SharedString,
 761        upgrade_command: String,
 762    },
 763    Exited {
 764        status: ExitStatus,
 765    },
 766    Other(SharedString),
 767}
 768
 769impl Display for LoadError {
 770    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 771        match self {
 772            LoadError::NotInstalled { error_message, .. }
 773            | LoadError::Unsupported { error_message, .. } => {
 774                write!(f, "{error_message}")
 775            }
 776            LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
 777            LoadError::Other(msg) => write!(f, "{}", msg),
 778        }
 779    }
 780}
 781
 782impl Error for LoadError {}
 783
 784impl AcpThread {
 785    pub fn new(
 786        title: impl Into<SharedString>,
 787        connection: Rc<dyn AgentConnection>,
 788        project: Entity<Project>,
 789        action_log: Entity<ActionLog>,
 790        session_id: acp::SessionId,
 791    ) -> Self {
 792        Self {
 793            action_log,
 794            shared_buffers: Default::default(),
 795            entries: Default::default(),
 796            plan: Default::default(),
 797            title: title.into(),
 798            project,
 799            send_task: None,
 800            connection,
 801            session_id,
 802            token_usage: None,
 803        }
 804    }
 805
 806    pub fn connection(&self) -> &Rc<dyn AgentConnection> {
 807        &self.connection
 808    }
 809
 810    pub fn action_log(&self) -> &Entity<ActionLog> {
 811        &self.action_log
 812    }
 813
 814    pub fn project(&self) -> &Entity<Project> {
 815        &self.project
 816    }
 817
 818    pub fn title(&self) -> SharedString {
 819        self.title.clone()
 820    }
 821
 822    pub fn entries(&self) -> &[AgentThreadEntry] {
 823        &self.entries
 824    }
 825
 826    pub fn session_id(&self) -> &acp::SessionId {
 827        &self.session_id
 828    }
 829
 830    pub fn status(&self) -> ThreadStatus {
 831        if self.send_task.is_some() {
 832            if self.waiting_for_tool_confirmation() {
 833                ThreadStatus::WaitingForToolConfirmation
 834            } else {
 835                ThreadStatus::Generating
 836            }
 837        } else {
 838            ThreadStatus::Idle
 839        }
 840    }
 841
 842    pub fn token_usage(&self) -> Option<&TokenUsage> {
 843        self.token_usage.as_ref()
 844    }
 845
 846    pub fn has_pending_edit_tool_calls(&self) -> bool {
 847        for entry in self.entries.iter().rev() {
 848            match entry {
 849                AgentThreadEntry::UserMessage(_) => return false,
 850                AgentThreadEntry::ToolCall(
 851                    call @ ToolCall {
 852                        status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
 853                        ..
 854                    },
 855                ) if call.diffs().next().is_some() => {
 856                    return true;
 857                }
 858                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 859            }
 860        }
 861
 862        false
 863    }
 864
 865    pub fn used_tools_since_last_user_message(&self) -> bool {
 866        for entry in self.entries.iter().rev() {
 867            match entry {
 868                AgentThreadEntry::UserMessage(..) => return false,
 869                AgentThreadEntry::AssistantMessage(..) => continue,
 870                AgentThreadEntry::ToolCall(..) => return true,
 871            }
 872        }
 873
 874        false
 875    }
 876
 877    pub fn handle_session_update(
 878        &mut self,
 879        update: acp::SessionUpdate,
 880        cx: &mut Context<Self>,
 881    ) -> Result<(), acp::Error> {
 882        match update {
 883            acp::SessionUpdate::UserMessageChunk { content } => {
 884                self.push_user_content_block(None, content, cx);
 885            }
 886            acp::SessionUpdate::AgentMessageChunk { content } => {
 887                self.push_assistant_content_block(content, false, cx);
 888            }
 889            acp::SessionUpdate::AgentThoughtChunk { content } => {
 890                self.push_assistant_content_block(content, true, cx);
 891            }
 892            acp::SessionUpdate::ToolCall(tool_call) => {
 893                self.upsert_tool_call(tool_call, cx)?;
 894            }
 895            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 896                self.update_tool_call(tool_call_update, cx)?;
 897            }
 898            acp::SessionUpdate::Plan(plan) => {
 899                self.update_plan(plan, cx);
 900            }
 901        }
 902        Ok(())
 903    }
 904
 905    pub fn push_user_content_block(
 906        &mut self,
 907        message_id: Option<UserMessageId>,
 908        chunk: acp::ContentBlock,
 909        cx: &mut Context<Self>,
 910    ) {
 911        let language_registry = self.project.read(cx).languages().clone();
 912        let entries_len = self.entries.len();
 913
 914        if let Some(last_entry) = self.entries.last_mut()
 915            && let AgentThreadEntry::UserMessage(UserMessage {
 916                id,
 917                content,
 918                chunks,
 919                ..
 920            }) = last_entry
 921        {
 922            *id = message_id.or(id.take());
 923            content.append(chunk.clone(), &language_registry, cx);
 924            chunks.push(chunk);
 925            let idx = entries_len - 1;
 926            cx.emit(AcpThreadEvent::EntryUpdated(idx));
 927        } else {
 928            let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
 929            self.push_entry(
 930                AgentThreadEntry::UserMessage(UserMessage {
 931                    id: message_id,
 932                    content,
 933                    chunks: vec![chunk],
 934                    checkpoint: None,
 935                }),
 936                cx,
 937            );
 938        }
 939    }
 940
 941    pub fn push_assistant_content_block(
 942        &mut self,
 943        chunk: acp::ContentBlock,
 944        is_thought: bool,
 945        cx: &mut Context<Self>,
 946    ) {
 947        let language_registry = self.project.read(cx).languages().clone();
 948        let entries_len = self.entries.len();
 949        if let Some(last_entry) = self.entries.last_mut()
 950            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
 951        {
 952            let idx = entries_len - 1;
 953            cx.emit(AcpThreadEvent::EntryUpdated(idx));
 954            match (chunks.last_mut(), is_thought) {
 955                (Some(AssistantMessageChunk::Message { block }), false)
 956                | (Some(AssistantMessageChunk::Thought { block }), true) => {
 957                    block.append(chunk, &language_registry, cx)
 958                }
 959                _ => {
 960                    let block = ContentBlock::new(chunk, &language_registry, cx);
 961                    if is_thought {
 962                        chunks.push(AssistantMessageChunk::Thought { block })
 963                    } else {
 964                        chunks.push(AssistantMessageChunk::Message { block })
 965                    }
 966                }
 967            }
 968        } else {
 969            let block = ContentBlock::new(chunk, &language_registry, cx);
 970            let chunk = if is_thought {
 971                AssistantMessageChunk::Thought { block }
 972            } else {
 973                AssistantMessageChunk::Message { block }
 974            };
 975
 976            self.push_entry(
 977                AgentThreadEntry::AssistantMessage(AssistantMessage {
 978                    chunks: vec![chunk],
 979                }),
 980                cx,
 981            );
 982        }
 983    }
 984
 985    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
 986        self.entries.push(entry);
 987        cx.emit(AcpThreadEvent::NewEntry);
 988    }
 989
 990    pub fn update_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Result<()> {
 991        self.title = title;
 992        cx.emit(AcpThreadEvent::TitleUpdated);
 993        Ok(())
 994    }
 995
 996    pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
 997        self.token_usage = usage;
 998        cx.emit(AcpThreadEvent::TokenUsageUpdated);
 999    }
1000
1001    pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1002        cx.emit(AcpThreadEvent::Retry(status));
1003    }
1004
1005    pub fn update_tool_call(
1006        &mut self,
1007        update: impl Into<ToolCallUpdate>,
1008        cx: &mut Context<Self>,
1009    ) -> Result<()> {
1010        let update = update.into();
1011        let languages = self.project.read(cx).languages().clone();
1012
1013        let (ix, current_call) = self
1014            .tool_call_mut(update.id())
1015            .context("Tool call not found")?;
1016        match update {
1017            ToolCallUpdate::UpdateFields(update) => {
1018                let location_updated = update.fields.locations.is_some();
1019                current_call.update_fields(update.fields, languages, cx);
1020                if location_updated {
1021                    self.resolve_locations(update.id, cx);
1022                }
1023            }
1024            ToolCallUpdate::UpdateDiff(update) => {
1025                current_call.content.clear();
1026                current_call
1027                    .content
1028                    .push(ToolCallContent::Diff(update.diff));
1029            }
1030            ToolCallUpdate::UpdateTerminal(update) => {
1031                current_call.content.clear();
1032                current_call
1033                    .content
1034                    .push(ToolCallContent::Terminal(update.terminal));
1035            }
1036        }
1037
1038        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1039
1040        Ok(())
1041    }
1042
1043    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1044    pub fn upsert_tool_call(
1045        &mut self,
1046        tool_call: acp::ToolCall,
1047        cx: &mut Context<Self>,
1048    ) -> Result<(), acp::Error> {
1049        let status = tool_call.status.into();
1050        self.upsert_tool_call_inner(tool_call.into(), status, cx)
1051    }
1052
1053    /// Fails if id does not match an existing entry.
1054    pub fn upsert_tool_call_inner(
1055        &mut self,
1056        tool_call_update: acp::ToolCallUpdate,
1057        status: ToolCallStatus,
1058        cx: &mut Context<Self>,
1059    ) -> Result<(), acp::Error> {
1060        let language_registry = self.project.read(cx).languages().clone();
1061        let id = tool_call_update.id.clone();
1062
1063        if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1064            current_call.update_fields(tool_call_update.fields, language_registry, cx);
1065            current_call.status = status;
1066
1067            cx.emit(AcpThreadEvent::EntryUpdated(ix));
1068        } else {
1069            let call =
1070                ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1071            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1072        };
1073
1074        self.resolve_locations(id, cx);
1075        Ok(())
1076    }
1077
1078    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1079        // The tool call we are looking for is typically the last one, or very close to the end.
1080        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1081        self.entries
1082            .iter_mut()
1083            .enumerate()
1084            .rev()
1085            .find_map(|(index, tool_call)| {
1086                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1087                    && &tool_call.id == id
1088                {
1089                    Some((index, tool_call))
1090                } else {
1091                    None
1092                }
1093            })
1094    }
1095
1096    pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1097        self.entries
1098            .iter()
1099            .enumerate()
1100            .rev()
1101            .find_map(|(index, tool_call)| {
1102                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1103                    && &tool_call.id == id
1104                {
1105                    Some((index, tool_call))
1106                } else {
1107                    None
1108                }
1109            })
1110    }
1111
1112    pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1113        let project = self.project.clone();
1114        let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1115            return;
1116        };
1117        let task = tool_call.resolve_locations(project, cx);
1118        cx.spawn(async move |this, cx| {
1119            let resolved_locations = task.await;
1120            this.update(cx, |this, cx| {
1121                let project = this.project.clone();
1122                let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1123                    return;
1124                };
1125                if let Some(Some(location)) = resolved_locations.last() {
1126                    project.update(cx, |project, cx| {
1127                        if let Some(agent_location) = project.agent_location() {
1128                            let should_ignore = agent_location.buffer == location.buffer
1129                                && location
1130                                    .buffer
1131                                    .update(cx, |buffer, _| {
1132                                        let snapshot = buffer.snapshot();
1133                                        let old_position =
1134                                            agent_location.position.to_point(&snapshot);
1135                                        let new_position = location.position.to_point(&snapshot);
1136                                        // ignore this so that when we get updates from the edit tool
1137                                        // the position doesn't reset to the startof line
1138                                        old_position.row == new_position.row
1139                                            && old_position.column > new_position.column
1140                                    })
1141                                    .ok()
1142                                    .unwrap_or_default();
1143                            if !should_ignore {
1144                                project.set_agent_location(Some(location.clone()), cx);
1145                            }
1146                        }
1147                    });
1148                }
1149                if tool_call.resolved_locations != resolved_locations {
1150                    tool_call.resolved_locations = resolved_locations;
1151                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1152                }
1153            })
1154        })
1155        .detach();
1156    }
1157
1158    pub fn request_tool_call_authorization(
1159        &mut self,
1160        tool_call: acp::ToolCallUpdate,
1161        options: Vec<acp::PermissionOption>,
1162        cx: &mut Context<Self>,
1163    ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1164        let (tx, rx) = oneshot::channel();
1165
1166        let status = ToolCallStatus::WaitingForConfirmation {
1167            options,
1168            respond_tx: tx,
1169        };
1170
1171        self.upsert_tool_call_inner(tool_call, status, cx)?;
1172        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1173        Ok(rx)
1174    }
1175
1176    pub fn authorize_tool_call(
1177        &mut self,
1178        id: acp::ToolCallId,
1179        option_id: acp::PermissionOptionId,
1180        option_kind: acp::PermissionOptionKind,
1181        cx: &mut Context<Self>,
1182    ) {
1183        let Some((ix, call)) = self.tool_call_mut(&id) else {
1184            return;
1185        };
1186
1187        let new_status = match option_kind {
1188            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1189                ToolCallStatus::Rejected
1190            }
1191            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1192                ToolCallStatus::InProgress
1193            }
1194        };
1195
1196        let curr_status = mem::replace(&mut call.status, new_status);
1197
1198        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1199            respond_tx.send(option_id).log_err();
1200        } else if cfg!(debug_assertions) {
1201            panic!("tried to authorize an already authorized tool call");
1202        }
1203
1204        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1205    }
1206
1207    /// Returns true if the last turn is awaiting tool authorization
1208    pub fn waiting_for_tool_confirmation(&self) -> bool {
1209        for entry in self.entries.iter().rev() {
1210            match &entry {
1211                AgentThreadEntry::ToolCall(call) => match call.status {
1212                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
1213                    ToolCallStatus::Pending
1214                    | ToolCallStatus::InProgress
1215                    | ToolCallStatus::Completed
1216                    | ToolCallStatus::Failed
1217                    | ToolCallStatus::Rejected
1218                    | ToolCallStatus::Canceled => continue,
1219                },
1220                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1221                    // Reached the beginning of the turn
1222                    return false;
1223                }
1224            }
1225        }
1226        false
1227    }
1228
1229    pub fn plan(&self) -> &Plan {
1230        &self.plan
1231    }
1232
1233    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1234        let new_entries_len = request.entries.len();
1235        let mut new_entries = request.entries.into_iter();
1236
1237        // Reuse existing markdown to prevent flickering
1238        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1239            let PlanEntry {
1240                content,
1241                priority,
1242                status,
1243            } = old;
1244            content.update(cx, |old, cx| {
1245                old.replace(new.content, cx);
1246            });
1247            *priority = new.priority;
1248            *status = new.status;
1249        }
1250        for new in new_entries {
1251            self.plan.entries.push(PlanEntry::from_acp(new, cx))
1252        }
1253        self.plan.entries.truncate(new_entries_len);
1254
1255        cx.notify();
1256    }
1257
1258    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1259        self.plan
1260            .entries
1261            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1262        cx.notify();
1263    }
1264
1265    #[cfg(any(test, feature = "test-support"))]
1266    pub fn send_raw(
1267        &mut self,
1268        message: &str,
1269        cx: &mut Context<Self>,
1270    ) -> BoxFuture<'static, Result<()>> {
1271        self.send(
1272            vec![acp::ContentBlock::Text(acp::TextContent {
1273                text: message.to_string(),
1274                annotations: None,
1275            })],
1276            cx,
1277        )
1278    }
1279
1280    pub fn send(
1281        &mut self,
1282        message: Vec<acp::ContentBlock>,
1283        cx: &mut Context<Self>,
1284    ) -> BoxFuture<'static, Result<()>> {
1285        let block = ContentBlock::new_combined(
1286            message.clone(),
1287            self.project.read(cx).languages().clone(),
1288            cx,
1289        );
1290        let request = acp::PromptRequest {
1291            prompt: message.clone(),
1292            session_id: self.session_id.clone(),
1293        };
1294        let git_store = self.project.read(cx).git_store().clone();
1295
1296        let message_id = if self
1297            .connection
1298            .session_editor(&self.session_id, cx)
1299            .is_some()
1300        {
1301            Some(UserMessageId::new())
1302        } else {
1303            None
1304        };
1305
1306        self.run_turn(cx, async move |this, cx| {
1307            this.update(cx, |this, cx| {
1308                this.push_entry(
1309                    AgentThreadEntry::UserMessage(UserMessage {
1310                        id: message_id.clone(),
1311                        content: block,
1312                        chunks: message,
1313                        checkpoint: None,
1314                    }),
1315                    cx,
1316                );
1317            })
1318            .ok();
1319
1320            let old_checkpoint = git_store
1321                .update(cx, |git, cx| git.checkpoint(cx))?
1322                .await
1323                .context("failed to get old checkpoint")
1324                .log_err();
1325            this.update(cx, |this, cx| {
1326                if let Some((_ix, message)) = this.last_user_message() {
1327                    message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1328                        git_checkpoint,
1329                        show: false,
1330                    });
1331                }
1332                this.connection.prompt(message_id, request, cx)
1333            })?
1334            .await
1335        })
1336    }
1337
1338    pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1339        self.run_turn(cx, async move |this, cx| {
1340            this.update(cx, |this, cx| {
1341                this.connection
1342                    .resume(&this.session_id, cx)
1343                    .map(|resume| resume.run(cx))
1344            })?
1345            .context("resuming a session is not supported")?
1346            .await
1347        })
1348    }
1349
1350    fn run_turn(
1351        &mut self,
1352        cx: &mut Context<Self>,
1353        f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1354    ) -> BoxFuture<'static, Result<()>> {
1355        self.clear_completed_plan_entries(cx);
1356
1357        let (tx, rx) = oneshot::channel();
1358        let cancel_task = self.cancel(cx);
1359
1360        self.send_task = Some(cx.spawn(async move |this, cx| {
1361            cancel_task.await;
1362            tx.send(f(this, cx).await).ok();
1363        }));
1364
1365        cx.spawn(async move |this, cx| {
1366            let response = rx.await;
1367
1368            this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1369                .await?;
1370
1371            this.update(cx, |this, cx| {
1372                this.project
1373                    .update(cx, |project, cx| project.set_agent_location(None, cx));
1374                match response {
1375                    Ok(Err(e)) => {
1376                        this.send_task.take();
1377                        cx.emit(AcpThreadEvent::Error);
1378                        Err(e)
1379                    }
1380                    result => {
1381                        let canceled = matches!(
1382                            result,
1383                            Ok(Ok(acp::PromptResponse {
1384                                stop_reason: acp::StopReason::Canceled
1385                            }))
1386                        );
1387
1388                        // We only take the task if the current prompt wasn't canceled.
1389                        //
1390                        // This prompt may have been canceled because another one was sent
1391                        // while it was still generating. In these cases, dropping `send_task`
1392                        // would cause the next generation to be canceled.
1393                        if !canceled {
1394                            this.send_task.take();
1395                        }
1396
1397                        // Truncate entries if the last prompt was refused.
1398                        if let Ok(Ok(acp::PromptResponse {
1399                            stop_reason: acp::StopReason::Refusal,
1400                        })) = result
1401                            && let Some((ix, _)) = this.last_user_message()
1402                        {
1403                            let range = ix..this.entries.len();
1404                            this.entries.truncate(ix);
1405                            cx.emit(AcpThreadEvent::EntriesRemoved(range));
1406                        }
1407
1408                        cx.emit(AcpThreadEvent::Stopped);
1409                        Ok(())
1410                    }
1411                }
1412            })?
1413        })
1414        .boxed()
1415    }
1416
1417    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1418        let Some(send_task) = self.send_task.take() else {
1419            return Task::ready(());
1420        };
1421
1422        for entry in self.entries.iter_mut() {
1423            if let AgentThreadEntry::ToolCall(call) = entry {
1424                let cancel = matches!(
1425                    call.status,
1426                    ToolCallStatus::Pending
1427                        | ToolCallStatus::WaitingForConfirmation { .. }
1428                        | ToolCallStatus::InProgress
1429                );
1430
1431                if cancel {
1432                    call.status = ToolCallStatus::Canceled;
1433                }
1434            }
1435        }
1436
1437        self.connection.cancel(&self.session_id, cx);
1438
1439        // Wait for the send task to complete
1440        cx.foreground_executor().spawn(send_task)
1441    }
1442
1443    /// Rewinds this thread to before the entry at `index`, removing it and all
1444    /// subsequent entries while reverting any changes made from that point.
1445    pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1446        let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else {
1447            return Task::ready(Err(anyhow!("not supported")));
1448        };
1449        let Some(message) = self.user_message(&id) else {
1450            return Task::ready(Err(anyhow!("message not found")));
1451        };
1452
1453        let checkpoint = message
1454            .checkpoint
1455            .as_ref()
1456            .map(|c| c.git_checkpoint.clone());
1457
1458        let git_store = self.project.read(cx).git_store().clone();
1459        cx.spawn(async move |this, cx| {
1460            if let Some(checkpoint) = checkpoint {
1461                git_store
1462                    .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1463                    .await?;
1464            }
1465
1466            cx.update(|cx| session_editor.truncate(id.clone(), cx))?
1467                .await?;
1468            this.update(cx, |this, cx| {
1469                if let Some((ix, _)) = this.user_message_mut(&id) {
1470                    let range = ix..this.entries.len();
1471                    this.entries.truncate(ix);
1472                    cx.emit(AcpThreadEvent::EntriesRemoved(range));
1473                }
1474            })
1475        })
1476    }
1477
1478    fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1479        let git_store = self.project.read(cx).git_store().clone();
1480
1481        let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1482            if let Some(checkpoint) = message.checkpoint.as_ref() {
1483                checkpoint.git_checkpoint.clone()
1484            } else {
1485                return Task::ready(Ok(()));
1486            }
1487        } else {
1488            return Task::ready(Ok(()));
1489        };
1490
1491        let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1492        cx.spawn(async move |this, cx| {
1493            let new_checkpoint = new_checkpoint
1494                .await
1495                .context("failed to get new checkpoint")
1496                .log_err();
1497            if let Some(new_checkpoint) = new_checkpoint {
1498                let equal = git_store
1499                    .update(cx, |git, cx| {
1500                        git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1501                    })?
1502                    .await
1503                    .unwrap_or(true);
1504                this.update(cx, |this, cx| {
1505                    let (ix, message) = this.last_user_message().context("no user message")?;
1506                    let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1507                    checkpoint.show = !equal;
1508                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1509                    anyhow::Ok(())
1510                })??;
1511            }
1512
1513            Ok(())
1514        })
1515    }
1516
1517    fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1518        self.entries
1519            .iter_mut()
1520            .enumerate()
1521            .rev()
1522            .find_map(|(ix, entry)| {
1523                if let AgentThreadEntry::UserMessage(message) = entry {
1524                    Some((ix, message))
1525                } else {
1526                    None
1527                }
1528            })
1529    }
1530
1531    fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1532        self.entries.iter().find_map(|entry| {
1533            if let AgentThreadEntry::UserMessage(message) = entry {
1534                if message.id.as_ref() == Some(id) {
1535                    Some(message)
1536                } else {
1537                    None
1538                }
1539            } else {
1540                None
1541            }
1542        })
1543    }
1544
1545    fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1546        self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1547            if let AgentThreadEntry::UserMessage(message) = entry {
1548                if message.id.as_ref() == Some(id) {
1549                    Some((ix, message))
1550                } else {
1551                    None
1552                }
1553            } else {
1554                None
1555            }
1556        })
1557    }
1558
1559    pub fn read_text_file(
1560        &self,
1561        path: PathBuf,
1562        line: Option<u32>,
1563        limit: Option<u32>,
1564        reuse_shared_snapshot: bool,
1565        cx: &mut Context<Self>,
1566    ) -> Task<Result<String>> {
1567        let project = self.project.clone();
1568        let action_log = self.action_log.clone();
1569        cx.spawn(async move |this, cx| {
1570            let load = project.update(cx, |project, cx| {
1571                let path = project
1572                    .project_path_for_absolute_path(&path, cx)
1573                    .context("invalid path")?;
1574                anyhow::Ok(project.open_buffer(path, cx))
1575            });
1576            let buffer = load??.await?;
1577
1578            let snapshot = if reuse_shared_snapshot {
1579                this.read_with(cx, |this, _| {
1580                    this.shared_buffers.get(&buffer.clone()).cloned()
1581                })
1582                .log_err()
1583                .flatten()
1584            } else {
1585                None
1586            };
1587
1588            let snapshot = if let Some(snapshot) = snapshot {
1589                snapshot
1590            } else {
1591                action_log.update(cx, |action_log, cx| {
1592                    action_log.buffer_read(buffer.clone(), cx);
1593                })?;
1594                project.update(cx, |project, cx| {
1595                    let position = buffer
1596                        .read(cx)
1597                        .snapshot()
1598                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1599                    project.set_agent_location(
1600                        Some(AgentLocation {
1601                            buffer: buffer.downgrade(),
1602                            position,
1603                        }),
1604                        cx,
1605                    );
1606                })?;
1607
1608                buffer.update(cx, |buffer, _| buffer.snapshot())?
1609            };
1610
1611            this.update(cx, |this, _| {
1612                let text = snapshot.text();
1613                this.shared_buffers.insert(buffer.clone(), snapshot);
1614                if line.is_none() && limit.is_none() {
1615                    return Ok(text);
1616                }
1617                let limit = limit.unwrap_or(u32::MAX) as usize;
1618                let Some(line) = line else {
1619                    return Ok(text.lines().take(limit).collect::<String>());
1620                };
1621
1622                let count = text.lines().count();
1623                if count < line as usize {
1624                    anyhow::bail!("There are only {} lines", count);
1625                }
1626                Ok(text
1627                    .lines()
1628                    .skip(line as usize + 1)
1629                    .take(limit)
1630                    .collect::<String>())
1631            })?
1632        })
1633    }
1634
1635    pub fn write_text_file(
1636        &self,
1637        path: PathBuf,
1638        content: String,
1639        cx: &mut Context<Self>,
1640    ) -> Task<Result<()>> {
1641        let project = self.project.clone();
1642        let action_log = self.action_log.clone();
1643        cx.spawn(async move |this, cx| {
1644            let load = project.update(cx, |project, cx| {
1645                let path = project
1646                    .project_path_for_absolute_path(&path, cx)
1647                    .context("invalid path")?;
1648                anyhow::Ok(project.open_buffer(path, cx))
1649            });
1650            let buffer = load??.await?;
1651            let snapshot = this.update(cx, |this, cx| {
1652                this.shared_buffers
1653                    .get(&buffer)
1654                    .cloned()
1655                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1656            })?;
1657            let edits = cx
1658                .background_executor()
1659                .spawn(async move {
1660                    let old_text = snapshot.text();
1661                    text_diff(old_text.as_str(), &content)
1662                        .into_iter()
1663                        .map(|(range, replacement)| {
1664                            (
1665                                snapshot.anchor_after(range.start)
1666                                    ..snapshot.anchor_before(range.end),
1667                                replacement,
1668                            )
1669                        })
1670                        .collect::<Vec<_>>()
1671                })
1672                .await;
1673
1674            project.update(cx, |project, cx| {
1675                project.set_agent_location(
1676                    Some(AgentLocation {
1677                        buffer: buffer.downgrade(),
1678                        position: edits
1679                            .last()
1680                            .map(|(range, _)| range.end)
1681                            .unwrap_or(Anchor::MIN),
1682                    }),
1683                    cx,
1684                );
1685            })?;
1686
1687            let format_on_save = cx.update(|cx| {
1688                action_log.update(cx, |action_log, cx| {
1689                    action_log.buffer_read(buffer.clone(), cx);
1690                });
1691
1692                let format_on_save = buffer.update(cx, |buffer, cx| {
1693                    buffer.edit(edits, None, cx);
1694
1695                    let settings = language::language_settings::language_settings(
1696                        buffer.language().map(|l| l.name()),
1697                        buffer.file(),
1698                        cx,
1699                    );
1700
1701                    settings.format_on_save != FormatOnSave::Off
1702                });
1703                action_log.update(cx, |action_log, cx| {
1704                    action_log.buffer_edited(buffer.clone(), cx);
1705                });
1706                format_on_save
1707            })?;
1708
1709            if format_on_save {
1710                let format_task = project.update(cx, |project, cx| {
1711                    project.format(
1712                        HashSet::from_iter([buffer.clone()]),
1713                        LspFormatTarget::Buffers,
1714                        false,
1715                        FormatTrigger::Save,
1716                        cx,
1717                    )
1718                })?;
1719                format_task.await.log_err();
1720
1721                action_log.update(cx, |action_log, cx| {
1722                    action_log.buffer_edited(buffer.clone(), cx);
1723                })?;
1724            }
1725
1726            project
1727                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1728                .await
1729        })
1730    }
1731
1732    pub fn to_markdown(&self, cx: &App) -> String {
1733        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1734    }
1735
1736    pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
1737        cx.emit(AcpThreadEvent::LoadError(error));
1738    }
1739}
1740
1741fn markdown_for_raw_output(
1742    raw_output: &serde_json::Value,
1743    language_registry: &Arc<LanguageRegistry>,
1744    cx: &mut App,
1745) -> Option<Entity<Markdown>> {
1746    match raw_output {
1747        serde_json::Value::Null => None,
1748        serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1749            Markdown::new(
1750                value.to_string().into(),
1751                Some(language_registry.clone()),
1752                None,
1753                cx,
1754            )
1755        })),
1756        serde_json::Value::Number(value) => Some(cx.new(|cx| {
1757            Markdown::new(
1758                value.to_string().into(),
1759                Some(language_registry.clone()),
1760                None,
1761                cx,
1762            )
1763        })),
1764        serde_json::Value::String(value) => Some(cx.new(|cx| {
1765            Markdown::new(
1766                value.clone().into(),
1767                Some(language_registry.clone()),
1768                None,
1769                cx,
1770            )
1771        })),
1772        value => Some(cx.new(|cx| {
1773            Markdown::new(
1774                format!("```json\n{}\n```", value).into(),
1775                Some(language_registry.clone()),
1776                None,
1777                cx,
1778            )
1779        })),
1780    }
1781}
1782
1783#[cfg(test)]
1784mod tests {
1785    use super::*;
1786    use anyhow::anyhow;
1787    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1788    use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1789    use indoc::indoc;
1790    use project::{FakeFs, Fs};
1791    use rand::Rng as _;
1792    use serde_json::json;
1793    use settings::SettingsStore;
1794    use smol::stream::StreamExt as _;
1795    use std::{
1796        any::Any,
1797        cell::RefCell,
1798        path::Path,
1799        rc::Rc,
1800        sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1801        time::Duration,
1802    };
1803    use util::path;
1804
1805    fn init_test(cx: &mut TestAppContext) {
1806        env_logger::try_init().ok();
1807        cx.update(|cx| {
1808            let settings_store = SettingsStore::test(cx);
1809            cx.set_global(settings_store);
1810            Project::init_settings(cx);
1811            language::init(cx);
1812        });
1813    }
1814
1815    #[gpui::test]
1816    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1817        init_test(cx);
1818
1819        let fs = FakeFs::new(cx.executor());
1820        let project = Project::test(fs, [], cx).await;
1821        let connection = Rc::new(FakeAgentConnection::new());
1822        let thread = cx
1823            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1824            .await
1825            .unwrap();
1826
1827        // Test creating a new user message
1828        thread.update(cx, |thread, cx| {
1829            thread.push_user_content_block(
1830                None,
1831                acp::ContentBlock::Text(acp::TextContent {
1832                    annotations: None,
1833                    text: "Hello, ".to_string(),
1834                }),
1835                cx,
1836            );
1837        });
1838
1839        thread.update(cx, |thread, cx| {
1840            assert_eq!(thread.entries.len(), 1);
1841            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1842                assert_eq!(user_msg.id, None);
1843                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1844            } else {
1845                panic!("Expected UserMessage");
1846            }
1847        });
1848
1849        // Test appending to existing user message
1850        let message_1_id = UserMessageId::new();
1851        thread.update(cx, |thread, cx| {
1852            thread.push_user_content_block(
1853                Some(message_1_id.clone()),
1854                acp::ContentBlock::Text(acp::TextContent {
1855                    annotations: None,
1856                    text: "world!".to_string(),
1857                }),
1858                cx,
1859            );
1860        });
1861
1862        thread.update(cx, |thread, cx| {
1863            assert_eq!(thread.entries.len(), 1);
1864            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1865                assert_eq!(user_msg.id, Some(message_1_id));
1866                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1867            } else {
1868                panic!("Expected UserMessage");
1869            }
1870        });
1871
1872        // Test creating new user message after assistant message
1873        thread.update(cx, |thread, cx| {
1874            thread.push_assistant_content_block(
1875                acp::ContentBlock::Text(acp::TextContent {
1876                    annotations: None,
1877                    text: "Assistant response".to_string(),
1878                }),
1879                false,
1880                cx,
1881            );
1882        });
1883
1884        let message_2_id = UserMessageId::new();
1885        thread.update(cx, |thread, cx| {
1886            thread.push_user_content_block(
1887                Some(message_2_id.clone()),
1888                acp::ContentBlock::Text(acp::TextContent {
1889                    annotations: None,
1890                    text: "New user message".to_string(),
1891                }),
1892                cx,
1893            );
1894        });
1895
1896        thread.update(cx, |thread, cx| {
1897            assert_eq!(thread.entries.len(), 3);
1898            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1899                assert_eq!(user_msg.id, Some(message_2_id));
1900                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1901            } else {
1902                panic!("Expected UserMessage at index 2");
1903            }
1904        });
1905    }
1906
1907    #[gpui::test]
1908    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1909        init_test(cx);
1910
1911        let fs = FakeFs::new(cx.executor());
1912        let project = Project::test(fs, [], cx).await;
1913        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1914            |_, thread, mut cx| {
1915                async move {
1916                    thread.update(&mut cx, |thread, cx| {
1917                        thread
1918                            .handle_session_update(
1919                                acp::SessionUpdate::AgentThoughtChunk {
1920                                    content: "Thinking ".into(),
1921                                },
1922                                cx,
1923                            )
1924                            .unwrap();
1925                        thread
1926                            .handle_session_update(
1927                                acp::SessionUpdate::AgentThoughtChunk {
1928                                    content: "hard!".into(),
1929                                },
1930                                cx,
1931                            )
1932                            .unwrap();
1933                    })?;
1934                    Ok(acp::PromptResponse {
1935                        stop_reason: acp::StopReason::EndTurn,
1936                    })
1937                }
1938                .boxed_local()
1939            },
1940        ));
1941
1942        let thread = cx
1943            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1944            .await
1945            .unwrap();
1946
1947        thread
1948            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1949            .await
1950            .unwrap();
1951
1952        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1953        assert_eq!(
1954            output,
1955            indoc! {r#"
1956            ## User
1957
1958            Hello from Zed!
1959
1960            ## Assistant
1961
1962            <thinking>
1963            Thinking hard!
1964            </thinking>
1965
1966            "#}
1967        );
1968    }
1969
1970    #[gpui::test]
1971    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1972        init_test(cx);
1973
1974        let fs = FakeFs::new(cx.executor());
1975        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1976            .await;
1977        let project = Project::test(fs.clone(), [], cx).await;
1978        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1979        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1980        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1981            move |_, thread, mut cx| {
1982                let read_file_tx = read_file_tx.clone();
1983                async move {
1984                    let content = thread
1985                        .update(&mut cx, |thread, cx| {
1986                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1987                        })
1988                        .unwrap()
1989                        .await
1990                        .unwrap();
1991                    assert_eq!(content, "one\ntwo\nthree\n");
1992                    read_file_tx.take().unwrap().send(()).unwrap();
1993                    thread
1994                        .update(&mut cx, |thread, cx| {
1995                            thread.write_text_file(
1996                                path!("/tmp/foo").into(),
1997                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
1998                                cx,
1999                            )
2000                        })
2001                        .unwrap()
2002                        .await
2003                        .unwrap();
2004                    Ok(acp::PromptResponse {
2005                        stop_reason: acp::StopReason::EndTurn,
2006                    })
2007                }
2008                .boxed_local()
2009            },
2010        ));
2011
2012        let (worktree, pathbuf) = project
2013            .update(cx, |project, cx| {
2014                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2015            })
2016            .await
2017            .unwrap();
2018        let buffer = project
2019            .update(cx, |project, cx| {
2020                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2021            })
2022            .await
2023            .unwrap();
2024
2025        let thread = cx
2026            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2027            .await
2028            .unwrap();
2029
2030        let request = thread.update(cx, |thread, cx| {
2031            thread.send_raw("Extend the count in /tmp/foo", cx)
2032        });
2033        read_file_rx.await.ok();
2034        buffer.update(cx, |buffer, cx| {
2035            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2036        });
2037        cx.run_until_parked();
2038        assert_eq!(
2039            buffer.read_with(cx, |buffer, _| buffer.text()),
2040            "zero\none\ntwo\nthree\nfour\nfive\n"
2041        );
2042        assert_eq!(
2043            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2044            "zero\none\ntwo\nthree\nfour\nfive\n"
2045        );
2046        request.await.unwrap();
2047    }
2048
2049    #[gpui::test]
2050    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2051        init_test(cx);
2052
2053        let fs = FakeFs::new(cx.executor());
2054        let project = Project::test(fs, [], cx).await;
2055        let id = acp::ToolCallId("test".into());
2056
2057        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2058            let id = id.clone();
2059            move |_, thread, mut cx| {
2060                let id = id.clone();
2061                async move {
2062                    thread
2063                        .update(&mut cx, |thread, cx| {
2064                            thread.handle_session_update(
2065                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2066                                    id: id.clone(),
2067                                    title: "Label".into(),
2068                                    kind: acp::ToolKind::Fetch,
2069                                    status: acp::ToolCallStatus::InProgress,
2070                                    content: vec![],
2071                                    locations: vec![],
2072                                    raw_input: None,
2073                                    raw_output: None,
2074                                }),
2075                                cx,
2076                            )
2077                        })
2078                        .unwrap()
2079                        .unwrap();
2080                    Ok(acp::PromptResponse {
2081                        stop_reason: acp::StopReason::EndTurn,
2082                    })
2083                }
2084                .boxed_local()
2085            }
2086        }));
2087
2088        let thread = cx
2089            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2090            .await
2091            .unwrap();
2092
2093        let request = thread.update(cx, |thread, cx| {
2094            thread.send_raw("Fetch https://example.com", cx)
2095        });
2096
2097        run_until_first_tool_call(&thread, cx).await;
2098
2099        thread.read_with(cx, |thread, _| {
2100            assert!(matches!(
2101                thread.entries[1],
2102                AgentThreadEntry::ToolCall(ToolCall {
2103                    status: ToolCallStatus::InProgress,
2104                    ..
2105                })
2106            ));
2107        });
2108
2109        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2110
2111        thread.read_with(cx, |thread, _| {
2112            assert!(matches!(
2113                &thread.entries[1],
2114                AgentThreadEntry::ToolCall(ToolCall {
2115                    status: ToolCallStatus::Canceled,
2116                    ..
2117                })
2118            ));
2119        });
2120
2121        thread
2122            .update(cx, |thread, cx| {
2123                thread.handle_session_update(
2124                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2125                        id,
2126                        fields: acp::ToolCallUpdateFields {
2127                            status: Some(acp::ToolCallStatus::Completed),
2128                            ..Default::default()
2129                        },
2130                    }),
2131                    cx,
2132                )
2133            })
2134            .unwrap();
2135
2136        request.await.unwrap();
2137
2138        thread.read_with(cx, |thread, _| {
2139            assert!(matches!(
2140                thread.entries[1],
2141                AgentThreadEntry::ToolCall(ToolCall {
2142                    status: ToolCallStatus::Completed,
2143                    ..
2144                })
2145            ));
2146        });
2147    }
2148
2149    #[gpui::test]
2150    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2151        init_test(cx);
2152        let fs = FakeFs::new(cx.background_executor.clone());
2153        fs.insert_tree(path!("/test"), json!({})).await;
2154        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2155
2156        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2157            move |_, thread, mut cx| {
2158                async move {
2159                    thread
2160                        .update(&mut cx, |thread, cx| {
2161                            thread.handle_session_update(
2162                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2163                                    id: acp::ToolCallId("test".into()),
2164                                    title: "Label".into(),
2165                                    kind: acp::ToolKind::Edit,
2166                                    status: acp::ToolCallStatus::Completed,
2167                                    content: vec![acp::ToolCallContent::Diff {
2168                                        diff: acp::Diff {
2169                                            path: "/test/test.txt".into(),
2170                                            old_text: None,
2171                                            new_text: "foo".into(),
2172                                        },
2173                                    }],
2174                                    locations: vec![],
2175                                    raw_input: None,
2176                                    raw_output: None,
2177                                }),
2178                                cx,
2179                            )
2180                        })
2181                        .unwrap()
2182                        .unwrap();
2183                    Ok(acp::PromptResponse {
2184                        stop_reason: acp::StopReason::EndTurn,
2185                    })
2186                }
2187                .boxed_local()
2188            }
2189        }));
2190
2191        let thread = cx
2192            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2193            .await
2194            .unwrap();
2195
2196        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2197            .await
2198            .unwrap();
2199
2200        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2201    }
2202
2203    #[gpui::test(iterations = 10)]
2204    async fn test_checkpoints(cx: &mut TestAppContext) {
2205        init_test(cx);
2206        let fs = FakeFs::new(cx.background_executor.clone());
2207        fs.insert_tree(
2208            path!("/test"),
2209            json!({
2210                ".git": {}
2211            }),
2212        )
2213        .await;
2214        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2215
2216        let simulate_changes = Arc::new(AtomicBool::new(true));
2217        let next_filename = Arc::new(AtomicUsize::new(0));
2218        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2219            let simulate_changes = simulate_changes.clone();
2220            let next_filename = next_filename.clone();
2221            let fs = fs.clone();
2222            move |request, thread, mut cx| {
2223                let fs = fs.clone();
2224                let simulate_changes = simulate_changes.clone();
2225                let next_filename = next_filename.clone();
2226                async move {
2227                    if simulate_changes.load(SeqCst) {
2228                        let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2229                        fs.write(Path::new(&filename), b"").await?;
2230                    }
2231
2232                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2233                        panic!("expected text content block");
2234                    };
2235                    thread.update(&mut cx, |thread, cx| {
2236                        thread
2237                            .handle_session_update(
2238                                acp::SessionUpdate::AgentMessageChunk {
2239                                    content: content.text.to_uppercase().into(),
2240                                },
2241                                cx,
2242                            )
2243                            .unwrap();
2244                    })?;
2245                    Ok(acp::PromptResponse {
2246                        stop_reason: acp::StopReason::EndTurn,
2247                    })
2248                }
2249                .boxed_local()
2250            }
2251        }));
2252        let thread = cx
2253            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2254            .await
2255            .unwrap();
2256
2257        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2258            .await
2259            .unwrap();
2260        thread.read_with(cx, |thread, cx| {
2261            assert_eq!(
2262                thread.to_markdown(cx),
2263                indoc! {"
2264                    ## User (checkpoint)
2265
2266                    Lorem
2267
2268                    ## Assistant
2269
2270                    LOREM
2271
2272                "}
2273            );
2274        });
2275        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2276
2277        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2278            .await
2279            .unwrap();
2280        thread.read_with(cx, |thread, cx| {
2281            assert_eq!(
2282                thread.to_markdown(cx),
2283                indoc! {"
2284                    ## User (checkpoint)
2285
2286                    Lorem
2287
2288                    ## Assistant
2289
2290                    LOREM
2291
2292                    ## User (checkpoint)
2293
2294                    ipsum
2295
2296                    ## Assistant
2297
2298                    IPSUM
2299
2300                "}
2301            );
2302        });
2303        assert_eq!(
2304            fs.files(),
2305            vec![
2306                Path::new(path!("/test/file-0")),
2307                Path::new(path!("/test/file-1"))
2308            ]
2309        );
2310
2311        // Checkpoint isn't stored when there are no changes.
2312        simulate_changes.store(false, SeqCst);
2313        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2314            .await
2315            .unwrap();
2316        thread.read_with(cx, |thread, cx| {
2317            assert_eq!(
2318                thread.to_markdown(cx),
2319                indoc! {"
2320                    ## User (checkpoint)
2321
2322                    Lorem
2323
2324                    ## Assistant
2325
2326                    LOREM
2327
2328                    ## User (checkpoint)
2329
2330                    ipsum
2331
2332                    ## Assistant
2333
2334                    IPSUM
2335
2336                    ## User
2337
2338                    dolor
2339
2340                    ## Assistant
2341
2342                    DOLOR
2343
2344                "}
2345            );
2346        });
2347        assert_eq!(
2348            fs.files(),
2349            vec![
2350                Path::new(path!("/test/file-0")),
2351                Path::new(path!("/test/file-1"))
2352            ]
2353        );
2354
2355        // Rewinding the conversation truncates the history and restores the checkpoint.
2356        thread
2357            .update(cx, |thread, cx| {
2358                let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2359                    panic!("unexpected entries {:?}", thread.entries)
2360                };
2361                thread.rewind(message.id.clone().unwrap(), cx)
2362            })
2363            .await
2364            .unwrap();
2365        thread.read_with(cx, |thread, cx| {
2366            assert_eq!(
2367                thread.to_markdown(cx),
2368                indoc! {"
2369                    ## User (checkpoint)
2370
2371                    Lorem
2372
2373                    ## Assistant
2374
2375                    LOREM
2376
2377                "}
2378            );
2379        });
2380        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2381    }
2382
2383    #[gpui::test]
2384    async fn test_refusal(cx: &mut TestAppContext) {
2385        init_test(cx);
2386        let fs = FakeFs::new(cx.background_executor.clone());
2387        fs.insert_tree(path!("/"), json!({})).await;
2388        let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2389
2390        let refuse_next = Arc::new(AtomicBool::new(false));
2391        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2392            let refuse_next = refuse_next.clone();
2393            move |request, thread, mut cx| {
2394                let refuse_next = refuse_next.clone();
2395                async move {
2396                    if refuse_next.load(SeqCst) {
2397                        return Ok(acp::PromptResponse {
2398                            stop_reason: acp::StopReason::Refusal,
2399                        });
2400                    }
2401
2402                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2403                        panic!("expected text content block");
2404                    };
2405                    thread.update(&mut cx, |thread, cx| {
2406                        thread
2407                            .handle_session_update(
2408                                acp::SessionUpdate::AgentMessageChunk {
2409                                    content: content.text.to_uppercase().into(),
2410                                },
2411                                cx,
2412                            )
2413                            .unwrap();
2414                    })?;
2415                    Ok(acp::PromptResponse {
2416                        stop_reason: acp::StopReason::EndTurn,
2417                    })
2418                }
2419                .boxed_local()
2420            }
2421        }));
2422        let thread = cx
2423            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2424            .await
2425            .unwrap();
2426
2427        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2428            .await
2429            .unwrap();
2430        thread.read_with(cx, |thread, cx| {
2431            assert_eq!(
2432                thread.to_markdown(cx),
2433                indoc! {"
2434                    ## User
2435
2436                    hello
2437
2438                    ## Assistant
2439
2440                    HELLO
2441
2442                "}
2443            );
2444        });
2445
2446        // Simulate refusing the second message, ensuring the conversation gets
2447        // truncated to before sending it.
2448        refuse_next.store(true, SeqCst);
2449        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
2450            .await
2451            .unwrap();
2452        thread.read_with(cx, |thread, cx| {
2453            assert_eq!(
2454                thread.to_markdown(cx),
2455                indoc! {"
2456                    ## User
2457
2458                    hello
2459
2460                    ## Assistant
2461
2462                    HELLO
2463
2464                "}
2465            );
2466        });
2467    }
2468
2469    async fn run_until_first_tool_call(
2470        thread: &Entity<AcpThread>,
2471        cx: &mut TestAppContext,
2472    ) -> usize {
2473        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2474
2475        let subscription = cx.update(|cx| {
2476            cx.subscribe(thread, move |thread, _, cx| {
2477                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2478                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2479                        return tx.try_send(ix).unwrap();
2480                    }
2481                }
2482            })
2483        });
2484
2485        select! {
2486            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2487                panic!("Timeout waiting for tool call")
2488            }
2489            ix = rx.next().fuse() => {
2490                drop(subscription);
2491                ix.unwrap()
2492            }
2493        }
2494    }
2495
2496    #[derive(Clone, Default)]
2497    struct FakeAgentConnection {
2498        auth_methods: Vec<acp::AuthMethod>,
2499        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2500        on_user_message: Option<
2501            Rc<
2502                dyn Fn(
2503                        acp::PromptRequest,
2504                        WeakEntity<AcpThread>,
2505                        AsyncApp,
2506                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2507                    + 'static,
2508            >,
2509        >,
2510    }
2511
2512    impl FakeAgentConnection {
2513        fn new() -> Self {
2514            Self {
2515                auth_methods: Vec::new(),
2516                on_user_message: None,
2517                sessions: Arc::default(),
2518            }
2519        }
2520
2521        #[expect(unused)]
2522        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2523            self.auth_methods = auth_methods;
2524            self
2525        }
2526
2527        fn on_user_message(
2528            mut self,
2529            handler: impl Fn(
2530                acp::PromptRequest,
2531                WeakEntity<AcpThread>,
2532                AsyncApp,
2533            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2534            + 'static,
2535        ) -> Self {
2536            self.on_user_message.replace(Rc::new(handler));
2537            self
2538        }
2539    }
2540
2541    impl AgentConnection for FakeAgentConnection {
2542        fn auth_methods(&self) -> &[acp::AuthMethod] {
2543            &self.auth_methods
2544        }
2545
2546        fn new_thread(
2547            self: Rc<Self>,
2548            project: Entity<Project>,
2549            _cwd: &Path,
2550            cx: &mut App,
2551        ) -> Task<gpui::Result<Entity<AcpThread>>> {
2552            let session_id = acp::SessionId(
2553                rand::thread_rng()
2554                    .sample_iter(&rand::distributions::Alphanumeric)
2555                    .take(7)
2556                    .map(char::from)
2557                    .collect::<String>()
2558                    .into(),
2559            );
2560            let action_log = cx.new(|_| ActionLog::new(project.clone()));
2561            let thread = cx.new(|_cx| {
2562                AcpThread::new(
2563                    "Test",
2564                    self.clone(),
2565                    project,
2566                    action_log,
2567                    session_id.clone(),
2568                )
2569            });
2570            self.sessions.lock().insert(session_id, thread.downgrade());
2571            Task::ready(Ok(thread))
2572        }
2573
2574        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2575            if self.auth_methods().iter().any(|m| m.id == method) {
2576                Task::ready(Ok(()))
2577            } else {
2578                Task::ready(Err(anyhow!("Invalid Auth Method")))
2579            }
2580        }
2581
2582        fn prompt(
2583            &self,
2584            _id: Option<UserMessageId>,
2585            params: acp::PromptRequest,
2586            cx: &mut App,
2587        ) -> Task<gpui::Result<acp::PromptResponse>> {
2588            let sessions = self.sessions.lock();
2589            let thread = sessions.get(&params.session_id).unwrap();
2590            if let Some(handler) = &self.on_user_message {
2591                let handler = handler.clone();
2592                let thread = thread.clone();
2593                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2594            } else {
2595                Task::ready(Ok(acp::PromptResponse {
2596                    stop_reason: acp::StopReason::EndTurn,
2597                }))
2598            }
2599        }
2600
2601        fn prompt_capabilities(&self) -> acp::PromptCapabilities {
2602            acp::PromptCapabilities {
2603                image: true,
2604                audio: true,
2605                embedded_context: true,
2606            }
2607        }
2608
2609        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2610            let sessions = self.sessions.lock();
2611            let thread = sessions.get(session_id).unwrap().clone();
2612
2613            cx.spawn(async move |cx| {
2614                thread
2615                    .update(cx, |thread, cx| thread.cancel(cx))
2616                    .unwrap()
2617                    .await
2618            })
2619            .detach();
2620        }
2621
2622        fn session_editor(
2623            &self,
2624            session_id: &acp::SessionId,
2625            _cx: &mut App,
2626        ) -> Option<Rc<dyn AgentSessionEditor>> {
2627            Some(Rc::new(FakeAgentSessionEditor {
2628                _session_id: session_id.clone(),
2629            }))
2630        }
2631
2632        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2633            self
2634        }
2635    }
2636
2637    struct FakeAgentSessionEditor {
2638        _session_id: acp::SessionId,
2639    }
2640
2641    impl AgentSessionEditor for FakeAgentSessionEditor {
2642        fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2643            Task::ready(Ok(()))
2644        }
2645    }
2646}