acp_thread.rs

   1mod connection;
   2pub use connection::*;
   3
   4use agent_client_protocol as acp;
   5use anyhow::{Context as _, Result};
   6use assistant_tool::ActionLog;
   7use buffer_diff::BufferDiff;
   8use editor::{Bias, MultiBuffer, PathKey};
   9use futures::future::{Fuse, FusedFuture};
  10use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  11use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
  12use itertools::Itertools;
  13use language::{
  14    Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
  15    text_diff,
  16};
  17use markdown::Markdown;
  18use project::{AgentLocation, Project};
  19use std::collections::HashMap;
  20use std::error::Error;
  21use std::fmt::Formatter;
  22use std::process::ExitStatus;
  23use std::rc::Rc;
  24use std::{
  25    fmt::Display,
  26    mem,
  27    path::{Path, PathBuf},
  28    sync::Arc,
  29};
  30use ui::App;
  31use util::ResultExt;
  32
  33#[derive(Debug)]
  34pub struct UserMessage {
  35    pub content: ContentBlock,
  36}
  37
  38impl UserMessage {
  39    pub fn from_acp(
  40        message: impl IntoIterator<Item = acp::ContentBlock>,
  41        language_registry: Arc<LanguageRegistry>,
  42        cx: &mut App,
  43    ) -> Self {
  44        let mut content = ContentBlock::Empty;
  45        for chunk in message {
  46            content.append(chunk, &language_registry, cx)
  47        }
  48        Self { content: content }
  49    }
  50
  51    fn to_markdown(&self, cx: &App) -> String {
  52        format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
  53    }
  54}
  55
  56#[derive(Debug)]
  57pub struct MentionPath<'a>(&'a Path);
  58
  59impl<'a> MentionPath<'a> {
  60    const PREFIX: &'static str = "@file:";
  61
  62    pub fn new(path: &'a Path) -> Self {
  63        MentionPath(path)
  64    }
  65
  66    pub fn try_parse(url: &'a str) -> Option<Self> {
  67        let path = url.strip_prefix(Self::PREFIX)?;
  68        Some(MentionPath(Path::new(path)))
  69    }
  70
  71    pub fn path(&self) -> &Path {
  72        self.0
  73    }
  74}
  75
  76impl Display for MentionPath<'_> {
  77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  78        write!(
  79            f,
  80            "[@{}]({}{})",
  81            self.0.file_name().unwrap_or_default().display(),
  82            Self::PREFIX,
  83            self.0.display()
  84        )
  85    }
  86}
  87
  88#[derive(Debug, PartialEq)]
  89pub struct AssistantMessage {
  90    pub chunks: Vec<AssistantMessageChunk>,
  91}
  92
  93impl AssistantMessage {
  94    pub fn to_markdown(&self, cx: &App) -> String {
  95        format!(
  96            "## Assistant\n\n{}\n\n",
  97            self.chunks
  98                .iter()
  99                .map(|chunk| chunk.to_markdown(cx))
 100                .join("\n\n")
 101        )
 102    }
 103}
 104
 105#[derive(Debug, PartialEq)]
 106pub enum AssistantMessageChunk {
 107    Message { block: ContentBlock },
 108    Thought { block: ContentBlock },
 109}
 110
 111impl AssistantMessageChunk {
 112    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
 113        Self::Message {
 114            block: ContentBlock::new(chunk.into(), language_registry, cx),
 115        }
 116    }
 117
 118    fn to_markdown(&self, cx: &App) -> String {
 119        match self {
 120            Self::Message { block } => block.to_markdown(cx).to_string(),
 121            Self::Thought { block } => {
 122                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 123            }
 124        }
 125    }
 126}
 127
 128#[derive(Debug)]
 129pub enum AgentThreadEntry {
 130    UserMessage(UserMessage),
 131    AssistantMessage(AssistantMessage),
 132    ToolCall(ToolCall),
 133}
 134
 135impl AgentThreadEntry {
 136    fn to_markdown(&self, cx: &App) -> String {
 137        match self {
 138            Self::UserMessage(message) => message.to_markdown(cx),
 139            Self::AssistantMessage(message) => message.to_markdown(cx),
 140            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 141        }
 142    }
 143
 144    pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
 145        if let AgentThreadEntry::ToolCall(call) = self {
 146            itertools::Either::Left(call.diffs())
 147        } else {
 148            itertools::Either::Right(std::iter::empty())
 149        }
 150    }
 151
 152    pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
 153        if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
 154            Some(locations)
 155        } else {
 156            None
 157        }
 158    }
 159}
 160
 161#[derive(Debug)]
 162pub struct ToolCall {
 163    pub id: acp::ToolCallId,
 164    pub label: Entity<Markdown>,
 165    pub kind: acp::ToolKind,
 166    pub content: Vec<ToolCallContent>,
 167    pub status: ToolCallStatus,
 168    pub locations: Vec<acp::ToolCallLocation>,
 169    pub raw_input: Option<serde_json::Value>,
 170}
 171
 172impl ToolCall {
 173    fn from_acp(
 174        tool_call: acp::ToolCall,
 175        status: ToolCallStatus,
 176        language_registry: Arc<LanguageRegistry>,
 177        cx: &mut App,
 178    ) -> Self {
 179        Self {
 180            id: tool_call.id,
 181            label: cx.new(|cx| {
 182                Markdown::new(
 183                    tool_call.title.into(),
 184                    Some(language_registry.clone()),
 185                    None,
 186                    cx,
 187                )
 188            }),
 189            kind: tool_call.kind,
 190            content: tool_call
 191                .content
 192                .into_iter()
 193                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 194                .collect(),
 195            locations: tool_call.locations,
 196            status,
 197            raw_input: tool_call.raw_input,
 198        }
 199    }
 200
 201    fn update(
 202        &mut self,
 203        fields: acp::ToolCallUpdateFields,
 204        language_registry: Arc<LanguageRegistry>,
 205        cx: &mut App,
 206    ) {
 207        let acp::ToolCallUpdateFields {
 208            kind,
 209            status,
 210            title,
 211            content,
 212            locations,
 213            raw_input,
 214        } = fields;
 215
 216        if let Some(kind) = kind {
 217            self.kind = kind;
 218        }
 219
 220        if let Some(status) = status {
 221            self.status = ToolCallStatus::Allowed { status };
 222        }
 223
 224        if let Some(title) = title {
 225            self.label.update(cx, |label, cx| {
 226                label.replace(title, cx);
 227            });
 228        }
 229
 230        if let Some(content) = content {
 231            self.content = content
 232                .into_iter()
 233                .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
 234                .collect();
 235        }
 236
 237        if let Some(locations) = locations {
 238            self.locations = locations;
 239        }
 240
 241        if let Some(raw_input) = raw_input {
 242            self.raw_input = Some(raw_input);
 243        }
 244    }
 245
 246    pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
 247        self.content.iter().filter_map(|content| match content {
 248            ToolCallContent::ContentBlock { .. } => None,
 249            ToolCallContent::Diff { diff } => Some(diff),
 250        })
 251    }
 252
 253    fn to_markdown(&self, cx: &App) -> String {
 254        let mut markdown = format!(
 255            "**Tool Call: {}**\nStatus: {}\n\n",
 256            self.label.read(cx).source(),
 257            self.status
 258        );
 259        for content in &self.content {
 260            markdown.push_str(content.to_markdown(cx).as_str());
 261            markdown.push_str("\n\n");
 262        }
 263        markdown
 264    }
 265}
 266
 267#[derive(Debug)]
 268pub enum ToolCallStatus {
 269    WaitingForConfirmation {
 270        options: Vec<acp::PermissionOption>,
 271        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 272    },
 273    Allowed {
 274        status: acp::ToolCallStatus,
 275    },
 276    Rejected,
 277    Canceled,
 278}
 279
 280impl Display for ToolCallStatus {
 281    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 282        write!(
 283            f,
 284            "{}",
 285            match self {
 286                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 287                ToolCallStatus::Allowed { status } => match status {
 288                    acp::ToolCallStatus::Pending => "Pending",
 289                    acp::ToolCallStatus::InProgress => "In Progress",
 290                    acp::ToolCallStatus::Completed => "Completed",
 291                    acp::ToolCallStatus::Failed => "Failed",
 292                },
 293                ToolCallStatus::Rejected => "Rejected",
 294                ToolCallStatus::Canceled => "Canceled",
 295            }
 296        )
 297    }
 298}
 299
 300#[derive(Debug, PartialEq, Clone)]
 301pub enum ContentBlock {
 302    Empty,
 303    Markdown { markdown: Entity<Markdown> },
 304}
 305
 306impl ContentBlock {
 307    pub fn new(
 308        block: acp::ContentBlock,
 309        language_registry: &Arc<LanguageRegistry>,
 310        cx: &mut App,
 311    ) -> Self {
 312        let mut this = Self::Empty;
 313        this.append(block, language_registry, cx);
 314        this
 315    }
 316
 317    pub fn new_combined(
 318        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 319        language_registry: Arc<LanguageRegistry>,
 320        cx: &mut App,
 321    ) -> Self {
 322        let mut this = Self::Empty;
 323        for block in blocks {
 324            this.append(block, &language_registry, cx);
 325        }
 326        this
 327    }
 328
 329    pub fn append(
 330        &mut self,
 331        block: acp::ContentBlock,
 332        language_registry: &Arc<LanguageRegistry>,
 333        cx: &mut App,
 334    ) {
 335        let new_content = match block {
 336            acp::ContentBlock::Text(text_content) => text_content.text.clone(),
 337            acp::ContentBlock::ResourceLink(resource_link) => {
 338                if let Some(path) = resource_link.uri.strip_prefix("file://") {
 339                    format!("{}", MentionPath(path.as_ref()))
 340                } else {
 341                    resource_link.uri.clone()
 342                }
 343            }
 344            acp::ContentBlock::Image(_)
 345            | acp::ContentBlock::Audio(_)
 346            | acp::ContentBlock::Resource(_) => String::new(),
 347        };
 348
 349        match self {
 350            ContentBlock::Empty => {
 351                *self = ContentBlock::Markdown {
 352                    markdown: cx.new(|cx| {
 353                        Markdown::new(
 354                            new_content.into(),
 355                            Some(language_registry.clone()),
 356                            None,
 357                            cx,
 358                        )
 359                    }),
 360                };
 361            }
 362            ContentBlock::Markdown { markdown } => {
 363                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 364            }
 365        }
 366    }
 367
 368    fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 369        match self {
 370            ContentBlock::Empty => "",
 371            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 372        }
 373    }
 374
 375    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 376        match self {
 377            ContentBlock::Empty => None,
 378            ContentBlock::Markdown { markdown } => Some(markdown),
 379        }
 380    }
 381}
 382
 383#[derive(Debug)]
 384pub enum ToolCallContent {
 385    ContentBlock { content: ContentBlock },
 386    Diff { diff: Diff },
 387}
 388
 389impl ToolCallContent {
 390    pub fn from_acp(
 391        content: acp::ToolCallContent,
 392        language_registry: Arc<LanguageRegistry>,
 393        cx: &mut App,
 394    ) -> Self {
 395        match content {
 396            acp::ToolCallContent::Content { content } => Self::ContentBlock {
 397                content: ContentBlock::new(content, &language_registry, cx),
 398            },
 399            acp::ToolCallContent::Diff { diff } => Self::Diff {
 400                diff: Diff::from_acp(diff, language_registry, cx),
 401            },
 402        }
 403    }
 404
 405    pub fn to_markdown(&self, cx: &App) -> String {
 406        match self {
 407            Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
 408            Self::Diff { diff } => diff.to_markdown(cx),
 409        }
 410    }
 411}
 412
 413#[derive(Debug)]
 414pub struct Diff {
 415    pub multibuffer: Entity<MultiBuffer>,
 416    pub path: PathBuf,
 417    _task: Task<Result<()>>,
 418}
 419
 420impl Diff {
 421    pub fn from_acp(
 422        diff: acp::Diff,
 423        language_registry: Arc<LanguageRegistry>,
 424        cx: &mut App,
 425    ) -> Self {
 426        let acp::Diff {
 427            path,
 428            old_text,
 429            new_text,
 430        } = diff;
 431
 432        let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
 433
 434        let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
 435        let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
 436        let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
 437        let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
 438
 439        let task = cx.spawn({
 440            let multibuffer = multibuffer.clone();
 441            let path = path.clone();
 442            async move |cx| {
 443                let language = language_registry
 444                    .language_for_file_path(&path)
 445                    .await
 446                    .log_err();
 447
 448                new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?;
 449
 450                let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| {
 451                    buffer.set_language(language, cx);
 452                    buffer.snapshot()
 453                })?;
 454
 455                buffer_diff
 456                    .update(cx, |diff, cx| {
 457                        diff.set_base_text(
 458                            old_buffer_snapshot,
 459                            Some(language_registry),
 460                            new_buffer_snapshot,
 461                            cx,
 462                        )
 463                    })?
 464                    .await?;
 465
 466                multibuffer
 467                    .update(cx, |multibuffer, cx| {
 468                        let hunk_ranges = {
 469                            let buffer = new_buffer.read(cx);
 470                            let diff = buffer_diff.read(cx);
 471                            diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
 472                                .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
 473                                .collect::<Vec<_>>()
 474                        };
 475
 476                        multibuffer.set_excerpts_for_path(
 477                            PathKey::for_buffer(&new_buffer, cx),
 478                            new_buffer.clone(),
 479                            hunk_ranges,
 480                            editor::DEFAULT_MULTIBUFFER_CONTEXT,
 481                            cx,
 482                        );
 483                        multibuffer.add_diff(buffer_diff, cx);
 484                    })
 485                    .log_err();
 486
 487                anyhow::Ok(())
 488            }
 489        });
 490
 491        Self {
 492            multibuffer,
 493            path,
 494            _task: task,
 495        }
 496    }
 497
 498    fn to_markdown(&self, cx: &App) -> String {
 499        let buffer_text = self
 500            .multibuffer
 501            .read(cx)
 502            .all_buffers()
 503            .iter()
 504            .map(|buffer| buffer.read(cx).text())
 505            .join("\n");
 506        format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
 507    }
 508}
 509
 510#[derive(Debug, Default)]
 511pub struct Plan {
 512    pub entries: Vec<PlanEntry>,
 513}
 514
 515#[derive(Debug)]
 516pub struct PlanStats<'a> {
 517    pub in_progress_entry: Option<&'a PlanEntry>,
 518    pub pending: u32,
 519    pub completed: u32,
 520}
 521
 522impl Plan {
 523    pub fn is_empty(&self) -> bool {
 524        self.entries.is_empty()
 525    }
 526
 527    pub fn stats(&self) -> PlanStats<'_> {
 528        let mut stats = PlanStats {
 529            in_progress_entry: None,
 530            pending: 0,
 531            completed: 0,
 532        };
 533
 534        for entry in &self.entries {
 535            match &entry.status {
 536                acp::PlanEntryStatus::Pending => {
 537                    stats.pending += 1;
 538                }
 539                acp::PlanEntryStatus::InProgress => {
 540                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 541                }
 542                acp::PlanEntryStatus::Completed => {
 543                    stats.completed += 1;
 544                }
 545            }
 546        }
 547
 548        stats
 549    }
 550}
 551
 552#[derive(Debug)]
 553pub struct PlanEntry {
 554    pub content: Entity<Markdown>,
 555    pub priority: acp::PlanEntryPriority,
 556    pub status: acp::PlanEntryStatus,
 557}
 558
 559impl PlanEntry {
 560    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 561        Self {
 562            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 563            priority: entry.priority,
 564            status: entry.status,
 565        }
 566    }
 567}
 568
 569pub struct AcpThread {
 570    title: SharedString,
 571    entries: Vec<AgentThreadEntry>,
 572    plan: Plan,
 573    project: Entity<Project>,
 574    action_log: Entity<ActionLog>,
 575    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 576    send_task: Option<Fuse<Task<()>>>,
 577    connection: Rc<dyn AgentConnection>,
 578    session_id: acp::SessionId,
 579}
 580
 581pub enum AcpThreadEvent {
 582    NewEntry,
 583    EntryUpdated(usize),
 584    ToolAuthorizationRequired,
 585    Stopped,
 586    Error,
 587    ServerExited(ExitStatus),
 588}
 589
 590impl EventEmitter<AcpThreadEvent> for AcpThread {}
 591
 592#[derive(PartialEq, Eq)]
 593pub enum ThreadStatus {
 594    Idle,
 595    WaitingForToolConfirmation,
 596    Generating,
 597}
 598
 599#[derive(Debug, Clone)]
 600pub enum LoadError {
 601    Unsupported {
 602        error_message: SharedString,
 603        upgrade_message: SharedString,
 604        upgrade_command: String,
 605    },
 606    Exited(i32),
 607    Other(SharedString),
 608}
 609
 610impl Display for LoadError {
 611    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 612        match self {
 613            LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
 614            LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
 615            LoadError::Other(msg) => write!(f, "{}", msg),
 616        }
 617    }
 618}
 619
 620impl Error for LoadError {}
 621
 622impl AcpThread {
 623    pub fn new(
 624        title: impl Into<SharedString>,
 625        connection: Rc<dyn AgentConnection>,
 626        project: Entity<Project>,
 627        session_id: acp::SessionId,
 628        cx: &mut Context<Self>,
 629    ) -> Self {
 630        let action_log = cx.new(|_| ActionLog::new(project.clone()));
 631
 632        Self {
 633            action_log,
 634            shared_buffers: Default::default(),
 635            entries: Default::default(),
 636            plan: Default::default(),
 637            title: title.into(),
 638            project,
 639            send_task: None,
 640            connection,
 641            session_id,
 642        }
 643    }
 644
 645    pub fn action_log(&self) -> &Entity<ActionLog> {
 646        &self.action_log
 647    }
 648
 649    pub fn project(&self) -> &Entity<Project> {
 650        &self.project
 651    }
 652
 653    pub fn title(&self) -> SharedString {
 654        self.title.clone()
 655    }
 656
 657    pub fn entries(&self) -> &[AgentThreadEntry] {
 658        &self.entries
 659    }
 660
 661    pub fn session_id(&self) -> &acp::SessionId {
 662        &self.session_id
 663    }
 664
 665    pub fn status(&self) -> ThreadStatus {
 666        if self
 667            .send_task
 668            .as_ref()
 669            .map_or(false, |t| !t.is_terminated())
 670        {
 671            if self.waiting_for_tool_confirmation() {
 672                ThreadStatus::WaitingForToolConfirmation
 673            } else {
 674                ThreadStatus::Generating
 675            }
 676        } else {
 677            ThreadStatus::Idle
 678        }
 679    }
 680
 681    pub fn has_pending_edit_tool_calls(&self) -> bool {
 682        for entry in self.entries.iter().rev() {
 683            match entry {
 684                AgentThreadEntry::UserMessage(_) => return false,
 685                AgentThreadEntry::ToolCall(
 686                    call @ ToolCall {
 687                        status:
 688                            ToolCallStatus::Allowed {
 689                                status:
 690                                    acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
 691                            },
 692                        ..
 693                    },
 694                ) if call.diffs().next().is_some() => {
 695                    return true;
 696                }
 697                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 698            }
 699        }
 700
 701        false
 702    }
 703
 704    pub fn used_tools_since_last_user_message(&self) -> bool {
 705        for entry in self.entries.iter().rev() {
 706            match entry {
 707                AgentThreadEntry::UserMessage(..) => return false,
 708                AgentThreadEntry::AssistantMessage(..) => continue,
 709                AgentThreadEntry::ToolCall(..) => return true,
 710            }
 711        }
 712
 713        false
 714    }
 715
 716    pub fn handle_session_update(
 717        &mut self,
 718        update: acp::SessionUpdate,
 719        cx: &mut Context<Self>,
 720    ) -> Result<()> {
 721        match update {
 722            acp::SessionUpdate::UserMessageChunk { content } => {
 723                self.push_user_content_block(content, cx);
 724            }
 725            acp::SessionUpdate::AgentMessageChunk { content } => {
 726                self.push_assistant_content_block(content, false, cx);
 727            }
 728            acp::SessionUpdate::AgentThoughtChunk { content } => {
 729                self.push_assistant_content_block(content, true, cx);
 730            }
 731            acp::SessionUpdate::ToolCall(tool_call) => {
 732                self.upsert_tool_call(tool_call, cx);
 733            }
 734            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 735                self.update_tool_call(tool_call_update, cx)?;
 736            }
 737            acp::SessionUpdate::Plan(plan) => {
 738                self.update_plan(plan, cx);
 739            }
 740        }
 741        Ok(())
 742    }
 743
 744    pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
 745        let language_registry = self.project.read(cx).languages().clone();
 746        let entries_len = self.entries.len();
 747
 748        if let Some(last_entry) = self.entries.last_mut()
 749            && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
 750        {
 751            content.append(chunk, &language_registry, cx);
 752            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 753        } else {
 754            let content = ContentBlock::new(chunk, &language_registry, cx);
 755            self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
 756        }
 757    }
 758
 759    pub fn push_assistant_content_block(
 760        &mut self,
 761        chunk: acp::ContentBlock,
 762        is_thought: bool,
 763        cx: &mut Context<Self>,
 764    ) {
 765        let language_registry = self.project.read(cx).languages().clone();
 766        let entries_len = self.entries.len();
 767        if let Some(last_entry) = self.entries.last_mut()
 768            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
 769        {
 770            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 771            match (chunks.last_mut(), is_thought) {
 772                (Some(AssistantMessageChunk::Message { block }), false)
 773                | (Some(AssistantMessageChunk::Thought { block }), true) => {
 774                    block.append(chunk, &language_registry, cx)
 775                }
 776                _ => {
 777                    let block = ContentBlock::new(chunk, &language_registry, cx);
 778                    if is_thought {
 779                        chunks.push(AssistantMessageChunk::Thought { block })
 780                    } else {
 781                        chunks.push(AssistantMessageChunk::Message { block })
 782                    }
 783                }
 784            }
 785        } else {
 786            let block = ContentBlock::new(chunk, &language_registry, cx);
 787            let chunk = if is_thought {
 788                AssistantMessageChunk::Thought { block }
 789            } else {
 790                AssistantMessageChunk::Message { block }
 791            };
 792
 793            self.push_entry(
 794                AgentThreadEntry::AssistantMessage(AssistantMessage {
 795                    chunks: vec![chunk],
 796                }),
 797                cx,
 798            );
 799        }
 800    }
 801
 802    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
 803        self.entries.push(entry);
 804        cx.emit(AcpThreadEvent::NewEntry);
 805    }
 806
 807    pub fn update_tool_call(
 808        &mut self,
 809        update: acp::ToolCallUpdate,
 810        cx: &mut Context<Self>,
 811    ) -> Result<()> {
 812        let languages = self.project.read(cx).languages().clone();
 813
 814        let (ix, current_call) = self
 815            .tool_call_mut(&update.id)
 816            .context("Tool call not found")?;
 817        current_call.update(update.fields, languages, cx);
 818
 819        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 820
 821        Ok(())
 822    }
 823
 824    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
 825    pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
 826        let status = ToolCallStatus::Allowed {
 827            status: tool_call.status,
 828        };
 829        self.upsert_tool_call_inner(tool_call, status, cx)
 830    }
 831
 832    pub fn upsert_tool_call_inner(
 833        &mut self,
 834        tool_call: acp::ToolCall,
 835        status: ToolCallStatus,
 836        cx: &mut Context<Self>,
 837    ) {
 838        let language_registry = self.project.read(cx).languages().clone();
 839        let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
 840
 841        let location = call.locations.last().cloned();
 842
 843        if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
 844            *current_call = call;
 845
 846            cx.emit(AcpThreadEvent::EntryUpdated(ix));
 847        } else {
 848            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
 849        }
 850
 851        if let Some(location) = location {
 852            self.set_project_location(location, cx)
 853        }
 854    }
 855
 856    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
 857        // The tool call we are looking for is typically the last one, or very close to the end.
 858        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
 859        self.entries
 860            .iter_mut()
 861            .enumerate()
 862            .rev()
 863            .find_map(|(index, tool_call)| {
 864                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
 865                    && &tool_call.id == id
 866                {
 867                    Some((index, tool_call))
 868                } else {
 869                    None
 870                }
 871            })
 872    }
 873
 874    pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
 875        self.project.update(cx, |project, cx| {
 876            let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
 877                return;
 878            };
 879            let buffer = project.open_buffer(path, cx);
 880            cx.spawn(async move |project, cx| {
 881                let buffer = buffer.await?;
 882
 883                project.update(cx, |project, cx| {
 884                    let position = if let Some(line) = location.line {
 885                        let snapshot = buffer.read(cx).snapshot();
 886                        let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
 887                        snapshot.anchor_before(point)
 888                    } else {
 889                        Anchor::MIN
 890                    };
 891
 892                    project.set_agent_location(
 893                        Some(AgentLocation {
 894                            buffer: buffer.downgrade(),
 895                            position,
 896                        }),
 897                        cx,
 898                    );
 899                })
 900            })
 901            .detach_and_log_err(cx);
 902        });
 903    }
 904
 905    pub fn request_tool_call_authorization(
 906        &mut self,
 907        tool_call: acp::ToolCall,
 908        options: Vec<acp::PermissionOption>,
 909        cx: &mut Context<Self>,
 910    ) -> oneshot::Receiver<acp::PermissionOptionId> {
 911        let (tx, rx) = oneshot::channel();
 912
 913        let status = ToolCallStatus::WaitingForConfirmation {
 914            options,
 915            respond_tx: tx,
 916        };
 917
 918        self.upsert_tool_call_inner(tool_call, status, cx);
 919        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
 920        rx
 921    }
 922
 923    pub fn authorize_tool_call(
 924        &mut self,
 925        id: acp::ToolCallId,
 926        option_id: acp::PermissionOptionId,
 927        option_kind: acp::PermissionOptionKind,
 928        cx: &mut Context<Self>,
 929    ) {
 930        let Some((ix, call)) = self.tool_call_mut(&id) else {
 931            return;
 932        };
 933
 934        let new_status = match option_kind {
 935            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
 936                ToolCallStatus::Rejected
 937            }
 938            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
 939                ToolCallStatus::Allowed {
 940                    status: acp::ToolCallStatus::InProgress,
 941                }
 942            }
 943        };
 944
 945        let curr_status = mem::replace(&mut call.status, new_status);
 946
 947        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
 948            respond_tx.send(option_id).log_err();
 949        } else if cfg!(debug_assertions) {
 950            panic!("tried to authorize an already authorized tool call");
 951        }
 952
 953        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 954    }
 955
 956    /// Returns true if the last turn is awaiting tool authorization
 957    pub fn waiting_for_tool_confirmation(&self) -> bool {
 958        for entry in self.entries.iter().rev() {
 959            match &entry {
 960                AgentThreadEntry::ToolCall(call) => match call.status {
 961                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
 962                    ToolCallStatus::Allowed { .. }
 963                    | ToolCallStatus::Rejected
 964                    | ToolCallStatus::Canceled => continue,
 965                },
 966                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
 967                    // Reached the beginning of the turn
 968                    return false;
 969                }
 970            }
 971        }
 972        false
 973    }
 974
 975    pub fn plan(&self) -> &Plan {
 976        &self.plan
 977    }
 978
 979    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
 980        let new_entries_len = request.entries.len();
 981        let mut new_entries = request.entries.into_iter();
 982
 983        // Reuse existing markdown to prevent flickering
 984        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
 985            let PlanEntry {
 986                content,
 987                priority,
 988                status,
 989            } = old;
 990            content.update(cx, |old, cx| {
 991                old.replace(new.content, cx);
 992            });
 993            *priority = new.priority;
 994            *status = new.status;
 995        }
 996        for new in new_entries {
 997            self.plan.entries.push(PlanEntry::from_acp(new, cx))
 998        }
 999        self.plan.entries.truncate(new_entries_len);
