acp_thread.rs

   1mod connection;
   2pub use connection::*;
   3
   4use agent_client_protocol as acp;
   5use anyhow::{Context as _, Result};
   6use assistant_tool::ActionLog;
   7use buffer_diff::BufferDiff;
   8use editor::{Bias, MultiBuffer, PathKey};
   9use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  10use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
  11use itertools::Itertools;
  12use language::{
  13    Anchor, Buffer, BufferSnapshot, Capability, LanguageRegistry, OffsetRangeExt as _, Point,
  14    text_diff,
  15};
  16use markdown::Markdown;
  17use project::{AgentLocation, Project};
  18use std::collections::HashMap;
  19use std::error::Error;
  20use std::fmt::Formatter;
  21use std::process::ExitStatus;
  22use std::rc::Rc;
  23use std::{
  24    fmt::Display,
  25    mem,
  26    path::{Path, PathBuf},
  27    sync::Arc,
  28};
  29use ui::App;
  30use util::ResultExt;
  31
  32#[derive(Debug)]
  33pub struct UserMessage {
  34    pub content: ContentBlock,
  35}
  36
  37impl UserMessage {
  38    pub fn from_acp(
  39        message: impl IntoIterator<Item = acp::ContentBlock>,
  40        language_registry: Arc<LanguageRegistry>,
  41        cx: &mut App,
  42    ) -> Self {
  43        let mut content = ContentBlock::Empty;
  44        for chunk in message {
  45            content.append(chunk, &language_registry, cx)
  46        }
  47        Self { content: content }
  48    }
  49
  50    fn to_markdown(&self, cx: &App) -> String {
  51        format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
  52    }
  53}
  54
  55#[derive(Debug)]
  56pub struct MentionPath<'a>(&'a Path);
  57
  58impl<'a> MentionPath<'a> {
  59    const PREFIX: &'static str = "@file:";
  60
  61    pub fn new(path: &'a Path) -> Self {
  62        MentionPath(path)
  63    }
  64
  65    pub fn try_parse(url: &'a str) -> Option<Self> {
  66        let path = url.strip_prefix(Self::PREFIX)?;
  67        Some(MentionPath(Path::new(path)))
  68    }
  69
  70    pub fn path(&self) -> &Path {
  71        self.0
  72    }
  73}
  74
  75impl Display for MentionPath<'_> {
  76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  77        write!(
  78            f,
  79            "[@{}]({}{})",
  80            self.0.file_name().unwrap_or_default().display(),
  81            Self::PREFIX,
  82            self.0.display()
  83        )
  84    }
  85}
  86
  87#[derive(Debug, PartialEq)]
  88pub struct AssistantMessage {
  89    pub chunks: Vec<AssistantMessageChunk>,
  90}
  91
  92impl AssistantMessage {
  93    pub fn to_markdown(&self, cx: &App) -> String {
  94        format!(
  95            "## Assistant\n\n{}\n\n",
  96            self.chunks
  97                .iter()
  98                .map(|chunk| chunk.to_markdown(cx))
  99                .join("\n\n")
 100        )
 101    }
 102}
 103
 104#[derive(Debug, PartialEq)]
 105pub enum AssistantMessageChunk {
 106    Message { block: ContentBlock },
 107    Thought { block: ContentBlock },
 108}
 109
 110impl AssistantMessageChunk {
 111    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
 112        Self::Message {
 113            block: ContentBlock::new(chunk.into(), language_registry, cx),
 114        }
 115    }
 116
 117    fn to_markdown(&self, cx: &App) -> String {
 118        match self {
 119            Self::Message { block } => block.to_markdown(cx).to_string(),
 120            Self::Thought { block } => {
 121                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 122            }
 123        }
 124    }
 125}
 126
 127#[derive(Debug)]
 128pub enum AgentThreadEntry {
 129    UserMessage(UserMessage),
 130    AssistantMessage(AssistantMessage),
 131    ToolCall(ToolCall),
 132}
 133
 134impl AgentThreadEntry {
 135    fn to_markdown(&self, cx: &App) -> String {
 136        match self {
 137            Self::UserMessage(message) => message.to_markdown(cx),
 138            Self::AssistantMessage(message) => message.to_markdown(cx),
 139            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 140        }
 141    }
 142
 143    pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
 144        if let AgentThreadEntry::ToolCall(call) = self {
 145            itertools::Either::Left(call.diffs())
 146        } else {
 147            itertools::Either::Right(std::iter::empty())
 148        }
 149    }
 150
 151    pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
 152        if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
 153            Some(locations)
 154        } else {
 155            None
 156        }
 157    }
 158}
 159
 160#[derive(Debug)]
 161pub struct ToolCall {
 162    pub id: acp::ToolCallId,
 163    pub label: Entity<Markdown>,
 164    pub kind: acp::ToolKind,
 165    pub content: Vec<ToolCallContent>,
 166    pub status: ToolCallStatus,
 167    pub locations: Vec<acp::ToolCallLocation>,
 168    pub raw_input: Option<serde_json::Value>,
 169    pub raw_output: Option<serde_json::Value>,
 170}
 171
 172impl ToolCall {
 173    fn from_acp(
 174        tool_call: acp::ToolCall,
 175        status: ToolCallStatus,
 176        language_registry: Arc<LanguageRegistry>,
 177        cx: &mut App,
 178    ) -> Self {
 179        Self {
 180            id: tool_call.id,
 181            label: cx.new(|cx| {
 182                Markdown::new(
 183                    tool_call.title.into(),
 184                    Some(language_registry.clone()),
 185                    None,
 186                    cx,
 187                )
 188            }),
 189            kind: tool_call.kind,
 190            content: tool_call
 191                .content
 192                .into_iter()
 193                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 194                .collect(),
 195            locations: tool_call.locations,
 196            status,
 197            raw_input: tool_call.raw_input,
 198            raw_output: tool_call.raw_output,
 199        }
 200    }
 201
 202    fn update(
 203        &mut self,
 204        fields: acp::ToolCallUpdateFields,
 205        language_registry: Arc<LanguageRegistry>,
 206        cx: &mut App,
 207    ) {
 208        let acp::ToolCallUpdateFields {
 209            kind,
 210            status,
 211            title,
 212            content,
 213            locations,
 214            raw_input,
 215            raw_output,
 216        } = fields;
 217
 218        if let Some(kind) = kind {
 219            self.kind = kind;
 220        }
 221
 222        if let Some(status) = status {
 223            self.status = ToolCallStatus::Allowed { status };
 224        }
 225
 226        if let Some(title) = title {
 227            self.label.update(cx, |label, cx| {
 228                label.replace(title, cx);
 229            });
 230        }
 231
 232        if let Some(content) = content {
 233            self.content = content
 234                .into_iter()
 235                .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
 236                .collect();
 237        }
 238
 239        if let Some(locations) = locations {
 240            self.locations = locations;
 241        }
 242
 243        if let Some(raw_input) = raw_input {
 244            self.raw_input = Some(raw_input);
 245        }
 246
 247        if let Some(raw_output) = raw_output {
 248            self.raw_output = Some(raw_output);
 249        }
 250    }
 251
 252    pub fn diffs(&self) -> impl Iterator<Item = &Diff> {
 253        self.content.iter().filter_map(|content| match content {
 254            ToolCallContent::ContentBlock { .. } => None,
 255            ToolCallContent::Diff { diff } => Some(diff),
 256        })
 257    }
 258
 259    fn to_markdown(&self, cx: &App) -> String {
 260        let mut markdown = format!(
 261            "**Tool Call: {}**\nStatus: {}\n\n",
 262            self.label.read(cx).source(),
 263            self.status
 264        );
 265        for content in &self.content {
 266            markdown.push_str(content.to_markdown(cx).as_str());
 267            markdown.push_str("\n\n");
 268        }
 269        markdown
 270    }
 271}
 272
 273#[derive(Debug)]
 274pub enum ToolCallStatus {
 275    WaitingForConfirmation {
 276        options: Vec<acp::PermissionOption>,
 277        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 278    },
 279    Allowed {
 280        status: acp::ToolCallStatus,
 281    },
 282    Rejected,
 283    Canceled,
 284}
 285
 286impl Display for ToolCallStatus {
 287    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 288        write!(
 289            f,
 290            "{}",
 291            match self {
 292                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 293                ToolCallStatus::Allowed { status } => match status {
 294                    acp::ToolCallStatus::Pending => "Pending",
 295                    acp::ToolCallStatus::InProgress => "In Progress",
 296                    acp::ToolCallStatus::Completed => "Completed",
 297                    acp::ToolCallStatus::Failed => "Failed",
 298                },
 299                ToolCallStatus::Rejected => "Rejected",
 300                ToolCallStatus::Canceled => "Canceled",
 301            }
 302        )
 303    }
 304}
 305
 306#[derive(Debug, PartialEq, Clone)]
 307pub enum ContentBlock {
 308    Empty,
 309    Markdown { markdown: Entity<Markdown> },
 310}
 311
 312impl ContentBlock {
 313    pub fn new(
 314        block: acp::ContentBlock,
 315        language_registry: &Arc<LanguageRegistry>,
 316        cx: &mut App,
 317    ) -> Self {
 318        let mut this = Self::Empty;
 319        this.append(block, language_registry, cx);
 320        this
 321    }
 322
 323    pub fn new_combined(
 324        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 325        language_registry: Arc<LanguageRegistry>,
 326        cx: &mut App,
 327    ) -> Self {
 328        let mut this = Self::Empty;
 329        for block in blocks {
 330            this.append(block, &language_registry, cx);
 331        }
 332        this
 333    }
 334
 335    pub fn append(
 336        &mut self,
 337        block: acp::ContentBlock,
 338        language_registry: &Arc<LanguageRegistry>,
 339        cx: &mut App,
 340    ) {
 341        let new_content = match block {
 342            acp::ContentBlock::Text(text_content) => text_content.text.clone(),
 343            acp::ContentBlock::ResourceLink(resource_link) => {
 344                if let Some(path) = resource_link.uri.strip_prefix("file://") {
 345                    format!("{}", MentionPath(path.as_ref()))
 346                } else {
 347                    resource_link.uri.clone()
 348                }
 349            }
 350            acp::ContentBlock::Image(_)
 351            | acp::ContentBlock::Audio(_)
 352            | acp::ContentBlock::Resource(_) => String::new(),
 353        };
 354
 355        match self {
 356            ContentBlock::Empty => {
 357                *self = ContentBlock::Markdown {
 358                    markdown: cx.new(|cx| {
 359                        Markdown::new(
 360                            new_content.into(),
 361                            Some(language_registry.clone()),
 362                            None,
 363                            cx,
 364                        )
 365                    }),
 366                };
 367            }
 368            ContentBlock::Markdown { markdown } => {
 369                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 370            }
 371        }
 372    }
 373
 374    fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 375        match self {
 376            ContentBlock::Empty => "",
 377            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 378        }
 379    }
 380
 381    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 382        match self {
 383            ContentBlock::Empty => None,
 384            ContentBlock::Markdown { markdown } => Some(markdown),
 385        }
 386    }
 387}
 388
 389#[derive(Debug)]
 390pub enum ToolCallContent {
 391    ContentBlock { content: ContentBlock },
 392    Diff { diff: Diff },
 393}
 394
 395impl ToolCallContent {
 396    pub fn from_acp(
 397        content: acp::ToolCallContent,
 398        language_registry: Arc<LanguageRegistry>,
 399        cx: &mut App,
 400    ) -> Self {
 401        match content {
 402            acp::ToolCallContent::Content { content } => Self::ContentBlock {
 403                content: ContentBlock::new(content, &language_registry, cx),
 404            },
 405            acp::ToolCallContent::Diff { diff } => Self::Diff {
 406                diff: Diff::from_acp(diff, language_registry, cx),
 407            },
 408        }
 409    }
 410
 411    pub fn to_markdown(&self, cx: &App) -> String {
 412        match self {
 413            Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
 414            Self::Diff { diff } => diff.to_markdown(cx),
 415        }
 416    }
 417}
 418
 419#[derive(Debug)]
 420pub struct Diff {
 421    pub multibuffer: Entity<MultiBuffer>,
 422    pub path: PathBuf,
 423    _task: Task<Result<()>>,
 424}
 425
 426impl Diff {
 427    pub fn from_acp(
 428        diff: acp::Diff,
 429        language_registry: Arc<LanguageRegistry>,
 430        cx: &mut App,
 431    ) -> Self {
 432        let acp::Diff {
 433            path,
 434            old_text,
 435            new_text,
 436        } = diff;
 437
 438        let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
 439
 440        let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
 441        let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
 442        let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
 443        let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
 444
 445        let task = cx.spawn({
 446            let multibuffer = multibuffer.clone();
 447            let path = path.clone();
 448            async move |cx| {
 449                let language = language_registry
 450                    .language_for_file_path(&path)
 451                    .await
 452                    .log_err();
 453
 454                new_buffer.update(cx, |buffer, cx| buffer.set_language(language.clone(), cx))?;
 455
 456                let old_buffer_snapshot = old_buffer.update(cx, |buffer, cx| {
 457                    buffer.set_language(language, cx);
 458                    buffer.snapshot()
 459                })?;
 460
 461                buffer_diff
 462                    .update(cx, |diff, cx| {
 463                        diff.set_base_text(
 464                            old_buffer_snapshot,
 465                            Some(language_registry),
 466                            new_buffer_snapshot,
 467                            cx,
 468                        )
 469                    })?
 470                    .await?;
 471
 472                multibuffer
 473                    .update(cx, |multibuffer, cx| {
 474                        let hunk_ranges = {
 475                            let buffer = new_buffer.read(cx);
 476                            let diff = buffer_diff.read(cx);
 477                            diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
 478                                .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
 479                                .collect::<Vec<_>>()
 480                        };
 481
 482                        multibuffer.set_excerpts_for_path(
 483                            PathKey::for_buffer(&new_buffer, cx),
 484                            new_buffer.clone(),
 485                            hunk_ranges,
 486                            editor::DEFAULT_MULTIBUFFER_CONTEXT,
 487                            cx,
 488                        );
 489                        multibuffer.add_diff(buffer_diff, cx);
 490                    })
 491                    .log_err();
 492
 493                anyhow::Ok(())
 494            }
 495        });
 496
 497        Self {
 498            multibuffer,
 499            path,
 500            _task: task,
 501        }
 502    }
 503
 504    fn to_markdown(&self, cx: &App) -> String {
 505        let buffer_text = self
 506            .multibuffer
 507            .read(cx)
 508            .all_buffers()
 509            .iter()
 510            .map(|buffer| buffer.read(cx).text())
 511            .join("\n");
 512        format!("Diff: {}\n```\n{}\n```\n", self.path.display(), buffer_text)
 513    }
 514}
 515
 516#[derive(Debug, Default)]
 517pub struct Plan {
 518    pub entries: Vec<PlanEntry>,
 519}
 520
 521#[derive(Debug)]
 522pub struct PlanStats<'a> {
 523    pub in_progress_entry: Option<&'a PlanEntry>,
 524    pub pending: u32,
 525    pub completed: u32,
 526}
 527
 528impl Plan {
 529    pub fn is_empty(&self) -> bool {
 530        self.entries.is_empty()
 531    }
 532
 533    pub fn stats(&self) -> PlanStats<'_> {
 534        let mut stats = PlanStats {
 535            in_progress_entry: None,
 536            pending: 0,
 537            completed: 0,
 538        };
 539
 540        for entry in &self.entries {
 541            match &entry.status {
 542                acp::PlanEntryStatus::Pending => {
 543                    stats.pending += 1;
 544                }
 545                acp::PlanEntryStatus::InProgress => {
 546                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 547                }
 548                acp::PlanEntryStatus::Completed => {
 549                    stats.completed += 1;
 550                }
 551            }
 552        }
 553
 554        stats
 555    }
 556}
 557
 558#[derive(Debug)]
 559pub struct PlanEntry {
 560    pub content: Entity<Markdown>,
 561    pub priority: acp::PlanEntryPriority,
 562    pub status: acp::PlanEntryStatus,
 563}
 564
 565impl PlanEntry {
 566    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 567        Self {
 568            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 569            priority: entry.priority,
 570            status: entry.status,
 571        }
 572    }
 573}
 574
 575pub struct AcpThread {
 576    title: SharedString,
 577    entries: Vec<AgentThreadEntry>,
 578    plan: Plan,
 579    project: Entity<Project>,
 580    action_log: Entity<ActionLog>,
 581    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 582    send_task: Option<Task<()>>,
 583    connection: Rc<dyn AgentConnection>,
 584    session_id: acp::SessionId,
 585}
 586
 587pub enum AcpThreadEvent {
 588    NewEntry,
 589    EntryUpdated(usize),
 590    ToolAuthorizationRequired,
 591    Stopped,
 592    Error,
 593    ServerExited(ExitStatus),
 594}
 595
 596impl EventEmitter<AcpThreadEvent> for AcpThread {}
 597
 598#[derive(PartialEq, Eq)]
 599pub enum ThreadStatus {
 600    Idle,
 601    WaitingForToolConfirmation,
 602    Generating,
 603}
 604
 605#[derive(Debug, Clone)]
 606pub enum LoadError {
 607    Unsupported {
 608        error_message: SharedString,
 609        upgrade_message: SharedString,
 610        upgrade_command: String,
 611    },
 612    Exited(i32),
 613    Other(SharedString),
 614}
 615
 616impl Display for LoadError {
 617    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 618        match self {
 619            LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
 620            LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
 621            LoadError::Other(msg) => write!(f, "{}", msg),
 622        }
 623    }
 624}
 625
 626impl Error for LoadError {}
 627
 628impl AcpThread {
 629    pub fn new(
 630        title: impl Into<SharedString>,
 631        connection: Rc<dyn AgentConnection>,
 632        project: Entity<Project>,
 633        session_id: acp::SessionId,
 634        cx: &mut Context<Self>,
 635    ) -> Self {
 636        let action_log = cx.new(|_| ActionLog::new(project.clone()));
 637
 638        Self {
 639            action_log,
 640            shared_buffers: Default::default(),
 641            entries: Default::default(),
 642            plan: Default::default(),
 643            title: title.into(),
 644            project,
 645            send_task: None,
 646            connection,
 647            session_id,
 648        }
 649    }
 650
 651    pub fn action_log(&self) -> &Entity<ActionLog> {
 652        &self.action_log
 653    }
 654
 655    pub fn project(&self) -> &Entity<Project> {
 656        &self.project
 657    }
 658
 659    pub fn title(&self) -> SharedString {
 660        self.title.clone()
 661    }
 662
 663    pub fn entries(&self) -> &[AgentThreadEntry] {
 664        &self.entries
 665    }
 666
 667    pub fn session_id(&self) -> &acp::SessionId {
 668        &self.session_id
 669    }
 670
 671    pub fn status(&self) -> ThreadStatus {
 672        if self.send_task.is_some() {
 673            if self.waiting_for_tool_confirmation() {
 674                ThreadStatus::WaitingForToolConfirmation
 675            } else {
 676                ThreadStatus::Generating
 677            }
 678        } else {
 679            ThreadStatus::Idle
 680        }
 681    }
 682
 683    pub fn has_pending_edit_tool_calls(&self) -> bool {
 684        for entry in self.entries.iter().rev() {
 685            match entry {
 686                AgentThreadEntry::UserMessage(_) => return false,
 687                AgentThreadEntry::ToolCall(
 688                    call @ ToolCall {
 689                        status:
 690                            ToolCallStatus::Allowed {
 691                                status:
 692                                    acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
 693                            },
 694                        ..
 695                    },
 696                ) if call.diffs().next().is_some() => {
 697                    return true;
 698                }
 699                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 700            }
 701        }
 702
 703        false
 704    }
 705
 706    pub fn used_tools_since_last_user_message(&self) -> bool {
 707        for entry in self.entries.iter().rev() {
 708            match entry {
 709                AgentThreadEntry::UserMessage(..) => return false,
 710                AgentThreadEntry::AssistantMessage(..) => continue,
 711                AgentThreadEntry::ToolCall(..) => return true,
 712            }
 713        }
 714
 715        false
 716    }
 717
 718    pub fn handle_session_update(
 719        &mut self,
 720        update: acp::SessionUpdate,
 721        cx: &mut Context<Self>,
 722    ) -> Result<()> {
 723        match update {
 724            acp::SessionUpdate::UserMessageChunk { content } => {
 725                self.push_user_content_block(content, cx);
 726            }
 727            acp::SessionUpdate::AgentMessageChunk { content } => {
 728                self.push_assistant_content_block(content, false, cx);
 729            }
 730            acp::SessionUpdate::AgentThoughtChunk { content } => {
 731                self.push_assistant_content_block(content, true, cx);
 732            }
 733            acp::SessionUpdate::ToolCall(tool_call) => {
 734                self.upsert_tool_call(tool_call, cx);
 735            }
 736            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 737                self.update_tool_call(tool_call_update, cx)?;
 738            }
 739            acp::SessionUpdate::Plan(plan) => {
 740                self.update_plan(plan, cx);
 741            }
 742        }
 743        Ok(())
 744    }
 745
 746    pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
 747        let language_registry = self.project.read(cx).languages().clone();
 748        let entries_len = self.entries.len();
 749
 750        if let Some(last_entry) = self.entries.last_mut()
 751            && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
 752        {
 753            content.append(chunk, &language_registry, cx);
 754            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 755        } else {
 756            let content = ContentBlock::new(chunk, &language_registry, cx);
 757            self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
 758        }
 759    }
 760
 761    pub fn push_assistant_content_block(
 762        &mut self,
 763        chunk: acp::ContentBlock,
 764        is_thought: bool,
 765        cx: &mut Context<Self>,
 766    ) {
 767        let language_registry = self.project.read(cx).languages().clone();
 768        let entries_len = self.entries.len();
 769        if let Some(last_entry) = self.entries.last_mut()
 770            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
 771        {
 772            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 773            match (chunks.last_mut(), is_thought) {
 774                (Some(AssistantMessageChunk::Message { block }), false)
 775                | (Some(AssistantMessageChunk::Thought { block }), true) => {
 776                    block.append(chunk, &language_registry, cx)
 777                }
 778                _ => {
 779                    let block = ContentBlock::new(chunk, &language_registry, cx);
 780                    if is_thought {
 781                        chunks.push(AssistantMessageChunk::Thought { block })
 782                    } else {
 783                        chunks.push(AssistantMessageChunk::Message { block })
 784                    }
 785                }
 786            }
 787        } else {
 788            let block = ContentBlock::new(chunk, &language_registry, cx);
 789            let chunk = if is_thought {
 790                AssistantMessageChunk::Thought { block }
 791            } else {
 792                AssistantMessageChunk::Message { block }
 793            };
 794
 795            self.push_entry(
 796                AgentThreadEntry::AssistantMessage(AssistantMessage {
 797                    chunks: vec![chunk],
 798                }),
 799                cx,
 800            );
 801        }
 802    }
 803
 804    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
 805        self.entries.push(entry);
 806        cx.emit(AcpThreadEvent::NewEntry);
 807    }
 808
 809    pub fn update_tool_call(
 810        &mut self,
 811        update: acp::ToolCallUpdate,
 812        cx: &mut Context<Self>,
 813    ) -> Result<()> {
 814        let languages = self.project.read(cx).languages().clone();
 815
 816        let (ix, current_call) = self
 817            .tool_call_mut(&update.id)
 818            .context("Tool call not found")?;
 819        current_call.update(update.fields, languages, cx);
 820
 821        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 822
 823        Ok(())
 824    }
 825
 826    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
 827    pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
 828        let status = ToolCallStatus::Allowed {
 829            status: tool_call.status,
 830        };
 831        self.upsert_tool_call_inner(tool_call, status, cx)
 832    }
 833
 834    pub fn upsert_tool_call_inner(
 835        &mut self,
 836        tool_call: acp::ToolCall,
 837        status: ToolCallStatus,
 838        cx: &mut Context<Self>,
 839    ) {
 840        let language_registry = self.project.read(cx).languages().clone();
 841        let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
 842
 843        let location = call.locations.last().cloned();
 844
 845        if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
 846            *current_call = call;
 847
 848            cx.emit(AcpThreadEvent::EntryUpdated(ix));
 849        } else {
 850            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
 851        }
 852
 853        if let Some(location) = location {
 854            self.set_project_location(location, cx)
 855        }
 856    }
 857
 858    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
 859        // The tool call we are looking for is typically the last one, or very close to the end.
 860        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
 861        self.entries
 862            .iter_mut()
 863            .enumerate()
 864            .rev()
 865            .find_map(|(index, tool_call)| {
 866                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
 867                    && &tool_call.id == id
 868                {
 869                    Some((index, tool_call))
 870                } else {
 871                    None
 872                }
 873            })
 874    }
 875
 876    pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
 877        self.project.update(cx, |project, cx| {
 878            let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
 879                return;
 880            };
 881            let buffer = project.open_buffer(path, cx);
 882            cx.spawn(async move |project, cx| {
 883                let buffer = buffer.await?;
 884
 885                project.update(cx, |project, cx| {
 886                    let position = if let Some(line) = location.line {
 887                        let snapshot = buffer.read(cx).snapshot();
 888                        let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
 889                        snapshot.anchor_before(point)
 890                    } else {
 891                        Anchor::MIN
 892                    };
 893
 894                    project.set_agent_location(
 895                        Some(AgentLocation {
 896                            buffer: buffer.downgrade(),
 897                            position,
 898                        }),
 899                        cx,
 900                    );
 901                })
 902            })
 903            .detach_and_log_err(cx);
 904        });
 905    }
 906
 907    pub fn request_tool_call_authorization(
 908        &mut self,
 909        tool_call: acp::ToolCall,
 910        options: Vec<acp::PermissionOption>,
 911        cx: &mut Context<Self>,
 912    ) -> oneshot::Receiver<acp::PermissionOptionId> {
 913        let (tx, rx) = oneshot::channel();
 914
 915        let status = ToolCallStatus::WaitingForConfirmation {
 916            options,
 917            respond_tx: tx,
 918        };
 919
 920        self.upsert_tool_call_inner(tool_call, status, cx);
 921        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
 922        rx
 923    }
 924
 925    pub fn authorize_tool_call(
 926        &mut self,
 927        id: acp::ToolCallId,
 928        option_id: acp::PermissionOptionId,
 929        option_kind: acp::PermissionOptionKind,
 930        cx: &mut Context<Self>,
 931    ) {
 932        let Some((ix, call)) = self.tool_call_mut(&id) else {
 933            return;
 934        };
 935
 936        let new_status = match option_kind {
 937            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
 938                ToolCallStatus::Rejected
 939            }
 940            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
 941                ToolCallStatus::Allowed {
 942                    status: acp::ToolCallStatus::InProgress,
 943                }
 944            }
 945        };
 946
 947        let curr_status = mem::replace(&mut call.status, new_status);
 948
 949        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
 950            respond_tx.send(option_id).log_err();
 951        } else if cfg!(debug_assertions) {
 952            panic!("tried to authorize an already authorized tool call");
 953        }
 954
 955        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 956    }
 957
 958    /// Returns true if the last turn is awaiting tool authorization
 959    pub fn waiting_for_tool_confirmation(&self) -> bool {
 960        for entry in self.entries.iter().rev() {
 961            match &entry {
 962                AgentThreadEntry::ToolCall(call) => match call.status {
 963                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
 964                    ToolCallStatus::Allowed { .. }
 965                    | ToolCallStatus::Rejected
 966                    | ToolCallStatus::Canceled => continue,
 967                },
 968                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
 969                    // Reached the beginning of the turn
 970                    return false;
 971                }
 972            }
 973        }
 974        false
 975    }
 976
 977    pub fn plan(&self) -> &Plan {
 978        &self.plan
 979    }
 980
 981    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
 982        let new_entries_len = request.entries.len();
 983        let mut new_entries = request.entries.into_iter();
 984
 985        // Reuse existing markdown to prevent flickering
 986        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
 987            let PlanEntry {
 988                content,
 989                priority,
 990                status,
 991            } = old;
 992            content.update(cx, |old, cx| {
 993                old.replace(new.content, cx);
 994            });
 995            *priority = new.priority;
 996            *status = new.status;
 997        }
 998        for new in new_entries {
 999            self.plan.entries.push(PlanEntry::from_acp(new, cx))
1000        }
1001        self.plan.entries.truncate(new_entries_len);
1002
1003        cx.notify();
1004    }
1005
1006    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1007        self.plan
1008            .entries
1009            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1010        cx.notify();
1011    }
1012
1013    #[cfg(any(test, feature = "test-support"))]
1014    pub fn send_raw(
1015        &mut self,
1016        message: &str,
1017        cx: &mut Context<Self>,
1018    ) -> BoxFuture<'static, Result<()>> {
1019        self.send(
1020            vec![acp::ContentBlock::Text(acp::TextContent {
1021                text: message.to_string(),
1022                annotations: None,
1023            })],
1024            cx,
1025        )
1026    }
1027
1028    pub fn send(
1029        &mut self,
1030        message: Vec<acp::ContentBlock>,
1031        cx: &mut Context<Self>,
1032    ) -> BoxFuture<'static, Result<()>> {
1033        let block = ContentBlock::new_combined(
1034            message.clone(),
1035            self.project.read(cx).languages().clone(),
1036            cx,
1037        );
1038        self.push_entry(
1039            AgentThreadEntry::UserMessage(UserMessage { content: block }),
1040            cx,
1041        );
1042        self.clear_completed_plan_entries(cx);
1043
1044        let (tx, rx) = oneshot::channel();
1045        let cancel_task = self.cancel(cx);
1046
1047        self.send_task = Some(cx.spawn(async move |this, cx| {
1048            async {
1049                cancel_task.await;
1050
1051                let result = this
1052                    .update(cx, |this, cx| {
1053                        this.connection.prompt(
1054                            acp::PromptRequest {
1055                                prompt: message,
1056                                session_id: this.session_id.clone(),
1057                            },
1058                            cx,
1059                        )
1060                    })?
1061                    .await;
1062
1063                tx.send(result).log_err();
1064
1065                anyhow::Ok(())
1066            }
1067            .await
1068            .log_err();
1069        }));
1070
1071        cx.spawn(async move |this, cx| match rx.await {
1072            Ok(Err(e)) => {
1073                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1074                    .log_err();
1075                Err(e)?
1076            }
1077            result => {
1078                let cancelled = matches!(
1079                    result,
1080                    Ok(Ok(acp::PromptResponse {
1081                        stop_reason: acp::StopReason::Cancelled
1082                    }))
1083                );
1084
1085                // We only take the task if the current prompt wasn't cancelled.
1086                //
1087                // This prompt may have been cancelled because another one was sent
1088                // while it was still generating. In these cases, dropping `send_task`
1089                // would cause the next generation to be cancelled.
1090                if !cancelled {
1091                    this.update(cx, |this, _cx| this.send_task.take()).ok();
1092                }
1093
1094                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1095                    .log_err();
1096                Ok(())
1097            }
1098        })
1099        .boxed()
1100    }
1101
1102    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1103        let Some(send_task) = self.send_task.take() else {
1104            return Task::ready(());
1105        };
1106
1107        for entry in self.entries.iter_mut() {
1108            if let AgentThreadEntry::ToolCall(call) = entry {
1109                let cancel = matches!(
1110                    call.status,
1111                    ToolCallStatus::WaitingForConfirmation { .. }
1112                        | ToolCallStatus::Allowed {
1113                            status: acp::ToolCallStatus::InProgress
1114                        }
1115                );
1116
1117                if cancel {
1118                    call.status = ToolCallStatus::Canceled;
1119                }
1120            }
1121        }
1122
1123        self.connection.cancel(&self.session_id, cx);
1124
1125        // Wait for the send task to complete
1126        cx.foreground_executor().spawn(send_task)
1127    }
1128
1129    pub fn read_text_file(
1130        &self,
1131        path: PathBuf,
1132        line: Option<u32>,
1133        limit: Option<u32>,
1134        reuse_shared_snapshot: bool,
1135        cx: &mut Context<Self>,
1136    ) -> Task<Result<String>> {
1137        let project = self.project.clone();
1138        let action_log = self.action_log.clone();
1139        cx.spawn(async move |this, cx| {
1140            let load = project.update(cx, |project, cx| {
1141                let path = project
1142                    .project_path_for_absolute_path(&path, cx)
1143                    .context("invalid path")?;
1144                anyhow::Ok(project.open_buffer(path, cx))
1145            });
1146            let buffer = load??.await?;
1147
1148            let snapshot = if reuse_shared_snapshot {
1149                this.read_with(cx, |this, _| {
1150                    this.shared_buffers.get(&buffer.clone()).cloned()
1151                })
1152                .log_err()
1153                .flatten()
1154            } else {
1155                None
1156            };
1157
1158            let snapshot = if let Some(snapshot) = snapshot {
1159                snapshot
1160            } else {
1161                action_log.update(cx, |action_log, cx| {
1162                    action_log.buffer_read(buffer.clone(), cx);
1163                })?;
1164                project.update(cx, |project, cx| {
1165                    let position = buffer
1166                        .read(cx)
1167                        .snapshot()
1168                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1169                    project.set_agent_location(
1170                        Some(AgentLocation {
1171                            buffer: buffer.downgrade(),
1172                            position,
1173                        }),
1174                        cx,
1175                    );
1176                })?;
1177
1178                buffer.update(cx, |buffer, _| buffer.snapshot())?
1179            };
1180
1181            this.update(cx, |this, _| {
1182                let text = snapshot.text();
1183                this.shared_buffers.insert(buffer.clone(), snapshot);
1184                if line.is_none() && limit.is_none() {
1185                    return Ok(text);
1186                }
1187                let limit = limit.unwrap_or(u32::MAX) as usize;
1188                let Some(line) = line else {
1189                    return Ok(text.lines().take(limit).collect::<String>());
1190                };
1191
1192                let count = text.lines().count();
1193                if count < line as usize {
1194                    anyhow::bail!("There are only {} lines", count);
1195                }
1196                Ok(text
1197                    .lines()
1198                    .skip(line as usize + 1)
1199                    .take(limit)
1200                    .collect::<String>())
1201            })?
1202        })
1203    }
1204
1205    pub fn write_text_file(
1206        &self,
1207        path: PathBuf,
1208        content: String,
1209        cx: &mut Context<Self>,
1210    ) -> Task<Result<()>> {
1211        let project = self.project.clone();
1212        let action_log = self.action_log.clone();
1213        cx.spawn(async move |this, cx| {
1214            let load = project.update(cx, |project, cx| {
1215                let path = project
1216                    .project_path_for_absolute_path(&path, cx)
1217                    .context("invalid path")?;
1218                anyhow::Ok(project.open_buffer(path, cx))
1219            });
1220            let buffer = load??.await?;
1221            let snapshot = this.update(cx, |this, cx| {
1222                this.shared_buffers
1223                    .get(&buffer)
1224                    .cloned()
1225                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1226            })?;
1227            let edits = cx
1228                .background_executor()
1229                .spawn(async move {
1230                    let old_text = snapshot.text();
1231                    text_diff(old_text.as_str(), &content)
1232                        .into_iter()
1233                        .map(|(range, replacement)| {
1234                            (
1235                                snapshot.anchor_after(range.start)
1236                                    ..snapshot.anchor_before(range.end),
1237                                replacement,
1238                            )
1239                        })
1240                        .collect::<Vec<_>>()
1241                })
1242                .await;
1243            cx.update(|cx| {
1244                project.update(cx, |project, cx| {
1245                    project.set_agent_location(
1246                        Some(AgentLocation {
1247                            buffer: buffer.downgrade(),
1248                            position: edits
1249                                .last()
1250                                .map(|(range, _)| range.end)
1251                                .unwrap_or(Anchor::MIN),
1252                        }),
1253                        cx,
1254                    );
1255                });
1256
1257                action_log.update(cx, |action_log, cx| {
1258                    action_log.buffer_read(buffer.clone(), cx);
1259                });
1260                buffer.update(cx, |buffer, cx| {
1261                    buffer.edit(edits, None, cx);
1262                });
1263                action_log.update(cx, |action_log, cx| {
1264                    action_log.buffer_edited(buffer.clone(), cx);
1265                });
1266            })?;
1267            project
1268                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1269                .await
1270        })
1271    }
1272
1273    pub fn to_markdown(&self, cx: &App) -> String {
1274        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1275    }
1276
1277    pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1278        cx.emit(AcpThreadEvent::ServerExited(status));
1279    }
1280}
1281
1282#[cfg(test)]
1283mod tests {
1284    use super::*;
1285    use anyhow::anyhow;
1286    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1287    use gpui::{AsyncApp, TestAppContext, WeakEntity};
1288    use indoc::indoc;
1289    use project::FakeFs;
1290    use rand::Rng as _;
1291    use serde_json::json;
1292    use settings::SettingsStore;
1293    use smol::stream::StreamExt as _;
1294    use std::{cell::RefCell, rc::Rc, time::Duration};
1295
1296    use util::path;
1297
1298    fn init_test(cx: &mut TestAppContext) {
1299        env_logger::try_init().ok();
1300        cx.update(|cx| {
1301            let settings_store = SettingsStore::test(cx);
1302            cx.set_global(settings_store);
1303            Project::init_settings(cx);
1304            language::init(cx);
1305        });
1306    }
1307
1308    #[gpui::test]
1309    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1310        init_test(cx);
1311
1312        let fs = FakeFs::new(cx.executor());
1313        let project = Project::test(fs, [], cx).await;
1314        let connection = Rc::new(FakeAgentConnection::new());
1315        let thread = cx
1316            .spawn(async move |mut cx| {
1317                connection
1318                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1319                    .await
1320            })
1321            .await
1322            .unwrap();
1323
1324        // Test creating a new user message
1325        thread.update(cx, |thread, cx| {
1326            thread.push_user_content_block(
1327                acp::ContentBlock::Text(acp::TextContent {
1328                    annotations: None,
1329                    text: "Hello, ".to_string(),
1330                }),
1331                cx,
1332            );
1333        });
1334
1335        thread.update(cx, |thread, cx| {
1336            assert_eq!(thread.entries.len(), 1);
1337            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1338                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1339            } else {
1340                panic!("Expected UserMessage");
1341            }
1342        });
1343
1344        // Test appending to existing user message
1345        thread.update(cx, |thread, cx| {
1346            thread.push_user_content_block(
1347                acp::ContentBlock::Text(acp::TextContent {
1348                    annotations: None,
1349                    text: "world!".to_string(),
1350                }),
1351                cx,
1352            );
1353        });
1354
1355        thread.update(cx, |thread, cx| {
1356            assert_eq!(thread.entries.len(), 1);
1357            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1358                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1359            } else {
1360                panic!("Expected UserMessage");
1361            }
1362        });
1363
1364        // Test creating new user message after assistant message
1365        thread.update(cx, |thread, cx| {
1366            thread.push_assistant_content_block(
1367                acp::ContentBlock::Text(acp::TextContent {
1368                    annotations: None,
1369                    text: "Assistant response".to_string(),
1370                }),
1371                false,
1372                cx,
1373            );
1374        });
1375
1376        thread.update(cx, |thread, cx| {
1377            thread.push_user_content_block(
1378                acp::ContentBlock::Text(acp::TextContent {
1379                    annotations: None,
1380                    text: "New user message".to_string(),
1381                }),
1382                cx,
1383            );
1384        });
1385
1386        thread.update(cx, |thread, cx| {
1387            assert_eq!(thread.entries.len(), 3);
1388            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1389                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1390            } else {
1391                panic!("Expected UserMessage at index 2");
1392            }
1393        });
1394    }
1395
1396    #[gpui::test]
1397    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1398        init_test(cx);
1399
1400        let fs = FakeFs::new(cx.executor());
1401        let project = Project::test(fs, [], cx).await;
1402        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1403            |_, thread, mut cx| {
1404                async move {
1405                    thread.update(&mut cx, |thread, cx| {
1406                        thread
1407                            .handle_session_update(
1408                                acp::SessionUpdate::AgentThoughtChunk {
1409                                    content: "Thinking ".into(),
1410                                },
1411                                cx,
1412                            )
1413                            .unwrap();
1414                        thread
1415                            .handle_session_update(
1416                                acp::SessionUpdate::AgentThoughtChunk {
1417                                    content: "hard!".into(),
1418                                },
1419                                cx,
1420                            )
1421                            .unwrap();
1422                    })?;
1423                    Ok(acp::PromptResponse {
1424                        stop_reason: acp::StopReason::EndTurn,
1425                    })
1426                }
1427                .boxed_local()
1428            },
1429        ));
1430
1431        let thread = cx
1432            .spawn(async move |mut cx| {
1433                connection
1434                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1435                    .await
1436            })
1437            .await
1438            .unwrap();
1439
1440        thread
1441            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1442            .await
1443            .unwrap();
1444
1445        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1446        assert_eq!(
1447            output,
1448            indoc! {r#"
1449            ## User
1450
1451            Hello from Zed!
1452
1453            ## Assistant
1454
1455            <thinking>
1456            Thinking hard!
1457            </thinking>
1458
1459            "#}
1460        );
1461    }
1462
1463    #[gpui::test]
1464    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1465        init_test(cx);
1466
1467        let fs = FakeFs::new(cx.executor());
1468        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1469            .await;
1470        let project = Project::test(fs.clone(), [], cx).await;
1471        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1472        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1473        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1474            move |_, thread, mut cx| {
1475                let read_file_tx = read_file_tx.clone();
1476                async move {
1477                    let content = thread
1478                        .update(&mut cx, |thread, cx| {
1479                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1480                        })
1481                        .unwrap()
1482                        .await
1483                        .unwrap();
1484                    assert_eq!(content, "one\ntwo\nthree\n");
1485                    read_file_tx.take().unwrap().send(()).unwrap();
1486                    thread
1487                        .update(&mut cx, |thread, cx| {
1488                            thread.write_text_file(
1489                                path!("/tmp/foo").into(),
1490                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
1491                                cx,
1492                            )
1493                        })
1494                        .unwrap()
1495                        .await
1496                        .unwrap();
1497                    Ok(acp::PromptResponse {
1498                        stop_reason: acp::StopReason::EndTurn,
1499                    })
1500                }
1501                .boxed_local()
1502            },
1503        ));
1504
1505        let (worktree, pathbuf) = project
1506            .update(cx, |project, cx| {
1507                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1508            })
1509            .await
1510            .unwrap();
1511        let buffer = project
1512            .update(cx, |project, cx| {
1513                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1514            })
1515            .await
1516            .unwrap();
1517
1518        let thread = cx
1519            .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1520            .await
1521            .unwrap();
1522
1523        let request = thread.update(cx, |thread, cx| {
1524            thread.send_raw("Extend the count in /tmp/foo", cx)
1525        });
1526        read_file_rx.await.ok();
1527        buffer.update(cx, |buffer, cx| {
1528            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1529        });
1530        cx.run_until_parked();
1531        assert_eq!(
1532            buffer.read_with(cx, |buffer, _| buffer.text()),
1533            "zero\none\ntwo\nthree\nfour\nfive\n"
1534        );
1535        assert_eq!(
1536            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1537            "zero\none\ntwo\nthree\nfour\nfive\n"
1538        );
1539        request.await.unwrap();
1540    }
1541
1542    #[gpui::test]
1543    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1544        init_test(cx);
1545
1546        let fs = FakeFs::new(cx.executor());
1547        let project = Project::test(fs, [], cx).await;
1548        let id = acp::ToolCallId("test".into());
1549
1550        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1551            let id = id.clone();
1552            move |_, thread, mut cx| {
1553                let id = id.clone();
1554                async move {
1555                    thread
1556                        .update(&mut cx, |thread, cx| {
1557                            thread.handle_session_update(
1558                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1559                                    id: id.clone(),
1560                                    title: "Label".into(),
1561                                    kind: acp::ToolKind::Fetch,
1562                                    status: acp::ToolCallStatus::InProgress,
1563                                    content: vec![],
1564                                    locations: vec![],
1565                                    raw_input: None,
1566                                    raw_output: None,
1567                                }),
1568                                cx,
1569                            )
1570                        })
1571                        .unwrap()
1572                        .unwrap();
1573                    Ok(acp::PromptResponse {
1574                        stop_reason: acp::StopReason::EndTurn,
1575                    })
1576                }
1577                .boxed_local()
1578            }
1579        }));
1580
1581        let thread = cx
1582            .spawn(async move |mut cx| {
1583                connection
1584                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1585                    .await
1586            })
1587            .await
1588            .unwrap();
1589
1590        let request = thread.update(cx, |thread, cx| {
1591            thread.send_raw("Fetch https://example.com", cx)
1592        });
1593
1594        run_until_first_tool_call(&thread, cx).await;
1595
1596        thread.read_with(cx, |thread, _| {
1597            assert!(matches!(
1598                thread.entries[1],
1599                AgentThreadEntry::ToolCall(ToolCall {
1600                    status: ToolCallStatus::Allowed {
1601                        status: acp::ToolCallStatus::InProgress,
1602                        ..
1603                    },
1604                    ..
1605                })
1606            ));
1607        });
1608
1609        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1610
1611        thread.read_with(cx, |thread, _| {
1612            assert!(matches!(
1613                &thread.entries[1],
1614                AgentThreadEntry::ToolCall(ToolCall {
1615                    status: ToolCallStatus::Canceled,
1616                    ..
1617                })
1618            ));
1619        });
1620
1621        thread
1622            .update(cx, |thread, cx| {
1623                thread.handle_session_update(
1624                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1625                        id,
1626                        fields: acp::ToolCallUpdateFields {
1627                            status: Some(acp::ToolCallStatus::Completed),
1628                            ..Default::default()
1629                        },
1630                    }),
1631                    cx,
1632                )
1633            })
1634            .unwrap();
1635
1636        request.await.unwrap();
1637
1638        thread.read_with(cx, |thread, _| {
1639            assert!(matches!(
1640                thread.entries[1],
1641                AgentThreadEntry::ToolCall(ToolCall {
1642                    status: ToolCallStatus::Allowed {
1643                        status: acp::ToolCallStatus::Completed,
1644                        ..
1645                    },
1646                    ..
1647                })
1648            ));
1649        });
1650    }
1651
1652    #[gpui::test]
1653    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1654        init_test(cx);
1655        let fs = FakeFs::new(cx.background_executor.clone());
1656        fs.insert_tree(path!("/test"), json!({})).await;
1657        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1658
1659        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1660            move |_, thread, mut cx| {
1661                async move {
1662                    thread
1663                        .update(&mut cx, |thread, cx| {
1664                            thread.handle_session_update(
1665                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1666                                    id: acp::ToolCallId("test".into()),
1667                                    title: "Label".into(),
1668                                    kind: acp::ToolKind::Edit,
1669                                    status: acp::ToolCallStatus::Completed,
1670                                    content: vec![acp::ToolCallContent::Diff {
1671                                        diff: acp::Diff {
1672                                            path: "/test/test.txt".into(),
1673                                            old_text: None,
1674                                            new_text: "foo".into(),
1675                                        },
1676                                    }],
1677                                    locations: vec![],
1678                                    raw_input: None,
1679                                    raw_output: None,
1680                                }),
1681                                cx,
1682                            )
1683                        })
1684                        .unwrap()
1685                        .unwrap();
1686                    Ok(acp::PromptResponse {
1687                        stop_reason: acp::StopReason::EndTurn,
1688                    })
1689                }
1690                .boxed_local()
1691            }
1692        }));
1693
1694        let thread = connection
1695            .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1696            .await
1697            .unwrap();
1698        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1699            .await
1700            .unwrap();
1701
1702        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1703    }
1704
1705    async fn run_until_first_tool_call(
1706        thread: &Entity<AcpThread>,
1707        cx: &mut TestAppContext,
1708    ) -> usize {
1709        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1710
1711        let subscription = cx.update(|cx| {
1712            cx.subscribe(thread, move |thread, _, cx| {
1713                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1714                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1715                        return tx.try_send(ix).unwrap();
1716                    }
1717                }
1718            })
1719        });
1720
1721        select! {
1722            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1723                panic!("Timeout waiting for tool call")
1724            }
1725            ix = rx.next().fuse() => {
1726                drop(subscription);
1727                ix.unwrap()
1728            }
1729        }
1730    }
1731
1732    #[derive(Clone, Default)]
1733    struct FakeAgentConnection {
1734        auth_methods: Vec<acp::AuthMethod>,
1735        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1736        on_user_message: Option<
1737            Rc<
1738                dyn Fn(
1739                        acp::PromptRequest,
1740                        WeakEntity<AcpThread>,
1741                        AsyncApp,
1742                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1743                    + 'static,
1744            >,
1745        >,
1746    }
1747
1748    impl FakeAgentConnection {
1749        fn new() -> Self {
1750            Self {
1751                auth_methods: Vec::new(),
1752                on_user_message: None,
1753                sessions: Arc::default(),
1754            }
1755        }
1756
1757        #[expect(unused)]
1758        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1759            self.auth_methods = auth_methods;
1760            self
1761        }
1762
1763        fn on_user_message(
1764            mut self,
1765            handler: impl Fn(
1766                acp::PromptRequest,
1767                WeakEntity<AcpThread>,
1768                AsyncApp,
1769            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1770            + 'static,
1771        ) -> Self {
1772            self.on_user_message.replace(Rc::new(handler));
1773            self
1774        }
1775    }
1776
1777    impl AgentConnection for FakeAgentConnection {
1778        fn auth_methods(&self) -> &[acp::AuthMethod] {
1779            &self.auth_methods
1780        }
1781
1782        fn new_thread(
1783            self: Rc<Self>,
1784            project: Entity<Project>,
1785            _cwd: &Path,
1786            cx: &mut gpui::AsyncApp,
1787        ) -> Task<gpui::Result<Entity<AcpThread>>> {
1788            let session_id = acp::SessionId(
1789                rand::thread_rng()
1790                    .sample_iter(&rand::distributions::Alphanumeric)
1791                    .take(7)
1792                    .map(char::from)
1793                    .collect::<String>()
1794                    .into(),
1795            );
1796            let thread = cx
1797                .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1798                .unwrap();
1799            self.sessions.lock().insert(session_id, thread.downgrade());
1800            Task::ready(Ok(thread))
1801        }
1802
1803        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1804            if self.auth_methods().iter().any(|m| m.id == method) {
1805                Task::ready(Ok(()))
1806            } else {
1807                Task::ready(Err(anyhow!("Invalid Auth Method")))
1808            }
1809        }
1810
1811        fn prompt(
1812            &self,
1813            params: acp::PromptRequest,
1814            cx: &mut App,
1815        ) -> Task<gpui::Result<acp::PromptResponse>> {
1816            let sessions = self.sessions.lock();
1817            let thread = sessions.get(&params.session_id).unwrap();
1818            if let Some(handler) = &self.on_user_message {
1819                let handler = handler.clone();
1820                let thread = thread.clone();
1821                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1822            } else {
1823                Task::ready(Ok(acp::PromptResponse {
1824                    stop_reason: acp::StopReason::EndTurn,
1825                }))
1826            }
1827        }
1828
1829        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1830            let sessions = self.sessions.lock();
1831            let thread = sessions.get(&session_id).unwrap().clone();
1832
1833            cx.spawn(async move |cx| {
1834                thread
1835                    .update(cx, |thread, cx| thread.cancel(cx))
1836                    .unwrap()
1837                    .await
1838            })
1839            .detach();
1840        }
1841    }
1842}