acp_thread.rs

   1mod connection;
   2mod diff;
   3mod terminal;
   4
   5pub use connection::*;
   6pub use diff::*;
   7pub use terminal::*;
   8
   9use action_log::ActionLog;
  10use agent_client_protocol as acp;
  11use anyhow::{Context as _, Result};
  12use editor::Bias;
  13use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  14use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
  15use itertools::Itertools;
  16use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff};
  17use markdown::Markdown;
  18use project::{AgentLocation, Project};
  19use std::collections::HashMap;
  20use std::error::Error;
  21use std::fmt::Formatter;
  22use std::process::ExitStatus;
  23use std::rc::Rc;
  24use std::{
  25    fmt::Display,
  26    mem,
  27    path::{Path, PathBuf},
  28    sync::Arc,
  29};
  30use ui::App;
  31use util::ResultExt;
  32
  33#[derive(Debug)]
  34pub struct UserMessage {
  35    pub content: ContentBlock,
  36}
  37
  38impl UserMessage {
  39    pub fn from_acp(
  40        message: impl IntoIterator<Item = acp::ContentBlock>,
  41        language_registry: Arc<LanguageRegistry>,
  42        cx: &mut App,
  43    ) -> Self {
  44        let mut content = ContentBlock::Empty;
  45        for chunk in message {
  46            content.append(chunk, &language_registry, cx)
  47        }
  48        Self { content: content }
  49    }
  50
  51    fn to_markdown(&self, cx: &App) -> String {
  52        format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
  53    }
  54}
  55
  56#[derive(Debug)]
  57pub struct MentionPath<'a>(&'a Path);
  58
  59impl<'a> MentionPath<'a> {
  60    const PREFIX: &'static str = "@file:";
  61
  62    pub fn new(path: &'a Path) -> Self {
  63        MentionPath(path)
  64    }
  65
  66    pub fn try_parse(url: &'a str) -> Option<Self> {
  67        let path = url.strip_prefix(Self::PREFIX)?;
  68        Some(MentionPath(Path::new(path)))
  69    }
  70
  71    pub fn path(&self) -> &Path {
  72        self.0
  73    }
  74}
  75
  76impl Display for MentionPath<'_> {
  77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  78        write!(
  79            f,
  80            "[@{}]({}{})",
  81            self.0.file_name().unwrap_or_default().display(),
  82            Self::PREFIX,
  83            self.0.display()
  84        )
  85    }
  86}
  87
  88#[derive(Debug, PartialEq)]
  89pub struct AssistantMessage {
  90    pub chunks: Vec<AssistantMessageChunk>,
  91}
  92
  93impl AssistantMessage {
  94    pub fn to_markdown(&self, cx: &App) -> String {
  95        format!(
  96            "## Assistant\n\n{}\n\n",
  97            self.chunks
  98                .iter()
  99                .map(|chunk| chunk.to_markdown(cx))
 100                .join("\n\n")
 101        )
 102    }
 103}
 104
 105#[derive(Debug, PartialEq)]
 106pub enum AssistantMessageChunk {
 107    Message { block: ContentBlock },
 108    Thought { block: ContentBlock },
 109}
 110
 111impl AssistantMessageChunk {
 112    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
 113        Self::Message {
 114            block: ContentBlock::new(chunk.into(), language_registry, cx),
 115        }
 116    }
 117
 118    fn to_markdown(&self, cx: &App) -> String {
 119        match self {
 120            Self::Message { block } => block.to_markdown(cx).to_string(),
 121            Self::Thought { block } => {
 122                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 123            }
 124        }
 125    }
 126}
 127
 128#[derive(Debug)]
 129pub enum AgentThreadEntry {
 130    UserMessage(UserMessage),
 131    AssistantMessage(AssistantMessage),
 132    ToolCall(ToolCall),
 133}
 134
 135impl AgentThreadEntry {
 136    fn to_markdown(&self, cx: &App) -> String {
 137        match self {
 138            Self::UserMessage(message) => message.to_markdown(cx),
 139            Self::AssistantMessage(message) => message.to_markdown(cx),
 140            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 141        }
 142    }
 143
 144    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 145        if let AgentThreadEntry::ToolCall(call) = self {
 146            itertools::Either::Left(call.diffs())
 147        } else {
 148            itertools::Either::Right(std::iter::empty())
 149        }
 150    }
 151
 152    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 153        if let AgentThreadEntry::ToolCall(call) = self {
 154            itertools::Either::Left(call.terminals())
 155        } else {
 156            itertools::Either::Right(std::iter::empty())
 157        }
 158    }
 159
 160    pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
 161        if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
 162            Some(locations)
 163        } else {
 164            None
 165        }
 166    }
 167}
 168
 169#[derive(Debug)]
 170pub struct ToolCall {
 171    pub id: acp::ToolCallId,
 172    pub label: Entity<Markdown>,
 173    pub kind: acp::ToolKind,
 174    pub content: Vec<ToolCallContent>,
 175    pub status: ToolCallStatus,
 176    pub locations: Vec<acp::ToolCallLocation>,
 177    pub raw_input: Option<serde_json::Value>,
 178    pub raw_output: Option<serde_json::Value>,
 179}
 180
 181impl ToolCall {
 182    fn from_acp(
 183        tool_call: acp::ToolCall,
 184        status: ToolCallStatus,
 185        language_registry: Arc<LanguageRegistry>,
 186        cx: &mut App,
 187    ) -> Self {
 188        Self {
 189            id: tool_call.id,
 190            label: cx.new(|cx| {
 191                Markdown::new(
 192                    tool_call.title.into(),
 193                    Some(language_registry.clone()),
 194                    None,
 195                    cx,
 196                )
 197            }),
 198            kind: tool_call.kind,
 199            content: tool_call
 200                .content
 201                .into_iter()
 202                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 203                .collect(),
 204            locations: tool_call.locations,
 205            status,
 206            raw_input: tool_call.raw_input,
 207            raw_output: tool_call.raw_output,
 208        }
 209    }
 210
 211    fn update_fields(
 212        &mut self,
 213        fields: acp::ToolCallUpdateFields,
 214        language_registry: Arc<LanguageRegistry>,
 215        cx: &mut App,
 216    ) {
 217        let acp::ToolCallUpdateFields {
 218            kind,
 219            status,
 220            title,
 221            content,
 222            locations,
 223            raw_input,
 224            raw_output,
 225        } = fields;
 226
 227        if let Some(kind) = kind {
 228            self.kind = kind;
 229        }
 230
 231        if let Some(status) = status {
 232            self.status = ToolCallStatus::Allowed { status };
 233        }
 234
 235        if let Some(title) = title {
 236            self.label.update(cx, |label, cx| {
 237                label.replace(title, cx);
 238            });
 239        }
 240
 241        if let Some(content) = content {
 242            self.content = content
 243                .into_iter()
 244                .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
 245                .collect();
 246        }
 247
 248        if let Some(locations) = locations {
 249            self.locations = locations;
 250        }
 251
 252        if let Some(raw_input) = raw_input {
 253            self.raw_input = Some(raw_input);
 254        }
 255
 256        if let Some(raw_output) = raw_output {
 257            self.raw_output = Some(raw_output);
 258        }
 259    }
 260
 261    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 262        self.content.iter().filter_map(|content| match content {
 263            ToolCallContent::Diff(diff) => Some(diff),
 264            ToolCallContent::ContentBlock(_) => None,
 265            ToolCallContent::Terminal(_) => None,
 266        })
 267    }
 268
 269    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 270        self.content.iter().filter_map(|content| match content {
 271            ToolCallContent::Terminal(terminal) => Some(terminal),
 272            ToolCallContent::ContentBlock(_) => None,
 273            ToolCallContent::Diff(_) => None,
 274        })
 275    }
 276
 277    fn to_markdown(&self, cx: &App) -> String {
 278        let mut markdown = format!(
 279            "**Tool Call: {}**\nStatus: {}\n\n",
 280            self.label.read(cx).source(),
 281            self.status
 282        );
 283        for content in &self.content {
 284            markdown.push_str(content.to_markdown(cx).as_str());
 285            markdown.push_str("\n\n");
 286        }
 287        markdown
 288    }
 289}
 290
 291#[derive(Debug)]
 292pub enum ToolCallStatus {
 293    WaitingForConfirmation {
 294        options: Vec<acp::PermissionOption>,
 295        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 296    },
 297    Allowed {
 298        status: acp::ToolCallStatus,
 299    },
 300    Rejected,
 301    Canceled,
 302}
 303
 304impl Display for ToolCallStatus {
 305    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 306        write!(
 307            f,
 308            "{}",
 309            match self {
 310                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 311                ToolCallStatus::Allowed { status } => match status {
 312                    acp::ToolCallStatus::Pending => "Pending",
 313                    acp::ToolCallStatus::InProgress => "In Progress",
 314                    acp::ToolCallStatus::Completed => "Completed",
 315                    acp::ToolCallStatus::Failed => "Failed",
 316                },
 317                ToolCallStatus::Rejected => "Rejected",
 318                ToolCallStatus::Canceled => "Canceled",
 319            }
 320        )
 321    }
 322}
 323
 324#[derive(Debug, PartialEq, Clone)]
 325pub enum ContentBlock {
 326    Empty,
 327    Markdown { markdown: Entity<Markdown> },
 328}
 329
 330impl ContentBlock {
 331    pub fn new(
 332        block: acp::ContentBlock,
 333        language_registry: &Arc<LanguageRegistry>,
 334        cx: &mut App,
 335    ) -> Self {
 336        let mut this = Self::Empty;
 337        this.append(block, language_registry, cx);
 338        this
 339    }
 340
 341    pub fn new_combined(
 342        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 343        language_registry: Arc<LanguageRegistry>,
 344        cx: &mut App,
 345    ) -> Self {
 346        let mut this = Self::Empty;
 347        for block in blocks {
 348            this.append(block, &language_registry, cx);
 349        }
 350        this
 351    }
 352
 353    pub fn append(
 354        &mut self,
 355        block: acp::ContentBlock,
 356        language_registry: &Arc<LanguageRegistry>,
 357        cx: &mut App,
 358    ) {
 359        let new_content = match block {
 360            acp::ContentBlock::Text(text_content) => text_content.text.clone(),
 361            acp::ContentBlock::ResourceLink(resource_link) => {
 362                if let Some(path) = resource_link.uri.strip_prefix("file://") {
 363                    format!("{}", MentionPath(path.as_ref()))
 364                } else {
 365                    resource_link.uri.clone()
 366                }
 367            }
 368            acp::ContentBlock::Image(_)
 369            | acp::ContentBlock::Audio(_)
 370            | acp::ContentBlock::Resource(_) => String::new(),
 371        };
 372
 373        match self {
 374            ContentBlock::Empty => {
 375                *self = ContentBlock::Markdown {
 376                    markdown: cx.new(|cx| {
 377                        Markdown::new(
 378                            new_content.into(),
 379                            Some(language_registry.clone()),
 380                            None,
 381                            cx,
 382                        )
 383                    }),
 384                };
 385            }
 386            ContentBlock::Markdown { markdown } => {
 387                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 388            }
 389        }
 390    }
 391
 392    fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 393        match self {
 394            ContentBlock::Empty => "",
 395            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 396        }
 397    }
 398
 399    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 400        match self {
 401            ContentBlock::Empty => None,
 402            ContentBlock::Markdown { markdown } => Some(markdown),
 403        }
 404    }
 405}
 406
 407#[derive(Debug)]
 408pub enum ToolCallContent {
 409    ContentBlock(ContentBlock),
 410    Diff(Entity<Diff>),
 411    Terminal(Entity<Terminal>),
 412}
 413
 414impl ToolCallContent {
 415    pub fn from_acp(
 416        content: acp::ToolCallContent,
 417        language_registry: Arc<LanguageRegistry>,
 418        cx: &mut App,
 419    ) -> Self {
 420        match content {
 421            acp::ToolCallContent::Content { content } => {
 422                Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
 423            }
 424            acp::ToolCallContent::Diff { diff } => {
 425                Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
 426            }
 427        }
 428    }
 429
 430    pub fn to_markdown(&self, cx: &App) -> String {
 431        match self {
 432            Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
 433            Self::Diff(diff) => diff.read(cx).to_markdown(cx),
 434            Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
 435        }
 436    }
 437}
 438
 439#[derive(Debug, PartialEq)]
 440pub enum ToolCallUpdate {
 441    UpdateFields(acp::ToolCallUpdate),
 442    UpdateDiff(ToolCallUpdateDiff),
 443    UpdateTerminal(ToolCallUpdateTerminal),
 444}
 445
 446impl ToolCallUpdate {
 447    fn id(&self) -> &acp::ToolCallId {
 448        match self {
 449            Self::UpdateFields(update) => &update.id,
 450            Self::UpdateDiff(diff) => &diff.id,
 451            Self::UpdateTerminal(terminal) => &terminal.id,
 452        }
 453    }
 454}
 455
 456impl From<acp::ToolCallUpdate> for ToolCallUpdate {
 457    fn from(update: acp::ToolCallUpdate) -> Self {
 458        Self::UpdateFields(update)
 459    }
 460}
 461
 462impl From<ToolCallUpdateDiff> for ToolCallUpdate {
 463    fn from(diff: ToolCallUpdateDiff) -> Self {
 464        Self::UpdateDiff(diff)
 465    }
 466}
 467
 468#[derive(Debug, PartialEq)]
 469pub struct ToolCallUpdateDiff {
 470    pub id: acp::ToolCallId,
 471    pub diff: Entity<Diff>,
 472}
 473
 474impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
 475    fn from(terminal: ToolCallUpdateTerminal) -> Self {
 476        Self::UpdateTerminal(terminal)
 477    }
 478}
 479
 480#[derive(Debug, PartialEq)]
 481pub struct ToolCallUpdateTerminal {
 482    pub id: acp::ToolCallId,
 483    pub terminal: Entity<Terminal>,
 484}
 485
 486#[derive(Debug, Default)]
 487pub struct Plan {
 488    pub entries: Vec<PlanEntry>,
 489}
 490
 491#[derive(Debug)]
 492pub struct PlanStats<'a> {
 493    pub in_progress_entry: Option<&'a PlanEntry>,
 494    pub pending: u32,
 495    pub completed: u32,
 496}
 497
 498impl Plan {
 499    pub fn is_empty(&self) -> bool {
 500        self.entries.is_empty()
 501    }
 502
 503    pub fn stats(&self) -> PlanStats<'_> {
 504        let mut stats = PlanStats {
 505            in_progress_entry: None,
 506            pending: 0,
 507            completed: 0,
 508        };
 509
 510        for entry in &self.entries {
 511            match &entry.status {
 512                acp::PlanEntryStatus::Pending => {
 513                    stats.pending += 1;
 514                }
 515                acp::PlanEntryStatus::InProgress => {
 516                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 517                }
 518                acp::PlanEntryStatus::Completed => {
 519                    stats.completed += 1;
 520                }
 521            }
 522        }
 523
 524        stats
 525    }
 526}
 527
 528#[derive(Debug)]
 529pub struct PlanEntry {
 530    pub content: Entity<Markdown>,
 531    pub priority: acp::PlanEntryPriority,
 532    pub status: acp::PlanEntryStatus,
 533}
 534
 535impl PlanEntry {
 536    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 537        Self {
 538            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 539            priority: entry.priority,
 540            status: entry.status,
 541        }
 542    }
 543}
 544
 545pub struct AcpThread {
 546    title: SharedString,
 547    entries: Vec<AgentThreadEntry>,
 548    plan: Plan,
 549    project: Entity<Project>,
 550    action_log: Entity<ActionLog>,
 551    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 552    send_task: Option<Task<()>>,
 553    connection: Rc<dyn AgentConnection>,
 554    session_id: acp::SessionId,
 555}
 556
 557pub enum AcpThreadEvent {
 558    NewEntry,
 559    EntryUpdated(usize),
 560    ToolAuthorizationRequired,
 561    Stopped,
 562    Error,
 563    ServerExited(ExitStatus),
 564}
 565
 566impl EventEmitter<AcpThreadEvent> for AcpThread {}
 567
 568#[derive(PartialEq, Eq)]
 569pub enum ThreadStatus {
 570    Idle,
 571    WaitingForToolConfirmation,
 572    Generating,
 573}
 574
 575#[derive(Debug, Clone)]
 576pub enum LoadError {
 577    Unsupported {
 578        error_message: SharedString,
 579        upgrade_message: SharedString,
 580        upgrade_command: String,
 581    },
 582    Exited(i32),
 583    Other(SharedString),
 584}
 585
 586impl Display for LoadError {
 587    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 588        match self {
 589            LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
 590            LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
 591            LoadError::Other(msg) => write!(f, "{}", msg),
 592        }
 593    }
 594}
 595
 596impl Error for LoadError {}
 597
 598impl AcpThread {
 599    pub fn new(
 600        title: impl Into<SharedString>,
 601        connection: Rc<dyn AgentConnection>,
 602        project: Entity<Project>,
 603        session_id: acp::SessionId,
 604        cx: &mut Context<Self>,
 605    ) -> Self {
 606        let action_log = cx.new(|_| ActionLog::new(project.clone()));
 607
 608        Self {
 609            action_log,
 610            shared_buffers: Default::default(),
 611            entries: Default::default(),
 612            plan: Default::default(),
 613            title: title.into(),
 614            project,
 615            send_task: None,
 616            connection,
 617            session_id,
 618        }
 619    }
 620
 621    pub fn action_log(&self) -> &Entity<ActionLog> {
 622        &self.action_log
 623    }
 624
 625    pub fn project(&self) -> &Entity<Project> {
 626        &self.project
 627    }
 628
 629    pub fn title(&self) -> SharedString {
 630        self.title.clone()
 631    }
 632
 633    pub fn entries(&self) -> &[AgentThreadEntry] {
 634        &self.entries
 635    }
 636
 637    pub fn session_id(&self) -> &acp::SessionId {
 638        &self.session_id
 639    }
 640
 641    pub fn status(&self) -> ThreadStatus {
 642        if self.send_task.is_some() {
 643            if self.waiting_for_tool_confirmation() {
 644                ThreadStatus::WaitingForToolConfirmation
 645            } else {
 646                ThreadStatus::Generating
 647            }
 648        } else {
 649            ThreadStatus::Idle
 650        }
 651    }
 652
 653    pub fn has_pending_edit_tool_calls(&self) -> bool {
 654        for entry in self.entries.iter().rev() {
 655            match entry {
 656                AgentThreadEntry::UserMessage(_) => return false,
 657                AgentThreadEntry::ToolCall(
 658                    call @ ToolCall {
 659                        status:
 660                            ToolCallStatus::Allowed {
 661                                status:
 662                                    acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
 663                            },
 664                        ..
 665                    },
 666                ) if call.diffs().next().is_some() => {
 667                    return true;
 668                }
 669                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 670            }
 671        }
 672
 673        false
 674    }
 675
 676    pub fn used_tools_since_last_user_message(&self) -> bool {
 677        for entry in self.entries.iter().rev() {
 678            match entry {
 679                AgentThreadEntry::UserMessage(..) => return false,
 680                AgentThreadEntry::AssistantMessage(..) => continue,
 681                AgentThreadEntry::ToolCall(..) => return true,
 682            }
 683        }
 684
 685        false
 686    }
 687
 688    pub fn handle_session_update(
 689        &mut self,
 690        update: acp::SessionUpdate,
 691        cx: &mut Context<Self>,
 692    ) -> Result<()> {
 693        match update {
 694            acp::SessionUpdate::UserMessageChunk { content } => {
 695                self.push_user_content_block(content, cx);
 696            }
 697            acp::SessionUpdate::AgentMessageChunk { content } => {
 698                self.push_assistant_content_block(content, false, cx);
 699            }
 700            acp::SessionUpdate::AgentThoughtChunk { content } => {
 701                self.push_assistant_content_block(content, true, cx);
 702            }
 703            acp::SessionUpdate::ToolCall(tool_call) => {
 704                self.upsert_tool_call(tool_call, cx);
 705            }
 706            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 707                self.update_tool_call(tool_call_update, cx)?;
 708            }
 709            acp::SessionUpdate::Plan(plan) => {
 710                self.update_plan(plan, cx);
 711            }
 712        }
 713        Ok(())
 714    }
 715
 716    pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
 717        let language_registry = self.project.read(cx).languages().clone();
 718        let entries_len = self.entries.len();
 719
 720        if let Some(last_entry) = self.entries.last_mut()
 721            && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
 722        {
 723            content.append(chunk, &language_registry, cx);
 724            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 725        } else {
 726            let content = ContentBlock::new(chunk, &language_registry, cx);
 727            self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
 728        }
 729    }
 730
 731    pub fn push_assistant_content_block(
 732        &mut self,
 733        chunk: acp::ContentBlock,
 734        is_thought: bool,
 735        cx: &mut Context<Self>,
 736    ) {
 737        let language_registry = self.project.read(cx).languages().clone();
 738        let entries_len = self.entries.len();
 739        if let Some(last_entry) = self.entries.last_mut()
 740            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
 741        {
 742            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 743            match (chunks.last_mut(), is_thought) {
 744                (Some(AssistantMessageChunk::Message { block }), false)
 745                | (Some(AssistantMessageChunk::Thought { block }), true) => {
 746                    block.append(chunk, &language_registry, cx)
 747                }
 748                _ => {
 749                    let block = ContentBlock::new(chunk, &language_registry, cx);
 750                    if is_thought {
 751                        chunks.push(AssistantMessageChunk::Thought { block })
 752                    } else {
 753                        chunks.push(AssistantMessageChunk::Message { block })
 754                    }
 755                }
 756            }
 757        } else {
 758            let block = ContentBlock::new(chunk, &language_registry, cx);
 759            let chunk = if is_thought {
 760                AssistantMessageChunk::Thought { block }
 761            } else {
 762                AssistantMessageChunk::Message { block }
 763            };
 764
 765            self.push_entry(
 766                AgentThreadEntry::AssistantMessage(AssistantMessage {
 767                    chunks: vec![chunk],
 768                }),
 769                cx,
 770            );
 771        }
 772    }
 773
 774    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
 775        self.entries.push(entry);
 776        cx.emit(AcpThreadEvent::NewEntry);
 777    }
 778
 779    pub fn update_tool_call(
 780        &mut self,
 781        update: impl Into<ToolCallUpdate>,
 782        cx: &mut Context<Self>,
 783    ) -> Result<()> {
 784        let update = update.into();
 785        let languages = self.project.read(cx).languages().clone();
 786
 787        let (ix, current_call) = self
 788            .tool_call_mut(update.id())
 789            .context("Tool call not found")?;
 790        match update {
 791            ToolCallUpdate::UpdateFields(update) => {
 792                current_call.update_fields(update.fields, languages, cx);
 793            }
 794            ToolCallUpdate::UpdateDiff(update) => {
 795                current_call.content.clear();
 796                current_call
 797                    .content
 798                    .push(ToolCallContent::Diff(update.diff));
 799            }
 800            ToolCallUpdate::UpdateTerminal(update) => {
 801                current_call.content.clear();
 802                current_call
 803                    .content
 804                    .push(ToolCallContent::Terminal(update.terminal));
 805            }
 806        }
 807
 808        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 809
 810        Ok(())
 811    }
 812
 813    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
 814    pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
 815        let status = ToolCallStatus::Allowed {
 816            status: tool_call.status,
 817        };
 818        self.upsert_tool_call_inner(tool_call, status, cx)
 819    }
 820
 821    pub fn upsert_tool_call_inner(
 822        &mut self,
 823        tool_call: acp::ToolCall,
 824        status: ToolCallStatus,
 825        cx: &mut Context<Self>,
 826    ) {
 827        let language_registry = self.project.read(cx).languages().clone();
 828        let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
 829
 830        let location = call.locations.last().cloned();
 831
 832        if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
 833            *current_call = call;
 834
 835            cx.emit(AcpThreadEvent::EntryUpdated(ix));
 836        } else {
 837            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
 838        }
 839
 840        if let Some(location) = location {
 841            self.set_project_location(location, cx)
 842        }
 843    }
 844
 845    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
 846        // The tool call we are looking for is typically the last one, or very close to the end.
 847        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
 848        self.entries
 849            .iter_mut()
 850            .enumerate()
 851            .rev()
 852            .find_map(|(index, tool_call)| {
 853                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
 854                    && &tool_call.id == id
 855                {
 856                    Some((index, tool_call))
 857                } else {
 858                    None
 859                }
 860            })
 861    }
 862
 863    pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
 864        self.project.update(cx, |project, cx| {
 865            let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
 866                return;
 867            };
 868            let buffer = project.open_buffer(path, cx);
 869            cx.spawn(async move |project, cx| {
 870                let buffer = buffer.await?;
 871
 872                project.update(cx, |project, cx| {
 873                    let position = if let Some(line) = location.line {
 874                        let snapshot = buffer.read(cx).snapshot();
 875                        let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
 876                        snapshot.anchor_before(point)
 877                    } else {
 878                        Anchor::MIN
 879                    };
 880
 881                    project.set_agent_location(
 882                        Some(AgentLocation {
 883                            buffer: buffer.downgrade(),
 884                            position,
 885                        }),
 886                        cx,
 887                    );
 888                })
 889            })
 890            .detach_and_log_err(cx);
 891        });
 892    }
 893
 894    pub fn request_tool_call_authorization(
 895        &mut self,
 896        tool_call: acp::ToolCall,
 897        options: Vec<acp::PermissionOption>,
 898        cx: &mut Context<Self>,
 899    ) -> oneshot::Receiver<acp::PermissionOptionId> {
 900        let (tx, rx) = oneshot::channel();
 901
 902        let status = ToolCallStatus::WaitingForConfirmation {
 903            options,
 904            respond_tx: tx,
 905        };
 906
 907        self.upsert_tool_call_inner(tool_call, status, cx);
 908        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
 909        rx
 910    }
 911
 912    pub fn authorize_tool_call(
 913        &mut self,
 914        id: acp::ToolCallId,
 915        option_id: acp::PermissionOptionId,
 916        option_kind: acp::PermissionOptionKind,
 917        cx: &mut Context<Self>,
 918    ) {
 919        let Some((ix, call)) = self.tool_call_mut(&id) else {
 920            return;
 921        };
 922
 923        let new_status = match option_kind {
 924            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
 925                ToolCallStatus::Rejected
 926            }
 927            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
 928                ToolCallStatus::Allowed {
 929                    status: acp::ToolCallStatus::InProgress,
 930                }
 931            }
 932        };
 933
 934        let curr_status = mem::replace(&mut call.status, new_status);
 935
 936        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
 937            respond_tx.send(option_id).log_err();
 938        } else if cfg!(debug_assertions) {
 939            panic!("tried to authorize an already authorized tool call");
 940        }
 941
 942        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 943    }
 944
 945    /// Returns true if the last turn is awaiting tool authorization
 946    pub fn waiting_for_tool_confirmation(&self) -> bool {
 947        for entry in self.entries.iter().rev() {
 948            match &entry {
 949                AgentThreadEntry::ToolCall(call) => match call.status {
 950                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
 951                    ToolCallStatus::Allowed { .. }
 952                    | ToolCallStatus::Rejected
 953                    | ToolCallStatus::Canceled => continue,
 954                },
 955                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
 956                    // Reached the beginning of the turn
 957                    return false;
 958                }
 959            }
 960        }
 961        false
 962    }
 963
 964    pub fn plan(&self) -> &Plan {
 965        &self.plan
 966    }
 967
 968    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
 969        let new_entries_len = request.entries.len();
 970        let mut new_entries = request.entries.into_iter();
 971
 972        // Reuse existing markdown to prevent flickering
 973        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
 974            let PlanEntry {
 975                content,
 976                priority,
 977                status,
 978            } = old;
 979            content.update(cx, |old, cx| {
 980                old.replace(new.content, cx);
 981            });
 982            *priority = new.priority;
 983            *status = new.status;
 984        }
 985        for new in new_entries {
 986            self.plan.entries.push(PlanEntry::from_acp(new, cx))
 987        }
 988        self.plan.entries.truncate(new_entries_len);
 989
 990        cx.notify();
 991    }
 992
 993    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
 994        self.plan
 995            .entries
 996            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
 997        cx.notify();
 998    }
 999