1000
1001        cx.notify();
1002    }
1003
1004    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1005        self.plan
1006            .entries
1007            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1008        cx.notify();
1009    }
1010
1011    #[cfg(any(test, feature = "test-support"))]
1012    pub fn send_raw(
1013        &mut self,
1014        message: &str,
1015        cx: &mut Context<Self>,
1016    ) -> BoxFuture<'static, Result<()>> {
1017        self.send(
1018            vec![acp::ContentBlock::Text(acp::TextContent {
1019                text: message.to_string(),
1020                annotations: None,
1021            })],
1022            cx,
1023        )
1024    }
1025
1026    pub fn send(
1027        &mut self,
1028        message: Vec<acp::ContentBlock>,
1029        cx: &mut Context<Self>,
1030    ) -> BoxFuture<'static, Result<()>> {
1031        let block = ContentBlock::new_combined(
1032            message.clone(),
1033            self.project.read(cx).languages().clone(),
1034            cx,
1035        );
1036        self.push_entry(
1037            AgentThreadEntry::UserMessage(UserMessage { content: block }),
1038            cx,
1039        );
1040        self.clear_completed_plan_entries(cx);
1041
1042        let (tx, rx) = oneshot::channel();
1043        let cancel_task = self.cancel(cx);
1044
1045        self.send_task = Some(
1046            cx.spawn(async move |this, cx| {
1047                async {
1048                    cancel_task.await;
1049
1050                    let result = this
1051                        .update(cx, |this, cx| {
1052                            this.connection.prompt(
1053                                acp::PromptRequest {
1054                                    prompt: message,
1055                                    session_id: this.session_id.clone(),
1056                                },
1057                                cx,
1058                            )
1059                        })?
1060                        .await;
1061
1062                    tx.send(result).log_err();
1063                    anyhow::Ok(())
1064                }
1065                .await
1066                .log_err();
1067            })
1068            .fuse(),
1069        );
1070
1071        cx.spawn(async move |this, cx| match rx.await {
1072            Ok(Err(e)) => {
1073                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1074                    .log_err();
1075                Err(e)?
1076            }
1077            _ => {
1078                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1079                    .log_err();
1080                Ok(())
1081            }
1082        })
1083        .boxed()
1084    }
1085
1086    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1087        let Some(send_task) = self.send_task.take() else {
1088            return Task::ready(());
1089        };
1090
1091        for entry in self.entries.iter_mut() {
1092            if let AgentThreadEntry::ToolCall(call) = entry {
1093                let cancel = matches!(
1094                    call.status,
1095                    ToolCallStatus::WaitingForConfirmation { .. }
1096                        | ToolCallStatus::Allowed {
1097                            status: acp::ToolCallStatus::InProgress
1098                        }
1099                );
1100
1101                if cancel {
1102                    call.status = ToolCallStatus::Canceled;
1103                }
1104            }
1105        }
1106
1107        self.connection.cancel(&self.session_id, cx);
1108
1109        // Wait for the send task to complete
1110        cx.foreground_executor().spawn(send_task)
1111    }
1112
1113    pub fn read_text_file(
1114        &self,
1115        path: PathBuf,
1116        line: Option<u32>,
1117        limit: Option<u32>,
1118        reuse_shared_snapshot: bool,
1119        cx: &mut Context<Self>,
1120    ) -> Task<Result<String>> {
1121        let project = self.project.clone();
1122        let action_log = self.action_log.clone();
1123        cx.spawn(async move |this, cx| {
1124            let load = project.update(cx, |project, cx| {
1125                let path = project
1126                    .project_path_for_absolute_path(&path, cx)
1127                    .context("invalid path")?;
1128                anyhow::Ok(project.open_buffer(path, cx))
1129            });
1130            let buffer = load??.await?;
1131
1132            let snapshot = if reuse_shared_snapshot {
1133                this.read_with(cx, |this, _| {
1134                    this.shared_buffers.get(&buffer.clone()).cloned()
1135                })
1136                .log_err()
1137                .flatten()
1138            } else {
1139                None
1140            };
1141
1142            let snapshot = if let Some(snapshot) = snapshot {
1143                snapshot
1144            } else {
1145                action_log.update(cx, |action_log, cx| {
1146                    action_log.buffer_read(buffer.clone(), cx);
1147                })?;
1148                project.update(cx, |project, cx| {
1149                    let position = buffer
1150                        .read(cx)
1151                        .snapshot()
1152                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1153                    project.set_agent_location(
1154                        Some(AgentLocation {
1155                            buffer: buffer.downgrade(),
1156                            position,
1157                        }),
1158                        cx,
1159                    );
1160                })?;
1161
1162                buffer.update(cx, |buffer, _| buffer.snapshot())?
1163            };
1164
1165            this.update(cx, |this, _| {
1166                let text = snapshot.text();
1167                this.shared_buffers.insert(buffer.clone(), snapshot);
1168                if line.is_none() && limit.is_none() {
1169                    return Ok(text);
1170                }
1171                let limit = limit.unwrap_or(u32::MAX) as usize;
1172                let Some(line) = line else {
1173                    return Ok(text.lines().take(limit).collect::<String>());
1174                };
1175
1176                let count = text.lines().count();
1177                if count < line as usize {
1178                    anyhow::bail!("There are only {} lines", count);
1179                }
1180                Ok(text
1181                    .lines()
1182                    .skip(line as usize + 1)
1183                    .take(limit)
1184                    .collect::<String>())
1185            })?
1186        })
1187    }
1188
1189    pub fn write_text_file(
1190        &self,
1191        path: PathBuf,
1192        content: String,
1193        cx: &mut Context<Self>,
1194    ) -> Task<Result<()>> {
1195        let project = self.project.clone();
1196        let action_log = self.action_log.clone();
1197        cx.spawn(async move |this, cx| {
1198            let load = project.update(cx, |project, cx| {
1199                let path = project
1200                    .project_path_for_absolute_path(&path, cx)
1201                    .context("invalid path")?;
1202                anyhow::Ok(project.open_buffer(path, cx))
1203            });
1204            let buffer = load??.await?;
1205            let snapshot = this.update(cx, |this, cx| {
1206                this.shared_buffers
1207                    .get(&buffer)
1208                    .cloned()
1209                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1210            })?;
1211            let edits = cx
1212                .background_executor()
1213                .spawn(async move {
1214                    let old_text = snapshot.text();
1215                    text_diff(old_text.as_str(), &content)
1216                        .into_iter()
1217                        .map(|(range, replacement)| {
1218                            (
1219                                snapshot.anchor_after(range.start)
1220                                    ..snapshot.anchor_before(range.end),
1221                                replacement,
1222                            )
1223                        })
1224                        .collect::<Vec<_>>()
1225                })
1226                .await;
1227            cx.update(|cx| {
1228                project.update(cx, |project, cx| {
1229                    project.set_agent_location(
1230                        Some(AgentLocation {
1231                            buffer: buffer.downgrade(),
1232                            position: edits
1233                                .last()
1234                                .map(|(range, _)| range.end)
1235                                .unwrap_or(Anchor::MIN),
1236                        }),
1237                        cx,
1238                    );
1239                });
1240
1241                action_log.update(cx, |action_log, cx| {
1242                    action_log.buffer_read(buffer.clone(), cx);
1243                });
1244                buffer.update(cx, |buffer, cx| {
1245                    buffer.edit(edits, None, cx);
1246                });
1247                action_log.update(cx, |action_log, cx| {
1248                    action_log.buffer_edited(buffer.clone(), cx);
1249                });
1250            })?;
1251            project
1252                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1253                .await
1254        })
1255    }
1256
1257    pub fn to_markdown(&self, cx: &App) -> String {
1258        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1259    }
1260
1261    pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1262        cx.emit(AcpThreadEvent::ServerExited(status));
1263    }
1264}
1265
1266#[cfg(test)]
1267mod tests {
1268    use super::*;
1269    use anyhow::anyhow;
1270    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1271    use gpui::{AsyncApp, TestAppContext, WeakEntity};
1272    use indoc::indoc;
1273    use project::FakeFs;
1274    use rand::Rng as _;
1275    use serde_json::json;
1276    use settings::SettingsStore;
1277    use smol::stream::StreamExt as _;
1278    use std::{cell::RefCell, rc::Rc, time::Duration};
1279
1280    use util::path;
1281
1282    fn init_test(cx: &mut TestAppContext) {
1283        env_logger::try_init().ok();
1284        cx.update(|cx| {
1285            let settings_store = SettingsStore::test(cx);
1286            cx.set_global(settings_store);
1287            Project::init_settings(cx);
1288            language::init(cx);
1289        });
1290    }
1291
1292    #[gpui::test]
1293    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1294        init_test(cx);
1295
1296        let fs = FakeFs::new(cx.executor());
1297        let project = Project::test(fs, [], cx).await;
1298        let connection = Rc::new(FakeAgentConnection::new());
1299        let thread = cx
1300            .spawn(async move |mut cx| {
1301                connection
1302                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1303                    .await
1304            })
1305            .await
1306            .unwrap();
1307
1308        // Test creating a new user message
1309        thread.update(cx, |thread, cx| {
1310            thread.push_user_content_block(
1311                acp::ContentBlock::Text(acp::TextContent {
1312                    annotations: None,
1313                    text: "Hello, ".to_string(),
1314                }),
1315                cx,
1316            );
1317        });
1318
1319        thread.update(cx, |thread, cx| {
1320            assert_eq!(thread.entries.len(), 1);
1321            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1322                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1323            } else {
1324                panic!("Expected UserMessage");
1325            }
1326        });
1327
1328        // Test appending to existing user message
1329        thread.update(cx, |thread, cx| {
1330            thread.push_user_content_block(
1331                acp::ContentBlock::Text(acp::TextContent {
1332                    annotations: None,
1333                    text: "world!".to_string(),
1334                }),
1335                cx,
1336            );
1337        });
1338
1339        thread.update(cx, |thread, cx| {
1340            assert_eq!(thread.entries.len(), 1);
1341            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1342                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1343            } else {
1344                panic!("Expected UserMessage");
1345            }
1346        });
1347
1348        // Test creating new user message after assistant message
1349        thread.update(cx, |thread, cx| {
1350            thread.push_assistant_content_block(
1351                acp::ContentBlock::Text(acp::TextContent {
1352                    annotations: None,
1353                    text: "Assistant response".to_string(),
1354                }),
1355                false,
1356                cx,
1357            );
1358        });
1359
1360        thread.update(cx, |thread, cx| {
1361            thread.push_user_content_block(
1362                acp::ContentBlock::Text(acp::TextContent {
1363                    annotations: None,
1364                    text: "New user message".to_string(),
1365                }),
1366                cx,
1367            );
1368        });
1369
1370        thread.update(cx, |thread, cx| {
1371            assert_eq!(thread.entries.len(), 3);
1372            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1373                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1374            } else {
1375                panic!("Expected UserMessage at index 2");
1376            }
1377        });
1378    }
1379
1380    #[gpui::test]
1381    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1382        init_test(cx);
1383
1384        let fs = FakeFs::new(cx.executor());
1385        let project = Project::test(fs, [], cx).await;
1386        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1387            |_, thread, mut cx| {
1388                async move {
1389                    thread.update(&mut cx, |thread, cx| {
1390                        thread
1391                            .handle_session_update(
1392                                acp::SessionUpdate::AgentThoughtChunk {
1393                                    content: "Thinking ".into(),
1394                                },
1395                                cx,
1396                            )
1397                            .unwrap();
1398                        thread
1399                            .handle_session_update(
1400                                acp::SessionUpdate::AgentThoughtChunk {
1401                                    content: "hard!".into(),
1402                                },
1403                                cx,
1404                            )
1405                            .unwrap();
1406                    })?;
1407                    Ok(acp::PromptResponse {
1408                        stop_reason: acp::StopReason::EndTurn,
1409                    })
1410                }
1411                .boxed_local()
1412            },
1413        ));
1414
1415        let thread = cx
1416            .spawn(async move |mut cx| {
1417                connection
1418                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1419                    .await
1420            })
1421            .await
1422            .unwrap();
1423
1424        thread
1425            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1426            .await
1427            .unwrap();
1428
1429        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1430        assert_eq!(
1431            output,
1432            indoc! {r#"
1433            ## User
1434
1435            Hello from Zed!
1436
1437            ## Assistant
1438
1439            <thinking>
1440            Thinking hard!
1441            </thinking>
1442
1443            "#}
1444        );
1445    }
1446
1447    #[gpui::test]
1448    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1449        init_test(cx);
1450
1451        let fs = FakeFs::new(cx.executor());
1452        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1453            .await;
1454        let project = Project::test(fs.clone(), [], cx).await;
1455        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1456        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1457        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1458            move |_, thread, mut cx| {
1459                let read_file_tx = read_file_tx.clone();
1460                async move {
1461                    let content = thread
1462                        .update(&mut cx, |thread, cx| {
1463                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1464                        })
1465                        .unwrap()
1466                        .await
1467                        .unwrap();
1468                    assert_eq!(content, "one\ntwo\nthree\n");
1469                    read_file_tx.take().unwrap().send(()).unwrap();
1470                    thread
1471                        .update(&mut cx, |thread, cx| {
1472                            thread.write_text_file(
1473                                path!("/tmp/foo").into(),
1474                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
1475                                cx,
1476                            )
1477                        })
1478                        .unwrap()
1479                        .await
1480                        .unwrap();
1481                    Ok(acp::PromptResponse {
1482                        stop_reason: acp::StopReason::EndTurn,
1483                    })
1484                }
1485                .boxed_local()
1486            },
1487        ));
1488
1489        let (worktree, pathbuf) = project
1490            .update(cx, |project, cx| {
1491                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1492            })
1493            .await
1494            .unwrap();
1495        let buffer = project
1496            .update(cx, |project, cx| {
1497                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1498            })
1499            .await
1500            .unwrap();
1501
1502        let thread = cx
1503            .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1504            .await
1505            .unwrap();
1506
1507        let request = thread.update(cx, |thread, cx| {
1508            thread.send_raw("Extend the count in /tmp/foo", cx)
1509        });
1510        read_file_rx.await.ok();
1511        buffer.update(cx, |buffer, cx| {
1512            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1513        });
1514        cx.run_until_parked();
1515        assert_eq!(
1516            buffer.read_with(cx, |buffer, _| buffer.text()),
1517            "zero\none\ntwo\nthree\nfour\nfive\n"
1518        );
1519        assert_eq!(
1520            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1521            "zero\none\ntwo\nthree\nfour\nfive\n"
1522        );
1523        request.await.unwrap();
1524    }
1525
1526    #[gpui::test]
1527    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1528        init_test(cx);
1529
1530        let fs = FakeFs::new(cx.executor());
1531        let project = Project::test(fs, [], cx).await;
1532        let id = acp::ToolCallId("test".into());
1533
1534        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1535            let id = id.clone();
1536            move |_, thread, mut cx| {
1537                let id = id.clone();
1538                async move {
1539                    thread
1540                        .update(&mut cx, |thread, cx| {
1541                            thread.handle_session_update(
1542                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1543                                    id: id.clone(),
1544                                    title: "Label".into(),
1545                                    kind: acp::ToolKind::Fetch,
1546                                    status: acp::ToolCallStatus::InProgress,
1547                                    content: vec![],
1548                                    locations: vec![],
1549                                    raw_input: None,
1550                                }),
1551                                cx,
1552                            )
1553                        })
1554                        .unwrap()
1555                        .unwrap();
1556                    Ok(acp::PromptResponse {
1557                        stop_reason: acp::StopReason::EndTurn,
1558                    })
1559                }
1560                .boxed_local()
1561            }
1562        }));
1563
1564        let thread = cx
1565            .spawn(async move |mut cx| {
1566                connection
1567                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1568                    .await
1569            })
1570            .await
1571            .unwrap();
1572
1573        let request = thread.update(cx, |thread, cx| {
1574            thread.send_raw("Fetch https://example.com", cx)
1575        });
1576
1577        run_until_first_tool_call(&thread, cx).await;
1578
1579        thread.read_with(cx, |thread, _| {
1580            assert!(matches!(
1581                thread.entries[1],
1582                AgentThreadEntry::ToolCall(ToolCall {
1583                    status: ToolCallStatus::Allowed {
1584                        status: acp::ToolCallStatus::InProgress,
1585                        ..
1586                    },
1587                    ..
1588                })
1589            ));
1590        });
1591
1592        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1593
1594        thread.read_with(cx, |thread, _| {
1595            assert!(matches!(
1596                &thread.entries[1],
1597                AgentThreadEntry::ToolCall(ToolCall {
1598                    status: ToolCallStatus::Canceled,
1599                    ..
1600                })
1601            ));
1602        });
1603
1604        thread
1605            .update(cx, |thread, cx| {
1606                thread.handle_session_update(
1607                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1608                        id,
1609                        fields: acp::ToolCallUpdateFields {
1610                            status: Some(acp::ToolCallStatus::Completed),
1611                            ..Default::default()
1612                        },
1613                    }),
1614                    cx,
1615                )
1616            })
1617            .unwrap();
1618
1619        request.await.unwrap();
1620
1621        thread.read_with(cx, |thread, _| {
1622            assert!(matches!(
1623                thread.entries[1],
1624                AgentThreadEntry::ToolCall(ToolCall {
1625                    status: ToolCallStatus::Allowed {
1626                        status: acp::ToolCallStatus::Completed,
1627                        ..
1628                    },
1629                    ..
1630                })
1631            ));
1632        });
1633    }
1634
1635    #[gpui::test]
1636    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1637        init_test(cx);
1638        let fs = FakeFs::new(cx.background_executor.clone());
1639        fs.insert_tree(path!("/test"), json!({})).await;
1640        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1641
1642        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1643            move |_, thread, mut cx| {
1644                async move {
1645                    thread
1646                        .update(&mut cx, |thread, cx| {
1647                            thread.handle_session_update(
1648                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1649                                    id: acp::ToolCallId("test".into()),
1650                                    title: "Label".into(),
1651                                    kind: acp::ToolKind::Edit,
1652                                    status: acp::ToolCallStatus::Completed,
1653                                    content: vec![acp::ToolCallContent::Diff {
1654                                        diff: acp::Diff {
1655                                            path: "/test/test.txt".into(),
1656                                            old_text: None,
1657                                            new_text: "foo".into(),
1658                                        },
1659                                    }],
1660                                    locations: vec![],
1661                                    raw_input: None,
1662                                }),
1663                                cx,
1664                            )
1665                        })
1666                        .unwrap()
1667                        .unwrap();
1668                    Ok(acp::PromptResponse {
1669                        stop_reason: acp::StopReason::EndTurn,
1670                    })
1671                }
1672                .boxed_local()
1673            }
1674        }));
1675
1676        let thread = connection
1677            .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1678            .await
1679            .unwrap();
1680        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1681            .await
1682            .unwrap();
1683
1684        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1685    }
1686
1687    async fn run_until_first_tool_call(
1688        thread: &Entity<AcpThread>,
1689        cx: &mut TestAppContext,
1690    ) -> usize {
1691        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1692
1693        let subscription = cx.update(|cx| {
1694            cx.subscribe(thread, move |thread, _, cx| {
1695                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1696                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1697                        return tx.try_send(ix).unwrap();
1698                    }
1699                }
1700            })
1701        });
1702
1703        select! {
1704            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1705                panic!("Timeout waiting for tool call")
1706            }
1707            ix = rx.next().fuse() => {
1708                drop(subscription);
1709                ix.unwrap()
1710            }
1711        }
1712    }
1713
1714    #[derive(Clone, Default)]
1715    struct FakeAgentConnection {
1716        auth_methods: Vec<acp::AuthMethod>,
1717        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1718        on_user_message: Option<
1719            Rc<
1720                dyn Fn(
1721                        acp::PromptRequest,
1722                        WeakEntity<AcpThread>,
1723                        AsyncApp,
1724                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1725                    + 'static,
1726            >,
1727        >,
1728    }
1729
1730    impl FakeAgentConnection {
1731        fn new() -> Self {
1732            Self {
1733                auth_methods: Vec::new(),
1734                on_user_message: None,
1735                sessions: Arc::default(),
1736            }
1737        }
1738
1739        #[expect(unused)]
1740        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1741            self.auth_methods = auth_methods;
1742            self
1743        }
1744
1745        fn on_user_message(
1746            mut self,
1747            handler: impl Fn(
1748                acp::PromptRequest,
1749                WeakEntity<AcpThread>,
1750                AsyncApp,
1751            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1752            + 'static,
1753        ) -> Self {
1754            self.on_user_message.replace(Rc::new(handler));
1755            self
1756        }
1757    }
1758
1759    impl AgentConnection for FakeAgentConnection {
1760        fn auth_methods(&self) -> &[acp::AuthMethod] {
1761            &self.auth_methods
1762        }
1763
1764        fn new_thread(
1765            self: Rc<Self>,
1766            project: Entity<Project>,
1767            _cwd: &Path,
1768            cx: &mut gpui::AsyncApp,
1769        ) -> Task<gpui::Result<Entity<AcpThread>>> {
1770            let session_id = acp::SessionId(
1771                rand::thread_rng()
1772                    .sample_iter(&rand::distributions::Alphanumeric)
1773                    .take(7)
1774                    .map(char::from)
1775                    .collect::<String>()
1776                    .into(),
1777            );
1778            let thread = cx
1779                .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1780                .unwrap();
1781            self.sessions.lock().insert(session_id, thread.downgrade());
1782            Task::ready(Ok(thread))
1783        }
1784
1785        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1786            if self.auth_methods().iter().any(|m| m.id == method) {
1787                Task::ready(Ok(()))
1788            } else {
1789                Task::ready(Err(anyhow!("Invalid Auth Method")))
1790            }
1791        }
1792
1793        fn prompt(
1794            &self,
1795            params: acp::PromptRequest,
1796            cx: &mut App,
1797        ) -> Task<gpui::Result<acp::PromptResponse>> {
1798            let sessions = self.sessions.lock();
1799            let thread = sessions.get(&params.session_id).unwrap();
1800            if let Some(handler) = &self.on_user_message {
1801                let handler = handler.clone();
1802                let thread = thread.clone();
1803                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1804            } else {
1805                Task::ready(Ok(acp::PromptResponse {
1806                    stop_reason: acp::StopReason::EndTurn,
1807                }))
1808            }
1809        }
1810
1811        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1812            let sessions = self.sessions.lock();
1813            let thread = sessions.get(&session_id).unwrap().clone();
1814
1815            cx.spawn(async move |cx| {
1816                thread
1817                    .update(cx, |thread, cx| thread.cancel(cx))
1818                    .unwrap()
1819                    .await
1820            })
1821            .detach();
1822        }
1823    }
1824}