acp_thread.rs

   1mod connection;
   2mod diff;
   3mod mention;
   4mod terminal;
   5
   6use agent_settings::AgentSettings;
   7use collections::HashSet;
   8pub use connection::*;
   9pub use diff::*;
  10use language::language_settings::FormatOnSave;
  11pub use mention::*;
  12use project::lsp_store::{FormatTrigger, LspFormatTarget};
  13use serde::{Deserialize, Serialize};
  14use settings::Settings as _;
  15pub use terminal::*;
  16
  17use action_log::ActionLog;
  18use agent_client_protocol as acp;
  19use anyhow::{Context as _, Result, anyhow};
  20use editor::Bias;
  21use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  22use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  23use itertools::Itertools;
  24use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
  25use markdown::Markdown;
  26use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
  27use std::collections::HashMap;
  28use std::error::Error;
  29use std::fmt::{Formatter, Write};
  30use std::ops::Range;
  31use std::process::ExitStatus;
  32use std::rc::Rc;
  33use std::time::{Duration, Instant};
  34use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
  35use ui::App;
  36use util::ResultExt;
  37
  38#[derive(Debug)]
  39pub struct UserMessage {
  40    pub id: Option<UserMessageId>,
  41    pub content: ContentBlock,
  42    pub chunks: Vec<acp::ContentBlock>,
  43    pub checkpoint: Option<Checkpoint>,
  44}
  45
  46#[derive(Debug)]
  47pub struct Checkpoint {
  48    git_checkpoint: GitStoreCheckpoint,
  49    pub show: bool,
  50}
  51
  52impl UserMessage {
  53    fn to_markdown(&self, cx: &App) -> String {
  54        let mut markdown = String::new();
  55        if self
  56            .checkpoint
  57            .as_ref()
  58            .is_some_and(|checkpoint| checkpoint.show)
  59        {
  60            writeln!(markdown, "## User (checkpoint)").unwrap();
  61        } else {
  62            writeln!(markdown, "## User").unwrap();
  63        }
  64        writeln!(markdown).unwrap();
  65        writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
  66        writeln!(markdown).unwrap();
  67        markdown
  68    }
  69}
  70
  71#[derive(Debug, PartialEq)]
  72pub struct AssistantMessage {
  73    pub chunks: Vec<AssistantMessageChunk>,
  74}
  75
  76impl AssistantMessage {
  77    pub fn to_markdown(&self, cx: &App) -> String {
  78        format!(
  79            "## Assistant\n\n{}\n\n",
  80            self.chunks
  81                .iter()
  82                .map(|chunk| chunk.to_markdown(cx))
  83                .join("\n\n")
  84        )
  85    }
  86}
  87
  88#[derive(Debug, PartialEq)]
  89pub enum AssistantMessageChunk {
  90    Message { block: ContentBlock },
  91    Thought { block: ContentBlock },
  92}
  93
  94impl AssistantMessageChunk {
  95    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
  96        Self::Message {
  97            block: ContentBlock::new(chunk.into(), language_registry, cx),
  98        }
  99    }
 100
 101    fn to_markdown(&self, cx: &App) -> String {
 102        match self {
 103            Self::Message { block } => block.to_markdown(cx).to_string(),
 104            Self::Thought { block } => {
 105                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 106            }
 107        }
 108    }
 109}
 110
 111#[derive(Debug)]
 112pub enum AgentThreadEntry {
 113    UserMessage(UserMessage),
 114    AssistantMessage(AssistantMessage),
 115    ToolCall(ToolCall),
 116}
 117
 118impl AgentThreadEntry {
 119    pub fn to_markdown(&self, cx: &App) -> String {
 120        match self {
 121            Self::UserMessage(message) => message.to_markdown(cx),
 122            Self::AssistantMessage(message) => message.to_markdown(cx),
 123            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 124        }
 125    }
 126
 127    pub fn user_message(&self) -> Option<&UserMessage> {
 128        if let AgentThreadEntry::UserMessage(message) = self {
 129            Some(message)
 130        } else {
 131            None
 132        }
 133    }
 134
 135    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 136        if let AgentThreadEntry::ToolCall(call) = self {
 137            itertools::Either::Left(call.diffs())
 138        } else {
 139            itertools::Either::Right(std::iter::empty())
 140        }
 141    }
 142
 143    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 144        if let AgentThreadEntry::ToolCall(call) = self {
 145            itertools::Either::Left(call.terminals())
 146        } else {
 147            itertools::Either::Right(std::iter::empty())
 148        }
 149    }
 150
 151    pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
 152        if let AgentThreadEntry::ToolCall(ToolCall {
 153            locations,
 154            resolved_locations,
 155            ..
 156        }) = self
 157        {
 158            Some((
 159                locations.get(ix)?.clone(),
 160                resolved_locations.get(ix)?.clone()?,
 161            ))
 162        } else {
 163            None
 164        }
 165    }
 166}
 167
 168#[derive(Debug)]
 169pub struct ToolCall {
 170    pub id: acp::ToolCallId,
 171    pub label: Entity<Markdown>,
 172    pub kind: acp::ToolKind,
 173    pub content: Vec<ToolCallContent>,
 174    pub status: ToolCallStatus,
 175    pub locations: Vec<acp::ToolCallLocation>,
 176    pub resolved_locations: Vec<Option<AgentLocation>>,
 177    pub raw_input: Option<serde_json::Value>,
 178    pub raw_output: Option<serde_json::Value>,
 179}
 180
 181impl ToolCall {
 182    fn from_acp(
 183        tool_call: acp::ToolCall,
 184        status: ToolCallStatus,
 185        language_registry: Arc<LanguageRegistry>,
 186        cx: &mut App,
 187    ) -> Self {
 188        let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
 189            first_line.to_owned() + ""
 190        } else {
 191            tool_call.title
 192        };
 193        Self {
 194            id: tool_call.id,
 195            label: cx
 196                .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
 197            kind: tool_call.kind,
 198            content: tool_call
 199                .content
 200                .into_iter()
 201                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 202                .collect(),
 203            locations: tool_call.locations,
 204            resolved_locations: Vec::default(),
 205            status,
 206            raw_input: tool_call.raw_input,
 207            raw_output: tool_call.raw_output,
 208        }
 209    }
 210
 211    fn update_fields(
 212        &mut self,
 213        fields: acp::ToolCallUpdateFields,
 214        language_registry: Arc<LanguageRegistry>,
 215        cx: &mut App,
 216    ) {
 217        let acp::ToolCallUpdateFields {
 218            kind,
 219            status,
 220            title,
 221            content,
 222            locations,
 223            raw_input,
 224            raw_output,
 225        } = fields;
 226
 227        if let Some(kind) = kind {
 228            self.kind = kind;
 229        }
 230
 231        if let Some(status) = status {
 232            self.status = status.into();
 233        }
 234
 235        if let Some(title) = title {
 236            self.label.update(cx, |label, cx| {
 237                if let Some((first_line, _)) = title.split_once("\n") {
 238                    label.replace(first_line.to_owned() + "", cx)
 239                } else {
 240                    label.replace(title, cx);
 241                }
 242            });
 243        }
 244
 245        if let Some(content) = content {
 246            let new_content_len = content.len();
 247            let mut content = content.into_iter();
 248
 249            // Reuse existing content if we can
 250            for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
 251                old.update_from_acp(new, language_registry.clone(), cx);
 252            }
 253            for new in content {
 254                self.content.push(ToolCallContent::from_acp(
 255                    new,
 256                    language_registry.clone(),
 257                    cx,
 258                ))
 259            }
 260            self.content.truncate(new_content_len);
 261        }
 262
 263        if let Some(locations) = locations {
 264            self.locations = locations;
 265        }
 266
 267        if let Some(raw_input) = raw_input {
 268            self.raw_input = Some(raw_input);
 269        }
 270
 271        if let Some(raw_output) = raw_output {
 272            if self.content.is_empty()
 273                && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
 274            {
 275                self.content
 276                    .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
 277                        markdown,
 278                    }));
 279            }
 280            self.raw_output = Some(raw_output);
 281        }
 282    }
 283
 284    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 285        self.content.iter().filter_map(|content| match content {
 286            ToolCallContent::Diff(diff) => Some(diff),
 287            ToolCallContent::ContentBlock(_) => None,
 288            ToolCallContent::Terminal(_) => None,
 289        })
 290    }
 291
 292    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 293        self.content.iter().filter_map(|content| match content {
 294            ToolCallContent::Terminal(terminal) => Some(terminal),
 295            ToolCallContent::ContentBlock(_) => None,
 296            ToolCallContent::Diff(_) => None,
 297        })
 298    }
 299
 300    fn to_markdown(&self, cx: &App) -> String {
 301        let mut markdown = format!(
 302            "**Tool Call: {}**\nStatus: {}\n\n",
 303            self.label.read(cx).source(),
 304            self.status
 305        );
 306        for content in &self.content {
 307            markdown.push_str(content.to_markdown(cx).as_str());
 308            markdown.push_str("\n\n");
 309        }
 310        markdown
 311    }
 312
 313    async fn resolve_location(
 314        location: acp::ToolCallLocation,
 315        project: WeakEntity<Project>,
 316        cx: &mut AsyncApp,
 317    ) -> Option<AgentLocation> {
 318        let buffer = project
 319            .update(cx, |project, cx| {
 320                project
 321                    .project_path_for_absolute_path(&location.path, cx)
 322                    .map(|path| project.open_buffer(path, cx))
 323            })
 324            .ok()??;
 325        let buffer = buffer.await.log_err()?;
 326        let position = buffer
 327            .update(cx, |buffer, _| {
 328                if let Some(row) = location.line {
 329                    let snapshot = buffer.snapshot();
 330                    let column = snapshot.indent_size_for_line(row).len;
 331                    let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
 332                    snapshot.anchor_before(point)
 333                } else {
 334                    Anchor::MIN
 335                }
 336            })
 337            .ok()?;
 338
 339        Some(AgentLocation {
 340            buffer: buffer.downgrade(),
 341            position,
 342        })
 343    }
 344
 345    fn resolve_locations(
 346        &self,
 347        project: Entity<Project>,
 348        cx: &mut App,
 349    ) -> Task<Vec<Option<AgentLocation>>> {
 350        let locations = self.locations.clone();
 351        project.update(cx, |_, cx| {
 352            cx.spawn(async move |project, cx| {
 353                let mut new_locations = Vec::new();
 354                for location in locations {
 355                    new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
 356                }
 357                new_locations
 358            })
 359        })
 360    }
 361}
 362
 363#[derive(Debug)]
 364pub enum ToolCallStatus {
 365    /// The tool call hasn't started running yet, but we start showing it to
 366    /// the user.
 367    Pending,
 368    /// The tool call is waiting for confirmation from the user.
 369    WaitingForConfirmation {
 370        options: Vec<acp::PermissionOption>,
 371        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 372    },
 373    /// The tool call is currently running.
 374    InProgress,
 375    /// The tool call completed successfully.
 376    Completed,
 377    /// The tool call failed.
 378    Failed,
 379    /// The user rejected the tool call.
 380    Rejected,
 381    /// The user canceled generation so the tool call was canceled.
 382    Canceled,
 383}
 384
 385impl From<acp::ToolCallStatus> for ToolCallStatus {
 386    fn from(status: acp::ToolCallStatus) -> Self {
 387        match status {
 388            acp::ToolCallStatus::Pending => Self::Pending,
 389            acp::ToolCallStatus::InProgress => Self::InProgress,
 390            acp::ToolCallStatus::Completed => Self::Completed,
 391            acp::ToolCallStatus::Failed => Self::Failed,
 392        }
 393    }
 394}
 395
 396impl Display for ToolCallStatus {
 397    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 398        write!(
 399            f,
 400            "{}",
 401            match self {
 402                ToolCallStatus::Pending => "Pending",
 403                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 404                ToolCallStatus::InProgress => "In Progress",
 405                ToolCallStatus::Completed => "Completed",
 406                ToolCallStatus::Failed => "Failed",
 407                ToolCallStatus::Rejected => "Rejected",
 408                ToolCallStatus::Canceled => "Canceled",
 409            }
 410        )
 411    }
 412}
 413
 414#[derive(Debug, PartialEq, Clone)]
 415pub enum ContentBlock {
 416    Empty,
 417    Markdown { markdown: Entity<Markdown> },
 418    ResourceLink { resource_link: acp::ResourceLink },
 419}
 420
 421impl ContentBlock {
 422    pub fn new(
 423        block: acp::ContentBlock,
 424        language_registry: &Arc<LanguageRegistry>,
 425        cx: &mut App,
 426    ) -> Self {
 427        let mut this = Self::Empty;
 428        this.append(block, language_registry, cx);
 429        this
 430    }
 431
 432    pub fn new_combined(
 433        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 434        language_registry: Arc<LanguageRegistry>,
 435        cx: &mut App,
 436    ) -> Self {
 437        let mut this = Self::Empty;
 438        for block in blocks {
 439            this.append(block, &language_registry, cx);
 440        }
 441        this
 442    }
 443
 444    pub fn append(
 445        &mut self,
 446        block: acp::ContentBlock,
 447        language_registry: &Arc<LanguageRegistry>,
 448        cx: &mut App,
 449    ) {
 450        if matches!(self, ContentBlock::Empty)
 451            && let acp::ContentBlock::ResourceLink(resource_link) = block
 452        {
 453            *self = ContentBlock::ResourceLink { resource_link };
 454            return;
 455        }
 456
 457        let new_content = self.block_string_contents(block);
 458
 459        match self {
 460            ContentBlock::Empty => {
 461                *self = Self::create_markdown_block(new_content, language_registry, cx);
 462            }
 463            ContentBlock::Markdown { markdown } => {
 464                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 465            }
 466            ContentBlock::ResourceLink { resource_link } => {
 467                let existing_content = Self::resource_link_md(&resource_link.uri);
 468                let combined = format!("{}\n{}", existing_content, new_content);
 469
 470                *self = Self::create_markdown_block(combined, language_registry, cx);
 471            }
 472        }
 473    }
 474
 475    fn create_markdown_block(
 476        content: String,
 477        language_registry: &Arc<LanguageRegistry>,
 478        cx: &mut App,
 479    ) -> ContentBlock {
 480        ContentBlock::Markdown {
 481            markdown: cx
 482                .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
 483        }
 484    }
 485
 486    fn block_string_contents(&self, block: acp::ContentBlock) -> String {
 487        match block {
 488            acp::ContentBlock::Text(text_content) => text_content.text,
 489            acp::ContentBlock::ResourceLink(resource_link) => {
 490                Self::resource_link_md(&resource_link.uri)
 491            }
 492            acp::ContentBlock::Resource(acp::EmbeddedResource {
 493                resource:
 494                    acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
 495                        uri,
 496                        ..
 497                    }),
 498                ..
 499            }) => Self::resource_link_md(&uri),
 500            acp::ContentBlock::Image(image) => Self::image_md(&image),
 501            acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
 502        }
 503    }
 504
 505    fn resource_link_md(uri: &str) -> String {
 506        if let Some(uri) = MentionUri::parse(uri).log_err() {
 507            uri.as_link().to_string()
 508        } else {
 509            uri.to_string()
 510        }
 511    }
 512
 513    fn image_md(_image: &acp::ImageContent) -> String {
 514        "`Image`".into()
 515    }
 516
 517    pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 518        match self {
 519            ContentBlock::Empty => "",
 520            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 521            ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
 522        }
 523    }
 524
 525    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 526        match self {
 527            ContentBlock::Empty => None,
 528            ContentBlock::Markdown { markdown } => Some(markdown),
 529            ContentBlock::ResourceLink { .. } => None,
 530        }
 531    }
 532
 533    pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
 534        match self {
 535            ContentBlock::ResourceLink { resource_link } => Some(resource_link),
 536            _ => None,
 537        }
 538    }
 539}
 540
 541#[derive(Debug)]
 542pub enum ToolCallContent {
 543    ContentBlock(ContentBlock),
 544    Diff(Entity<Diff>),
 545    Terminal(Entity<Terminal>),
 546}
 547
 548impl ToolCallContent {
 549    pub fn from_acp(
 550        content: acp::ToolCallContent,
 551        language_registry: Arc<LanguageRegistry>,
 552        cx: &mut App,
 553    ) -> Self {
 554        match content {
 555            acp::ToolCallContent::Content { content } => {
 556                Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
 557            }
 558            acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
 559                Diff::finalized(
 560                    diff.path,
 561                    diff.old_text,
 562                    diff.new_text,
 563                    language_registry,
 564                    cx,
 565                )
 566            })),
 567        }
 568    }
 569
 570    pub fn update_from_acp(
 571        &mut self,
 572        new: acp::ToolCallContent,
 573        language_registry: Arc<LanguageRegistry>,
 574        cx: &mut App,
 575    ) {
 576        let needs_update = match (&self, &new) {
 577            (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
 578                old_diff.read(cx).needs_update(
 579                    new_diff.old_text.as_deref().unwrap_or(""),
 580                    &new_diff.new_text,
 581                    cx,
 582                )
 583            }
 584            _ => true,
 585        };
 586
 587        if needs_update {
 588            *self = Self::from_acp(new, language_registry, cx);
 589        }
 590    }
 591
 592    pub fn to_markdown(&self, cx: &App) -> String {
 593        match self {
 594            Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
 595            Self::Diff(diff) => diff.read(cx).to_markdown(cx),
 596            Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
 597        }
 598    }
 599}
 600
 601#[derive(Debug, PartialEq)]
 602pub enum ToolCallUpdate {
 603    UpdateFields(acp::ToolCallUpdate),
 604    UpdateDiff(ToolCallUpdateDiff),
 605    UpdateTerminal(ToolCallUpdateTerminal),
 606}
 607
 608impl ToolCallUpdate {
 609    fn id(&self) -> &acp::ToolCallId {
 610        match self {
 611            Self::UpdateFields(update) => &update.id,
 612            Self::UpdateDiff(diff) => &diff.id,
 613            Self::UpdateTerminal(terminal) => &terminal.id,
 614        }
 615    }
 616}
 617
 618impl From<acp::ToolCallUpdate> for ToolCallUpdate {
 619    fn from(update: acp::ToolCallUpdate) -> Self {
 620        Self::UpdateFields(update)
 621    }
 622}
 623
 624impl From<ToolCallUpdateDiff> for ToolCallUpdate {
 625    fn from(diff: ToolCallUpdateDiff) -> Self {
 626        Self::UpdateDiff(diff)
 627    }
 628}
 629
 630#[derive(Debug, PartialEq)]
 631pub struct ToolCallUpdateDiff {
 632    pub id: acp::ToolCallId,
 633    pub diff: Entity<Diff>,
 634}
 635
 636impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
 637    fn from(terminal: ToolCallUpdateTerminal) -> Self {
 638        Self::UpdateTerminal(terminal)
 639    }
 640}
 641
 642#[derive(Debug, PartialEq)]
 643pub struct ToolCallUpdateTerminal {
 644    pub id: acp::ToolCallId,
 645    pub terminal: Entity<Terminal>,
 646}
 647
 648#[derive(Debug, Default)]
 649pub struct Plan {
 650    pub entries: Vec<PlanEntry>,
 651}
 652
 653#[derive(Debug)]
 654pub struct PlanStats<'a> {
 655    pub in_progress_entry: Option<&'a PlanEntry>,
 656    pub pending: u32,
 657    pub completed: u32,
 658}
 659
 660impl Plan {
 661    pub fn is_empty(&self) -> bool {
 662        self.entries.is_empty()
 663    }
 664
 665    pub fn stats(&self) -> PlanStats<'_> {
 666        let mut stats = PlanStats {
 667            in_progress_entry: None,
 668            pending: 0,
 669            completed: 0,
 670        };
 671
 672        for entry in &self.entries {
 673            match &entry.status {
 674                acp::PlanEntryStatus::Pending => {
 675                    stats.pending += 1;
 676                }
 677                acp::PlanEntryStatus::InProgress => {
 678                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 679                }
 680                acp::PlanEntryStatus::Completed => {
 681                    stats.completed += 1;
 682                }
 683            }
 684        }
 685
 686        stats
 687    }
 688}
 689
 690#[derive(Debug)]
 691pub struct PlanEntry {
 692    pub content: Entity<Markdown>,
 693    pub priority: acp::PlanEntryPriority,
 694    pub status: acp::PlanEntryStatus,
 695}
 696
 697impl PlanEntry {
 698    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 699        Self {
 700            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 701            priority: entry.priority,
 702            status: entry.status,
 703        }
 704    }
 705}
 706
 707#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 708pub struct TokenUsage {
 709    pub max_tokens: u64,
 710    pub used_tokens: u64,
 711}
 712
 713impl TokenUsage {
 714    pub fn ratio(&self) -> TokenUsageRatio {
 715        #[cfg(debug_assertions)]
 716        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 717            .unwrap_or("0.8".to_string())
 718            .parse()
 719            .unwrap();
 720        #[cfg(not(debug_assertions))]
 721        let warning_threshold: f32 = 0.8;
 722
 723        // When the maximum is unknown because there is no selected model,
 724        // avoid showing the token limit warning.
 725        if self.max_tokens == 0 {
 726            TokenUsageRatio::Normal
 727        } else if self.used_tokens >= self.max_tokens {
 728            TokenUsageRatio::Exceeded
 729        } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
 730            TokenUsageRatio::Warning
 731        } else {
 732            TokenUsageRatio::Normal
 733        }
 734    }
 735}
 736
 737#[derive(Debug, Clone, PartialEq, Eq)]
 738pub enum TokenUsageRatio {
 739    Normal,
 740    Warning,
 741    Exceeded,
 742}
 743
 744#[derive(Debug, Clone)]
 745pub struct RetryStatus {
 746    pub last_error: SharedString,
 747    pub attempt: usize,
 748    pub max_attempts: usize,
 749    pub started_at: Instant,
 750    pub duration: Duration,
 751}
 752
 753pub struct AcpThread {
 754    title: SharedString,
 755    entries: Vec<AgentThreadEntry>,
 756    plan: Plan,
 757    project: Entity<Project>,
 758    action_log: Entity<ActionLog>,
 759    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 760    send_task: Option<Task<()>>,
 761    connection: Rc<dyn AgentConnection>,
 762    session_id: acp::SessionId,
 763    token_usage: Option<TokenUsage>,
 764    prompt_capabilities: acp::PromptCapabilities,
 765    _observe_prompt_capabilities: Task<anyhow::Result<()>>,
 766}
 767
 768#[derive(Debug)]
 769pub enum AcpThreadEvent {
 770    NewEntry,
 771    TitleUpdated,
 772    TokenUsageUpdated,
 773    EntryUpdated(usize),
 774    EntriesRemoved(Range<usize>),
 775    ToolAuthorizationRequired,
 776    Retry(RetryStatus),
 777    Stopped,
 778    Error,
 779    LoadError(LoadError),
 780    PromptCapabilitiesUpdated,
 781}
 782
 783impl EventEmitter<AcpThreadEvent> for AcpThread {}
 784
 785#[derive(PartialEq, Eq, Debug)]
 786pub enum ThreadStatus {
 787    Idle,
 788    WaitingForToolConfirmation,
 789    Generating,
 790}
 791
 792#[derive(Debug, Clone)]
 793pub enum LoadError {
 794    Unsupported {
 795        command: SharedString,
 796        current_version: SharedString,
 797        minimum_version: SharedString,
 798    },
 799    FailedToInstall(SharedString),
 800    Exited {
 801        status: ExitStatus,
 802    },
 803    Other(SharedString),
 804}
 805
 806impl Display for LoadError {
 807    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 808        match self {
 809            LoadError::Unsupported {
 810                command: path,
 811                current_version,
 812                minimum_version,
 813            } => {
 814                write!(
 815                    f,
 816                    "version {current_version} from {path} is not supported (need at least {minimum_version})"
 817                )
 818            }
 819            LoadError::FailedToInstall(msg) => write!(f, "Failed to install: {msg}"),
 820            LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
 821            LoadError::Other(msg) => write!(f, "{msg}"),
 822        }
 823    }
 824}
 825
 826impl Error for LoadError {}
 827
 828impl AcpThread {
 829    pub fn new(
 830        title: impl Into<SharedString>,
 831        connection: Rc<dyn AgentConnection>,
 832        project: Entity<Project>,
 833        action_log: Entity<ActionLog>,
 834        session_id: acp::SessionId,
 835        mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
 836        cx: &mut Context<Self>,
 837    ) -> Self {
 838        let prompt_capabilities = *prompt_capabilities_rx.borrow();
 839        let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
 840            loop {
 841                let caps = prompt_capabilities_rx.recv().await?;
 842                this.update(cx, |this, cx| {
 843                    this.prompt_capabilities = caps;
 844                    cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
 845                })?;
 846            }
 847        });
 848
 849        Self {
 850            action_log,
 851            shared_buffers: Default::default(),
 852            entries: Default::default(),
 853            plan: Default::default(),
 854            title: title.into(),
 855            project,
 856            send_task: None,
 857            connection,
 858            session_id,
 859            token_usage: None,
 860            prompt_capabilities,
 861            _observe_prompt_capabilities: task,
 862        }
 863    }
 864
 865    pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
 866        self.prompt_capabilities
 867    }
 868
 869    pub fn connection(&self) -> &Rc<dyn AgentConnection> {
 870        &self.connection
 871    }
 872
 873    pub fn action_log(&self) -> &Entity<ActionLog> {
 874        &self.action_log
 875    }
 876
 877    pub fn project(&self) -> &Entity<Project> {
 878        &self.project
 879    }
 880
 881    pub fn title(&self) -> SharedString {
 882        self.title.clone()
 883    }
 884
 885    pub fn entries(&self) -> &[AgentThreadEntry] {
 886        &self.entries
 887    }
 888
 889    pub fn session_id(&self) -> &acp::SessionId {
 890        &self.session_id
 891    }
 892
 893    pub fn status(&self) -> ThreadStatus {
 894        if self.send_task.is_some() {
 895            if self.waiting_for_tool_confirmation() {
 896                ThreadStatus::WaitingForToolConfirmation
 897            } else {
 898                ThreadStatus::Generating
 899            }
 900        } else {
 901            ThreadStatus::Idle
 902        }
 903    }
 904
 905    pub fn token_usage(&self) -> Option<&TokenUsage> {
 906        self.token_usage.as_ref()
 907    }
 908
 909    pub fn has_pending_edit_tool_calls(&self) -> bool {
 910        for entry in self.entries.iter().rev() {
 911            match entry {
 912                AgentThreadEntry::UserMessage(_) => return false,
 913                AgentThreadEntry::ToolCall(
 914                    call @ ToolCall {
 915                        status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
 916                        ..
 917                    },
 918                ) if call.diffs().next().is_some() => {
 919                    return true;
 920                }
 921                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 922            }
 923        }
 924
 925        false
 926    }
 927
 928    pub fn used_tools_since_last_user_message(&self) -> bool {
 929        for entry in self.entries.iter().rev() {
 930            match entry {
 931                AgentThreadEntry::UserMessage(..) => return false,
 932                AgentThreadEntry::AssistantMessage(..) => continue,
 933                AgentThreadEntry::ToolCall(..) => return true,
 934            }
 935        }
 936
 937        false
 938    }
 939
 940    pub fn handle_session_update(
 941        &mut self,
 942        update: acp::SessionUpdate,
 943        cx: &mut Context<Self>,
 944    ) -> Result<(), acp::Error> {
 945        match update {
 946            acp::SessionUpdate::UserMessageChunk { content } => {
 947                self.push_user_content_block(None, content, cx);
 948            }
 949            acp::SessionUpdate::AgentMessageChunk { content } => {
 950                self.push_assistant_content_block(content, false, cx);
 951            }
 952            acp::SessionUpdate::AgentThoughtChunk { content } => {
 953                self.push_assistant_content_block(content, true, cx);
 954            }
 955            acp::SessionUpdate::ToolCall(tool_call) => {
 956                self.upsert_tool_call(tool_call, cx)?;
 957            }
 958            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 959                self.update_tool_call(tool_call_update, cx)?;
 960            }
 961            acp::SessionUpdate::Plan(plan) => {
 962                self.update_plan(plan, cx);
 963            }
 964        }
 965        Ok(())
 966    }
 967
 968    pub fn push_user_content_block(
 969        &mut self,
 970        message_id: Option<UserMessageId>,
 971        chunk: acp::ContentBlock,
 972        cx: &mut Context<Self>,
 973    ) {
 974        let language_registry = self.project.read(cx).languages().clone();
 975        let entries_len = self.entries.len();
 976
 977        if let Some(last_entry) = self.entries.last_mut()
 978            && let AgentThreadEntry::UserMessage(UserMessage {
 979                id,
 980                content,
 981                chunks,
 982                ..
 983            }) = last_entry
 984        {
 985            *id = message_id.or(id.take());
 986            content.append(chunk.clone(), &language_registry, cx);
 987            chunks.push(chunk);
 988            let idx = entries_len - 1;
 989            cx.emit(AcpThreadEvent::EntryUpdated(idx));
 990        } else {
 991            let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
 992            self.push_entry(
 993                AgentThreadEntry::UserMessage(UserMessage {
 994                    id: message_id,
 995                    content,
 996                    chunks: vec![chunk],
 997                    checkpoint: None,
 998                }),
 999                cx,
1000            );
1001        }
1002    }
1003
1004    pub fn push_assistant_content_block(
1005        &mut self,
1006        chunk: acp::ContentBlock,
1007        is_thought: bool,
1008        cx: &mut Context<Self>,
1009    ) {
1010        let language_registry = self.project.read(cx).languages().clone();
1011        let entries_len = self.entries.len();
1012        if let Some(last_entry) = self.entries.last_mut()
1013            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1014        {
1015            let idx = entries_len - 1;
1016            cx.emit(AcpThreadEvent::EntryUpdated(idx));
1017            match (chunks.last_mut(), is_thought) {
1018                (Some(AssistantMessageChunk::Message { block }), false)
1019                | (Some(AssistantMessageChunk::Thought { block }), true) => {
1020                    block.append(chunk, &language_registry, cx)
1021                }
1022                _ => {
1023                    let block = ContentBlock::new(chunk, &language_registry, cx);
1024                    if is_thought {
1025                        chunks.push(AssistantMessageChunk::Thought { block })
1026                    } else {
1027                        chunks.push(AssistantMessageChunk::Message { block })
1028                    }
1029                }
1030            }
1031        } else {
1032            let block = ContentBlock::new(chunk, &language_registry, cx);
1033            let chunk = if is_thought {
1034                AssistantMessageChunk::Thought { block }
1035            } else {
1036                AssistantMessageChunk::Message { block }
1037            };
1038
1039            self.push_entry(
1040                AgentThreadEntry::AssistantMessage(AssistantMessage {
1041                    chunks: vec![chunk],
1042                }),
1043                cx,
1044            );
1045        }
1046    }
1047
1048    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1049        self.entries.push(entry);
1050        cx.emit(AcpThreadEvent::NewEntry);
1051    }
1052
1053    pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1054        self.connection.set_title(&self.session_id, cx).is_some()
1055    }
1056
1057    pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1058        if title != self.title {
1059            self.title = title.clone();
1060            cx.emit(AcpThreadEvent::TitleUpdated);
1061            if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1062                return set_title.run(title, cx);
1063            }
1064        }
1065        Task::ready(Ok(()))
1066    }
1067
1068    pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1069        self.token_usage = usage;
1070        cx.emit(AcpThreadEvent::TokenUsageUpdated);
1071    }
1072
1073    pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1074        cx.emit(AcpThreadEvent::Retry(status));
1075    }
1076
1077    pub fn update_tool_call(
1078        &mut self,
1079        update: impl Into<ToolCallUpdate>,
1080        cx: &mut Context<Self>,
1081    ) -> Result<()> {
1082        let update = update.into();
1083        let languages = self.project.read(cx).languages().clone();
1084
1085        let (ix, current_call) = self
1086            .tool_call_mut(update.id())
1087            .context("Tool call not found")?;
1088        match update {
1089            ToolCallUpdate::UpdateFields(update) => {
1090                let location_updated = update.fields.locations.is_some();
1091                current_call.update_fields(update.fields, languages, cx);
1092                if location_updated {
1093                    self.resolve_locations(update.id, cx);
1094                }
1095            }
1096            ToolCallUpdate::UpdateDiff(update) => {
1097                current_call.content.clear();
1098                current_call
1099                    .content
1100                    .push(ToolCallContent::Diff(update.diff));
1101            }
1102            ToolCallUpdate::UpdateTerminal(update) => {
1103                current_call.content.clear();
1104                current_call
1105                    .content
1106                    .push(ToolCallContent::Terminal(update.terminal));
1107            }
1108        }
1109
1110        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1111
1112        Ok(())
1113    }
1114
1115    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1116    pub fn upsert_tool_call(
1117        &mut self,
1118        tool_call: acp::ToolCall,
1119        cx: &mut Context<Self>,
1120    ) -> Result<(), acp::Error> {
1121        let status = tool_call.status.into();
1122        self.upsert_tool_call_inner(tool_call.into(), status, cx)
1123    }
1124
1125    /// Fails if id does not match an existing entry.
1126    pub fn upsert_tool_call_inner(
1127        &mut self,
1128        tool_call_update: acp::ToolCallUpdate,
1129        status: ToolCallStatus,
1130        cx: &mut Context<Self>,
1131    ) -> Result<(), acp::Error> {
1132        let language_registry = self.project.read(cx).languages().clone();
1133        let id = tool_call_update.id.clone();
1134
1135        if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1136            current_call.update_fields(tool_call_update.fields, language_registry, cx);
1137            current_call.status = status;
1138
1139            cx.emit(AcpThreadEvent::EntryUpdated(ix));
1140        } else {
1141            let call =
1142                ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1143            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1144        };
1145
1146        self.resolve_locations(id, cx);
1147        Ok(())
1148    }
1149
1150    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1151        // The tool call we are looking for is typically the last one, or very close to the end.
1152        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1153        self.entries
1154            .iter_mut()
1155            .enumerate()
1156            .rev()
1157            .find_map(|(index, tool_call)| {
1158                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1159                    && &tool_call.id == id
1160                {
1161                    Some((index, tool_call))
1162                } else {
1163                    None
1164                }
1165            })
1166    }
1167
1168    pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1169        self.entries
1170            .iter()
1171            .enumerate()
1172            .rev()
1173            .find_map(|(index, tool_call)| {
1174                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1175                    && &tool_call.id == id
1176                {
1177                    Some((index, tool_call))
1178                } else {
1179                    None
1180                }
1181            })
1182    }
1183
1184    pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1185        let project = self.project.clone();
1186        let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1187            return;
1188        };
1189        let task = tool_call.resolve_locations(project, cx);
1190        cx.spawn(async move |this, cx| {
1191            let resolved_locations = task.await;
1192            this.update(cx, |this, cx| {
1193                let project = this.project.clone();
1194                let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1195                    return;
1196                };
1197                if let Some(Some(location)) = resolved_locations.last() {
1198                    project.update(cx, |project, cx| {
1199                        if let Some(agent_location) = project.agent_location() {
1200                            let should_ignore = agent_location.buffer == location.buffer
1201                                && location
1202                                    .buffer
1203                                    .update(cx, |buffer, _| {
1204                                        let snapshot = buffer.snapshot();
1205                                        let old_position =
1206                                            agent_location.position.to_point(&snapshot);
1207                                        let new_position = location.position.to_point(&snapshot);
1208                                        // ignore this so that when we get updates from the edit tool
1209                                        // the position doesn't reset to the startof line
1210                                        old_position.row == new_position.row
1211                                            && old_position.column > new_position.column
1212                                    })
1213                                    .ok()
1214                                    .unwrap_or_default();
1215                            if !should_ignore {
1216                                project.set_agent_location(Some(location.clone()), cx);
1217                            }
1218                        }
1219                    });
1220                }
1221                if tool_call.resolved_locations != resolved_locations {
1222                    tool_call.resolved_locations = resolved_locations;
1223                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1224                }
1225            })
1226        })
1227        .detach();
1228    }
1229
1230    pub fn request_tool_call_authorization(
1231        &mut self,
1232        tool_call: acp::ToolCallUpdate,
1233        options: Vec<acp::PermissionOption>,
1234        cx: &mut Context<Self>,
1235    ) -> Result<BoxFuture<'static, acp::RequestPermissionOutcome>> {
1236        let (tx, rx) = oneshot::channel();
1237
1238        if AgentSettings::get_global(cx).always_allow_tool_actions {
1239            // Don't use AllowAlways, because then if you were to turn off always_allow_tool_actions,
1240            // some tools would (incorrectly) continue to auto-accept.
1241            if let Some(allow_once_option) = options.iter().find_map(|option| {
1242                if matches!(option.kind, acp::PermissionOptionKind::AllowOnce) {
1243                    Some(option.id.clone())
1244                } else {
1245                    None
1246                }
1247            }) {
1248                self.upsert_tool_call_inner(tool_call, ToolCallStatus::Pending, cx)?;
1249                return Ok(async {
1250                    acp::RequestPermissionOutcome::Selected {
1251                        option_id: allow_once_option,
1252                    }
1253                }
1254                .boxed());
1255            }
1256        }
1257
1258        let status = ToolCallStatus::WaitingForConfirmation {
1259            options,
1260            respond_tx: tx,
1261        };
1262
1263        self.upsert_tool_call_inner(tool_call, status, cx)?;
1264        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1265
1266        let fut = async {
1267            match rx.await {
1268                Ok(option) => acp::RequestPermissionOutcome::Selected { option_id: option },
1269                Err(oneshot::Canceled) => acp::RequestPermissionOutcome::Cancelled,
1270            }
1271        }
1272        .boxed();
1273
1274        Ok(fut)
1275    }
1276
1277    pub fn authorize_tool_call(
1278        &mut self,
1279        id: acp::ToolCallId,
1280        option_id: acp::PermissionOptionId,
1281        option_kind: acp::PermissionOptionKind,
1282        cx: &mut Context<Self>,
1283    ) {
1284        let Some((ix, call)) = self.tool_call_mut(&id) else {
1285            return;
1286        };
1287
1288        let new_status = match option_kind {
1289            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1290                ToolCallStatus::Rejected
1291            }
1292            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1293                ToolCallStatus::InProgress
1294            }
1295        };
1296
1297        let curr_status = mem::replace(&mut call.status, new_status);
1298
1299        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1300            respond_tx.send(option_id).log_err();
1301        } else if cfg!(debug_assertions) {
1302            panic!("tried to authorize an already authorized tool call");
1303        }
1304
1305        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1306    }
1307
1308    /// Returns true if the last turn is awaiting tool authorization
1309    pub fn waiting_for_tool_confirmation(&self) -> bool {
1310        for entry in self.entries.iter().rev() {
1311            match &entry {
1312                AgentThreadEntry::ToolCall(call) => match call.status {
1313                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
1314                    ToolCallStatus::Pending
1315                    | ToolCallStatus::InProgress
1316                    | ToolCallStatus::Completed
1317                    | ToolCallStatus::Failed
1318                    | ToolCallStatus::Rejected
1319                    | ToolCallStatus::Canceled => continue,
1320                },
1321                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1322                    // Reached the beginning of the turn
1323                    return false;
1324                }
1325            }
1326        }
1327        false
1328    }
1329
1330    pub fn plan(&self) -> &Plan {
1331        &self.plan
1332    }
1333
1334    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1335        let new_entries_len = request.entries.len();
1336        let mut new_entries = request.entries.into_iter();
1337
1338        // Reuse existing markdown to prevent flickering
1339        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1340            let PlanEntry {
1341                content,
1342                priority,
1343                status,
1344            } = old;
1345            content.update(cx, |old, cx| {
1346                old.replace(new.content, cx);
1347            });
1348            *priority = new.priority;
1349            *status = new.status;
1350        }
1351        for new in new_entries {
1352            self.plan.entries.push(PlanEntry::from_acp(new, cx))
1353        }
1354        self.plan.entries.truncate(new_entries_len);
1355
1356        cx.notify();
1357    }
1358
1359    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1360        self.plan
1361            .entries
1362            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1363        cx.notify();
1364    }
1365
1366    #[cfg(any(test, feature = "test-support"))]
1367    pub fn send_raw(
1368        &mut self,
1369        message: &str,
1370        cx: &mut Context<Self>,
1371    ) -> BoxFuture<'static, Result<()>> {
1372        self.send(
1373            vec![acp::ContentBlock::Text(acp::TextContent {
1374                text: message.to_string(),
1375                annotations: None,
1376            })],
1377            cx,
1378        )
1379    }
1380
1381    pub fn send(
1382        &mut self,
1383        message: Vec<acp::ContentBlock>,
1384        cx: &mut Context<Self>,
1385    ) -> BoxFuture<'static, Result<()>> {
1386        let block = ContentBlock::new_combined(
1387            message.clone(),
1388            self.project.read(cx).languages().clone(),
1389            cx,
1390        );
1391        let request = acp::PromptRequest {
1392            prompt: message.clone(),
1393            session_id: self.session_id.clone(),
1394        };
1395        let git_store = self.project.read(cx).git_store().clone();
1396
1397        let message_id = if self.connection.truncate(&self.session_id, cx).is_some() {
1398            Some(UserMessageId::new())
1399        } else {
1400            None
1401        };
1402
1403        self.run_turn(cx, async move |this, cx| {
1404            this.update(cx, |this, cx| {
1405                this.push_entry(
1406                    AgentThreadEntry::UserMessage(UserMessage {
1407                        id: message_id.clone(),
1408                        content: block,
1409                        chunks: message,
1410                        checkpoint: None,
1411                    }),
1412                    cx,
1413                );
1414            })
1415            .ok();
1416
1417            let old_checkpoint = git_store
1418                .update(cx, |git, cx| git.checkpoint(cx))?
1419                .await
1420                .context("failed to get old checkpoint")
1421                .log_err();
1422            this.update(cx, |this, cx| {
1423                if let Some((_ix, message)) = this.last_user_message() {
1424                    message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1425                        git_checkpoint,
1426                        show: false,
1427                    });
1428                }
1429                this.connection.prompt(message_id, request, cx)
1430            })?
1431            .await
1432        })
1433    }
1434
1435    pub fn can_resume(&self, cx: &App) -> bool {
1436        self.connection.resume(&self.session_id, cx).is_some()
1437    }
1438
1439    pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1440        self.run_turn(cx, async move |this, cx| {
1441            this.update(cx, |this, cx| {
1442                this.connection
1443                    .resume(&this.session_id, cx)
1444                    .map(|resume| resume.run(cx))
1445            })?
1446            .context("resuming a session is not supported")?
1447            .await
1448        })
1449    }
1450
1451    fn run_turn(
1452        &mut self,
1453        cx: &mut Context<Self>,
1454        f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1455    ) -> BoxFuture<'static, Result<()>> {
1456        self.clear_completed_plan_entries(cx);
1457
1458        let (tx, rx) = oneshot::channel();
1459        let cancel_task = self.cancel(cx);
1460
1461        self.send_task = Some(cx.spawn(async move |this, cx| {
1462            cancel_task.await;
1463            tx.send(f(this, cx).await).ok();
1464        }));
1465
1466        cx.spawn(async move |this, cx| {
1467            let response = rx.await;
1468
1469            this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1470                .await?;
1471
1472            this.update(cx, |this, cx| {
1473                this.project
1474                    .update(cx, |project, cx| project.set_agent_location(None, cx));
1475                match response {
1476                    Ok(Err(e)) => {
1477                        this.send_task.take();
1478                        cx.emit(AcpThreadEvent::Error);
1479                        Err(e)
1480                    }
1481                    result => {
1482                        let canceled = matches!(
1483                            result,
1484                            Ok(Ok(acp::PromptResponse {
1485                                stop_reason: acp::StopReason::Cancelled
1486                            }))
1487                        );
1488
1489                        // We only take the task if the current prompt wasn't canceled.
1490                        //
1491                        // This prompt may have been canceled because another one was sent
1492                        // while it was still generating. In these cases, dropping `send_task`
1493                        // would cause the next generation to be canceled.
1494                        if !canceled {
1495                            this.send_task.take();
1496                        }
1497
1498                        // Truncate entries if the last prompt was refused.
1499                        if let Ok(Ok(acp::PromptResponse {
1500                            stop_reason: acp::StopReason::Refusal,
1501                        })) = result
1502                            && let Some((ix, _)) = this.last_user_message()
1503                        {
1504                            let range = ix..this.entries.len();
1505                            this.entries.truncate(ix);
1506                            cx.emit(AcpThreadEvent::EntriesRemoved(range));
1507                        }
1508
1509                        cx.emit(AcpThreadEvent::Stopped);
1510                        Ok(())
1511                    }
1512                }
1513            })?
1514        })
1515        .boxed()
1516    }
1517
1518    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1519        let Some(send_task) = self.send_task.take() else {
1520            return Task::ready(());
1521        };
1522
1523        for entry in self.entries.iter_mut() {
1524            if let AgentThreadEntry::ToolCall(call) = entry {
1525                let cancel = matches!(
1526                    call.status,
1527                    ToolCallStatus::Pending
1528                        | ToolCallStatus::WaitingForConfirmation { .. }
1529                        | ToolCallStatus::InProgress
1530                );
1531
1532                if cancel {
1533                    call.status = ToolCallStatus::Canceled;
1534                }
1535            }
1536        }
1537
1538        self.connection.cancel(&self.session_id, cx);
1539
1540        // Wait for the send task to complete
1541        cx.foreground_executor().spawn(send_task)
1542    }
1543
1544    /// Rewinds this thread to before the entry at `index`, removing it and all
1545    /// subsequent entries while reverting any changes made from that point.
1546    pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
1547        let Some(truncate) = self.connection.truncate(&self.session_id, cx) else {
1548            return Task::ready(Err(anyhow!("not supported")));
1549        };
1550        let Some(message) = self.user_message(&id) else {
1551            return Task::ready(Err(anyhow!("message not found")));
1552        };
1553
1554        let checkpoint = message
1555            .checkpoint
1556            .as_ref()
1557            .map(|c| c.git_checkpoint.clone());
1558
1559        let git_store = self.project.read(cx).git_store().clone();
1560        cx.spawn(async move |this, cx| {
1561            if let Some(checkpoint) = checkpoint {
1562                git_store
1563                    .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1564                    .await?;
1565            }
1566
1567            cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
1568            this.update(cx, |this, cx| {
1569                if let Some((ix, _)) = this.user_message_mut(&id) {
1570                    let range = ix..this.entries.len();
1571                    this.entries.truncate(ix);
1572                    cx.emit(AcpThreadEvent::EntriesRemoved(range));
1573                }
1574            })
1575        })
1576    }
1577
1578    fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1579        let git_store = self.project.read(cx).git_store().clone();
1580
1581        let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1582            if let Some(checkpoint) = message.checkpoint.as_ref() {
1583                checkpoint.git_checkpoint.clone()
1584            } else {
1585                return Task::ready(Ok(()));
1586            }
1587        } else {
1588            return Task::ready(Ok(()));
1589        };
1590
1591        let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1592        cx.spawn(async move |this, cx| {
1593            let new_checkpoint = new_checkpoint
1594                .await
1595                .context("failed to get new checkpoint")
1596                .log_err();
1597            if let Some(new_checkpoint) = new_checkpoint {
1598                let equal = git_store
1599                    .update(cx, |git, cx| {
1600                        git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1601                    })?
1602                    .await
1603                    .unwrap_or(true);
1604                this.update(cx, |this, cx| {
1605                    let (ix, message) = this.last_user_message().context("no user message")?;
1606                    let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1607                    checkpoint.show = !equal;
1608                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1609                    anyhow::Ok(())
1610                })??;
1611            }
1612
1613            Ok(())
1614        })
1615    }
1616
1617    fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1618        self.entries
1619            .iter_mut()
1620            .enumerate()
1621            .rev()
1622            .find_map(|(ix, entry)| {
1623                if let AgentThreadEntry::UserMessage(message) = entry {
1624                    Some((ix, message))
1625                } else {
1626                    None
1627                }
1628            })
1629    }
1630
1631    fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
1632        self.entries.iter().find_map(|entry| {
1633            if let AgentThreadEntry::UserMessage(message) = entry {
1634                if message.id.as_ref() == Some(id) {
1635                    Some(message)
1636                } else {
1637                    None
1638                }
1639            } else {
1640                None
1641            }
1642        })
1643    }
1644
1645    fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
1646        self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1647            if let AgentThreadEntry::UserMessage(message) = entry {
1648                if message.id.as_ref() == Some(id) {
1649                    Some((ix, message))
1650                } else {
1651                    None
1652                }
1653            } else {
1654                None
1655            }
1656        })
1657    }
1658
1659    pub fn read_text_file(
1660        &self,
1661        path: PathBuf,
1662        line: Option<u32>,
1663        limit: Option<u32>,
1664        reuse_shared_snapshot: bool,
1665        cx: &mut Context<Self>,
1666    ) -> Task<Result<String>> {
1667        let project = self.project.clone();
1668        let action_log = self.action_log.clone();
1669        cx.spawn(async move |this, cx| {
1670            let load = project.update(cx, |project, cx| {
1671                let path = project
1672                    .project_path_for_absolute_path(&path, cx)
1673                    .context("invalid path")?;
1674                anyhow::Ok(project.open_buffer(path, cx))
1675            });
1676            let buffer = load??.await?;
1677
1678            let snapshot = if reuse_shared_snapshot {
1679                this.read_with(cx, |this, _| {
1680                    this.shared_buffers.get(&buffer.clone()).cloned()
1681                })
1682                .log_err()
1683                .flatten()
1684            } else {
1685                None
1686            };
1687
1688            let snapshot = if let Some(snapshot) = snapshot {
1689                snapshot
1690            } else {
1691                action_log.update(cx, |action_log, cx| {
1692                    action_log.buffer_read(buffer.clone(), cx);
1693                })?;
1694                project.update(cx, |project, cx| {
1695                    let position = buffer
1696                        .read(cx)
1697                        .snapshot()
1698                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1699                    project.set_agent_location(
1700                        Some(AgentLocation {
1701                            buffer: buffer.downgrade(),
1702                            position,
1703                        }),
1704                        cx,
1705                    );
1706                })?;
1707
1708                buffer.update(cx, |buffer, _| buffer.snapshot())?
1709            };
1710
1711            this.update(cx, |this, _| {
1712                let text = snapshot.text();
1713                this.shared_buffers.insert(buffer.clone(), snapshot);
1714                if line.is_none() && limit.is_none() {
1715                    return Ok(text);
1716                }
1717                let limit = limit.unwrap_or(u32::MAX) as usize;
1718                let Some(line) = line else {
1719                    return Ok(text.lines().take(limit).collect::<String>());
1720                };
1721
1722                let count = text.lines().count();
1723                if count < line as usize {
1724                    anyhow::bail!("There are only {} lines", count);
1725                }
1726                Ok(text
1727                    .lines()
1728                    .skip(line as usize + 1)
1729                    .take(limit)
1730                    .collect::<String>())
1731            })?
1732        })
1733    }
1734
1735    pub fn write_text_file(
1736        &self,
1737        path: PathBuf,
1738        content: String,
1739        cx: &mut Context<Self>,
1740    ) -> Task<Result<()>> {
1741        let project = self.project.clone();
1742        let action_log = self.action_log.clone();
1743        cx.spawn(async move |this, cx| {
1744            let load = project.update(cx, |project, cx| {
1745                let path = project
1746                    .project_path_for_absolute_path(&path, cx)
1747                    .context("invalid path")?;
1748                anyhow::Ok(project.open_buffer(path, cx))
1749            });
1750            let buffer = load??.await?;
1751            let snapshot = this.update(cx, |this, cx| {
1752                this.shared_buffers
1753                    .get(&buffer)
1754                    .cloned()
1755                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1756            })?;
1757            let edits = cx
1758                .background_executor()
1759                .spawn(async move {
1760                    let old_text = snapshot.text();
1761                    text_diff(old_text.as_str(), &content)
1762                        .into_iter()
1763                        .map(|(range, replacement)| {
1764                            (
1765                                snapshot.anchor_after(range.start)
1766                                    ..snapshot.anchor_before(range.end),
1767                                replacement,
1768                            )
1769                        })
1770                        .collect::<Vec<_>>()
1771                })
1772                .await;
1773
1774            project.update(cx, |project, cx| {
1775                project.set_agent_location(
1776                    Some(AgentLocation {
1777                        buffer: buffer.downgrade(),
1778                        position: edits
1779                            .last()
1780                            .map(|(range, _)| range.end)
1781                            .unwrap_or(Anchor::MIN),
1782                    }),
1783                    cx,
1784                );
1785            })?;
1786
1787            let format_on_save = cx.update(|cx| {
1788                action_log.update(cx, |action_log, cx| {
1789                    action_log.buffer_read(buffer.clone(), cx);
1790                });
1791
1792                let format_on_save = buffer.update(cx, |buffer, cx| {
1793                    buffer.edit(edits, None, cx);
1794
1795                    let settings = language::language_settings::language_settings(
1796                        buffer.language().map(|l| l.name()),
1797                        buffer.file(),
1798                        cx,
1799                    );
1800
1801                    settings.format_on_save != FormatOnSave::Off
1802                });
1803                action_log.update(cx, |action_log, cx| {
1804                    action_log.buffer_edited(buffer.clone(), cx);
1805                });
1806                format_on_save
1807            })?;
1808
1809            if format_on_save {
1810                let format_task = project.update(cx, |project, cx| {
1811                    project.format(
1812                        HashSet::from_iter([buffer.clone()]),
1813                        LspFormatTarget::Buffers,
1814                        false,
1815                        FormatTrigger::Save,
1816                        cx,
1817                    )
1818                })?;
1819                format_task.await.log_err();
1820
1821                action_log.update(cx, |action_log, cx| {
1822                    action_log.buffer_edited(buffer.clone(), cx);
1823                })?;
1824            }
1825
1826            project
1827                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1828                .await
1829        })
1830    }
1831
1832    pub fn to_markdown(&self, cx: &App) -> String {
1833        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1834    }
1835
1836    pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
1837        cx.emit(AcpThreadEvent::LoadError(error));
1838    }
1839}
1840
1841fn markdown_for_raw_output(
1842    raw_output: &serde_json::Value,
1843    language_registry: &Arc<LanguageRegistry>,
1844    cx: &mut App,
1845) -> Option<Entity<Markdown>> {
1846    match raw_output {
1847        serde_json::Value::Null => None,
1848        serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1849            Markdown::new(
1850                value.to_string().into(),
1851                Some(language_registry.clone()),
1852                None,
1853                cx,
1854            )
1855        })),
1856        serde_json::Value::Number(value) => Some(cx.new(|cx| {
1857            Markdown::new(
1858                value.to_string().into(),
1859                Some(language_registry.clone()),
1860                None,
1861                cx,
1862            )
1863        })),
1864        serde_json::Value::String(value) => Some(cx.new(|cx| {
1865            Markdown::new(
1866                value.clone().into(),
1867                Some(language_registry.clone()),
1868                None,
1869                cx,
1870            )
1871        })),
1872        value => Some(cx.new(|cx| {
1873            Markdown::new(
1874                format!("```json\n{}\n```", value).into(),
1875                Some(language_registry.clone()),
1876                None,
1877                cx,
1878            )
1879        })),
1880    }
1881}
1882
1883#[cfg(test)]
1884mod tests {
1885    use super::*;
1886    use anyhow::anyhow;
1887    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1888    use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1889    use indoc::indoc;
1890    use project::{FakeFs, Fs};
1891    use rand::Rng as _;
1892    use serde_json::json;
1893    use settings::SettingsStore;
1894    use smol::stream::StreamExt as _;
1895    use std::{
1896        any::Any,
1897        cell::RefCell,
1898        path::Path,
1899        rc::Rc,
1900        sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1901        time::Duration,
1902    };
1903    use util::path;
1904
1905    fn init_test(cx: &mut TestAppContext) {
1906        env_logger::try_init().ok();
1907        cx.update(|cx| {
1908            let settings_store = SettingsStore::test(cx);
1909            cx.set_global(settings_store);
1910            Project::init_settings(cx);
1911            language::init(cx);
1912        });
1913    }
1914
1915    #[gpui::test]
1916    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1917        init_test(cx);
1918
1919        let fs = FakeFs::new(cx.executor());
1920        let project = Project::test(fs, [], cx).await;
1921        let connection = Rc::new(FakeAgentConnection::new());
1922        let thread = cx
1923            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1924            .await
1925            .unwrap();
1926
1927        // Test creating a new user message
1928        thread.update(cx, |thread, cx| {
1929            thread.push_user_content_block(
1930                None,
1931                acp::ContentBlock::Text(acp::TextContent {
1932                    annotations: None,
1933                    text: "Hello, ".to_string(),
1934                }),
1935                cx,
1936            );
1937        });
1938
1939        thread.update(cx, |thread, cx| {
1940            assert_eq!(thread.entries.len(), 1);
1941            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1942                assert_eq!(user_msg.id, None);
1943                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1944            } else {
1945                panic!("Expected UserMessage");
1946            }
1947        });
1948
1949        // Test appending to existing user message
1950        let message_1_id = UserMessageId::new();
1951        thread.update(cx, |thread, cx| {
1952            thread.push_user_content_block(
1953                Some(message_1_id.clone()),
1954                acp::ContentBlock::Text(acp::TextContent {
1955                    annotations: None,
1956                    text: "world!".to_string(),
1957                }),
1958                cx,
1959            );
1960        });
1961
1962        thread.update(cx, |thread, cx| {
1963            assert_eq!(thread.entries.len(), 1);
1964            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1965                assert_eq!(user_msg.id, Some(message_1_id));
1966                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1967            } else {
1968                panic!("Expected UserMessage");
1969            }
1970        });
1971
1972        // Test creating new user message after assistant message
1973        thread.update(cx, |thread, cx| {
1974            thread.push_assistant_content_block(
1975                acp::ContentBlock::Text(acp::TextContent {
1976                    annotations: None,
1977                    text: "Assistant response".to_string(),
1978                }),
1979                false,
1980                cx,
1981            );
1982        });
1983
1984        let message_2_id = UserMessageId::new();
1985        thread.update(cx, |thread, cx| {
1986            thread.push_user_content_block(
1987                Some(message_2_id.clone()),
1988                acp::ContentBlock::Text(acp::TextContent {
1989                    annotations: None,
1990                    text: "New user message".to_string(),
1991                }),
1992                cx,
1993            );
1994        });
1995
1996        thread.update(cx, |thread, cx| {
1997            assert_eq!(thread.entries.len(), 3);
1998            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1999                assert_eq!(user_msg.id, Some(message_2_id));
2000                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
2001            } else {
2002                panic!("Expected UserMessage at index 2");
2003            }
2004        });
2005    }
2006
2007    #[gpui::test]
2008    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
2009        init_test(cx);
2010
2011        let fs = FakeFs::new(cx.executor());
2012        let project = Project::test(fs, [], cx).await;
2013        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2014            |_, thread, mut cx| {
2015                async move {
2016                    thread.update(&mut cx, |thread, cx| {
2017                        thread
2018                            .handle_session_update(
2019                                acp::SessionUpdate::AgentThoughtChunk {
2020                                    content: "Thinking ".into(),
2021                                },
2022                                cx,
2023                            )
2024                            .unwrap();
2025                        thread
2026                            .handle_session_update(
2027                                acp::SessionUpdate::AgentThoughtChunk {
2028                                    content: "hard!".into(),
2029                                },
2030                                cx,
2031                            )
2032                            .unwrap();
2033                    })?;
2034                    Ok(acp::PromptResponse {
2035                        stop_reason: acp::StopReason::EndTurn,
2036                    })
2037                }
2038                .boxed_local()
2039            },
2040        ));
2041
2042        let thread = cx
2043            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2044            .await
2045            .unwrap();
2046
2047        thread
2048            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2049            .await
2050            .unwrap();
2051
2052        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2053        assert_eq!(
2054            output,
2055            indoc! {r#"
2056            ## User
2057
2058            Hello from Zed!
2059
2060            ## Assistant
2061
2062            <thinking>
2063            Thinking hard!
2064            </thinking>
2065
2066            "#}
2067        );
2068    }
2069
2070    #[gpui::test]
2071    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2072        init_test(cx);
2073
2074        let fs = FakeFs::new(cx.executor());
2075        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2076            .await;
2077        let project = Project::test(fs.clone(), [], cx).await;
2078        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2079        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2080        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2081            move |_, thread, mut cx| {
2082                let read_file_tx = read_file_tx.clone();
2083                async move {
2084                    let content = thread
2085                        .update(&mut cx, |thread, cx| {
2086                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2087                        })
2088                        .unwrap()
2089                        .await
2090                        .unwrap();
2091                    assert_eq!(content, "one\ntwo\nthree\n");
2092                    read_file_tx.take().unwrap().send(()).unwrap();
2093                    thread
2094                        .update(&mut cx, |thread, cx| {
2095                            thread.write_text_file(
2096                                path!("/tmp/foo").into(),
2097                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
2098                                cx,
2099                            )
2100                        })
2101                        .unwrap()
2102                        .await
2103                        .unwrap();
2104                    Ok(acp::PromptResponse {
2105                        stop_reason: acp::StopReason::EndTurn,
2106                    })
2107                }
2108                .boxed_local()
2109            },
2110        ));
2111
2112        let (worktree, pathbuf) = project
2113            .update(cx, |project, cx| {
2114                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2115            })
2116            .await
2117            .unwrap();
2118        let buffer = project
2119            .update(cx, |project, cx| {
2120                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2121            })
2122            .await
2123            .unwrap();
2124
2125        let thread = cx
2126            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2127            .await
2128            .unwrap();
2129
2130        let request = thread.update(cx, |thread, cx| {
2131            thread.send_raw("Extend the count in /tmp/foo", cx)
2132        });
2133        read_file_rx.await.ok();
2134        buffer.update(cx, |buffer, cx| {
2135            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2136        });
2137        cx.run_until_parked();
2138        assert_eq!(
2139            buffer.read_with(cx, |buffer, _| buffer.text()),
2140            "zero\none\ntwo\nthree\nfour\nfive\n"
2141        );
2142        assert_eq!(
2143            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2144            "zero\none\ntwo\nthree\nfour\nfive\n"
2145        );
2146        request.await.unwrap();
2147    }
2148
2149    #[gpui::test]
2150    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2151        init_test(cx);
2152
2153        let fs = FakeFs::new(cx.executor());
2154        let project = Project::test(fs, [], cx).await;
2155        let id = acp::ToolCallId("test".into());
2156
2157        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2158            let id = id.clone();
2159            move |_, thread, mut cx| {
2160                let id = id.clone();
2161                async move {
2162                    thread
2163                        .update(&mut cx, |thread, cx| {
2164                            thread.handle_session_update(
2165                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2166                                    id: id.clone(),
2167                                    title: "Label".into(),
2168                                    kind: acp::ToolKind::Fetch,
2169                                    status: acp::ToolCallStatus::InProgress,
2170                                    content: vec![],
2171                                    locations: vec![],
2172                                    raw_input: None,
2173                                    raw_output: None,
2174                                }),
2175                                cx,
2176                            )
2177                        })
2178                        .unwrap()
2179                        .unwrap();
2180                    Ok(acp::PromptResponse {
2181                        stop_reason: acp::StopReason::EndTurn,
2182                    })
2183                }
2184                .boxed_local()
2185            }
2186        }));
2187
2188        let thread = cx
2189            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2190            .await
2191            .unwrap();
2192
2193        let request = thread.update(cx, |thread, cx| {
2194            thread.send_raw("Fetch https://example.com", cx)
2195        });
2196
2197        run_until_first_tool_call(&thread, cx).await;
2198
2199        thread.read_with(cx, |thread, _| {
2200            assert!(matches!(
2201                thread.entries[1],
2202                AgentThreadEntry::ToolCall(ToolCall {
2203                    status: ToolCallStatus::InProgress,
2204                    ..
2205                })
2206            ));
2207        });
2208
2209        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2210
2211        thread.read_with(cx, |thread, _| {
2212            assert!(matches!(
2213                &thread.entries[1],
2214                AgentThreadEntry::ToolCall(ToolCall {
2215                    status: ToolCallStatus::Canceled,
2216                    ..
2217                })
2218            ));
2219        });
2220
2221        thread
2222            .update(cx, |thread, cx| {
2223                thread.handle_session_update(
2224                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2225                        id,
2226                        fields: acp::ToolCallUpdateFields {
2227                            status: Some(acp::ToolCallStatus::Completed),
2228                            ..Default::default()
2229                        },
2230                    }),
2231                    cx,
2232                )
2233            })
2234            .unwrap();
2235
2236        request.await.unwrap();
2237
2238        thread.read_with(cx, |thread, _| {
2239            assert!(matches!(
2240                thread.entries[1],
2241                AgentThreadEntry::ToolCall(ToolCall {
2242                    status: ToolCallStatus::Completed,
2243                    ..
2244                })
2245            ));
2246        });
2247    }
2248
2249    #[gpui::test]
2250    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2251        init_test(cx);
2252        let fs = FakeFs::new(cx.background_executor.clone());
2253        fs.insert_tree(path!("/test"), json!({})).await;
2254        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2255
2256        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2257            move |_, thread, mut cx| {
2258                async move {
2259                    thread
2260                        .update(&mut cx, |thread, cx| {
2261                            thread.handle_session_update(
2262                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2263                                    id: acp::ToolCallId("test".into()),
2264                                    title: "Label".into(),
2265                                    kind: acp::ToolKind::Edit,
2266                                    status: acp::ToolCallStatus::Completed,
2267                                    content: vec![acp::ToolCallContent::Diff {
2268                                        diff: acp::Diff {
2269                                            path: "/test/test.txt".into(),
2270                                            old_text: None,
2271                                            new_text: "foo".into(),
2272                                        },
2273                                    }],
2274                                    locations: vec![],
2275                                    raw_input: None,
2276                                    raw_output: None,
2277                                }),
2278                                cx,
2279                            )
2280                        })
2281                        .unwrap()
2282                        .unwrap();
2283                    Ok(acp::PromptResponse {
2284                        stop_reason: acp::StopReason::EndTurn,
2285                    })
2286                }
2287                .boxed_local()
2288            }
2289        }));
2290
2291        let thread = cx
2292            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2293            .await
2294            .unwrap();
2295
2296        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Hi".into()], cx)))
2297            .await
2298            .unwrap();
2299
2300        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2301    }
2302
2303    #[gpui::test(iterations = 10)]
2304    async fn test_checkpoints(cx: &mut TestAppContext) {
2305        init_test(cx);
2306        let fs = FakeFs::new(cx.background_executor.clone());
2307        fs.insert_tree(
2308            path!("/test"),
2309            json!({
2310                ".git": {}
2311            }),
2312        )
2313        .await;
2314        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2315
2316        let simulate_changes = Arc::new(AtomicBool::new(true));
2317        let next_filename = Arc::new(AtomicUsize::new(0));
2318        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2319            let simulate_changes = simulate_changes.clone();
2320            let next_filename = next_filename.clone();
2321            let fs = fs.clone();
2322            move |request, thread, mut cx| {
2323                let fs = fs.clone();
2324                let simulate_changes = simulate_changes.clone();
2325                let next_filename = next_filename.clone();
2326                async move {
2327                    if simulate_changes.load(SeqCst) {
2328                        let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2329                        fs.write(Path::new(&filename), b"").await?;
2330                    }
2331
2332                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2333                        panic!("expected text content block");
2334                    };
2335                    thread.update(&mut cx, |thread, cx| {
2336                        thread
2337                            .handle_session_update(
2338                                acp::SessionUpdate::AgentMessageChunk {
2339                                    content: content.text.to_uppercase().into(),
2340                                },
2341                                cx,
2342                            )
2343                            .unwrap();
2344                    })?;
2345                    Ok(acp::PromptResponse {
2346                        stop_reason: acp::StopReason::EndTurn,
2347                    })
2348                }
2349                .boxed_local()
2350            }
2351        }));
2352        let thread = cx
2353            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2354            .await
2355            .unwrap();
2356
2357        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
2358            .await
2359            .unwrap();
2360        thread.read_with(cx, |thread, cx| {
2361            assert_eq!(
2362                thread.to_markdown(cx),
2363                indoc! {"
2364                    ## User (checkpoint)
2365
2366                    Lorem
2367
2368                    ## Assistant
2369
2370                    LOREM
2371
2372                "}
2373            );
2374        });
2375        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2376
2377        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
2378            .await
2379            .unwrap();
2380        thread.read_with(cx, |thread, cx| {
2381            assert_eq!(
2382                thread.to_markdown(cx),
2383                indoc! {"
2384                    ## User (checkpoint)
2385
2386                    Lorem
2387
2388                    ## Assistant
2389
2390                    LOREM
2391
2392                    ## User (checkpoint)
2393
2394                    ipsum
2395
2396                    ## Assistant
2397
2398                    IPSUM
2399
2400                "}
2401            );
2402        });
2403        assert_eq!(
2404            fs.files(),
2405            vec![
2406                Path::new(path!("/test/file-0")),
2407                Path::new(path!("/test/file-1"))
2408            ]
2409        );
2410
2411        // Checkpoint isn't stored when there are no changes.
2412        simulate_changes.store(false, SeqCst);
2413        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
2414            .await
2415            .unwrap();
2416        thread.read_with(cx, |thread, cx| {
2417            assert_eq!(
2418                thread.to_markdown(cx),
2419                indoc! {"
2420                    ## User (checkpoint)
2421
2422                    Lorem
2423
2424                    ## Assistant
2425
2426                    LOREM
2427
2428                    ## User (checkpoint)
2429
2430                    ipsum
2431
2432                    ## Assistant
2433
2434                    IPSUM
2435
2436                    ## User
2437
2438                    dolor
2439
2440                    ## Assistant
2441
2442                    DOLOR
2443
2444                "}
2445            );
2446        });
2447        assert_eq!(
2448            fs.files(),
2449            vec![
2450                Path::new(path!("/test/file-0")),
2451                Path::new(path!("/test/file-1"))
2452            ]
2453        );
2454
2455        // Rewinding the conversation truncates the history and restores the checkpoint.
2456        thread
2457            .update(cx, |thread, cx| {
2458                let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2459                    panic!("unexpected entries {:?}", thread.entries)
2460                };
2461                thread.rewind(message.id.clone().unwrap(), cx)
2462            })
2463            .await
2464            .unwrap();
2465        thread.read_with(cx, |thread, cx| {
2466            assert_eq!(
2467                thread.to_markdown(cx),
2468                indoc! {"
2469                    ## User (checkpoint)
2470
2471                    Lorem
2472
2473                    ## Assistant
2474
2475                    LOREM
2476
2477                "}
2478            );
2479        });
2480        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2481    }
2482
2483    #[gpui::test]
2484    async fn test_refusal(cx: &mut TestAppContext) {
2485        init_test(cx);
2486        let fs = FakeFs::new(cx.background_executor.clone());
2487        fs.insert_tree(path!("/"), json!({})).await;
2488        let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2489
2490        let refuse_next = Arc::new(AtomicBool::new(false));
2491        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2492            let refuse_next = refuse_next.clone();
2493            move |request, thread, mut cx| {
2494                let refuse_next = refuse_next.clone();
2495                async move {
2496                    if refuse_next.load(SeqCst) {
2497                        return Ok(acp::PromptResponse {
2498                            stop_reason: acp::StopReason::Refusal,
2499                        });
2500                    }
2501
2502                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2503                        panic!("expected text content block");
2504                    };
2505                    thread.update(&mut cx, |thread, cx| {
2506                        thread
2507                            .handle_session_update(
2508                                acp::SessionUpdate::AgentMessageChunk {
2509                                    content: content.text.to_uppercase().into(),
2510                                },
2511                                cx,
2512                            )
2513                            .unwrap();
2514                    })?;
2515                    Ok(acp::PromptResponse {
2516                        stop_reason: acp::StopReason::EndTurn,
2517                    })
2518                }
2519                .boxed_local()
2520            }
2521        }));
2522        let thread = cx
2523            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2524            .await
2525            .unwrap();
2526
2527        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["hello".into()], cx)))
2528            .await
2529            .unwrap();
2530        thread.read_with(cx, |thread, cx| {
2531            assert_eq!(
2532                thread.to_markdown(cx),
2533                indoc! {"
2534                    ## User
2535
2536                    hello
2537
2538                    ## Assistant
2539
2540                    HELLO
2541
2542                "}
2543            );
2544        });
2545
2546        // Simulate refusing the second message, ensuring the conversation gets
2547        // truncated to before sending it.
2548        refuse_next.store(true, SeqCst);
2549        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["world".into()], cx)))
2550            .await
2551            .unwrap();
2552        thread.read_with(cx, |thread, cx| {
2553            assert_eq!(
2554                thread.to_markdown(cx),
2555                indoc! {"
2556                    ## User
2557
2558                    hello
2559
2560                    ## Assistant
2561
2562                    HELLO
2563
2564                "}
2565            );
2566        });
2567    }
2568
2569    async fn run_until_first_tool_call(
2570        thread: &Entity<AcpThread>,
2571        cx: &mut TestAppContext,
2572    ) -> usize {
2573        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2574
2575        let subscription = cx.update(|cx| {
2576            cx.subscribe(thread, move |thread, _, cx| {
2577                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2578                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2579                        return tx.try_send(ix).unwrap();
2580                    }
2581                }
2582            })
2583        });
2584
2585        select! {
2586            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2587                panic!("Timeout waiting for tool call")
2588            }
2589            ix = rx.next().fuse() => {
2590                drop(subscription);
2591                ix.unwrap()
2592            }
2593        }
2594    }
2595
2596    #[derive(Clone, Default)]
2597    struct FakeAgentConnection {
2598        auth_methods: Vec<acp::AuthMethod>,
2599        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2600        on_user_message: Option<
2601            Rc<
2602                dyn Fn(
2603                        acp::PromptRequest,
2604                        WeakEntity<AcpThread>,
2605                        AsyncApp,
2606                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2607                    + 'static,
2608            >,
2609        >,
2610    }
2611
2612    impl FakeAgentConnection {
2613        fn new() -> Self {
2614            Self {
2615                auth_methods: Vec::new(),
2616                on_user_message: None,
2617                sessions: Arc::default(),
2618            }
2619        }
2620
2621        #[expect(unused)]
2622        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2623            self.auth_methods = auth_methods;
2624            self
2625        }
2626
2627        fn on_user_message(
2628            mut self,
2629            handler: impl Fn(
2630                acp::PromptRequest,
2631                WeakEntity<AcpThread>,
2632                AsyncApp,
2633            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2634            + 'static,
2635        ) -> Self {
2636            self.on_user_message.replace(Rc::new(handler));
2637            self
2638        }
2639    }
2640
2641    impl AgentConnection for FakeAgentConnection {
2642        fn auth_methods(&self) -> &[acp::AuthMethod] {
2643            &self.auth_methods
2644        }
2645
2646        fn new_thread(
2647            self: Rc<Self>,
2648            project: Entity<Project>,
2649            _cwd: &Path,
2650            cx: &mut App,
2651        ) -> Task<gpui::Result<Entity<AcpThread>>> {
2652            let session_id = acp::SessionId(
2653                rand::thread_rng()
2654                    .sample_iter(&rand::distributions::Alphanumeric)
2655                    .take(7)
2656                    .map(char::from)
2657                    .collect::<String>()
2658                    .into(),
2659            );
2660            let action_log = cx.new(|_| ActionLog::new(project.clone()));
2661            let thread = cx.new(|cx| {
2662                AcpThread::new(
2663                    "Test",
2664                    self.clone(),
2665                    project,
2666                    action_log,
2667                    session_id.clone(),
2668                    watch::Receiver::constant(acp::PromptCapabilities {
2669                        image: true,
2670                        audio: true,
2671                        embedded_context: true,
2672                    }),
2673                    cx,
2674                )
2675            });
2676            self.sessions.lock().insert(session_id, thread.downgrade());
2677            Task::ready(Ok(thread))
2678        }
2679
2680        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2681            if self.auth_methods().iter().any(|m| m.id == method) {
2682                Task::ready(Ok(()))
2683            } else {
2684                Task::ready(Err(anyhow!("Invalid Auth Method")))
2685            }
2686        }
2687
2688        fn prompt(
2689            &self,
2690            _id: Option<UserMessageId>,
2691            params: acp::PromptRequest,
2692            cx: &mut App,
2693        ) -> Task<gpui::Result<acp::PromptResponse>> {
2694            let sessions = self.sessions.lock();
2695            let thread = sessions.get(&params.session_id).unwrap();
2696            if let Some(handler) = &self.on_user_message {
2697                let handler = handler.clone();
2698                let thread = thread.clone();
2699                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2700            } else {
2701                Task::ready(Ok(acp::PromptResponse {
2702                    stop_reason: acp::StopReason::EndTurn,
2703                }))
2704            }
2705        }
2706
2707        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2708            let sessions = self.sessions.lock();
2709            let thread = sessions.get(session_id).unwrap().clone();
2710
2711            cx.spawn(async move |cx| {
2712                thread
2713                    .update(cx, |thread, cx| thread.cancel(cx))
2714                    .unwrap()
2715                    .await
2716            })
2717            .detach();
2718        }
2719
2720        fn truncate(
2721            &self,
2722            session_id: &acp::SessionId,
2723            _cx: &App,
2724        ) -> Option<Rc<dyn AgentSessionTruncate>> {
2725            Some(Rc::new(FakeAgentSessionEditor {
2726                _session_id: session_id.clone(),
2727            }))
2728        }
2729
2730        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2731            self
2732        }
2733    }
2734
2735    struct FakeAgentSessionEditor {
2736        _session_id: acp::SessionId,
2737    }
2738
2739    impl AgentSessionTruncate for FakeAgentSessionEditor {
2740        fn run(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
2741            Task::ready(Ok(()))
2742        }
2743    }
2744}