acp_thread.rs

   1mod connection;
   2mod diff;
   3mod mention;
   4mod terminal;
   5
   6use collections::HashSet;
   7pub use connection::*;
   8pub use diff::*;
   9use language::language_settings::FormatOnSave;
  10pub use mention::*;
  11use project::lsp_store::{FormatTrigger, LspFormatTarget};
  12use serde::{Deserialize, Serialize};
  13pub use terminal::*;
  14
  15use action_log::ActionLog;
  16use agent_client_protocol as acp;
  17use anyhow::{Context as _, Result, anyhow};
  18use editor::Bias;
  19use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  20use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  21use itertools::Itertools;
  22use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
  23use markdown::Markdown;
  24use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
  25use std::collections::HashMap;
  26use std::error::Error;
  27use std::fmt::{Formatter, Write};
  28use std::ops::Range;
  29use std::process::ExitStatus;
  30use std::rc::Rc;
  31use std::time::{Duration, Instant};
  32use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
  33use ui::App;
  34use util::ResultExt;
  35
  36#[derive(Debug)]
  37pub struct UserMessage {
  38    pub id: Option<UserMessageId>,
  39    pub content: ContentBlock,
  40    pub chunks: Vec<acp::ContentBlock>,
  41    pub checkpoint: Option<Checkpoint>,
  42}
  43
  44#[derive(Debug)]
  45pub struct Checkpoint {
  46    git_checkpoint: GitStoreCheckpoint,
  47    pub show: bool,
  48}
  49
  50impl UserMessage {
  51    fn to_markdown(&self, cx: &App) -> String {
  52        let mut markdown = String::new();
  53        if self
  54            .checkpoint
  55            .as_ref()
  56            .is_some_and(|checkpoint| checkpoint.show)
  57        {
  58            writeln!(markdown, "## User (checkpoint)").unwrap();
  59        } else {
  60            writeln!(markdown, "## User").unwrap();
  61        }
  62        writeln!(markdown).unwrap();
  63        writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
  64        writeln!(markdown).unwrap();
  65        markdown
  66    }
  67}
  68
  69#[derive(Debug, PartialEq)]
  70pub struct AssistantMessage {
  71    pub chunks: Vec<AssistantMessageChunk>,
  72}
  73
  74impl AssistantMessage {
  75    pub fn to_markdown(&self, cx: &App) -> String {
  76        format!(
  77            "## Assistant\n\n{}\n\n",
  78            self.chunks
  79                .iter()
  80                .map(|chunk| chunk.to_markdown(cx))
  81                .join("\n\n")
  82        )
  83    }
  84}
  85
  86#[derive(Debug, PartialEq)]
  87pub enum AssistantMessageChunk {
  88    Message { block: ContentBlock },
  89    Thought { block: ContentBlock },
  90}
  91
  92impl AssistantMessageChunk {
  93    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
  94        Self::Message {
  95            block: ContentBlock::new(chunk.into(), language_registry, cx),
  96        }
  97    }
  98
  99    fn to_markdown(&self, cx: &App) -> String {
 100        match self {
 101            Self::Message { block } => block.to_markdown(cx).to_string(),
 102            Self::Thought { block } => {
 103                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 104            }
 105        }
 106    }
 107}
 108
 109#[derive(Debug)]
 110pub enum AgentThreadEntry {
 111    UserMessage(UserMessage),
 112    AssistantMessage(AssistantMessage),
 113    ToolCall(ToolCall),
 114}
 115
 116impl AgentThreadEntry {
 117    pub fn to_markdown(&self, cx: &App) -> String {
 118        match self {
 119            Self::UserMessage(message) => message.to_markdown(cx),
 120            Self::AssistantMessage(message) => message.to_markdown(cx),
 121            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 122        }
 123    }
 124
 125    pub fn user_message(&self) -> Option<&UserMessage> {
 126        if let AgentThreadEntry::UserMessage(message) = self {
 127            Some(message)
 128        } else {
 129            None
 130        }
 131    }
 132
 133    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 134        if let AgentThreadEntry::ToolCall(call) = self {
 135            itertools::Either::Left(call.diffs())
 136        } else {
 137            itertools::Either::Right(std::iter::empty())
 138        }
 139    }
 140
 141    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 142        if let AgentThreadEntry::ToolCall(call) = self {
 143            itertools::Either::Left(call.terminals())
 144        } else {
 145            itertools::Either::Right(std::iter::empty())
 146        }
 147    }
 148
 149    pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
 150        if let AgentThreadEntry::ToolCall(ToolCall {
 151            locations,
 152            resolved_locations,
 153            ..
 154        }) = self
 155        {
 156            Some((
 157                locations.get(ix)?.clone(),
 158                resolved_locations.get(ix)?.clone()?,
 159            ))
 160        } else {
 161            None
 162        }
 163    }
 164}
 165
 166#[derive(Debug)]
 167pub struct ToolCall {
 168    pub id: acp::ToolCallId,
 169    pub label: Entity<Markdown>,
 170    pub kind: acp::ToolKind,
 171    pub content: Vec<ToolCallContent>,
 172    pub status: ToolCallStatus,
 173    pub locations: Vec<acp::ToolCallLocation>,
 174    pub resolved_locations: Vec<Option<AgentLocation>>,
 175    pub raw_input: Option<serde_json::Value>,
 176    pub raw_output: Option<serde_json::Value>,
 177}
 178
 179impl ToolCall {
 180    fn from_acp(
 181        tool_call: acp::ToolCall,
 182        status: ToolCallStatus,
 183        language_registry: Arc<LanguageRegistry>,
 184        cx: &mut App,
 185    ) -> Self {
 186        Self {
 187            id: tool_call.id,
 188            label: cx.new(|cx| {
 189                Markdown::new(
 190                    tool_call.title.into(),
 191                    Some(language_registry.clone()),
 192                    None,
 193                    cx,
 194                )
 195            }),
 196            kind: tool_call.kind,
 197            content: tool_call
 198                .content
 199                .into_iter()
 200                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 201                .collect(),
 202            locations: tool_call.locations,
 203            resolved_locations: Vec::default(),
 204            status,
 205            raw_input: tool_call.raw_input,
 206            raw_output: tool_call.raw_output,
 207        }
 208    }
 209
 210    fn update_fields(
 211        &mut self,
 212        fields: acp::ToolCallUpdateFields,
 213        language_registry: Arc<LanguageRegistry>,
 214        cx: &mut App,
 215    ) {
 216        let acp::ToolCallUpdateFields {
 217            kind,
 218            status,
 219            title,
 220            content,
 221            locations,
 222            raw_input,
 223            raw_output,
 224        } = fields;
 225
 226        if let Some(kind) = kind {
 227            self.kind = kind;
 228        }
 229
 230        if let Some(status) = status {
 231            self.status = status.into();
 232        }
 233
 234        if let Some(title) = title {
 235            self.label.update(cx, |label, cx| {
 236                label.replace(title, cx);
 237            });
 238        }
 239
 240        if let Some(content) = content {
 241            self.content = content
 242                .into_iter()
 243                .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
 244                .collect();
 245        }
 246
 247        if let Some(locations) = locations {
 248            self.locations = locations;
 249        }
 250
 251        if let Some(raw_input) = raw_input {
 252            self.raw_input = Some(raw_input);
 253        }
 254
 255        if let Some(raw_output) = raw_output {
 256            if self.content.is_empty()
 257                && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
 258            {
 259                self.content
 260                    .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
 261                        markdown,
 262                    }));
 263            }
 264            self.raw_output = Some(raw_output);
 265        }
 266    }
 267
 268    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 269        self.content.iter().filter_map(|content| match content {
 270            ToolCallContent::Diff(diff) => Some(diff),
 271            ToolCallContent::ContentBlock(_) => None,
 272            ToolCallContent::Terminal(_) => None,
 273        })
 274    }
 275
 276    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 277        self.content.iter().filter_map(|content| match content {
 278            ToolCallContent::Terminal(terminal) => Some(terminal),
 279            ToolCallContent::ContentBlock(_) => None,
 280            ToolCallContent::Diff(_) => None,
 281        })
 282    }
 283
 284    fn to_markdown(&self, cx: &App) -> String {
 285        let mut markdown = format!(
 286            "**Tool Call: {}**\nStatus: {}\n\n",
 287            self.label.read(cx).source(),
 288            self.status
 289        );
 290        for content in &self.content {
 291            markdown.push_str(content.to_markdown(cx).as_str());
 292            markdown.push_str("\n\n");
 293        }
 294        markdown
 295    }
 296
 297    async fn resolve_location(
 298        location: acp::ToolCallLocation,
 299        project: WeakEntity<Project>,
 300        cx: &mut AsyncApp,
 301    ) -> Option<AgentLocation> {
 302        let buffer = project
 303            .update(cx, |project, cx| {
 304                if let Some(path) = project.project_path_for_absolute_path(&location.path, cx) {
 305                    Some(project.open_buffer(path, cx))
 306                } else {
 307                    None
 308                }
 309            })
 310            .ok()??;
 311        let buffer = buffer.await.log_err()?;
 312        let position = buffer
 313            .update(cx, |buffer, _| {
 314                if let Some(row) = location.line {
 315                    let snapshot = buffer.snapshot();
 316                    let column = snapshot.indent_size_for_line(row).len;
 317                    let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
 318                    snapshot.anchor_before(point)
 319                } else {
 320                    Anchor::MIN
 321                }
 322            })
 323            .ok()?;
 324
 325        Some(AgentLocation {
 326            buffer: buffer.downgrade(),
 327            position,
 328        })
 329    }
 330
 331    fn resolve_locations(
 332        &self,
 333        project: Entity<Project>,
 334        cx: &mut App,
 335    ) -> Task<Vec<Option<AgentLocation>>> {
 336        let locations = self.locations.clone();
 337        project.update(cx, |_, cx| {
 338            cx.spawn(async move |project, cx| {
 339                let mut new_locations = Vec::new();
 340                for location in locations {
 341                    new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
 342                }
 343                new_locations
 344            })
 345        })
 346    }
 347}
 348
 349#[derive(Debug)]
 350pub enum ToolCallStatus {
 351    /// The tool call hasn't started running yet, but we start showing it to
 352    /// the user.
 353    Pending,
 354    /// The tool call is waiting for confirmation from the user.
 355    WaitingForConfirmation {
 356        options: Vec<acp::PermissionOption>,
 357        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 358    },
 359    /// The tool call is currently running.
 360    InProgress,
 361    /// The tool call completed successfully.
 362    Completed,
 363    /// The tool call failed.
 364    Failed,
 365    /// The user rejected the tool call.
 366    Rejected,
 367    /// The user canceled generation so the tool call was canceled.
 368    Canceled,
 369}
 370
 371impl From<acp::ToolCallStatus> for ToolCallStatus {
 372    fn from(status: acp::ToolCallStatus) -> Self {
 373        match status {
 374            acp::ToolCallStatus::Pending => Self::Pending,
 375            acp::ToolCallStatus::InProgress => Self::InProgress,
 376            acp::ToolCallStatus::Completed => Self::Completed,
 377            acp::ToolCallStatus::Failed => Self::Failed,
 378        }
 379    }
 380}
 381
 382impl Display for ToolCallStatus {
 383    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 384        write!(
 385            f,
 386            "{}",
 387            match self {
 388                ToolCallStatus::Pending => "Pending",
 389                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 390                ToolCallStatus::InProgress => "In Progress",
 391                ToolCallStatus::Completed => "Completed",
 392                ToolCallStatus::Failed => "Failed",
 393                ToolCallStatus::Rejected => "Rejected",
 394                ToolCallStatus::Canceled => "Canceled",
 395            }
 396        )
 397    }
 398}
 399
 400#[derive(Debug, PartialEq, Clone)]
 401pub enum ContentBlock {
 402    Empty,
 403    Markdown { markdown: Entity<Markdown> },
 404    ResourceLink { resource_link: acp::ResourceLink },
 405}
 406
 407impl ContentBlock {
 408    pub fn new(
 409        block: acp::ContentBlock,
 410        language_registry: &Arc<LanguageRegistry>,
 411        cx: &mut App,
 412    ) -> Self {
 413        let mut this = Self::Empty;
 414        this.append(block, language_registry, cx);
 415        this
 416    }
 417
 418    pub fn new_combined(
 419        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 420        language_registry: Arc<LanguageRegistry>,
 421        cx: &mut App,
 422    ) -> Self {
 423        let mut this = Self::Empty;
 424        for block in blocks {
 425            this.append(block, &language_registry, cx);
 426        }
 427        this
 428    }
 429
 430    pub fn append(
 431        &mut self,
 432        block: acp::ContentBlock,
 433        language_registry: &Arc<LanguageRegistry>,
 434        cx: &mut App,
 435    ) {
 436        if matches!(self, ContentBlock::Empty)
 437            && let acp::ContentBlock::ResourceLink(resource_link) = block
 438        {
 439            *self = ContentBlock::ResourceLink { resource_link };
 440            return;
 441        }
 442
 443        let new_content = self.block_string_contents(block);
 444
 445        match self {
 446            ContentBlock::Empty => {
 447                *self = Self::create_markdown_block(new_content, language_registry, cx);
 448            }
 449            ContentBlock::Markdown { markdown } => {
 450                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 451            }
 452            ContentBlock::ResourceLink { resource_link } => {
 453                let existing_content = Self::resource_link_md(&resource_link.uri);
 454                let combined = format!("{}\n{}", existing_content, new_content);
 455
 456                *self = Self::create_markdown_block(combined, language_registry, cx);
 457            }
 458        }
 459    }
 460
 461    fn create_markdown_block(
 462        content: String,
 463        language_registry: &Arc<LanguageRegistry>,
 464        cx: &mut App,
 465    ) -> ContentBlock {
 466        ContentBlock::Markdown {
 467            markdown: cx
 468                .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
 469        }
 470    }
 471
 472    fn block_string_contents(&self, block: acp::ContentBlock) -> String {
 473        match block {
 474            acp::ContentBlock::Text(text_content) => text_content.text.clone(),
 475            acp::ContentBlock::ResourceLink(resource_link) => {
 476                Self::resource_link_md(&resource_link.uri)
 477            }
 478            acp::ContentBlock::Resource(acp::EmbeddedResource {
 479                resource:
 480                    acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
 481                        uri,
 482                        ..
 483                    }),
 484                ..
 485            }) => Self::resource_link_md(&uri),
 486            acp::ContentBlock::Image(image) => Self::image_md(&image),
 487            acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
 488        }
 489    }
 490
 491    fn resource_link_md(uri: &str) -> String {
 492        if let Some(uri) = MentionUri::parse(uri).log_err() {
 493            uri.as_link().to_string()
 494        } else {
 495            uri.to_string()
 496        }
 497    }
 498
 499    fn image_md(_image: &acp::ImageContent) -> String {
 500        "`Image`".into()
 501    }
 502
 503    fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 504        match self {
 505            ContentBlock::Empty => "",
 506            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 507            ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
 508        }
 509    }
 510
 511    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 512        match self {
 513            ContentBlock::Empty => None,
 514            ContentBlock::Markdown { markdown } => Some(markdown),
 515            ContentBlock::ResourceLink { .. } => None,
 516        }
 517    }
 518
 519    pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
 520        match self {
 521            ContentBlock::ResourceLink { resource_link } => Some(resource_link),
 522            _ => None,
 523        }
 524    }
 525}
 526
 527#[derive(Debug)]
 528pub enum ToolCallContent {
 529    ContentBlock(ContentBlock),
 530    Diff(Entity<Diff>),
 531    Terminal(Entity<Terminal>),
 532}
 533
 534impl ToolCallContent {
 535    pub fn from_acp(
 536        content: acp::ToolCallContent,
 537        language_registry: Arc<LanguageRegistry>,
 538        cx: &mut App,
 539    ) -> Self {
 540        match content {
 541            acp::ToolCallContent::Content { content } => {
 542                Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
 543            }
 544            acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
 545                Diff::finalized(
 546                    diff.path,
 547                    diff.old_text,
 548                    diff.new_text,
 549                    language_registry,
 550                    cx,
 551                )
 552            })),
 553        }
 554    }
 555
 556    pub fn to_markdown(&self, cx: &App) -> String {
 557        match self {
 558            Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
 559            Self::Diff(diff) => diff.read(cx).to_markdown(cx),
 560            Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
 561        }
 562    }
 563}
 564
 565#[derive(Debug, PartialEq)]
 566pub enum ToolCallUpdate {
 567    UpdateFields(acp::ToolCallUpdate),
 568    UpdateDiff(ToolCallUpdateDiff),
 569    UpdateTerminal(ToolCallUpdateTerminal),
 570}
 571
 572impl ToolCallUpdate {
 573    fn id(&self) -> &acp::ToolCallId {
 574        match self {
 575            Self::UpdateFields(update) => &update.id,
 576            Self::UpdateDiff(diff) => &diff.id,
 577            Self::UpdateTerminal(terminal) => &terminal.id,
 578        }
 579    }
 580}
 581
 582impl From<acp::ToolCallUpdate> for ToolCallUpdate {
 583    fn from(update: acp::ToolCallUpdate) -> Self {
 584        Self::UpdateFields(update)
 585    }
 586}
 587
 588impl From<ToolCallUpdateDiff> for ToolCallUpdate {
 589    fn from(diff: ToolCallUpdateDiff) -> Self {
 590        Self::UpdateDiff(diff)
 591    }
 592}
 593
 594#[derive(Debug, PartialEq)]
 595pub struct ToolCallUpdateDiff {
 596    pub id: acp::ToolCallId,
 597    pub diff: Entity<Diff>,
 598}
 599
 600impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
 601    fn from(terminal: ToolCallUpdateTerminal) -> Self {
 602        Self::UpdateTerminal(terminal)
 603    }
 604}
 605
 606#[derive(Debug, PartialEq)]
 607pub struct ToolCallUpdateTerminal {
 608    pub id: acp::ToolCallId,
 609    pub terminal: Entity<Terminal>,
 610}
 611
 612#[derive(Debug, Default)]
 613pub struct Plan {
 614    pub entries: Vec<PlanEntry>,
 615}
 616
 617#[derive(Debug)]
 618pub struct PlanStats<'a> {
 619    pub in_progress_entry: Option<&'a PlanEntry>,
 620    pub pending: u32,
 621    pub completed: u32,
 622}
 623
 624impl Plan {
 625    pub fn is_empty(&self) -> bool {
 626        self.entries.is_empty()
 627    }
 628
 629    pub fn stats(&self) -> PlanStats<'_> {
 630        let mut stats = PlanStats {
 631            in_progress_entry: None,
 632            pending: 0,
 633            completed: 0,
 634        };
 635
 636        for entry in &self.entries {
 637            match &entry.status {
 638                acp::PlanEntryStatus::Pending => {
 639                    stats.pending += 1;
 640                }
 641                acp::PlanEntryStatus::InProgress => {
 642                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 643                }
 644                acp::PlanEntryStatus::Completed => {
 645                    stats.completed += 1;
 646                }
 647            }
 648        }
 649
 650        stats
 651    }
 652}
 653
 654#[derive(Debug)]
 655pub struct PlanEntry {
 656    pub content: Entity<Markdown>,
 657    pub priority: acp::PlanEntryPriority,
 658    pub status: acp::PlanEntryStatus,
 659}
 660
 661impl PlanEntry {
 662    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 663        Self {
 664            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 665            priority: entry.priority,
 666            status: entry.status,
 667        }
 668    }
 669}
 670
 671#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 672pub struct TokenUsage {
 673    pub max_tokens: u64,
 674    pub used_tokens: u64,
 675}
 676
 677impl TokenUsage {
 678    pub fn ratio(&self) -> TokenUsageRatio {
 679        #[cfg(debug_assertions)]
 680        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 681            .unwrap_or("0.8".to_string())
 682            .parse()
 683            .unwrap();
 684        #[cfg(not(debug_assertions))]
 685        let warning_threshold: f32 = 0.8;
 686
 687        // When the maximum is unknown because there is no selected model,
 688        // avoid showing the token limit warning.
 689        if self.max_tokens == 0 {
 690            TokenUsageRatio::Normal
 691        } else if self.used_tokens >= self.max_tokens {
 692            TokenUsageRatio::Exceeded
 693        } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
 694            TokenUsageRatio::Warning
 695        } else {
 696            TokenUsageRatio::Normal
 697        }
 698    }
 699}
 700
 701#[derive(Debug, Clone, PartialEq, Eq)]
 702pub enum TokenUsageRatio {
 703    Normal,
 704    Warning,
 705    Exceeded,
 706}
 707
 708#[derive(Debug, Clone)]
 709pub struct RetryStatus {
 710    pub last_error: SharedString,
 711    pub attempt: usize,
 712    pub max_attempts: usize,
 713    pub started_at: Instant,
 714    pub duration: Duration,
 715}
 716
 717pub struct AcpThread {
 718    title: SharedString,
 719    entries: Vec<AgentThreadEntry>,
 720    plan: Plan,
 721    project: Entity<Project>,
 722    action_log: Entity<ActionLog>,
 723    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 724    send_task: Option<Task<()>>,
 725    connection: Rc<dyn AgentConnection>,
 726    session_id: acp::SessionId,
 727    token_usage: Option<TokenUsage>,
 728}
 729
 730#[derive(Debug)]
 731pub enum AcpThreadEvent {
 732    NewEntry,
 733    TitleUpdated,
 734    TokenUsageUpdated,
 735    EntryUpdated(usize),
 736    EntriesRemoved(Range<usize>),
 737    ToolAuthorizationRequired,
 738    Retry(RetryStatus),
 739    Stopped,
 740    Error,
 741    LoadError(LoadError),
 742}
 743
 744impl EventEmitter<AcpThreadEvent> for AcpThread {}
 745
 746#[derive(PartialEq, Eq)]
 747pub enum ThreadStatus {
 748    Idle,
 749    WaitingForToolConfirmation,
 750    Generating,
 751}
 752
 753#[derive(Debug, Clone)]
 754pub enum LoadError {
 755    NotInstalled {
 756        error_message: SharedString,
 757        install_message: SharedString,
 758        install_command: String,
 759    },
 760    Unsupported {
 761        error_message: SharedString,
 762        upgrade_message: SharedString,
 763        upgrade_command: String,
 764    },
 765    Exited {
 766        status: ExitStatus,
 767    },
 768    Other(SharedString),
 769}
 770
 771impl Display for LoadError {
 772    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 773        match self {
 774            LoadError::NotInstalled { error_message, .. }
 775            | LoadError::Unsupported { error_message, .. } => {
 776                write!(f, "{error_message}")
 777            }
 778            LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
 779            LoadError::Other(msg) => write!(f, "{}", msg),
 780        }
 781    }
 782}
 783
 784impl Error for LoadError {}
 785
 786impl AcpThread {
 787    pub fn new(
 788        title: impl Into<SharedString>,
 789        connection: Rc<dyn AgentConnection>,
 790        project: Entity<Project>,
 791        action_log: Entity<ActionLog>,
 792        session_id: acp::SessionId,
 793    ) -> Self {
 794        Self {
 795            action_log,
 796            shared_buffers: Default::default(),
 797            entries: Default::default(),
 798            plan: Default::default(),
 799            title: title.into(),
 800            project,
 801            send_task: None,
 802            connection,
 803            session_id,
 804            token_usage: None,
 805        }
 806    }
 807
 808    pub fn connection(&self) -> &Rc<dyn AgentConnection> {
 809        &self.connection
 810    }
 811
 812    pub fn action_log(&self) -> &Entity<ActionLog> {
 813        &self.action_log
 814    }
 815
 816    pub fn project(&self) -> &Entity<Project> {
 817        &self.project
 818    }
 819
 820    pub fn title(&self) -> SharedString {
 821        self.title.clone()
 822    }
 823
 824    pub fn entries(&self) -> &[AgentThreadEntry] {
 825        &self.entries
 826    }
 827
 828    pub fn session_id(&self) -> &acp::SessionId {
 829        &self.session_id
 830    }
 831
 832    pub fn status(&self) -> ThreadStatus {
 833        if self.send_task.is_some() {
 834            if self.waiting_for_tool_confirmation() {
 835                ThreadStatus::WaitingForToolConfirmation
 836            } else {
 837                ThreadStatus::Generating
 838            }
 839        } else {
 840            ThreadStatus::Idle
 841        }
 842    }
 843
 844    pub fn token_usage(&self) -> Option<&TokenUsage> {
 845        self.token_usage.as_ref()
 846    }
 847
 848    pub fn has_pending_edit_tool_calls(&self) -> bool {
 849        for entry in self.entries.iter().rev() {
 850            match entry {
 851                AgentThreadEntry::UserMessage(_) => return false,
 852                AgentThreadEntry::ToolCall(
 853                    call @ ToolCall {
 854                        status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
 855                        ..
 856                    },
 857                ) if call.diffs().next().is_some() => {
 858                    return true;
 859                }
 860                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 861            }
 862        }
 863
 864        false
 865    }
 866
 867    pub fn used_tools_since_last_user_message(&self) -> bool {
 868        for entry in self.entries.iter().rev() {
 869            match entry {
 870                AgentThreadEntry::UserMessage(..) => return false,
 871                AgentThreadEntry::AssistantMessage(..) => continue,
 872                AgentThreadEntry::ToolCall(..) => return true,
 873            }
 874        }
 875
 876        false
 877    }
 878
 879    pub fn handle_session_update(
 880        &mut self,
 881        update: acp::SessionUpdate,
 882        cx: &mut Context<Self>,
 883    ) -> Result<(), acp::Error> {
 884        match update {
 885            acp::SessionUpdate::UserMessageChunk { content } => {
 886                self.push_user_content_block(None, content, cx);
 887            }
 888            acp::SessionUpdate::AgentMessageChunk { content } => {
 889                self.push_assistant_content_block(content, false, cx);
 890            }
 891            acp::SessionUpdate::AgentThoughtChunk { content } => {
 892                self.push_assistant_content_block(content, true, cx);
 893            }
 894            acp::SessionUpdate::ToolCall(tool_call) => {
 895                self.upsert_tool_call(tool_call, cx)?;
 896            }
 897            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 898                self.update_tool_call(tool_call_update, cx)?;
 899            }
 900            acp::SessionUpdate::Plan(plan) => {
 901                self.update_plan(plan, cx);
 902            }
 903        }
 904        Ok(())
 905    }
 906
 907    pub fn push_user_content_block(
 908        &mut self,
 909        message_id: Option<UserMessageId>,
 910        chunk: acp::ContentBlock,
 911        cx: &mut Context<Self>,
 912    ) {
 913        let language_registry = self.project.read(cx).languages().clone();
 914        let entries_len = self.entries.len();
 915
 916        if let Some(last_entry) = self.entries.last_mut()
 917            && let AgentThreadEntry::UserMessage(UserMessage {
 918                id,
 919                content,
 920                chunks,
 921                ..
 922            }) = last_entry
 923        {
 924            *id = message_id.or(id.take());
 925            content.append(chunk.clone(), &language_registry, cx);
 926            chunks.push(chunk);
 927            let idx = entries_len - 1;
 928            cx.emit(AcpThreadEvent::EntryUpdated(idx));
 929        } else {
 930            let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
 931            self.push_entry(
 932                AgentThreadEntry::UserMessage(UserMessage {
 933                    id: message_id,
 934                    content,
 935                    chunks: vec![chunk],
 936                    checkpoint: None,
 937                }),
 938                cx,
 939            );
 940        }
 941    }
 942
 943    pub fn push_assistant_content_block(
 944        &mut self,
 945        chunk: acp::ContentBlock,
 946        is_thought: bool,
 947        cx: &mut Context<Self>,
 948    ) {
 949        let language_registry = self.project.read(cx).languages().clone();
 950        let entries_len = self.entries.len();
 951        if let Some(last_entry) = self.entries.last_mut()
 952            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
 953        {
 954            let idx = entries_len - 1;
 955            cx.emit(AcpThreadEvent::EntryUpdated(idx));
 956            match (chunks.last_mut(), is_thought) {
 957                (Some(AssistantMessageChunk::Message { block }), false)
 958                | (Some(AssistantMessageChunk::Thought { block }), true) => {
 959                    block.append(chunk, &language_registry, cx)
 960                }
 961                _ => {
 962                    let block = ContentBlock::new(chunk, &language_registry, cx);
 963                    if is_thought {
 964                        chunks.push(AssistantMessageChunk::Thought { block })
 965                    } else {
 966                        chunks.push(AssistantMessageChunk::Message { block })
 967                    }
 968                }
 969            }
 970        } else {
 971            let block = ContentBlock::new(chunk, &language_registry, cx);
 972            let chunk = if is_thought {
 973                AssistantMessageChunk::Thought { block }
 974            } else {
 975                AssistantMessageChunk::Message { block }
 976            };
 977
 978            self.push_entry(
 979                AgentThreadEntry::AssistantMessage(AssistantMessage {
 980                    chunks: vec![chunk],
 981                }),
 982                cx,
 983            );
 984        }
 985    }
 986
 987    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
 988        self.entries.push(entry);
 989        cx.emit(AcpThreadEvent::NewEntry);
 990    }
 991
 992    pub fn update_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Result<()> {
 993        self.title = title;
 994        cx.emit(AcpThreadEvent::TitleUpdated);
 995        Ok(())
 996    }
 997
 998    pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
 999        self.token_usage = usage;
