acp_thread.rs

   1mod connection;
   2mod diff;
   3mod mention;
   4mod terminal;
   5
   6pub use connection::*;
   7pub use diff::*;
   8pub use mention::*;
   9use serde::{Deserialize, Serialize};
  10pub use terminal::*;
  11
  12use action_log::ActionLog;
  13use agent_client_protocol as acp;
  14use anyhow::{Context as _, Result, anyhow};
  15use chrono::{DateTime, Utc};
  16use editor::Bias;
  17use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  18use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  19use itertools::Itertools;
  20use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
  21use markdown::Markdown;
  22use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
  23use std::collections::HashMap;
  24use std::error::Error;
  25use std::fmt::{Formatter, Write};
  26use std::ops::Range;
  27use std::process::ExitStatus;
  28use std::rc::Rc;
  29use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
  30use ui::App;
  31use util::ResultExt;
  32
  33#[derive(Debug)]
  34pub struct UserMessage {
  35    pub id: Option<UserMessageId>,
  36    pub content: ContentBlock,
  37    pub chunks: Vec<acp::ContentBlock>,
  38    pub checkpoint: Option<Checkpoint>,
  39}
  40
  41#[derive(Debug)]
  42pub struct Checkpoint {
  43    git_checkpoint: GitStoreCheckpoint,
  44    pub show: bool,
  45}
  46
  47impl UserMessage {
  48    fn to_markdown(&self, cx: &App) -> String {
  49        let mut markdown = String::new();
  50        if self
  51            .checkpoint
  52            .as_ref()
  53            .map_or(false, |checkpoint| checkpoint.show)
  54        {
  55            writeln!(markdown, "## User (checkpoint)").unwrap();
  56        } else {
  57            writeln!(markdown, "## User").unwrap();
  58        }
  59        writeln!(markdown).unwrap();
  60        writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
  61        writeln!(markdown).unwrap();
  62        markdown
  63    }
  64}
  65
  66#[derive(Debug, PartialEq)]
  67pub struct AssistantMessage {
  68    pub chunks: Vec<AssistantMessageChunk>,
  69}
  70
  71impl AssistantMessage {
  72    pub fn to_markdown(&self, cx: &App) -> String {
  73        format!(
  74            "## Assistant\n\n{}\n\n",
  75            self.chunks
  76                .iter()
  77                .map(|chunk| chunk.to_markdown(cx))
  78                .join("\n\n")
  79        )
  80    }
  81}
  82
  83#[derive(Debug, PartialEq)]
  84pub enum AssistantMessageChunk {
  85    Message { block: ContentBlock },
  86    Thought { block: ContentBlock },
  87}
  88
  89impl AssistantMessageChunk {
  90    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
  91        Self::Message {
  92            block: ContentBlock::new(chunk.into(), language_registry, cx),
  93        }
  94    }
  95
  96    fn to_markdown(&self, cx: &App) -> String {
  97        match self {
  98            Self::Message { block } => block.to_markdown(cx).to_string(),
  99            Self::Thought { block } => {
 100                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 101            }
 102        }
 103    }
 104}
 105
 106#[derive(Debug)]
 107pub enum AgentThreadEntry {
 108    UserMessage(UserMessage),
 109    AssistantMessage(AssistantMessage),
 110    ToolCall(ToolCall),
 111}
 112
 113impl AgentThreadEntry {
 114    pub fn to_markdown(&self, cx: &App) -> String {
 115        match self {
 116            Self::UserMessage(message) => message.to_markdown(cx),
 117            Self::AssistantMessage(message) => message.to_markdown(cx),
 118            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 119        }
 120    }
 121
 122    pub fn user_message(&self) -> Option<&UserMessage> {
 123        if let AgentThreadEntry::UserMessage(message) = self {
 124            Some(message)
 125        } else {
 126            None
 127        }
 128    }
 129
 130    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 131        if let AgentThreadEntry::ToolCall(call) = self {
 132            itertools::Either::Left(call.diffs())
 133        } else {
 134            itertools::Either::Right(std::iter::empty())
 135        }
 136    }
 137
 138    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 139        if let AgentThreadEntry::ToolCall(call) = self {
 140            itertools::Either::Left(call.terminals())
 141        } else {
 142            itertools::Either::Right(std::iter::empty())
 143        }
 144    }
 145
 146    pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
 147        if let AgentThreadEntry::ToolCall(ToolCall {
 148            locations,
 149            resolved_locations,
 150            ..
 151        }) = self
 152        {
 153            Some((
 154                locations.get(ix)?.clone(),
 155                resolved_locations.get(ix)?.clone()?,
 156            ))
 157        } else {
 158            None
 159        }
 160    }
 161}
 162
 163#[derive(Debug)]
 164pub struct ToolCall {
 165    pub id: acp::ToolCallId,
 166    pub label: Entity<Markdown>,
 167    pub kind: acp::ToolKind,
 168    pub content: Vec<ToolCallContent>,
 169    pub status: ToolCallStatus,
 170    pub locations: Vec<acp::ToolCallLocation>,
 171    pub resolved_locations: Vec<Option<AgentLocation>>,
 172    pub raw_input: Option<serde_json::Value>,
 173    pub raw_output: Option<serde_json::Value>,
 174}
 175
 176impl ToolCall {
 177    fn from_acp(
 178        tool_call: acp::ToolCall,
 179        status: ToolCallStatus,
 180        language_registry: Arc<LanguageRegistry>,
 181        cx: &mut App,
 182    ) -> Self {
 183        Self {
 184            id: tool_call.id,
 185            label: cx.new(|cx| {
 186                Markdown::new(
 187                    tool_call.title.into(),
 188                    Some(language_registry.clone()),
 189                    None,
 190                    cx,
 191                )
 192            }),
 193            kind: tool_call.kind,
 194            content: tool_call
 195                .content
 196                .into_iter()
 197                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 198                .collect(),
 199            locations: tool_call.locations,
 200            resolved_locations: Vec::default(),
 201            status,
 202            raw_input: tool_call.raw_input,
 203            raw_output: tool_call.raw_output,
 204        }
 205    }
 206
 207    fn update_fields(
 208        &mut self,
 209        fields: acp::ToolCallUpdateFields,
 210        language_registry: Arc<LanguageRegistry>,
 211        cx: &mut App,
 212    ) {
 213        let acp::ToolCallUpdateFields {
 214            kind,
 215            status,
 216            title,
 217            content,
 218            locations,
 219            raw_input,
 220            raw_output,
 221        } = fields;
 222
 223        if let Some(kind) = kind {
 224            self.kind = kind;
 225        }
 226
 227        if let Some(status) = status {
 228            self.status = status.into();
 229        }
 230
 231        if let Some(title) = title {
 232            self.label.update(cx, |label, cx| {
 233                label.replace(title, cx);
 234            });
 235        }
 236
 237        if let Some(content) = content {
 238            self.content = content
 239                .into_iter()
 240                .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
 241                .collect();
 242        }
 243
 244        if let Some(locations) = locations {
 245            self.locations = locations;
 246        }
 247
 248        if let Some(raw_input) = raw_input {
 249            self.raw_input = Some(raw_input);
 250        }
 251
 252        if let Some(raw_output) = raw_output {
 253            if self.content.is_empty() {
 254                if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
 255                {
 256                    self.content
 257                        .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
 258                            markdown,
 259                        }));
 260                }
 261            }
 262            self.raw_output = Some(raw_output);
 263        }
 264    }
 265
 266    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 267        self.content.iter().filter_map(|content| match content {
 268            ToolCallContent::Diff(diff) => Some(diff),
 269            ToolCallContent::ContentBlock(_) => None,
 270            ToolCallContent::Terminal(_) => None,
 271        })
 272    }
 273
 274    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 275        self.content.iter().filter_map(|content| match content {
 276            ToolCallContent::Terminal(terminal) => Some(terminal),
 277            ToolCallContent::ContentBlock(_) => None,
 278            ToolCallContent::Diff(_) => None,
 279        })
 280    }
 281
 282    fn to_markdown(&self, cx: &App) -> String {
 283        let mut markdown = format!(
 284            "**Tool Call: {}**\nStatus: {}\n\n",
 285            self.label.read(cx).source(),
 286            self.status
 287        );
 288        for content in &self.content {
 289            markdown.push_str(content.to_markdown(cx).as_str());
 290            markdown.push_str("\n\n");
 291        }
 292        markdown
 293    }
 294
 295    async fn resolve_location(
 296        location: acp::ToolCallLocation,
 297        project: WeakEntity<Project>,
 298        cx: &mut AsyncApp,
 299    ) -> Option<AgentLocation> {
 300        let buffer = project
 301            .update(cx, |project, cx| {
 302                if let Some(path) = project.project_path_for_absolute_path(&location.path, cx) {
 303                    Some(project.open_buffer(path, cx))
 304                } else {
 305                    None
 306                }
 307            })
 308            .ok()??;
 309        let buffer = buffer.await.log_err()?;
 310        let position = buffer
 311            .update(cx, |buffer, _| {
 312                if let Some(row) = location.line {
 313                    let snapshot = buffer.snapshot();
 314                    let column = snapshot.indent_size_for_line(row).len;
 315                    let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
 316                    snapshot.anchor_before(point)
 317                } else {
 318                    Anchor::MIN
 319                }
 320            })
 321            .ok()?;
 322
 323        Some(AgentLocation {
 324            buffer: buffer.downgrade(),
 325            position,
 326        })
 327    }
 328
 329    fn resolve_locations(
 330        &self,
 331        project: Entity<Project>,
 332        cx: &mut App,
 333    ) -> Task<Vec<Option<AgentLocation>>> {
 334        let locations = self.locations.clone();
 335        project.update(cx, |_, cx| {
 336            cx.spawn(async move |project, cx| {
 337                let mut new_locations = Vec::new();
 338                for location in locations {
 339                    new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
 340                }
 341                new_locations
 342            })
 343        })
 344    }
 345}
 346
 347#[derive(Debug)]
 348pub enum ToolCallStatus {
 349    /// The tool call hasn't started running yet, but we start showing it to
 350    /// the user.
 351    Pending,
 352    /// The tool call is waiting for confirmation from the user.
 353    WaitingForConfirmation {
 354        options: Vec<acp::PermissionOption>,
 355        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 356    },
 357    /// The tool call is currently running.
 358    InProgress,
 359    /// The tool call completed successfully.
 360    Completed,
 361    /// The tool call failed.
 362    Failed,
 363    /// The user rejected the tool call.
 364    Rejected,
 365    /// The user canceled generation so the tool call was canceled.
 366    Canceled,
 367}
 368
 369impl From<acp::ToolCallStatus> for ToolCallStatus {
 370    fn from(status: acp::ToolCallStatus) -> Self {
 371        match status {
 372            acp::ToolCallStatus::Pending => Self::Pending,
 373            acp::ToolCallStatus::InProgress => Self::InProgress,
 374            acp::ToolCallStatus::Completed => Self::Completed,
 375            acp::ToolCallStatus::Failed => Self::Failed,
 376        }
 377    }
 378}
 379
 380impl Display for ToolCallStatus {
 381    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 382        write!(
 383            f,
 384            "{}",
 385            match self {
 386                ToolCallStatus::Pending => "Pending",
 387                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 388                ToolCallStatus::InProgress => "In Progress",
 389                ToolCallStatus::Completed => "Completed",
 390                ToolCallStatus::Failed => "Failed",
 391                ToolCallStatus::Rejected => "Rejected",
 392                ToolCallStatus::Canceled => "Canceled",
 393            }
 394        )
 395    }
 396}
 397
 398#[derive(Debug, PartialEq, Clone)]
 399pub enum ContentBlock {
 400    Empty,
 401    Markdown { markdown: Entity<Markdown> },
 402    ResourceLink { resource_link: acp::ResourceLink },
 403}
 404
 405impl ContentBlock {
 406    pub fn new(
 407        block: acp::ContentBlock,
 408        language_registry: &Arc<LanguageRegistry>,
 409        cx: &mut App,
 410    ) -> Self {
 411        let mut this = Self::Empty;
 412        this.append(block, language_registry, cx);
 413        this
 414    }
 415
 416    pub fn new_combined(
 417        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 418        language_registry: Arc<LanguageRegistry>,
 419        cx: &mut App,
 420    ) -> Self {
 421        let mut this = Self::Empty;
 422        for block in blocks {
 423            this.append(block, &language_registry, cx);
 424        }
 425        this
 426    }
 427
 428    pub fn append(
 429        &mut self,
 430        block: acp::ContentBlock,
 431        language_registry: &Arc<LanguageRegistry>,
 432        cx: &mut App,
 433    ) {
 434        if matches!(self, ContentBlock::Empty) {
 435            if let acp::ContentBlock::ResourceLink(resource_link) = block {
 436                *self = ContentBlock::ResourceLink { resource_link };
 437                return;
 438            }
 439        }
 440
 441        let new_content = self.block_string_contents(block);
 442
 443        match self {
 444            ContentBlock::Empty => {
 445                *self = Self::create_markdown_block(new_content, language_registry, cx);
 446            }
 447            ContentBlock::Markdown { markdown } => {
 448                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 449            }
 450            ContentBlock::ResourceLink { resource_link } => {
 451                let existing_content = Self::resource_link_md(&resource_link.uri);
 452                let combined = format!("{}\n{}", existing_content, new_content);
 453
 454                *self = Self::create_markdown_block(combined, language_registry, cx);
 455            }
 456        }
 457    }
 458
 459    fn create_markdown_block(
 460        content: String,
 461        language_registry: &Arc<LanguageRegistry>,
 462        cx: &mut App,
 463    ) -> ContentBlock {
 464        ContentBlock::Markdown {
 465            markdown: cx
 466                .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
 467        }
 468    }
 469
 470    fn block_string_contents(&self, block: acp::ContentBlock) -> String {
 471        match block {
 472            acp::ContentBlock::Text(text_content) => text_content.text.clone(),
 473            acp::ContentBlock::ResourceLink(resource_link) => {
 474                Self::resource_link_md(&resource_link.uri)
 475            }
 476            acp::ContentBlock::Resource(acp::EmbeddedResource {
 477                resource:
 478                    acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
 479                        uri,
 480                        ..
 481                    }),
 482                ..
 483            }) => Self::resource_link_md(&uri),
 484            acp::ContentBlock::Image(image) => Self::image_md(&image),
 485            acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
 486        }
 487    }
 488
 489    fn resource_link_md(uri: &str) -> String {
 490        if let Some(uri) = MentionUri::parse(&uri).log_err() {
 491            uri.as_link().to_string()
 492        } else {
 493            uri.to_string()
 494        }
 495    }
 496
 497    fn image_md(_image: &acp::ImageContent) -> String {
 498        "`Image`".into()
 499    }
 500
 501    fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 502        match self {
 503            ContentBlock::Empty => "",
 504            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 505            ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
 506        }
 507    }
 508
 509    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 510        match self {
 511            ContentBlock::Empty => None,
 512            ContentBlock::Markdown { markdown } => Some(markdown),
 513            ContentBlock::ResourceLink { .. } => None,
 514        }
 515    }
 516
 517    pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
 518        match self {
 519            ContentBlock::ResourceLink { resource_link } => Some(resource_link),
 520            _ => None,
 521        }
 522    }
 523}
 524
 525#[derive(Debug)]
 526pub enum ToolCallContent {
 527    ContentBlock(ContentBlock),
 528    Diff(Entity<Diff>),
 529    Terminal(Entity<Terminal>),
 530}
 531
 532impl ToolCallContent {
 533    pub fn from_acp(
 534        content: acp::ToolCallContent,
 535        language_registry: Arc<LanguageRegistry>,
 536        cx: &mut App,
 537    ) -> Self {
 538        match content {
 539            acp::ToolCallContent::Content { content } => {
 540                Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
 541            }
 542            acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
 543                Diff::finalized(
 544                    diff.path,
 545                    diff.old_text,
 546                    diff.new_text,
 547                    language_registry,
 548                    cx,
 549                )
 550            })),
 551        }
 552    }
 553
 554    pub fn to_markdown(&self, cx: &App) -> String {
 555        match self {
 556            Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
 557            Self::Diff(diff) => diff.read(cx).to_markdown(cx),
 558            Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
 559        }
 560    }
 561}
 562
 563#[derive(Debug, PartialEq)]
 564pub enum ToolCallUpdate {
 565    UpdateFields(acp::ToolCallUpdate),
 566    UpdateDiff(ToolCallUpdateDiff),
 567    UpdateTerminal(ToolCallUpdateTerminal),
 568}
 569
 570impl ToolCallUpdate {
 571    fn id(&self) -> &acp::ToolCallId {
 572        match self {
 573            Self::UpdateFields(update) => &update.id,
 574            Self::UpdateDiff(diff) => &diff.id,
 575            Self::UpdateTerminal(terminal) => &terminal.id,
 576        }
 577    }
 578}
 579
 580impl From<acp::ToolCallUpdate> for ToolCallUpdate {
 581    fn from(update: acp::ToolCallUpdate) -> Self {
 582        Self::UpdateFields(update)
 583    }
 584}
 585
 586impl From<ToolCallUpdateDiff> for ToolCallUpdate {
 587    fn from(diff: ToolCallUpdateDiff) -> Self {
 588        Self::UpdateDiff(diff)
 589    }
 590}
 591
 592#[derive(Debug, PartialEq)]
 593pub struct ToolCallUpdateDiff {
 594    pub id: acp::ToolCallId,
 595    pub diff: Entity<Diff>,
 596}
 597
 598impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
 599    fn from(terminal: ToolCallUpdateTerminal) -> Self {
 600        Self::UpdateTerminal(terminal)
 601    }
 602}
 603
 604#[derive(Debug, PartialEq)]
 605pub struct ToolCallUpdateTerminal {
 606    pub id: acp::ToolCallId,
 607    pub terminal: Entity<Terminal>,
 608}
 609
 610#[derive(Debug, Default)]
 611pub struct Plan {
 612    pub entries: Vec<PlanEntry>,
 613}
 614
 615#[derive(Debug)]
 616pub struct PlanStats<'a> {
 617    pub in_progress_entry: Option<&'a PlanEntry>,
 618    pub pending: u32,
 619    pub completed: u32,
 620}
 621
 622impl Plan {
 623    pub fn is_empty(&self) -> bool {
 624        self.entries.is_empty()
 625    }
 626
 627    pub fn stats(&self) -> PlanStats<'_> {
 628        let mut stats = PlanStats {
 629            in_progress_entry: None,
 630            pending: 0,
 631            completed: 0,
 632        };
 633
 634        for entry in &self.entries {
 635            match &entry.status {
 636                acp::PlanEntryStatus::Pending => {
 637                    stats.pending += 1;
 638                }
 639                acp::PlanEntryStatus::InProgress => {
 640                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 641                }
 642                acp::PlanEntryStatus::Completed => {
 643                    stats.completed += 1;
 644                }
 645            }
 646        }
 647
 648        stats
 649    }
 650}
 651
 652#[derive(Debug)]
 653pub struct PlanEntry {
 654    pub content: Entity<Markdown>,
 655    pub priority: acp::PlanEntryPriority,
 656    pub status: acp::PlanEntryStatus,
 657}
 658
 659impl PlanEntry {
 660    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 661        Self {
 662            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 663            priority: entry.priority,
 664            status: entry.status,
 665        }
 666    }
 667}
 668
 669#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
 670pub struct AgentServerName(pub SharedString);
 671
 672#[derive(Debug, Clone, Serialize, Deserialize)]
 673pub struct AcpThreadMetadata {
 674    pub agent: AgentServerName,
 675    pub id: acp::SessionId,
 676    pub title: SharedString,
 677    pub updated_at: DateTime<Utc>,
 678}
 679
 680pub struct AcpThread {
 681    title: SharedString,
 682    entries: Vec<AgentThreadEntry>,
 683    plan: Plan,
 684    project: Entity<Project>,
 685    action_log: Entity<ActionLog>,
 686    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 687    send_task: Option<Task<()>>,
 688    connection: Rc<dyn AgentConnection>,
 689    session_id: acp::SessionId,
 690}
 691
 692#[derive(Debug)]
 693pub enum AcpThreadEvent {
 694    NewEntry,
 695    TitleUpdated,
 696    EntryUpdated(usize),
 697    EntriesRemoved(Range<usize>),
 698    ToolAuthorizationRequired,
 699    Stopped,
 700    Error,
 701    ServerExited(ExitStatus),
 702}
 703
 704impl EventEmitter<AcpThreadEvent> for AcpThread {}
 705
 706#[derive(PartialEq, Eq)]
 707pub enum ThreadStatus {
 708    Idle,
 709    WaitingForToolConfirmation,
 710    Generating,
 711}
 712
 713#[derive(Debug, Clone)]
 714pub enum LoadError {
 715    Unsupported {
 716        error_message: SharedString,
 717        upgrade_message: SharedString,
 718        upgrade_command: String,
 719    },
 720    Exited(i32),
 721    Other(SharedString),
 722}
 723
 724impl Display for LoadError {
 725    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 726        match self {
 727            LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
 728            LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
 729            LoadError::Other(msg) => write!(f, "{}", msg),
 730        }
 731    }
 732}
 733
 734impl Error for LoadError {}
 735
 736impl AcpThread {
 737    pub fn new(
 738        title: impl Into<SharedString>,
 739        connection: Rc<dyn AgentConnection>,
 740        project: Entity<Project>,
 741        session_id: acp::SessionId,
 742        cx: &mut Context<Self>,
 743    ) -> Self {
 744        let action_log = cx.new(|_| ActionLog::new(project.clone()));
 745
 746        Self {
 747            action_log,
 748            shared_buffers: Default::default(),
 749            entries: Default::default(),
 750            plan: Default::default(),
 751            title: title.into(),
 752            project,
 753            send_task: None,
 754            connection,
 755            session_id,
 756        }
 757    }
 758
 759    pub fn connection(&self) -> &Rc<dyn AgentConnection> {
 760        &self.connection
 761    }
 762
 763    pub fn action_log(&self) -> &Entity<ActionLog> {
 764        &self.action_log
 765    }
 766
 767    pub fn project(&self) -> &Entity<Project> {
 768        &self.project
 769    }
 770
 771    pub fn title(&self) -> SharedString {
 772        self.title.clone()
 773    }
 774
 775    pub fn entries(&self) -> &[AgentThreadEntry] {
 776        &self.entries
 777    }
 778
 779    pub fn session_id(&self) -> &acp::SessionId {
 780        &self.session_id
 781    }
 782
 783    pub fn status(&self) -> ThreadStatus {
 784        if self.send_task.is_some() {
 785            if self.waiting_for_tool_confirmation() {
 786                ThreadStatus::WaitingForToolConfirmation
 787            } else {
 788                ThreadStatus::Generating
 789            }
 790        } else {
 791            ThreadStatus::Idle
 792        }
 793    }
 794
 795    pub fn has_pending_edit_tool_calls(&self) -> bool {
 796        for entry in self.entries.iter().rev() {
 797            match entry {
 798                AgentThreadEntry::UserMessage(_) => return false,
 799                AgentThreadEntry::ToolCall(
 800                    call @ ToolCall {
 801                        status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
 802                        ..
 803                    },
 804                ) if call.diffs().next().is_some() => {
 805                    return true;
 806                }
 807                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 808            }
 809        }
 810
 811        false
 812    }
 813
 814    pub fn used_tools_since_last_user_message(&self) -> bool {
 815        for entry in self.entries.iter().rev() {
 816            match entry {
 817                AgentThreadEntry::UserMessage(..) => return false,
 818                AgentThreadEntry::AssistantMessage(..) => continue,
 819                AgentThreadEntry::ToolCall(..) => return true,
 820            }
 821        }
 822
 823        false
 824    }
 825
 826    pub fn handle_session_update(
 827        &mut self,
 828        update: acp::SessionUpdate,
 829        cx: &mut Context<Self>,
 830    ) -> Result<(), acp::Error> {
 831        match update {
 832            acp::SessionUpdate::UserMessageChunk { content } => {
 833                self.push_user_content_block(None, content, cx);
 834            }
 835            acp::SessionUpdate::AgentMessageChunk { content } => {
 836                self.push_assistant_content_block(content, false, cx);
 837            }
 838            acp::SessionUpdate::AgentThoughtChunk { content } => {
 839                self.push_assistant_content_block(content, true, cx);
 840            }
 841            acp::SessionUpdate::ToolCall(tool_call) => {
 842                self.upsert_tool_call(tool_call, cx)?;
 843            }
 844            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 845                self.update_tool_call(tool_call_update, cx)?;
 846            }
 847            acp::SessionUpdate::Plan(plan) => {
 848                self.update_plan(plan, cx);
 849            }
 850        }
 851        Ok(())
 852    }
 853
 854    pub fn push_user_content_block(
 855        &mut self,
 856        message_id: Option<UserMessageId>,
 857        chunk: acp::ContentBlock,
 858        cx: &mut Context<Self>,
 859    ) {
 860        let language_registry = self.project.read(cx).languages().clone();
 861        let entries_len = self.entries.len();
 862
 863        if let Some(last_entry) = self.entries.last_mut()
 864            && let AgentThreadEntry::UserMessage(UserMessage {
 865                id,
 866                content,
 867                chunks,
 868                ..
 869            }) = last_entry
 870        {
 871            *id = message_id.or(id.take());
 872            content.append(chunk.clone(), &language_registry, cx);
 873            chunks.push(chunk);
 874            let idx = entries_len - 1;
 875            cx.emit(AcpThreadEvent::EntryUpdated(idx));
 876        } else {
 877            let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
 878            self.push_entry(
 879                AgentThreadEntry::UserMessage(UserMessage {
 880                    id: message_id,
 881                    content,
 882                    chunks: vec![chunk],
 883                    checkpoint: None,
 884                }),
 885                cx,
 886            );
 887        }
 888    }
 889
 890    pub fn push_assistant_content_block(
 891        &mut self,
 892        chunk: acp::ContentBlock,
 893        is_thought: bool,
 894        cx: &mut Context<Self>,
 895    ) {
 896        let language_registry = self.project.read(cx).languages().clone();
 897        let entries_len = self.entries.len();
 898        if let Some(last_entry) = self.entries.last_mut()
 899            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
 900        {
 901            let idx = entries_len - 1;
 902            cx.emit(AcpThreadEvent::EntryUpdated(idx));
 903            match (chunks.last_mut(), is_thought) {
 904                (Some(AssistantMessageChunk::Message { block }), false)
 905                | (Some(AssistantMessageChunk::Thought { block }), true) => {
 906                    block.append(chunk, &language_registry, cx)
 907                }
 908                _ => {
 909                    let block = ContentBlock::new(chunk, &language_registry, cx);
 910                    if is_thought {
 911                        chunks.push(AssistantMessageChunk::Thought { block })
 912                    } else {
 913                        chunks.push(AssistantMessageChunk::Message { block })
 914                    }
 915                }
 916            }
 917        } else {
 918            let block = ContentBlock::new(chunk, &language_registry, cx);
 919            let chunk = if is_thought {
 920                AssistantMessageChunk::Thought { block }
 921            } else {
 922                AssistantMessageChunk::Message { block }
 923            };
 924
 925            self.push_entry(
 926                AgentThreadEntry::AssistantMessage(AssistantMessage {
 927                    chunks: vec![chunk],
 928                }),
 929                cx,
 930            );
 931        }
 932    }
 933
 934    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
 935        self.entries.push(entry);
 936        cx.emit(AcpThreadEvent::NewEntry);
 937    }
 938
 939    pub fn update_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Result<()> {
 940        self.title = title;
 941        cx.emit(AcpThreadEvent::TitleUpdated);
 942        Ok(())
 943    }
 944
 945    pub fn update_tool_call(
 946        &mut self,
 947        update: impl Into<ToolCallUpdate>,
 948        cx: &mut Context<Self>,
 949    ) -> Result<()> {
 950        let update = update.into();
 951        let languages = self.project.read(cx).languages().clone();
 952
 953        let (ix, current_call) = self
 954            .tool_call_mut(update.id())
 955            .context("Tool call not found")?;
 956        match update {
 957            ToolCallUpdate::UpdateFields(update) => {
 958                let location_updated = update.fields.locations.is_some();
 959                current_call.update_fields(update.fields, languages, cx);
 960                if location_updated {
 961                    self.resolve_locations(update.id.clone(), cx);
 962                }
 963            }
 964            ToolCallUpdate::UpdateDiff(update) => {
 965                current_call.content.clear();
 966                current_call
 967                    .content
 968                    .push(ToolCallContent::Diff(update.diff));
 969            }
 970            ToolCallUpdate::UpdateTerminal(update) => {
 971                current_call.content.clear();
 972                current_call
 973                    .content
 974                    .push(ToolCallContent::Terminal(update.terminal));
 975            }
 976        }
 977
 978        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 979
 980        Ok(())
 981    }
 982
 983    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
 984    pub fn upsert_tool_call(
 985        &mut self,
 986        tool_call: acp::ToolCall,
 987        cx: &mut Context<Self>,
 988    ) -> Result<(), acp::Error> {
 989        let status = tool_call.status.into();
 990        self.upsert_tool_call_inner(tool_call.into(), status, cx)
 991    }
 992
 993    /// Fails if id does not match an existing entry.
 994    pub fn upsert_tool_call_inner(
 995        &mut self,
 996        tool_call_update: acp::ToolCallUpdate,
 997        status: ToolCallStatus,
 998        cx: &mut Context<Self>,
 999    ) -> Result<(), acp::Error> {
1000        let language_registry = self.project.read(cx).languages().clone();
1001        let id = tool_call_update.id.clone();
1002
1003        if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1004            current_call.update_fields(tool_call_update.fields, language_registry, cx);
1005            current_call.status = status;
1006
1007            cx.emit(AcpThreadEvent::EntryUpdated(ix));
1008        } else {
1009            let call =
1010                ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1011            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1012        };
1013
1014        self.resolve_locations(id, cx);
1015        Ok(())
1016    }
1017
1018    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1019        // The tool call we are looking for is typically the last one, or very close to the end.
1020        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1021        self.entries
1022            .iter_mut()
1023            .enumerate()
1024            .rev()
1025            .find_map(|(index, tool_call)| {
1026                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1027                    && &tool_call.id == id
1028                {
1029                    Some((index, tool_call))
1030                } else {
1031                    None
1032                }
1033            })
1034    }
1035
1036    pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1037        let project = self.project.clone();
1038        let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1039            return;
1040        };
1041        let task = tool_call.resolve_locations(project, cx);
1042        cx.spawn(async move |this, cx| {
1043            let resolved_locations = task.await;
1044            this.update(cx, |this, cx| {
1045                let project = this.project.clone();
1046                let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1047                    return;
1048                };
1049                if let Some(Some(location)) = resolved_locations.last() {
1050                    project.update(cx, |project, cx| {
1051                        if let Some(agent_location) = project.agent_location() {
1052                            let should_ignore = agent_location.buffer == location.buffer
1053                                && location
1054                                    .buffer
1055                                    .update(cx, |buffer, _| {
1056                                        let snapshot = buffer.snapshot();
1057                                        let old_position =
1058                                            agent_location.position.to_point(&snapshot);
1059                                        let new_position = location.position.to_point(&snapshot);
1060                                        // ignore this so that when we get updates from the edit tool
1061                                        // the position doesn't reset to the startof line
1062                                        old_position.row == new_position.row
1063                                            && old_position.column > new_position.column
1064                                    })
1065                                    .ok()
1066                                    .unwrap_or_default();
1067                            if !should_ignore {
1068                                project.set_agent_location(Some(location.clone()), cx);
1069                            }
1070                        }
1071                    });
1072                }
1073                if tool_call.resolved_locations != resolved_locations {
1074                    tool_call.resolved_locations = resolved_locations;
1075                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1076                }
1077            })
1078        })
1079        .detach();
1080    }
1081
1082    pub fn request_tool_call_authorization(
1083        &mut self,
1084        tool_call: acp::ToolCallUpdate,
1085        options: Vec<acp::PermissionOption>,
1086        cx: &mut Context<Self>,
1087    ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1088        let (tx, rx) = oneshot::channel();
1089
1090        let status = ToolCallStatus::WaitingForConfirmation {
1091            options,
1092            respond_tx: tx,
1093        };
1094
1095        self.upsert_tool_call_inner(tool_call, status, cx)?;
1096        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1097        Ok(rx)
1098    }
1099
1100    pub fn authorize_tool_call(
1101        &mut self,
1102        id: acp::ToolCallId,
1103        option_id: acp::PermissionOptionId,
1104        option_kind: acp::PermissionOptionKind,
1105        cx: &mut Context<Self>,
1106    ) {
1107        let Some((ix, call)) = self.tool_call_mut(&id) else {
1108            return;
1109        };
1110
1111        let new_status = match option_kind {
1112            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1113                ToolCallStatus::Rejected
1114            }
1115            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1116                ToolCallStatus::InProgress
1117            }
1118        };
1119
1120        let curr_status = mem::replace(&mut call.status, new_status);
1121
1122        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1123            respond_tx.send(option_id).log_err();
1124        } else if cfg!(debug_assertions) {
1125            panic!("tried to authorize an already authorized tool call");
1126        }
1127
1128        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1129    }
1130
1131    /// Returns true if the last turn is awaiting tool authorization
1132    pub fn waiting_for_tool_confirmation(&self) -> bool {
1133        for entry in self.entries.iter().rev() {
1134            match &entry {
1135                AgentThreadEntry::ToolCall(call) => match call.status {
1136                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
1137                    ToolCallStatus::Pending
1138                    | ToolCallStatus::InProgress
1139                    | ToolCallStatus::Completed
1140                    | ToolCallStatus::Failed
1141                    | ToolCallStatus::Rejected
1142                    | ToolCallStatus::Canceled => continue,
1143                },
1144                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1145                    // Reached the beginning of the turn
1146                    return false;
1147                }
1148            }
1149        }
1150        false
1151    }
1152
1153    pub fn plan(&self) -> &Plan {
1154        &self.plan
1155    }
1156
1157    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1158        let new_entries_len = request.entries.len();
1159        let mut new_entries = request.entries.into_iter();
1160
1161        // Reuse existing markdown to prevent flickering
1162        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1163            let PlanEntry {
1164                content,
1165                priority,
1166                status,
1167            } = old;
1168            content.update(cx, |old, cx| {
1169                old.replace(new.content, cx);
1170            });
1171            *priority = new.priority;
1172            *status = new.status;
1173        }
1174        for new in new_entries {
1175            self.plan.entries.push(PlanEntry::from_acp(new, cx))
1176        }
1177        self.plan.entries.truncate(new_entries_len);
1178
1179        cx.notify();
1180    }
1181
1182    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1183        self.plan
1184            .entries
1185            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1186        cx.notify();
1187    }
1188
1189    #[cfg(any(test, feature = "test-support"))]
1190    pub fn send_raw(
1191        &mut self,
1192        message: &str,
1193        cx: &mut Context<Self>,
1194    ) -> BoxFuture<'static, Result<()>> {
1195        self.send(
1196            vec![acp::ContentBlock::Text(acp::TextContent {
1197                text: message.to_string(),
1198                annotations: None,
1199            })],
1200            cx,
1201        )
1202    }
1203
1204    pub fn send(
1205        &mut self,
1206        message: Vec<acp::ContentBlock>,
1207        cx: &mut Context<Self>,
1208    ) -> BoxFuture<'static, Result<()>> {
1209        let block = ContentBlock::new_combined(
1210            message.clone(),
1211            self.project.read(cx).languages().clone(),
1212            cx,
1213        );
1214        let request = acp::PromptRequest {
1215            prompt: message.clone(),
1216            session_id: self.session_id.clone(),
1217        };
1218        let git_store = self.project.read(cx).git_store().clone();
1219
1220        let message_id = if self
1221            .connection
1222            .session_editor(&self.session_id, cx)
1223            .is_some()
1224        {
1225            Some(UserMessageId::new())
1226        } else {
1227            None
1228        };
1229
1230        self.run_turn(cx, async move |this, cx| {
1231            this.update(cx, |this, cx| {
1232                this.push_entry(
1233                    AgentThreadEntry::UserMessage(UserMessage {
1234                        id: message_id.clone(),
1235                        content: block,
1236                        chunks: message,
1237                        checkpoint: None,
1238                    }),
1239                    cx,
1240                );
1241            })
1242            .ok();
1243
1244            let old_checkpoint = git_store
1245                .update(cx, |git, cx| git.checkpoint(cx))?
1246                .await
1247                .context("failed to get old checkpoint")
1248                .log_err();
1249            this.update(cx, |this, cx| {
1250                if let Some((_ix, message)) = this.last_user_message() {
1251                    message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1252                        git_checkpoint,
1253                        show: false,
1254                    });
1255                }
1256                this.connection.prompt(message_id, request, cx)
1257            })?
1258            .await
1259        })
1260    }
1261
1262    pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1263        self.run_turn(cx, async move |this, cx| {
1264            this.update(cx, |this, cx| {
1265                this.connection
1266                    .resume(&this.session_id, cx)
1267                    .map(|resume| resume.run(cx))
1268            })?
1269            .context("resuming a session is not supported")?
1270            .await
1271        })
1272    }
1273
1274    fn run_turn(
1275        &mut self,
1276        cx: &mut Context<Self>,
1277        f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1278    ) -> BoxFuture<'static, Result<()>> {
1279        self.clear_completed_plan_entries(cx);
1280
1281        let (tx, rx) = oneshot::channel();
1282        let cancel_task = self.cancel(cx);
1283
1284        self.send_task = Some(cx.spawn(async move |this, cx| {
1285            cancel_task.await;
1286            tx.send(f(this, cx).await).ok();
1287        }));
1288
1289        cx.spawn(async move |this, cx| {
1290            let response = rx.await;
1291
1292            this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1293                .await?;
1294
1295            this.update(cx, |this, cx| {
1296                match response {
1297                    Ok(Err(e)) => {
1298                        this.send_task.take();
1299                        cx.emit(AcpThreadEvent::Error);
1300                        Err(e)
1301                    }
1302                    result => {
1303                        let canceled = matches!(
1304                            result,
1305                            Ok(Ok(acp::PromptResponse {
1306                                stop_reason: acp::StopReason::Canceled
1307                            }))
1308                        );
1309
1310                        // We only take the task if the current prompt wasn't canceled.
1311                        //
1312                        // This prompt may have been canceled because another one was sent
1313                        // while it was still generating. In these cases, dropping `send_task`
1314                        // would cause the next generation to be canceled.
1315                        if !canceled {
1316                            this.send_task.take();
1317                        }
1318
1319                        cx.emit(AcpThreadEvent::Stopped);
1320                        Ok(())
1321                    }
1322                }
1323            })?
1324        })
1325        .boxed()
1326    }
1327
1328    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1329        let Some(send_task) = self.send_task.take() else {
1330            return Task::ready(());
1331        };
1332
1333        for entry in self.entries.iter_mut() {
1334            if let AgentThreadEntry::ToolCall(call) = entry {
1335                let cancel = matches!(
1336                    call.status,
1337                    ToolCallStatus::Pending
1338                        | ToolCallStatus::WaitingForConfirmation { .. }
1339                        | ToolCallStatus::InProgress
1340                );
1341
1342                if cancel {
1343                    call.status = ToolCallStatus::Canceled;
1344                }
1345            }
1346        }
1347
1348        self.connection.cancel(&self.session_id, cx);
1349
1350        // Wait for the send task to complete
1351        cx.foreground_executor().spawn(send_task)
1352    }
1353
1354    /// Rewinds this thread to before the entry at `index`, removing it and all
1355    /// subsequent entries while reverting any changes made from that point.
1356    pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1357        let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else {
1358            return Task::ready(Err(anyhow!("not supported")));
1359        };
1360        let Some(message) = self.user_message(&id) else {
1361            return Task::ready(Err(anyhow!("message not found")));
1362        };
1363
1364        let checkpoint = message
1365            .checkpoint
1366            .as_ref()
1367            .map(|c| c.git_checkpoint.clone());
1368
1369        let git_store = self.project.read(cx).git_store().clone();
1370        cx.spawn(async move |this, cx| {
1371            if let Some(checkpoint) = checkpoint {
1372                git_store
1373                    .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1374                    .await?;
1375            }
1376
1377            cx.update(|cx| session_editor.truncate(id.clone(), cx))?
1378                .await?;
1379            this.update(cx, |this, cx| {
1380                if let Some((ix, _)) = this.user_message_mut(&id) {
1381                    let range = ix..this.entries.len();
1382                    this.entries.truncate(ix);
1383                    cx.emit(AcpThreadEvent::EntriesRemoved(range));
1384                }
1385            })
1386        })
1387    }
1388
1389    fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1390        let git_store = self.project.read(cx).git_store().clone();
1391
1392        let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1393            if let Some(checkpoint) = message.checkpoint.as_ref() {
1394                checkpoint.git_checkpoint.clone()
1395            } else {
1396                return Task::ready(Ok(()));
1397            }
1398        } else {
1399            return Task::ready(Ok(()));
1400        };
1401
1402        let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1403        cx.spawn(async move |this, cx| {
1404            let new_checkpoint = new_checkpoint
1405                .await
1406                .context("failed to get new checkpoint")
1407                .log_err();
1408            if let Some(new_checkpoint) = new_checkpoint {
1409                let equal = git_store
1410                    .update(cx, |git, cx| {
1411                        git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1412                    })?
1413                    .await
1414                    .unwrap_or(true);
1415                this.update(cx, |this, cx| {
1416                    let (ix, message) = this.last_user_message().context("no user message")?;
1417                    let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1418                    checkpoint.show = !equal;
1419                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1420                    anyhow::Ok(())
1421                })??;
1422            }
1423
1424            Ok(())
1425        })
1426    }
1427
1428    fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1429        self.entries
1430            .iter_mut()
1431            .enumerate()
1432            .rev()
1433            .find_map(|(ix, entry)| {
1434                if let AgentThreadEntry::UserMessage(message) = entry {
1435                    Some((ix, message))
1436                } else {
1437                    None
1438                }
1439            })
1440    }
1441
1442    fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1443        self.entries.iter().find_map(|entry| {
1444            if let AgentThreadEntry::UserMessage(message) = entry {
1445                if message.id.as_ref() == Some(&id) {
1446                    Some(message)
1447                } else {
1448                    None
1449                }
1450            } else {
1451                None
1452            }
1453        })
1454    }
1455
1456    fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1457        self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1458            if let AgentThreadEntry::UserMessage(message) = entry {
1459                if message.id.as_ref() == Some(&id) {
1460                    Some((ix, message))
1461                } else {
1462                    None
1463                }
1464            } else {
1465                None
1466            }
1467        })
1468    }
1469
1470    pub fn read_text_file(
1471        &self,
1472        path: PathBuf,
1473        line: Option<u32>,
1474        limit: Option<u32>,
1475        reuse_shared_snapshot: bool,
1476        cx: &mut Context<Self>,
1477    ) -> Task<Result<String>> {
1478        let project = self.project.clone();
1479        let action_log = self.action_log.clone();
1480        cx.spawn(async move |this, cx| {
1481            let load = project.update(cx, |project, cx| {
1482                let path = project
1483                    .project_path_for_absolute_path(&path, cx)
1484                    .context("invalid path")?;
1485                anyhow::Ok(project.open_buffer(path, cx))
1486            });
1487            let buffer = load??.await?;
1488
1489            let snapshot = if reuse_shared_snapshot {
1490                this.read_with(cx, |this, _| {
1491                    this.shared_buffers.get(&buffer.clone()).cloned()
1492                })
1493                .log_err()
1494                .flatten()
1495            } else {
1496                None
1497            };
1498
1499            let snapshot = if let Some(snapshot) = snapshot {
1500                snapshot
1501            } else {
1502                action_log.update(cx, |action_log, cx| {
1503                    action_log.buffer_read(buffer.clone(), cx);
1504                })?;
1505                project.update(cx, |project, cx| {
1506                    let position = buffer
1507                        .read(cx)
1508                        .snapshot()
1509                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1510                    project.set_agent_location(
1511                        Some(AgentLocation {
1512                            buffer: buffer.downgrade(),
1513                            position,
1514                        }),
1515                        cx,
1516                    );
1517                })?;
1518
1519                buffer.update(cx, |buffer, _| buffer.snapshot())?
1520            };
1521
1522            this.update(cx, |this, _| {
1523                let text = snapshot.text();
1524                this.shared_buffers.insert(buffer.clone(), snapshot);
1525                if line.is_none() && limit.is_none() {
1526                    return Ok(text);
1527                }
1528                let limit = limit.unwrap_or(u32::MAX) as usize;
1529                let Some(line) = line else {
1530                    return Ok(text.lines().take(limit).collect::<String>());
1531                };
1532
1533                let count = text.lines().count();
1534                if count < line as usize {
1535                    anyhow::bail!("There are only {} lines", count);
1536                }
1537                Ok(text
1538                    .lines()
1539                    .skip(line as usize + 1)
1540                    .take(limit)
1541                    .collect::<String>())
1542            })?
1543        })
1544    }
1545
1546    pub fn write_text_file(
1547        &self,
1548        path: PathBuf,
1549        content: String,
1550        cx: &mut Context<Self>,
1551    ) -> Task<Result<()>> {
1552        let project = self.project.clone();
1553        let action_log = self.action_log.clone();
1554        cx.spawn(async move |this, cx| {
1555            let load = project.update(cx, |project, cx| {
1556                let path = project
1557                    .project_path_for_absolute_path(&path, cx)
1558                    .context("invalid path")?;
1559                anyhow::Ok(project.open_buffer(path, cx))
1560            });
1561            let buffer = load??.await?;
1562            let snapshot = this.update(cx, |this, cx| {
1563                this.shared_buffers
1564                    .get(&buffer)
1565                    .cloned()
1566                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1567            })?;
1568            let edits = cx
1569                .background_executor()
1570                .spawn(async move {
1571                    let old_text = snapshot.text();
1572                    text_diff(old_text.as_str(), &content)
1573                        .into_iter()
1574                        .map(|(range, replacement)| {
1575                            (
1576                                snapshot.anchor_after(range.start)
1577                                    ..snapshot.anchor_before(range.end),
1578                                replacement,
1579                            )
1580                        })
1581                        .collect::<Vec<_>>()
1582                })
1583                .await;
1584            cx.update(|cx| {
1585                project.update(cx, |project, cx| {
1586                    project.set_agent_location(
1587                        Some(AgentLocation {
1588                            buffer: buffer.downgrade(),
1589                            position: edits
1590                                .last()
1591                                .map(|(range, _)| range.end)
1592                                .unwrap_or(Anchor::MIN),
1593                        }),
1594                        cx,
1595                    );
1596                });
1597
1598                action_log.update(cx, |action_log, cx| {
1599                    action_log.buffer_read(buffer.clone(), cx);
1600                });
1601                buffer.update(cx, |buffer, cx| {
1602                    buffer.edit(edits, None, cx);
1603                });
1604                action_log.update(cx, |action_log, cx| {
1605                    action_log.buffer_edited(buffer.clone(), cx);
1606                });
1607            })?;
1608            project
1609                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1610                .await
1611        })
1612    }
1613
1614    pub fn to_markdown(&self, cx: &App) -> String {
1615        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1616    }
1617
1618    pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1619        cx.emit(AcpThreadEvent::ServerExited(status));
1620    }
1621}
1622
1623fn markdown_for_raw_output(
1624    raw_output: &serde_json::Value,
1625    language_registry: &Arc<LanguageRegistry>,
1626    cx: &mut App,
1627) -> Option<Entity<Markdown>> {
1628    match raw_output {
1629        serde_json::Value::Null => None,
1630        serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1631            Markdown::new(
1632                value.to_string().into(),
1633                Some(language_registry.clone()),
1634                None,
1635                cx,
1636            )
1637        })),
1638        serde_json::Value::Number(value) => Some(cx.new(|cx| {
1639            Markdown::new(
1640                value.to_string().into(),
1641                Some(language_registry.clone()),
1642                None,
1643                cx,
1644            )
1645        })),
1646        serde_json::Value::String(value) => Some(cx.new(|cx| {
1647            Markdown::new(
1648                value.clone().into(),
1649                Some(language_registry.clone()),
1650                None,
1651                cx,
1652            )
1653        })),
1654        value => Some(cx.new(|cx| {
1655            Markdown::new(
1656                format!("```json\n{}\n```", value).into(),
1657                Some(language_registry.clone()),
1658                None,
1659                cx,
1660            )
1661        })),
1662    }
1663}
1664
1665#[cfg(test)]
1666mod tests {
1667    use super::*;
1668    use anyhow::anyhow;
1669    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1670    use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1671    use indoc::indoc;
1672    use project::{FakeFs, Fs};
1673    use rand::Rng as _;
1674    use serde_json::json;
1675    use settings::SettingsStore;
1676    use smol::stream::StreamExt as _;
1677    use std::{
1678        any::Any,
1679        cell::RefCell,
1680        path::Path,
1681        rc::Rc,
1682        sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1683        time::Duration,
1684    };
1685    use util::path;
1686
1687    fn init_test(cx: &mut TestAppContext) {
1688        env_logger::try_init().ok();
1689        cx.update(|cx| {
1690            let settings_store = SettingsStore::test(cx);
1691            cx.set_global(settings_store);
1692            Project::init_settings(cx);
1693            language::init(cx);
1694        });
1695    }
1696
1697    #[gpui::test]
1698    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1699        init_test(cx);
1700
1701        let fs = FakeFs::new(cx.executor());
1702        let project = Project::test(fs, [], cx).await;
1703        let connection = Rc::new(FakeAgentConnection::new());
1704        let thread = cx
1705            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1706            .await
1707            .unwrap();
1708
1709        // Test creating a new user message
1710        thread.update(cx, |thread, cx| {
1711            thread.push_user_content_block(
1712                None,
1713                acp::ContentBlock::Text(acp::TextContent {
1714                    annotations: None,
1715                    text: "Hello, ".to_string(),
1716                }),
1717                cx,
1718            );
1719        });
1720
1721        thread.update(cx, |thread, cx| {
1722            assert_eq!(thread.entries.len(), 1);
1723            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1724                assert_eq!(user_msg.id, None);
1725                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1726            } else {
1727                panic!("Expected UserMessage");
1728            }
1729        });
1730
1731        // Test appending to existing user message
1732        let message_1_id = UserMessageId::new();
1733        thread.update(cx, |thread, cx| {
1734            thread.push_user_content_block(
1735                Some(message_1_id.clone()),
1736                acp::ContentBlock::Text(acp::TextContent {
1737                    annotations: None,
1738                    text: "world!".to_string(),
1739                }),
1740                cx,
1741            );
1742        });
1743
1744        thread.update(cx, |thread, cx| {
1745            assert_eq!(thread.entries.len(), 1);
1746            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1747                assert_eq!(user_msg.id, Some(message_1_id));
1748                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1749            } else {
1750                panic!("Expected UserMessage");
1751            }
1752        });
1753
1754        // Test creating new user message after assistant message
1755        thread.update(cx, |thread, cx| {
1756            thread.push_assistant_content_block(
1757                acp::ContentBlock::Text(acp::TextContent {
1758                    annotations: None,
1759                    text: "Assistant response".to_string(),
1760                }),
1761                false,
1762                cx,
1763            );
1764        });
1765
1766        let message_2_id = UserMessageId::new();
1767        thread.update(cx, |thread, cx| {
1768            thread.push_user_content_block(
1769                Some(message_2_id.clone()),
1770                acp::ContentBlock::Text(acp::TextContent {
1771                    annotations: None,
1772                    text: "New user message".to_string(),
1773                }),
1774                cx,
1775            );
1776        });
1777
1778        thread.update(cx, |thread, cx| {
1779            assert_eq!(thread.entries.len(), 3);
1780            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1781                assert_eq!(user_msg.id, Some(message_2_id));
1782                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1783            } else {
1784                panic!("Expected UserMessage at index 2");
1785            }
1786        });
1787    }
1788
1789    #[gpui::test]
1790    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1791        init_test(cx);
1792
1793        let fs = FakeFs::new(cx.executor());
1794        let project = Project::test(fs, [], cx).await;
1795        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1796            |_, thread, mut cx| {
1797                async move {
1798                    thread.update(&mut cx, |thread, cx| {
1799                        thread
1800                            .handle_session_update(
1801                                acp::SessionUpdate::AgentThoughtChunk {
1802                                    content: "Thinking ".into(),
1803                                },
1804                                cx,
1805                            )
1806                            .unwrap();
1807                        thread
1808                            .handle_session_update(
1809                                acp::SessionUpdate::AgentThoughtChunk {
1810                                    content: "hard!".into(),
1811                                },
1812                                cx,
1813                            )
1814                            .unwrap();
1815                    })?;
1816                    Ok(acp::PromptResponse {
1817                        stop_reason: acp::StopReason::EndTurn,
1818                    })
1819                }
1820                .boxed_local()
1821            },
1822        ));
1823
1824        let thread = cx
1825            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1826            .await
1827            .unwrap();
1828
1829        thread
1830            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1831            .await
1832            .unwrap();
1833
1834        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1835        assert_eq!(
1836            output,
1837            indoc! {r#"
1838            ## User
1839
1840            Hello from Zed!
1841
1842            ## Assistant
1843
1844            <thinking>
1845            Thinking hard!
1846            </thinking>
1847
1848            "#}
1849        );
1850    }
1851
1852    #[gpui::test]
1853    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1854        init_test(cx);
1855
1856        let fs = FakeFs::new(cx.executor());
1857        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1858            .await;
1859        let project = Project::test(fs.clone(), [], cx).await;
1860        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1861        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1862        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1863            move |_, thread, mut cx| {
1864                let read_file_tx = read_file_tx.clone();
1865                async move {
1866                    let content = thread
1867                        .update(&mut cx, |thread, cx| {
1868                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1869                        })
1870                        .unwrap()
1871                        .await
1872                        .unwrap();
1873                    assert_eq!(content, "one\ntwo\nthree\n");
1874                    read_file_tx.take().unwrap().send(()).unwrap();
1875                    thread
1876                        .update(&mut cx, |thread, cx| {
1877                            thread.write_text_file(
1878                                path!("/tmp/foo").into(),
1879                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
1880                                cx,
1881                            )
1882                        })
1883                        .unwrap()
1884                        .await
1885                        .unwrap();
1886                    Ok(acp::PromptResponse {
1887                        stop_reason: acp::StopReason::EndTurn,
1888                    })
1889                }
1890                .boxed_local()
1891            },
1892        ));
1893
1894        let (worktree, pathbuf) = project
1895            .update(cx, |project, cx| {
1896                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1897            })
1898            .await
1899            .unwrap();
1900        let buffer = project
1901            .update(cx, |project, cx| {
1902                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1903            })
1904            .await
1905            .unwrap();
1906
1907        let thread = cx
1908            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
1909            .await
1910            .unwrap();
1911
1912        let request = thread.update(cx, |thread, cx| {
1913            thread.send_raw("Extend the count in /tmp/foo", cx)
1914        });
1915        read_file_rx.await.ok();
1916        buffer.update(cx, |buffer, cx| {
1917            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1918        });
1919        cx.run_until_parked();
1920        assert_eq!(
1921            buffer.read_with(cx, |buffer, _| buffer.text()),
1922            "zero\none\ntwo\nthree\nfour\nfive\n"
1923        );
1924        assert_eq!(
1925            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1926            "zero\none\ntwo\nthree\nfour\nfive\n"
1927        );
1928        request.await.unwrap();
1929    }
1930
1931    #[gpui::test]
1932    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1933        init_test(cx);
1934
1935        let fs = FakeFs::new(cx.executor());
1936        let project = Project::test(fs, [], cx).await;
1937        let id = acp::ToolCallId("test".into());
1938
1939        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1940            let id = id.clone();
1941            move |_, thread, mut cx| {
1942                let id = id.clone();
1943                async move {
1944                    thread
1945                        .update(&mut cx, |thread, cx| {
1946                            thread.handle_session_update(
1947                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1948                                    id: id.clone(),
1949                                    title: "Label".into(),
1950                                    kind: acp::ToolKind::Fetch,
1951                                    status: acp::ToolCallStatus::InProgress,
1952                                    content: vec![],
1953                                    locations: vec![],
1954                                    raw_input: None,
1955                                    raw_output: None,
1956                                }),
1957                                cx,
1958                            )
1959                        })
1960                        .unwrap()
1961                        .unwrap();
1962                    Ok(acp::PromptResponse {
1963                        stop_reason: acp::StopReason::EndTurn,
1964                    })
1965                }
1966                .boxed_local()
1967            }
1968        }));
1969
1970        let thread = cx
1971            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1972            .await
1973            .unwrap();
1974
1975        let request = thread.update(cx, |thread, cx| {
1976            thread.send_raw("Fetch https://example.com", cx)
1977        });
1978
1979        run_until_first_tool_call(&thread, cx).await;
1980
1981        thread.read_with(cx, |thread, _| {
1982            assert!(matches!(
1983                thread.entries[1],
1984                AgentThreadEntry::ToolCall(ToolCall {
1985                    status: ToolCallStatus::InProgress,
1986                    ..
1987                })
1988            ));
1989        });
1990
1991        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1992
1993        thread.read_with(cx, |thread, _| {
1994            assert!(matches!(
1995                &thread.entries[1],
1996                AgentThreadEntry::ToolCall(ToolCall {
1997                    status: ToolCallStatus::Canceled,
1998                    ..
1999                })
2000            ));
2001        });
2002
2003        thread
2004            .update(cx, |thread, cx| {
2005                thread.handle_session_update(
2006                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2007                        id,
2008                        fields: acp::ToolCallUpdateFields {
2009                            status: Some(acp::ToolCallStatus::Completed),
2010                            ..Default::default()
2011                        },
2012                    }),
2013                    cx,
2014                )
2015            })
2016            .unwrap();
2017
2018        request.await.unwrap();
2019
2020        thread.read_with(cx, |thread, _| {
2021            assert!(matches!(
2022                thread.entries[1],
2023                AgentThreadEntry::ToolCall(ToolCall {
2024                    status: ToolCallStatus::Completed,
2025                    ..
2026                })
2027            ));
2028        });
2029    }
2030
2031    #[gpui::test]
2032    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2033        init_test(cx);
2034        let fs = FakeFs::new(cx.background_executor.clone());
2035        fs.insert_tree(path!("/test"), json!({})).await;
2036        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2037
2038        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2039            move |_, thread, mut cx| {
2040                async move {
2041                    thread
2042                        .update(&mut cx, |thread, cx| {
2043                            thread.handle_session_update(
2044                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2045                                    id: acp::ToolCallId("test".into()),
2046                                    title: "Label".into(),
2047                                    kind: acp::ToolKind::Edit,
2048                                    status: acp::ToolCallStatus::Completed,
2049                                    content: vec![acp::ToolCallContent::Diff {
2050                                        diff: acp::Diff {
2051                                            path: "/test/test.txt".into(),
2052                                            old_text: None,
2053                                            new_text: "foo".into(),
2054                                        },
2055                                    }],
2056                                    locations: vec![],
2057                                    raw_input: None,
2058                                    raw_output: None,
2059                                }),
2060                                cx,
2061                            )
2062                        })
2063                        .unwrap()
2064                        .unwrap();
2065                    Ok(acp::PromptResponse {
2066                        stop_reason: acp::StopReason::EndTurn,
2067                    })
2068                }
2069                .boxed_local()
2070            }
2071        }));
2072
2073        let thread = cx
2074            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2075            .await
2076            .unwrap();
2077
2078        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2079            .await
2080            .unwrap();
2081
2082        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2083    }
2084
2085    #[gpui::test(iterations = 10)]
2086    async fn test_checkpoints(cx: &mut TestAppContext) {
2087        init_test(cx);
2088        let fs = FakeFs::new(cx.background_executor.clone());
2089        fs.insert_tree(
2090            path!("/test"),
2091            json!({
2092                ".git": {}
2093            }),
2094        )
2095        .await;
2096        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2097
2098        let simulate_changes = Arc::new(AtomicBool::new(true));
2099        let next_filename = Arc::new(AtomicUsize::new(0));
2100        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2101            let simulate_changes = simulate_changes.clone();
2102            let next_filename = next_filename.clone();
2103            let fs = fs.clone();
2104            move |request, thread, mut cx| {
2105                let fs = fs.clone();
2106                let simulate_changes = simulate_changes.clone();
2107                let next_filename = next_filename.clone();
2108                async move {
2109                    if simulate_changes.load(SeqCst) {
2110                        let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2111                        fs.write(Path::new(&filename), b"").await?;
2112                    }
2113
2114                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2115                        panic!("expected text content block");
2116                    };
2117                    thread.update(&mut cx, |thread, cx| {
2118                        thread
2119                            .handle_session_update(
2120                                acp::SessionUpdate::AgentMessageChunk {
2121                                    content: content.text.to_uppercase().into(),
2122                                },
2123                                cx,
2124                            )
2125                            .unwrap();
2126                    })?;
2127                    Ok(acp::PromptResponse {
2128                        stop_reason: acp::StopReason::EndTurn,
2129                    })
2130                }
2131                .boxed_local()
2132            }
2133        }));
2134        let thread = cx
2135            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2136            .await
2137            .unwrap();
2138
2139        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2140            .await
2141            .unwrap();
2142        thread.read_with(cx, |thread, cx| {
2143            assert_eq!(
2144                thread.to_markdown(cx),
2145                indoc! {"
2146                    ## User (checkpoint)
2147
2148                    Lorem
2149
2150                    ## Assistant
2151
2152                    LOREM
2153
2154                "}
2155            );
2156        });
2157        assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
2158
2159        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2160            .await
2161            .unwrap();
2162        thread.read_with(cx, |thread, cx| {
2163            assert_eq!(
2164                thread.to_markdown(cx),
2165                indoc! {"
2166                    ## User (checkpoint)
2167
2168                    Lorem
2169
2170                    ## Assistant
2171
2172                    LOREM
2173
2174                    ## User (checkpoint)
2175
2176                    ipsum
2177
2178                    ## Assistant
2179
2180                    IPSUM
2181
2182                "}
2183            );
2184        });
2185        assert_eq!(
2186            fs.files(),
2187            vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
2188        );
2189
2190        // Checkpoint isn't stored when there are no changes.
2191        simulate_changes.store(false, SeqCst);
2192        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2193            .await
2194            .unwrap();
2195        thread.read_with(cx, |thread, cx| {
2196            assert_eq!(
2197                thread.to_markdown(cx),
2198                indoc! {"
2199                    ## User (checkpoint)
2200
2201                    Lorem
2202
2203                    ## Assistant
2204
2205                    LOREM
2206
2207                    ## User (checkpoint)
2208
2209                    ipsum
2210
2211                    ## Assistant
2212
2213                    IPSUM
2214
2215                    ## User
2216
2217                    dolor
2218
2219                    ## Assistant
2220
2221                    DOLOR
2222
2223                "}
2224            );
2225        });
2226        assert_eq!(
2227            fs.files(),
2228            vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
2229        );
2230
2231        // Rewinding the conversation truncates the history and restores the checkpoint.
2232        thread
2233            .update(cx, |thread, cx| {
2234                let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2235                    panic!("unexpected entries {:?}", thread.entries)
2236                };
2237                thread.rewind(message.id.clone().unwrap(), cx)
2238            })
2239            .await
2240            .unwrap();
2241        thread.read_with(cx, |thread, cx| {
2242            assert_eq!(
2243                thread.to_markdown(cx),
2244                indoc! {"
2245                    ## User (checkpoint)
2246
2247                    Lorem
2248
2249                    ## Assistant
2250
2251                    LOREM
2252
2253                "}
2254            );
2255        });
2256        assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
2257    }
2258
2259    async fn run_until_first_tool_call(
2260        thread: &Entity<AcpThread>,
2261        cx: &mut TestAppContext,
2262    ) -> usize {
2263        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2264
2265        let subscription = cx.update(|cx| {
2266            cx.subscribe(thread, move |thread, _, cx| {
2267                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2268                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2269                        return tx.try_send(ix).unwrap();
2270                    }
2271                }
2272            })
2273        });
2274
2275        select! {
2276            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2277                panic!("Timeout waiting for tool call")
2278            }
2279            ix = rx.next().fuse() => {
2280                drop(subscription);
2281                ix.unwrap()
2282            }
2283        }
2284    }
2285
2286    #[derive(Clone, Default)]
2287    struct FakeAgentConnection {
2288        auth_methods: Vec<acp::AuthMethod>,
2289        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2290        on_user_message: Option<
2291            Rc<
2292                dyn Fn(
2293                        acp::PromptRequest,
2294                        WeakEntity<AcpThread>,
2295                        AsyncApp,
2296                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2297                    + 'static,
2298            >,
2299        >,
2300    }
2301
2302    impl FakeAgentConnection {
2303        fn new() -> Self {
2304            Self {
2305                auth_methods: Vec::new(),
2306                on_user_message: None,
2307                sessions: Arc::default(),
2308            }
2309        }
2310
2311        #[expect(unused)]
2312        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2313            self.auth_methods = auth_methods;
2314            self
2315        }
2316
2317        fn on_user_message(
2318            mut self,
2319            handler: impl Fn(
2320                acp::PromptRequest,
2321                WeakEntity<AcpThread>,
2322                AsyncApp,
2323            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2324            + 'static,
2325        ) -> Self {
2326            self.on_user_message.replace(Rc::new(handler));
2327            self
2328        }
2329    }
2330
2331    impl AgentConnection for FakeAgentConnection {
2332        fn auth_methods(&self) -> &[acp::AuthMethod] {
2333            &self.auth_methods
2334        }
2335
2336        fn new_thread(
2337            self: Rc<Self>,
2338            project: Entity<Project>,
2339            _cwd: &Path,
2340            cx: &mut App,
2341        ) -> Task<gpui::Result<Entity<AcpThread>>> {
2342            let session_id = acp::SessionId(
2343                rand::thread_rng()
2344                    .sample_iter(&rand::distributions::Alphanumeric)
2345                    .take(7)
2346                    .map(char::from)
2347                    .collect::<String>()
2348                    .into(),
2349            );
2350            let thread =
2351                cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
2352            self.sessions.lock().insert(session_id, thread.downgrade());
2353            Task::ready(Ok(thread))
2354        }
2355
2356        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2357            if self.auth_methods().iter().any(|m| m.id == method) {
2358                Task::ready(Ok(()))
2359            } else {
2360                Task::ready(Err(anyhow!("Invalid Auth Method")))
2361            }
2362        }
2363
2364        fn prompt(
2365            &self,
2366            _id: Option<UserMessageId>,
2367            params: acp::PromptRequest,
2368            cx: &mut App,
2369        ) -> Task<gpui::Result<acp::PromptResponse>> {
2370            let sessions = self.sessions.lock();
2371            let thread = sessions.get(&params.session_id).unwrap();
2372            if let Some(handler) = &self.on_user_message {
2373                let handler = handler.clone();
2374                let thread = thread.clone();
2375                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2376            } else {
2377                Task::ready(Ok(acp::PromptResponse {
2378                    stop_reason: acp::StopReason::EndTurn,
2379                }))
2380            }
2381        }
2382
2383        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2384            let sessions = self.sessions.lock();
2385            let thread = sessions.get(&session_id).unwrap().clone();
2386
2387            cx.spawn(async move |cx| {
2388                thread
2389                    .update(cx, |thread, cx| thread.cancel(cx))
2390                    .unwrap()
2391                    .await
2392            })
2393            .detach();
2394        }
2395
2396        fn session_editor(
2397            &self,
2398            session_id: &acp::SessionId,
2399            _cx: &mut App,
2400        ) -> Option<Rc<dyn AgentSessionEditor>> {
2401            Some(Rc::new(FakeAgentSessionEditor {
2402                _session_id: session_id.clone(),
2403            }))
2404        }
2405
2406        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2407            self
2408        }
2409    }
2410
2411    struct FakeAgentSessionEditor {
2412        _session_id: acp::SessionId,
2413    }
2414
2415    impl AgentSessionEditor for FakeAgentSessionEditor {
2416        fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2417            Task::ready(Ok(()))
2418        }
2419    }
2420}