acp_thread.rs

   1mod connection;
   2mod diff;
   3mod mention;
   4mod terminal;
   5
   6use agent_settings::AgentSettings;
   7use collections::HashSet;
   8pub use connection::*;
   9pub use diff::*;
  10use futures::future::Shared;
  11use language::language_settings::FormatOnSave;
  12pub use mention::*;
  13use project::lsp_store::{FormatTrigger, LspFormatTarget};
  14use serde::{Deserialize, Serialize};
  15use settings::Settings as _;
  16pub use terminal::*;
  17
  18use action_log::ActionLog;
  19use agent_client_protocol::{self as acp};
  20use anyhow::{Context as _, Result, anyhow};
  21use editor::Bias;
  22use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  23use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  24use itertools::Itertools;
  25use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
  26use markdown::Markdown;
  27use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
  28use std::collections::HashMap;
  29use std::error::Error;
  30use std::fmt::{Formatter, Write};
  31use std::ops::Range;
  32use std::process::ExitStatus;
  33use std::rc::Rc;
  34use std::time::{Duration, Instant};
  35use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
  36use ui::App;
  37use util::{ResultExt, get_system_shell};
  38use uuid::Uuid;
  39
  40#[derive(Debug)]
  41pub struct UserMessage {
  42    pub id: Option<UserMessageId>,
  43    pub content: ContentBlock,
  44    pub chunks: Vec<acp::ContentBlock>,
  45    pub checkpoint: Option<Checkpoint>,
  46}
  47
  48#[derive(Debug)]
  49pub struct Checkpoint {
  50    git_checkpoint: GitStoreCheckpoint,
  51    pub show: bool,
  52}
  53
  54impl UserMessage {
  55    fn to_markdown(&self, cx: &App) -> String {
  56        let mut markdown = String::new();
  57        if self
  58            .checkpoint
  59            .as_ref()
  60            .is_some_and(|checkpoint| checkpoint.show)
  61        {
  62            writeln!(markdown, "## User (checkpoint)").unwrap();
  63        } else {
  64            writeln!(markdown, "## User").unwrap();
  65        }
  66        writeln!(markdown).unwrap();
  67        writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
  68        writeln!(markdown).unwrap();
  69        markdown
  70    }
  71}
  72
  73#[derive(Debug, PartialEq)]
  74pub struct AssistantMessage {
  75    pub chunks: Vec<AssistantMessageChunk>,
  76}
  77
  78impl AssistantMessage {
  79    pub fn to_markdown(&self, cx: &App) -> String {
  80        format!(
  81            "## Assistant\n\n{}\n\n",
  82            self.chunks
  83                .iter()
  84                .map(|chunk| chunk.to_markdown(cx))
  85                .join("\n\n")
  86        )
  87    }
  88}
  89
  90#[derive(Debug, PartialEq)]
  91pub enum AssistantMessageChunk {
  92    Message { block: ContentBlock },
  93    Thought { block: ContentBlock },
  94}
  95
  96impl AssistantMessageChunk {
  97    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
  98        Self::Message {
  99            block: ContentBlock::new(chunk.into(), language_registry, cx),
 100        }
 101    }
 102
 103    fn to_markdown(&self, cx: &App) -> String {
 104        match self {
 105            Self::Message { block } => block.to_markdown(cx).to_string(),
 106            Self::Thought { block } => {
 107                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 108            }
 109        }
 110    }
 111}
 112
 113#[derive(Debug)]
 114pub enum AgentThreadEntry {
 115    UserMessage(UserMessage),
 116    AssistantMessage(AssistantMessage),
 117    ToolCall(ToolCall),
 118}
 119
 120impl AgentThreadEntry {
 121    pub fn to_markdown(&self, cx: &App) -> String {
 122        match self {
 123            Self::UserMessage(message) => message.to_markdown(cx),
 124            Self::AssistantMessage(message) => message.to_markdown(cx),
 125            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 126        }
 127    }
 128
 129    pub fn user_message(&self) -> Option<&UserMessage> {
 130        if let AgentThreadEntry::UserMessage(message) = self {
 131            Some(message)
 132        } else {
 133            None
 134        }
 135    }
 136
 137    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 138        if let AgentThreadEntry::ToolCall(call) = self {
 139            itertools::Either::Left(call.diffs())
 140        } else {
 141            itertools::Either::Right(std::iter::empty())
 142        }
 143    }
 144
 145    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 146        if let AgentThreadEntry::ToolCall(call) = self {
 147            itertools::Either::Left(call.terminals())
 148        } else {
 149            itertools::Either::Right(std::iter::empty())
 150        }
 151    }
 152
 153    pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
 154        if let AgentThreadEntry::ToolCall(ToolCall {
 155            locations,
 156            resolved_locations,
 157            ..
 158        }) = self
 159        {
 160            Some((
 161                locations.get(ix)?.clone(),
 162                resolved_locations.get(ix)?.clone()?,
 163            ))
 164        } else {
 165            None
 166        }
 167    }
 168}
 169
 170#[derive(Debug)]
 171pub struct ToolCall {
 172    pub id: acp::ToolCallId,
 173    pub label: Entity<Markdown>,
 174    pub kind: acp::ToolKind,
 175    pub content: Vec<ToolCallContent>,
 176    pub status: ToolCallStatus,
 177    pub locations: Vec<acp::ToolCallLocation>,
 178    pub resolved_locations: Vec<Option<AgentLocation>>,
 179    pub raw_input: Option<serde_json::Value>,
 180    pub raw_output: Option<serde_json::Value>,
 181}
 182
 183impl ToolCall {
 184    fn from_acp(
 185        tool_call: acp::ToolCall,
 186        status: ToolCallStatus,
 187        language_registry: Arc<LanguageRegistry>,
 188        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 189        cx: &mut App,
 190    ) -> Result<Self> {
 191        let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
 192            first_line.to_owned() + ""
 193        } else {
 194            tool_call.title
 195        };
 196        let mut content = Vec::with_capacity(tool_call.content.len());
 197        for item in tool_call.content {
 198            content.push(ToolCallContent::from_acp(
 199                item,
 200                language_registry.clone(),
 201                terminals,
 202                cx,
 203            )?);
 204        }
 205
 206        let result = Self {
 207            id: tool_call.id,
 208            label: cx
 209                .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
 210            kind: tool_call.kind,
 211            content,
 212            locations: tool_call.locations,
 213            resolved_locations: Vec::default(),
 214            status,
 215            raw_input: tool_call.raw_input,
 216            raw_output: tool_call.raw_output,
 217        };
 218        Ok(result)
 219    }
 220
 221    fn update_fields(
 222        &mut self,
 223        fields: acp::ToolCallUpdateFields,
 224        language_registry: Arc<LanguageRegistry>,
 225        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 226        cx: &mut App,
 227    ) -> Result<()> {
 228        let acp::ToolCallUpdateFields {
 229            kind,
 230            status,
 231            title,
 232            content,
 233            locations,
 234            raw_input,
 235            raw_output,
 236        } = fields;
 237
 238        if let Some(kind) = kind {
 239            self.kind = kind;
 240        }
 241
 242        if let Some(status) = status {
 243            self.status = status.into();
 244        }
 245
 246        if let Some(title) = title {
 247            self.label.update(cx, |label, cx| {
 248                if let Some((first_line, _)) = title.split_once("\n") {
 249                    label.replace(first_line.to_owned() + "", cx)
 250                } else {
 251                    label.replace(title, cx);
 252                }
 253            });
 254        }
 255
 256        if let Some(content) = content {
 257            let new_content_len = content.len();
 258            let mut content = content.into_iter();
 259
 260            // Reuse existing content if we can
 261            for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
 262                old.update_from_acp(new, language_registry.clone(), terminals, cx)?;
 263            }
 264            for new in content {
 265                self.content.push(ToolCallContent::from_acp(
 266                    new,
 267                    language_registry.clone(),
 268                    terminals,
 269                    cx,
 270                )?)
 271            }
 272            self.content.truncate(new_content_len);
 273        }
 274
 275        if let Some(locations) = locations {
 276            self.locations = locations;
 277        }
 278
 279        if let Some(raw_input) = raw_input {
 280            self.raw_input = Some(raw_input);
 281        }
 282
 283        if let Some(raw_output) = raw_output {
 284            if self.content.is_empty()
 285                && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
 286            {
 287                self.content
 288                    .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
 289                        markdown,
 290                    }));
 291            }
 292            self.raw_output = Some(raw_output);
 293        }
 294        Ok(())
 295    }
 296
 297    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 298        self.content.iter().filter_map(|content| match content {
 299            ToolCallContent::Diff(diff) => Some(diff),
 300            ToolCallContent::ContentBlock(_) => None,
 301            ToolCallContent::Terminal(_) => None,
 302        })
 303    }
 304
 305    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 306        self.content.iter().filter_map(|content| match content {
 307            ToolCallContent::Terminal(terminal) => Some(terminal),
 308            ToolCallContent::ContentBlock(_) => None,
 309            ToolCallContent::Diff(_) => None,
 310        })
 311    }
 312
 313    fn to_markdown(&self, cx: &App) -> String {
 314        let mut markdown = format!(
 315            "**Tool Call: {}**\nStatus: {}\n\n",
 316            self.label.read(cx).source(),
 317            self.status
 318        );
 319        for content in &self.content {
 320            markdown.push_str(content.to_markdown(cx).as_str());
 321            markdown.push_str("\n\n");
 322        }
 323        markdown
 324    }
 325
 326    async fn resolve_location(
 327        location: acp::ToolCallLocation,
 328        project: WeakEntity<Project>,
 329        cx: &mut AsyncApp,
 330    ) -> Option<AgentLocation> {
 331        let buffer = project
 332            .update(cx, |project, cx| {
 333                project
 334                    .project_path_for_absolute_path(&location.path, cx)
 335                    .map(|path| project.open_buffer(path, cx))
 336            })
 337            .ok()??;
 338        let buffer = buffer.await.log_err()?;
 339        let position = buffer
 340            .update(cx, |buffer, _| {
 341                if let Some(row) = location.line {
 342                    let snapshot = buffer.snapshot();
 343                    let column = snapshot.indent_size_for_line(row).len;
 344                    let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
 345                    snapshot.anchor_before(point)
 346                } else {
 347                    Anchor::MIN
 348                }
 349            })
 350            .ok()?;
 351
 352        Some(AgentLocation {
 353            buffer: buffer.downgrade(),
 354            position,
 355        })
 356    }
 357
 358    fn resolve_locations(
 359        &self,
 360        project: Entity<Project>,
 361        cx: &mut App,
 362    ) -> Task<Vec<Option<AgentLocation>>> {
 363        let locations = self.locations.clone();
 364        project.update(cx, |_, cx| {
 365            cx.spawn(async move |project, cx| {
 366                let mut new_locations = Vec::new();
 367                for location in locations {
 368                    new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
 369                }
 370                new_locations
 371            })
 372        })
 373    }
 374}
 375
 376#[derive(Debug)]
 377pub enum ToolCallStatus {
 378    /// The tool call hasn't started running yet, but we start showing it to
 379    /// the user.
 380    Pending,
 381    /// The tool call is waiting for confirmation from the user.
 382    WaitingForConfirmation {
 383        options: Vec<acp::PermissionOption>,
 384        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 385    },
 386    /// The tool call is currently running.
 387    InProgress,
 388    /// The tool call completed successfully.
 389    Completed,
 390    /// The tool call failed.
 391    Failed,
 392    /// The user rejected the tool call.
 393    Rejected,
 394    /// The user canceled generation so the tool call was canceled.
 395    Canceled,
 396}
 397
 398impl From<acp::ToolCallStatus> for ToolCallStatus {
 399    fn from(status: acp::ToolCallStatus) -> Self {
 400        match status {
 401            acp::ToolCallStatus::Pending => Self::Pending,
 402            acp::ToolCallStatus::InProgress => Self::InProgress,
 403            acp::ToolCallStatus::Completed => Self::Completed,
 404            acp::ToolCallStatus::Failed => Self::Failed,
 405        }
 406    }
 407}
 408
 409impl Display for ToolCallStatus {
 410    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 411        write!(
 412            f,
 413            "{}",
 414            match self {
 415                ToolCallStatus::Pending => "Pending",
 416                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 417                ToolCallStatus::InProgress => "In Progress",
 418                ToolCallStatus::Completed => "Completed",
 419                ToolCallStatus::Failed => "Failed",
 420                ToolCallStatus::Rejected => "Rejected",
 421                ToolCallStatus::Canceled => "Canceled",
 422            }
 423        )
 424    }
 425}
 426
 427#[derive(Debug, PartialEq, Clone)]
 428pub enum ContentBlock {
 429    Empty,
 430    Markdown { markdown: Entity<Markdown> },
 431    ResourceLink { resource_link: acp::ResourceLink },
 432}
 433
 434impl ContentBlock {
 435    pub fn new(
 436        block: acp::ContentBlock,
 437        language_registry: &Arc<LanguageRegistry>,
 438        cx: &mut App,
 439    ) -> Self {
 440        let mut this = Self::Empty;
 441        this.append(block, language_registry, cx);
 442        this
 443    }
 444
 445    pub fn new_combined(
 446        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 447        language_registry: Arc<LanguageRegistry>,
 448        cx: &mut App,
 449    ) -> Self {
 450        let mut this = Self::Empty;
 451        for block in blocks {
 452            this.append(block, &language_registry, cx);
 453        }
 454        this
 455    }
 456
 457    pub fn append(
 458        &mut self,
 459        block: acp::ContentBlock,
 460        language_registry: &Arc<LanguageRegistry>,
 461        cx: &mut App,
 462    ) {
 463        if matches!(self, ContentBlock::Empty)
 464            && let acp::ContentBlock::ResourceLink(resource_link) = block
 465        {
 466            *self = ContentBlock::ResourceLink { resource_link };
 467            return;
 468        }
 469
 470        let new_content = self.block_string_contents(block);
 471
 472        match self {
 473            ContentBlock::Empty => {
 474                *self = Self::create_markdown_block(new_content, language_registry, cx);
 475            }
 476            ContentBlock::Markdown { markdown } => {
 477                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 478            }
 479            ContentBlock::ResourceLink { resource_link } => {
 480                let existing_content = Self::resource_link_md(&resource_link.uri);
 481                let combined = format!("{}\n{}", existing_content, new_content);
 482
 483                *self = Self::create_markdown_block(combined, language_registry, cx);
 484            }
 485        }
 486    }
 487
 488    fn create_markdown_block(
 489        content: String,
 490        language_registry: &Arc<LanguageRegistry>,
 491        cx: &mut App,
 492    ) -> ContentBlock {
 493        ContentBlock::Markdown {
 494            markdown: cx
 495                .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
 496        }
 497    }
 498
 499    fn block_string_contents(&self, block: acp::ContentBlock) -> String {
 500        match block {
 501            acp::ContentBlock::Text(text_content) => text_content.text,
 502            acp::ContentBlock::ResourceLink(resource_link) => {
 503                Self::resource_link_md(&resource_link.uri)
 504            }
 505            acp::ContentBlock::Resource(acp::EmbeddedResource {
 506                resource:
 507                    acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
 508                        uri,
 509                        ..
 510                    }),
 511                ..
 512            }) => Self::resource_link_md(&uri),
 513            acp::ContentBlock::Image(image) => Self::image_md(&image),
 514            acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
 515        }
 516    }
 517
 518    fn resource_link_md(uri: &str) -> String {
 519        if let Some(uri) = MentionUri::parse(uri).log_err() {
 520            uri.as_link().to_string()
 521        } else {
 522            uri.to_string()
 523        }
 524    }
 525
 526    fn image_md(_image: &acp::ImageContent) -> String {
 527        "`Image`".into()
 528    }
 529
 530    pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 531        match self {
 532            ContentBlock::Empty => "",
 533            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 534            ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
 535        }
 536    }
 537
 538    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 539        match self {
 540            ContentBlock::Empty => None,
 541            ContentBlock::Markdown { markdown } => Some(markdown),
 542            ContentBlock::ResourceLink { .. } => None,
 543        }
 544    }
 545
 546    pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
 547        match self {
 548            ContentBlock::ResourceLink { resource_link } => Some(resource_link),
 549            _ => None,
 550        }
 551    }
 552}
 553
 554#[derive(Debug)]
 555pub enum ToolCallContent {
 556    ContentBlock(ContentBlock),
 557    Diff(Entity<Diff>),
 558    Terminal(Entity<Terminal>),
 559}
 560
 561impl ToolCallContent {
 562    pub fn from_acp(
 563        content: acp::ToolCallContent,
 564        language_registry: Arc<LanguageRegistry>,
 565        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 566        cx: &mut App,
 567    ) -> Result<Self> {
 568        match content {
 569            acp::ToolCallContent::Content { content } => Ok(Self::ContentBlock(ContentBlock::new(
 570                content,
 571                &language_registry,
 572                cx,
 573            ))),
 574            acp::ToolCallContent::Diff { diff } => Ok(Self::Diff(cx.new(|cx| {
 575                Diff::finalized(
 576                    diff.path,
 577                    diff.old_text,
 578                    diff.new_text,
 579                    language_registry,
 580                    cx,
 581                )
 582            }))),
 583            acp::ToolCallContent::Terminal { terminal_id } => terminals
 584                .get(&terminal_id)
 585                .cloned()
 586                .map(Self::Terminal)
 587                .ok_or_else(|| anyhow::anyhow!("Terminal with id `{}` not found", terminal_id)),
 588        }
 589    }
 590
 591    pub fn update_from_acp(
 592        &mut self,
 593        new: acp::ToolCallContent,
 594        language_registry: Arc<LanguageRegistry>,
 595        terminals: &HashMap<acp::TerminalId, Entity<Terminal>>,
 596        cx: &mut App,
 597    ) -> Result<()> {
 598        let needs_update = match (&self, &new) {
 599            (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
 600                old_diff.read(cx).needs_update(
 601                    new_diff.old_text.as_deref().unwrap_or(""),
 602                    &new_diff.new_text,
 603                    cx,
 604                )
 605            }
 606            _ => true,
 607        };
 608
 609        if needs_update {
 610            *self = Self::from_acp(new, language_registry, terminals, cx)?;
 611        }
 612        Ok(())
 613    }
 614
 615    pub fn to_markdown(&self, cx: &App) -> String {
 616        match self {
 617            Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
 618            Self::Diff(diff) => diff.read(cx).to_markdown(cx),
 619            Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
 620        }
 621    }
 622}
 623
 624#[derive(Debug, PartialEq)]
 625pub enum ToolCallUpdate {
 626    UpdateFields(acp::ToolCallUpdate),
 627    UpdateDiff(ToolCallUpdateDiff),
 628    UpdateTerminal(ToolCallUpdateTerminal),
 629}
 630
 631impl ToolCallUpdate {
 632    fn id(&self) -> &acp::ToolCallId {
 633        match self {
 634            Self::UpdateFields(update) => &update.id,
 635            Self::UpdateDiff(diff) => &diff.id,
 636            Self::UpdateTerminal(terminal) => &terminal.id,
 637        }
 638    }
 639}
 640
 641impl From<acp::ToolCallUpdate> for ToolCallUpdate {
 642    fn from(update: acp::ToolCallUpdate) -> Self {
 643        Self::UpdateFields(update)
 644    }
 645}
 646
 647impl From<ToolCallUpdateDiff> for ToolCallUpdate {
 648    fn from(diff: ToolCallUpdateDiff) -> Self {
 649        Self::UpdateDiff(diff)
 650    }
 651}
 652
 653#[derive(Debug, PartialEq)]
 654pub struct ToolCallUpdateDiff {
 655    pub id: acp::ToolCallId,
 656    pub diff: Entity<Diff>,
 657}
 658
 659impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
 660    fn from(terminal: ToolCallUpdateTerminal) -> Self {
 661        Self::UpdateTerminal(terminal)
 662    }
 663}
 664
 665#[derive(Debug, PartialEq)]
 666pub struct ToolCallUpdateTerminal {
 667    pub id: acp::ToolCallId,
 668    pub terminal: Entity<Terminal>,
 669}
 670
 671#[derive(Debug, Default)]
 672pub struct Plan {
 673    pub entries: Vec<PlanEntry>,
 674}
 675
 676#[derive(Debug)]
 677pub struct PlanStats<'a> {
 678    pub in_progress_entry: Option<&'a PlanEntry>,
 679    pub pending: u32,
 680    pub completed: u32,
 681}
 682
 683impl Plan {
 684    pub fn is_empty(&self) -> bool {
 685        self.entries.is_empty()
 686    }
 687
 688    pub fn stats(&self) -> PlanStats<'_> {
 689        let mut stats = PlanStats {
 690            in_progress_entry: None,
 691            pending: 0,
 692            completed: 0,
 693        };
 694
 695        for entry in &self.entries {
 696            match &entry.status {
 697                acp::PlanEntryStatus::Pending => {
 698                    stats.pending += 1;
 699                }
 700                acp::PlanEntryStatus::InProgress => {
 701                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 702                }
 703                acp::PlanEntryStatus::Completed => {
 704                    stats.completed += 1;
 705                }
 706            }
 707        }
 708
 709        stats
 710    }
 711}
 712
 713#[derive(Debug)]
 714pub struct PlanEntry {
 715    pub content: Entity<Markdown>,
 716    pub priority: acp::PlanEntryPriority,
 717    pub status: acp::PlanEntryStatus,
 718}
 719
 720impl PlanEntry {
 721    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 722        Self {
 723            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 724            priority: entry.priority,
 725            status: entry.status,
 726        }
 727    }
 728}
 729
 730#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 731pub struct TokenUsage {
 732    pub max_tokens: u64,
 733    pub used_tokens: u64,
 734}
 735
 736impl TokenUsage {
 737    pub fn ratio(&self) -> TokenUsageRatio {
 738        #[cfg(debug_assertions)]
 739        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 740            .unwrap_or("0.8".to_string())
 741            .parse()
 742            .unwrap();
 743        #[cfg(not(debug_assertions))]
 744        let warning_threshold: f32 = 0.8;
 745
 746        // When the maximum is unknown because there is no selected model,
 747        // avoid showing the token limit warning.
 748        if self.max_tokens == 0 {
 749            TokenUsageRatio::Normal
 750        } else if self.used_tokens >= self.max_tokens {
 751            TokenUsageRatio::Exceeded
 752        } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
 753            TokenUsageRatio::Warning
 754        } else {
 755            TokenUsageRatio::Normal
 756        }
 757    }
 758}
 759
 760#[derive(Debug, Clone, PartialEq, Eq)]
 761pub enum TokenUsageRatio {
 762    Normal,
 763    Warning,
 764    Exceeded,
 765}
 766
 767#[derive(Debug, Clone)]
 768pub struct RetryStatus {
 769    pub last_error: SharedString,
 770    pub attempt: usize,
 771    pub max_attempts: usize,
 772    pub started_at: Instant,
 773    pub duration: Duration,
 774}
 775
 776pub struct AcpThread {
 777    title: SharedString,
 778    entries: Vec<AgentThreadEntry>,
 779    plan: Plan,
 780    project: Entity<Project>,
 781    action_log: Entity<ActionLog>,
 782    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 783    send_task: Option<Task<()>>,
 784    connection: Rc<dyn AgentConnection>,
 785    session_id: acp::SessionId,
 786    token_usage: Option<TokenUsage>,
 787    prompt_capabilities: acp::PromptCapabilities,
 788    available_commands: Vec<acp::AvailableCommand>,
 789    _observe_prompt_capabilities: Task<anyhow::Result<()>>,
 790    determine_shell: Shared<Task<String>>,
 791    terminals: HashMap<acp::TerminalId, Entity<Terminal>>,
 792}
 793
 794#[derive(Debug)]
 795pub enum AcpThreadEvent {
 796    NewEntry,
 797    TitleUpdated,
 798    TokenUsageUpdated,
 799    EntryUpdated(usize),
 800    EntriesRemoved(Range<usize>),
 801    ToolAuthorizationRequired,
 802    Retry(RetryStatus),
 803    Stopped,
 804    Error,
 805    LoadError(LoadError),
 806    PromptCapabilitiesUpdated,
 807}
 808
 809impl EventEmitter<AcpThreadEvent> for AcpThread {}
 810
 811#[derive(PartialEq, Eq, Debug)]
 812pub enum ThreadStatus {
 813    Idle,
 814    WaitingForToolConfirmation,
 815    Generating,
 816}
 817
 818#[derive(Debug, Clone)]
 819pub enum LoadError {
 820    Unsupported {
 821        command: SharedString,
 822        current_version: SharedString,
 823        minimum_version: SharedString,
 824    },
 825    FailedToInstall(SharedString),
 826    Exited {
 827        status: ExitStatus,
 828    },
 829    Other(SharedString),
 830}
 831
 832impl Display for LoadError {
 833    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 834        match self {
 835            LoadError::Unsupported {
 836                command: path,
 837                current_version,
 838                minimum_version,
 839            } => {
 840                write!(
 841                    f,
 842                    "version {current_version} from {path} is not supported (need at least {minimum_version})"
 843                )
 844            }
 845            LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
 846            LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
 847            LoadError::Other(msg) => write!(f, "{msg}"),
 848        }
 849    }
 850}
 851
 852impl Error for LoadError {}
 853
 854impl AcpThread {
 855    pub fn new(
 856        title: impl Into<SharedString>,
 857        connection: Rc<dyn AgentConnection>,
 858        project: Entity<Project>,
 859        action_log: Entity<ActionLog>,
 860        session_id: acp::SessionId,
 861        mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
 862        available_commands: Vec<acp::AvailableCommand>,
 863        cx: &mut Context<Self>,
 864    ) -> Self {
 865        let prompt_capabilities = *prompt_capabilities_rx.borrow();
 866        let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
 867            loop {
 868                let caps = prompt_capabilities_rx.recv().await?;
 869                this.update(cx, |this, cx| {
 870                    this.prompt_capabilities = caps;
 871                    cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
 872                })?;
 873            }
 874        });
 875
 876        let determine_shell = cx
 877            .background_spawn(async move {
 878                if cfg!(windows) {
 879                    return get_system_shell();
 880                }
 881
 882                if which::which("bash").is_ok() {
 883                    "bash".into()
 884                } else {
 885                    get_system_shell()
 886                }
 887            })
 888            .shared();
 889
 890        Self {
 891            action_log,
 892            shared_buffers: Default::default(),
 893            entries: Default::default(),
 894            plan: Default::default(),
 895            title: title.into(),
 896            project,
 897            send_task: None,
 898            connection,
 899            session_id,
 900            token_usage: None,
 901            prompt_capabilities,
 902            available_commands,
 903            _observe_prompt_capabilities: task,
 904            terminals: HashMap::default(),
 905            determine_shell,
 906        }
 907    }
 908
 909    pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
 910        self.prompt_capabilities
 911    }
 912
 913    pub fn available_commands(&self) -> Vec<acp::AvailableCommand> {
 914        self.available_commands.clone()
 915    }
 916
 917    pub fn connection(&self) -> &Rc<dyn AgentConnection> {
 918        &self.connection
 919    }
 920
 921    pub fn action_log(&self) -> &Entity<ActionLog> {
 922        &self.action_log
 923    }
 924
 925    pub fn project(&self) -> &Entity<Project> {
 926        &self.project
 927    }
 928
 929    pub fn title(&self) -> SharedString {
 930        self.title.clone()
 931    }
 932
 933    pub fn entries(&self) -> &[AgentThreadEntry] {
 934        &self.entries
 935    }
 936
 937    pub fn session_id(&self) -> &acp::SessionId {
 938        &self.session_id
 939    }
 940
 941    pub fn status(&self) -> ThreadStatus {
 942        if self.send_task.is_some() {
 943            if self.waiting_for_tool_confirmation() {
 944                ThreadStatus::WaitingForToolConfirmation
 945            } else {
 946                ThreadStatus::Generating
 947            }
 948        } else {
 949            ThreadStatus::Idle
 950        }
 951    }
 952
 953    pub fn token_usage(&self) -> Option<&TokenUsage> {
 954        self.token_usage.as_ref()
 955    }
 956
 957    pub fn has_pending_edit_tool_calls(&self) -> bool {
 958        for entry in self.entries.iter().rev() {
 959            match entry {
 960                AgentThreadEntry::UserMessage(_) => return false,
 961                AgentThreadEntry::ToolCall(
 962                    call @ ToolCall {
 963                        status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
 964                        ..
 965                    },
 966                ) if call.diffs().next().is_some() => {
 967                    return true;
 968                }
 969                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 970            }
 971        }
 972
 973        false
 974    }
 975
 976    pub fn used_tools_since_last_user_message(&self) -> bool {
 977        for entry in self.entries.iter().rev() {
 978            match entry {
 979                AgentThreadEntry::UserMessage(..) => return false,
 980                AgentThreadEntry::AssistantMessage(..) => continue,
 981                AgentThreadEntry::ToolCall(..) => return true,
 982            }
 983        }
 984
 985        false
 986    }
 987
 988    pub fn handle_session_update(
 989        &mut self,
 990        update: acp::SessionUpdate,
 991        cx: &mut Context<Self>,
 992    ) -> Result<(), acp::Error> {
 993        match update {
 994            acp::SessionUpdate::UserMessageChunk { content } => {
 995                self.push_user_content_block(None, content, cx);
 996            }
 997            acp::SessionUpdate::AgentMessageChunk { content } => {
 998                self.push_assistant_content_block(content, false, cx);
 999            }
1000            acp::SessionUpdate::AgentThoughtChunk { content } => {
1001                self.push_assistant_content_block(content, true, cx);
1002            }
1003            acp::SessionUpdate::ToolCall(tool_call) => {
1004                self.upsert_tool_call(tool_call, cx)?;
1005            }
1006            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
1007                self.update_tool_call(tool_call_update, cx)?;
1008            }
1009            acp::SessionUpdate::Plan(plan) => {
1010                self.update_plan(plan, cx);
1011            }
1012        }
1013        Ok(())
1014    }
1015
1016    pub fn push_user_content_block(
1017        &mut self,
1018        message_id: Option<UserMessageId>,
1019        chunk: acp::ContentBlock,
1020        cx: &mut Context<Self>,
1021    ) {
1022        let language_registry = self.project.read(cx).languages().clone();
1023        let entries_len = self.entries.len();
1024
1025        if let Some(last_entry) = self.entries.last_mut()
1026            && let AgentThreadEntry::UserMessage(UserMessage {
1027                id,
1028                content,
1029                chunks,
1030                ..
1031            }) = last_entry
1032        {
1033            *id = message_id.or(id.take());
1034            content.append(chunk.clone(), &language_registry, cx);
1035            chunks.push(chunk);
1036            let idx = entries_len - 1;
1037            cx.emit(AcpThreadEvent::EntryUpdated(idx));
1038        } else {
1039            let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
1040            self.push_entry(
1041                AgentThreadEntry::UserMessage(UserMessage {
1042                    id: message_id,
1043                    content,
1044                    chunks: vec![chunk],
1045                    checkpoint: None,
1046                }),
1047                cx,
1048            );
1049        }
1050    }
1051
1052    pub fn push_assistant_content_block(
1053        &mut self,
1054        chunk: acp::ContentBlock,
1055        is_thought: bool,
1056        cx: &mut Context<Self>,
1057    ) {
1058        let language_registry = self.project.read(cx).languages().clone();
1059        let entries_len = self.entries.len();
1060        if let Some(last_entry) = self.entries.last_mut()
1061            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1062        {
1063            let idx = entries_len - 1;
1064            cx.emit(AcpThreadEvent::EntryUpdated(idx));
1065            match (chunks.last_mut(), is_thought) {
1066                (Some(AssistantMessageChunk::Message { block }), false)
1067                | (Some(AssistantMessageChunk::Thought { block }), true) => {
1068                    block.append(chunk, &language_registry, cx)
1069                }
1070                _ => {
1071                    let block = ContentBlock::new(chunk, &language_registry, cx);
1072                    if is_thought {
1073                        chunks.push(AssistantMessageChunk::Thought { block })
1074                    } else {
1075                        chunks.push(AssistantMessageChunk::Message { block })
1076                    }
1077                }
1078            }
1079        } else {
1080            let block = ContentBlock::new(chunk, &language_registry, cx);
1081            let chunk = if is_thought {
1082                AssistantMessageChunk::Thought { block }
1083            } else {
1084                AssistantMessageChunk::Message { block }
1085            };
1086
1087            self.push_entry(
1088                AgentThreadEntry::AssistantMessage(AssistantMessage {
1089                    chunks: vec![chunk],
1090                }),
1091                cx,
1092            );
1093        }
1094    }
1095
1096    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1097        self.entries.push(entry);
1098        cx.emit(AcpThreadEvent::NewEntry);
1099    }
1100
1101    pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1102        self.connection.set_title(&self.session_id, cx).is_some()
1103    }
1104
1105    pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1106        if title != self.title {
1107            self.title = title.clone();
1108            cx.emit(AcpThreadEvent::TitleUpdated);
1109            if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1110                return set_title.run(title, cx);
1111            }
1112        }
1113        Task::ready(Ok(()))
1114    }
1115
1116    pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1117        self.token_usage = usage;
1118        cx.emit(AcpThreadEvent::TokenUsageUpdated);
1119    }
1120
1121    pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1122        cx.emit(AcpThreadEvent::Retry(status));
1123    }
1124
1125    pub fn update_tool_call(
1126        &mut self,
1127        update: impl Into<ToolCallUpdate>,
1128        cx: &mut Context<Self>,
1129    ) -> Result<()> {
1130        let update = update.into();
1131        let languages = self.project.read(cx).languages().clone();
1132
1133        let ix = self
1134            .index_for_tool_call(update.id())
1135            .context("Tool call not found")?;
1136        let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1137            unreachable!()
1138        };
1139
1140        match update {
1141            ToolCallUpdate::UpdateFields(update) => {
1142                let location_updated = update.fields.locations.is_some();
1143                call.update_fields(update.fields, languages, &self.terminals, cx)?;
1144                if location_updated {
1145                    self.resolve_locations(update.id, cx);
1146                }
1147            }
1148            ToolCallUpdate::UpdateDiff(update) => {
1149                call.content.clear();
1150                call.content.push(ToolCallContent::Diff(update.diff));
1151            }
1152            ToolCallUpdate::UpdateTerminal(update) => {
1153                call.content.clear();
1154                call.content
1155                    .push(ToolCallContent::Terminal(update.terminal));
1156            }
1157        }
1158
1159        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1160
1161        Ok(())
1162    }
1163
1164    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1165    pub fn upsert_tool_call(
1166        &mut self,
1167        tool_call: acp::ToolCall,
1168        cx: &mut Context<Self>,
1169    ) -> Result<(), acp::Error> {
1170        let status = tool_call.status.into();
1171        self.upsert_tool_call_inner(tool_call.into(), status, cx)
1172    }
1173
1174    /// Fails if id does not match an existing entry.
1175    pub fn upsert_tool_call_inner(
1176        &mut self,
1177        update: acp::ToolCallUpdate,
1178        status: ToolCallStatus,
1179        cx: &mut Context<Self>,
1180    ) -> Result<(), acp::Error> {
1181        let language_registry = self.project.read(cx).languages().clone();
1182        let id = update.id.clone();
1183
1184        if let Some(ix) = self.index_for_tool_call(&id) {
1185            let AgentThreadEntry::ToolCall(call) = &mut self.entries[ix] else {
1186                unreachable!()
1187            };
1188
1189            call.update_fields(update.fields, language_registry, &self.terminals, cx)?;
1190            call.status = status;
1191
1192            cx.emit(AcpThreadEvent::EntryUpdated(ix));
1193        } else {
1194            let call = ToolCall::from_acp(
1195                update.try_into()?,
1196                status,
1197                language_registry,
1198                &self.terminals,
1199                cx,
1200            )?;
1201            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1202        };
1203
1204        self.resolve_locations(id, cx);
1205        Ok(())
1206    }
1207
1208    fn index_for_tool_call(&self, id: &acp::ToolCallId) -> Option<usize> {
1209        self.entries
1210            .iter()
1211            .enumerate()
1212            .rev()
1213            .find_map(|(index, entry)| {
1214                if let AgentThreadEntry::ToolCall(tool_call) = entry
1215                    && &tool_call.id == id
1216                {
1217                    Some(index)
1218                } else {
1219                    None
1220                }
1221            })
1222    }
1223
1224    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1225        // The tool call we are looking for is typically the last one, or very close to the end.
1226        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1227        self.entries
1228            .iter_mut()
1229            .enumerate()
1230            .rev()
1231            .find_map(|(index, tool_call)| {
1232                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1233                    && &tool_call.id == id
1234                {
1235                    Some((index, tool_call))
1236                } else {
1237                    None
1238                }
1239            })
1240    }
1241
1242    pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1243        self.entries
1244            .iter()
1245            .enumerate()
1246            .rev()
1247            .find_map(|(index, tool_call)| {
1248                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1249                    && &tool_call.id == id
1250                {
1251                    Some((index, tool_call))
1252                } else {
1253                    None
1254                }
1255            })
1256    }
1257
1258    pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1259        let project = self.project.clone();
1260        let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1261            return;
1262        };
1263        let task = tool_call.resolve_locations(project, cx);
1264        cx.spawn(async move |this, cx| {
1265            let resolved_locations = task.await;
1266            this.update(cx, |this, cx| {
1267                let project = this.project.clone();
1268                let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1269                    return;
1270                };
1271                if let Some(Some(location)) = resolved_locations.last() {
1272                    project.update(cx, |project, cx| {
1273                        if let Some(agent_location) = project.agent_location() {
1274                            let should_ignore = agent_location.buffer == location.buffer
1275                                && location
1276                                    .buffer
1277                                    .update(cx, |buffer, _| {
1278                                        let snapshot = buffer.snapshot();
1279                                        let old_position =
1280                                            agent_location.position.to_point(&snapshot);
1281                                        let new_position = location.position.to_point(&snapshot);
1282                                        // ignore this so that when we get updates from the edit tool
1283                                        // the position doesn't reset to the startof line
1284                                        old_position.row == new_position.row
1285                                            && old_position.column > new_position.column
1286                                    })
1287                                    .ok()
1288                                    .unwrap_or_default();
1289                            if !should_ignore {
1290                                project.set_agent_location(Some(location.clone()), cx);
1291                            }
1292                        }
1293                    });
1294                }
1295                if tool_call.resolved_locations != resolved_locations {
1296                    tool_call.resolved_locations = resolved_locations;
1297                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1298                }
1299            })
1300        })
1301        .detach();
1302    }
1303
1304    pub fn request_tool_call_authorization(
1305        &mut self,
1306        tool_call: acp::ToolCallUpdate,
1307        options: Vec<acp::PermissionOption>,
1308        cx: &mut Context<Self>,
1309    ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1310        let (tx, rx) = oneshot::channel();
1311
1312        if AgentSettings::get_global(cx).always_allow_tool_actions {
1313            // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1314            // some tools would (incorrectly) continue to auto-accept.
1315            if let Some(allow_once_option) = options.iter().find_map(|option| {
1316                if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1317                    Some(option.id.clone())
1318                } else {
1319                    None
1320                }
1321            }) {
1322                self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1323                return Ok(async {
1324                    acp::RequestPermissionOutcome::Selected {
1325                        option_id: allow_once_option,
1326                    }
1327                }
1328                .boxed());
1329            }
1330        }
1331
1332        let status = ToolCallStatus::WaitingForConfirmation {
1333            options,
1334            respond_tx: tx,
1335        };
1336
1337        self.upsert_tool_call_inner(tool_call, status, cx)?;
1338        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1339
1340        let fut = async {
1341            match rx.await {
1342                Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
1343                Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1344            }
1345        }
1346        .boxed();
1347
1348        Ok(fut)
1349    }
1350
1351    pub fn authorize_tool_call(
1352        &mut self,
1353        id: acp::ToolCallId,
1354        option_id: acp::PermissionOptionId,
1355        option_kind: acp::PermissionOptionKind,
1356        cx: &mut Context<Self>,
1357    ) {
1358        let Some((ix, call)) = self.tool_call_mut(&id) else {
1359            return;
1360        };
1361
1362        let new_status = match option_kind {
1363            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1364                ToolCallStatus::Rejected
1365            }
1366            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1367                ToolCallStatus::InProgress
1368            }
1369        };
1370
1371        let curr_status = mem::replace(&mut call.status, new_status);
1372
1373        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1374            respond_tx.send(option_id).log_err();
1375        } else if cfg!(debug_assertions) {
1376            panic!("tried to authorize an already authorized tool call");
1377        }
1378
1379        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1380    }
1381
1382    /// Returns true if the last turn is awaiting tool authorization
1383    pub fn waiting_for_tool_confirmation(&self) -> bool {
1384        for entry in self.entries.iter().rev() {
1385            match &entry {
1386                AgentThreadEntry::ToolCall(call) => match call.status {
1387                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
1388                    ToolCallStatus::Pending
1389                    | ToolCallStatus::InProgress
1390                    | ToolCallStatus::Completed
1391                    | ToolCallStatus::Failed
1392                    | ToolCallStatus::Rejected
1393                    | ToolCallStatus::Canceled => continue,
1394                },
1395                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1396                    // Reached the beginning of the turn
1397                    return false;
1398                }
1399            }
1400        }
1401        false
1402    }
1403
1404    pub fn plan(&self) -> &Plan {
1405        &self.plan
1406    }
1407
1408    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1409        let new_entries_len = request.entries.len();
1410        let mut new_entries = request.entries.into_iter();
1411
1412        // Reuse existing markdown to prevent flickering
1413        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1414            let PlanEntry {
1415                content,
1416                priority,
1417                status,
1418            } = old;
1419            content.update(cx, |old, cx| {
1420                old.replace(new.content, cx);
1421            });
1422            *priority = new.priority;
1423            *status = new.status;
1424        }
1425        for new in new_entries {
1426            self.plan.entries.push(PlanEntry::from_acp(new, cx))
1427        }
1428        self.plan.entries.truncate(new_entries_len);
1429
1430        cx.notify();
1431    }
1432
1433    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1434        self.plan
1435            .entries
1436            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1437        cx.notify();
1438    }
1439
1440    #[cfg(any(test, feature = "test-support"))]
1441    pub fn send_raw(
1442        &mut self,
1443        message: &str,
1444        cx: &mut Context<Self>,
1445    ) -> BoxFuture<'static, Result<()>> {
1446        self.send(
1447            vec![acp::ContentBlock::Text(acp::TextContent {
1448                text: message.to_string(),
1449                annotations: None,
1450            })],
1451            cx,
1452        )
1453    }
1454
1455    pub fn send(
1456        &mut self,
1457        message: Vec<acp::ContentBlock>,
1458        cx: &mut Context<Self>,
1459    ) -> BoxFuture<'static, Result<()>> {
1460        let block = ContentBlock::new_combined(
1461            message.clone(),
1462            self.project.read(cx).languages().clone(),
1463            cx,
1464        );
1465        let request = acp::PromptRequest {
1466            prompt: message.clone(),
1467            session_id: self.session_id.clone(),
1468        };
1469        let git_store = self.project.read(cx).git_store().clone();
1470
1471        let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1472            Some(UserMessageId::new())
1473        } else {
1474            None
1475        };
1476
1477        self.run_turn(cx, async move |this, cx| {
1478            this.update(cx, |this, cx| {
1479                this.push_entry(
1480                    AgentThreadEntry::UserMessage(UserMessage {
1481                        id: message_id.clone(),
1482                        content: block,
1483                        chunks: message,
1484                        checkpoint: None,
1485                    }),
1486                    cx,
1487                );
1488            })
1489            .ok();
1490
1491            let old_checkpoint = git_store
1492                .update(cx, |git, cx| git.checkpoint(cx))?
1493                .await
1494                .context("failed to get old checkpoint")
1495                .log_err();
1496            this.update(cx, |this, cx| {
1497                if let Some((_ix, message)) = this.last_user_message() {
1498                    message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1499                        git_checkpoint,
1500                        show: false,
1501                    });
1502                }
1503                this.connection.prompt(message_id, request, cx)
1504            })?
1505            .await
1506        })
1507    }
1508
1509    pub fn can_resume(&self, cx: &App) -> bool {
1510        self.connection.resume(&self.session_id, cx).is_some()
1511    }
1512
1513    pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1514        self.run_turn(cx, async move |this, cx| {
1515            this.update(cx, |this, cx| {
1516                this.connection
1517                    .resume(&this.session_id, cx)
1518                    .map(|resume| resume.run(cx))
1519            })?
1520            .context("resuming a session is not supported")?
1521            .await
1522        })
1523    }
1524
1525    fn run_turn(
1526        &mut self,
1527        cx: &mut Context<Self>,
1528        f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1529    ) -> BoxFuture<'static, Result<()>> {
1530        self.clear_completed_plan_entries(cx);
1531
1532        let (tx, rx) = oneshot::channel();
1533        let cancel_task = self.cancel(cx);
1534
1535        self.send_task = Some(cx.spawn(async move |this, cx| {
1536            cancel_task.await;
1537            tx.send(f(this, cx).await).ok();
1538        }));
1539
1540        cx.spawn(async move |this, cx| {
1541            let response = rx.await;
1542
1543            this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1544                .await?;
1545
1546            this.update(cx, |this, cx| {
1547                this.project
1548                    .update(cx, |project, cx| project.set_agent_location(None, cx));
1549                match response {
1550                    Ok(Err(e)) => {
1551                        this.send_task.take();
1552                        cx.emit(AcpThreadEvent::Error);
1553                        Err(e)
1554                    }
1555                    result => {
1556                        let canceled = matches!(
1557                            result,
1558                            Ok(Ok(acp::PromptResponse {
1559                                stop_reason: acp::StopReason::Cancelled
1560                            }))
1561                        );
1562
1563                        // We only take the task if the current prompt wasn't canceled.
1564                        //
1565                        // This prompt may have been canceled because another one was sent
1566                        // while it was still generating. In these cases, dropping `send_task`
1567                        // would cause the next generation to be canceled.
1568                        if !canceled {
1569                            this.send_task.take();
1570                        }
1571
1572                        // Truncate entries if the last prompt was refused.
1573                        if let Ok(Ok(acp::PromptResponse {
1574                            stop_reason: acp::StopReason::Refusal,
1575                        })) = result
1576                            && let Some((ix, _)) = this.last_user_message()
1577                        {
1578                            let range = ix..this.entries.len();
1579                            this.entries.truncate(ix);
1580                            cx.emit(AcpThreadEvent::EntriesRemoved(range));
1581                        }
1582
1583                        cx.emit(AcpThreadEvent::Stopped);
1584                        Ok(())
1585                    }
1586                }
1587            })?
1588        })
1589        .boxed()
1590    }
1591
1592    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1593        let Some(send_task) = self.send_task.take() else {
1594            return Task::ready(());
1595        };
1596
1597        for entry in self.entries.iter_mut() {
1598            if let AgentThreadEntry::ToolCall(call) = entry {
1599                let cancel = matches!(
1600                    call.status,
1601                    ToolCallStatus::Pending
1602                        | ToolCallStatus::WaitingForConfirmation { .. }
1603                        | ToolCallStatus::InProgress
1604                );
1605
1606                if cancel {
1607                    call.status = ToolCallStatus::Canceled;
1608                }
1609            }
1610        }
1611
1612        self.connection.cancel(&self.session_id, cx);
1613
1614        // Wait for the send task to complete
1615        cx.foreground_executor().spawn(send_task)
1616    }
1617
1618    /// Rewinds this thread to before the entry at `index`, removing it and all
1619    /// subsequent entries while reverting any changes made from that point.
1620    pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1621        let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1622            return Task::ready(Err(anyhow!("not supported")));
1623        };
1624        let Some(message) = self.user_message(&id) else {
1625            return Task::ready(Err(anyhow!("message not found")));
1626        };
1627
1628        let checkpoint = message
1629            .checkpoint
1630            .as_ref()
1631            .map(|c| c.git_checkpoint.clone());
1632
1633        let git_store = self.project.read(cx).git_store().clone();
1634        cx.spawn(async move |this, cx| {
1635            if let Some(checkpoint) = checkpoint {
1636                git_store
1637                    .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1638                    .await?;
1639            }
1640
1641            cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1642            this.update(cx, |this, cx| {
1643                if let Some((ix, _)) = this.user_message_mut(&id) {
1644                    let range = ix..this.entries.len();
1645                    this.entries.truncate(ix);
1646                    cx.emit(AcpThreadEvent::EntriesRemoved(range));
1647                }
1648            })
1649        })
1650    }
1651
1652    fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1653        let git_store = self.project.read(cx).git_store().clone();
1654
1655        let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1656            if let Some(checkpoint) = message.checkpoint.as_ref() {
1657                checkpoint.git_checkpoint.clone()
1658            } else {
1659                return Task::ready(Ok(()));
1660            }
1661        } else {
1662            return Task::ready(Ok(()));
1663        };
1664
1665        let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1666        cx.spawn(async move |this, cx| {
1667            let new_checkpoint = new_checkpoint
1668                .await
1669                .context("failed to get new checkpoint")
1670                .log_err();
1671            if let Some(new_checkpoint) = new_checkpoint {
1672                let equal = git_store
1673                    .update(cx, |git, cx| {
1674                        git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1675                    })?
1676                    .await
1677                    .unwrap_or(true);
1678                this.update(cx, |this, cx| {
1679                    let (ix, message) = this.last_user_message().context("no user message")?;
1680                    let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1681                    checkpoint.show = !equal;
1682                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1683                    anyhow::Ok(())
1684                })??;
1685            }
1686
1687            Ok(())
1688        })
1689    }
1690
1691    fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1692        self.entries
1693            .iter_mut()
1694            .enumerate()
1695            .rev()
1696            .find_map(|(ix, entry)| {
1697                if let AgentThreadEntry::UserMessage(message) = entry {
1698                    Some((ix, message))
1699                } else {
1700                    None
1701                }
1702            })
1703    }
1704
1705    fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1706        self.entries.iter().find_map(|entry| {
1707            if let AgentThreadEntry::UserMessage(message) = entry {
1708                if message.id.as_ref() == Some(id) {
1709                    Some(message)
1710                } else {
1711                    None
1712                }
1713            } else {
1714                None
1715            }
1716        })
1717    }
1718
1719    fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1720        self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1721            if let AgentThreadEntry::UserMessage(message) = entry {
1722                if message.id.as_ref() == Some(id) {
1723                    Some((ix, message))
1724                } else {
1725                    None
1726                }
1727            } else {
1728                None
1729            }
1730        })
1731    }
1732
1733    pub fn read_text_file(
1734        &self,
1735        path: PathBuf,
1736        line: Option<u32>,
1737        limit: Option<u32>,
1738        reuse_shared_snapshot: bool,
1739        cx: &mut Context<Self>,
1740    ) -> Task<Result<String>> {
1741        let project = self.project.clone();
1742        let action_log = self.action_log.clone();
1743        cx.spawn(async move |this, cx| {
1744            let load = project.update(cx, |project, cx| {
1745                let path = project
1746                    .project_path_for_absolute_path(&path, cx)
1747                    .context("invalid path")?;
1748                anyhow::Ok(project.open_buffer(path, cx))
1749            });
1750            let buffer = load??.await?;
1751
1752            let snapshot = if reuse_shared_snapshot {
1753                this.read_with(cx, |this, _| {
1754                    this.shared_buffers.get(&buffer.clone()).cloned()
1755                })
1756                .log_err()
1757                .flatten()
1758            } else {
1759                None
1760            };
1761
1762            let snapshot = if let Some(snapshot) = snapshot {
1763                snapshot
1764            } else {
1765                action_log.update(cx, |action_log, cx| {
1766                    action_log.buffer_read(buffer.clone(), cx);
1767                })?;
1768                project.update(cx, |project, cx| {
1769                    let position = buffer
1770                        .read(cx)
1771                        .snapshot()
1772                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1773                    project.set_agent_location(
1774                        Some(AgentLocation {
1775                            buffer: buffer.downgrade(),
1776                            position,
1777                        }),
1778                        cx,
1779                    );
1780                })?;
1781
1782                buffer.update(cx, |buffer, _| buffer.snapshot())?
1783            };
1784
1785            this.update(cx, |this, _| {
1786                let text = snapshot.text();
1787                this.shared_buffers.insert(buffer.clone(), snapshot);
1788                if line.is_none() && limit.is_none() {
1789                    return Ok(text);
1790                }
1791                let limit = limit.unwrap_or(u32::MAX) as usize;
1792                let Some(line) = line else {
1793                    return Ok(text.lines().take(limit).collect::<String>());
1794                };
1795
1796                let count = text.lines().count();
1797                if count < line as usize {
1798                    anyhow::bail!("There are only {} lines", count);
1799                }
1800                Ok(text
1801                    .lines()
1802                    .skip(line as usize + 1)
1803                    .take(limit)
1804                    .collect::<String>())
1805            })?
1806        })
1807    }
1808
1809    pub fn write_text_file(
1810        &self,
1811        path: PathBuf,
1812        content: String,
1813        cx: &mut Context<Self>,
1814    ) -> Task<Result<()>> {
1815        let project = self.project.clone();
1816        let action_log = self.action_log.clone();
1817        cx.spawn(async move |this, cx| {
1818            let load = project.update(cx, |project, cx| {
1819                let path = project
1820                    .project_path_for_absolute_path(&path, cx)
1821                    .context("invalid path")?;
1822                anyhow::Ok(project.open_buffer(path, cx))
1823            });
1824            let buffer = load??.await?;
1825            let snapshot = this.update(cx, |this, cx| {
1826                this.shared_buffers
1827                    .get(&buffer)
1828                    .cloned()
1829                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1830            })?;
1831            let edits = cx
1832                .background_executor()
1833                .spawn(async move {
1834                    let old_text = snapshot.text();
1835                    text_diff(old_text.as_str(), &content)
1836                        .into_iter()
1837                        .map(|(range, replacement)| {
1838                            (
1839                                snapshot.anchor_after(range.start)
1840                                    ..snapshot.anchor_before(range.end),
1841                                replacement,
1842                            )
1843                        })
1844                        .collect::<Vec<_>>()
1845                })
1846                .await;
1847
1848            project.update(cx, |project, cx| {
1849                project.set_agent_location(
1850                    Some(AgentLocation {
1851                        buffer: buffer.downgrade(),
1852                        position: edits
1853                            .last()
1854                            .map(|(range, _)| range.end)
1855                            .unwrap_or(Anchor::MIN),
1856                    }),
1857                    cx,
1858                );
1859            })?;
1860
1861            let format_on_save = cx.update(|cx| {
1862                action_log.update(cx, |action_log, cx| {
1863                    action_log.buffer_read(buffer.clone(), cx);
1864                });
1865
1866                let format_on_save = buffer.update(cx, |buffer, cx| {
1867                    buffer.edit(edits, None, cx);
1868
1869                    let settings = language::language_settings::language_settings(
1870                        buffer.language().map(|l| l.name()),
1871                        buffer.file(),
1872                        cx,
1873                    );
1874
1875                    settings.format_on_save != FormatOnSave::Off
1876                });
1877                action_log.update(cx, |action_log, cx| {
1878                    action_log.buffer_edited(buffer.clone(), cx);
1879                });
1880                format_on_save
1881            })?;
1882
1883            if format_on_save {
1884                let format_task = project.update(cx, |project, cx| {
1885                    project.format(
1886                        HashSet::from_iter([buffer.clone()]),
1887                        LspFormatTarget::Buffers,
1888                        false,
1889                        FormatTrigger::Save,
1890                        cx,
1891                    )
1892                })?;
1893                format_task.await.log_err();
1894
1895                action_log.update(cx, |action_log, cx| {
1896                    action_log.buffer_edited(buffer.clone(), cx);
1897                })?;
1898            }
1899
1900            project
1901                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1902                .await
1903        })
1904    }
1905
1906    pub fn create_terminal(
1907        &self,
1908        mut command: String,
1909        args: Vec<String>,
1910        extra_env: Vec<acp::EnvVariable>,
1911        cwd: Option<PathBuf>,
1912        output_byte_limit: Option<u64>,
1913        cx: &mut Context<Self>,
1914    ) -> Task<Result<Entity<Terminal>>> {
1915        for arg in args {
1916            command.push(' ');
1917            command.push_str(&arg);
1918        }
1919
1920        let shell_command = if cfg!(windows) {
1921            format!("$null | & {{{}}}", command.replace("\"", "'"))
1922        } else if let Some(cwd) = cwd.as_ref().and_then(|cwd| cwd.as_os_str().to_str()) {
1923            // Make sure once we're *inside* the shell, we cd into `cwd`
1924            format!("(cd {cwd}; {}) </dev/null", command)
1925        } else {
1926            format!("({}) </dev/null", command)
1927        };
1928        let args = vec!["-c".into(), shell_command];
1929
1930        let env = match &cwd {
1931            Some(dir) => self.project.update(cx, |project, cx| {
1932                project.directory_environment(dir.as_path().into(), cx)
1933            }),
1934            None => Task::ready(None).shared(),
1935        };
1936
1937        let env = cx.spawn(async move |_, _| {
1938            let mut env = env.await.unwrap_or_default();
1939            if cfg!(unix) {
1940                env.insert("PAGER".into(), "cat".into());
1941            }
1942            for var in extra_env {
1943                env.insert(var.name, var.value);
1944            }
1945            env
1946        });
1947
1948        let project = self.project.clone();
1949        let language_registry = project.read(cx).languages().clone();
1950        let determine_shell = self.determine_shell.clone();
1951
1952        let terminal_id = acp::TerminalId(Uuid::new_v4().to_string().into());
1953        let terminal_task = cx.spawn({
1954            let terminal_id = terminal_id.clone();
1955            async move |_this, cx| {
1956                let program = determine_shell.await;
1957                let env = env.await;
1958                let terminal = project
1959                    .update(cx, |project, cx| {
1960                        project.create_terminal_task(
1961                            task::SpawnInTerminal {
1962                                command: Some(program),
1963                                args,
1964                                cwd: cwd.clone(),
1965                                env,
1966                                ..Default::default()
1967                            },
1968                            cx,
1969                        )
1970                    })?
1971                    .await?;
1972
1973                cx.new(|cx| {
1974                    Terminal::new(
1975                        terminal_id,
1976                        command,
1977                        cwd,
1978                        output_byte_limit.map(|l| l as usize),
1979                        terminal,
1980                        language_registry,
1981                        cx,
1982                    )
1983                })
1984            }
1985        });
1986
1987        cx.spawn(async move |this, cx| {
1988            let terminal = terminal_task.await?;
1989            this.update(cx, |this, _cx| {
1990                this.terminals.insert(terminal_id, terminal.clone());
1991                terminal
1992            })
1993        })
1994    }
1995
1996    pub fn kill_terminal(
1997        &mut self,
1998        terminal_id: acp::TerminalId,
1999        cx: &mut Context<Self>,
2000    ) -> Result<()> {
2001        self.terminals
2002            .get(&terminal_id)
2003            .context("Terminal not found")?
2004            .update(cx, |terminal, cx| {
2005                terminal.kill(cx);
2006            });
2007
2008        Ok(())
2009    }
2010
2011    pub fn release_terminal(
2012        &mut self,
2013        terminal_id: acp::TerminalId,
2014        cx: &mut Context<Self>,
2015    ) -> Result<()> {
2016        self.terminals
2017            .remove(&terminal_id)
2018            .context("Terminal not found")?
2019            .update(cx, |terminal, cx| {
2020                terminal.kill(cx);
2021            });
2022
2023        Ok(())
2024    }
2025
2026    pub fn terminal(&self, terminal_id: acp::TerminalId) -> Result<Entity<Terminal>> {
2027        self.terminals
2028            .get(&terminal_id)
2029            .context("Terminal not found")
2030            .cloned()
2031    }
2032
2033    pub fn to_markdown(&self, cx: &App) -> String {
2034        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
2035    }
2036
2037    pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
2038        cx.emit(AcpThreadEvent::LoadError(error));
2039    }
2040}
2041
2042fn markdown_for_raw_output(
2043    raw_output: &serde_json::Value,
2044    language_registry: &Arc<LanguageRegistry>,
2045    cx: &mut App,
2046) -> Option<Entity<Markdown>> {
2047    match raw_output {
2048        serde_json::Value::Null => None,
2049        serde_json::Value::Bool(value) => Some(cx.new(|cx| {
2050            Markdown::new(
2051                value.to_string().into(),
2052                Some(language_registry.clone()),
2053                None,
2054                cx,
2055            )
2056        })),
2057        serde_json::Value::Number(value) => Some(cx.new(|cx| {
2058            Markdown::new(
2059                value.to_string().into(),
2060                Some(language_registry.clone()),
2061                None,
2062                cx,
2063            )
2064        })),
2065        serde_json::Value::String(value) => Some(cx.new(|cx| {
2066            Markdown::new(
2067                value.clone().into(),
2068                Some(language_registry.clone()),
2069                None,
2070                cx,
2071            )
2072        })),
2073        value => Some(cx.new(|cx| {
2074            Markdown::new(
2075                format!("```json\n{}\n```", value).into(),
2076                Some(language_registry.clone()),
2077                None,
2078                cx,
2079            )
2080        })),
2081    }
2082}
2083
2084#[cfg(test)]
2085mod tests {
2086    use super::*;
2087    use anyhow::anyhow;
2088    use futures::{channel::mpsc, future::LocalBoxFuture, select};
2089    use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
2090    use indoc::indoc;
2091    use project::{FakeFs, Fs};
2092    use rand::Rng as _;
2093    use serde_json::json;
2094    use settings::SettingsStore;
2095    use smol::stream::StreamExt as _;
2096    use std::{
2097        any::Any,
2098        cell::RefCell,
2099        path::Path,
2100        rc::Rc,
2101        sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
2102        time::Duration,
2103    };
2104    use util::path;
2105
2106    fn init_test(cx: &mut TestAppContext) {
2107        env_logger::try_init().ok();
2108        cx.update(|cx| {
2109            let settings_store = SettingsStore::test(cx);
2110            cx.set_global(settings_store);
2111            Project::init_settings(cx);
2112            language::init(cx);
2113        });
2114    }
2115
2116    #[gpui::test]
2117    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
2118        init_test(cx);
2119
2120        let fs = FakeFs::new(cx.executor());
2121        let project = Project::test(fs, [], cx).await;
2122        let connection = Rc::new(FakeAgentConnection::new());
2123        let thread = cx
2124            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2125            .await
2126            .unwrap();
2127
2128        // Test creating a new user message
2129        thread.update(cx, |thread, cx| {
2130            thread.push_user_content_block(
2131                None,
2132                acp::ContentBlock::Text(acp::TextContent {
2133                    annotations: None,
2134                    text: "Hello, ".to_string(),
2135                }),
2136                cx,
2137            );
2138        });
2139
2140        thread.update(cx, |thread, cx| {
2141            assert_eq!(thread.entries.len(), 1);
2142            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2143                assert_eq!(user_msg.id, None);
2144                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
2145            } else {
2146                panic!("Expected UserMessage");
2147            }
2148        });
2149
2150        // Test appending to existing user message
2151        let message_1_id = UserMessageId::new();
2152        thread.update(cx, |thread, cx| {
2153            thread.push_user_content_block(
2154                Some(message_1_id.clone()),
2155                acp::ContentBlock::Text(acp::TextContent {
2156                    annotations: None,
2157                    text: "world!".to_string(),
2158                }),
2159                cx,
2160            );
2161        });
2162
2163        thread.update(cx, |thread, cx| {
2164            assert_eq!(thread.entries.len(), 1);
2165            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
2166                assert_eq!(user_msg.id, Some(message_1_id));
2167                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
2168            } else {
2169                panic!("Expected UserMessage");
2170            }
2171        });
2172
2173        // Test creating new user message after assistant message
2174        thread.update(cx, |thread, cx| {
2175            thread.push_assistant_content_block(
2176                acp::ContentBlock::Text(acp::TextContent {
2177                    annotations: None,
2178                    text: "Assistant response".to_string(),
2179                }),
2180                false,
2181                cx,
2182            );
2183        });
2184
2185        let message_2_id = UserMessageId::new();
2186        thread.update(cx, |thread, cx| {
2187            thread.push_user_content_block(
2188                Some(message_2_id.clone()),
2189                acp::ContentBlock::Text(acp::TextContent {
2190                    annotations: None,
2191                    text: "New user message".to_string(),
2192                }),
2193                cx,
2194            );
2195        });
2196
2197        thread.update(cx, |thread, cx| {
2198            assert_eq!(thread.entries.len(), 3);
2199            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
2200                assert_eq!(user_msg.id, Some(message_2_id));
2201                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2202            } else {
2203                panic!("Expected UserMessage at index 2");
2204            }
2205        });
2206    }
2207
2208    #[gpui::test]
2209    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2210        init_test(cx);
2211
2212        let fs = FakeFs::new(cx.executor());
2213        let project = Project::test(fs, [], cx).await;
2214        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2215            |_, thread, mut cx| {
2216                async move {
2217                    thread.update(&mut cx, |thread, cx| {
2218                        thread
2219                            .handle_session_update(
2220                                acp::SessionUpdate::AgentThoughtChunk {
2221                                    content: "Thinking ".into(),
2222                                },
2223                                cx,
2224                            )
2225                            .unwrap();
2226                        thread
2227                            .handle_session_update(
2228                                acp::SessionUpdate::AgentThoughtChunk {
2229                                    content: "hard!".into(),
2230                                },
2231                                cx,
2232                            )
2233                            .unwrap();
2234                    })?;
2235                    Ok(acp::PromptResponse {
2236                        stop_reason: acp::StopReason::EndTurn,
2237                    })
2238                }
2239                .boxed_local()
2240            },
2241        ));
2242
2243        let thread = cx
2244            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2245            .await
2246            .unwrap();
2247
2248        thread
2249            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2250            .await
2251            .unwrap();
2252
2253        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2254        assert_eq!(
2255            output,
2256            indoc! {r#"
2257            ## User
2258
2259            Hello from Zed!
2260
2261            ## Assistant
2262
2263            <thinking>
2264            Thinking hard!
2265            </thinking>
2266
2267            "#}
2268        );
2269    }
2270
2271    #[gpui::test]
2272    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2273        init_test(cx);
2274
2275        let fs = FakeFs::new(cx.executor());
2276        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2277            .await;
2278        let project = Project::test(fs.clone(), [], cx).await;
2279        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2280        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2281        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2282            move |_, thread, mut cx| {
2283                let read_file_tx = read_file_tx.clone();
2284                async move {
2285                    let content = thread
2286                        .update(&mut cx, |thread, cx| {
2287                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2288                        })
2289                        .unwrap()
2290                        .await
2291                        .unwrap();
2292                    assert_eq!(content, "one\ntwo\nthree\n");
2293                    read_file_tx.take().unwrap().send(()).unwrap();
2294                    thread
2295                        .update(&mut cx, |thread, cx| {
2296                            thread.write_text_file(
2297                                path!("/tmp/foo").into(),
2298                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
2299                                cx,
2300                            )
2301                        })
2302                        .unwrap()
2303                        .await
2304                        .unwrap();
2305                    Ok(acp::PromptResponse {
2306                        stop_reason: acp::StopReason::EndTurn,
2307                    })
2308                }
2309                .boxed_local()
2310            },
2311        ));
2312
2313        let (worktree, pathbuf) = project
2314            .update(cx, |project, cx| {
2315                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2316            })
2317            .await
2318            .unwrap();
2319        let buffer = project
2320            .update(cx, |project, cx| {
2321                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2322            })
2323            .await
2324            .unwrap();
2325
2326        let thread = cx
2327            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2328            .await
2329            .unwrap();
2330
2331        let request = thread.update(cx, |thread, cx| {
2332            thread.send_raw("Extend the count in /tmp/foo", cx)
2333        });
2334        read_file_rx.await.ok();
2335        buffer.update(cx, |buffer, cx| {
2336            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2337        });
2338        cx.run_until_parked();
2339        assert_eq!(
2340            buffer.read_with(cx, |buffer, _| buffer.text()),
2341            "zero\none\ntwo\nthree\nfour\nfive\n"
2342        );
2343        assert_eq!(
2344            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2345            "zero\none\ntwo\nthree\nfour\nfive\n"
2346        );
2347        request.await.unwrap();
2348    }
2349
2350    #[gpui::test]
2351    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2352        init_test(cx);
2353
2354        let fs = FakeFs::new(cx.executor());
2355        let project = Project::test(fs, [], cx).await;
2356        let id = acp::ToolCallId("test".into());
2357
2358        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2359            let id = id.clone();
2360            move |_, thread, mut cx| {
2361                let id = id.clone();
2362                async move {
2363                    thread
2364                        .update(&mut cx, |thread, cx| {
2365                            thread.handle_session_update(
2366                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2367                                    id: id.clone(),
2368                                    title: "Label".into(),
2369                                    kind: acp::ToolKind::Fetch,
2370                                    status: acp::ToolCallStatus::InProgress,
2371                                    content: vec![],
2372                                    locations: vec![],
2373                                    raw_input: None,
2374                                    raw_output: None,
2375                                }),
2376                                cx,
2377                            )
2378                        })
2379                        .unwrap()
2380                        .unwrap();
2381                    Ok(acp::PromptResponse {
2382                        stop_reason: acp::StopReason::EndTurn,
2383                    })
2384                }
2385                .boxed_local()
2386            }
2387        }));
2388
2389        let thread = cx
2390            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2391            .await
2392            .unwrap();
2393
2394        let request = thread.update(cx, |thread, cx| {
2395            thread.send_raw("Fetch https://example.com", cx)
2396        });
2397
2398        run_until_first_tool_call(&thread, cx).await;
2399
2400        thread.read_with(cx, |thread, _| {
2401            assert!(matches!(
2402                thread.entries[1],
2403                AgentThreadEntry::ToolCall(ToolCall {
2404                    status: ToolCallStatus::InProgress,
2405                    ..
2406                })
2407            ));
2408        });
2409
2410        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2411
2412        thread.read_with(cx, |thread, _| {
2413            assert!(matches!(
2414                &thread.entries[1],
2415                AgentThreadEntry::ToolCall(ToolCall {
2416                    status: ToolCallStatus::Canceled,
2417                    ..
2418                })
2419            ));
2420        });
2421
2422        thread
2423            .update(cx, |thread, cx| {
2424                thread.handle_session_update(
2425                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2426                        id,
2427                        fields: acp::ToolCallUpdateFields {
2428                            status: Some(acp::ToolCallStatus::Completed),
2429                            ..Default::default()
2430                        },
2431                    }),
2432                    cx,
2433                )
2434            })
2435            .unwrap();
2436
2437        request.await.unwrap();
2438
2439        thread.read_with(cx, |thread, _| {
2440            assert!(matches!(
2441                thread.entries[1],
2442                AgentThreadEntry::ToolCall(ToolCall {
2443                    status: ToolCallStatus::Completed,
2444                    ..
2445                })
2446            ));
2447        });
2448    }
2449
2450    #[gpui::test]
2451    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2452        init_test(cx);
2453        let fs = FakeFs::new(cx.background_executor.clone());
2454        fs.insert_tree(path!("/test"), json!({})).await;
2455        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2456
2457        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2458            move |_, thread, mut cx| {
2459                async move {
2460                    thread
2461                        .update(&mut cx, |thread, cx| {
2462                            thread.handle_session_update(
2463                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2464                                    id: acp::ToolCallId("test".into()),
2465                                    title: "Label".into(),
2466                                    kind: acp::ToolKind::Edit,
2467                                    status: acp::ToolCallStatus::Completed,
2468                                    content: vec![acp::ToolCallContent::Diff {
2469                                        diff: acp::Diff {
2470                                            path: "/test/test.txt".into(),
2471                                            old_text: None,
2472                                            new_text: "foo".into(),
2473                                        },
2474                                    }],
2475                                    locations: vec![],
2476                                    raw_input: None,
2477                                    raw_output: None,
2478                                }),
2479                                cx,
2480                            )
2481                        })
2482                        .unwrap()
2483                        .unwrap();
2484                    Ok(acp::PromptResponse {
2485                        stop_reason: acp::StopReason::EndTurn,
2486                    })
2487                }
2488                .boxed_local()
2489            }
2490        }));
2491
2492        let thread = cx
2493            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2494            .await
2495            .unwrap();
2496
2497        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2498            .await
2499            .unwrap();
2500
2501        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2502    }
2503
2504    #[gpui::test(iterations = 10)]
2505    async fn test_checkpoints(cx: &mut TestAppContext) {
2506        init_test(cx);
2507        let fs = FakeFs::new(cx.background_executor.clone());
2508        fs.insert_tree(
2509            path!("/test"),
2510            json!({
2511                ".git": {}
2512            }),
2513        )
2514        .await;
2515        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2516
2517        let simulate_changes = Arc::new(AtomicBool::new(true));
2518        let next_filename = Arc::new(AtomicUsize::new(0));
2519        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2520            let simulate_changes = simulate_changes.clone();
2521            let next_filename = next_filename.clone();
2522            let fs = fs.clone();
2523            move |request, thread, mut cx| {
2524                let fs = fs.clone();
2525                let simulate_changes = simulate_changes.clone();
2526                let next_filename = next_filename.clone();
2527                async move {
2528                    if simulate_changes.load(SeqCst) {
2529                        let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2530                        fs.write(Path::new(&filename), b"").await?;
2531                    }
2532
2533                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2534                        panic!("expected text content block");
2535                    };
2536                    thread.update(&mut cx, |thread, cx| {
2537                        thread
2538                            .handle_session_update(
2539                                acp::SessionUpdate::AgentMessageChunk {
2540                                    content: content.text.to_uppercase().into(),
2541                                },
2542                                cx,
2543                            )
2544                            .unwrap();
2545                    })?;
2546                    Ok(acp::PromptResponse {
2547                        stop_reason: acp::StopReason::EndTurn,
2548                    })
2549                }
2550                .boxed_local()
2551            }
2552        }));
2553        let thread = cx
2554            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2555            .await
2556            .unwrap();
2557
2558        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2559            .await
2560            .unwrap();
2561        thread.read_with(cx, |thread, cx| {
2562            assert_eq!(
2563                thread.to_markdown(cx),
2564                indoc! {"
2565                    ## User (checkpoint)
2566
2567                    Lorem
2568
2569                    ## Assistant
2570
2571                    LOREM
2572
2573                "}
2574            );
2575        });
2576        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2577
2578        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2579            .await
2580            .unwrap();
2581        thread.read_with(cx, |thread, cx| {
2582            assert_eq!(
2583                thread.to_markdown(cx),
2584                indoc! {"
2585                    ## User (checkpoint)
2586
2587                    Lorem
2588
2589                    ## Assistant
2590
2591                    LOREM
2592
2593                    ## User (checkpoint)
2594
2595                    ipsum
2596
2597                    ## Assistant
2598
2599                    IPSUM
2600
2601                "}
2602            );
2603        });
2604        assert_eq!(
2605            fs.files(),
2606            vec![
2607                Path::new(path!("/test/file-0")),
2608                Path::new(path!("/test/file-1"))
2609            ]
2610        );
2611
2612        // Checkpoint isn't stored when there are no changes.
2613        simulate_changes.store(false, SeqCst);
2614        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2615            .await
2616            .unwrap();
2617        thread.read_with(cx, |thread, cx| {
2618            assert_eq!(
2619                thread.to_markdown(cx),
2620                indoc! {"
2621                    ## User (checkpoint)
2622
2623                    Lorem
2624
2625                    ## Assistant
2626
2627                    LOREM
2628
2629                    ## User (checkpoint)
2630
2631                    ipsum
2632
2633                    ## Assistant
2634
2635                    IPSUM
2636
2637                    ## User
2638
2639                    dolor
2640
2641                    ## Assistant
2642
2643                    DOLOR
2644
2645                "}
2646            );
2647        });
2648        assert_eq!(
2649            fs.files(),
2650            vec![
2651                Path::new(path!("/test/file-0")),
2652                Path::new(path!("/test/file-1"))
2653            ]
2654        );
2655
2656        // Rewinding the conversation truncates the history and restores the checkpoint.
2657        thread
2658            .update(cx, |thread, cx| {
2659                let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2660                    panic!("unexpected entries {:?}", thread.entries)
2661                };
2662                thread.rewind(message.id.clone().unwrap(), cx)
2663            })
2664            .await
2665            .unwrap();
2666        thread.read_with(cx, |thread, cx| {
2667            assert_eq!(
2668                thread.to_markdown(cx),
2669                indoc! {"
2670                    ## User (checkpoint)
2671
2672                    Lorem
2673
2674                    ## Assistant
2675
2676                    LOREM
2677
2678                "}
2679            );
2680        });
2681        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2682    }
2683
2684    #[gpui::test]
2685    async fn test_refusal(cx: &mut TestAppContext) {
2686        init_test(cx);
2687        let fs = FakeFs::new(cx.background_executor.clone());
2688        fs.insert_tree(path!("/"), json!({})).await;
2689        let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2690
2691        let refuse_next = Arc::new(AtomicBool::new(false));
2692        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2693            let refuse_next = refuse_next.clone();
2694            move |request, thread, mut cx| {
2695                let refuse_next = refuse_next.clone();
2696                async move {
2697                    if refuse_next.load(SeqCst) {
2698                        return Ok(acp::PromptResponse {
2699                            stop_reason: acp::StopReason::Refusal,
2700                        });
2701                    }
2702
2703                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2704                        panic!("expected text content block");
2705                    };
2706                    thread.update(&mut cx, |thread, cx| {
2707                        thread
2708                            .handle_session_update(
2709                                acp::SessionUpdate::AgentMessageChunk {
2710                                    content: content.text.to_uppercase().into(),
2711                                },
2712                                cx,
2713                            )
2714                            .unwrap();
2715                    })?;
2716                    Ok(acp::PromptResponse {
2717                        stop_reason: acp::StopReason::EndTurn,
2718                    })
2719                }
2720                .boxed_local()
2721            }
2722        }));
2723        let thread = cx
2724            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2725            .await
2726            .unwrap();
2727
2728        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2729            .await
2730            .unwrap();
2731        thread.read_with(cx, |thread, cx| {
2732            assert_eq!(
2733                thread.to_markdown(cx),
2734                indoc! {"
2735                    ## User
2736
2737                    hello
2738
2739                    ## Assistant
2740
2741                    HELLO
2742
2743                "}
2744            );
2745        });
2746
2747        // Simulate refusing the second message, ensuring the conversation gets
2748        // truncated to before sending it.
2749        refuse_next.store(true, SeqCst);
2750        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
2751            .await
2752            .unwrap();
2753        thread.read_with(cx, |thread, cx| {
2754            assert_eq!(
2755                thread.to_markdown(cx),
2756                indoc! {"
2757                    ## User
2758
2759                    hello
2760
2761                    ## Assistant
2762
2763                    HELLO
2764
2765                "}
2766            );
2767        });
2768    }
2769
2770    async fn run_until_first_tool_call(
2771        thread: &Entity<AcpThread>,
2772        cx: &mut TestAppContext,
2773    ) -> usize {
2774        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2775
2776        let subscription = cx.update(|cx| {
2777            cx.subscribe(thread, move |thread, _, cx| {
2778                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2779                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2780                        return tx.try_send(ix).unwrap();
2781                    }
2782                }
2783            })
2784        });
2785
2786        select! {
2787            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2788                panic!("Timeout waiting for tool call")
2789            }
2790            ix = rx.next().fuse() => {
2791                drop(subscription);
2792                ix.unwrap()
2793            }
2794        }
2795    }
2796
2797    #[derive(Clone, Default)]
2798    struct FakeAgentConnection {
2799        auth_methods: Vec<acp::AuthMethod>,
2800        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2801        on_user_message: Option<
2802            Rc<
2803                dyn Fn(
2804                        acp::PromptRequest,
2805                        WeakEntity<AcpThread>,
2806                        AsyncApp,
2807                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2808                    + 'static,
2809            >,
2810        >,
2811    }
2812
2813    impl FakeAgentConnection {
2814        fn new() -> Self {
2815            Self {
2816                auth_methods: Vec::new(),
2817                on_user_message: None,
2818                sessions: Arc::default(),
2819            }
2820        }
2821
2822        #[expect(unused)]
2823        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2824            self.auth_methods = auth_methods;
2825            self
2826        }
2827
2828        fn on_user_message(
2829            mut self,
2830            handler: impl Fn(
2831                acp::PromptRequest,
2832                WeakEntity<AcpThread>,
2833                AsyncApp,
2834            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2835            + 'static,
2836        ) -> Self {
2837            self.on_user_message.replace(Rc::new(handler));
2838            self
2839        }
2840    }
2841
2842    impl AgentConnection for FakeAgentConnection {
2843        fn auth_methods(&self) -> &[acp::AuthMethod] {
2844            &self.auth_methods
2845        }
2846
2847        fn new_thread(
2848            self: Rc<Self>,
2849            project: Entity<Project>,
2850            _cwd: &Path,
2851            cx: &mut App,
2852        ) -> Task<gpui::Result<Entity<AcpThread>>> {
2853            let session_id = acp::SessionId(
2854                rand::thread_rng()
2855                    .sample_iter(&rand::distributions::Alphanumeric)
2856                    .take(7)
2857                    .map(char::from)
2858                    .collect::<String>()
2859                    .into(),
2860            );
2861            let action_log = cx.new(|_| ActionLog::new(project.clone()));
2862            let thread = cx.new(|cx| {
2863                AcpThread::new(
2864                    "Test",
2865                    self.clone(),
2866                    project,
2867                    action_log,
2868                    session_id.clone(),
2869                    watch::Receiver::constant(acp::PromptCapabilities {
2870                        image: true,
2871                        audio: true,
2872                        embedded_context: true,
2873                    }),
2874                    vec![],
2875                    cx,
2876                )
2877            });
2878            self.sessions.lock().insert(session_id, thread.downgrade());
2879            Task::ready(Ok(thread))
2880        }
2881
2882        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2883            if self.auth_methods().iter().any(|m| m.id == method) {
2884                Task::ready(Ok(()))
2885            } else {
2886                Task::ready(Err(anyhow!("Invalid Auth Method")))
2887            }
2888        }
2889
2890        fn prompt(
2891            &self,
2892            _id: Option<UserMessageId>,
2893            params: acp::PromptRequest,
2894            cx: &mut App,
2895        ) -> Task<gpui::Result<acp::PromptResponse>> {
2896            let sessions = self.sessions.lock();
2897            let thread = sessions.get(&params.session_id).unwrap();
2898            if let Some(handler) = &self.on_user_message {
2899                let handler = handler.clone();
2900                let thread = thread.clone();
2901                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2902            } else {
2903                Task::ready(Ok(acp::PromptResponse {
2904                    stop_reason: acp::StopReason::EndTurn,
2905                }))
2906            }
2907        }
2908
2909        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2910            let sessions = self.sessions.lock();
2911            let thread = sessions.get(session_id).unwrap().clone();
2912
2913            cx.spawn(async move |cx| {
2914                thread
2915                    .update(cx, |thread, cx| thread.cancel(cx))
2916                    .unwrap()
2917                    .await
2918            })
2919            .detach();
2920        }
2921
2922        fn truncate(
2923            &self,
2924            session_id: &acp::SessionId,
2925            _cx: &App,
2926        ) -> Option<Rc<dyn AgentSessionTruncate>> {
2927            Some(Rc::new(FakeAgentSessionEditor {
2928                _session_id: session_id.clone(),
2929            }))
2930        }
2931
2932        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2933            self
2934        }
2935    }
2936
2937    struct FakeAgentSessionEditor {
2938        _session_id: acp::SessionId,
2939    }
2940
2941    impl AgentSessionTruncate for FakeAgentSessionEditor {
2942        fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2943            Task::ready(Ok(()))
2944        }
2945    }
2946}