acp_thread.rs

   1mod connection;
   2mod diff;
   3mod mention;
   4mod terminal;
   5
   6pub use connection::*;
   7pub use diff::*;
   8pub use mention::*;
   9pub use terminal::*;
  10
  11use action_log::ActionLog;
  12use agent_client_protocol::{self as acp};
  13use anyhow::{Context as _, Result};
  14use editor::Bias;
  15use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  16use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
  17use itertools::Itertools;
  18use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff};
  19use markdown::Markdown;
  20use project::{AgentLocation, Project};
  21use std::collections::HashMap;
  22use std::error::Error;
  23use std::fmt::Formatter;
  24use std::process::ExitStatus;
  25use std::rc::Rc;
  26use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
  27use ui::App;
  28use util::ResultExt;
  29
  30#[derive(Debug)]
  31pub struct UserMessage {
  32    pub content: ContentBlock,
  33}
  34
  35impl UserMessage {
  36    pub fn from_acp(
  37        message: impl IntoIterator<Item = acp::ContentBlock>,
  38        language_registry: Arc<LanguageRegistry>,
  39        cx: &mut App,
  40    ) -> Self {
  41        let mut content = ContentBlock::Empty;
  42        for chunk in message {
  43            content.append(chunk, &language_registry, cx)
  44        }
  45        Self { content: content }
  46    }
  47
  48    fn to_markdown(&self, cx: &App) -> String {
  49        format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
  50    }
  51}
  52
  53#[derive(Debug, PartialEq)]
  54pub struct AssistantMessage {
  55    pub chunks: Vec<AssistantMessageChunk>,
  56}
  57
  58impl AssistantMessage {
  59    pub fn to_markdown(&self, cx: &App) -> String {
  60        format!(
  61            "## Assistant\n\n{}\n\n",
  62            self.chunks
  63                .iter()
  64                .map(|chunk| chunk.to_markdown(cx))
  65                .join("\n\n")
  66        )
  67    }
  68}
  69
  70#[derive(Debug, PartialEq)]
  71pub enum AssistantMessageChunk {
  72    Message { block: ContentBlock },
  73    Thought { block: ContentBlock },
  74}
  75
  76impl AssistantMessageChunk {
  77    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
  78        Self::Message {
  79            block: ContentBlock::new(chunk.into(), language_registry, cx),
  80        }
  81    }
  82
  83    fn to_markdown(&self, cx: &App) -> String {
  84        match self {
  85            Self::Message { block } => block.to_markdown(cx).to_string(),
  86            Self::Thought { block } => {
  87                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
  88            }
  89        }
  90    }
  91}
  92
  93#[derive(Debug)]
  94pub enum AgentThreadEntry {
  95    UserMessage(UserMessage),
  96    AssistantMessage(AssistantMessage),
  97    ToolCall(ToolCall),
  98}
  99
 100impl AgentThreadEntry {
 101    fn to_markdown(&self, cx: &App) -> String {
 102        match self {
 103            Self::UserMessage(message) => message.to_markdown(cx),
 104            Self::AssistantMessage(message) => message.to_markdown(cx),
 105            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 106        }
 107    }
 108
 109    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 110        if let AgentThreadEntry::ToolCall(call) = self {
 111            itertools::Either::Left(call.diffs())
 112        } else {
 113            itertools::Either::Right(std::iter::empty())
 114        }
 115    }
 116
 117    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 118        if let AgentThreadEntry::ToolCall(call) = self {
 119            itertools::Either::Left(call.terminals())
 120        } else {
 121            itertools::Either::Right(std::iter::empty())
 122        }
 123    }
 124
 125    pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
 126        if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
 127            Some(locations)
 128        } else {
 129            None
 130        }
 131    }
 132}
 133
 134#[derive(Debug)]
 135pub struct ToolCall {
 136    pub id: acp::ToolCallId,
 137    pub label: Entity<Markdown>,
 138    pub kind: acp::ToolKind,
 139    pub content: Vec<ToolCallContent>,
 140    pub status: ToolCallStatus,
 141    pub locations: Vec<acp::ToolCallLocation>,
 142    pub raw_input: Option<serde_json::Value>,
 143    pub raw_output: Option<serde_json::Value>,
 144}
 145
 146impl ToolCall {
 147    fn from_acp(
 148        tool_call: acp::ToolCall,
 149        status: ToolCallStatus,
 150        language_registry: Arc<LanguageRegistry>,
 151        cx: &mut App,
 152    ) -> Self {
 153        Self {
 154            id: tool_call.id,
 155            label: cx.new(|cx| {
 156                Markdown::new(
 157                    tool_call.title.into(),
 158                    Some(language_registry.clone()),
 159                    None,
 160                    cx,
 161                )
 162            }),
 163            kind: tool_call.kind,
 164            content: tool_call
 165                .content
 166                .into_iter()
 167                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 168                .collect(),
 169            locations: tool_call.locations,
 170            status,
 171            raw_input: tool_call.raw_input,
 172            raw_output: tool_call.raw_output,
 173        }
 174    }
 175
 176    fn update_fields(
 177        &mut self,
 178        fields: acp::ToolCallUpdateFields,
 179        language_registry: Arc<LanguageRegistry>,
 180        cx: &mut App,
 181    ) {
 182        let acp::ToolCallUpdateFields {
 183            kind,
 184            status,
 185            title,
 186            content,
 187            locations,
 188            raw_input,
 189            raw_output,
 190        } = fields;
 191
 192        if let Some(kind) = kind {
 193            self.kind = kind;
 194        }
 195
 196        if let Some(status) = status {
 197            self.status = ToolCallStatus::Allowed { status };
 198        }
 199
 200        if let Some(title) = title {
 201            self.label.update(cx, |label, cx| {
 202                label.replace(title, cx);
 203            });
 204        }
 205
 206        if let Some(content) = content {
 207            self.content = content
 208                .into_iter()
 209                .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
 210                .collect();
 211        }
 212
 213        if let Some(locations) = locations {
 214            self.locations = locations;
 215        }
 216
 217        if let Some(raw_input) = raw_input {
 218            self.raw_input = Some(raw_input);
 219        }
 220
 221        if let Some(raw_output) = raw_output {
 222            if self.content.is_empty() {
 223                if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
 224                {
 225                    self.content
 226                        .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
 227                            markdown,
 228                        }));
 229                }
 230            }
 231            self.raw_output = Some(raw_output);
 232        }
 233    }
 234
 235    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 236        self.content.iter().filter_map(|content| match content {
 237            ToolCallContent::Diff(diff) => Some(diff),
 238            ToolCallContent::ContentBlock(_) => None,
 239            ToolCallContent::Terminal(_) => None,
 240        })
 241    }
 242
 243    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 244        self.content.iter().filter_map(|content| match content {
 245            ToolCallContent::Terminal(terminal) => Some(terminal),
 246            ToolCallContent::ContentBlock(_) => None,
 247            ToolCallContent::Diff(_) => None,
 248        })
 249    }
 250
 251    fn to_markdown(&self, cx: &App) -> String {
 252        let mut markdown = format!(
 253            "**Tool Call: {}**\nStatus: {}\n\n",
 254            self.label.read(cx).source(),
 255            self.status
 256        );
 257        for content in &self.content {
 258            markdown.push_str(content.to_markdown(cx).as_str());
 259            markdown.push_str("\n\n");
 260        }
 261        markdown
 262    }
 263}
 264
 265#[derive(Debug)]
 266pub enum ToolCallStatus {
 267    WaitingForConfirmation {
 268        options: Vec<acp::PermissionOption>,
 269        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 270    },
 271    Allowed {
 272        status: acp::ToolCallStatus,
 273    },
 274    Rejected,
 275    Canceled,
 276}
 277
 278impl Display for ToolCallStatus {
 279    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 280        write!(
 281            f,
 282            "{}",
 283            match self {
 284                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 285                ToolCallStatus::Allowed { status } => match status {
 286                    acp::ToolCallStatus::Pending => "Pending",
 287                    acp::ToolCallStatus::InProgress => "In Progress",
 288                    acp::ToolCallStatus::Completed => "Completed",
 289                    acp::ToolCallStatus::Failed => "Failed",
 290                },
 291                ToolCallStatus::Rejected => "Rejected",
 292                ToolCallStatus::Canceled => "Canceled",
 293            }
 294        )
 295    }
 296}
 297
 298#[derive(Debug, PartialEq, Clone)]
 299pub enum ContentBlock {
 300    Empty,
 301    Markdown { markdown: Entity<Markdown> },
 302    ResourceLink { resource_link: acp::ResourceLink },
 303}
 304
 305impl ContentBlock {
 306    pub fn new(
 307        block: acp::ContentBlock,
 308        language_registry: &Arc<LanguageRegistry>,
 309        cx: &mut App,
 310    ) -> Self {
 311        let mut this = Self::Empty;
 312        this.append(block, language_registry, cx);
 313        this
 314    }
 315
 316    pub fn new_combined(
 317        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 318        language_registry: Arc<LanguageRegistry>,
 319        cx: &mut App,
 320    ) -> Self {
 321        let mut this = Self::Empty;
 322        for block in blocks {
 323            this.append(block, &language_registry, cx);
 324        }
 325        this
 326    }
 327
 328    pub fn append(
 329        &mut self,
 330        block: acp::ContentBlock,
 331        language_registry: &Arc<LanguageRegistry>,
 332        cx: &mut App,
 333    ) {
 334        if matches!(self, ContentBlock::Empty) {
 335            if let acp::ContentBlock::ResourceLink(resource_link) = block {
 336                *self = ContentBlock::ResourceLink { resource_link };
 337                return;
 338            }
 339        }
 340
 341        let new_content = self.extract_content_from_block(block);
 342
 343        match self {
 344            ContentBlock::Empty => {
 345                *self = Self::create_markdown_block(new_content, language_registry, cx);
 346            }
 347            ContentBlock::Markdown { markdown } => {
 348                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 349            }
 350            ContentBlock::ResourceLink { resource_link } => {
 351                let existing_content = Self::resource_link_to_content(&resource_link.uri);
 352                let combined = format!("{}\n{}", existing_content, new_content);
 353
 354                *self = Self::create_markdown_block(combined, language_registry, cx);
 355            }
 356        }
 357    }
 358
 359    fn resource_link_to_content(uri: &str) -> String {
 360        if let Some(uri) = MentionUri::parse(&uri).log_err() {
 361            uri.to_link()
 362        } else {
 363            uri.to_string().clone()
 364        }
 365    }
 366
 367    fn create_markdown_block(
 368        content: String,
 369        language_registry: &Arc<LanguageRegistry>,
 370        cx: &mut App,
 371    ) -> ContentBlock {
 372        ContentBlock::Markdown {
 373            markdown: cx
 374                .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
 375        }
 376    }
 377
 378    fn extract_content_from_block(&self, block: acp::ContentBlock) -> String {
 379        match block {
 380            acp::ContentBlock::Text(text_content) => text_content.text.clone(),
 381            acp::ContentBlock::ResourceLink(resource_link) => {
 382                Self::resource_link_to_content(&resource_link.uri)
 383            }
 384            acp::ContentBlock::Resource(acp::EmbeddedResource {
 385                resource:
 386                    acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
 387                        uri,
 388                        ..
 389                    }),
 390                ..
 391            }) => Self::resource_link_to_content(&uri),
 392            acp::ContentBlock::Image(_)
 393            | acp::ContentBlock::Audio(_)
 394            | acp::ContentBlock::Resource(_) => String::new(),
 395        }
 396    }
 397
 398    fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 399        match self {
 400            ContentBlock::Empty => "",
 401            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 402            ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
 403        }
 404    }
 405
 406    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 407        match self {
 408            ContentBlock::Empty => None,
 409            ContentBlock::Markdown { markdown } => Some(markdown),
 410            ContentBlock::ResourceLink { .. } => None,
 411        }
 412    }
 413
 414    pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
 415        match self {
 416            ContentBlock::ResourceLink { resource_link } => Some(resource_link),
 417            _ => None,
 418        }
 419    }
 420}
 421
 422#[derive(Debug)]
 423pub enum ToolCallContent {
 424    ContentBlock(ContentBlock),
 425    Diff(Entity<Diff>),
 426    Terminal(Entity<Terminal>),
 427}
 428
 429impl ToolCallContent {
 430    pub fn from_acp(
 431        content: acp::ToolCallContent,
 432        language_registry: Arc<LanguageRegistry>,
 433        cx: &mut App,
 434    ) -> Self {
 435        match content {
 436            acp::ToolCallContent::Content { content } => {
 437                Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
 438            }
 439            acp::ToolCallContent::Diff { diff } => {
 440                Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
 441            }
 442        }
 443    }
 444
 445    pub fn to_markdown(&self, cx: &App) -> String {
 446        match self {
 447            Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
 448            Self::Diff(diff) => diff.read(cx).to_markdown(cx),
 449            Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
 450        }
 451    }
 452}
 453
 454#[derive(Debug, PartialEq)]
 455pub enum ToolCallUpdate {
 456    UpdateFields(acp::ToolCallUpdate),
 457    UpdateDiff(ToolCallUpdateDiff),
 458    UpdateTerminal(ToolCallUpdateTerminal),
 459}
 460
 461impl ToolCallUpdate {
 462    fn id(&self) -> &acp::ToolCallId {
 463        match self {
 464            Self::UpdateFields(update) => &update.id,
 465            Self::UpdateDiff(diff) => &diff.id,
 466            Self::UpdateTerminal(terminal) => &terminal.id,
 467        }
 468    }
 469}
 470
 471impl From<acp::ToolCallUpdate> for ToolCallUpdate {
 472    fn from(update: acp::ToolCallUpdate) -> Self {
 473        Self::UpdateFields(update)
 474    }
 475}
 476
 477impl From<ToolCallUpdateDiff> for ToolCallUpdate {
 478    fn from(diff: ToolCallUpdateDiff) -> Self {
 479        Self::UpdateDiff(diff)
 480    }
 481}
 482
 483#[derive(Debug, PartialEq)]
 484pub struct ToolCallUpdateDiff {
 485    pub id: acp::ToolCallId,
 486    pub diff: Entity<Diff>,
 487}
 488
 489impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
 490    fn from(terminal: ToolCallUpdateTerminal) -> Self {
 491        Self::UpdateTerminal(terminal)
 492    }
 493}
 494
 495#[derive(Debug, PartialEq)]
 496pub struct ToolCallUpdateTerminal {
 497    pub id: acp::ToolCallId,
 498    pub terminal: Entity<Terminal>,
 499}
 500
 501#[derive(Debug, Default)]
 502pub struct Plan {
 503    pub entries: Vec<PlanEntry>,
 504}
 505
 506#[derive(Debug)]
 507pub struct PlanStats<'a> {
 508    pub in_progress_entry: Option<&'a PlanEntry>,
 509    pub pending: u32,
 510    pub completed: u32,
 511}
 512
 513impl Plan {
 514    pub fn is_empty(&self) -> bool {
 515        self.entries.is_empty()
 516    }
 517
 518    pub fn stats(&self) -> PlanStats<'_> {
 519        let mut stats = PlanStats {
 520            in_progress_entry: None,
 521            pending: 0,
 522            completed: 0,
 523        };
 524
 525        for entry in &self.entries {
 526            match &entry.status {
 527                acp::PlanEntryStatus::Pending => {
 528                    stats.pending += 1;
 529                }
 530                acp::PlanEntryStatus::InProgress => {
 531                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 532                }
 533                acp::PlanEntryStatus::Completed => {
 534                    stats.completed += 1;
 535                }
 536            }
 537        }
 538
 539        stats
 540    }
 541}
 542
 543#[derive(Debug)]
 544pub struct PlanEntry {
 545    pub content: Entity<Markdown>,
 546    pub priority: acp::PlanEntryPriority,
 547    pub status: acp::PlanEntryStatus,
 548}
 549
 550impl PlanEntry {
 551    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 552        Self {
 553            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 554            priority: entry.priority,
 555            status: entry.status,
 556        }
 557    }
 558}
 559
 560pub struct AcpThread {
 561    title: SharedString,
 562    entries: Vec<AgentThreadEntry>,
 563    plan: Plan,
 564    project: Entity<Project>,
 565    action_log: Entity<ActionLog>,
 566    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 567    send_task: Option<Task<()>>,
 568    connection: Rc<dyn AgentConnection>,
 569    session_id: acp::SessionId,
 570}
 571
 572pub enum AcpThreadEvent {
 573    NewEntry,
 574    EntryUpdated(usize),
 575    ToolAuthorizationRequired,
 576    Stopped,
 577    Error,
 578    ServerExited(ExitStatus),
 579}
 580
 581impl EventEmitter<AcpThreadEvent> for AcpThread {}
 582
 583#[derive(PartialEq, Eq)]
 584pub enum ThreadStatus {
 585    Idle,
 586    WaitingForToolConfirmation,
 587    Generating,
 588}
 589
 590#[derive(Debug, Clone)]
 591pub enum LoadError {
 592    Unsupported {
 593        error_message: SharedString,
 594        upgrade_message: SharedString,
 595        upgrade_command: String,
 596    },
 597    Exited(i32),
 598    Other(SharedString),
 599}
 600
 601impl Display for LoadError {
 602    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 603        match self {
 604            LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
 605            LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
 606            LoadError::Other(msg) => write!(f, "{}", msg),
 607        }
 608    }
 609}
 610
 611impl Error for LoadError {}
 612
 613impl AcpThread {
 614    pub fn new(
 615        title: impl Into<SharedString>,
 616        connection: Rc<dyn AgentConnection>,
 617        project: Entity<Project>,
 618        session_id: acp::SessionId,
 619        cx: &mut Context<Self>,
 620    ) -> Self {
 621        let action_log = cx.new(|_| ActionLog::new(project.clone()));
 622
 623        Self {
 624            action_log,
 625            shared_buffers: Default::default(),
 626            entries: Default::default(),
 627            plan: Default::default(),
 628            title: title.into(),
 629            project,
 630            send_task: None,
 631            connection,
 632            session_id,
 633        }
 634    }
 635
 636    pub fn action_log(&self) -> &Entity<ActionLog> {
 637        &self.action_log
 638    }
 639
 640    pub fn project(&self) -> &Entity<Project> {
 641        &self.project
 642    }
 643
 644    pub fn title(&self) -> SharedString {
 645        self.title.clone()
 646    }
 647
 648    pub fn entries(&self) -> &[AgentThreadEntry] {
 649        &self.entries
 650    }
 651
 652    pub fn session_id(&self) -> &acp::SessionId {
 653        &self.session_id
 654    }
 655
 656    pub fn status(&self) -> ThreadStatus {
 657        if self.send_task.is_some() {
 658            if self.waiting_for_tool_confirmation() {
 659                ThreadStatus::WaitingForToolConfirmation
 660            } else {
 661                ThreadStatus::Generating
 662            }
 663        } else {
 664            ThreadStatus::Idle
 665        }
 666    }
 667
 668    pub fn has_pending_edit_tool_calls(&self) -> bool {
 669        for entry in self.entries.iter().rev() {
 670            match entry {
 671                AgentThreadEntry::UserMessage(_) => return false,
 672                AgentThreadEntry::ToolCall(
 673                    call @ ToolCall {
 674                        status:
 675                            ToolCallStatus::Allowed {
 676                                status:
 677                                    acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
 678                            },
 679                        ..
 680                    },
 681                ) if call.diffs().next().is_some() => {
 682                    return true;
 683                }
 684                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 685            }
 686        }
 687
 688        false
 689    }
 690
 691    pub fn used_tools_since_last_user_message(&self) -> bool {
 692        for entry in self.entries.iter().rev() {
 693            match entry {
 694                AgentThreadEntry::UserMessage(..) => return false,
 695                AgentThreadEntry::AssistantMessage(..) => continue,
 696                AgentThreadEntry::ToolCall(..) => return true,
 697            }
 698        }
 699
 700        false
 701    }
 702
 703    pub fn handle_session_update(
 704        &mut self,
 705        update: acp::SessionUpdate,
 706        cx: &mut Context<Self>,
 707    ) -> Result<()> {
 708        match update {
 709            acp::SessionUpdate::UserMessageChunk { content } => {
 710                self.push_user_content_block(content, cx);
 711            }
 712            acp::SessionUpdate::AgentMessageChunk { content } => {
 713                self.push_assistant_content_block(content, false, cx);
 714            }
 715            acp::SessionUpdate::AgentThoughtChunk { content } => {
 716                self.push_assistant_content_block(content, true, cx);
 717            }
 718            acp::SessionUpdate::ToolCall(tool_call) => {
 719                self.upsert_tool_call(tool_call, cx);
 720            }
 721            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 722                self.update_tool_call(tool_call_update, cx)?;
 723            }
 724            acp::SessionUpdate::Plan(plan) => {
 725                self.update_plan(plan, cx);
 726            }
 727        }
 728        Ok(())
 729    }
 730
 731    pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
 732        let language_registry = self.project.read(cx).languages().clone();
 733        let entries_len = self.entries.len();
 734
 735        if let Some(last_entry) = self.entries.last_mut()
 736            && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
 737        {
 738            content.append(chunk, &language_registry, cx);
 739            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 740        } else {
 741            let content = ContentBlock::new(chunk, &language_registry, cx);
 742            self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
 743        }
 744    }
 745
 746    pub fn push_assistant_content_block(
 747        &mut self,
 748        chunk: acp::ContentBlock,
 749        is_thought: bool,
 750        cx: &mut Context<Self>,
 751    ) {
 752        let language_registry = self.project.read(cx).languages().clone();
 753        let entries_len = self.entries.len();
 754        if let Some(last_entry) = self.entries.last_mut()
 755            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
 756        {
 757            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 758            match (chunks.last_mut(), is_thought) {
 759                (Some(AssistantMessageChunk::Message { block }), false)
 760                | (Some(AssistantMessageChunk::Thought { block }), true) => {
 761                    block.append(chunk, &language_registry, cx)
 762                }
 763                _ => {
 764                    let block = ContentBlock::new(chunk, &language_registry, cx);
 765                    if is_thought {
 766                        chunks.push(AssistantMessageChunk::Thought { block })
 767                    } else {
 768                        chunks.push(AssistantMessageChunk::Message { block })
 769                    }
 770                }
 771            }
 772        } else {
 773            let block = ContentBlock::new(chunk, &language_registry, cx);
 774            let chunk = if is_thought {
 775                AssistantMessageChunk::Thought { block }
 776            } else {
 777                AssistantMessageChunk::Message { block }
 778            };
 779
 780            self.push_entry(
 781                AgentThreadEntry::AssistantMessage(AssistantMessage {
 782                    chunks: vec![chunk],
 783                }),
 784                cx,
 785            );
 786        }
 787    }
 788
 789    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
 790        self.entries.push(entry);
 791        cx.emit(AcpThreadEvent::NewEntry);
 792    }
 793
 794    pub fn update_tool_call(
 795        &mut self,
 796        update: impl Into<ToolCallUpdate>,
 797        cx: &mut Context<Self>,
 798    ) -> Result<()> {
 799        let update = update.into();
 800        let languages = self.project.read(cx).languages().clone();
 801
 802        let (ix, current_call) = self
 803            .tool_call_mut(update.id())
 804            .context("Tool call not found")?;
 805        match update {
 806            ToolCallUpdate::UpdateFields(update) => {
 807                current_call.update_fields(update.fields, languages, cx);
 808            }
 809            ToolCallUpdate::UpdateDiff(update) => {
 810                current_call.content.clear();
 811                current_call
 812                    .content
 813                    .push(ToolCallContent::Diff(update.diff));
 814            }
 815            ToolCallUpdate::UpdateTerminal(update) => {
 816                current_call.content.clear();
 817                current_call
 818                    .content
 819                    .push(ToolCallContent::Terminal(update.terminal));
 820            }
 821        }
 822
 823        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 824
 825        Ok(())
 826    }
 827
 828    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
 829    pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
 830        let status = ToolCallStatus::Allowed {
 831            status: tool_call.status,
 832        };
 833        self.upsert_tool_call_inner(tool_call, status, cx)
 834    }
 835
 836    pub fn upsert_tool_call_inner(
 837        &mut self,
 838        tool_call: acp::ToolCall,
 839        status: ToolCallStatus,
 840        cx: &mut Context<Self>,
 841    ) {
 842        let language_registry = self.project.read(cx).languages().clone();
 843        let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
 844
 845        let location = call.locations.last().cloned();
 846
 847        if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
 848            *current_call = call;
 849
 850            cx.emit(AcpThreadEvent::EntryUpdated(ix));
 851        } else {
 852            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
 853        }
 854
 855        if let Some(location) = location {
 856            self.set_project_location(location, cx)
 857        }
 858    }
 859
 860    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
 861        // The tool call we are looking for is typically the last one, or very close to the end.
 862        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
 863        self.entries
 864            .iter_mut()
 865            .enumerate()
 866            .rev()
 867            .find_map(|(index, tool_call)| {
 868                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
 869                    && &tool_call.id == id
 870                {
 871                    Some((index, tool_call))
 872                } else {
 873                    None
 874                }
 875            })
 876    }
 877
 878    pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
 879        self.project.update(cx, |project, cx| {
 880            let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
 881                return;
 882            };
 883            let buffer = project.open_buffer(path, cx);
 884            cx.spawn(async move |project, cx| {
 885                let buffer = buffer.await?;
 886
 887                project.update(cx, |project, cx| {
 888                    let position = if let Some(line) = location.line {
 889                        let snapshot = buffer.read(cx).snapshot();
 890                        let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
 891                        snapshot.anchor_before(point)
 892                    } else {
 893                        Anchor::MIN
 894                    };
 895
 896                    project.set_agent_location(
 897                        Some(AgentLocation {
 898                            buffer: buffer.downgrade(),
 899                            position,
 900                        }),
 901                        cx,
 902                    );
 903                })
 904            })
 905            .detach_and_log_err(cx);
 906        });
 907    }
 908
 909    pub fn request_tool_call_authorization(
 910        &mut self,
 911        tool_call: acp::ToolCall,
 912        options: Vec<acp::PermissionOption>,
 913        cx: &mut Context<Self>,
 914    ) -> oneshot::Receiver<acp::PermissionOptionId> {
 915        let (tx, rx) = oneshot::channel();
 916
 917        let status = ToolCallStatus::WaitingForConfirmation {
 918            options,
 919            respond_tx: tx,
 920        };
 921
 922        self.upsert_tool_call_inner(tool_call, status, cx);
 923        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
 924        rx
 925    }
 926
 927    pub fn authorize_tool_call(
 928        &mut self,
 929        id: acp::ToolCallId,
 930        option_id: acp::PermissionOptionId,
 931        option_kind: acp::PermissionOptionKind,
 932        cx: &mut Context<Self>,
 933    ) {
 934        let Some((ix, call)) = self.tool_call_mut(&id) else {
 935            return;
 936        };
 937
 938        let new_status = match option_kind {
 939            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
 940                ToolCallStatus::Rejected
 941            }
 942            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
 943                ToolCallStatus::Allowed {
 944                    status: acp::ToolCallStatus::InProgress,
 945                }
 946            }
 947        };
 948
 949        let curr_status = mem::replace(&mut call.status, new_status);
 950
 951        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
 952            respond_tx.send(option_id).log_err();
 953        } else if cfg!(debug_assertions) {
 954            panic!("tried to authorize an already authorized tool call");
 955        }
 956
 957        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 958    }
 959
 960    /// Returns true if the last turn is awaiting tool authorization
 961    pub fn waiting_for_tool_confirmation(&self) -> bool {
 962        for entry in self.entries.iter().rev() {
 963            match &entry {
 964                AgentThreadEntry::ToolCall(call) => match call.status {
 965                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
 966                    ToolCallStatus::Allowed { .. }
 967                    | ToolCallStatus::Rejected
 968                    | ToolCallStatus::Canceled => continue,
 969                },
 970                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
 971                    // Reached the beginning of the turn
 972                    return false;
 973                }
 974            }
 975        }
 976        false
 977    }
 978
 979    pub fn plan(&self) -> &Plan {
 980        &self.plan
 981    }
 982
 983    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
 984        let new_entries_len = request.entries.len();
 985        let mut new_entries = request.entries.into_iter();
 986
 987        // Reuse existing markdown to prevent flickering
 988        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
 989            let PlanEntry {
 990                content,
 991                priority,
 992                status,
 993            } = old;
 994            content.update(cx, |old, cx| {
 995                old.replace(new.content, cx);
 996            });
 997            *priority = new.priority;
 998            *status = new.status;
 999        }
