acp_thread.rs

   1mod connection;
   2mod diff;
   3mod mention;
   4mod terminal;
   5
   6pub use connection::*;
   7pub use diff::*;
   8pub use mention::*;
   9pub use terminal::*;
  10
  11use action_log::ActionLog;
  12use agent_client_protocol::{self as acp};
  13use anyhow::{Context as _, Result};
  14use editor::Bias;
  15use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  16use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
  17use itertools::Itertools;
  18use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff};
  19use markdown::Markdown;
  20use project::{AgentLocation, Project};
  21use std::collections::HashMap;
  22use std::error::Error;
  23use std::fmt::Formatter;
  24use std::process::ExitStatus;
  25use std::rc::Rc;
  26use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
  27use ui::App;
  28use util::ResultExt;
  29
  30#[derive(Debug)]
  31pub struct UserMessage {
  32    pub content: ContentBlock,
  33}
  34
  35impl UserMessage {
  36    pub fn from_acp(
  37        message: impl IntoIterator<Item = acp::ContentBlock>,
  38        language_registry: Arc<LanguageRegistry>,
  39        cx: &mut App,
  40    ) -> Self {
  41        let mut content = ContentBlock::Empty;
  42        for chunk in message {
  43            content.append(chunk, &language_registry, cx)
  44        }
  45        Self { content: content }
  46    }
  47
  48    fn to_markdown(&self, cx: &App) -> String {
  49        format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
  50    }
  51}
  52
  53#[derive(Debug, PartialEq)]
  54pub struct AssistantMessage {
  55    pub chunks: Vec<AssistantMessageChunk>,
  56}
  57
  58impl AssistantMessage {
  59    pub fn to_markdown(&self, cx: &App) -> String {
  60        format!(
  61            "## Assistant\n\n{}\n\n",
  62            self.chunks
  63                .iter()
  64                .map(|chunk| chunk.to_markdown(cx))
  65                .join("\n\n")
  66        )
  67    }
  68}
  69
  70#[derive(Debug, PartialEq)]
  71pub enum AssistantMessageChunk {
  72    Message { block: ContentBlock },
  73    Thought { block: ContentBlock },
  74}
  75
  76impl AssistantMessageChunk {
  77    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
  78        Self::Message {
  79            block: ContentBlock::new(chunk.into(), language_registry, cx),
  80        }
  81    }
  82
  83    fn to_markdown(&self, cx: &App) -> String {
  84        match self {
  85            Self::Message { block } => block.to_markdown(cx).to_string(),
  86            Self::Thought { block } => {
  87                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
  88            }
  89        }
  90    }
  91}
  92
  93#[derive(Debug)]
  94pub enum AgentThreadEntry {
  95    UserMessage(UserMessage),
  96    AssistantMessage(AssistantMessage),
  97    ToolCall(ToolCall),
  98}
  99
 100impl AgentThreadEntry {
 101    fn to_markdown(&self, cx: &App) -> String {
 102        match self {
 103            Self::UserMessage(message) => message.to_markdown(cx),
 104            Self::AssistantMessage(message) => message.to_markdown(cx),
 105            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 106        }
 107    }
 108
 109    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 110        if let AgentThreadEntry::ToolCall(call) = self {
 111            itertools::Either::Left(call.diffs())
 112        } else {
 113            itertools::Either::Right(std::iter::empty())
 114        }
 115    }
 116
 117    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 118        if let AgentThreadEntry::ToolCall(call) = self {
 119            itertools::Either::Left(call.terminals())
 120        } else {
 121            itertools::Either::Right(std::iter::empty())
 122        }
 123    }
 124
 125    pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
 126        if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
 127            Some(locations)
 128        } else {
 129            None
 130        }
 131    }
 132}
 133
 134#[derive(Debug)]
 135pub struct ToolCall {
 136    pub id: acp::ToolCallId,
 137    pub label: Entity<Markdown>,
 138    pub kind: acp::ToolKind,
 139    pub content: Vec<ToolCallContent>,
 140    pub status: ToolCallStatus,
 141    pub locations: Vec<acp::ToolCallLocation>,
 142    pub raw_input: Option<serde_json::Value>,
 143    pub raw_output: Option<serde_json::Value>,
 144}
 145
 146impl ToolCall {
 147    fn from_acp(
 148        tool_call: acp::ToolCall,
 149        status: ToolCallStatus,
 150        language_registry: Arc<LanguageRegistry>,
 151        cx: &mut App,
 152    ) -> Self {
 153        Self {
 154            id: tool_call.id,
 155            label: cx.new(|cx| {
 156                Markdown::new(
 157                    tool_call.title.into(),
 158                    Some(language_registry.clone()),
 159                    None,
 160                    cx,
 161                )
 162            }),
 163            kind: tool_call.kind,
 164            content: tool_call
 165                .content
 166                .into_iter()
 167                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 168                .collect(),
 169            locations: tool_call.locations,
 170            status,
 171            raw_input: tool_call.raw_input,
 172            raw_output: tool_call.raw_output,
 173        }
 174    }
 175
 176    fn update_fields(
 177        &mut self,
 178        fields: acp::ToolCallUpdateFields,
 179        language_registry: Arc<LanguageRegistry>,
 180        cx: &mut App,
 181    ) {
 182        let acp::ToolCallUpdateFields {
 183            kind,
 184            status,
 185            title,
 186            content,
 187            locations,
 188            raw_input,
 189            raw_output,
 190        } = fields;
 191
 192        if let Some(kind) = kind {
 193            self.kind = kind;
 194        }
 195
 196        if let Some(status) = status {
 197            self.status = ToolCallStatus::Allowed { status };
 198        }
 199
 200        if let Some(title) = title {
 201            self.label.update(cx, |label, cx| {
 202                label.replace(title, cx);
 203            });
 204        }
 205
 206        if let Some(content) = content {
 207            self.content = content
 208                .into_iter()
 209                .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
 210                .collect();
 211        }
 212
 213        if let Some(locations) = locations {
 214            self.locations = locations;
 215        }
 216
 217        if let Some(raw_input) = raw_input {
 218            self.raw_input = Some(raw_input);
 219        }
 220
 221        if let Some(raw_output) = raw_output {
 222            if self.content.is_empty() {
 223                if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
 224                {
 225                    self.content
 226                        .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
 227                            markdown,
 228                        }));
 229                }
 230            }
 231            self.raw_output = Some(raw_output);
 232        }
 233    }
 234
 235    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 236        self.content.iter().filter_map(|content| match content {
 237            ToolCallContent::Diff(diff) => Some(diff),
 238            ToolCallContent::ContentBlock(_) => None,
 239            ToolCallContent::Terminal(_) => None,
 240        })
 241    }
 242
 243    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 244        self.content.iter().filter_map(|content| match content {
 245            ToolCallContent::Terminal(terminal) => Some(terminal),
 246            ToolCallContent::ContentBlock(_) => None,
 247            ToolCallContent::Diff(_) => None,
 248        })
 249    }
 250
 251    fn to_markdown(&self, cx: &App) -> String {
 252        let mut markdown = format!(
 253            "**Tool Call: {}**\nStatus: {}\n\n",
 254            self.label.read(cx).source(),
 255            self.status
 256        );
 257        for content in &self.content {
 258            markdown.push_str(content.to_markdown(cx).as_str());
 259            markdown.push_str("\n\n");
 260        }
 261        markdown
 262    }
 263}
 264
 265#[derive(Debug)]
 266pub enum ToolCallStatus {
 267    WaitingForConfirmation {
 268        options: Vec<acp::PermissionOption>,
 269        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 270    },
 271    Allowed {
 272        status: acp::ToolCallStatus,
 273    },
 274    Rejected,
 275    Canceled,
 276}
 277
 278impl Display for ToolCallStatus {
 279    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 280        write!(
 281            f,
 282            "{}",
 283            match self {
 284                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 285                ToolCallStatus::Allowed { status } => match status {
 286                    acp::ToolCallStatus::Pending => "Pending",
 287                    acp::ToolCallStatus::InProgress => "In Progress",
 288                    acp::ToolCallStatus::Completed => "Completed",
 289                    acp::ToolCallStatus::Failed => "Failed",
 290                },
 291                ToolCallStatus::Rejected => "Rejected",
 292                ToolCallStatus::Canceled => "Canceled",
 293            }
 294        )
 295    }
 296}
 297
 298#[derive(Debug, PartialEq, Clone)]
 299pub enum ContentBlock {
 300    Empty,
 301    Markdown { markdown: Entity<Markdown> },
 302    ResourceLink { resource_link: acp::ResourceLink },
 303}
 304
 305impl ContentBlock {
 306    pub fn new(
 307        block: acp::ContentBlock,
 308        language_registry: &Arc<LanguageRegistry>,
 309        cx: &mut App,
 310    ) -> Self {
 311        let mut this = Self::Empty;
 312        this.append(block, language_registry, cx);
 313        this
 314    }
 315
 316    pub fn new_combined(
 317        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 318        language_registry: Arc<LanguageRegistry>,
 319        cx: &mut App,
 320    ) -> Self {
 321        let mut this = Self::Empty;
 322        for block in blocks {
 323            this.append(block, &language_registry, cx);
 324        }
 325        this
 326    }
 327
 328    pub fn append(
 329        &mut self,
 330        block: acp::ContentBlock,
 331        language_registry: &Arc<LanguageRegistry>,
 332        cx: &mut App,
 333    ) {
 334        if matches!(self, ContentBlock::Empty) {
 335            if let acp::ContentBlock::ResourceLink(resource_link) = block {
 336                *self = ContentBlock::ResourceLink { resource_link };
 337                return;
 338            }
 339        }
 340
 341        let new_content = self.extract_content_from_block(block);
 342
 343        match self {
 344            ContentBlock::Empty => {
 345                *self = Self::create_markdown_block(new_content, language_registry, cx);
 346            }
 347            ContentBlock::Markdown { markdown } => {
 348                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 349            }
 350            ContentBlock::ResourceLink { resource_link } => {
 351                let existing_content = Self::resource_link_to_content(&resource_link.uri);
 352                let combined = format!("{}\n{}", existing_content, new_content);
 353
 354                *self = Self::create_markdown_block(combined, language_registry, cx);
 355            }
 356        }
 357    }
 358
 359    fn resource_link_to_content(uri: &str) -> String {
 360        if let Some(uri) = MentionUri::parse(&uri).log_err() {
 361            uri.to_link()
 362        } else {
 363            uri.to_string().clone()
 364        }
 365    }
 366
 367    fn create_markdown_block(
 368        content: String,
 369        language_registry: &Arc<LanguageRegistry>,
 370        cx: &mut App,
 371    ) -> ContentBlock {
 372        ContentBlock::Markdown {
 373            markdown: cx
 374                .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
 375        }
 376    }
 377
 378    fn extract_content_from_block(&self, block: acp::ContentBlock) -> String {
 379        match block {
 380            acp::ContentBlock::Text(text_content) => text_content.text.clone(),
 381            acp::ContentBlock::ResourceLink(resource_link) => {
 382                Self::resource_link_to_content(&resource_link.uri)
 383            }
 384            acp::ContentBlock::Resource(acp::EmbeddedResource {
 385                resource:
 386                    acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
 387                        uri,
 388                        ..
 389                    }),
 390                ..
 391            }) => Self::resource_link_to_content(&uri),
 392            acp::ContentBlock::Image(_)
 393            | acp::ContentBlock::Audio(_)
 394            | acp::ContentBlock::Resource(_) => String::new(),
 395        }
 396    }
 397
 398    fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 399        match self {
 400            ContentBlock::Empty => "",
 401            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 402            ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
 403        }
 404    }
 405
 406    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 407        match self {
 408            ContentBlock::Empty => None,
 409            ContentBlock::Markdown { markdown } => Some(markdown),
 410            ContentBlock::ResourceLink { .. } => None,
 411        }
 412    }
 413
 414    pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
 415        match self {
 416            ContentBlock::ResourceLink { resource_link } => Some(resource_link),
 417            _ => None,
 418        }
 419    }
 420}
 421
 422#[derive(Debug)]
 423pub enum ToolCallContent {
 424    ContentBlock(ContentBlock),
 425    Diff(Entity<Diff>),
 426    Terminal(Entity<Terminal>),
 427}
 428
 429impl ToolCallContent {
 430    pub fn from_acp(
 431        content: acp::ToolCallContent,
 432        language_registry: Arc<LanguageRegistry>,
 433        cx: &mut App,
 434    ) -> Self {
 435        match content {
 436            acp::ToolCallContent::Content { content } => {
 437                Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
 438            }
 439            acp::ToolCallContent::Diff { diff } => {
 440                Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
 441            }
 442        }
 443    }
 444
 445    pub fn to_markdown(&self, cx: &App) -> String {
 446        match self {
 447            Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
 448            Self::Diff(diff) => diff.read(cx).to_markdown(cx),
 449            Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
 450        }
 451    }
 452}
 453
 454#[derive(Debug, PartialEq)]
 455pub enum ToolCallUpdate {
 456    UpdateFields(acp::ToolCallUpdate),
 457    UpdateDiff(ToolCallUpdateDiff),
 458    UpdateTerminal(ToolCallUpdateTerminal),
 459}
 460
 461impl ToolCallUpdate {
 462    fn id(&self) -> &acp::ToolCallId {
 463        match self {
 464            Self::UpdateFields(update) => &update.id,
 465            Self::UpdateDiff(diff) => &diff.id,
 466            Self::UpdateTerminal(terminal) => &terminal.id,
 467        }
 468    }
 469}
 470
 471impl From<acp::ToolCallUpdate> for ToolCallUpdate {
 472    fn from(update: acp::ToolCallUpdate) -> Self {
 473        Self::UpdateFields(update)
 474    }
 475}
 476
 477impl From<ToolCallUpdateDiff> for ToolCallUpdate {
 478    fn from(diff: ToolCallUpdateDiff) -> Self {
 479        Self::UpdateDiff(diff)
 480    }
 481}
 482
 483#[derive(Debug, PartialEq)]
 484pub struct ToolCallUpdateDiff {
 485    pub id: acp::ToolCallId,
 486    pub diff: Entity<Diff>,
 487}
 488
 489impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
 490    fn from(terminal: ToolCallUpdateTerminal) -> Self {
 491        Self::UpdateTerminal(terminal)
 492    }
 493}
 494
 495#[derive(Debug, PartialEq)]
 496pub struct ToolCallUpdateTerminal {
 497    pub id: acp::ToolCallId,
 498    pub terminal: Entity<Terminal>,
 499}
 500
 501#[derive(Debug, Default)]
 502pub struct Plan {
 503    pub entries: Vec<PlanEntry>,
 504}
 505
 506#[derive(Debug)]
 507pub struct PlanStats<'a> {
 508    pub in_progress_entry: Option<&'a PlanEntry>,
 509    pub pending: u32,
 510    pub completed: u32,
 511}
 512
 513impl Plan {
 514    pub fn is_empty(&self) -> bool {
 515        self.entries.is_empty()
 516    }
 517
 518    pub fn stats(&self) -> PlanStats<'_> {
 519        let mut stats = PlanStats {
 520            in_progress_entry: None,
 521            pending: 0,
 522            completed: 0,
 523        };
 524
 525        for entry in &self.entries {
 526            match &entry.status {
 527                acp::PlanEntryStatus::Pending => {
 528                    stats.pending += 1;
 529                }
 530                acp::PlanEntryStatus::InProgress => {
 531                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 532                }
 533                acp::PlanEntryStatus::Completed => {
 534                    stats.completed += 1;
 535                }
 536            }
 537        }
 538
 539        stats
 540    }
 541}
 542
 543#[derive(Debug)]
 544pub struct PlanEntry {
 545    pub content: Entity<Markdown>,
 546    pub priority: acp::PlanEntryPriority,
 547    pub status: acp::PlanEntryStatus,
 548}
 549
 550impl PlanEntry {
 551    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 552        Self {
 553            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 554            priority: entry.priority,
 555            status: entry.status,
 556        }
 557    }
 558}
 559
 560pub struct AcpThread {
 561    title: SharedString,
 562    entries: Vec<AgentThreadEntry>,
 563    plan: Plan,
 564    project: Entity<Project>,
 565    action_log: Entity<ActionLog>,
 566    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 567    send_task: Option<Task<()>>,
 568    connection: Rc<dyn AgentConnection>,
 569    session_id: acp::SessionId,
 570}
 571
 572pub enum AcpThreadEvent {
 573    NewEntry,
 574    EntryUpdated(usize),
 575    ToolAuthorizationRequired,
 576    Stopped,
 577    Error,
 578    ServerExited(ExitStatus),
 579}
 580
 581impl EventEmitter<AcpThreadEvent> for AcpThread {}
 582
 583#[derive(PartialEq, Eq)]
 584pub enum ThreadStatus {
 585    Idle,
 586    WaitingForToolConfirmation,
 587    Generating,
 588}
 589
 590#[derive(Debug, Clone)]
 591pub enum LoadError {
 592    Unsupported {
 593        error_message: SharedString,
 594        upgrade_message: SharedString,
 595        upgrade_command: String,
 596    },
 597    Exited(i32),
 598    Other(SharedString),
 599}
 600
 601impl Display for LoadError {
 602    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 603        match self {
 604            LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
 605            LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
 606            LoadError::Other(msg) => write!(f, "{}", msg),
 607        }
 608    }
 609}
 610
 611impl Error for LoadError {}
 612
 613impl AcpThread {
 614    pub fn new(
 615        title: impl Into<SharedString>,
 616        connection: Rc<dyn AgentConnection>,
 617        project: Entity<Project>,
 618        session_id: acp::SessionId,
 619        cx: &mut Context<Self>,
 620    ) -> Self {
 621        let action_log = cx.new(|_| ActionLog::new(project.clone()));
 622
 623        Self {
 624            action_log,
 625            shared_buffers: Default::default(),
 626            entries: Default::default(),
 627            plan: Default::default(),
 628            title: title.into(),
 629            project,
 630            send_task: None,
 631            connection,
 632            session_id,
 633        }
 634    }
 635
 636    pub fn action_log(&self) -> &Entity<ActionLog> {
 637        &self.action_log
 638    }
 639
 640    pub fn project(&self) -> &Entity<Project> {
 641        &self.project
 642    }
 643
 644    pub fn title(&self) -> SharedString {
 645        self.title.clone()
 646    }
 647
 648    pub fn entries(&self) -> &[AgentThreadEntry] {
 649        &self.entries
 650    }
 651
 652    pub fn session_id(&self) -> &acp::SessionId {
 653        &self.session_id
 654    }
 655
 656    pub fn status(&self) -> ThreadStatus {
 657        if self.send_task.is_some() {
 658            if self.waiting_for_tool_confirmation() {
 659                ThreadStatus::WaitingForToolConfirmation
 660            } else {
 661                ThreadStatus::Generating
 662            }
 663        } else {
 664            ThreadStatus::Idle
 665        }
 666    }
 667
 668    pub fn has_pending_edit_tool_calls(&self) -> bool {
 669        for entry in self.entries.iter().rev() {
 670            match entry {
 671                AgentThreadEntry::UserMessage(_) => return false,
 672                AgentThreadEntry::ToolCall(
 673                    call @ ToolCall {
 674                        status:
 675                            ToolCallStatus::Allowed {
 676                                status:
 677                                    acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
 678                            },
 679                        ..
 680                    },
 681                ) if call.diffs().next().is_some() => {
 682                    return true;
 683                }
 684                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 685            }
 686        }
 687
 688        false
 689    }
 690
 691    pub fn used_tools_since_last_user_message(&self) -> bool {
 692        for entry in self.entries.iter().rev() {
 693            match entry {
 694                AgentThreadEntry::UserMessage(..) => return false,
 695                AgentThreadEntry::AssistantMessage(..) => continue,
 696                AgentThreadEntry::ToolCall(..) => return true,
 697            }
 698        }
 699
 700        false
 701    }
 702
 703    pub fn handle_session_update(
 704        &mut self,
 705        update: acp::SessionUpdate,
 706        cx: &mut Context<Self>,
 707    ) -> Result<()> {
 708        match update {
 709            acp::SessionUpdate::UserMessageChunk { content } => {
 710                self.push_user_content_block(content, cx);
 711            }
 712            acp::SessionUpdate::AgentMessageChunk { content } => {
 713                self.push_assistant_content_block(content, false, cx);
 714            }
 715            acp::SessionUpdate::AgentThoughtChunk { content } => {
 716                self.push_assistant_content_block(content, true, cx);
 717            }
 718            acp::SessionUpdate::ToolCall(tool_call) => {
 719                self.upsert_tool_call(tool_call, cx);
 720            }
 721            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 722                self.update_tool_call(tool_call_update, cx)?;
 723            }
 724            acp::SessionUpdate::Plan(plan) => {
 725                self.update_plan(plan, cx);
 726            }
 727        }
 728        Ok(())
 729    }
 730
 731    pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
 732        let language_registry = self.project.read(cx).languages().clone();
 733        let entries_len = self.entries.len();
 734
 735        if let Some(last_entry) = self.entries.last_mut()
 736            && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
 737        {
 738            content.append(chunk, &language_registry, cx);
 739            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 740        } else {
 741            let content = ContentBlock::new(chunk, &language_registry, cx);
 742            self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
 743        }
 744    }
 745
 746    pub fn push_assistant_content_block(
 747        &mut self,
 748        chunk: acp::ContentBlock,
 749        is_thought: bool,
 750        cx: &mut Context<Self>,
 751    ) {
 752        let language_registry = self.project.read(cx).languages().clone();
 753        let entries_len = self.entries.len();
 754        if let Some(last_entry) = self.entries.last_mut()
 755            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
 756        {
 757            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 758            match (chunks.last_mut(), is_thought) {
 759                (Some(AssistantMessageChunk::Message { block }), false)
 760                | (Some(AssistantMessageChunk::Thought { block }), true) => {
 761                    block.append(chunk, &language_registry, cx)
 762                }
 763                _ => {
 764                    let block = ContentBlock::new(chunk, &language_registry, cx);
 765                    if is_thought {
 766                        chunks.push(AssistantMessageChunk::Thought { block })
 767                    } else {
 768                        chunks.push(AssistantMessageChunk::Message { block })
 769                    }
 770                }
 771            }
 772        } else {
 773            let block = ContentBlock::new(chunk, &language_registry, cx);
 774            let chunk = if is_thought {
 775                AssistantMessageChunk::Thought { block }
 776            } else {
 777                AssistantMessageChunk::Message { block }
 778            };
 779
 780            self.push_entry(
 781                AgentThreadEntry::AssistantMessage(AssistantMessage {
 782                    chunks: vec![chunk],
 783                }),
 784                cx,
 785            );
 786        }
 787    }
 788
 789    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
 790        self.entries.push(entry);
 791        cx.emit(AcpThreadEvent::NewEntry);
 792    }
 793
 794    pub fn update_tool_call(
 795        &mut self,
 796        update: impl Into<ToolCallUpdate>,
 797        cx: &mut Context<Self>,
 798    ) -> Result<()> {
 799        let update = update.into();
 800        let languages = self.project.read(cx).languages().clone();
 801
 802        let (ix, current_call) = self
 803            .tool_call_mut(update.id())
 804            .context("Tool call not found")?;
 805        match update {
 806            ToolCallUpdate::UpdateFields(update) => {
 807                current_call.update_fields(update.fields, languages, cx);
 808            }
 809            ToolCallUpdate::UpdateDiff(update) => {
 810                current_call.content.clear();
 811                current_call
 812                    .content
 813                    .push(ToolCallContent::Diff(update.diff));
 814            }
 815            ToolCallUpdate::UpdateTerminal(update) => {
 816                current_call.content.clear();
 817                current_call
 818                    .content
 819                    .push(ToolCallContent::Terminal(update.terminal));
 820            }
 821        }
 822
 823        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 824
 825        Ok(())
 826    }
 827
 828    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
 829    pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
 830        let status = ToolCallStatus::Allowed {
 831            status: tool_call.status,
 832        };
 833        self.upsert_tool_call_inner(tool_call, status, cx)
 834    }
 835
 836    pub fn upsert_tool_call_inner(
 837        &mut self,
 838        tool_call: acp::ToolCall,
 839        status: ToolCallStatus,
 840        cx: &mut Context<Self>,
 841    ) {
 842        let language_registry = self.project.read(cx).languages().clone();
 843        let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
 844
 845        let location = call.locations.last().cloned();
 846
 847        if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
 848            *current_call = call;
 849
 850            cx.emit(AcpThreadEvent::EntryUpdated(ix));
 851        } else {
 852            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
 853        }
 854
 855        if let Some(location) = location {
 856            self.set_project_location(location, cx)
 857        }
 858    }
 859
 860    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
 861        // The tool call we are looking for is typically the last one, or very close to the end.
 862        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
 863        self.entries
 864            .iter_mut()
 865            .enumerate()
 866            .rev()
 867            .find_map(|(index, tool_call)| {
 868                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
 869                    && &tool_call.id == id
 870                {
 871                    Some((index, tool_call))
 872                } else {
 873                    None
 874                }
 875            })
 876    }
 877
 878    pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
 879        self.project.update(cx, |project, cx| {
 880            let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
 881                return;
 882            };
 883            let buffer = project.open_buffer(path, cx);
 884            cx.spawn(async move |project, cx| {
 885                let buffer = buffer.await?;
 886
 887                project.update(cx, |project, cx| {
 888                    let position = if let Some(line) = location.line {
 889                        let snapshot = buffer.read(cx).snapshot();
 890                        let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
 891                        snapshot.anchor_before(point)
 892                    } else {
 893                        Anchor::MIN
 894                    };
 895
 896                    project.set_agent_location(
 897                        Some(AgentLocation {
 898                            buffer: buffer.downgrade(),
 899                            position,
 900                        }),
 901                        cx,
 902                    );
 903                })
 904            })
 905            .detach_and_log_err(cx);
 906        });
 907    }
 908
 909    pub fn request_tool_call_authorization(
 910        &mut self,
 911        tool_call: acp::ToolCall,
 912        options: Vec<acp::PermissionOption>,
 913        cx: &mut Context<Self>,
 914    ) -> oneshot::Receiver<acp::PermissionOptionId> {
 915        let (tx, rx) = oneshot::channel();
 916
 917        let status = ToolCallStatus::WaitingForConfirmation {
 918            options,
 919            respond_tx: tx,
 920        };
 921
 922        self.upsert_tool_call_inner(tool_call, status, cx);
 923        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
 924        rx
 925    }
 926
 927    pub fn authorize_tool_call(
 928        &mut self,
 929        id: acp::ToolCallId,
 930        option_id: acp::PermissionOptionId,
 931        option_kind: acp::PermissionOptionKind,
 932        cx: &mut Context<Self>,
 933    ) {
 934        let Some((ix, call)) = self.tool_call_mut(&id) else {
 935            return;
 936        };
 937
 938        let new_status = match option_kind {
 939            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
 940                ToolCallStatus::Rejected
 941            }
 942            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
 943                ToolCallStatus::Allowed {
 944                    status: acp::ToolCallStatus::InProgress,
 945                }
 946            }
 947        };
 948
 949        let curr_status = mem::replace(&mut call.status, new_status);
 950
 951        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
 952            respond_tx.send(option_id).log_err();
 953        } else if cfg!(debug_assertions) {
 954            panic!("tried to authorize an already authorized tool call");
 955        }
 956
 957        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 958    }
 959
 960    /// Returns true if the last turn is awaiting tool authorization
 961    pub fn waiting_for_tool_confirmation(&self) -> bool {
 962        for entry in self.entries.iter().rev() {
 963            match &entry {
 964                AgentThreadEntry::ToolCall(call) => match call.status {
 965                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
 966                    ToolCallStatus::Allowed { .. }
 967                    | ToolCallStatus::Rejected
 968                    | ToolCallStatus::Canceled => continue,
 969                },
 970                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
 971                    // Reached the beginning of the turn
 972                    return false;
 973                }
 974            }
 975        }
 976        false
 977    }
 978
 979    pub fn plan(&self) -> &Plan {
 980        &self.plan
 981    }
 982
 983    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
 984        let new_entries_len = request.entries.len();
 985        let mut new_entries = request.entries.into_iter();
 986
 987        // Reuse existing markdown to prevent flickering
 988        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
 989            let PlanEntry {
 990                content,
 991                priority,
 992                status,
 993            } = old;
 994            content.update(cx, |old, cx| {
 995                old.replace(new.content, cx);
 996            });
 997            *priority = new.priority;
 998            *status = new.status;
 999        }
