acp_thread.rs

   1mod connection;
   2mod diff;
   3
   4pub use connection::*;
   5pub use diff::*;
   6
   7use agent_client_protocol as acp;
   8use anyhow::{Context as _, Result};
   9use assistant_tool::ActionLog;
  10use editor::Bias;
  11use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  12use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
  13use itertools::Itertools;
  14use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, text_diff};
  15use markdown::Markdown;
  16use project::{AgentLocation, Project};
  17use std::collections::HashMap;
  18use std::error::Error;
  19use std::fmt::Formatter;
  20use std::process::ExitStatus;
  21use std::rc::Rc;
  22use std::{
  23    fmt::Display,
  24    mem,
  25    path::{Path, PathBuf},
  26    sync::Arc,
  27};
  28use ui::App;
  29use util::ResultExt;
  30
  31#[derive(Debug)]
  32pub struct UserMessage {
  33    pub content: ContentBlock,
  34}
  35
  36impl UserMessage {
  37    pub fn from_acp(
  38        message: impl IntoIterator<Item = acp::ContentBlock>,
  39        language_registry: Arc<LanguageRegistry>,
  40        cx: &mut App,
  41    ) -> Self {
  42        let mut content = ContentBlock::Empty;
  43        for chunk in message {
  44            content.append(chunk, &language_registry, cx)
  45        }
  46        Self { content: content }
  47    }
  48
  49    fn to_markdown(&self, cx: &App) -> String {
  50        format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
  51    }
  52}
  53
  54#[derive(Debug)]
  55pub struct MentionPath<'a>(&'a Path);
  56
  57impl<'a> MentionPath<'a> {
  58    const PREFIX: &'static str = "@file:";
  59
  60    pub fn new(path: &'a Path) -> Self {
  61        MentionPath(path)
  62    }
  63
  64    pub fn try_parse(url: &'a str) -> Option<Self> {
  65        let path = url.strip_prefix(Self::PREFIX)?;
  66        Some(MentionPath(Path::new(path)))
  67    }
  68
  69    pub fn path(&self) -> &Path {
  70        self.0
  71    }
  72}
  73
  74impl Display for MentionPath<'_> {
  75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  76        write!(
  77            f,
  78            "[@{}]({}{})",
  79            self.0.file_name().unwrap_or_default().display(),
  80            Self::PREFIX,
  81            self.0.display()
  82        )
  83    }
  84}
  85
  86#[derive(Debug, PartialEq)]
  87pub struct AssistantMessage {
  88    pub chunks: Vec<AssistantMessageChunk>,
  89}
  90
  91impl AssistantMessage {
  92    pub fn to_markdown(&self, cx: &App) -> String {
  93        format!(
  94            "## Assistant\n\n{}\n\n",
  95            self.chunks
  96                .iter()
  97                .map(|chunk| chunk.to_markdown(cx))
  98                .join("\n\n")
  99        )
 100    }
 101}
 102
 103#[derive(Debug, PartialEq)]
 104pub enum AssistantMessageChunk {
 105    Message { block: ContentBlock },
 106    Thought { block: ContentBlock },
 107}
 108
 109impl AssistantMessageChunk {
 110    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
 111        Self::Message {
 112            block: ContentBlock::new(chunk.into(), language_registry, cx),
 113        }
 114    }
 115
 116    fn to_markdown(&self, cx: &App) -> String {
 117        match self {
 118            Self::Message { block } => block.to_markdown(cx).to_string(),
 119            Self::Thought { block } => {
 120                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 121            }
 122        }
 123    }
 124}
 125
 126#[derive(Debug)]
 127pub enum AgentThreadEntry {
 128    UserMessage(UserMessage),
 129    AssistantMessage(AssistantMessage),
 130    ToolCall(ToolCall),
 131}
 132
 133impl AgentThreadEntry {
 134    fn to_markdown(&self, cx: &App) -> String {
 135        match self {
 136            Self::UserMessage(message) => message.to_markdown(cx),
 137            Self::AssistantMessage(message) => message.to_markdown(cx),
 138            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 139        }
 140    }
 141
 142    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 143        if let AgentThreadEntry::ToolCall(call) = self {
 144            itertools::Either::Left(call.diffs())
 145        } else {
 146            itertools::Either::Right(std::iter::empty())
 147        }
 148    }
 149
 150    pub fn locations(&self) -> Option<&[acp::ToolCallLocation]> {
 151        if let AgentThreadEntry::ToolCall(ToolCall { locations, .. }) = self {
 152            Some(locations)
 153        } else {
 154            None
 155        }
 156    }
 157}
 158
 159#[derive(Debug)]
 160pub struct ToolCall {
 161    pub id: acp::ToolCallId,
 162    pub label: Entity<Markdown>,
 163    pub kind: acp::ToolKind,
 164    pub content: Vec<ToolCallContent>,
 165    pub status: ToolCallStatus,
 166    pub locations: Vec<acp::ToolCallLocation>,
 167    pub raw_input: Option<serde_json::Value>,
 168    pub raw_output: Option<serde_json::Value>,
 169}
 170
 171impl ToolCall {
 172    fn from_acp(
 173        tool_call: acp::ToolCall,
 174        status: ToolCallStatus,
 175        language_registry: Arc<LanguageRegistry>,
 176        cx: &mut App,
 177    ) -> Self {
 178        Self {
 179            id: tool_call.id,
 180            label: cx.new(|cx| {
 181                Markdown::new(
 182                    tool_call.title.into(),
 183                    Some(language_registry.clone()),
 184                    None,
 185                    cx,
 186                )
 187            }),
 188            kind: tool_call.kind,
 189            content: tool_call
 190                .content
 191                .into_iter()
 192                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 193                .collect(),
 194            locations: tool_call.locations,
 195            status,
 196            raw_input: tool_call.raw_input,
 197            raw_output: tool_call.raw_output,
 198        }
 199    }
 200
 201    fn update(
 202        &mut self,
 203        fields: acp::ToolCallUpdateFields,
 204        language_registry: Arc<LanguageRegistry>,
 205        cx: &mut App,
 206    ) {
 207        let acp::ToolCallUpdateFields {
 208            kind,
 209            status,
 210            title,
 211            content,
 212            locations,
 213            raw_input,
 214            raw_output,
 215        } = fields;
 216
 217        if let Some(kind) = kind {
 218            self.kind = kind;
 219        }
 220
 221        if let Some(status) = status {
 222            self.status = ToolCallStatus::Allowed { status };
 223        }
 224
 225        if let Some(title) = title {
 226            self.label.update(cx, |label, cx| {
 227                label.replace(title, cx);
 228            });
 229        }
 230
 231        if let Some(content) = content {
 232            self.content = content
 233                .into_iter()
 234                .map(|chunk| ToolCallContent::from_acp(chunk, language_registry.clone(), cx))
 235                .collect();
 236        }
 237
 238        if let Some(locations) = locations {
 239            self.locations = locations;
 240        }
 241
 242        if let Some(raw_input) = raw_input {
 243            self.raw_input = Some(raw_input);
 244        }
 245
 246        if let Some(raw_output) = raw_output {
 247            self.raw_output = Some(raw_output);
 248        }
 249    }
 250
 251    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 252        self.content.iter().filter_map(|content| match content {
 253            ToolCallContent::ContentBlock { .. } => None,
 254            ToolCallContent::Diff { diff } => Some(diff),
 255        })
 256    }
 257
 258    fn to_markdown(&self, cx: &App) -> String {
 259        let mut markdown = format!(
 260            "**Tool Call: {}**\nStatus: {}\n\n",
 261            self.label.read(cx).source(),
 262            self.status
 263        );
 264        for content in &self.content {
 265            markdown.push_str(content.to_markdown(cx).as_str());
 266            markdown.push_str("\n\n");
 267        }
 268        markdown
 269    }
 270}
 271
 272#[derive(Debug)]
 273pub enum ToolCallStatus {
 274    WaitingForConfirmation {
 275        options: Vec<acp::PermissionOption>,
 276        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 277    },
 278    Allowed {
 279        status: acp::ToolCallStatus,
 280    },
 281    Rejected,
 282    Canceled,
 283}
 284
 285impl Display for ToolCallStatus {
 286    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 287        write!(
 288            f,
 289            "{}",
 290            match self {
 291                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 292                ToolCallStatus::Allowed { status } => match status {
 293                    acp::ToolCallStatus::Pending => "Pending",
 294                    acp::ToolCallStatus::InProgress => "In Progress",
 295                    acp::ToolCallStatus::Completed => "Completed",
 296                    acp::ToolCallStatus::Failed => "Failed",
 297                },
 298                ToolCallStatus::Rejected => "Rejected",
 299                ToolCallStatus::Canceled => "Canceled",
 300            }
 301        )
 302    }
 303}
 304
 305#[derive(Debug, PartialEq, Clone)]
 306pub enum ContentBlock {
 307    Empty,
 308    Markdown { markdown: Entity<Markdown> },
 309}
 310
 311impl ContentBlock {
 312    pub fn new(
 313        block: acp::ContentBlock,
 314        language_registry: &Arc<LanguageRegistry>,
 315        cx: &mut App,
 316    ) -> Self {
 317        let mut this = Self::Empty;
 318        this.append(block, language_registry, cx);
 319        this
 320    }
 321
 322    pub fn new_combined(
 323        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 324        language_registry: Arc<LanguageRegistry>,
 325        cx: &mut App,
 326    ) -> Self {
 327        let mut this = Self::Empty;
 328        for block in blocks {
 329            this.append(block, &language_registry, cx);
 330        }
 331        this
 332    }
 333
 334    pub fn append(
 335        &mut self,
 336        block: acp::ContentBlock,
 337        language_registry: &Arc<LanguageRegistry>,
 338        cx: &mut App,
 339    ) {
 340        let new_content = match block {
 341            acp::ContentBlock::Text(text_content) => text_content.text.clone(),
 342            acp::ContentBlock::ResourceLink(resource_link) => {
 343                if let Some(path) = resource_link.uri.strip_prefix("file://") {
 344                    format!("{}", MentionPath(path.as_ref()))
 345                } else {
 346                    resource_link.uri.clone()
 347                }
 348            }
 349            acp::ContentBlock::Image(_)
 350            | acp::ContentBlock::Audio(_)
 351            | acp::ContentBlock::Resource(_) => String::new(),
 352        };
 353
 354        match self {
 355            ContentBlock::Empty => {
 356                *self = ContentBlock::Markdown {
 357                    markdown: cx.new(|cx| {
 358                        Markdown::new(
 359                            new_content.into(),
 360                            Some(language_registry.clone()),
 361                            None,
 362                            cx,
 363                        )
 364                    }),
 365                };
 366            }
 367            ContentBlock::Markdown { markdown } => {
 368                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 369            }
 370        }
 371    }
 372
 373    fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 374        match self {
 375            ContentBlock::Empty => "",
 376            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 377        }
 378    }
 379
 380    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 381        match self {
 382            ContentBlock::Empty => None,
 383            ContentBlock::Markdown { markdown } => Some(markdown),
 384        }
 385    }
 386}
 387
 388#[derive(Debug)]
 389pub enum ToolCallContent {
 390    ContentBlock { content: ContentBlock },
 391    Diff { diff: Entity<Diff> },
 392}
 393
 394impl ToolCallContent {
 395    pub fn from_acp(
 396        content: acp::ToolCallContent,
 397        language_registry: Arc<LanguageRegistry>,
 398        cx: &mut App,
 399    ) -> Self {
 400        match content {
 401            acp::ToolCallContent::Content { content } => Self::ContentBlock {
 402                content: ContentBlock::new(content, &language_registry, cx),
 403            },
 404            acp::ToolCallContent::Diff { diff } => Self::Diff {
 405                diff: cx.new(|cx| Diff::from_acp(diff, language_registry, cx)),
 406            },
 407        }
 408    }
 409
 410    pub fn to_markdown(&self, cx: &App) -> String {
 411        match self {
 412            Self::ContentBlock { content } => content.to_markdown(cx).to_string(),
 413            Self::Diff { diff } => diff.read(cx).to_markdown(cx),
 414        }
 415    }
 416}
 417
 418#[derive(Debug, Default)]
 419pub struct Plan {
 420    pub entries: Vec<PlanEntry>,
 421}
 422
 423#[derive(Debug)]
 424pub struct PlanStats<'a> {
 425    pub in_progress_entry: Option<&'a PlanEntry>,
 426    pub pending: u32,
 427    pub completed: u32,
 428}
 429
 430impl Plan {
 431    pub fn is_empty(&self) -> bool {
 432        self.entries.is_empty()
 433    }
 434
 435    pub fn stats(&self) -> PlanStats<'_> {
 436        let mut stats = PlanStats {
 437            in_progress_entry: None,
 438            pending: 0,
 439            completed: 0,
 440        };
 441
 442        for entry in &self.entries {
 443            match &entry.status {
 444                acp::PlanEntryStatus::Pending => {
 445                    stats.pending += 1;
 446                }
 447                acp::PlanEntryStatus::InProgress => {
 448                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 449                }
 450                acp::PlanEntryStatus::Completed => {
 451                    stats.completed += 1;
 452                }
 453            }
 454        }
 455
 456        stats
 457    }
 458}
 459
 460#[derive(Debug)]
 461pub struct PlanEntry {
 462    pub content: Entity<Markdown>,
 463    pub priority: acp::PlanEntryPriority,
 464    pub status: acp::PlanEntryStatus,
 465}
 466
 467impl PlanEntry {
 468    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 469        Self {
 470            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 471            priority: entry.priority,
 472            status: entry.status,
 473        }
 474    }
 475}
 476
 477pub struct AcpThread {
 478    title: SharedString,
 479    entries: Vec<AgentThreadEntry>,
 480    plan: Plan,
 481    project: Entity<Project>,
 482    action_log: Entity<ActionLog>,
 483    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 484    send_task: Option<Task<()>>,
 485    connection: Rc<dyn AgentConnection>,
 486    session_id: acp::SessionId,
 487}
 488
 489pub enum AcpThreadEvent {
 490    NewEntry,
 491    EntryUpdated(usize),
 492    ToolAuthorizationRequired,
 493    Stopped,
 494    Error,
 495    ServerExited(ExitStatus),
 496}
 497
 498impl EventEmitter<AcpThreadEvent> for AcpThread {}
 499
 500#[derive(PartialEq, Eq)]
 501pub enum ThreadStatus {
 502    Idle,
 503    WaitingForToolConfirmation,
 504    Generating,
 505}
 506
 507#[derive(Debug, Clone)]
 508pub enum LoadError {
 509    Unsupported {
 510        error_message: SharedString,
 511        upgrade_message: SharedString,
 512        upgrade_command: String,
 513    },
 514    Exited(i32),
 515    Other(SharedString),
 516}
 517
 518impl Display for LoadError {
 519    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 520        match self {
 521            LoadError::Unsupported { error_message, .. } => write!(f, "{}", error_message),
 522            LoadError::Exited(status) => write!(f, "Server exited with status {}", status),
 523            LoadError::Other(msg) => write!(f, "{}", msg),
 524        }
 525    }
 526}
 527
 528impl Error for LoadError {}
 529
 530impl AcpThread {
 531    pub fn new(
 532        title: impl Into<SharedString>,
 533        connection: Rc<dyn AgentConnection>,
 534        project: Entity<Project>,
 535        session_id: acp::SessionId,
 536        cx: &mut Context<Self>,
 537    ) -> Self {
 538        let action_log = cx.new(|_| ActionLog::new(project.clone()));
 539
 540        Self {
 541            action_log,
 542            shared_buffers: Default::default(),
 543            entries: Default::default(),
 544            plan: Default::default(),
 545            title: title.into(),
 546            project,
 547            send_task: None,
 548            connection,
 549            session_id,
 550        }
 551    }
 552
 553    pub fn action_log(&self) -> &Entity<ActionLog> {
 554        &self.action_log
 555    }
 556
 557    pub fn project(&self) -> &Entity<Project> {
 558        &self.project
 559    }
 560
 561    pub fn title(&self) -> SharedString {
 562        self.title.clone()
 563    }
 564
 565    pub fn entries(&self) -> &[AgentThreadEntry] {
 566        &self.entries
 567    }
 568
 569    pub fn session_id(&self) -> &acp::SessionId {
 570        &self.session_id
 571    }
 572
 573    pub fn status(&self) -> ThreadStatus {
 574        if self.send_task.is_some() {
 575            if self.waiting_for_tool_confirmation() {
 576                ThreadStatus::WaitingForToolConfirmation
 577            } else {
 578                ThreadStatus::Generating
 579            }
 580        } else {
 581            ThreadStatus::Idle
 582        }
 583    }
 584
 585    pub fn has_pending_edit_tool_calls(&self) -> bool {
 586        for entry in self.entries.iter().rev() {
 587            match entry {
 588                AgentThreadEntry::UserMessage(_) => return false,
 589                AgentThreadEntry::ToolCall(
 590                    call @ ToolCall {
 591                        status:
 592                            ToolCallStatus::Allowed {
 593                                status:
 594                                    acp::ToolCallStatus::InProgress | acp::ToolCallStatus::Pending,
 595                            },
 596                        ..
 597                    },
 598                ) if call.diffs().next().is_some() => {
 599                    return true;
 600                }
 601                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 602            }
 603        }
 604
 605        false
 606    }
 607
 608    pub fn used_tools_since_last_user_message(&self) -> bool {
 609        for entry in self.entries.iter().rev() {
 610            match entry {
 611                AgentThreadEntry::UserMessage(..) => return false,
 612                AgentThreadEntry::AssistantMessage(..) => continue,
 613                AgentThreadEntry::ToolCall(..) => return true,
 614            }
 615        }
 616
 617        false
 618    }
 619
 620    pub fn handle_session_update(
 621        &mut self,
 622        update: acp::SessionUpdate,
 623        cx: &mut Context<Self>,
 624    ) -> Result<()> {
 625        match update {
 626            acp::SessionUpdate::UserMessageChunk { content } => {
 627                self.push_user_content_block(content, cx);
 628            }
 629            acp::SessionUpdate::AgentMessageChunk { content } => {
 630                self.push_assistant_content_block(content, false, cx);
 631            }
 632            acp::SessionUpdate::AgentThoughtChunk { content } => {
 633                self.push_assistant_content_block(content, true, cx);
 634            }
 635            acp::SessionUpdate::ToolCall(tool_call) => {
 636                self.upsert_tool_call(tool_call, cx);
 637            }
 638            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 639                self.update_tool_call(tool_call_update, cx)?;
 640            }
 641            acp::SessionUpdate::Plan(plan) => {
 642                self.update_plan(plan, cx);
 643            }
 644        }
 645        Ok(())
 646    }
 647
 648    pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
 649        let language_registry = self.project.read(cx).languages().clone();
 650        let entries_len = self.entries.len();
 651
 652        if let Some(last_entry) = self.entries.last_mut()
 653            && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
 654        {
 655            content.append(chunk, &language_registry, cx);
 656            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 657        } else {
 658            let content = ContentBlock::new(chunk, &language_registry, cx);
 659            self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
 660        }
 661    }
 662
 663    pub fn push_assistant_content_block(
 664        &mut self,
 665        chunk: acp::ContentBlock,
 666        is_thought: bool,
 667        cx: &mut Context<Self>,
 668    ) {
 669        let language_registry = self.project.read(cx).languages().clone();
 670        let entries_len = self.entries.len();
 671        if let Some(last_entry) = self.entries.last_mut()
 672            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
 673        {
 674            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 675            match (chunks.last_mut(), is_thought) {
 676                (Some(AssistantMessageChunk::Message { block }), false)
 677                | (Some(AssistantMessageChunk::Thought { block }), true) => {
 678                    block.append(chunk, &language_registry, cx)
 679                }
 680                _ => {
 681                    let block = ContentBlock::new(chunk, &language_registry, cx);
 682                    if is_thought {
 683                        chunks.push(AssistantMessageChunk::Thought { block })
 684                    } else {
 685                        chunks.push(AssistantMessageChunk::Message { block })
 686                    }
 687                }
 688            }
 689        } else {
 690            let block = ContentBlock::new(chunk, &language_registry, cx);
 691            let chunk = if is_thought {
 692                AssistantMessageChunk::Thought { block }
 693            } else {
 694                AssistantMessageChunk::Message { block }
 695            };
 696
 697            self.push_entry(
 698                AgentThreadEntry::AssistantMessage(AssistantMessage {
 699                    chunks: vec![chunk],
 700                }),
 701                cx,
 702            );
 703        }
 704    }
 705
 706    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
 707        self.entries.push(entry);
 708        cx.emit(AcpThreadEvent::NewEntry);
 709    }
 710
 711    pub fn update_tool_call(
 712        &mut self,
 713        update: acp::ToolCallUpdate,
 714        cx: &mut Context<Self>,
 715    ) -> Result<()> {
 716        let languages = self.project.read(cx).languages().clone();
 717
 718        let (ix, current_call) = self
 719            .tool_call_mut(&update.id)
 720            .context("Tool call not found")?;
 721        current_call.update(update.fields, languages, cx);
 722
 723        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 724
 725        Ok(())
 726    }
 727
 728    pub fn set_tool_call_diff(
 729        &mut self,
 730        tool_call_id: &acp::ToolCallId,
 731        diff: Entity<Diff>,
 732        cx: &mut Context<Self>,
 733    ) -> Result<()> {
 734        let (ix, current_call) = self
 735            .tool_call_mut(tool_call_id)
 736            .context("Tool call not found")?;
 737        current_call.content.clear();
 738        current_call.content.push(ToolCallContent::Diff { diff });
 739        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 740        Ok(())
 741    }
 742
 743    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
 744    pub fn upsert_tool_call(&mut self, tool_call: acp::ToolCall, cx: &mut Context<Self>) {
 745        let status = ToolCallStatus::Allowed {
 746            status: tool_call.status,
 747        };
 748        self.upsert_tool_call_inner(tool_call, status, cx)
 749    }
 750
 751    pub fn upsert_tool_call_inner(
 752        &mut self,
 753        tool_call: acp::ToolCall,
 754        status: ToolCallStatus,
 755        cx: &mut Context<Self>,
 756    ) {
 757        let language_registry = self.project.read(cx).languages().clone();
 758        let call = ToolCall::from_acp(tool_call, status, language_registry, cx);
 759
 760        let location = call.locations.last().cloned();
 761
 762        if let Some((ix, current_call)) = self.tool_call_mut(&call.id) {
 763            *current_call = call;
 764
 765            cx.emit(AcpThreadEvent::EntryUpdated(ix));
 766        } else {
 767            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
 768        }
 769
 770        if let Some(location) = location {
 771            self.set_project_location(location, cx)
 772        }
 773    }
 774
 775    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
 776        // The tool call we are looking for is typically the last one, or very close to the end.
 777        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
 778        self.entries
 779            .iter_mut()
 780            .enumerate()
 781            .rev()
 782            .find_map(|(index, tool_call)| {
 783                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
 784                    && &tool_call.id == id
 785                {
 786                    Some((index, tool_call))
 787                } else {
 788                    None
 789                }
 790            })
 791    }
 792
 793    pub fn set_project_location(&self, location: acp::ToolCallLocation, cx: &mut Context<Self>) {
 794        self.project.update(cx, |project, cx| {
 795            let Some(path) = project.project_path_for_absolute_path(&location.path, cx) else {
 796                return;
 797            };
 798            let buffer = project.open_buffer(path, cx);
 799            cx.spawn(async move |project, cx| {
 800                let buffer = buffer.await?;
 801
 802                project.update(cx, |project, cx| {
 803                    let position = if let Some(line) = location.line {
 804                        let snapshot = buffer.read(cx).snapshot();
 805                        let point = snapshot.clip_point(Point::new(line, 0), Bias::Left);
 806                        snapshot.anchor_before(point)
 807                    } else {
 808                        Anchor::MIN
 809                    };
 810
 811                    project.set_agent_location(
 812                        Some(AgentLocation {
 813                            buffer: buffer.downgrade(),
 814                            position,
 815                        }),
 816                        cx,
 817                    );
 818                })
 819            })
 820            .detach_and_log_err(cx);
 821        });
 822    }
 823
 824    pub fn request_tool_call_authorization(
 825        &mut self,
 826        tool_call: acp::ToolCall,
 827        options: Vec<acp::PermissionOption>,
 828        cx: &mut Context<Self>,
 829    ) -> oneshot::Receiver<acp::PermissionOptionId> {
 830        let (tx, rx) = oneshot::channel();
 831
 832        let status = ToolCallStatus::WaitingForConfirmation {
 833            options,
 834            respond_tx: tx,
 835        };
 836
 837        self.upsert_tool_call_inner(tool_call, status, cx);
 838        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
 839        rx
 840    }
 841
 842    pub fn authorize_tool_call(
 843        &mut self,
 844        id: acp::ToolCallId,
 845        option_id: acp::PermissionOptionId,
 846        option_kind: acp::PermissionOptionKind,
 847        cx: &mut Context<Self>,
 848    ) {
 849        let Some((ix, call)) = self.tool_call_mut(&id) else {
 850            return;
 851        };
 852
 853        let new_status = match option_kind {
 854            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
 855                ToolCallStatus::Rejected
 856            }
 857            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
 858                ToolCallStatus::Allowed {
 859                    status: acp::ToolCallStatus::InProgress,
 860                }
 861            }
 862        };
 863
 864        let curr_status = mem::replace(&mut call.status, new_status);
 865
 866        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
 867            respond_tx.send(option_id).log_err();
 868        } else if cfg!(debug_assertions) {
 869            panic!("tried to authorize an already authorized tool call");
 870        }
 871
 872        cx.emit(AcpThreadEvent::EntryUpdated(ix));
 873    }
 874
 875    /// Returns true if the last turn is awaiting tool authorization
 876    pub fn waiting_for_tool_confirmation(&self) -> bool {
 877        for entry in self.entries.iter().rev() {
 878            match &entry {
 879                AgentThreadEntry::ToolCall(call) => match call.status {
 880                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
 881                    ToolCallStatus::Allowed { .. }
 882                    | ToolCallStatus::Rejected
 883                    | ToolCallStatus::Canceled => continue,
 884                },
 885                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
 886                    // Reached the beginning of the turn
 887                    return false;
 888                }
 889            }
 890        }
 891        false
 892    }
 893
 894    pub fn plan(&self) -> &Plan {
 895        &self.plan
 896    }
 897
 898    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
 899        let new_entries_len = request.entries.len();
 900        let mut new_entries = request.entries.into_iter();
 901
 902        // Reuse existing markdown to prevent flickering
 903        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
 904            let PlanEntry {
 905                content,
 906                priority,
 907                status,
 908            } = old;
 909            content.update(cx, |old, cx| {
 910                old.replace(new.content, cx);
 911            });
 912            *priority = new.priority;
 913            *status = new.status;
 914        }
 915        for new in new_entries {
 916            self.plan.entries.push(PlanEntry::from_acp(new, cx))
 917        }
 918        self.plan.entries.truncate(new_entries_len);
 919
 920        cx.notify();
 921    }
 922
 923    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
 924        self.plan
 925            .entries
 926            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
 927        cx.notify();
 928    }
 929
 930    #[cfg(any(test, feature = "test-support"))]
 931    pub fn send_raw(
 932        &mut self,
 933        message: &str,
 934        cx: &mut Context<Self>,
 935    ) -> BoxFuture<'static, Result<()>> {
 936        self.send(
 937            vec![acp::ContentBlock::Text(acp::TextContent {
 938                text: message.to_string(),
 939                annotations: None,
 940            })],
 941            cx,
 942        )
 943    }
 944
 945    pub fn send(
 946        &mut self,
 947        message: Vec<acp::ContentBlock>,
 948        cx: &mut Context<Self>,
 949    ) -> BoxFuture<'static, Result<()>> {
 950        let block = ContentBlock::new_combined(
 951            message.clone(),
 952            self.project.read(cx).languages().clone(),
 953            cx,
 954        );
 955        self.push_entry(
 956            AgentThreadEntry::UserMessage(UserMessage { content: block }),
 957            cx,
 958        );
 959        self.clear_completed_plan_entries(cx);
 960
 961        let (tx, rx) = oneshot::channel();
 962        let cancel_task = self.cancel(cx);
 963
 964        self.send_task = Some(cx.spawn(async move |this, cx| {
 965            async {
 966                cancel_task.await;
 967
 968                let result = this
 969                    .update(cx, |this, cx| {
 970                        this.connection.prompt(
 971                            acp::PromptRequest {
 972                                prompt: message,
 973                                session_id: this.session_id.clone(),
 974                            },
 975                            cx,
 976                        )
 977                    })?
 978                    .await;
 979
 980                tx.send(result).log_err();
 981
 982                anyhow::Ok(())
 983            }
 984            .await
 985            .log_err();
 986        }));
 987
 988        cx.spawn(async move |this, cx| match rx.await {
 989            Ok(Err(e)) => {
 990                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Error))
 991                    .log_err();
 992                Err(e)?
 993            }
 994            result => {
 995                let cancelled = matches!(
 996                    result,
 997                    Ok(Ok(acp::PromptResponse {
 998                        stop_reason: acp::StopReason::Cancelled
 999                    }))
1000                );
1001
1002                // We only take the task if the current prompt wasn't cancelled.
1003                //
1004                // This prompt may have been cancelled because another one was sent
1005                // while it was still generating. In these cases, dropping `send_task`
1006                // would cause the next generation to be cancelled.
1007                if !cancelled {
1008                    this.update(cx, |this, _cx| this.send_task.take()).ok();
1009                }
1010
1011                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
1012                    .log_err();
1013                Ok(())
1014            }
1015        })
1016        .boxed()
1017    }
1018
1019    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1020        let Some(send_task) = self.send_task.take() else {
1021            return Task::ready(());
1022        };
1023
1024        for entry in self.entries.iter_mut() {
1025            if let AgentThreadEntry::ToolCall(call) = entry {
1026                let cancel = matches!(
1027                    call.status,
1028                    ToolCallStatus::WaitingForConfirmation { .. }
1029                        | ToolCallStatus::Allowed {
1030                            status: acp::ToolCallStatus::InProgress
1031                        }
1032                );
1033
1034                if cancel {
1035                    call.status = ToolCallStatus::Canceled;
1036                }
1037            }
1038        }
1039
1040        self.connection.cancel(&self.session_id, cx);
1041
1042        // Wait for the send task to complete
1043        cx.foreground_executor().spawn(send_task)
1044    }
1045
1046    pub fn read_text_file(
1047        &self,
1048        path: PathBuf,
1049        line: Option<u32>,
1050        limit: Option<u32>,
1051        reuse_shared_snapshot: bool,
1052        cx: &mut Context<Self>,
1053    ) -> Task<Result<String>> {
1054        let project = self.project.clone();
1055        let action_log = self.action_log.clone();
1056        cx.spawn(async move |this, cx| {
1057            let load = project.update(cx, |project, cx| {
1058                let path = project
1059                    .project_path_for_absolute_path(&path, cx)
1060                    .context("invalid path")?;
1061                anyhow::Ok(project.open_buffer(path, cx))
1062            });
1063            let buffer = load??.await?;
1064
1065            let snapshot = if reuse_shared_snapshot {
1066                this.read_with(cx, |this, _| {
1067                    this.shared_buffers.get(&buffer.clone()).cloned()
1068                })
1069                .log_err()
1070                .flatten()
1071            } else {
1072                None
1073            };
1074
1075            let snapshot = if let Some(snapshot) = snapshot {
1076                snapshot
1077            } else {
1078                action_log.update(cx, |action_log, cx| {
1079                    action_log.buffer_read(buffer.clone(), cx);
1080                })?;
1081                project.update(cx, |project, cx| {
1082                    let position = buffer
1083                        .read(cx)
1084                        .snapshot()
1085                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1086                    project.set_agent_location(
1087                        Some(AgentLocation {
1088                            buffer: buffer.downgrade(),
1089                            position,
1090                        }),
1091                        cx,
1092                    );
1093                })?;
1094
1095                buffer.update(cx, |buffer, _| buffer.snapshot())?
1096            };
1097
1098            this.update(cx, |this, _| {
1099                let text = snapshot.text();
1100                this.shared_buffers.insert(buffer.clone(), snapshot);
1101                if line.is_none() && limit.is_none() {
1102                    return Ok(text);
1103                }
1104                let limit = limit.unwrap_or(u32::MAX) as usize;
1105                let Some(line) = line else {
1106                    return Ok(text.lines().take(limit).collect::<String>());
1107                };
1108
1109                let count = text.lines().count();
1110                if count < line as usize {
1111                    anyhow::bail!("There are only {} lines", count);
1112                }
1113                Ok(text
1114                    .lines()
1115                    .skip(line as usize + 1)
1116                    .take(limit)
1117                    .collect::<String>())
1118            })?
1119        })
1120    }
1121
1122    pub fn write_text_file(
1123        &self,
1124        path: PathBuf,
1125        content: String,
1126        cx: &mut Context<Self>,
1127    ) -> Task<Result<()>> {
1128        let project = self.project.clone();
1129        let action_log = self.action_log.clone();
1130        cx.spawn(async move |this, cx| {
1131            let load = project.update(cx, |project, cx| {
1132                let path = project
1133                    .project_path_for_absolute_path(&path, cx)
1134                    .context("invalid path")?;
1135                anyhow::Ok(project.open_buffer(path, cx))
1136            });
1137            let buffer = load??.await?;
1138            let snapshot = this.update(cx, |this, cx| {
1139                this.shared_buffers
1140                    .get(&buffer)
1141                    .cloned()
1142                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1143            })?;
1144            let edits = cx
1145                .background_executor()
1146                .spawn(async move {
1147                    let old_text = snapshot.text();
1148                    text_diff(old_text.as_str(), &content)
1149                        .into_iter()
1150                        .map(|(range, replacement)| {
1151                            (
1152                                snapshot.anchor_after(range.start)
1153                                    ..snapshot.anchor_before(range.end),
1154                                replacement,
1155                            )
1156                        })
1157                        .collect::<Vec<_>>()
1158                })
1159                .await;
1160            cx.update(|cx| {
1161                project.update(cx, |project, cx| {
1162                    project.set_agent_location(
1163                        Some(AgentLocation {
1164                            buffer: buffer.downgrade(),
1165                            position: edits
1166                                .last()
1167                                .map(|(range, _)| range.end)
1168                                .unwrap_or(Anchor::MIN),
1169                        }),
1170                        cx,
1171                    );
1172                });
1173
1174                action_log.update(cx, |action_log, cx| {
1175                    action_log.buffer_read(buffer.clone(), cx);
1176                });
1177                buffer.update(cx, |buffer, cx| {
1178                    buffer.edit(edits, None, cx);
1179                });
1180                action_log.update(cx, |action_log, cx| {
1181                    action_log.buffer_edited(buffer.clone(), cx);
1182                });
1183            })?;
1184            project
1185                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1186                .await
1187        })
1188    }
1189
1190    pub fn to_markdown(&self, cx: &App) -> String {
1191        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1192    }
1193
1194    pub fn emit_server_exited(&mut self, status: ExitStatus, cx: &mut Context<Self>) {
1195        cx.emit(AcpThreadEvent::ServerExited(status));
1196    }
1197}
1198
1199#[cfg(test)]
1200mod tests {
1201    use super::*;
1202    use anyhow::anyhow;
1203    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1204    use gpui::{AsyncApp, TestAppContext, WeakEntity};
1205    use indoc::indoc;
1206    use project::FakeFs;
1207    use rand::Rng as _;
1208    use serde_json::json;
1209    use settings::SettingsStore;
1210    use smol::stream::StreamExt as _;
1211    use std::{cell::RefCell, rc::Rc, time::Duration};
1212
1213    use util::path;
1214
1215    fn init_test(cx: &mut TestAppContext) {
1216        env_logger::try_init().ok();
1217        cx.update(|cx| {
1218            let settings_store = SettingsStore::test(cx);
1219            cx.set_global(settings_store);
1220            Project::init_settings(cx);
1221            language::init(cx);
1222        });
1223    }
1224
1225    #[gpui::test]
1226    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1227        init_test(cx);
1228
1229        let fs = FakeFs::new(cx.executor());
1230        let project = Project::test(fs, [], cx).await;
1231        let connection = Rc::new(FakeAgentConnection::new());
1232        let thread = cx
1233            .spawn(async move |mut cx| {
1234                connection
1235                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1236                    .await
1237            })
1238            .await
1239            .unwrap();
1240
1241        // Test creating a new user message
1242        thread.update(cx, |thread, cx| {
1243            thread.push_user_content_block(
1244                acp::ContentBlock::Text(acp::TextContent {
1245                    annotations: None,
1246                    text: "Hello, ".to_string(),
1247                }),
1248                cx,
1249            );
1250        });
1251
1252        thread.update(cx, |thread, cx| {
1253            assert_eq!(thread.entries.len(), 1);
1254            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1255                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1256            } else {
1257                panic!("Expected UserMessage");
1258            }
1259        });
1260
1261        // Test appending to existing user message
1262        thread.update(cx, |thread, cx| {
1263            thread.push_user_content_block(
1264                acp::ContentBlock::Text(acp::TextContent {
1265                    annotations: None,
1266                    text: "world!".to_string(),
1267                }),
1268                cx,
1269            );
1270        });
1271
1272        thread.update(cx, |thread, cx| {
1273            assert_eq!(thread.entries.len(), 1);
1274            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1275                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1276            } else {
1277                panic!("Expected UserMessage");
1278            }
1279        });
1280
1281        // Test creating new user message after assistant message
1282        thread.update(cx, |thread, cx| {
1283            thread.push_assistant_content_block(
1284                acp::ContentBlock::Text(acp::TextContent {
1285                    annotations: None,
1286                    text: "Assistant response".to_string(),
1287                }),
1288                false,
1289                cx,
1290            );
1291        });
1292
1293        thread.update(cx, |thread, cx| {
1294            thread.push_user_content_block(
1295                acp::ContentBlock::Text(acp::TextContent {
1296                    annotations: None,
1297                    text: "New user message".to_string(),
1298                }),
1299                cx,
1300            );
1301        });
1302
1303        thread.update(cx, |thread, cx| {
1304            assert_eq!(thread.entries.len(), 3);
1305            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1306                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1307            } else {
1308                panic!("Expected UserMessage at index 2");
1309            }
1310        });
1311    }
1312
1313    #[gpui::test]
1314    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1315        init_test(cx);
1316
1317        let fs = FakeFs::new(cx.executor());
1318        let project = Project::test(fs, [], cx).await;
1319        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1320            |_, thread, mut cx| {
1321                async move {
1322                    thread.update(&mut cx, |thread, cx| {
1323                        thread
1324                            .handle_session_update(
1325                                acp::SessionUpdate::AgentThoughtChunk {
1326                                    content: "Thinking ".into(),
1327                                },
1328                                cx,
1329                            )
1330                            .unwrap();
1331                        thread
1332                            .handle_session_update(
1333                                acp::SessionUpdate::AgentThoughtChunk {
1334                                    content: "hard!".into(),
1335                                },
1336                                cx,
1337                            )
1338                            .unwrap();
1339                    })?;
1340                    Ok(acp::PromptResponse {
1341                        stop_reason: acp::StopReason::EndTurn,
1342                    })
1343                }
1344                .boxed_local()
1345            },
1346        ));
1347
1348        let thread = cx
1349            .spawn(async move |mut cx| {
1350                connection
1351                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1352                    .await
1353            })
1354            .await
1355            .unwrap();
1356
1357        thread
1358            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
1359            .await
1360            .unwrap();
1361
1362        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
1363        assert_eq!(
1364            output,
1365            indoc! {r#"
1366            ## User
1367
1368            Hello from Zed!
1369
1370            ## Assistant
1371
1372            <thinking>
1373            Thinking hard!
1374            </thinking>
1375
1376            "#}
1377        );
1378    }
1379
1380    #[gpui::test]
1381    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
1382        init_test(cx);
1383
1384        let fs = FakeFs::new(cx.executor());
1385        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
1386            .await;
1387        let project = Project::test(fs.clone(), [], cx).await;
1388        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
1389        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
1390        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1391            move |_, thread, mut cx| {
1392                let read_file_tx = read_file_tx.clone();
1393                async move {
1394                    let content = thread
1395                        .update(&mut cx, |thread, cx| {
1396                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
1397                        })
1398                        .unwrap()
1399                        .await
1400                        .unwrap();
1401                    assert_eq!(content, "one\ntwo\nthree\n");
1402                    read_file_tx.take().unwrap().send(()).unwrap();
1403                    thread
1404                        .update(&mut cx, |thread, cx| {
1405                            thread.write_text_file(
1406                                path!("/tmp/foo").into(),
1407                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
1408                                cx,
1409                            )
1410                        })
1411                        .unwrap()
1412                        .await
1413                        .unwrap();
1414                    Ok(acp::PromptResponse {
1415                        stop_reason: acp::StopReason::EndTurn,
1416                    })
1417                }
1418                .boxed_local()
1419            },
1420        ));
1421
1422        let (worktree, pathbuf) = project
1423            .update(cx, |project, cx| {
1424                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
1425            })
1426            .await
1427            .unwrap();
1428        let buffer = project
1429            .update(cx, |project, cx| {
1430                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
1431            })
1432            .await
1433            .unwrap();
1434
1435        let thread = cx
1436            .spawn(|mut cx| connection.new_thread(project, Path::new(path!("/tmp")), &mut cx))
1437            .await
1438            .unwrap();
1439
1440        let request = thread.update(cx, |thread, cx| {
1441            thread.send_raw("Extend the count in /tmp/foo", cx)
1442        });
1443        read_file_rx.await.ok();
1444        buffer.update(cx, |buffer, cx| {
1445            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
1446        });
1447        cx.run_until_parked();
1448        assert_eq!(
1449            buffer.read_with(cx, |buffer, _| buffer.text()),
1450            "zero\none\ntwo\nthree\nfour\nfive\n"
1451        );
1452        assert_eq!(
1453            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
1454            "zero\none\ntwo\nthree\nfour\nfive\n"
1455        );
1456        request.await.unwrap();
1457    }
1458
1459    #[gpui::test]
1460    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
1461        init_test(cx);
1462
1463        let fs = FakeFs::new(cx.executor());
1464        let project = Project::test(fs, [], cx).await;
1465        let id = acp::ToolCallId("test".into());
1466
1467        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1468            let id = id.clone();
1469            move |_, thread, mut cx| {
1470                let id = id.clone();
1471                async move {
1472                    thread
1473                        .update(&mut cx, |thread, cx| {
1474                            thread.handle_session_update(
1475                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1476                                    id: id.clone(),
1477                                    title: "Label".into(),
1478                                    kind: acp::ToolKind::Fetch,
1479                                    status: acp::ToolCallStatus::InProgress,
1480                                    content: vec![],
1481                                    locations: vec![],
1482                                    raw_input: None,
1483                                    raw_output: None,
1484                                }),
1485                                cx,
1486                            )
1487                        })
1488                        .unwrap()
1489                        .unwrap();
1490                    Ok(acp::PromptResponse {
1491                        stop_reason: acp::StopReason::EndTurn,
1492                    })
1493                }
1494                .boxed_local()
1495            }
1496        }));
1497
1498        let thread = cx
1499            .spawn(async move |mut cx| {
1500                connection
1501                    .new_thread(project, Path::new(path!("/test")), &mut cx)
1502                    .await
1503            })
1504            .await
1505            .unwrap();
1506
1507        let request = thread.update(cx, |thread, cx| {
1508            thread.send_raw("Fetch https://example.com", cx)
1509        });
1510
1511        run_until_first_tool_call(&thread, cx).await;
1512
1513        thread.read_with(cx, |thread, _| {
1514            assert!(matches!(
1515                thread.entries[1],
1516                AgentThreadEntry::ToolCall(ToolCall {
1517                    status: ToolCallStatus::Allowed {
1518                        status: acp::ToolCallStatus::InProgress,
1519                        ..
1520                    },
1521                    ..
1522                })
1523            ));
1524        });
1525
1526        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1527
1528        thread.read_with(cx, |thread, _| {
1529            assert!(matches!(
1530                &thread.entries[1],
1531                AgentThreadEntry::ToolCall(ToolCall {
1532                    status: ToolCallStatus::Canceled,
1533                    ..
1534                })
1535            ));
1536        });
1537
1538        thread
1539            .update(cx, |thread, cx| {
1540                thread.handle_session_update(
1541                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
1542                        id,
1543                        fields: acp::ToolCallUpdateFields {
1544                            status: Some(acp::ToolCallStatus::Completed),
1545                            ..Default::default()
1546                        },
1547                    }),
1548                    cx,
1549                )
1550            })
1551            .unwrap();
1552
1553        request.await.unwrap();
1554
1555        thread.read_with(cx, |thread, _| {
1556            assert!(matches!(
1557                thread.entries[1],
1558                AgentThreadEntry::ToolCall(ToolCall {
1559                    status: ToolCallStatus::Allowed {
1560                        status: acp::ToolCallStatus::Completed,
1561                        ..
1562                    },
1563                    ..
1564                })
1565            ));
1566        });
1567    }
1568
1569    #[gpui::test]
1570    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
1571        init_test(cx);
1572        let fs = FakeFs::new(cx.background_executor.clone());
1573        fs.insert_tree(path!("/test"), json!({})).await;
1574        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
1575
1576        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
1577            move |_, thread, mut cx| {
1578                async move {
1579                    thread
1580                        .update(&mut cx, |thread, cx| {
1581                            thread.handle_session_update(
1582                                acp::SessionUpdate::ToolCall(acp::ToolCall {
1583                                    id: acp::ToolCallId("test".into()),
1584                                    title: "Label".into(),
1585                                    kind: acp::ToolKind::Edit,
1586                                    status: acp::ToolCallStatus::Completed,
1587                                    content: vec![acp::ToolCallContent::Diff {
1588                                        diff: acp::Diff {
1589                                            path: "/test/test.txt".into(),
1590                                            old_text: None,
1591                                            new_text: "foo".into(),
1592                                        },
1593                                    }],
1594                                    locations: vec![],
1595                                    raw_input: None,
1596                                    raw_output: None,
1597                                }),
1598                                cx,
1599                            )
1600                        })
1601                        .unwrap()
1602                        .unwrap();
1603                    Ok(acp::PromptResponse {
1604                        stop_reason: acp::StopReason::EndTurn,
1605                    })
1606                }
1607                .boxed_local()
1608            }
1609        }));
1610
1611        let thread = connection
1612            .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
1613            .await
1614            .unwrap();
1615        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
1616            .await
1617            .unwrap();
1618
1619        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
1620    }
1621
1622    async fn run_until_first_tool_call(
1623        thread: &Entity<AcpThread>,
1624        cx: &mut TestAppContext,
1625    ) -> usize {
1626        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1627
1628        let subscription = cx.update(|cx| {
1629            cx.subscribe(thread, move |thread, _, cx| {
1630                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1631                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
1632                        return tx.try_send(ix).unwrap();
1633                    }
1634                }
1635            })
1636        });
1637
1638        select! {
1639            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1640                panic!("Timeout waiting for tool call")
1641            }
1642            ix = rx.next().fuse() => {
1643                drop(subscription);
1644                ix.unwrap()
1645            }
1646        }
1647    }
1648
1649    #[derive(Clone, Default)]
1650    struct FakeAgentConnection {
1651        auth_methods: Vec<acp::AuthMethod>,
1652        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
1653        on_user_message: Option<
1654            Rc<
1655                dyn Fn(
1656                        acp::PromptRequest,
1657                        WeakEntity<AcpThread>,
1658                        AsyncApp,
1659                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1660                    + 'static,
1661            >,
1662        >,
1663    }
1664
1665    impl FakeAgentConnection {
1666        fn new() -> Self {
1667            Self {
1668                auth_methods: Vec::new(),
1669                on_user_message: None,
1670                sessions: Arc::default(),
1671            }
1672        }
1673
1674        #[expect(unused)]
1675        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
1676            self.auth_methods = auth_methods;
1677            self
1678        }
1679
1680        fn on_user_message(
1681            mut self,
1682            handler: impl Fn(
1683                acp::PromptRequest,
1684                WeakEntity<AcpThread>,
1685                AsyncApp,
1686            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
1687            + 'static,
1688        ) -> Self {
1689            self.on_user_message.replace(Rc::new(handler));
1690            self
1691        }
1692    }
1693
1694    impl AgentConnection for FakeAgentConnection {
1695        fn auth_methods(&self) -> &[acp::AuthMethod] {
1696            &self.auth_methods
1697        }
1698
1699        fn new_thread(
1700            self: Rc<Self>,
1701            project: Entity<Project>,
1702            _cwd: &Path,
1703            cx: &mut gpui::AsyncApp,
1704        ) -> Task<gpui::Result<Entity<AcpThread>>> {
1705            let session_id = acp::SessionId(
1706                rand::thread_rng()
1707                    .sample_iter(&rand::distributions::Alphanumeric)
1708                    .take(7)
1709                    .map(char::from)
1710                    .collect::<String>()
1711                    .into(),
1712            );
1713            let thread = cx
1714                .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
1715                .unwrap();
1716            self.sessions.lock().insert(session_id, thread.downgrade());
1717            Task::ready(Ok(thread))
1718        }
1719
1720        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
1721            if self.auth_methods().iter().any(|m| m.id == method) {
1722                Task::ready(Ok(()))
1723            } else {
1724                Task::ready(Err(anyhow!("Invalid Auth Method")))
1725            }
1726        }
1727
1728        fn prompt(
1729            &self,
1730            params: acp::PromptRequest,
1731            cx: &mut App,
1732        ) -> Task<gpui::Result<acp::PromptResponse>> {
1733            let sessions = self.sessions.lock();
1734            let thread = sessions.get(&params.session_id).unwrap();
1735            if let Some(handler) = &self.on_user_message {
1736                let handler = handler.clone();
1737                let thread = thread.clone();
1738                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
1739            } else {
1740                Task::ready(Ok(acp::PromptResponse {
1741                    stop_reason: acp::StopReason::EndTurn,
1742                }))
1743            }
1744        }
1745
1746        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
1747            let sessions = self.sessions.lock();
1748            let thread = sessions.get(&session_id).unwrap().clone();
1749
1750            cx.spawn(async move |cx| {
1751                thread
1752                    .update(cx, |thread, cx| thread.cancel(cx))
1753                    .unwrap()
1754                    .await
1755            })
1756            .detach();
1757        }
1758    }
1759}