1000        for new in new_entries {
1001            self.plan.entries.push(PlanEntry::from_acp(new, cx))
1002        }
1003        self.plan.entries.truncate(new_entries_len);
1004
1005        cx.notify();
1006    }
1007
1008    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1009        self.plan
1010            .entries
1011            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1012        cx.notify();
1013    }
1014
1015    #[cfg(any(test, feature = "test-support"))]
1016    pub fn send_raw(
1017        &mut self,
1018        message: &str,
1019        cx: &mut Context<Self>,
1020    ) -> BoxFuture<'static, Result<()>> {
1021        self.send(
1022            vec![acp::ContentBlock::Text(acp::TextContent {
1023                text: message.to_string(),
1024                annotations: None,
1025            })],
1026            cx,
1027        )
1028    }
1029
1030    pub fn send(
1031        &mut self,
1032        message: Vec<acp::ContentBlock>,
1033        cx: &mut Context<Self>,
1034    ) -> BoxFuture<'static, Result<()>> {
1035        let block = ContentBlock::new_combined(
1036            message.clone(),
1037            self.project.read(cx).languages().clone(),
1038            cx,
1039        );
1040        self.push_entry(
1041            AgentThreadEntry::UserMessage(UserMessage { content: block }),
1042            cx,
1043        );
1044        self.clear_completed_plan_entries(cx);
1045
1046        let (tx, rx) = oneshot::channel();
1047        let cancel_task = self.cancel(cx);
1048
1049        self.send_task = Some(cx.spawn(async move |this, cx| {
1050            async {
1051                cancel_task.await;
1052
1053                let result = this
1054                    .update(cx, |this, cx| {
1055                        this.connection.prompt(
1056                            acp::PromptRequest {
1057                                prompt: message,
1058                                session_id: this.session_id.clone(),
1059                            },
1060                            cx,
1061                        )
1062                    })?
1063                    .await;
1064
1065                tx.send(result).log_err();
1066
1067                anyhow::Ok(())
1068            }
1069            .await
1070            .log_err();
1071        }));
1072
1073        cx.spawn(async move |this, cx| match rx.await {
1074            Ok(Err(e)) => {
1075                this.update(cx, |this, cx| {
1076                    this.send_task.take();
1077                    cx.emit(AcpThreadEvent::Error)
1078                })
1079                .log_err();
1080                Err(e)?
1081            }
1082            result => {
1083                let cancelled = matches!(
1084                    result,
1085                    Ok(Ok(acp::PromptResponse {
1086                        stop_reason: acp::StopReason::Cancelled
1087                    }))
1088                );
1089
1090                // We only take the task if the current prompt wasn't cancelled.
1091                //
1092                // This prompt may have been cancelled because another one was sent
1093                // while it was still generating. In these cases, dropping `send_task`
1094                // would cause the next generation to be cancelled.
1095                if !cancelled {
1096                    this.update(cx, |this, _cx| this.send_task.take()).ok();
1097                }
1098
1099                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1100                    .log_err();
1101                Ok(())
1102            }
1103        })
1104        .boxed()
1105    }
1106
1107    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1108        let Some(send_task) = self.send_task.take() else {
1109            return Task::ready(());
1110        };
1111
1112        for entry in self.entries.iter_mut() {
1113            if let AgentThreadEntry::ToolCall(call) = entry {
1114                let cancel = matches!(
1115                    call.status,
1116                    ToolCallStatus::WaitingForConfirmation { .. }
1117                        | ToolCallStatus::Allowed {
1118                            status: acp::ToolCallStatus::InProgress
1119                        }
1120                );
1121
1122                if cancel {
1123                    call.status = ToolCallStatus::Canceled;
1124                }
1125            }
1126        }
1127
1128        self.connection.cancel(&self.session_id, cx);
1129
1130        // Wait for the send task to complete
1131        cx.foreground_executor().spawn(send_task)
1132    }
1133
1134    pub fn read_text_file(
1135        &self,
1136        path: PathBuf,
1137        line: Option<u32>,
1138        limit: Option<u32>,
1139        reuse_shared_snapshot: bool,
1140        cx: &mut Context<Self>,
1141    ) -> Task<Result<String>> {
1142        let project = self.project.clone();
1143        let action_log = self.action_log.clone();
1144        cx.spawn(async move |this, cx| {
1145            let load = project.update(cx, |project, cx| {
1146                let path = project
1147                    .project_path_for_absolute_path(&path, cx)
1148                    .context("invalid path")?;
1149                anyhow::Ok(project.open_buffer(path, cx))
1150            });
1151            let buffer = load??.await?;
1152
1153            let snapshot = if reuse_shared_snapshot {
1154                this.read_with(cx, |this, _| {
1155                    this.shared_buffers.get(&buffer.clone()).cloned()
1156                })
1157                .log_err()
1158                .flatten()
1159            } else {
1160                None
1161            };
1162
1163            let snapshot = if let Some(snapshot) = snapshot {
1164                snapshot
1165            } else {
1166                action_log.update(cx, |action_log, cx| {
1167                    action_log.buffer_read(buffer.clone(), cx);
1168                })?;
1169                project.update(cx, |project, cx| {
1170                    let position = buffer
1171                        .read(cx)
1172                        .snapshot()
1173                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1174                    project.set_agent_location(
1175                        Some(AgentLocation {
1176                            buffer: buffer.downgrade(),
1177                            position,
1178                        }),
1179                        cx,
1180                    );
1181                })?;
1182
1183                buffer.update(cx, |buffer, _| buffer.snapshot())?
1184            };
1185
1186            this.update(cx, |this, _| {
1187                let text = snapshot.text();
1188                this.shared_buffers.insert(buffer.clone(), snapshot);
1189                if line.is_none() && limit.is_none() {
1190                    return Ok(text);
1191                }
1192                let limit = limit.unwrap_or(u32::MAX) as usize;
1193                let Some(line) = line else {
1194                    return Ok(text.lines().take(limit).collect::<String>());
1195                };
1196
1197                let count = text.lines().count();
1198                if count < line as usize {
1199                    anyhow::bail!("There are only {} lines", count);
1200                }
1201                Ok(text
1202                    .lines()
1203                    .skip(line as usize + 1)
1204                    .take(limit)
1205                    .collect::<String>())
1206            })?
1207        })
1208    }
1209
1210    pub fn write_text_file(
1211        &self,
1212        path: PathBuf,
1213        content: String,
1214        cx: &mut Context<Self>,
1215    ) -> Task<Result<()>> {
1216        let project = self.project.clone();
1217        let action_log = self.action_log.clone();
1218        cx.spawn(async move |this, cx| {
1219            let load = project.update(cx, |project, cx| {
1220                let path = project
1221                    .project_path_for_absolute_path(&path, cx)
1222                    .context("invalid path")?;
1223                anyhow::Ok(project.open_buffer(path, cx))
1224            });
1225            let buffer = load??.await?;
1226            let snapshot = this.update(cx, |this, cx| {
1227                this.shared_buffers
1228                    .get(&buffer)
1229                    .cloned()
1230                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1231            })?;
1232            let edits = cx
1233                .background_executor()
1234                .spawn(async move {
1235                    let old_text = snapshot.text();
1236                    text_diff(old_text.as_str(), &content)
1237                        .into_iter()
1238                        .map(|(range, replacement)| {
1239                            (
1240                                snapshot.anchor_after(range.start)
1241                                    ..snapshot.anchor_before(range.end),
1242                                replacement,
1243                            )
1244                        })
1245                        .collect::<Vec<_>>()
1246                })
1247                .await;
1248            cx.update(|cx| {
1249                project.update(cx, |project, cx| {
1250                    project.set_agent_location(
1251                        Some(AgentLocation {
1252                            buffer: buffer.downgrade(),
1253                            position: edits
1254                                .last()
1255                                .map(|(range, _)| range.end)
1256                                .unwrap_or(Anchor::MIN),
1257                        }),
1258                        cx,
1259                    );
1260                });
1261
1262                action_log.update(cx, |action_log, cx| {
1263                    action_log.buffer_read(buffer.clone(), cx);
1264                });
1265                buffer.update(cx, |buffer, cx| {
1266                    buffer.edit(edits, None, cx);
1267                });
1268                action_log.update(cx, |action_log, cx| {
1269                    action_log.buffer_edited(buffer.clone(), cx);
1270                });
1271            })?;
1272            project
1273                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1274                .await
1275        })
1276    }
1277
1278    pub fn to_markdown(&self, cx: &App) -> String {
1279        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1280    }
1281
1282    pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1283        cx.emit(AcpThreadEvent::ServerExited(status));
1284    }
1285}
1286
1287fn markdown_for_raw_output(
1288    raw_output: &serde_json::Value,
1289    language_registry: &Arc<LanguageRegistry>,
1290    cx: &mut App,
1291) -> Option<Entity<Markdown>> {
1292    match raw_output {
1293        serde_json::Value::Null => None,
1294        serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1295            Markdown::new(
1296                value.to_string().into(),
1297                Some(language_registry.clone()),
1298                None,
1299                cx,
1300            )
1301        })),
1302        serde_json::Value::Number(value) => Some(cx.new(|cx| {
1303            Markdown::new(
1304                value.to_string().into(),
1305                Some(language_registry.clone()),
1306                None,
1307                cx,
1308            )
1309        })),
1310        serde_json::Value::String(value) => Some(cx.new(|cx| {
1311            Markdown::new(
1312                value.clone().into(),
1313                Some(language_registry.clone()),
1314                None,
1315                cx,
1316            )
1317        })),
1318        value => Some(cx.new(|cx| {
1319            Markdown::new(
1320                format!("```json\n{}\n```", value).into(),
1321                Some(language_registry.clone()),
1322                None,
1323                cx,
1324            )
1325        })),
1326    }
1327}
1328
1329#[cfg(test)]
1330mod tests {
1331    use super::*;
1332    use anyhow::anyhow;
1333    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1334    use gpui::{AsyncApp, TestAppContext, WeakEntity};
1335    use indoc::indoc;
1336    use project::FakeFs;
1337    use rand::Rng as _;
1338    use serde_json::json;
1339    use settings::SettingsStore;
1340    use smol::stream::StreamExt as _;
1341    use std::{cell::RefCell, path::Path, rc::Rc, time::Duration};
1342
1343    use util::path;
1344
1345    fn init_test(cx: &mut TestAppContext) {
1346        env_logger::try_init().ok();
1347        cx.update(|cx| {
1348            let settings_store = SettingsStore::test(cx);
1349            cx.set_global(settings_store);
1350            Project::init_settings(cx);
1351            language::init(cx);
1352        });
1353    }
1354
1355    #[gpui::test]
1356    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1357        init_test(cx);
1358
1359        let fs = FakeFs::new(cx.executor());
1360        let project = Project::test(fs, [], cx).await;
1361        let connection = Rc::new(FakeAgentConnection::new());
1362        let thread = cx
1363            .spawn(async move |mut cx| {
1364                connection
1365                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1366                    .await
1367            })
1368            .await
1369            .unwrap();
1370
1371        // Test creating a new user message
1372        thread.update(cx, |thread, cx| {
1373            thread.push_user_content_block(
1374                acp::ContentBlock::Text(acp::TextContent {
1375                    annotations: None,
1376                    text: "Hello, ".to_string(),
1377                }),
1378                cx,
1379            );
1380        });
1381
1382        thread.update(cx, |thread, cx| {
1383            assert_eq!(thread.entries.len(), 1);
1384            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1385                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1386            } else {
1387                panic!("Expected UserMessage");
1388            }
1389        });
1390
1391        // Test appending to existing user message
1392        thread.update(cx, |thread, cx| {
1393            thread.push_user_content_block(
1394                acp::ContentBlock::Text(acp::TextContent {
1395                    annotations: None,
1396                    text: "world!".to_string(),
1397                }),
1398                cx,
1399            );
1400        });
1401
1402        thread.update(cx, |thread, cx| {
1403            assert_eq!(thread.entries.len(), 1);
1404            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1405                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1406            } else {
1407                panic!("Expected UserMessage");
1408            }
1409        });
1410
1411        // Test creating new user message after assistant message
1412        thread.update(cx, |thread, cx| {
1413            thread.push_assistant_content_block(
1414                acp::ContentBlock::Text(acp::TextContent {
1415                    annotations: None,
1416                    text: "Assistant response".to_string(),
1417                }),
1418                false,
1419                cx,
1420            );
1421        });
1422
1423        thread.update(cx, |thread, cx| {
1424            thread.push_user_content_block(
1425                acp::ContentBlock::Text(acp::TextContent {
1426                    annotations: None,
1427                    text: "New user message".to_string(),
1428                }),
1429                cx,
1430            );
1431        });
1432
1433        thread.update(cx, |thread, cx| {
1434            assert_eq!(thread.entries.len(), 3);
1435            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1436                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1437            } else {
1438                panic!("Expected UserMessage at index 2");
1439            }
1440        });
1441    }
1442
1443    #[gpui::test]
1444    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1445        init_test(cx);
1446
1447        let fs = FakeFs::new(cx.executor());
1448        let project = Project::test(fs, [], cx).await;
1449        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1450            |_, thread, mut cx| {
1451                async move {
1452                    thread.update(&mut cx, |thread, cx| {
1453                        thread
1454                            .handle_session_update(
1455                                acp::SessionUpdate::AgentThoughtChunk {
1456                                    content: "Thinking ".into(),
1457                                },
1458                                cx,
1459                            )
1460                            .unwrap();
1461                        thread
1462                            .handle_session_update(
1463                                acp::SessionUpdate::AgentThoughtChunk {
1464                                    content: "hard!".into(),
1465                                },
1466                                cx,
1467                            )
1468                            .unwrap();
1469                    })?;
1470                    Ok(acp::PromptResponse {
1471                        stop_reason: acp::StopReason::EndTurn,
1472                    })
1473                }
1474                .boxed_local()
1475            },
1476        ));
1477
1478        let thread = cx
1479            .spawn(async move |mut cx| {
1480                connection
1481                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1482                    .await
1483            })
1484            .await
1485            .unwrap();
1486
1487        thread
1488            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1489            .await
1490            .unwrap();
1491
1492        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1493        assert_eq!(
1494            output,
1495            indoc! {r#"
1496            ## User
1497
1498            Hello from Zed!
1499
1500            ## Assistant
1501
1502            <thinking>
1503            Thinking hard!
1504            </thinking>
1505
1506            "#}
1507        );
1508    }
1509
1510    #[gpui::test]
1511    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1512        init_test(cx);
1513
1514        let fs = FakeFs::new(cx.executor());
1515        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1516            .await;
1517        let project = Project::test(fs.clone(), [], cx).await;
1518        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1519        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1520        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1521            move |_, thread, mut cx| {
1522                let read_file_tx = read_file_tx.clone();
1523                async move {
1524                    let content = thread
1525                        .update(&mut cx, |thread, cx| {
1526                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1527                        })
1528                        .unwrap()
1529                        .await
1530                        .unwrap();
1531                    assert_eq!(content, "one\ntwo\nthree\n");
1532                    read_file_tx.take().unwrap().send(()).unwrap();
1533                    thread
1534                        .update(&mut cx, |thread, cx| {
1535                            thread.write_text_file(
1536                                path!("/tmp/foo").into(),
1537                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
1538                                cx,
1539                            )
1540                        })
1541                        .unwrap()
1542                        .await
1543                        .unwrap();
1544                    Ok(acp::PromptResponse {
1545                        stop_reason: acp::StopReason::EndTurn,
1546                    })
1547                }
1548                .boxed_local()
1549            },
1550        ));
1551
1552        let (worktree, pathbuf) = project
1553            .update(cx, |project, cx| {
1554                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1555            })
1556            .await
1557            .unwrap();
1558        let buffer = project
1559            .update(cx, |project, cx| {
1560                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1561            })
1562            .await
1563            .unwrap();
1564
1565        let thread = cx
1566            .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1567            .await
1568            .unwrap();
1569
1570        let request = thread.update(cx, |thread, cx| {
1571            thread.send_raw("Extend the count in /tmp/foo", cx)
1572        });
1573        read_file_rx.await.ok();
1574        buffer.update(cx, |buffer, cx| {
1575            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1576        });
1577        cx.run_until_parked();
1578        assert_eq!(
1579            buffer.read_with(cx, |buffer, _| buffer.text()),
1580            "zero\none\ntwo\nthree\nfour\nfive\n"
1581        );
1582        assert_eq!(
1583            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1584            "zero\none\ntwo\nthree\nfour\nfive\n"
1585        );
1586        request.await.unwrap();
1587    }
1588
1589    #[gpui::test]
1590    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1591        init_test(cx);
1592
1593        let fs = FakeFs::new(cx.executor());
1594        let project = Project::test(fs, [], cx).await;
1595        let id = acp::ToolCallId("test".into());
1596
1597        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1598            let id = id.clone();
1599            move |_, thread, mut cx| {
1600                let id = id.clone();
1601                async move {
1602                    thread
1603                        .update(&mut cx, |thread, cx| {
1604                            thread.handle_session_update(
1605                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1606                                    id: id.clone(),
1607                                    title: "Label".into(),
1608                                    kind: acp::ToolKind::Fetch,
1609                                    status: acp::ToolCallStatus::InProgress,
1610                                    content: vec![],
1611                                    locations: vec![],
1612                                    raw_input: None,
1613                                    raw_output: None,
1614                                }),
1615                                cx,
1616                            )
1617                        })
1618                        .unwrap()
1619                        .unwrap();
1620                    Ok(acp::PromptResponse {
1621                        stop_reason: acp::StopReason::EndTurn,
1622                    })
1623                }
1624                .boxed_local()
1625            }
1626        }));
1627
1628        let thread = cx
1629            .spawn(async move |mut cx| {
1630                connection
1631                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1632                    .await
1633            })
1634            .await
1635            .unwrap();
1636
1637        let request = thread.update(cx, |thread, cx| {
1638            thread.send_raw("Fetch https://example.com", cx)
1639        });
1640
1641        run_until_first_tool_call(&thread, cx).await;
1642
1643        thread.read_with(cx, |thread, _| {
1644            assert!(matches!(
1645                thread.entries[1],
1646                AgentThreadEntry::ToolCall(ToolCall {
1647                    status: ToolCallStatus::Allowed {
1648                        status: acp::ToolCallStatus::InProgress,
1649                        ..
1650                    },
1651                    ..
1652                })
1653            ));
1654        });
1655
1656        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1657
1658        thread.read_with(cx, |thread, _| {
1659            assert!(matches!(
1660                &thread.entries[1],
1661                AgentThreadEntry::ToolCall(ToolCall {
1662                    status: ToolCallStatus::Canceled,
1663                    ..
1664                })
1665            ));
1666        });
1667
1668        thread
1669            .update(cx, |thread, cx| {
1670                thread.handle_session_update(
1671                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1672                        id,
1673                        fields: acp::ToolCallUpdateFields {
1674                            status: Some(acp::ToolCallStatus::Completed),
1675                            ..Default::default()
1676                        },
1677                    }),
1678                    cx,
1679                )
1680            })
1681            .unwrap();
1682
1683        request.await.unwrap();
1684
1685        thread.read_with(cx, |thread, _| {
1686            assert!(matches!(
1687                thread.entries[1],
1688                AgentThreadEntry::ToolCall(ToolCall {
1689                    status: ToolCallStatus::Allowed {
1690                        status: acp::ToolCallStatus::Completed,
1691                        ..
1692                    },
1693                    ..
1694                })
1695            ));
1696        });
1697    }
1698
1699    #[gpui::test]
1700    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1701        init_test(cx);
1702        let fs = FakeFs::new(cx.background_executor.clone());
1703        fs.insert_tree(path!("/test"), json!({})).await;
1704        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1705
1706        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1707            move |_, thread, mut cx| {
1708                async move {
1709                    thread
1710                        .update(&mut cx, |thread, cx| {
1711                            thread.handle_session_update(
1712                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1713                                    id: acp::ToolCallId("test".into()),
1714                                    title: "Label".into(),
1715                                    kind: acp::ToolKind::Edit,
1716                                    status: acp::ToolCallStatus::Completed,
1717                                    content: vec![acp::ToolCallContent::Diff {
1718                                        diff: acp::Diff {
1719                                            path: "/test/test.txt".into(),
1720                                            old_text: None,
1721                                            new_text: "foo".into(),
1722                                        },
1723                                    }],
1724                                    locations: vec![],
1725                                    raw_input: None,
1726                                    raw_output: None,
1727                                }),
1728                                cx,
1729                            )
1730                        })
1731                        .unwrap()
1732                        .unwrap();
1733                    Ok(acp::PromptResponse {
1734                        stop_reason: acp::StopReason::EndTurn,
1735                    })
1736                }
1737                .boxed_local()
1738            }
1739        }));
1740
1741        let thread = connection
1742            .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1743            .await
1744            .unwrap();
1745        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1746            .await
1747            .unwrap();
1748
1749        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1750    }
1751
1752    async fn run_until_first_tool_call(
1753        thread: &Entity<AcpThread>,
1754        cx: &mut TestAppContext,
1755    ) -> usize {
1756        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1757
1758        let subscription = cx.update(|cx| {
1759            cx.subscribe(thread, move |thread, _, cx| {
1760                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1761                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1762                        return tx.try_send(ix).unwrap();
1763                    }
1764                }
1765            })
1766        });
1767
1768        select! {
1769            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1770                panic!("Timeout waiting for tool call")
1771            }
1772            ix = rx.next().fuse() => {
1773                drop(subscription);
1774                ix.unwrap()
1775            }
1776        }
1777    }
1778
1779    #[derive(Clone, Default)]
1780    struct FakeAgentConnection {
1781        auth_methods: Vec<acp::AuthMethod>,
1782        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1783        on_user_message: Option<
1784            Rc<
1785                dyn Fn(
1786                        acp::PromptRequest,
1787                        WeakEntity<AcpThread>,
1788                        AsyncApp,
1789                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1790                    + 'static,
1791            >,
1792        >,
1793    }
1794
1795    impl FakeAgentConnection {
1796        fn new() -> Self {
1797            Self {
1798                auth_methods: Vec::new(),
1799                on_user_message: None,
1800                sessions: Arc::default(),
1801            }
1802        }
1803
1804        #[expect(unused)]
1805        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1806            self.auth_methods = auth_methods;
1807            self
1808        }
1809
1810        fn on_user_message(
1811            mut self,
1812            handler: impl Fn(
1813                acp::PromptRequest,
1814                WeakEntity<AcpThread>,
1815                AsyncApp,
1816            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1817            + 'static,
1818        ) -> Self {
1819            self.on_user_message.replace(Rc::new(handler));
1820            self
1821        }
1822    }
1823
1824    impl AgentConnection for FakeAgentConnection {
1825        fn auth_methods(&self) -> &[acp::AuthMethod] {
1826            &self.auth_methods
1827        }
1828
1829        fn new_thread(
1830            self: Rc<Self>,
1831            project: Entity<Project>,
1832            _cwd: &Path,
1833            cx: &mut gpui::AsyncApp,
1834        ) -> Task<gpui::Result<Entity<AcpThread>>> {
1835            let session_id = acp::SessionId(
1836                rand::thread_rng()
1837                    .sample_iter(&rand::distributions::Alphanumeric)
1838                    .take(7)
1839                    .map(char::from)
1840                    .collect::<String>()
1841                    .into(),
1842            );
1843            let thread = cx
1844                .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1845                .unwrap();
1846            self.sessions.lock().insert(session_id, thread.downgrade());
1847            Task::ready(Ok(thread))
1848        }
1849
1850        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1851            if self.auth_methods().iter().any(|m| m.id == method) {
1852                Task::ready(Ok(()))
1853            } else {
1854                Task::ready(Err(anyhow!("Invalid Auth Method")))
1855            }
1856        }
1857
1858        fn prompt(
1859            &self,
1860            params: acp::PromptRequest,
1861            cx: &mut App,
1862        ) -> Task<gpui::Result<acp::PromptResponse>> {
1863            let sessions = self.sessions.lock();
1864            let thread = sessions.get(&params.session_id).unwrap();
1865            if let Some(handler) = &self.on_user_message {
1866                let handler = handler.clone();
1867                let thread = thread.clone();
1868                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1869            } else {
1870                Task::ready(Ok(acp::PromptResponse {
1871                    stop_reason: acp::StopReason::EndTurn,
1872                }))
1873            }
1874        }
1875
1876        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1877            let sessions = self.sessions.lock();
1878            let thread = sessions.get(&session_id).unwrap().clone();
1879
1880            cx.spawn(async move |cx| {
1881                thread
1882                    .update(cx, |thread, cx| thread.cancel(cx))
1883                    .unwrap()
1884                    .await
1885            })
1886            .detach();
1887        }
1888    }
1889}