1000        cx.emit(AcpThreadEvent::TokenUsageUpdated);
1001    }
1002
1003    pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1004        cx.emit(AcpThreadEvent::Retry(status));
1005    }
1006
1007    pub fn update_tool_call(
1008        &mut self,
1009        update: impl Into<ToolCallUpdate>,
1010        cx: &mut Context<Self>,
1011    ) -> Result<()> {
1012        let update = update.into();
1013        let languages = self.project.read(cx).languages().clone();
1014
1015        let (ix, current_call) = self
1016            .tool_call_mut(update.id())
1017            .context("Tool call not found")?;
1018        match update {
1019            ToolCallUpdate::UpdateFields(update) => {
1020                let location_updated = update.fields.locations.is_some();
1021                current_call.update_fields(update.fields, languages, cx);
1022                if location_updated {
1023                    self.resolve_locations(update.id.clone(), cx);
1024                }
1025            }
1026            ToolCallUpdate::UpdateDiff(update) => {
1027                current_call.content.clear();
1028                current_call
1029                    .content
1030                    .push(ToolCallContent::Diff(update.diff));
1031            }
1032            ToolCallUpdate::UpdateTerminal(update) => {
1033                current_call.content.clear();
1034                current_call
1035                    .content
1036                    .push(ToolCallContent::Terminal(update.terminal));
1037            }
1038        }
1039
1040        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1041
1042        Ok(())
1043    }
1044
1045    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1046    pub fn upsert_tool_call(
1047        &mut self,
1048        tool_call: acp::ToolCall,
1049        cx: &mut Context<Self>,
1050    ) -> Result<(), acp::Error> {
1051        let status = tool_call.status.into();
1052        self.upsert_tool_call_inner(tool_call.into(), status, cx)
1053    }
1054
1055    /// Fails if id does not match an existing entry.
1056    pub fn upsert_tool_call_inner(
1057        &mut self,
1058        tool_call_update: acp::ToolCallUpdate,
1059        status: ToolCallStatus,
1060        cx: &mut Context<Self>,
1061    ) -> Result<(), acp::Error> {
1062        let language_registry = self.project.read(cx).languages().clone();
1063        let id = tool_call_update.id.clone();
1064
1065        if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1066            current_call.update_fields(tool_call_update.fields, language_registry, cx);
1067            current_call.status = status;
1068
1069            cx.emit(AcpThreadEvent::EntryUpdated(ix));
1070        } else {
1071            let call =
1072                ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1073            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1074        };
1075
1076        self.resolve_locations(id, cx);
1077        Ok(())
1078    }
1079
1080    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1081        // The tool call we are looking for is typically the last one, or very close to the end.
1082        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1083        self.entries
1084            .iter_mut()
1085            .enumerate()
1086            .rev()
1087            .find_map(|(index, tool_call)| {
1088                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1089                    && &tool_call.id == id
1090                {
1091                    Some((index, tool_call))
1092                } else {
1093                    None
1094                }
1095            })
1096    }
1097
1098    pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1099        self.entries
1100            .iter()
1101            .enumerate()
1102            .rev()
1103            .find_map(|(index, tool_call)| {
1104                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1105                    && &tool_call.id == id
1106                {
1107                    Some((index, tool_call))
1108                } else {
1109                    None
1110                }
1111            })
1112    }
1113
1114    pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1115        let project = self.project.clone();
1116        let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1117            return;
1118        };
1119        let task = tool_call.resolve_locations(project, cx);
1120        cx.spawn(async move |this, cx| {
1121            let resolved_locations = task.await;
1122            this.update(cx, |this, cx| {
1123                let project = this.project.clone();
1124                let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1125                    return;
1126                };
1127                if let Some(Some(location)) = resolved_locations.last() {
1128                    project.update(cx, |project, cx| {
1129                        if let Some(agent_location) = project.agent_location() {
1130                            let should_ignore = agent_location.buffer == location.buffer
1131                                && location
1132                                    .buffer
1133                                    .update(cx, |buffer, _| {
1134                                        let snapshot = buffer.snapshot();
1135                                        let old_position =
1136                                            agent_location.position.to_point(&snapshot);
1137                                        let new_position = location.position.to_point(&snapshot);
1138                                        // ignore this so that when we get updates from the edit tool
1139                                        // the position doesn't reset to the startof line
1140                                        old_position.row == new_position.row
1141                                            && old_position.column > new_position.column
1142                                    })
1143                                    .ok()
1144                                    .unwrap_or_default();
1145                            if !should_ignore {
1146                                project.set_agent_location(Some(location.clone()), cx);
1147                            }
1148                        }
1149                    });
1150                }
1151                if tool_call.resolved_locations != resolved_locations {
1152                    tool_call.resolved_locations = resolved_locations;
1153                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1154                }
1155            })
1156        })
1157        .detach();
1158    }
1159
1160    pub fn request_tool_call_authorization(
1161        &mut self,
1162        tool_call: acp::ToolCallUpdate,
1163        options: Vec<acp::PermissionOption>,
1164        cx: &mut Context<Self>,
1165    ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1166        let (tx, rx) = oneshot::channel();
1167
1168        let status = ToolCallStatus::WaitingForConfirmation {
1169            options,
1170            respond_tx: tx,
1171        };
1172
1173        self.upsert_tool_call_inner(tool_call, status, cx)?;
1174        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1175        Ok(rx)
1176    }
1177
1178    pub fn authorize_tool_call(
1179        &mut self,
1180        id: acp::ToolCallId,
1181        option_id: acp::PermissionOptionId,
1182        option_kind: acp::PermissionOptionKind,
1183        cx: &mut Context<Self>,
1184    ) {
1185        let Some((ix, call)) = self.tool_call_mut(&id) else {
1186            return;
1187        };
1188
1189        let new_status = match option_kind {
1190            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1191                ToolCallStatus::Rejected
1192            }
1193            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1194                ToolCallStatus::InProgress
1195            }
1196        };
1197
1198        let curr_status = mem::replace(&mut call.status, new_status);
1199
1200        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1201            respond_tx.send(option_id).log_err();
1202        } else if cfg!(debug_assertions) {
1203            panic!("tried to authorize an already authorized tool call");
1204        }
1205
1206        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1207    }
1208
1209    /// Returns true if the last turn is awaiting tool authorization
1210    pub fn waiting_for_tool_confirmation(&self) -> bool {
1211        for entry in self.entries.iter().rev() {
1212            match &entry {
1213                AgentThreadEntry::ToolCall(call) => match call.status {
1214                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
1215                    ToolCallStatus::Pending
1216                    | ToolCallStatus::InProgress
1217                    | ToolCallStatus::Completed
1218                    | ToolCallStatus::Failed
1219                    | ToolCallStatus::Rejected
1220                    | ToolCallStatus::Canceled => continue,
1221                },
1222                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1223                    // Reached the beginning of the turn
1224                    return false;
1225                }
1226            }
1227        }
1228        false
1229    }
1230
1231    pub fn plan(&self) -> &Plan {
1232        &self.plan
1233    }
1234
1235    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1236        let new_entries_len = request.entries.len();
1237        let mut new_entries = request.entries.into_iter();
1238
1239        // Reuse existing markdown to prevent flickering
1240        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1241            let PlanEntry {
1242                content,
1243                priority,
1244                status,
1245            } = old;
1246            content.update(cx, |old, cx| {
1247                old.replace(new.content, cx);
1248            });
1249            *priority = new.priority;
1250            *status = new.status;
1251        }
1252        for new in new_entries {
1253            self.plan.entries.push(PlanEntry::from_acp(new, cx))
1254        }
1255        self.plan.entries.truncate(new_entries_len);
1256
1257        cx.notify();
1258    }
1259
1260    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1261        self.plan
1262            .entries
1263            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1264        cx.notify();
1265    }
1266
1267    #[cfg(any(test, feature = "test-support"))]
1268    pub fn send_raw(
1269        &mut self,
1270        message: &str,
1271        cx: &mut Context<Self>,
1272    ) -> BoxFuture<'static, Result<()>> {
1273        self.send(
1274            vec![acp::ContentBlock::Text(acp::TextContent {
1275                text: message.to_string(),
1276                annotations: None,
1277            })],
1278            cx,
1279        )
1280    }
1281
1282    pub fn send(
1283        &mut self,
1284        message: Vec<acp::ContentBlock>,
1285        cx: &mut Context<Self>,
1286    ) -> BoxFuture<'static, Result<()>> {
1287        let block = ContentBlock::new_combined(
1288            message.clone(),
1289            self.project.read(cx).languages().clone(),
1290            cx,
1291        );
1292        let request = acp::PromptRequest {
1293            prompt: message.clone(),
1294            session_id: self.session_id.clone(),
1295        };
1296        let git_store = self.project.read(cx).git_store().clone();
1297
1298        let message_id = if self
1299            .connection
1300            .session_editor(&self.session_id, cx)
1301            .is_some()
1302        {
1303            Some(UserMessageId::new())
1304        } else {
1305            None
1306        };
1307
1308        self.run_turn(cx, async move |this, cx| {
1309            this.update(cx, |this, cx| {
1310                this.push_entry(
1311                    AgentThreadEntry::UserMessage(UserMessage {
1312                        id: message_id.clone(),
1313                        content: block,
1314                        chunks: message,
1315                        checkpoint: None,
1316                    }),
1317                    cx,
1318                );
1319            })
1320            .ok();
1321
1322            let old_checkpoint = git_store
1323                .update(cx, |git, cx| git.checkpoint(cx))?
1324                .await
1325                .context("failed to get old checkpoint")
1326                .log_err();
1327            this.update(cx, |this, cx| {
1328                if let Some((_ix, message)) = this.last_user_message() {
1329                    message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1330                        git_checkpoint,
1331                        show: false,
1332                    });
1333                }
1334                this.connection.prompt(message_id, request, cx)
1335            })?
1336            .await
1337        })
1338    }
1339
1340    pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1341        self.run_turn(cx, async move |this, cx| {
1342            this.update(cx, |this, cx| {
1343                this.connection
1344                    .resume(&this.session_id, cx)
1345                    .map(|resume| resume.run(cx))
1346            })?
1347            .context("resuming a session is not supported")?
1348            .await
1349        })
1350    }
1351
1352    fn run_turn(
1353        &mut self,
1354        cx: &mut Context<Self>,
1355        f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1356    ) -> BoxFuture<'static, Result<()>> {
1357        self.clear_completed_plan_entries(cx);
1358
1359        let (tx, rx) = oneshot::channel();
1360        let cancel_task = self.cancel(cx);
1361
1362        self.send_task = Some(cx.spawn(async move |this, cx| {
1363            cancel_task.await;
1364            tx.send(f(this, cx).await).ok();
1365        }));
1366
1367        cx.spawn(async move |this, cx| {
1368            let response = rx.await;
1369
1370            this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1371                .await?;
1372
1373            this.update(cx, |this, cx| {
1374                this.project
1375                    .update(cx, |project, cx| project.set_agent_location(None, cx));
1376                match response {
1377                    Ok(Err(e)) => {
1378                        this.send_task.take();
1379                        cx.emit(AcpThreadEvent::Error);
1380                        Err(e)
1381                    }
1382                    result => {
1383                        let canceled = matches!(
1384                            result,
1385                            Ok(Ok(acp::PromptResponse {
1386                                stop_reason: acp::StopReason::Canceled
1387                            }))
1388                        );
1389
1390                        // We only take the task if the current prompt wasn't canceled.
1391                        //
1392                        // This prompt may have been canceled because another one was sent
1393                        // while it was still generating. In these cases, dropping `send_task`
1394                        // would cause the next generation to be canceled.
1395                        if !canceled {
1396                            this.send_task.take();
1397                        }
1398
1399                        cx.emit(AcpThreadEvent::Stopped);
1400                        Ok(())
1401                    }
1402                }
1403            })?
1404        })
1405        .boxed()
1406    }
1407
1408    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1409        let Some(send_task) = self.send_task.take() else {
1410            return Task::ready(());
1411        };
1412
1413        for entry in self.entries.iter_mut() {
1414            if let AgentThreadEntry::ToolCall(call) = entry {
1415                let cancel = matches!(
1416                    call.status,
1417                    ToolCallStatus::Pending
1418                        | ToolCallStatus::WaitingForConfirmation { .. }
1419                        | ToolCallStatus::InProgress
1420                );
1421
1422                if cancel {
1423                    call.status = ToolCallStatus::Canceled;
1424                }
1425            }
1426        }
1427
1428        self.connection.cancel(&self.session_id, cx);
1429
1430        // Wait for the send task to complete
1431        cx.foreground_executor().spawn(send_task)
1432    }
1433
1434    /// Rewinds this thread to before the entry at `index`, removing it and all
1435    /// subsequent entries while reverting any changes made from that point.
1436    pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1437        let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else {
1438            return Task::ready(Err(anyhow!("not supported")));
1439        };
1440        let Some(message) = self.user_message(&id) else {
1441            return Task::ready(Err(anyhow!("message not found")));
1442        };
1443
1444        let checkpoint = message
1445            .checkpoint
1446            .as_ref()
1447            .map(|c| c.git_checkpoint.clone());
1448
1449        let git_store = self.project.read(cx).git_store().clone();
1450        cx.spawn(async move |this, cx| {
1451            if let Some(checkpoint) = checkpoint {
1452                git_store
1453                    .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1454                    .await?;
1455            }
1456
1457            cx.update(|cx| session_editor.truncate(id.clone(), cx))?
1458                .await?;
1459            this.update(cx, |this, cx| {
1460                if let Some((ix, _)) = this.user_message_mut(&id) {
1461                    let range = ix..this.entries.len();
1462                    this.entries.truncate(ix);
1463                    cx.emit(AcpThreadEvent::EntriesRemoved(range));
1464                }
1465            })
1466        })
1467    }
1468
1469    fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1470        let git_store = self.project.read(cx).git_store().clone();
1471
1472        let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1473            if let Some(checkpoint) = message.checkpoint.as_ref() {
1474                checkpoint.git_checkpoint.clone()
1475            } else {
1476                return Task::ready(Ok(()));
1477            }
1478        } else {
1479            return Task::ready(Ok(()));
1480        };
1481
1482        let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1483        cx.spawn(async move |this, cx| {
1484            let new_checkpoint = new_checkpoint
1485                .await
1486                .context("failed to get new checkpoint")
1487                .log_err();
1488            if let Some(new_checkpoint) = new_checkpoint {
1489                let equal = git_store
1490                    .update(cx, |git, cx| {
1491                        git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1492                    })?
1493                    .await
1494                    .unwrap_or(true);
1495                this.update(cx, |this, cx| {
1496                    let (ix, message) = this.last_user_message().context("no user message")?;
1497                    let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1498                    checkpoint.show = !equal;
1499                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1500                    anyhow::Ok(())
1501                })??;
1502            }
1503
1504            Ok(())
1505        })
1506    }
1507
1508    fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1509        self.entries
1510            .iter_mut()
1511            .enumerate()
1512            .rev()
1513            .find_map(|(ix, entry)| {
1514                if let AgentThreadEntry::UserMessage(message) = entry {
1515                    Some((ix, message))
1516                } else {
1517                    None
1518                }
1519            })
1520    }
1521
1522    fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1523        self.entries.iter().find_map(|entry| {
1524            if let AgentThreadEntry::UserMessage(message) = entry {
1525                if message.id.as_ref() == Some(id) {
1526                    Some(message)
1527                } else {
1528                    None
1529                }
1530            } else {
1531                None
1532            }
1533        })
1534    }
1535
1536    fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1537        self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1538            if let AgentThreadEntry::UserMessage(message) = entry {
1539                if message.id.as_ref() == Some(id) {
1540                    Some((ix, message))
1541                } else {
1542                    None
1543                }
1544            } else {
1545                None
1546            }
1547        })
1548    }
1549
1550    pub fn read_text_file(
1551        &self,
1552        path: PathBuf,
1553        line: Option<u32>,
1554        limit: Option<u32>,
1555        reuse_shared_snapshot: bool,
1556        cx: &mut Context<Self>,
1557    ) -> Task<Result<String>> {
1558        let project = self.project.clone();
1559        let action_log = self.action_log.clone();
1560        cx.spawn(async move |this, cx| {
1561            let load = project.update(cx, |project, cx| {
1562                let path = project
1563                    .project_path_for_absolute_path(&path, cx)
1564                    .context("invalid path")?;
1565                anyhow::Ok(project.open_buffer(path, cx))
1566            });
1567            let buffer = load??.await?;
1568
1569            let snapshot = if reuse_shared_snapshot {
1570                this.read_with(cx, |this, _| {
1571                    this.shared_buffers.get(&buffer.clone()).cloned()
1572                })
1573                .log_err()
1574                .flatten()
1575            } else {
1576                None
1577            };
1578
1579            let snapshot = if let Some(snapshot) = snapshot {
1580                snapshot
1581            } else {
1582                action_log.update(cx, |action_log, cx| {
1583                    action_log.buffer_read(buffer.clone(), cx);
1584                })?;
1585                project.update(cx, |project, cx| {
1586                    let position = buffer
1587                        .read(cx)
1588                        .snapshot()
1589                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1590                    project.set_agent_location(
1591                        Some(AgentLocation {
1592                            buffer: buffer.downgrade(),
1593                            position,
1594                        }),
1595                        cx,
1596                    );
1597                })?;
1598
1599                buffer.update(cx, |buffer, _| buffer.snapshot())?
1600            };
1601
1602            this.update(cx, |this, _| {
1603                let text = snapshot.text();
1604                this.shared_buffers.insert(buffer.clone(), snapshot);
1605                if line.is_none() && limit.is_none() {
1606                    return Ok(text);
1607                }
1608                let limit = limit.unwrap_or(u32::MAX) as usize;
1609                let Some(line) = line else {
1610                    return Ok(text.lines().take(limit).collect::<String>());
1611                };
1612
1613                let count = text.lines().count();
1614                if count < line as usize {
1615                    anyhow::bail!("There are only {} lines", count);
1616                }
1617                Ok(text
1618                    .lines()
1619                    .skip(line as usize + 1)
1620                    .take(limit)
1621                    .collect::<String>())
1622            })?
1623        })
1624    }
1625
1626    pub fn write_text_file(
1627        &self,
1628        path: PathBuf,
1629        content: String,
1630        cx: &mut Context<Self>,
1631    ) -> Task<Result<()>> {
1632        let project = self.project.clone();
1633        let action_log = self.action_log.clone();
1634        cx.spawn(async move |this, cx| {
1635            let load = project.update(cx, |project, cx| {
1636                let path = project
1637                    .project_path_for_absolute_path(&path, cx)
1638                    .context("invalid path")?;
1639                anyhow::Ok(project.open_buffer(path, cx))
1640            });
1641            let buffer = load??.await?;
1642            let snapshot = this.update(cx, |this, cx| {
1643                this.shared_buffers
1644                    .get(&buffer)
1645                    .cloned()
1646                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1647            })?;
1648            let edits = cx
1649                .background_executor()
1650                .spawn(async move {
1651                    let old_text = snapshot.text();
1652                    text_diff(old_text.as_str(), &content)
1653                        .into_iter()
1654                        .map(|(range, replacement)| {
1655                            (
1656                                snapshot.anchor_after(range.start)
1657                                    ..snapshot.anchor_before(range.end),
1658                                replacement,
1659                            )
1660                        })
1661                        .collect::<Vec<_>>()
1662                })
1663                .await;
1664
1665            project.update(cx, |project, cx| {
1666                project.set_agent_location(
1667                    Some(AgentLocation {
1668                        buffer: buffer.downgrade(),
1669                        position: edits
1670                            .last()
1671                            .map(|(range, _)| range.end)
1672                            .unwrap_or(Anchor::MIN),
1673                    }),
1674                    cx,
1675                );
1676            })?;
1677
1678            let format_on_save = cx.update(|cx| {
1679                action_log.update(cx, |action_log, cx| {
1680                    action_log.buffer_read(buffer.clone(), cx);
1681                });
1682
1683                let format_on_save = buffer.update(cx, |buffer, cx| {
1684                    buffer.edit(edits, None, cx);
1685
1686                    let settings = language::language_settings::language_settings(
1687                        buffer.language().map(|l| l.name()),
1688                        buffer.file(),
1689                        cx,
1690                    );
1691
1692                    settings.format_on_save != FormatOnSave::Off
1693                });
1694                action_log.update(cx, |action_log, cx| {
1695                    action_log.buffer_edited(buffer.clone(), cx);
1696                });
1697                format_on_save
1698            })?;
1699
1700            if format_on_save {
1701                let format_task = project.update(cx, |project, cx| {
1702                    project.format(
1703                        HashSet::from_iter([buffer.clone()]),
1704                        LspFormatTarget::Buffers,
1705                        false,
1706                        FormatTrigger::Save,
1707                        cx,
1708                    )
1709                })?;
1710                format_task.await.log_err();
1711
1712                action_log.update(cx, |action_log, cx| {
1713                    action_log.buffer_edited(buffer.clone(), cx);
1714                })?;
1715            }
1716
1717            project
1718                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1719                .await
1720        })
1721    }
1722
1723    pub fn to_markdown(&self, cx: &App) -> String {
1724        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1725    }
1726
1727    pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
1728        cx.emit(AcpThreadEvent::LoadError(error));
1729    }
1730}
1731
1732fn markdown_for_raw_output(
1733    raw_output: &serde_json::Value,
1734    language_registry: &Arc<LanguageRegistry>,
1735    cx: &mut App,
1736) -> Option<Entity<Markdown>> {
1737    match raw_output {
1738        serde_json::Value::Null => None,
1739        serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1740            Markdown::new(
1741                value.to_string().into(),
1742                Some(language_registry.clone()),
1743                None,
1744                cx,
1745            )
1746        })),
1747        serde_json::Value::Number(value) => Some(cx.new(|cx| {
1748            Markdown::new(
1749                value.to_string().into(),
1750                Some(language_registry.clone()),
1751                None,
1752                cx,
1753            )
1754        })),
1755        serde_json::Value::String(value) => Some(cx.new(|cx| {
1756            Markdown::new(
1757                value.clone().into(),
1758                Some(language_registry.clone()),
1759                None,
1760                cx,
1761            )
1762        })),
1763        value => Some(cx.new(|cx| {
1764            Markdown::new(
1765                format!("```json\n{}\n```", value).into(),
1766                Some(language_registry.clone()),
1767                None,
1768                cx,
1769            )
1770        })),
1771    }
1772}
1773
1774#[cfg(test)]
1775mod tests {
1776    use super::*;
1777    use anyhow::anyhow;
1778    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1779    use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1780    use indoc::indoc;
1781    use project::{FakeFs, Fs};
1782    use rand::Rng as _;
1783    use serde_json::json;
1784    use settings::SettingsStore;
1785    use smol::stream::StreamExt as _;
1786    use std::{
1787        any::Any,
1788        cell::RefCell,
1789        path::Path,
1790        rc::Rc,
1791        sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1792        time::Duration,
1793    };
1794    use util::path;
1795
1796    fn init_test(cx: &mut TestAppContext) {
1797        env_logger::try_init().ok();
1798        cx.update(|cx| {
1799            let settings_store = SettingsStore::test(cx);
1800            cx.set_global(settings_store);
1801            Project::init_settings(cx);
1802            language::init(cx);
1803        });
1804    }
1805
1806    #[gpui::test]
1807    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1808        init_test(cx);
1809
1810        let fs = FakeFs::new(cx.executor());
1811        let project = Project::test(fs, [], cx).await;
1812        let connection = Rc::new(FakeAgentConnection::new());
1813        let thread = cx
1814            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1815            .await
1816            .unwrap();
1817
1818        // Test creating a new user message
1819        thread.update(cx, |thread, cx| {
1820            thread.push_user_content_block(
1821                None,
1822                acp::ContentBlock::Text(acp::TextContent {
1823                    annotations: None,
1824                    text: "Hello, ".to_string(),
1825                }),
1826                cx,
1827            );
1828        });
1829
1830        thread.update(cx, |thread, cx| {
1831            assert_eq!(thread.entries.len(), 1);
1832            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1833                assert_eq!(user_msg.id, None);
1834                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1835            } else {
1836                panic!("Expected UserMessage");
1837            }
1838        });
1839
1840        // Test appending to existing user message
1841        let message_1_id = UserMessageId::new();
1842        thread.update(cx, |thread, cx| {
1843            thread.push_user_content_block(
1844                Some(message_1_id.clone()),
1845                acp::ContentBlock::Text(acp::TextContent {
1846                    annotations: None,
1847                    text: "world!".to_string(),
1848                }),
1849                cx,
1850            );
1851        });
1852
1853        thread.update(cx, |thread, cx| {
1854            assert_eq!(thread.entries.len(), 1);
1855            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1856                assert_eq!(user_msg.id, Some(message_1_id));
1857                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1858            } else {
1859                panic!("Expected UserMessage");
1860            }
1861        });
1862
1863        // Test creating new user message after assistant message
1864        thread.update(cx, |thread, cx| {
1865            thread.push_assistant_content_block(
1866                acp::ContentBlock::Text(acp::TextContent {
1867                    annotations: None,
1868                    text: "Assistant response".to_string(),
1869                }),
1870                false,
1871                cx,
1872            );
1873        });
1874
1875        let message_2_id = UserMessageId::new();
1876        thread.update(cx, |thread, cx| {
1877            thread.push_user_content_block(
1878                Some(message_2_id.clone()),
1879                acp::ContentBlock::Text(acp::TextContent {
1880                    annotations: None,
1881                    text: "New user message".to_string(),
1882                }),
1883                cx,
1884            );
1885        });
1886
1887        thread.update(cx, |thread, cx| {
1888            assert_eq!(thread.entries.len(), 3);
1889            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1890                assert_eq!(user_msg.id, Some(message_2_id));
1891                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1892            } else {
1893                panic!("Expected UserMessage at index 2");
1894            }
1895        });
1896    }
1897
1898    #[gpui::test]
1899    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1900        init_test(cx);
1901
1902        let fs = FakeFs::new(cx.executor());
1903        let project = Project::test(fs, [], cx).await;
1904        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1905            |_, thread, mut cx| {
1906                async move {
1907                    thread.update(&mut cx, |thread, cx| {
1908                        thread
1909                            .handle_session_update(
1910                                acp::SessionUpdate::AgentThoughtChunk {
1911                                    content: "Thinking ".into(),
1912                                },
1913                                cx,
1914                            )
1915                            .unwrap();
1916                        thread
1917                            .handle_session_update(
1918                                acp::SessionUpdate::AgentThoughtChunk {
1919                                    content: "hard!".into(),
1920                                },
1921                                cx,
1922                            )
1923                            .unwrap();
1924                    })?;
1925                    Ok(acp::PromptResponse {
1926                        stop_reason: acp::StopReason::EndTurn,
1927                    })
1928                }
1929                .boxed_local()
1930            },
1931        ));
1932
1933        let thread = cx
1934            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1935            .await
1936            .unwrap();
1937
1938        thread
1939            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1940            .await
1941            .unwrap();
1942
1943        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1944        assert_eq!(
1945            output,
1946            indoc! {r#"
1947            ## User
1948
1949            Hello from Zed!
1950
1951            ## Assistant
1952
1953            <thinking>
1954            Thinking hard!
1955            </thinking>
1956
1957            "#}
1958        );
1959    }
1960
1961    #[gpui::test]
1962    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1963        init_test(cx);
1964
1965        let fs = FakeFs::new(cx.executor());
1966        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1967            .await;
1968        let project = Project::test(fs.clone(), [], cx).await;
1969        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1970        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1971        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1972            move |_, thread, mut cx| {
1973                let read_file_tx = read_file_tx.clone();
1974                async move {
1975                    let content = thread
1976                        .update(&mut cx, |thread, cx| {
1977                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1978                        })
1979                        .unwrap()
1980                        .await
1981                        .unwrap();
1982                    assert_eq!(content, "one\ntwo\nthree\n");
1983                    read_file_tx.take().unwrap().send(()).unwrap();
1984                    thread
1985                        .update(&mut cx, |thread, cx| {
1986                            thread.write_text_file(
1987                                path!("/tmp/foo").into(),
1988                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
1989                                cx,
1990                            )
1991                        })
1992                        .unwrap()
1993                        .await
1994                        .unwrap();
1995                    Ok(acp::PromptResponse {
1996                        stop_reason: acp::StopReason::EndTurn,
1997                    })
1998                }
1999                .boxed_local()
2000            },
2001        ));
2002
2003        let (worktree, pathbuf) = project
2004            .update(cx, |project, cx| {
2005                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2006            })
2007            .await
2008            .unwrap();
2009        let buffer = project
2010            .update(cx, |project, cx| {
2011                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2012            })
2013            .await
2014            .unwrap();
2015
2016        let thread = cx
2017            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2018            .await
2019            .unwrap();
2020
2021        let request = thread.update(cx, |thread, cx| {
2022            thread.send_raw("Extend the count in /tmp/foo", cx)
2023        });
2024        read_file_rx.await.ok();
2025        buffer.update(cx, |buffer, cx| {
2026            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2027        });
2028        cx.run_until_parked();
2029        assert_eq!(
2030            buffer.read_with(cx, |buffer, _| buffer.text()),
2031            "zero\none\ntwo\nthree\nfour\nfive\n"
2032        );
2033        assert_eq!(
2034            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2035            "zero\none\ntwo\nthree\nfour\nfive\n"
2036        );
2037        request.await.unwrap();
2038    }
2039
2040    #[gpui::test]
2041    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2042        init_test(cx);
2043
2044        let fs = FakeFs::new(cx.executor());
2045        let project = Project::test(fs, [], cx).await;
2046        let id = acp::ToolCallId("test".into());
2047
2048        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2049            let id = id.clone();
2050            move |_, thread, mut cx| {
2051                let id = id.clone();
2052                async move {
2053                    thread
2054                        .update(&mut cx, |thread, cx| {
2055                            thread.handle_session_update(
2056                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2057                                    id: id.clone(),
2058                                    title: "Label".into(),
2059                                    kind: acp::ToolKind::Fetch,
2060                                    status: acp::ToolCallStatus::InProgress,
2061                                    content: vec![],
2062                                    locations: vec![],
2063                                    raw_input: None,
2064                                    raw_output: None,
2065                                }),
2066                                cx,
2067                            )
2068                        })
2069                        .unwrap()
2070                        .unwrap();
2071                    Ok(acp::PromptResponse {
2072                        stop_reason: acp::StopReason::EndTurn,
2073                    })
2074                }
2075                .boxed_local()
2076            }
2077        }));
2078
2079        let thread = cx
2080            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2081            .await
2082            .unwrap();
2083
2084        let request = thread.update(cx, |thread, cx| {
2085            thread.send_raw("Fetch https://example.com", cx)
2086        });
2087
2088        run_until_first_tool_call(&thread, cx).await;
2089
2090        thread.read_with(cx, |thread, _| {
2091            assert!(matches!(
2092                thread.entries[1],
2093                AgentThreadEntry::ToolCall(ToolCall {
2094                    status: ToolCallStatus::InProgress,
2095                    ..
2096                })
2097            ));
2098        });
2099
2100        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2101
2102        thread.read_with(cx, |thread, _| {
2103            assert!(matches!(
2104                &thread.entries[1],
2105                AgentThreadEntry::ToolCall(ToolCall {
2106                    status: ToolCallStatus::Canceled,
2107                    ..
2108                })
2109            ));
2110        });
2111
2112        thread
2113            .update(cx, |thread, cx| {
2114                thread.handle_session_update(
2115                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2116                        id,
2117                        fields: acp::ToolCallUpdateFields {
2118                            status: Some(acp::ToolCallStatus::Completed),
2119                            ..Default::default()
2120                        },
2121                    }),
2122                    cx,
2123                )
2124            })
2125            .unwrap();
2126
2127        request.await.unwrap();
2128
2129        thread.read_with(cx, |thread, _| {
2130            assert!(matches!(
2131                thread.entries[1],
2132                AgentThreadEntry::ToolCall(ToolCall {
2133                    status: ToolCallStatus::Completed,
2134                    ..
2135                })
2136            ));
2137        });
2138    }
2139
2140    #[gpui::test]
2141    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2142        init_test(cx);
2143        let fs = FakeFs::new(cx.background_executor.clone());
2144        fs.insert_tree(path!("/test"), json!({})).await;
2145        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2146
2147        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2148            move |_, thread, mut cx| {
2149                async move {
2150                    thread
2151                        .update(&mut cx, |thread, cx| {
2152                            thread.handle_session_update(
2153                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2154                                    id: acp::ToolCallId("test".into()),
2155                                    title: "Label".into(),
2156                                    kind: acp::ToolKind::Edit,
2157                                    status: acp::ToolCallStatus::Completed,
2158                                    content: vec![acp::ToolCallContent::Diff {
2159                                        diff: acp::Diff {
2160                                            path: "/test/test.txt".into(),
2161                                            old_text: None,
2162                                            new_text: "foo".into(),
2163                                        },
2164                                    }],
2165                                    locations: vec![],
2166                                    raw_input: None,
2167                                    raw_output: None,
2168                                }),
2169                                cx,
2170                            )
2171                        })
2172                        .unwrap()
2173                        .unwrap();
2174                    Ok(acp::PromptResponse {
2175                        stop_reason: acp::StopReason::EndTurn,
2176                    })
2177                }
2178                .boxed_local()
2179            }
2180        }));
2181
2182        let thread = cx
2183            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2184            .await
2185            .unwrap();
2186
2187        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2188            .await
2189            .unwrap();
2190
2191        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2192    }
2193
2194    #[gpui::test(iterations = 10)]
2195    async fn test_checkpoints(cx: &mut TestAppContext) {
2196        init_test(cx);
2197        let fs = FakeFs::new(cx.background_executor.clone());
2198        fs.insert_tree(
2199            path!("/test"),
2200            json!({
2201                ".git": {}
2202            }),
2203        )
2204        .await;
2205        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2206
2207        let simulate_changes = Arc::new(AtomicBool::new(true));
2208        let next_filename = Arc::new(AtomicUsize::new(0));
2209        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2210            let simulate_changes = simulate_changes.clone();
2211            let next_filename = next_filename.clone();
2212            let fs = fs.clone();
2213            move |request, thread, mut cx| {
2214                let fs = fs.clone();
2215                let simulate_changes = simulate_changes.clone();
2216                let next_filename = next_filename.clone();
2217                async move {
2218                    if simulate_changes.load(SeqCst) {
2219                        let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2220                        fs.write(Path::new(&filename), b"").await?;
2221                    }
2222
2223                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2224                        panic!("expected text content block");
2225                    };
2226                    thread.update(&mut cx, |thread, cx| {
2227                        thread
2228                            .handle_session_update(
2229                                acp::SessionUpdate::AgentMessageChunk {
2230                                    content: content.text.to_uppercase().into(),
2231                                },
2232                                cx,
2233                            )
2234                            .unwrap();
2235                    })?;
2236                    Ok(acp::PromptResponse {
2237                        stop_reason: acp::StopReason::EndTurn,
2238                    })
2239                }
2240                .boxed_local()
2241            }
2242        }));
2243        let thread = cx
2244            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2245            .await
2246            .unwrap();
2247
2248        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2249            .await
2250            .unwrap();
2251        thread.read_with(cx, |thread, cx| {
2252            assert_eq!(
2253                thread.to_markdown(cx),
2254                indoc! {"
2255                    ## User (checkpoint)
2256
2257                    Lorem
2258
2259                    ## Assistant
2260
2261                    LOREM
2262
2263                "}
2264            );
2265        });
2266        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2267
2268        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2269            .await
2270            .unwrap();
2271        thread.read_with(cx, |thread, cx| {
2272            assert_eq!(
2273                thread.to_markdown(cx),
2274                indoc! {"
2275                    ## User (checkpoint)
2276
2277                    Lorem
2278
2279                    ## Assistant
2280
2281                    LOREM
2282
2283                    ## User (checkpoint)
2284
2285                    ipsum
2286
2287                    ## Assistant
2288
2289                    IPSUM
2290
2291                "}
2292            );
2293        });
2294        assert_eq!(
2295            fs.files(),
2296            vec![
2297                Path::new(path!("/test/file-0")),
2298                Path::new(path!("/test/file-1"))
2299            ]
2300        );
2301
2302        // Checkpoint isn't stored when there are no changes.
2303        simulate_changes.store(false, SeqCst);
2304        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2305            .await
2306            .unwrap();
2307        thread.read_with(cx, |thread, cx| {
2308            assert_eq!(
2309                thread.to_markdown(cx),
2310                indoc! {"
2311                    ## User (checkpoint)
2312
2313                    Lorem
2314
2315                    ## Assistant
2316
2317                    LOREM
2318
2319                    ## User (checkpoint)
2320
2321                    ipsum
2322
2323                    ## Assistant
2324
2325                    IPSUM
2326
2327                    ## User
2328
2329                    dolor
2330
2331                    ## Assistant
2332
2333                    DOLOR
2334
2335                "}
2336            );
2337        });
2338        assert_eq!(
2339            fs.files(),
2340            vec![
2341                Path::new(path!("/test/file-0")),
2342                Path::new(path!("/test/file-1"))
2343            ]
2344        );
2345
2346        // Rewinding the conversation truncates the history and restores the checkpoint.
2347        thread
2348            .update(cx, |thread, cx| {
2349                let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2350                    panic!("unexpected entries {:?}", thread.entries)
2351                };
2352                thread.rewind(message.id.clone().unwrap(), cx)
2353            })
2354            .await
2355            .unwrap();
2356        thread.read_with(cx, |thread, cx| {
2357            assert_eq!(
2358                thread.to_markdown(cx),
2359                indoc! {"
2360                    ## User (checkpoint)
2361
2362                    Lorem
2363
2364                    ## Assistant
2365
2366                    LOREM
2367
2368                "}
2369            );
2370        });
2371        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2372    }
2373
2374    async fn run_until_first_tool_call(
2375        thread: &Entity<AcpThread>,
2376        cx: &mut TestAppContext,
2377    ) -> usize {
2378        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2379
2380        let subscription = cx.update(|cx| {
2381            cx.subscribe(thread, move |thread, _, cx| {
2382                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2383                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2384                        return tx.try_send(ix).unwrap();
2385                    }
2386                }
2387            })
2388        });
2389
2390        select! {
2391            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2392                panic!("Timeout waiting for tool call")
2393            }
2394            ix = rx.next().fuse() => {
2395                drop(subscription);
2396                ix.unwrap()
2397            }
2398        }
2399    }
2400
2401    #[derive(Clone, Default)]
2402    struct FakeAgentConnection {
2403        auth_methods: Vec<acp::AuthMethod>,
2404        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2405        on_user_message: Option<
2406            Rc<
2407                dyn Fn(
2408                        acp::PromptRequest,
2409                        WeakEntity<AcpThread>,
2410                        AsyncApp,
2411                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2412                    + 'static,
2413            >,
2414        >,
2415    }
2416
2417    impl FakeAgentConnection {
2418        fn new() -> Self {
2419            Self {
2420                auth_methods: Vec::new(),
2421                on_user_message: None,
2422                sessions: Arc::default(),
2423            }
2424        }
2425
2426        #[expect(unused)]
2427        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2428            self.auth_methods = auth_methods;
2429            self
2430        }
2431
2432        fn on_user_message(
2433            mut self,
2434            handler: impl Fn(
2435                acp::PromptRequest,
2436                WeakEntity<AcpThread>,
2437                AsyncApp,
2438            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2439            + 'static,
2440        ) -> Self {
2441            self.on_user_message.replace(Rc::new(handler));
2442            self
2443        }
2444    }
2445
2446    impl AgentConnection for FakeAgentConnection {
2447        fn auth_methods(&self) -> &[acp::AuthMethod] {
2448            &self.auth_methods
2449        }
2450
2451        fn new_thread(
2452            self: Rc<Self>,
2453            project: Entity<Project>,
2454            _cwd: &Path,
2455            cx: &mut App,
2456        ) -> Task<gpui::Result<Entity<AcpThread>>> {
2457            let session_id = acp::SessionId(
2458                rand::thread_rng()
2459                    .sample_iter(&rand::distributions::Alphanumeric)
2460                    .take(7)
2461                    .map(char::from)
2462                    .collect::<String>()
2463                    .into(),
2464            );
2465            let action_log = cx.new(|_| ActionLog::new(project.clone()));
2466            let thread = cx.new(|_cx| {
2467                AcpThread::new(
2468                    "Test",
2469                    self.clone(),
2470                    project,
2471                    action_log,
2472                    session_id.clone(),
2473                )
2474            });
2475            self.sessions.lock().insert(session_id, thread.downgrade());
2476            Task::ready(Ok(thread))
2477        }
2478
2479        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2480            if self.auth_methods().iter().any(|m| m.id == method) {
2481                Task::ready(Ok(()))
2482            } else {
2483                Task::ready(Err(anyhow!("Invalid Auth Method")))
2484            }
2485        }
2486
2487        fn prompt(
2488            &self,
2489            _id: Option<UserMessageId>,
2490            params: acp::PromptRequest,
2491            cx: &mut App,
2492        ) -> Task<gpui::Result<acp::PromptResponse>> {
2493            let sessions = self.sessions.lock();
2494            let thread = sessions.get(&params.session_id).unwrap();
2495            if let Some(handler) = &self.on_user_message {
2496                let handler = handler.clone();
2497                let thread = thread.clone();
2498                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2499            } else {
2500                Task::ready(Ok(acp::PromptResponse {
2501                    stop_reason: acp::StopReason::EndTurn,
2502                }))
2503            }
2504        }
2505
2506        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2507            let sessions = self.sessions.lock();
2508            let thread = sessions.get(session_id).unwrap().clone();
2509
2510            cx.spawn(async move |cx| {
2511                thread
2512                    .update(cx, |thread, cx| thread.cancel(cx))
2513                    .unwrap()
2514                    .await
2515            })
2516            .detach();
2517        }
2518
2519        fn session_editor(
2520            &self,
2521            session_id: &acp::SessionId,
2522            _cx: &mut App,
2523        ) -> Option<Rc<dyn AgentSessionEditor>> {
2524            Some(Rc::new(FakeAgentSessionEditor {
2525                _session_id: session_id.clone(),
2526            }))
2527        }
2528
2529        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2530            self
2531        }
2532    }
2533
2534    struct FakeAgentSessionEditor {
2535        _session_id: acp::SessionId,
2536    }
2537
2538    impl AgentSessionEditor for FakeAgentSessionEditor {
2539        fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2540            Task::ready(Ok(()))
2541        }
2542    }
2543}