acp_thread.rs

   1mod connection;
   2mod diff;
   3mod mention;
   4mod terminal;
   5
   6pub use connection::*;
   7pub use diff::*;
   8pub use mention::*;
   9pub use terminal::*;
  10
  11use action_log::ActionLog;
  12use agent_client_protocol::{self as acp};
  13use anyhow::{Context as _, Result};
  14use editor::Bias;
  15use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  16use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  17use itertools::Itertools;
  18use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
  19use markdown::Markdown;
  20use project::{AgentLocation, Project};
  21use std::collections::HashMap;
  22use std::error::Error;
  23use std::fmt::Formatter;
  24use std::process::ExitStatus;
  25use std::rc::Rc;
  26use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
  27use ui::App;
  28use util::ResultExt;
  29
  30#[derive(Debug)]
  31pub struct UserMessage {
  32    pub content: ContentBlock,
  33}
  34
  35impl UserMessage {
  36    pub fn from_acp(
  37        message: impl IntoIterator<Item = acp::ContentBlock>,
  38        language_registry: Arc<LanguageRegistry>,
  39        cx: &mut App,
  40    ) -> Self {
  41        let mut content = ContentBlock::Empty;
  42        for chunk in message {
  43            content.append(chunk, &language_registry, cx)
  44        }
  45        Self { content: content }
  46    }
  47
  48    fn to_markdown(&self, cx: &App) -> String {
  49        format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
  50    }
  51}
  52
  53#[derive(Debug, PartialEq)]
  54pub struct AssistantMessage {
  55    pub chunks: Vec<AssistantMessageChunk>,
  56}
  57
  58impl AssistantMessage {
  59    pub fn to_markdown(&self, cx: &App) -> String {
  60        format!(
  61            "## Assistant\n\n{}\n\n",
  62            self.chunks
  63                .iter()
  64                .map(|chunk| chunk.to_markdown(cx))
  65                .join("\n\n")
  66        )
  67    }
  68}
  69
  70#[derive(Debug, PartialEq)]
  71pub enum AssistantMessageChunk {
  72    Message { block: ContentBlock },
  73    Thought { block: ContentBlock },
  74}
  75
  76impl AssistantMessageChunk {
  77    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
  78        Self::Message {
  79            block: ContentBlock::new(chunk.into(), language_registry, cx),
  80        }
  81    }
  82
  83    fn to_markdown(&self, cx: &App) -> String {
  84        match self {
  85            Self::Message { block } => block.to_markdown(cx).to_string(),
  86            Self::Thought { block } => {
  87                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
  88            }
  89        }
  90    }
  91}
  92
  93#[derive(Debug)]
  94pub enum AgentThreadEntry {
  95    UserMessage(UserMessage),
  96    AssistantMessage(AssistantMessage),
  97    ToolCall(ToolCall),
  98}
  99
 100impl AgentThreadEntry {
 101    fn to_markdown(&self, cx: &App) -> String {
 102        match self {
 103            Self::UserMessage(message) => message.to_markdown(cx),
 104            Self::AssistantMessage(message) => message.to_markdown(cx),
 105            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 106        }
 107    }
 108
 109    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 110        if let AgentThreadEntry::ToolCall(call) = self {
 111            itertools::Either::Left(call.diffs())
 112        } else {
 113            itertools::Either::Right(std::iter::empty())
 114        }
 115    }
 116
 117    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 118        if let AgentThreadEntry::ToolCall(call) = self {
 119            itertools::Either::Left(call.terminals())
 120        } else {
 121            itertools::Either::Right(std::iter::empty())
 122        }
 123    }
 124
 125    pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
 126        if let AgentThreadEntry::ToolCall(ToolCall {
 127            locations,
 128            resolved_locations,
 129            ..
 130        }) = self
 131        {
 132            Some((
 133                locations.get(ix)?.clone(),
 134                resolved_locations.get(ix)?.clone()?,
 135            ))
 136        } else {
 137            None
 138        }
 139    }
 140}
 141
 142#[derive(Debug)]
 143pub struct ToolCall {
 144    pub id: acp::ToolCallId,
 145    pub label: Entity<Markdown>,
 146    pub kind: acp::ToolKind,
 147    pub content: Vec<ToolCallContent>,
 148    pub status: ToolCallStatus,
 149    pub locations: Vec<acp::ToolCallLocation>,
 150    pub resolved_locations: Vec<Option<AgentLocation>>,
 151    pub raw_input: Option<serde_json::Value>,
 152    pub raw_output: Option<serde_json::Value>,
 153}
 154
 155impl ToolCall {
 156    fn from_acp(
 157        tool_call: acp::ToolCall,
 158        status: ToolCallStatus,
 159        language_registry: Arc<LanguageRegistry>,
 160        cx: &mut App,
 161    ) -> Self {
 162        Self {
 163            id: tool_call.id,
 164            label: cx.new(|cx| {
 165                Markdown::new(
 166                    tool_call.title.into(),
 167                    Some(language_registry.clone()),
 168                    None,
 169                    cx,
 170                )
 171            }),
 172            kind: tool_call.kind,
 173            content: tool_call
 174                .content
 175                .into_iter()
 176                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 177                .collect(),
 178            locations: tool_call.locations,
 179            resolved_locations: Vec::default(),
 180            status,
 181            raw_input: tool_call.raw_input,
 182            raw_output: tool_call.raw_output,
 183        }
 184    }
 185
 186    fn update_fields(
 187        &mut self,
 188        fields: acp::ToolCallUpdateFields,
 189        language_registry: Arc<LanguageRegistry>,
 190        cx: &mut App,
 191    ) {
 192        let acp::ToolCallUpdateFields {
 193            kind,
 194            status,
 195            title,
 196            content,
 197            locations,
 198            raw_input,
 199            raw_output,
 200        } = fields;
 201
 202        if let Some(kind) = kind {
 203            self.kind = kind;
 204        }
 205
 206        if let Some(status) = status {
 207            self.status = ToolCallStatus::Allowed { status };
 208        }
 209
 210        if let Some(title) = title {
 211            self.label.update(cx, |label, cx| {
 212                label.replace(title, cx);
 213            });
 214        }
 215
 216        if let Some(content) = content {
 217            self.content = content
 218                .into_iter()
 219                .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
 220                .collect();
 221        }
 222
 223        if let Some(locations) = locations {
 224            self.locations = locations;
 225        }
 226
 227        if let Some(raw_input) = raw_input {
 228            self.raw_input = Some(raw_input);
 229        }
 230
 231        if let Some(raw_output) = raw_output {
 232            if self.content.is_empty() {
 233                if let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
 234                {
 235                    self.content
 236                        .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
 237                            markdown,
 238                        }));
 239                }
 240            }
 241            self.raw_output = Some(raw_output);
 242        }
 243    }
 244
 245    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 246        self.content.iter().filter_map(|content| match content {
 247            ToolCallContent::Diff(diff) => Some(diff),
 248            ToolCallContent::ContentBlock(_) => None,
 249            ToolCallContent::Terminal(_) => None,
 250        })
 251    }
 252
 253    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 254        self.content.iter().filter_map(|content| match content {
 255            ToolCallContent::Terminal(terminal) => Some(terminal),
 256            ToolCallContent::ContentBlock(_) => None,
 257            ToolCallContent::Diff(_) => None,
 258        })
 259    }
 260
 261    fn to_markdown(&self, cx: &App) -> String {
 262        let mut markdown = format!(
 263            "**Tool Call: {}**\nStatus: {}\n\n",
 264            self.label.read(cx).source(),
 265            self.status
 266        );
 267        for content in &self.content {
 268            markdown.push_str(content.to_markdown(cx).as_str());
 269            markdown.push_str("\n\n");
 270        }
 271        markdown
 272    }
 273
 274    async fn resolve_location(
 275        location: acp::ToolCallLocation,
 276        project: WeakEntity<Project>,
 277        cx: &mut AsyncApp,
 278    ) -> Option<AgentLocation> {
 279        let buffer = project
 280            .update(cx, |project, cx| {
 281                if let Some(path) = project.project_path_for_absolute_path(&location.path, cx) {
 282                    Some(project.open_buffer(path, cx))
 283                } else {
 284                    None
 285                }
 286            })
 287            .ok()??;
 288        let buffer = buffer.await.log_err()?;
 289        let position = buffer
 290            .update(cx, |buffer, _| {
 291                if let Some(row) = location.line {
 292                    let snapshot = buffer.snapshot();
 293                    let column = snapshot.indent_size_for_line(row).len;
 294                    let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
 295                    snapshot.anchor_before(point)
 296                } else {
 297                    Anchor::MIN
 298                }
 299            })
 300            .ok()?;
 301
 302        Some(AgentLocation {
 303            buffer: buffer.downgrade(),
 304            position,
 305        })
 306    }
 307
 308    fn resolve_locations(
 309        &self,
 310        project: Entity<Project>,
 311        cx: &mut App,
 312    ) -> Task<Vec<Option<AgentLocation>>> {
 313        let locations = self.locations.clone();
 314        project.update(cx, |_, cx| {
 315            cx.spawn(async move |project, cx| {
 316                let mut new_locations = Vec::new();
 317                for location in locations {
 318                    new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
 319                }
 320                new_locations
 321            })
 322        })
 323    }
 324}
 325
 326#[derive(Debug)]
 327pub enum ToolCallStatus {
 328    WaitingForConfirmation {
 329        options: Vec<acp::PermissionOption>,
 330        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 331    },
 332    Allowed {
 333        status: acp::ToolCallStatus,
 334    },
 335    Rejected,
 336    Canceled,
 337}
 338
 339impl Display for ToolCallStatus {
 340    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 341        write!(
 342            f,
 343            "{}",
 344            match self {
 345                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 346                ToolCallStatus::Allowed { status } => match status {
 347                    acp::ToolCallStatus::Pending => "Pending",
 348                    acp::ToolCallStatus::InProgress => "In Progress",
 349                    acp::ToolCallStatus::Completed => "Completed",
 350                    acp::ToolCallStatus::Failed => "Failed",
 351                },
 352                ToolCallStatus::Rejected => "Rejected",
 353                ToolCallStatus::Canceled => "Canceled",
 354            }
 355        )
 356    }
 357}
 358
 359#[derive(Debug, PartialEq, Clone)]
 360pub enum ContentBlock {
 361    Empty,
 362    Markdown { markdown: Entity<Markdown> },
 363    ResourceLink { resource_link: acp::ResourceLink },
 364}
 365
 366impl ContentBlock {
 367    pub fn new(
 368        block: acp::ContentBlock,
 369        language_registry: &Arc<LanguageRegistry>,
 370        cx: &mut App,
 371    ) -> Self {
 372        let mut this = Self::Empty;
 373        this.append(block, language_registry, cx);
 374        this
 375    }
 376
 377    pub fn new_combined(
 378        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 379        language_registry: Arc<LanguageRegistry>,
 380        cx: &mut App,
 381    ) -> Self {
 382        let mut this = Self::Empty;
 383        for block in blocks {
 384            this.append(block, &language_registry, cx);
 385        }
 386        this
 387    }
 388
 389    pub fn append(
 390        &mut self,
 391        block: acp::ContentBlock,
 392        language_registry: &Arc<LanguageRegistry>,
 393        cx: &mut App,
 394    ) {
 395        if matches!(self, ContentBlock::Empty) {
 396            if let acp::ContentBlock::ResourceLink(resource_link) = block {
 397                *self = ContentBlock::ResourceLink { resource_link };
 398                return;
 399            }
 400        }
 401
 402        let new_content = self.extract_content_from_block(block);
 403
 404        match self {
 405            ContentBlock::Empty => {
 406                *self = Self::create_markdown_block(new_content, language_registry, cx);
 407            }
 408            ContentBlock::Markdown { markdown } => {
 409                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 410            }
 411            ContentBlock::ResourceLink { resource_link } => {
 412                let existing_content = Self::resource_link_to_content(&resource_link.uri);
 413                let combined = format!("{}\n{}", existing_content, new_content);
 414
 415                *self = Self::create_markdown_block(combined, language_registry, cx);
 416            }
 417        }
 418    }
 419
 420    fn resource_link_to_content(uri: &str) -> String {
 421        if let Some(uri) = MentionUri::parse(&uri).log_err() {
 422            uri.to_link()
 423        } else {
 424            uri.to_string().clone()
 425        }
 426    }
 427
 428    fn create_markdown_block(
 429        content: String,
 430        language_registry: &Arc<LanguageRegistry>,
 431        cx: &mut App,
 432    ) -> ContentBlock {
 433        ContentBlock::Markdown {
 434            markdown: cx
 435                .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
 436        }
 437    }
 438
 439    fn extract_content_from_block(&self, block: acp::ContentBlock) -> String {
 440        match block {
 441            acp::ContentBlock::Text(text_content) => text_content.text.clone(),
 442            acp::ContentBlock::ResourceLink(resource_link) => {
 443                Self::resource_link_to_content(&resource_link.uri)
 444            }
 445            acp::ContentBlock::Resource(acp::EmbeddedResource {
 446                resource:
 447                    acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
 448                        uri,
 449                        ..
 450                    }),
 451                ..
 452            }) => Self::resource_link_to_content(&uri),
 453            acp::ContentBlock::Image(_)
 454            | acp::ContentBlock::Audio(_)
 455            | acp::ContentBlock::Resource(_) => String::new(),
 456        }
 457    }
 458
 459    fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 460        match self {
 461            ContentBlock::Empty => "",
 462            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 463            ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
 464        }
 465    }
 466
 467    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 468        match self {
 469            ContentBlock::Empty => None,
 470            ContentBlock::Markdown { markdown } => Some(markdown),
 471            ContentBlock::ResourceLink { .. } => None,
 472        }
 473    }
 474
 475    pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
 476        match self {
 477            ContentBlock::ResourceLink { resource_link } => Some(resource_link),
 478            _ => None,
 479        }
 480    }
 481}
 482
 483#[derive(Debug)]
 484pub enum ToolCallContent {
 485    ContentBlock(ContentBlock),
 486    Diff(Entity<Diff>),
 487    Terminal(Entity<Terminal>),
 488}
 489
 490impl ToolCallContent {
 491    pub fn from_acp(
 492        content: acp::ToolCallContent,
 493        language_registry: Arc<LanguageRegistry>,
 494        cx: &mut App,
 495    ) -> Self {
 496        match content {
 497            acp::ToolCallContent::Content { content } => {
 498                Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
 499            }
 500            acp::ToolCallContent::Diff { diff } => {
 501                Self::Diff(cx.new(|cx| Diff::from_acp(diff, language_registry, cx)))
 502            }
 503        }
 504    }
 505
 506    pub fn to_markdown(&self, cx: &App) -> String {
 507        match self {
 508            Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
 509            Self::Diff(diff) => diff.read(cx).to_markdown(cx),
 510            Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
 511        }
 512    }
 513}
 514
 515#[derive(Debug, PartialEq)]
 516pub enum ToolCallUpdate {
 517    UpdateFields(acp::ToolCallUpdate),
 518    UpdateDiff(ToolCallUpdateDiff),
 519    UpdateTerminal(ToolCallUpdateTerminal),
 520}
 521
 522impl ToolCallUpdate {
 523    fn id(&self) -> &acp::ToolCallId {
 524        match self {
 525            Self::UpdateFields(update) => &update.id,
 526            Self::UpdateDiff(diff) => &diff.id,
 527            Self::UpdateTerminal(terminal) => &terminal.id,
 528        }
 529    }
 530}
 531
 532impl From<acp::ToolCallUpdate> for ToolCallUpdate {
 533    fn from(update: acp::ToolCallUpdate) -> Self {
 534        Self::UpdateFields(update)
 535    }
 536}
 537
 538impl From<ToolCallUpdateDiff> for ToolCallUpdate {
 539    fn from(diff: ToolCallUpdateDiff) -> Self {
 540        Self::UpdateDiff(diff)
 541    }
 542}
 543
 544#[derive(Debug, PartialEq)]
 545pub struct ToolCallUpdateDiff {
 546    pub id: acp::ToolCallId,
 547    pub diff: Entity<Diff>,
 548}
 549
 550impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
 551    fn from(terminal: ToolCallUpdateTerminal) -> Self {
 552        Self::UpdateTerminal(terminal)
 553    }
 554}
 555
 556#[derive(Debug, PartialEq)]
 557pub struct ToolCallUpdateTerminal {
 558    pub id: acp::ToolCallId,
 559    pub terminal: Entity<Terminal>,
 560}
 561
 562#[derive(Debug, Default)]
 563pub struct Plan {
 564    pub entries: Vec<PlanEntry>,
 565}
 566
 567#[derive(Debug)]
 568pub struct PlanStats<'a> {
 569    pub in_progress_entry: Option<&'a PlanEntry>,
 570    pub pending: u32,
 571    pub completed: u32,
 572}
 573
 574impl Plan {
 575    pub fn is_empty(&self) -> bool {
 576        self.entries.is_empty()
 577    }
 578
 579    pub fn stats(&self) -> PlanStats<'_> {
 580        let mut stats = PlanStats {
 581            in_progress_entry: None,
 582            pending: 0,
 583            completed: 0,
 584        };
 585
 586        for entry in &self.entries {
 587            match &entry.status {
 588                acp::PlanEntryStatus::Pending => {
 589                    stats.pending += 1;
 590                }
 591                acp::PlanEntryStatus::InProgress => {
 592                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 593                }
 594                acp::PlanEntryStatus::Completed => {
 595                    stats.completed += 1;
 596                }
 597            }
 598        }
 599
 600        stats
 601    }
 602}
 603
 604#[derive(Debug)]
 605pub struct PlanEntry {
 606    pub content: Entity<Markdown>,
 607    pub priority: acp::PlanEntryPriority,
 608    pub status: acp::PlanEntryStatus,
 609}
 610
 611impl PlanEntry {
 612    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 613        Self {
 614            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 615            priority: entry.priority,
 616            status: entry.status,
 617        }
 618    }
 619}
 620
 621pub struct AcpThread {
 622    title: SharedString,
 623    entries: Vec<AgentThreadEntry>,
 624    plan: Plan,
 625    project: Entity<Project>,
 626    action_log: Entity<ActionLog>,
 627    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 628    send_task: Option<Task<()>>,
 629    connection: Rc<dyn AgentConnection>,
 630    session_id: acp::SessionId,
 631}
 632
 633pub enum AcpThreadEvent {
 634    NewEntry,
 635    EntryUpdated(usize),
 636    ToolAuthorizationRequired,
 637    Stopped,
 638    Error,
 639    ServerExited(ExitStatus),
 640}
 641
 642impl EventEmitter<AcpThreadEvent> for AcpThread {}
 643
 644#[derive(PartialEq, Eq)]
 645pub enum ThreadStatus {
 646    Idle,
 647    WaitingForToolConfirmation,
 648    Generating,
 649}
 650
 651#[derive(Debug, Clone)]
 652pub enum LoadError {
 653    Unsupported {
 654        error_message: SharedString,
 655        upgrade_message: SharedString,
 656        upgrade_command: String,
 657    },
 658    Exited(i32),
 659    Other(SharedString),
 660}
 661
 662impl Display for LoadError {
 663    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 664        match self {
 665            LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
 666            LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
 667            LoadError::Other(msg) => write!(f, "{}", msg),
 668        }
 669    }
 670}
 671
 672impl Error for LoadError {}
 673
 674impl AcpThread {
 675    pub fn new(
 676        title: impl Into<SharedString>,
 677        connection: Rc<dyn AgentConnection>,
 678        project: Entity<Project>,
 679        session_id: acp::SessionId,
 680        cx: &mut Context<Self>,
 681    ) -> Self {
 682        let action_log = cx.new(|_| ActionLog::new(project.clone()));
 683
 684        Self {
 685            action_log,
 686            shared_buffers: Default::default(),
 687            entries: Default::default(),
 688            plan: Default::default(),
 689            title: title.into(),
 690            project,
 691            send_task: None,
 692            connection,
 693            session_id,
 694        }
 695    }
 696
 697    pub fn connection(&self) -> &Rc<dyn AgentConnection> {
 698        &self.connection
 699    }
 700
 701    pub fn action_log(&self) -> &Entity<ActionLog> {
 702        &self.action_log
 703    }
 704
 705    pub fn project(&self) -> &Entity<Project> {
 706        &self.project
 707    }
 708
 709    pub fn title(&self) -> SharedString {
 710        self.title.clone()
 711    }
 712
 713    pub fn entries(&self) -> &[AgentThreadEntry] {
 714        &self.entries
 715    }
 716
 717    pub fn session_id(&self) -> &acp::SessionId {
 718        &self.session_id
 719    }
 720
 721    pub fn status(&self) -> ThreadStatus {
 722        if self.send_task.is_some() {
 723            if self.waiting_for_tool_confirmation() {
 724                ThreadStatus::WaitingForToolConfirmation
 725            } else {
 726                ThreadStatus::Generating
 727            }
 728        } else {
 729            ThreadStatus::Idle
 730        }
 731    }
 732
 733    pub fn has_pending_edit_tool_calls(&self) -> bool {
 734        for entry in self.entries.iter().rev() {
 735            match entry {
 736                AgentThreadEntry::UserMessage(_) => return false,
 737                AgentThreadEntry::ToolCall(
 738                    call @ ToolCall {
 739                        status:
 740                            ToolCallStatus::Allowed {
 741                                status:
 742                                    acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
 743                            },
 744                        ..
 745                    },
 746                ) if call.diffs().next().is_some() => {
 747                    return true;
 748                }
 749                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 750            }
 751        }
 752
 753        false
 754    }
 755
 756    pub fn used_tools_since_last_user_message(&self) -> bool {
 757        for entry in self.entries.iter().rev() {
 758            match entry {
 759                AgentThreadEntry::UserMessage(..) => return false,
 760                AgentThreadEntry::AssistantMessage(..) => continue,
 761                AgentThreadEntry::ToolCall(..) => return true,
 762            }
 763        }
 764
 765        false
 766    }
 767
 768    pub fn handle_session_update(
 769        &mut self,
 770        update: acp::SessionUpdate,
 771        cx: &mut Context<Self>,
 772    ) -> Result<()> {
 773        match update {
 774            acp::SessionUpdate::UserMessageChunk { content } => {
 775                self.push_user_content_block(content, cx);
 776            }
 777            acp::SessionUpdate::AgentMessageChunk { content } => {
 778                self.push_assistant_content_block(content, false, cx);
 779            }
 780            acp::SessionUpdate::AgentThoughtChunk { content } => {
 781                self.push_assistant_content_block(content, true, cx);
 782            }
 783            acp::SessionUpdate::ToolCall(tool_call) => {
 784                self.upsert_tool_call(tool_call, cx);
 785            }
 786            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 787                self.update_tool_call(tool_call_update, cx)?;
 788            }
 789            acp::SessionUpdate::Plan(plan) => {
 790                self.update_plan(plan, cx);
 791            }
 792        }
 793        Ok(())
 794    }
 795
 796    pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
 797        let language_registry = self.project.read(cx).languages().clone();
 798        let entries_len = self.entries.len();
 799
 800        if let Some(last_entry) = self.entries.last_mut()
 801            && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
 802        {
 803            content.append(chunk, &language_registry, cx);
 804            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 805        } else {
 806            let content = ContentBlock::new(chunk, &language_registry, cx);
 807            self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
 808        }
 809    }
 810
 811    pub fn push_assistant_content_block(
 812        &mut self,
 813        chunk: acp::ContentBlock,
 814        is_thought: bool,
 815        cx: &mut Context<Self>,
 816    ) {
 817        let language_registry = self.project.read(cx).languages().clone();
 818        let entries_len = self.entries.len();
 819        if let Some(last_entry) = self.entries.last_mut()
 820            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
 821        {
 822            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 823            match (chunks.last_mut(), is_thought) {
 824                (Some(AssistantMessageChunk::Message { block }), false)
 825                | (Some(AssistantMessageChunk::Thought { block }), true) => {
 826                    block.append(chunk, &language_registry, cx)
 827                }
 828                _ => {
 829                    let block = ContentBlock::new(chunk, &language_registry, cx);
 830                    if is_thought {
 831                        chunks.push(AssistantMessageChunk::Thought { block })
 832                    } else {
 833                        chunks.push(AssistantMessageChunk::Message { block })
 834                    }
 835                }
 836            }
 837        } else {
 838            let block = ContentBlock::new(chunk, &language_registry, cx);
 839            let chunk = if is_thought {
 840                AssistantMessageChunk::Thought { block }
 841            } else {
 842                AssistantMessageChunk::Message { block }
 843            };
 844
 845            self.push_entry(
 846                AgentThreadEntry::AssistantMessage(AssistantMessage {
 847                    chunks: vec![chunk],
 848                }),
 849                cx,
 850            );
 851        }
 852    }
 853
 854    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
 855        self.entries.push(entry);
 856        cx.emit(AcpThreadEvent::NewEntry);
 857    }
 858
 859    pub fn update_tool_call(
 860        &mut self,
 861        update: impl Into<ToolCallUpdate>,
 862        cx: &mut Context<Self>,
 863    ) -> Result<()> {
 864        let update = update.into();
 865        let languages = self.project.read(cx).languages().clone();
 866
 867        let (ix, current_call) = self
 868            .tool_call_mut(update.id())
 869            .context("Tool call not found")?;
 870        match update {
 871            ToolCallUpdate::UpdateFields(update) => {
 872                let location_updated = update.fields.locations.is_some();
 873                current_call.update_fields(update.fields, languages, cx);
 874                if location_updated {
 875                    self.resolve_locations(update.id.clone(), cx);
 876                }
 877            }
 878            ToolCallUpdate::UpdateDiff(update) => {
 879                current_call.content.clear();
 880                current_call
 881                    .content
 882                    .push(ToolCallContent::Diff(update.diff));
 883            }
 884            ToolCallUpdate::UpdateTerminal(update) => {
 885                current_call.content.clear();
 886                current_call
 887                    .content
 888                    .push(ToolCallContent::Terminal(update.terminal));
 889            }
 890        }
 891
 892        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 893
 894        Ok(())
 895    }
 896
 897    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
 898    pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
 899        let status = ToolCallStatus::Allowed {
 900            status: tool_call.status,
 901        };
 902        self.upsert_tool_call_inner(tool_call, status, cx)
 903    }
 904
 905    pub fn upsert_tool_call_inner(
 906        &mut self,
 907        tool_call: acp::ToolCall,
 908        status: ToolCallStatus,
 909        cx: &mut Context<Self>,
 910    ) {
 911        let language_registry = self.project.read(cx).languages().clone();
 912        let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
 913        let id = call.id.clone();
 914
 915        if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
 916            *current_call = call;
 917
 918            cx.emit(AcpThreadEvent::EntryUpdated(ix));
 919        } else {
 920            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
 921        };
 922
 923        self.resolve_locations(id, cx);
 924    }
 925
 926    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
 927        // The tool call we are looking for is typically the last one, or very close to the end.
 928        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
 929        self.entries
 930            .iter_mut()
 931            .enumerate()
 932            .rev()
 933            .find_map(|(index, tool_call)| {
 934                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
 935                    && &tool_call.id == id
 936                {
 937                    Some((index, tool_call))
 938                } else {
 939                    None
 940                }
 941            })
 942    }
 943
 944    pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
 945        let project = self.project.clone();
 946        let Some((_, tool_call)) = self.tool_call_mut(&id) else {
 947            return;
 948        };
 949        let task = tool_call.resolve_locations(project, cx);
 950        cx.spawn(async move |this, cx| {
 951            let resolved_locations = task.await;
 952            this.update(cx, |this, cx| {
 953                let project = this.project.clone();
 954                let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
 955                    return;
 956                };
 957                if let Some(Some(location)) = resolved_locations.last() {
 958                    project.update(cx, |project, cx| {
 959                        if let Some(agent_location) = project.agent_location() {
 960                            let should_ignore = agent_location.buffer == location.buffer
 961                                && location
 962                                    .buffer
 963                                    .update(cx, |buffer, _| {
 964                                        let snapshot = buffer.snapshot();
 965                                        let old_position =
 966                                            agent_location.position.to_point(&snapshot);
 967                                        let new_position = location.position.to_point(&snapshot);
 968                                        // ignore this so that when we get updates from the edit tool
 969                                        // the position doesn't reset to the startof line
 970                                        old_position.row == new_position.row
 971                                            && old_position.column > new_position.column
 972                                    })
 973                                    .ok()
 974                                    .unwrap_or_default();
 975                            if !should_ignore {
 976                                project.set_agent_location(Some(location.clone()), cx);
 977                            }
 978                        }
 979                    });
 980                }
 981                if tool_call.resolved_locations != resolved_locations {
 982                    tool_call.resolved_locations = resolved_locations;
 983                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
 984                }
 985            })
 986        })
 987        .detach();
 988    }
 989
 990    pub fn request_tool_call_authorization(
 991        &mut self,
 992        tool_call: acp::ToolCall,
 993        options: Vec<acp::PermissionOption>,
 994        cx: &mut Context<Self>,
 995    ) -> oneshot::Receiver<acp::PermissionOptionId> {
 996        let (tx, rx) = oneshot::channel();
 997
 998        let status = ToolCallStatus::WaitingForConfirmation {
 999            options,
1000            respond_tx: tx,
1001        };
1002
1003        self.upsert_tool_call_inner(tool_call, status, cx);
1004        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1005        rx
1006    }
1007
1008    pub fn authorize_tool_call(
1009        &mut self,
1010        id: acp::ToolCallId,
1011        option_id: acp::PermissionOptionId,
1012        option_kind: acp::PermissionOptionKind,
1013        cx: &mut Context<Self>,
1014    ) {
1015        let Some((ix, call)) = self.tool_call_mut(&id) else {
1016            return;
1017        };
1018
1019        let new_status = match option_kind {
1020            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1021                ToolCallStatus::Rejected
1022            }
1023            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1024                ToolCallStatus::Allowed {
1025                    status: acp::ToolCallStatus::InProgress,
1026                }
1027            }
1028        };
1029
1030        let curr_status = mem::replace(&mut call.status, new_status);
1031
1032        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1033            respond_tx.send(option_id).log_err();
1034        } else if cfg!(debug_assertions) {
1035            panic!("tried to authorize an already authorized tool call");
1036        }
1037
1038        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1039    }
1040
1041    /// Returns true if the last turn is awaiting tool authorization
1042    pub fn waiting_for_tool_confirmation(&self) -> bool {
1043        for entry in self.entries.iter().rev() {
1044            match &entry {
1045                AgentThreadEntry::ToolCall(call) => match call.status {
1046                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
1047                    ToolCallStatus::Allowed { .. }
1048                    | ToolCallStatus::Rejected
1049                    | ToolCallStatus::Canceled => continue,
1050                },
1051                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1052                    // Reached the beginning of the turn
1053                    return false;
1054                }
1055            }
1056        }
1057        false
1058    }
1059
1060    pub fn plan(&self) -> &Plan {
1061        &self.plan
1062    }
1063
1064    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1065        let new_entries_len = request.entries.len();
1066        let mut new_entries = request.entries.into_iter();
1067
1068        // Reuse existing markdown to prevent flickering
1069        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1070            let PlanEntry {
1071                content,
1072                priority,
1073                status,
1074            } = old;
1075            content.update(cx, |old, cx| {
1076                old.replace(new.content, cx);
1077            });
1078            *priority = new.priority;
1079            *status = new.status;
1080        }
1081        for new in new_entries {
1082            self.plan.entries.push(PlanEntry::from_acp(new, cx))
1083        }
1084        self.plan.entries.truncate(new_entries_len);
1085
1086        cx.notify();
1087    }
1088
1089    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1090        self.plan
1091            .entries
1092            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1093        cx.notify();
1094    }
1095
1096    #[cfg(any(test, feature = "test-support"))]
1097    pub fn send_raw(
1098        &mut self,
1099        message: &str,
1100        cx: &mut Context<Self>,
1101    ) -> BoxFuture<'static, Result<()>> {
1102        self.send(
1103            vec![acp::ContentBlock::Text(acp::TextContent {
1104                text: message.to_string(),
1105                annotations: None,
1106            })],
1107            cx,
1108        )
1109    }
1110
1111    pub fn send(
1112        &mut self,
1113        message: Vec<acp::ContentBlock>,
1114        cx: &mut Context<Self>,
1115    ) -> BoxFuture<'static, Result<()>> {
1116        let block = ContentBlock::new_combined(
1117            message.clone(),
1118            self.project.read(cx).languages().clone(),
1119            cx,
1120        );
1121        self.push_entry(
1122            AgentThreadEntry::UserMessage(UserMessage { content: block }),
1123            cx,
1124        );
1125        self.clear_completed_plan_entries(cx);
1126
1127        let (tx, rx) = oneshot::channel();
1128        let cancel_task = self.cancel(cx);
1129
1130        self.send_task = Some(cx.spawn(async move |this, cx| {
1131            async {
1132                cancel_task.await;
1133
1134                let result = this
1135                    .update(cx, |this, cx| {
1136                        this.connection.prompt(
1137                            acp::PromptRequest {
1138                                prompt: message,
1139                                session_id: this.session_id.clone(),
1140                            },
1141                            cx,
1142                        )
1143                    })?
1144                    .await;
1145
1146                tx.send(result).log_err();
1147
1148                anyhow::Ok(())
1149            }
1150            .await
1151            .log_err();
1152        }));
1153
1154        cx.spawn(async move |this, cx| match rx.await {
1155            Ok(Err(e)) => {
1156                this.update(cx, |this, cx| {
1157                    this.send_task.take();
1158                    cx.emit(AcpThreadEvent::Error)
1159                })
1160                .log_err();
1161                Err(e)?
1162            }
1163            result => {
1164                let cancelled = matches!(
1165                    result,
1166                    Ok(Ok(acp::PromptResponse {
1167                        stop_reason: acp::StopReason::Cancelled
1168                    }))
1169                );
1170
1171                // We only take the task if the current prompt wasn't cancelled.
1172                //
1173                // This prompt may have been cancelled because another one was sent
1174                // while it was still generating. In these cases, dropping `send_task`
1175                // would cause the next generation to be cancelled.
1176                if !cancelled {
1177                    this.update(cx, |this, _cx| this.send_task.take()).ok();
1178                }
1179
1180                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1181                    .log_err();
1182                Ok(())
1183            }
1184        })
1185        .boxed()
1186    }
1187
1188    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1189        let Some(send_task) = self.send_task.take() else {
1190            return Task::ready(());
1191        };
1192
1193        for entry in self.entries.iter_mut() {
1194            if let AgentThreadEntry::ToolCall(call) = entry {
1195                let cancel = matches!(
1196                    call.status,
1197                    ToolCallStatus::WaitingForConfirmation { .. }
1198                        | ToolCallStatus::Allowed {
1199                            status: acp::ToolCallStatus::InProgress
1200                        }
1201                );
1202
1203                if cancel {
1204                    call.status = ToolCallStatus::Canceled;
1205                }
1206            }
1207        }
1208
1209        self.connection.cancel(&self.session_id, cx);
1210
1211        // Wait for the send task to complete
1212        cx.foreground_executor().spawn(send_task)
1213    }
1214
1215    pub fn read_text_file(
1216        &self,
1217        path: PathBuf,
1218        line: Option<u32>,
1219        limit: Option<u32>,
1220        reuse_shared_snapshot: bool,
1221        cx: &mut Context<Self>,
1222    ) -> Task<Result<String>> {
1223        let project = self.project.clone();
1224        let action_log = self.action_log.clone();
1225        cx.spawn(async move |this, cx| {
1226            let load = project.update(cx, |project, cx| {
1227                let path = project
1228                    .project_path_for_absolute_path(&path, cx)
1229                    .context("invalid path")?;
1230                anyhow::Ok(project.open_buffer(path, cx))
1231            });
1232            let buffer = load??.await?;
1233
1234            let snapshot = if reuse_shared_snapshot {
1235                this.read_with(cx, |this, _| {
1236                    this.shared_buffers.get(&buffer.clone()).cloned()
1237                })
1238                .log_err()
1239                .flatten()
1240            } else {
1241                None
1242            };
1243
1244            let snapshot = if let Some(snapshot) = snapshot {
1245                snapshot
1246            } else {
1247                action_log.update(cx, |action_log, cx| {
1248                    action_log.buffer_read(buffer.clone(), cx);
1249                })?;
1250                project.update(cx, |project, cx| {
1251                    let position = buffer
1252                        .read(cx)
1253                        .snapshot()
1254                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1255                    project.set_agent_location(
1256                        Some(AgentLocation {
1257                            buffer: buffer.downgrade(),
1258                            position,
1259                        }),
1260                        cx,
1261                    );
1262                })?;
1263
1264                buffer.update(cx, |buffer, _| buffer.snapshot())?
1265            };
1266
1267            this.update(cx, |this, _| {
1268                let text = snapshot.text();
1269                this.shared_buffers.insert(buffer.clone(), snapshot);
1270                if line.is_none() && limit.is_none() {
1271                    return Ok(text);
1272                }
1273                let limit = limit.unwrap_or(u32::MAX) as usize;
1274                let Some(line) = line else {
1275                    return Ok(text.lines().take(limit).collect::<String>());
1276                };
1277
1278                let count = text.lines().count();
1279                if count < line as usize {
1280                    anyhow::bail!("There are only {} lines", count);
1281                }
1282                Ok(text
1283                    .lines()
1284                    .skip(line as usize + 1)
1285                    .take(limit)
1286                    .collect::<String>())
1287            })?
1288        })
1289    }
1290
1291    pub fn write_text_file(
1292        &self,
1293        path: PathBuf,
1294        content: String,
1295        cx: &mut Context<Self>,
1296    ) -> Task<Result<()>> {
1297        let project = self.project.clone();
1298        let action_log = self.action_log.clone();
1299        cx.spawn(async move |this, cx| {
1300            let load = project.update(cx, |project, cx| {
1301                let path = project
1302                    .project_path_for_absolute_path(&path, cx)
1303                    .context("invalid path")?;
1304                anyhow::Ok(project.open_buffer(path, cx))
1305            });
1306            let buffer = load??.await?;
1307            let snapshot = this.update(cx, |this, cx| {
1308                this.shared_buffers
1309                    .get(&buffer)
1310                    .cloned()
1311                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1312            })?;
1313            let edits = cx
1314                .background_executor()
1315                .spawn(async move {
1316                    let old_text = snapshot.text();
1317                    text_diff(old_text.as_str(), &content)
1318                        .into_iter()
1319                        .map(|(range, replacement)| {
1320                            (
1321                                snapshot.anchor_after(range.start)
1322                                    ..snapshot.anchor_before(range.end),
1323                                replacement,
1324                            )
1325                        })
1326                        .collect::<Vec<_>>()
1327                })
1328                .await;
1329            cx.update(|cx| {
1330                project.update(cx, |project, cx| {
1331                    project.set_agent_location(
1332                        Some(AgentLocation {
1333                            buffer: buffer.downgrade(),
1334                            position: edits
1335                                .last()
1336                                .map(|(range, _)| range.end)
1337                                .unwrap_or(Anchor::MIN),
1338                        }),
1339                        cx,
1340                    );
1341                });
1342
1343                action_log.update(cx, |action_log, cx| {
1344                    action_log.buffer_read(buffer.clone(), cx);
1345                });
1346                buffer.update(cx, |buffer, cx| {
1347                    buffer.edit(edits, None, cx);
1348                });
1349                action_log.update(cx, |action_log, cx| {
1350                    action_log.buffer_edited(buffer.clone(), cx);
1351                });
1352            })?;
1353            project
1354                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1355                .await
1356        })
1357    }
1358
1359    pub fn to_markdown(&self, cx: &App) -> String {
1360        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1361    }
1362
1363    pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1364        cx.emit(AcpThreadEvent::ServerExited(status));
1365    }
1366}
1367
1368fn markdown_for_raw_output(
1369    raw_output: &serde_json::Value,
1370    language_registry: &Arc<LanguageRegistry>,
1371    cx: &mut App,
1372) -> Option<Entity<Markdown>> {
1373    match raw_output {
1374        serde_json::Value::Null => None,
1375        serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1376            Markdown::new(
1377                value.to_string().into(),
1378                Some(language_registry.clone()),
1379                None,
1380                cx,
1381            )
1382        })),
1383        serde_json::Value::Number(value) => Some(cx.new(|cx| {
1384            Markdown::new(
1385                value.to_string().into(),
1386                Some(language_registry.clone()),
1387                None,
1388                cx,
1389            )
1390        })),
1391        serde_json::Value::String(value) => Some(cx.new(|cx| {
1392            Markdown::new(
1393                value.clone().into(),
1394                Some(language_registry.clone()),
1395                None,
1396                cx,
1397            )
1398        })),
1399        value => Some(cx.new(|cx| {
1400            Markdown::new(
1401                format!("```json\n{}\n```", value).into(),
1402                Some(language_registry.clone()),
1403                None,
1404                cx,
1405            )
1406        })),
1407    }
1408}
1409
1410#[cfg(test)]
1411mod tests {
1412    use super::*;
1413    use anyhow::anyhow;
1414    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1415    use gpui::{AsyncApp, TestAppContext, WeakEntity};
1416    use indoc::indoc;
1417    use project::FakeFs;
1418    use rand::Rng as _;
1419    use serde_json::json;
1420    use settings::SettingsStore;
1421    use smol::stream::StreamExt as _;
1422    use std::{cell::RefCell, path::Path, rc::Rc, time::Duration};
1423
1424    use util::path;
1425
1426    fn init_test(cx: &mut TestAppContext) {
1427        env_logger::try_init().ok();
1428        cx.update(|cx| {
1429            let settings_store = SettingsStore::test(cx);
1430            cx.set_global(settings_store);
1431            Project::init_settings(cx);
1432            language::init(cx);
1433        });
1434    }
1435
1436    #[gpui::test]
1437    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1438        init_test(cx);
1439
1440        let fs = FakeFs::new(cx.executor());
1441        let project = Project::test(fs, [], cx).await;
1442        let connection = Rc::new(FakeAgentConnection::new());
1443        let thread = cx
1444            .spawn(async move |mut cx| {
1445                connection
1446                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1447                    .await
1448            })
1449            .await
1450            .unwrap();
1451
1452        // Test creating a new user message
1453        thread.update(cx, |thread, cx| {
1454            thread.push_user_content_block(
1455                acp::ContentBlock::Text(acp::TextContent {
1456                    annotations: None,
1457                    text: "Hello, ".to_string(),
1458                }),
1459                cx,
1460            );
1461        });
1462
1463        thread.update(cx, |thread, cx| {
1464            assert_eq!(thread.entries.len(), 1);
1465            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1466                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1467            } else {
1468                panic!("Expected UserMessage");
1469            }
1470        });
1471
1472        // Test appending to existing user message
1473        thread.update(cx, |thread, cx| {
1474            thread.push_user_content_block(
1475                acp::ContentBlock::Text(acp::TextContent {
1476                    annotations: None,
1477                    text: "world!".to_string(),
1478                }),
1479                cx,
1480            );
1481        });
1482
1483        thread.update(cx, |thread, cx| {
1484            assert_eq!(thread.entries.len(), 1);
1485            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1486                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1487            } else {
1488                panic!("Expected UserMessage");
1489            }
1490        });
1491
1492        // Test creating new user message after assistant message
1493        thread.update(cx, |thread, cx| {
1494            thread.push_assistant_content_block(
1495                acp::ContentBlock::Text(acp::TextContent {
1496                    annotations: None,
1497                    text: "Assistant response".to_string(),
1498                }),
1499                false,
1500                cx,
1501            );
1502        });
1503
1504        thread.update(cx, |thread, cx| {
1505            thread.push_user_content_block(
1506                acp::ContentBlock::Text(acp::TextContent {
1507                    annotations: None,
1508                    text: "New user message".to_string(),
1509                }),
1510                cx,
1511            );
1512        });
1513
1514        thread.update(cx, |thread, cx| {
1515            assert_eq!(thread.entries.len(), 3);
1516            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1517                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1518            } else {
1519                panic!("Expected UserMessage at index 2");
1520            }
1521        });
1522    }
1523
1524    #[gpui::test]
1525    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1526        init_test(cx);
1527
1528        let fs = FakeFs::new(cx.executor());
1529        let project = Project::test(fs, [], cx).await;
1530        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1531            |_, thread, mut cx| {
1532                async move {
1533                    thread.update(&mut cx, |thread, cx| {
1534                        thread
1535                            .handle_session_update(
1536                                acp::SessionUpdate::AgentThoughtChunk {
1537                                    content: "Thinking ".into(),
1538                                },
1539                                cx,
1540                            )
1541                            .unwrap();
1542                        thread
1543                            .handle_session_update(
1544                                acp::SessionUpdate::AgentThoughtChunk {
1545                                    content: "hard!".into(),
1546                                },
1547                                cx,
1548                            )
1549                            .unwrap();
1550                    })?;
1551                    Ok(acp::PromptResponse {
1552                        stop_reason: acp::StopReason::EndTurn,
1553                    })
1554                }
1555                .boxed_local()
1556            },
1557        ));
1558
1559        let thread = cx
1560            .spawn(async move |mut cx| {
1561                connection
1562                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1563                    .await
1564            })
1565            .await
1566            .unwrap();
1567
1568        thread
1569            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1570            .await
1571            .unwrap();
1572
1573        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1574        assert_eq!(
1575            output,
1576            indoc! {r#"
1577            ## User
1578
1579            Hello from Zed!
1580
1581            ## Assistant
1582
1583            <thinking>
1584            Thinking hard!
1585            </thinking>
1586
1587            "#}
1588        );
1589    }
1590
1591    #[gpui::test]
1592    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1593        init_test(cx);
1594
1595        let fs = FakeFs::new(cx.executor());
1596        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1597            .await;
1598        let project = Project::test(fs.clone(), [], cx).await;
1599        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1600        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1601        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1602            move |_, thread, mut cx| {
1603                let read_file_tx = read_file_tx.clone();
1604                async move {
1605                    let content = thread
1606                        .update(&mut cx, |thread, cx| {
1607                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1608                        })
1609                        .unwrap()
1610                        .await
1611                        .unwrap();
1612                    assert_eq!(content, "one\ntwo\nthree\n");
1613                    read_file_tx.take().unwrap().send(()).unwrap();
1614                    thread
1615                        .update(&mut cx, |thread, cx| {
1616                            thread.write_text_file(
1617                                path!("/tmp/foo").into(),
1618                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
1619                                cx,
1620                            )
1621                        })
1622                        .unwrap()
1623                        .await
1624                        .unwrap();
1625                    Ok(acp::PromptResponse {
1626                        stop_reason: acp::StopReason::EndTurn,
1627                    })
1628                }
1629                .boxed_local()
1630            },
1631        ));
1632
1633        let (worktree, pathbuf) = project
1634            .update(cx, |project, cx| {
1635                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1636            })
1637            .await
1638            .unwrap();
1639        let buffer = project
1640            .update(cx, |project, cx| {
1641                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1642            })
1643            .await
1644            .unwrap();
1645
1646        let thread = cx
1647            .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1648            .await
1649            .unwrap();
1650
1651        let request = thread.update(cx, |thread, cx| {
1652            thread.send_raw("Extend the count in /tmp/foo", cx)
1653        });
1654        read_file_rx.await.ok();
1655        buffer.update(cx, |buffer, cx| {
1656            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1657        });
1658        cx.run_until_parked();
1659        assert_eq!(
1660            buffer.read_with(cx, |buffer, _| buffer.text()),
1661            "zero\none\ntwo\nthree\nfour\nfive\n"
1662        );
1663        assert_eq!(
1664            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1665            "zero\none\ntwo\nthree\nfour\nfive\n"
1666        );
1667        request.await.unwrap();
1668    }
1669
1670    #[gpui::test]
1671    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1672        init_test(cx);
1673
1674        let fs = FakeFs::new(cx.executor());
1675        let project = Project::test(fs, [], cx).await;
1676        let id = acp::ToolCallId("test".into());
1677
1678        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1679            let id = id.clone();
1680            move |_, thread, mut cx| {
1681                let id = id.clone();
1682                async move {
1683                    thread
1684                        .update(&mut cx, |thread, cx| {
1685                            thread.handle_session_update(
1686                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1687                                    id: id.clone(),
1688                                    title: "Label".into(),
1689                                    kind: acp::ToolKind::Fetch,
1690                                    status: acp::ToolCallStatus::InProgress,
1691                                    content: vec![],
1692                                    locations: vec![],
1693                                    raw_input: None,
1694                                    raw_output: None,
1695                                }),
1696                                cx,
1697                            )
1698                        })
1699                        .unwrap()
1700                        .unwrap();
1701                    Ok(acp::PromptResponse {
1702                        stop_reason: acp::StopReason::EndTurn,
1703                    })
1704                }
1705                .boxed_local()
1706            }
1707        }));
1708
1709        let thread = cx
1710            .spawn(async move |mut cx| {
1711                connection
1712                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1713                    .await
1714            })
1715            .await
1716            .unwrap();
1717
1718        let request = thread.update(cx, |thread, cx| {
1719            thread.send_raw("Fetch https://example.com", cx)
1720        });
1721
1722        run_until_first_tool_call(&thread, cx).await;
1723
1724        thread.read_with(cx, |thread, _| {
1725            assert!(matches!(
1726                thread.entries[1],
1727                AgentThreadEntry::ToolCall(ToolCall {
1728                    status: ToolCallStatus::Allowed {
1729                        status: acp::ToolCallStatus::InProgress,
1730                        ..
1731                    },
1732                    ..
1733                })
1734            ));
1735        });
1736
1737        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1738
1739        thread.read_with(cx, |thread, _| {
1740            assert!(matches!(
1741                &thread.entries[1],
1742                AgentThreadEntry::ToolCall(ToolCall {
1743                    status: ToolCallStatus::Canceled,
1744                    ..
1745                })
1746            ));
1747        });
1748
1749        thread
1750            .update(cx, |thread, cx| {
1751                thread.handle_session_update(
1752                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1753                        id,
1754                        fields: acp::ToolCallUpdateFields {
1755                            status: Some(acp::ToolCallStatus::Completed),
1756                            ..Default::default()
1757                        },
1758                    }),
1759                    cx,
1760                )
1761            })
1762            .unwrap();
1763
1764        request.await.unwrap();
1765
1766        thread.read_with(cx, |thread, _| {
1767            assert!(matches!(
1768                thread.entries[1],
1769                AgentThreadEntry::ToolCall(ToolCall {
1770                    status: ToolCallStatus::Allowed {
1771                        status: acp::ToolCallStatus::Completed,
1772                        ..
1773                    },
1774                    ..
1775                })
1776            ));
1777        });
1778    }
1779
1780    #[gpui::test]
1781    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1782        init_test(cx);
1783        let fs = FakeFs::new(cx.background_executor.clone());
1784        fs.insert_tree(path!("/test"), json!({})).await;
1785        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1786
1787        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1788            move |_, thread, mut cx| {
1789                async move {
1790                    thread
1791                        .update(&mut cx, |thread, cx| {
1792                            thread.handle_session_update(
1793                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1794                                    id: acp::ToolCallId("test".into()),
1795                                    title: "Label".into(),
1796                                    kind: acp::ToolKind::Edit,
1797                                    status: acp::ToolCallStatus::Completed,
1798                                    content: vec![acp::ToolCallContent::Diff {
1799                                        diff: acp::Diff {
1800                                            path: "/test/test.txt".into(),
1801                                            old_text: None,
1802                                            new_text: "foo".into(),
1803                                        },
1804                                    }],
1805                                    locations: vec![],
1806                                    raw_input: None,
1807                                    raw_output: None,
1808                                }),
1809                                cx,
1810                            )
1811                        })
1812                        .unwrap()
1813                        .unwrap();
1814                    Ok(acp::PromptResponse {
1815                        stop_reason: acp::StopReason::EndTurn,
1816                    })
1817                }
1818                .boxed_local()
1819            }
1820        }));
1821
1822        let thread = connection
1823            .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1824            .await
1825            .unwrap();
1826        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1827            .await
1828            .unwrap();
1829
1830        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1831    }
1832
1833    async fn run_until_first_tool_call(
1834        thread: &Entity<AcpThread>,
1835        cx: &mut TestAppContext,
1836    ) -> usize {
1837        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1838
1839        let subscription = cx.update(|cx| {
1840            cx.subscribe(thread, move |thread, _, cx| {
1841                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1842                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1843                        return tx.try_send(ix).unwrap();
1844                    }
1845                }
1846            })
1847        });
1848
1849        select! {
1850            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1851                panic!("Timeout waiting for tool call")
1852            }
1853            ix = rx.next().fuse() => {
1854                drop(subscription);
1855                ix.unwrap()
1856            }
1857        }
1858    }
1859
1860    #[derive(Clone, Default)]
1861    struct FakeAgentConnection {
1862        auth_methods: Vec<acp::AuthMethod>,
1863        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1864        on_user_message: Option<
1865            Rc<
1866                dyn Fn(
1867                        acp::PromptRequest,
1868                        WeakEntity<AcpThread>,
1869                        AsyncApp,
1870                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1871                    + 'static,
1872            >,
1873        >,
1874    }
1875
1876    impl FakeAgentConnection {
1877        fn new() -> Self {
1878            Self {
1879                auth_methods: Vec::new(),
1880                on_user_message: None,
1881                sessions: Arc::default(),
1882            }
1883        }
1884
1885        #[expect(unused)]
1886        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1887            self.auth_methods = auth_methods;
1888            self
1889        }
1890
1891        fn on_user_message(
1892            mut self,
1893            handler: impl Fn(
1894                acp::PromptRequest,
1895                WeakEntity<AcpThread>,
1896                AsyncApp,
1897            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1898            + 'static,
1899        ) -> Self {
1900            self.on_user_message.replace(Rc::new(handler));
1901            self
1902        }
1903    }
1904
1905    impl AgentConnection for FakeAgentConnection {
1906        fn auth_methods(&self) -> &[acp::AuthMethod] {
1907            &self.auth_methods
1908        }
1909
1910        fn new_thread(
1911            self: Rc<Self>,
1912            project: Entity<Project>,
1913            _cwd: &Path,
1914            cx: &mut gpui::AsyncApp,
1915        ) -> Task<gpui::Result<Entity<AcpThread>>> {
1916            let session_id = acp::SessionId(
1917                rand::thread_rng()
1918                    .sample_iter(&rand::distributions::Alphanumeric)
1919                    .take(7)
1920                    .map(char::from)
1921                    .collect::<String>()
1922                    .into(),
1923            );
1924            let thread = cx
1925                .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1926                .unwrap();
1927            self.sessions.lock().insert(session_id, thread.downgrade());
1928            Task::ready(Ok(thread))
1929        }
1930
1931        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1932            if self.auth_methods().iter().any(|m| m.id == method) {
1933                Task::ready(Ok(()))
1934            } else {
1935                Task::ready(Err(anyhow!("Invalid Auth Method")))
1936            }
1937        }
1938
1939        fn prompt(
1940            &self,
1941            params: acp::PromptRequest,
1942            cx: &mut App,
1943        ) -> Task<gpui::Result<acp::PromptResponse>> {
1944            let sessions = self.sessions.lock();
1945            let thread = sessions.get(&params.session_id).unwrap();
1946            if let Some(handler) = &self.on_user_message {
1947                let handler = handler.clone();
1948                let thread = thread.clone();
1949                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1950            } else {
1951                Task::ready(Ok(acp::PromptResponse {
1952                    stop_reason: acp::StopReason::EndTurn,
1953                }))
1954            }
1955        }
1956
1957        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1958            let sessions = self.sessions.lock();
1959            let thread = sessions.get(&session_id).unwrap().clone();
1960
1961            cx.spawn(async move |cx| {
1962                thread
1963                    .update(cx, |thread, cx| thread.cancel(cx))
1964                    .unwrap()
1965                    .await
1966            })
1967            .detach();
1968        }
1969    }
1970}