1000        for new in new_entries {
1001            self.plan.entries.push(PlanEntry::from_acp(new, cx))
1002        }
1003        self.plan.entries.truncate(new_entries_len);
1004
1005        cx.notify();
1006    }
1007
1008    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1009        self.plan
1010            .entries
1011            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1012        cx.notify();
1013    }
1014
1015    #[cfg(any(test, feature = "test-support"))]
1016    pub fn send_raw(
1017        &mut self,
1018        message: &str,
1019        cx: &mut Context<Self>,
1020    ) -> BoxFuture<'static, Result<()>> {
1021        self.send(
1022            vec![acp::ContentBlock::Text(acp::TextContent {
1023                text: message.to_string(),
1024                annotations: None,
1025            })],
1026            cx,
1027        )
1028    }
1029
1030    pub fn send(
1031        &mut self,
1032        message: Vec<acp::ContentBlock>,
1033        cx: &mut Context<Self>,
1034    ) -> BoxFuture<'static, Result<()>> {
1035        let block = ContentBlock::new_combined(
1036            message.clone(),
1037            self.project.read(cx).languages().clone(),
1038            cx,
1039        );
1040        self.push_entry(
1041            AgentThreadEntry::UserMessage(UserMessage { content: block }),
1042            cx,
1043        );
1044        self.clear_completed_plan_entries(cx);
1045
1046        let (tx, rx) = oneshot::channel();
1047        let cancel_task = self.cancel(cx);
1048
1049        self.send_task = Some(cx.spawn(async move |this, cx| {
1050            async {
1051                cancel_task.await;
1052
1053                let result = this
1054                    .update(cx, |this, cx| {
1055                        this.connection.prompt(
1056                            acp::PromptRequest {
1057                                prompt: message,
1058                                session_id: this.session_id.clone(),
1059                            },
1060                            cx,
1061                        )
1062                    })?
1063                    .await;
1064
1065                tx.send(result).log_err();
1066
1067                anyhow::Ok(())
1068            }
1069            .await
1070            .log_err();
1071        }));
1072
1073        cx.spawn(async move |this, cx| match rx.await {
1074            Ok(Err(e)) => {
1075                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
1076                    .log_err();
1077                Err(e)?
1078            }
1079            result => {
1080                let cancelled = matches!(
1081                    result,
1082                    Ok(Ok(acp::PromptResponse {
1083                        stop_reason: acp::StopReason::Cancelled
1084                    }))
1085                );
1086
1087                // We only take the task if the current prompt wasn't cancelled.
1088                //
1089                // This prompt may have been cancelled because another one was sent
1090                // while it was still generating. In these cases, dropping `send_task`
1091                // would cause the next generation to be cancelled.
1092                if !cancelled {
1093                    this.update(cx, |this, _cx| this.send_task.take()).ok();
1094                }
1095
1096                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1097                    .log_err();
1098                Ok(())
1099            }
1100        })
1101        .boxed()
1102    }
1103
1104    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1105        let Some(send_task) = self.send_task.take() else {
1106            return Task::ready(());
1107        };
1108
1109        for entry in self.entries.iter_mut() {
1110            if let AgentThreadEntry::ToolCall(call) = entry {
1111                let cancel = matches!(
1112                    call.status,
1113                    ToolCallStatus::WaitingForConfirmation { .. }
1114                        | ToolCallStatus::Allowed {
1115                            status: acp::ToolCallStatus::InProgress
1116                        }
1117                );
1118
1119                if cancel {
1120                    call.status = ToolCallStatus::Canceled;
1121                }
1122            }
1123        }
1124
1125        self.connection.cancel(&self.session_id, cx);
1126
1127        // Wait for the send task to complete
1128        cx.foreground_executor().spawn(send_task)
1129    }
1130
1131    pub fn read_text_file(
1132        &self,
1133        path: PathBuf,
1134        line: Option<u32>,
1135        limit: Option<u32>,
1136        reuse_shared_snapshot: bool,
1137        cx: &mut Context<Self>,
1138    ) -> Task<Result<String>> {
1139        let project = self.project.clone();
1140        let action_log = self.action_log.clone();
1141        cx.spawn(async move |this, cx| {
1142            let load = project.update(cx, |project, cx| {
1143                let path = project
1144                    .project_path_for_absolute_path(&path, cx)
1145                    .context("invalid path")?;
1146                anyhow::Ok(project.open_buffer(path, cx))
1147            });
1148            let buffer = load??.await?;
1149
1150            let snapshot = if reuse_shared_snapshot {
1151                this.read_with(cx, |this, _| {
1152                    this.shared_buffers.get(&buffer.clone()).cloned()
1153                })
1154                .log_err()
1155                .flatten()
1156            } else {
1157                None
1158            };
1159
1160            let snapshot = if let Some(snapshot) = snapshot {
1161                snapshot
1162            } else {
1163                action_log.update(cx, |action_log, cx| {
1164                    action_log.buffer_read(buffer.clone(), cx);
1165                })?;
1166                project.update(cx, |project, cx| {
1167                    let position = buffer
1168                        .read(cx)
1169                        .snapshot()
1170                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1171                    project.set_agent_location(
1172                        Some(AgentLocation {
1173                            buffer: buffer.downgrade(),
1174                            position,
1175                        }),
1176                        cx,
1177                    );
1178                })?;
1179
1180                buffer.update(cx, |buffer, _| buffer.snapshot())?
1181            };
1182
1183            this.update(cx, |this, _| {
1184                let text = snapshot.text();
1185                this.shared_buffers.insert(buffer.clone(), snapshot);
1186                if line.is_none() && limit.is_none() {
1187                    return Ok(text);
1188                }
1189                let limit = limit.unwrap_or(u32::MAX) as usize;
1190                let Some(line) = line else {
1191                    return Ok(text.lines().take(limit).collect::<String>());
1192                };
1193
1194                let count = text.lines().count();
1195                if count < line as usize {
1196                    anyhow::bail!("There are only {} lines", count);
1197                }
1198                Ok(text
1199                    .lines()
1200                    .skip(line as usize + 1)
1201                    .take(limit)
1202                    .collect::<String>())
1203            })?
1204        })
1205    }
1206
1207    pub fn write_text_file(
1208        &self,
1209        path: PathBuf,
1210        content: String,
1211        cx: &mut Context<Self>,
1212    ) -> Task<Result<()>> {
1213        let project = self.project.clone();
1214        let action_log = self.action_log.clone();
1215        cx.spawn(async move |this, cx| {
1216            let load = project.update(cx, |project, cx| {
1217                let path = project
1218                    .project_path_for_absolute_path(&path, cx)
1219                    .context("invalid path")?;
1220                anyhow::Ok(project.open_buffer(path, cx))
1221            });
1222            let buffer = load??.await?;
1223            let snapshot = this.update(cx, |this, cx| {
1224                this.shared_buffers
1225                    .get(&buffer)
1226                    .cloned()
1227                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1228            })?;
1229            let edits = cx
1230                .background_executor()
1231                .spawn(async move {
1232                    let old_text = snapshot.text();
1233                    text_diff(old_text.as_str(), &content)
1234                        .into_iter()
1235                        .map(|(range, replacement)| {
1236                            (
1237                                snapshot.anchor_after(range.start)
1238                                    ..snapshot.anchor_before(range.end),
1239                                replacement,
1240                            )
1241                        })
1242                        .collect::<Vec<_>>()
1243                })
1244                .await;
1245            cx.update(|cx| {
1246                project.update(cx, |project, cx| {
1247                    project.set_agent_location(
1248                        Some(AgentLocation {
1249                            buffer: buffer.downgrade(),
1250                            position: edits
1251                                .last()
1252                                .map(|(range, _)| range.end)
1253                                .unwrap_or(Anchor::MIN),
1254                        }),
1255                        cx,
1256                    );
1257                });
1258
1259                action_log.update(cx, |action_log, cx| {
1260                    action_log.buffer_read(buffer.clone(), cx);
1261                });
1262                buffer.update(cx, |buffer, cx| {
1263                    buffer.edit(edits, None, cx);
1264                });
1265                action_log.update(cx, |action_log, cx| {
1266                    action_log.buffer_edited(buffer.clone(), cx);
1267                });
1268            })?;
1269            project
1270                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1271                .await
1272        })
1273    }
1274
1275    pub fn to_markdown(&self, cx: &App) -> String {
1276        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1277    }
1278
1279    pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1280        cx.emit(AcpThreadEvent::ServerExited(status));
1281    }
1282}
1283
1284fn markdown_for_raw_output(
1285    raw_output: &serde_json::Value,
1286    language_registry: &Arc<LanguageRegistry>,
1287    cx: &mut App,
1288) -> Option<Entity<Markdown>> {
1289    match raw_output {
1290        serde_json::Value::Null => None,
1291        serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1292            Markdown::new(
1293                value.to_string().into(),
1294                Some(language_registry.clone()),
1295                None,
1296                cx,
1297            )
1298        })),
1299        serde_json::Value::Number(value) => Some(cx.new(|cx| {
1300            Markdown::new(
1301                value.to_string().into(),
1302                Some(language_registry.clone()),
1303                None,
1304                cx,
1305            )
1306        })),
1307        serde_json::Value::String(value) => Some(cx.new(|cx| {
1308            Markdown::new(
1309                value.clone().into(),
1310                Some(language_registry.clone()),
1311                None,
1312                cx,
1313            )
1314        })),
1315        value => Some(cx.new(|cx| {
1316            Markdown::new(
1317                format!("```json\n{}\n```", value).into(),
1318                Some(language_registry.clone()),
1319                None,
1320                cx,
1321            )
1322        })),
1323    }
1324}
1325
1326#[cfg(test)]
1327mod tests {
1328    use super::*;
1329    use anyhow::anyhow;
1330    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1331    use gpui::{AsyncApp, TestAppContext, WeakEntity};
1332    use indoc::indoc;
1333    use project::FakeFs;
1334    use rand::Rng as _;
1335    use serde_json::json;
1336    use settings::SettingsStore;
1337    use smol::stream::StreamExt as _;
1338    use std::{cell::RefCell, path::Path, rc::Rc, time::Duration};
1339
1340    use util::path;
1341
1342    fn init_test(cx: &mut TestAppContext) {
1343        env_logger::try_init().ok();
1344        cx.update(|cx| {
1345            let settings_store = SettingsStore::test(cx);
1346            cx.set_global(settings_store);
1347            Project::init_settings(cx);
1348            language::init(cx);
1349        });
1350    }
1351
1352    #[gpui::test]
1353    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1354        init_test(cx);
1355
1356        let fs = FakeFs::new(cx.executor());
1357        let project = Project::test(fs, [], cx).await;
1358        let connection = Rc::new(FakeAgentConnection::new());
1359        let thread = cx
1360            .spawn(async move |mut cx| {
1361                connection
1362                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1363                    .await
1364            })
1365            .await
1366            .unwrap();
1367
1368        // Test creating a new user message
1369        thread.update(cx, |thread, cx| {
1370            thread.push_user_content_block(
1371                acp::ContentBlock::Text(acp::TextContent {
1372                    annotations: None,
1373                    text: "Hello, ".to_string(),
1374                }),
1375                cx,
1376            );
1377        });
1378
1379        thread.update(cx, |thread, cx| {
1380            assert_eq!(thread.entries.len(), 1);
1381            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1382                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1383            } else {
1384                panic!("Expected UserMessage");
1385            }
1386        });
1387
1388        // Test appending to existing user message
1389        thread.update(cx, |thread, cx| {
1390            thread.push_user_content_block(
1391                acp::ContentBlock::Text(acp::TextContent {
1392                    annotations: None,
1393                    text: "world!".to_string(),
1394                }),
1395                cx,
1396            );
1397        });
1398
1399        thread.update(cx, |thread, cx| {
1400            assert_eq!(thread.entries.len(), 1);
1401            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1402                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1403            } else {
1404                panic!("Expected UserMessage");
1405            }
1406        });
1407
1408        // Test creating new user message after assistant message
1409        thread.update(cx, |thread, cx| {
1410            thread.push_assistant_content_block(
1411                acp::ContentBlock::Text(acp::TextContent {
1412                    annotations: None,
1413                    text: "Assistant response".to_string(),
1414                }),
1415                false,
1416                cx,
1417            );
1418        });
1419
1420        thread.update(cx, |thread, cx| {
1421            thread.push_user_content_block(
1422                acp::ContentBlock::Text(acp::TextContent {
1423                    annotations: None,
1424                    text: "New user message".to_string(),
1425                }),
1426                cx,
1427            );
1428        });
1429
1430        thread.update(cx, |thread, cx| {
1431            assert_eq!(thread.entries.len(), 3);
1432            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1433                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1434            } else {
1435                panic!("Expected UserMessage at index 2");
1436            }
1437        });
1438    }
1439
1440    #[gpui::test]
1441    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1442        init_test(cx);
1443
1444        let fs = FakeFs::new(cx.executor());
1445        let project = Project::test(fs, [], cx).await;
1446        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1447            |_, thread, mut cx| {
1448                async move {
1449                    thread.update(&mut cx, |thread, cx| {
1450                        thread
1451                            .handle_session_update(
1452                                acp::SessionUpdate::AgentThoughtChunk {
1453                                    content: "Thinking ".into(),
1454                                },
1455                                cx,
1456                            )
1457                            .unwrap();
1458                        thread
1459                            .handle_session_update(
1460                                acp::SessionUpdate::AgentThoughtChunk {
1461                                    content: "hard!".into(),
1462                                },
1463                                cx,
1464                            )
1465                            .unwrap();
1466                    })?;
1467                    Ok(acp::PromptResponse {
1468                        stop_reason: acp::StopReason::EndTurn,
1469                    })
1470                }
1471                .boxed_local()
1472            },
1473        ));
1474
1475        let thread = cx
1476            .spawn(async move |mut cx| {
1477                connection
1478                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1479                    .await
1480            })
1481            .await
1482            .unwrap();
1483
1484        thread
1485            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1486            .await
1487            .unwrap();
1488
1489        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1490        assert_eq!(
1491            output,
1492            indoc! {r#"
1493            ## User
1494
1495            Hello from Zed!
1496
1497            ## Assistant
1498
1499            <thinking>
1500            Thinking hard!
1501            </thinking>
1502
1503            "#}
1504        );
1505    }
1506
1507    #[gpui::test]
1508    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1509        init_test(cx);
1510
1511        let fs = FakeFs::new(cx.executor());
1512        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1513            .await;
1514        let project = Project::test(fs.clone(), [], cx).await;
1515        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1516        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1517        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1518            move |_, thread, mut cx| {
1519                let read_file_tx = read_file_tx.clone();
1520                async move {
1521                    let content = thread
1522                        .update(&mut cx, |thread, cx| {
1523                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1524                        })
1525                        .unwrap()
1526                        .await
1527                        .unwrap();
1528                    assert_eq!(content, "one\ntwo\nthree\n");
1529                    read_file_tx.take().unwrap().send(()).unwrap();
1530                    thread
1531                        .update(&mut cx, |thread, cx| {
1532                            thread.write_text_file(
1533                                path!("/tmp/foo").into(),
1534                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
1535                                cx,
1536                            )
1537                        })
1538                        .unwrap()
1539                        .await
1540                        .unwrap();
1541                    Ok(acp::PromptResponse {
1542                        stop_reason: acp::StopReason::EndTurn,
1543                    })
1544                }
1545                .boxed_local()
1546            },
1547        ));
1548
1549        let (worktree, pathbuf) = project
1550            .update(cx, |project, cx| {
1551                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1552            })
1553            .await
1554            .unwrap();
1555        let buffer = project
1556            .update(cx, |project, cx| {
1557                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1558            })
1559            .await
1560            .unwrap();
1561
1562        let thread = cx
1563            .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1564            .await
1565            .unwrap();
1566
1567        let request = thread.update(cx, |thread, cx| {
1568            thread.send_raw("Extend the count in /tmp/foo", cx)
1569        });
1570        read_file_rx.await.ok();
1571        buffer.update(cx, |buffer, cx| {
1572            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1573        });
1574        cx.run_until_parked();
1575        assert_eq!(
1576            buffer.read_with(cx, |buffer, _| buffer.text()),
1577            "zero\none\ntwo\nthree\nfour\nfive\n"
1578        );
1579        assert_eq!(
1580            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1581            "zero\none\ntwo\nthree\nfour\nfive\n"
1582        );
1583        request.await.unwrap();
1584    }
1585
1586    #[gpui::test]
1587    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1588        init_test(cx);
1589
1590        let fs = FakeFs::new(cx.executor());
1591        let project = Project::test(fs, [], cx).await;
1592        let id = acp::ToolCallId("test".into());
1593
1594        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1595            let id = id.clone();
1596            move |_, thread, mut cx| {
1597                let id = id.clone();
1598                async move {
1599                    thread
1600                        .update(&mut cx, |thread, cx| {
1601                            thread.handle_session_update(
1602                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1603                                    id: id.clone(),
1604                                    title: "Label".into(),
1605                                    kind: acp::ToolKind::Fetch,
1606                                    status: acp::ToolCallStatus::InProgress,
1607                                    content: vec![],
1608                                    locations: vec![],
1609                                    raw_input: None,
1610                                    raw_output: None,
1611                                }),
1612                                cx,
1613                            )
1614                        })
1615                        .unwrap()
1616                        .unwrap();
1617                    Ok(acp::PromptResponse {
1618                        stop_reason: acp::StopReason::EndTurn,
1619                    })
1620                }
1621                .boxed_local()
1622            }
1623        }));
1624
1625        let thread = cx
1626            .spawn(async move |mut cx| {
1627                connection
1628                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1629                    .await
1630            })
1631            .await
1632            .unwrap();
1633
1634        let request = thread.update(cx, |thread, cx| {
1635            thread.send_raw("Fetch https://example.com", cx)
1636        });
1637
1638        run_until_first_tool_call(&thread, cx).await;
1639
1640        thread.read_with(cx, |thread, _| {
1641            assert!(matches!(
1642                thread.entries[1],
1643                AgentThreadEntry::ToolCall(ToolCall {
1644                    status: ToolCallStatus::Allowed {
1645                        status: acp::ToolCallStatus::InProgress,
1646                        ..
1647                    },
1648                    ..
1649                })
1650            ));
1651        });
1652
1653        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1654
1655        thread.read_with(cx, |thread, _| {
1656            assert!(matches!(
1657                &thread.entries[1],
1658                AgentThreadEntry::ToolCall(ToolCall {
1659                    status: ToolCallStatus::Canceled,
1660                    ..
1661                })
1662            ));
1663        });
1664
1665        thread
1666            .update(cx, |thread, cx| {
1667                thread.handle_session_update(
1668                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1669                        id,
1670                        fields: acp::ToolCallUpdateFields {
1671                            status: Some(acp::ToolCallStatus::Completed),
1672                            ..Default::default()
1673                        },
1674                    }),
1675                    cx,
1676                )
1677            })
1678            .unwrap();
1679
1680        request.await.unwrap();
1681
1682        thread.read_with(cx, |thread, _| {
1683            assert!(matches!(
1684                thread.entries[1],
1685                AgentThreadEntry::ToolCall(ToolCall {
1686                    status: ToolCallStatus::Allowed {
1687                        status: acp::ToolCallStatus::Completed,
1688                        ..
1689                    },
1690                    ..
1691                })
1692            ));
1693        });
1694    }
1695
1696    #[gpui::test]
1697    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1698        init_test(cx);
1699        let fs = FakeFs::new(cx.background_executor.clone());
1700        fs.insert_tree(path!("/test"), json!({})).await;
1701        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1702
1703        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1704            move |_, thread, mut cx| {
1705                async move {
1706                    thread
1707                        .update(&mut cx, |thread, cx| {
1708                            thread.handle_session_update(
1709                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1710                                    id: acp::ToolCallId("test".into()),
1711                                    title: "Label".into(),
1712                                    kind: acp::ToolKind::Edit,
1713                                    status: acp::ToolCallStatus::Completed,
1714                                    content: vec![acp::ToolCallContent::Diff {
1715                                        diff: acp::Diff {
1716                                            path: "/test/test.txt".into(),
1717                                            old_text: None,
1718                                            new_text: "foo".into(),
1719                                        },
1720                                    }],
1721                                    locations: vec![],
1722                                    raw_input: None,
1723                                    raw_output: None,
1724                                }),
1725                                cx,
1726                            )
1727                        })
1728                        .unwrap()
1729                        .unwrap();
1730                    Ok(acp::PromptResponse {
1731                        stop_reason: acp::StopReason::EndTurn,
1732                    })
1733                }
1734                .boxed_local()
1735            }
1736        }));
1737
1738        let thread = connection
1739            .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1740            .await
1741            .unwrap();
1742        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1743            .await
1744            .unwrap();
1745
1746        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1747    }
1748
1749    async fn run_until_first_tool_call(
1750        thread: &Entity<AcpThread>,
1751        cx: &mut TestAppContext,
1752    ) -> usize {
1753        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1754
1755        let subscription = cx.update(|cx| {
1756            cx.subscribe(thread, move |thread, _, cx| {
1757                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1758                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1759                        return tx.try_send(ix).unwrap();
1760                    }
1761                }
1762            })
1763        });
1764
1765        select! {
1766            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1767                panic!("Timeout waiting for tool call")
1768            }
1769            ix = rx.next().fuse() => {
1770                drop(subscription);
1771                ix.unwrap()
1772            }
1773        }
1774    }
1775
1776    #[derive(Clone, Default)]
1777    struct FakeAgentConnection {
1778        auth_methods: Vec<acp::AuthMethod>,
1779        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1780        on_user_message: Option<
1781            Rc<
1782                dyn Fn(
1783                        acp::PromptRequest,
1784                        WeakEntity<AcpThread>,
1785                        AsyncApp,
1786                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1787                    + 'static,
1788            >,
1789        >,
1790    }
1791
1792    impl FakeAgentConnection {
1793        fn new() -> Self {
1794            Self {
1795                auth_methods: Vec::new(),
1796                on_user_message: None,
1797                sessions: Arc::default(),
1798            }
1799        }
1800
1801        #[expect(unused)]
1802        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1803            self.auth_methods = auth_methods;
1804            self
1805        }
1806
1807        fn on_user_message(
1808            mut self,
1809            handler: impl Fn(
1810                acp::PromptRequest,
1811                WeakEntity<AcpThread>,
1812                AsyncApp,
1813            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1814            + 'static,
1815        ) -> Self {
1816            self.on_user_message.replace(Rc::new(handler));
1817            self
1818        }
1819    }
1820
1821    impl AgentConnection for FakeAgentConnection {
1822        fn auth_methods(&self) -> &[acp::AuthMethod] {
1823            &self.auth_methods
1824        }
1825
1826        fn new_thread(
1827            self: Rc<Self>,
1828            project: Entity<Project>,
1829            _cwd: &Path,
1830            cx: &mut gpui::AsyncApp,
1831        ) -> Task<gpui::Result<Entity<AcpThread>>> {
1832            let session_id = acp::SessionId(
1833                rand::thread_rng()
1834                    .sample_iter(&rand::distributions::Alphanumeric)
1835                    .take(7)
1836                    .map(char::from)
1837                    .collect::<String>()
1838                    .into(),
1839            );
1840            let thread = cx
1841                .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1842                .unwrap();
1843            self.sessions.lock().insert(session_id, thread.downgrade());
1844            Task::ready(Ok(thread))
1845        }
1846
1847        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1848            if self.auth_methods().iter().any(|m| m.id == method) {
1849                Task::ready(Ok(()))
1850            } else {
1851                Task::ready(Err(anyhow!("Invalid Auth Method")))
1852            }
1853        }
1854
1855        fn prompt(
1856            &self,
1857            params: acp::PromptRequest,
1858            cx: &mut App,
1859        ) -> Task<gpui::Result<acp::PromptResponse>> {
1860            let sessions = self.sessions.lock();
1861            let thread = sessions.get(&params.session_id).unwrap();
1862            if let Some(handler) = &self.on_user_message {
1863                let handler = handler.clone();
1864                let thread = thread.clone();
1865                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1866            } else {
1867                Task::ready(Ok(acp::PromptResponse {
1868                    stop_reason: acp::StopReason::EndTurn,
1869                }))
1870            }
1871        }
1872
1873        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1874            let sessions = self.sessions.lock();
1875            let thread = sessions.get(&session_id).unwrap().clone();
1876
1877            cx.spawn(async move |cx| {
1878                thread
1879                    .update(cx, |thread, cx| thread.cancel(cx))
1880                    .unwrap()
1881                    .await
1882            })
1883            .detach();
1884        }
1885    }
1886}