1000    #[cfg(any(test, feature = "test-support"))]
1001    pub fn send_raw(
1002        &mut self,
1003        message: &str,
1004        cx: &mut Context<Self>,
1005    ) -> BoxFuture<'static, Result<()>> {
1006        self.send(
1007            vec![acp::ContentBlock::Text(acp::TextContent {
1008                text: message.to_string(),
1009                annotations: None,
1010            })],
1011            cx,
1012        )
1013    }
1014
1015    pub fn send(
1016        &mut self,
1017        message: Vec<acp::ContentBlock>,
1018        cx: &mut Context<Self>,
1019    ) -> BoxFuture<'static, Result<()>> {
1020        let block = ContentBlock::new_combined(
1021            message.clone(),
1022            self.project.read(cx).languages().clone(),
1023            cx,
1024        );
1025        self.push_entry(
1026            AgentThreadEntry::UserMessage(UserMessage { content: block }),
1027            cx,
1028        );
1029        self.clear_completed_plan_entries(cx);
1030
1031        let (tx, rx) = oneshot::channel();
1032        let cancel_task = self.cancel(cx);
1033
1034        self.send_task = Some(cx.spawn(async move |this, cx| {
1035            async {
1036                cancel_task.await;
1037
1038                let result = this
1039                    .update(cx, |this, cx| {
1040                        this.connection.prompt(
1041                            acp::PromptRequest {
1042                                prompt: message,
1043                                session_id: this.session_id.clone(),
1044                            },
1045                            cx,
1046                        )
1047                    })?
1048                    .await;
1049
1050                tx.send(result).log_err();
1051
1052                anyhow::Ok(())
1053            }
1054            .await
1055            .log_err();
1056        }));
1057
1058        cx.spawn(async move |this, cx| match rx.await {
1059            Ok(Err(e)) => {
1060                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1061                    .log_err();
1062                Err(e)?
1063            }
1064            result => {
1065                let cancelled = matches!(
1066                    result,
1067                    Ok(Ok(acp::PromptResponse {
1068                        stop_reason: acp::StopReason::Cancelled
1069                    }))
1070                );
1071
1072                // We only take the task if the current prompt wasn't cancelled.
1073                //
1074                // This prompt may have been cancelled because another one was sent
1075                // while it was still generating. In these cases, dropping `send_task`
1076                // would cause the next generation to be cancelled.
1077                if !cancelled {
1078                    this.update(cx, |this, _cx| this.send_task.take()).ok();
1079                }
1080
1081                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1082                    .log_err();
1083                Ok(())
1084            }
1085        })
1086        .boxed()
1087    }
1088
1089    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1090        let Some(send_task) = self.send_task.take() else {
1091            return Task::ready(());
1092        };
1093
1094        for entry in self.entries.iter_mut() {
1095            if let AgentThreadEntry::ToolCall(call) = entry {
1096                let cancel = matches!(
1097                    call.status,
1098                    ToolCallStatus::WaitingForConfirmation { .. }
1099                        | ToolCallStatus::Allowed {
1100                            status: acp::ToolCallStatus::InProgress
1101                        }
1102                );
1103
1104                if cancel {
1105                    call.status = ToolCallStatus::Canceled;
1106                }
1107            }
1108        }
1109
1110        self.connection.cancel(&self.session_id, cx);
1111
1112        // Wait for the send task to complete
1113        cx.foreground_executor().spawn(send_task)
1114    }
1115
1116    pub fn read_text_file(
1117        &self,
1118        path: PathBuf,
1119        line: Option<u32>,
1120        limit: Option<u32>,
1121        reuse_shared_snapshot: bool,
1122        cx: &mut Context<Self>,
1123    ) -> Task<Result<String>> {
1124        let project = self.project.clone();
1125        let action_log = self.action_log.clone();
1126        cx.spawn(async move |this, cx| {
1127            let load = project.update(cx, |project, cx| {
1128                let path = project
1129                    .project_path_for_absolute_path(&path, cx)
1130                    .context("invalid path")?;
1131                anyhow::Ok(project.open_buffer(path, cx))
1132            });
1133            let buffer = load??.await?;
1134
1135            let snapshot = if reuse_shared_snapshot {
1136                this.read_with(cx, |this, _| {
1137                    this.shared_buffers.get(&buffer.clone()).cloned()
1138                })
1139                .log_err()
1140                .flatten()
1141            } else {
1142                None
1143            };
1144
1145            let snapshot = if let Some(snapshot) = snapshot {
1146                snapshot
1147            } else {
1148                action_log.update(cx, |action_log, cx| {
1149                    action_log.buffer_read(buffer.clone(), cx);
1150                })?;
1151                project.update(cx, |project, cx| {
1152                    let position = buffer
1153                        .read(cx)
1154                        .snapshot()
1155                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1156                    project.set_agent_location(
1157                        Some(AgentLocation {
1158                            buffer: buffer.downgrade(),
1159                            position,
1160                        }),
1161                        cx,
1162                    );
1163                })?;
1164
1165                buffer.update(cx, |buffer, _| buffer.snapshot())?
1166            };
1167
1168            this.update(cx, |this, _| {
1169                let text = snapshot.text();
1170                this.shared_buffers.insert(buffer.clone(), snapshot);
1171                if line.is_none() && limit.is_none() {
1172                    return Ok(text);
1173                }
1174                let limit = limit.unwrap_or(u32::MAX) as usize;
1175                let Some(line) = line else {
1176                    return Ok(text.lines().take(limit).collect::<String>());
1177                };
1178
1179                let count = text.lines().count();
1180                if count < line as usize {
1181                    anyhow::bail!("There are only {} lines", count);
1182                }
1183                Ok(text
1184                    .lines()
1185                    .skip(line as usize + 1)
1186                    .take(limit)
1187                    .collect::<String>())
1188            })?
1189        })
1190    }
1191
1192    pub fn write_text_file(
1193        &self,
1194        path: PathBuf,
1195        content: String,
1196        cx: &mut Context<Self>,
1197    ) -> Task<Result<()>> {
1198        let project = self.project.clone();
1199        let action_log = self.action_log.clone();
1200        cx.spawn(async move |this, cx| {
1201            let load = project.update(cx, |project, cx| {
1202                let path = project
1203                    .project_path_for_absolute_path(&path, cx)
1204                    .context("invalid path")?;
1205                anyhow::Ok(project.open_buffer(path, cx))
1206            });
1207            let buffer = load??.await?;
1208            let snapshot = this.update(cx, |this, cx| {
1209                this.shared_buffers
1210                    .get(&buffer)
1211                    .cloned()
1212                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1213            })?;
1214            let edits = cx
1215                .background_executor()
1216                .spawn(async move {
1217                    let old_text = snapshot.text();
1218                    text_diff(old_text.as_str(), &content)
1219                        .into_iter()
1220                        .map(|(range, replacement)| {
1221                            (
1222                                snapshot.anchor_after(range.start)
1223                                    ..snapshot.anchor_before(range.end),
1224                                replacement,
1225                            )
1226                        })
1227                        .collect::<Vec<_>>()
1228                })
1229                .await;
1230            cx.update(|cx| {
1231                project.update(cx, |project, cx| {
1232                    project.set_agent_location(
1233                        Some(AgentLocation {
1234                            buffer: buffer.downgrade(),
1235                            position: edits
1236                                .last()
1237                                .map(|(range, _)| range.end)
1238                                .unwrap_or(Anchor::MIN),
1239                        }),
1240                        cx,
1241                    );
1242                });
1243
1244                action_log.update(cx, |action_log, cx| {
1245                    action_log.buffer_read(buffer.clone(), cx);
1246                });
1247                buffer.update(cx, |buffer, cx| {
1248                    buffer.edit(edits, None, cx);
1249                });
1250                action_log.update(cx, |action_log, cx| {
1251                    action_log.buffer_edited(buffer.clone(), cx);
1252                });
1253            })?;
1254            project
1255                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1256                .await
1257        })
1258    }
1259
1260    pub fn to_markdown(&self, cx: &App) -> String {
1261        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1262    }
1263
1264    pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1265        cx.emit(AcpThreadEvent::ServerExited(status));
1266    }
1267}
1268
1269#[cfg(test)]
1270mod tests {
1271    use super::*;
1272    use anyhow::anyhow;
1273    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1274    use gpui::{AsyncApp, TestAppContext, WeakEntity};
1275    use indoc::indoc;
1276    use project::FakeFs;
1277    use rand::Rng as _;
1278    use serde_json::json;
1279    use settings::SettingsStore;
1280    use smol::stream::StreamExt as _;
1281    use std::{cell::RefCell, rc::Rc, time::Duration};
1282
1283    use util::path;
1284
1285    fn init_test(cx: &mut TestAppContext) {
1286        env_logger::try_init().ok();
1287        cx.update(|cx| {
1288            let settings_store = SettingsStore::test(cx);
1289            cx.set_global(settings_store);
1290            Project::init_settings(cx);
1291            language::init(cx);
1292        });
1293    }
1294
1295    #[gpui::test]
1296    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1297        init_test(cx);
1298
1299        let fs = FakeFs::new(cx.executor());
1300        let project = Project::test(fs, [], cx).await;
1301        let connection = Rc::new(FakeAgentConnection::new());
1302        let thread = cx
1303            .spawn(async move |mut cx| {
1304                connection
1305                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1306                    .await
1307            })
1308            .await
1309            .unwrap();
1310
1311        // Test creating a new user message
1312        thread.update(cx, |thread, cx| {
1313            thread.push_user_content_block(
1314                acp::ContentBlock::Text(acp::TextContent {
1315                    annotations: None,
1316                    text: "Hello, ".to_string(),
1317                }),
1318                cx,
1319            );
1320        });
1321
1322        thread.update(cx, |thread, cx| {
1323            assert_eq!(thread.entries.len(), 1);
1324            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1325                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1326            } else {
1327                panic!("Expected UserMessage");
1328            }
1329        });
1330
1331        // Test appending to existing user message
1332        thread.update(cx, |thread, cx| {
1333            thread.push_user_content_block(
1334                acp::ContentBlock::Text(acp::TextContent {
1335                    annotations: None,
1336                    text: "world!".to_string(),
1337                }),
1338                cx,
1339            );
1340        });
1341
1342        thread.update(cx, |thread, cx| {
1343            assert_eq!(thread.entries.len(), 1);
1344            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1345                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1346            } else {
1347                panic!("Expected UserMessage");
1348            }
1349        });
1350
1351        // Test creating new user message after assistant message
1352        thread.update(cx, |thread, cx| {
1353            thread.push_assistant_content_block(
1354                acp::ContentBlock::Text(acp::TextContent {
1355                    annotations: None,
1356                    text: "Assistant response".to_string(),
1357                }),
1358                false,
1359                cx,
1360            );
1361        });
1362
1363        thread.update(cx, |thread, cx| {
1364            thread.push_user_content_block(
1365                acp::ContentBlock::Text(acp::TextContent {
1366                    annotations: None,
1367                    text: "New user message".to_string(),
1368                }),
1369                cx,
1370            );
1371        });
1372
1373        thread.update(cx, |thread, cx| {
1374            assert_eq!(thread.entries.len(), 3);
1375            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1376                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1377            } else {
1378                panic!("Expected UserMessage at index 2");
1379            }
1380        });
1381    }
1382
1383    #[gpui::test]
1384    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1385        init_test(cx);
1386
1387        let fs = FakeFs::new(cx.executor());
1388        let project = Project::test(fs, [], cx).await;
1389        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1390            |_, thread, mut cx| {
1391                async move {
1392                    thread.update(&mut cx, |thread, cx| {
1393                        thread
1394                            .handle_session_update(
1395                                acp::SessionUpdate::AgentThoughtChunk {
1396                                    content: "Thinking ".into(),
1397                                },
1398                                cx,
1399                            )
1400                            .unwrap();
1401                        thread
1402                            .handle_session_update(
1403                                acp::SessionUpdate::AgentThoughtChunk {
1404                                    content: "hard!".into(),
1405                                },
1406                                cx,
1407                            )
1408                            .unwrap();
1409                    })?;
1410                    Ok(acp::PromptResponse {
1411                        stop_reason: acp::StopReason::EndTurn,
1412                    })
1413                }
1414                .boxed_local()
1415            },
1416        ));
1417
1418        let thread = cx
1419            .spawn(async move |mut cx| {
1420                connection
1421                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1422                    .await
1423            })
1424            .await
1425            .unwrap();
1426
1427        thread
1428            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1429            .await
1430            .unwrap();
1431
1432        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1433        assert_eq!(
1434            output,
1435            indoc! {r#"
1436            ## User
1437
1438            Hello from Zed!
1439
1440            ## Assistant
1441
1442            <thinking>
1443            Thinking hard!
1444            </thinking>
1445
1446            "#}
1447        );
1448    }
1449
1450    #[gpui::test]
1451    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1452        init_test(cx);
1453
1454        let fs = FakeFs::new(cx.executor());
1455        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1456            .await;
1457        let project = Project::test(fs.clone(), [], cx).await;
1458        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1459        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1460        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1461            move |_, thread, mut cx| {
1462                let read_file_tx = read_file_tx.clone();
1463                async move {
1464                    let content = thread
1465                        .update(&mut cx, |thread, cx| {
1466                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1467                        })
1468                        .unwrap()
1469                        .await
1470                        .unwrap();
1471                    assert_eq!(content, "one\ntwo\nthree\n");
1472                    read_file_tx.take().unwrap().send(()).unwrap();
1473                    thread
1474                        .update(&mut cx, |thread, cx| {
1475                            thread.write_text_file(
1476                                path!("/tmp/foo").into(),
1477                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
1478                                cx,
1479                            )
1480                        })
1481                        .unwrap()
1482                        .await
1483                        .unwrap();
1484                    Ok(acp::PromptResponse {
1485                        stop_reason: acp::StopReason::EndTurn,
1486                    })
1487                }
1488                .boxed_local()
1489            },
1490        ));
1491
1492        let (worktree, pathbuf) = project
1493            .update(cx, |project, cx| {
1494                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1495            })
1496            .await
1497            .unwrap();
1498        let buffer = project
1499            .update(cx, |project, cx| {
1500                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1501            })
1502            .await
1503            .unwrap();
1504
1505        let thread = cx
1506            .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1507            .await
1508            .unwrap();
1509
1510        let request = thread.update(cx, |thread, cx| {
1511            thread.send_raw("Extend the count in /tmp/foo", cx)
1512        });
1513        read_file_rx.await.ok();
1514        buffer.update(cx, |buffer, cx| {
1515            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1516        });
1517        cx.run_until_parked();
1518        assert_eq!(
1519            buffer.read_with(cx, |buffer, _| buffer.text()),
1520            "zero\none\ntwo\nthree\nfour\nfive\n"
1521        );
1522        assert_eq!(
1523            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1524            "zero\none\ntwo\nthree\nfour\nfive\n"
1525        );
1526        request.await.unwrap();
1527    }
1528
1529    #[gpui::test]
1530    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1531        init_test(cx);
1532
1533        let fs = FakeFs::new(cx.executor());
1534        let project = Project::test(fs, [], cx).await;
1535        let id = acp::ToolCallId("test".into());
1536
1537        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1538            let id = id.clone();
1539            move |_, thread, mut cx| {
1540                let id = id.clone();
1541                async move {
1542                    thread
1543                        .update(&mut cx, |thread, cx| {
1544                            thread.handle_session_update(
1545                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1546                                    id: id.clone(),
1547                                    title: "Label".into(),
1548                                    kind: acp::ToolKind::Fetch,
1549                                    status: acp::ToolCallStatus::InProgress,
1550                                    content: vec![],
1551                                    locations: vec![],
1552                                    raw_input: None,
1553                                    raw_output: None,
1554                                }),
1555                                cx,
1556                            )
1557                        })
1558                        .unwrap()
1559                        .unwrap();
1560                    Ok(acp::PromptResponse {
1561                        stop_reason: acp::StopReason::EndTurn,
1562                    })
1563                }
1564                .boxed_local()
1565            }
1566        }));
1567
1568        let thread = cx
1569            .spawn(async move |mut cx| {
1570                connection
1571                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1572                    .await
1573            })
1574            .await
1575            .unwrap();
1576
1577        let request = thread.update(cx, |thread, cx| {
1578            thread.send_raw("Fetch https://example.com", cx)
1579        });
1580
1581        run_until_first_tool_call(&thread, cx).await;
1582
1583        thread.read_with(cx, |thread, _| {
1584            assert!(matches!(
1585                thread.entries[1],
1586                AgentThreadEntry::ToolCall(ToolCall {
1587                    status: ToolCallStatus::Allowed {
1588                        status: acp::ToolCallStatus::InProgress,
1589                        ..
1590                    },
1591                    ..
1592                })
1593            ));
1594        });
1595
1596        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1597
1598        thread.read_with(cx, |thread, _| {
1599            assert!(matches!(
1600                &thread.entries[1],
1601                AgentThreadEntry::ToolCall(ToolCall {
1602                    status: ToolCallStatus::Canceled,
1603                    ..
1604                })
1605            ));
1606        });
1607
1608        thread
1609            .update(cx, |thread, cx| {
1610                thread.handle_session_update(
1611                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1612                        id,
1613                        fields: acp::ToolCallUpdateFields {
1614                            status: Some(acp::ToolCallStatus::Completed),
1615                            ..Default::default()
1616                        },
1617                    }),
1618                    cx,
1619                )
1620            })
1621            .unwrap();
1622
1623        request.await.unwrap();
1624
1625        thread.read_with(cx, |thread, _| {
1626            assert!(matches!(
1627                thread.entries[1],
1628                AgentThreadEntry::ToolCall(ToolCall {
1629                    status: ToolCallStatus::Allowed {
1630                        status: acp::ToolCallStatus::Completed,
1631                        ..
1632                    },
1633                    ..
1634                })
1635            ));
1636        });
1637    }
1638
1639    #[gpui::test]
1640    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1641        init_test(cx);
1642        let fs = FakeFs::new(cx.background_executor.clone());
1643        fs.insert_tree(path!("/test"), json!({})).await;
1644        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1645
1646        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1647            move |_, thread, mut cx| {
1648                async move {
1649                    thread
1650                        .update(&mut cx, |thread, cx| {
1651                            thread.handle_session_update(
1652                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1653                                    id: acp::ToolCallId("test".into()),
1654                                    title: "Label".into(),
1655                                    kind: acp::ToolKind::Edit,
1656                                    status: acp::ToolCallStatus::Completed,
1657                                    content: vec![acp::ToolCallContent::Diff {
1658                                        diff: acp::Diff {
1659                                            path: "/test/test.txt".into(),
1660                                            old_text: None,
1661                                            new_text: "foo".into(),
1662                                        },
1663                                    }],
1664                                    locations: vec![],
1665                                    raw_input: None,
1666                                    raw_output: None,
1667                                }),
1668                                cx,
1669                            )
1670                        })
1671                        .unwrap()
1672                        .unwrap();
1673                    Ok(acp::PromptResponse {
1674                        stop_reason: acp::StopReason::EndTurn,
1675                    })
1676                }
1677                .boxed_local()
1678            }
1679        }));
1680
1681        let thread = connection
1682            .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1683            .await
1684            .unwrap();
1685        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1686            .await
1687            .unwrap();
1688
1689        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1690    }
1691
1692    async fn run_until_first_tool_call(
1693        thread: &Entity<AcpThread>,
1694        cx: &mut TestAppContext,
1695    ) -> usize {
1696        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1697
1698        let subscription = cx.update(|cx| {
1699            cx.subscribe(thread, move |thread, _, cx| {
1700                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1701                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1702                        return tx.try_send(ix).unwrap();
1703                    }
1704                }
1705            })
1706        });
1707
1708        select! {
1709            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1710                panic!("Timeout waiting for tool call")
1711            }
1712            ix = rx.next().fuse() => {
1713                drop(subscription);
1714                ix.unwrap()
1715            }
1716        }
1717    }
1718
1719    #[derive(Clone, Default)]
1720    struct FakeAgentConnection {
1721        auth_methods: Vec<acp::AuthMethod>,
1722        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1723        on_user_message: Option<
1724            Rc<
1725                dyn Fn(
1726                        acp::PromptRequest,
1727                        WeakEntity<AcpThread>,
1728                        AsyncApp,
1729                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1730                    + 'static,
1731            >,
1732        >,
1733    }
1734
1735    impl FakeAgentConnection {
1736        fn new() -> Self {
1737            Self {
1738                auth_methods: Vec::new(),
1739                on_user_message: None,
1740                sessions: Arc::default(),
1741            }
1742        }
1743
1744        #[expect(unused)]
1745        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1746            self.auth_methods = auth_methods;
1747            self
1748        }
1749
1750        fn on_user_message(
1751            mut self,
1752            handler: impl Fn(
1753                acp::PromptRequest,
1754                WeakEntity<AcpThread>,
1755                AsyncApp,
1756            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1757            + 'static,
1758        ) -> Self {
1759            self.on_user_message.replace(Rc::new(handler));
1760            self
1761        }
1762    }
1763
1764    impl AgentConnection for FakeAgentConnection {
1765        fn auth_methods(&self) -> &[acp::AuthMethod] {
1766            &self.auth_methods
1767        }
1768
1769        fn new_thread(
1770            self: Rc<Self>,
1771            project: Entity<Project>,
1772            _cwd: &Path,
1773            cx: &mut gpui::AsyncApp,
1774        ) -> Task<gpui::Result<Entity<AcpThread>>> {
1775            let session_id = acp::SessionId(
1776                rand::thread_rng()
1777                    .sample_iter(&rand::distributions::Alphanumeric)
1778                    .take(7)
1779                    .map(char::from)
1780                    .collect::<String>()
1781                    .into(),
1782            );
1783            let thread = cx
1784                .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1785                .unwrap();
1786            self.sessions.lock().insert(session_id, thread.downgrade());
1787            Task::ready(Ok(thread))
1788        }
1789
1790        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1791            if self.auth_methods().iter().any(|m| m.id == method) {
1792                Task::ready(Ok(()))
1793            } else {
1794                Task::ready(Err(anyhow!("Invalid Auth Method")))
1795            }
1796        }
1797
1798        fn prompt(
1799            &self,
1800            params: acp::PromptRequest,
1801            cx: &mut App,
1802        ) -> Task<gpui::Result<acp::PromptResponse>> {
1803            let sessions = self.sessions.lock();
1804            let thread = sessions.get(&params.session_id).unwrap();
1805            if let Some(handler) = &self.on_user_message {
1806                let handler = handler.clone();
1807                let thread = thread.clone();
1808                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1809            } else {
1810                Task::ready(Ok(acp::PromptResponse {
1811                    stop_reason: acp::StopReason::EndTurn,
1812                }))
1813            }
1814        }
1815
1816        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1817            let sessions = self.sessions.lock();
1818            let thread = sessions.get(&session_id).unwrap().clone();
1819
1820            cx.spawn(async move |cx| {
1821                thread
1822                    .update(cx, |thread, cx| thread.cancel(cx))
1823                    .unwrap()
1824                    .await
1825            })
1826            .detach();
1827        }
1828    }
1829}