acp_thread.rs

   1mod connection;
   2mod diff;
   3mod mention;
   4mod terminal;
   5
   6use collections::HashSet;
   7pub use connection::*;
   8pub use diff::*;
   9use language::language_settings::FormatOnSave;
  10pub use mention::*;
  11use project::lsp_store::{FormatTrigger, LspFormatTarget};
  12use serde::{Deserialize, Serialize};
  13pub use terminal::*;
  14
  15use action_log::ActionLog;
  16use agent_client_protocol as acp;
  17use anyhow::{Context as _, Result, anyhow};
  18use editor::Bias;
  19use futures::{FutureExt, channel::oneshot, future::BoxFuture};
  20use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  21use itertools::Itertools;
  22use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
  23use markdown::Markdown;
  24use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
  25use std::collections::HashMap;
  26use std::error::Error;
  27use std::fmt::{Formatter, Write};
  28use std::ops::Range;
  29use std::process::ExitStatus;
  30use std::rc::Rc;
  31use std::time::{Duration, Instant};
  32use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
  33use ui::App;
  34use util::ResultExt;
  35use uuid::Uuid;
  36
  37pub fn new_prompt_id() -> acp::PromptId {
  38    acp::PromptId(Uuid::new_v4().to_string().into())
  39}
  40
  41#[derive(Debug)]
  42pub struct UserMessage {
  43    pub prompt_id: Option<acp::PromptId>,
  44    pub content: ContentBlock,
  45    pub chunks: Vec<acp::ContentBlock>,
  46    pub checkpoint: Option<Checkpoint>,
  47}
  48
  49#[derive(Debug)]
  50pub struct Checkpoint {
  51    git_checkpoint: GitStoreCheckpoint,
  52    pub show: bool,
  53}
  54
  55impl UserMessage {
  56    fn to_markdown(&self, cx: &App) -> String {
  57        let mut markdown = String::new();
  58        if self
  59            .checkpoint
  60            .as_ref()
  61            .is_some_and(|checkpoint| checkpoint.show)
  62        {
  63            writeln!(markdown, "## User (checkpoint)").unwrap();
  64        } else {
  65            writeln!(markdown, "## User").unwrap();
  66        }
  67        writeln!(markdown).unwrap();
  68        writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
  69        writeln!(markdown).unwrap();
  70        markdown
  71    }
  72}
  73
  74#[derive(Debug, PartialEq)]
  75pub struct AssistantMessage {
  76    pub chunks: Vec<AssistantMessageChunk>,
  77}
  78
  79impl AssistantMessage {
  80    pub fn to_markdown(&self, cx: &App) -> String {
  81        format!(
  82            "## Assistant\n\n{}\n\n",
  83            self.chunks
  84                .iter()
  85                .map(|chunk| chunk.to_markdown(cx))
  86                .join("\n\n")
  87        )
  88    }
  89}
  90
  91#[derive(Debug, PartialEq)]
  92pub enum AssistantMessageChunk {
  93    Message { block: ContentBlock },
  94    Thought { block: ContentBlock },
  95}
  96
  97impl AssistantMessageChunk {
  98    pub fn from_str(chunk: &str, language_registry: &Arc<LanguageRegistry>, cx: &mut App) -> Self {
  99        Self::Message {
 100            block: ContentBlock::new(chunk.into(), language_registry, cx),
 101        }
 102    }
 103
 104    fn to_markdown(&self, cx: &App) -> String {
 105        match self {
 106            Self::Message { block } => block.to_markdown(cx).to_string(),
 107            Self::Thought { block } => {
 108                format!("<thinking>\n{}\n</thinking>", block.to_markdown(cx))
 109            }
 110        }
 111    }
 112}
 113
 114#[derive(Debug)]
 115pub enum AgentThreadEntry {
 116    UserMessage(UserMessage),
 117    AssistantMessage(AssistantMessage),
 118    ToolCall(ToolCall),
 119}
 120
 121impl AgentThreadEntry {
 122    pub fn to_markdown(&self, cx: &App) -> String {
 123        match self {
 124            Self::UserMessage(message) => message.to_markdown(cx),
 125            Self::AssistantMessage(message) => message.to_markdown(cx),
 126            Self::ToolCall(tool_call) => tool_call.to_markdown(cx),
 127        }
 128    }
 129
 130    pub fn user_message(&self) -> Option<&UserMessage> {
 131        if let AgentThreadEntry::UserMessage(message) = self {
 132            Some(message)
 133        } else {
 134            None
 135        }
 136    }
 137
 138    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 139        if let AgentThreadEntry::ToolCall(call) = self {
 140            itertools::Either::Left(call.diffs())
 141        } else {
 142            itertools::Either::Right(std::iter::empty())
 143        }
 144    }
 145
 146    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 147        if let AgentThreadEntry::ToolCall(call) = self {
 148            itertools::Either::Left(call.terminals())
 149        } else {
 150            itertools::Either::Right(std::iter::empty())
 151        }
 152    }
 153
 154    pub fn location(&self, ix: usize) -> Option<(acp::ToolCallLocation, AgentLocation)> {
 155        if let AgentThreadEntry::ToolCall(ToolCall {
 156            locations,
 157            resolved_locations,
 158            ..
 159        }) = self
 160        {
 161            Some((
 162                locations.get(ix)?.clone(),
 163                resolved_locations.get(ix)?.clone()?,
 164            ))
 165        } else {
 166            None
 167        }
 168    }
 169}
 170
 171#[derive(Debug)]
 172pub struct ToolCall {
 173    pub id: acp::ToolCallId,
 174    pub label: Entity<Markdown>,
 175    pub kind: acp::ToolKind,
 176    pub content: Vec<ToolCallContent>,
 177    pub status: ToolCallStatus,
 178    pub locations: Vec<acp::ToolCallLocation>,
 179    pub resolved_locations: Vec<Option<AgentLocation>>,
 180    pub raw_input: Option<serde_json::Value>,
 181    pub raw_output: Option<serde_json::Value>,
 182}
 183
 184impl ToolCall {
 185    fn from_acp(
 186        tool_call: acp::ToolCall,
 187        status: ToolCallStatus,
 188        language_registry: Arc<LanguageRegistry>,
 189        cx: &mut App,
 190    ) -> Self {
 191        let title = if let Some((first_line, _)) = tool_call.title.split_once("\n") {
 192            first_line.to_owned() + ""
 193        } else {
 194            tool_call.title
 195        };
 196        Self {
 197            id: tool_call.id,
 198            label: cx
 199                .new(|cx| Markdown::new(title.into(), Some(language_registry.clone()), None, cx)),
 200            kind: tool_call.kind,
 201            content: tool_call
 202                .content
 203                .into_iter()
 204                .map(|content| ToolCallContent::from_acp(content, language_registry.clone(), cx))
 205                .collect(),
 206            locations: tool_call.locations,
 207            resolved_locations: Vec::default(),
 208            status,
 209            raw_input: tool_call.raw_input,
 210            raw_output: tool_call.raw_output,
 211        }
 212    }
 213
 214    fn update_fields(
 215        &mut self,
 216        fields: acp::ToolCallUpdateFields,
 217        language_registry: Arc<LanguageRegistry>,
 218        cx: &mut App,
 219    ) {
 220        let acp::ToolCallUpdateFields {
 221            kind,
 222            status,
 223            title,
 224            content,
 225            locations,
 226            raw_input,
 227            raw_output,
 228        } = fields;
 229
 230        if let Some(kind) = kind {
 231            self.kind = kind;
 232        }
 233
 234        if let Some(status) = status {
 235            self.status = status.into();
 236        }
 237
 238        if let Some(title) = title {
 239            self.label.update(cx, |label, cx| {
 240                if let Some((first_line, _)) = title.split_once("\n") {
 241                    label.replace(first_line.to_owned() + "", cx)
 242                } else {
 243                    label.replace(title, cx);
 244                }
 245            });
 246        }
 247
 248        if let Some(content) = content {
 249            let new_content_len = content.len();
 250            let mut content = content.into_iter();
 251
 252            // Reuse existing content if we can
 253            for (old, new) in self.content.iter_mut().zip(content.by_ref()) {
 254                old.update_from_acp(new, language_registry.clone(), cx);
 255            }
 256            for new in content {
 257                self.content.push(ToolCallContent::from_acp(
 258                    new,
 259                    language_registry.clone(),
 260                    cx,
 261                ))
 262            }
 263            self.content.truncate(new_content_len);
 264        }
 265
 266        if let Some(locations) = locations {
 267            self.locations = locations;
 268        }
 269
 270        if let Some(raw_input) = raw_input {
 271            self.raw_input = Some(raw_input);
 272        }
 273
 274        if let Some(raw_output) = raw_output {
 275            if self.content.is_empty()
 276                && let Some(markdown) = markdown_for_raw_output(&raw_output, &language_registry, cx)
 277            {
 278                self.content
 279                    .push(ToolCallContent::ContentBlock(ContentBlock::Markdown {
 280                        markdown,
 281                    }));
 282            }
 283            self.raw_output = Some(raw_output);
 284        }
 285    }
 286
 287    pub fn diffs(&self) -> impl Iterator<Item = &Entity<Diff>> {
 288        self.content.iter().filter_map(|content| match content {
 289            ToolCallContent::Diff(diff) => Some(diff),
 290            ToolCallContent::ContentBlock(_) => None,
 291            ToolCallContent::Terminal(_) => None,
 292        })
 293    }
 294
 295    pub fn terminals(&self) -> impl Iterator<Item = &Entity<Terminal>> {
 296        self.content.iter().filter_map(|content| match content {
 297            ToolCallContent::Terminal(terminal) => Some(terminal),
 298            ToolCallContent::ContentBlock(_) => None,
 299            ToolCallContent::Diff(_) => None,
 300        })
 301    }
 302
 303    fn to_markdown(&self, cx: &App) -> String {
 304        let mut markdown = format!(
 305            "**Tool Call: {}**\nStatus: {}\n\n",
 306            self.label.read(cx).source(),
 307            self.status
 308        );
 309        for content in &self.content {
 310            markdown.push_str(content.to_markdown(cx).as_str());
 311            markdown.push_str("\n\n");
 312        }
 313        markdown
 314    }
 315
 316    async fn resolve_location(
 317        location: acp::ToolCallLocation,
 318        project: WeakEntity<Project>,
 319        cx: &mut AsyncApp,
 320    ) -> Option<AgentLocation> {
 321        let buffer = project
 322            .update(cx, |project, cx| {
 323                project
 324                    .project_path_for_absolute_path(&location.path, cx)
 325                    .map(|path| project.open_buffer(path, cx))
 326            })
 327            .ok()??;
 328        let buffer = buffer.await.log_err()?;
 329        let position = buffer
 330            .update(cx, |buffer, _| {
 331                if let Some(row) = location.line {
 332                    let snapshot = buffer.snapshot();
 333                    let column = snapshot.indent_size_for_line(row).len;
 334                    let point = snapshot.clip_point(Point::new(row, column), Bias::Left);
 335                    snapshot.anchor_before(point)
 336                } else {
 337                    Anchor::MIN
 338                }
 339            })
 340            .ok()?;
 341
 342        Some(AgentLocation {
 343            buffer: buffer.downgrade(),
 344            position,
 345        })
 346    }
 347
 348    fn resolve_locations(
 349        &self,
 350        project: Entity<Project>,
 351        cx: &mut App,
 352    ) -> Task<Vec<Option<AgentLocation>>> {
 353        let locations = self.locations.clone();
 354        project.update(cx, |_, cx| {
 355            cx.spawn(async move |project, cx| {
 356                let mut new_locations = Vec::new();
 357                for location in locations {
 358                    new_locations.push(Self::resolve_location(location, project.clone(), cx).await);
 359                }
 360                new_locations
 361            })
 362        })
 363    }
 364}
 365
 366#[derive(Debug)]
 367pub enum ToolCallStatus {
 368    /// The tool call hasn't started running yet, but we start showing it to
 369    /// the user.
 370    Pending,
 371    /// The tool call is waiting for confirmation from the user.
 372    WaitingForConfirmation {
 373        options: Vec<acp::PermissionOption>,
 374        respond_tx: oneshot::Sender<acp::PermissionOptionId>,
 375    },
 376    /// The tool call is currently running.
 377    InProgress,
 378    /// The tool call completed successfully.
 379    Completed,
 380    /// The tool call failed.
 381    Failed,
 382    /// The user rejected the tool call.
 383    Rejected,
 384    /// The user canceled generation so the tool call was canceled.
 385    Canceled,
 386}
 387
 388impl From<acp::ToolCallStatus> for ToolCallStatus {
 389    fn from(status: acp::ToolCallStatus) -> Self {
 390        match status {
 391            acp::ToolCallStatus::Pending => Self::Pending,
 392            acp::ToolCallStatus::InProgress => Self::InProgress,
 393            acp::ToolCallStatus::Completed => Self::Completed,
 394            acp::ToolCallStatus::Failed => Self::Failed,
 395        }
 396    }
 397}
 398
 399impl Display for ToolCallStatus {
 400    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 401        write!(
 402            f,
 403            "{}",
 404            match self {
 405                ToolCallStatus::Pending => "Pending",
 406                ToolCallStatus::WaitingForConfirmation { .. } => "Waiting for confirmation",
 407                ToolCallStatus::InProgress => "In Progress",
 408                ToolCallStatus::Completed => "Completed",
 409                ToolCallStatus::Failed => "Failed",
 410                ToolCallStatus::Rejected => "Rejected",
 411                ToolCallStatus::Canceled => "Canceled",
 412            }
 413        )
 414    }
 415}
 416
 417#[derive(Debug, PartialEq, Clone)]
 418pub enum ContentBlock {
 419    Empty,
 420    Markdown { markdown: Entity<Markdown> },
 421    ResourceLink { resource_link: acp::ResourceLink },
 422}
 423
 424impl ContentBlock {
 425    pub fn new(
 426        block: acp::ContentBlock,
 427        language_registry: &Arc<LanguageRegistry>,
 428        cx: &mut App,
 429    ) -> Self {
 430        let mut this = Self::Empty;
 431        this.append(block, language_registry, cx);
 432        this
 433    }
 434
 435    pub fn new_combined(
 436        blocks: impl IntoIterator<Item = acp::ContentBlock>,
 437        language_registry: Arc<LanguageRegistry>,
 438        cx: &mut App,
 439    ) -> Self {
 440        let mut this = Self::Empty;
 441        for block in blocks {
 442            this.append(block, &language_registry, cx);
 443        }
 444        this
 445    }
 446
 447    pub fn append(
 448        &mut self,
 449        block: acp::ContentBlock,
 450        language_registry: &Arc<LanguageRegistry>,
 451        cx: &mut App,
 452    ) {
 453        if matches!(self, ContentBlock::Empty)
 454            && let acp::ContentBlock::ResourceLink(resource_link) = block
 455        {
 456            *self = ContentBlock::ResourceLink { resource_link };
 457            return;
 458        }
 459
 460        let new_content = self.block_string_contents(block);
 461
 462        match self {
 463            ContentBlock::Empty => {
 464                *self = Self::create_markdown_block(new_content, language_registry, cx);
 465            }
 466            ContentBlock::Markdown { markdown } => {
 467                markdown.update(cx, |markdown, cx| markdown.append(&new_content, cx));
 468            }
 469            ContentBlock::ResourceLink { resource_link } => {
 470                let existing_content = Self::resource_link_md(&resource_link.uri);
 471                let combined = format!("{}\n{}", existing_content, new_content);
 472
 473                *self = Self::create_markdown_block(combined, language_registry, cx);
 474            }
 475        }
 476    }
 477
 478    fn create_markdown_block(
 479        content: String,
 480        language_registry: &Arc<LanguageRegistry>,
 481        cx: &mut App,
 482    ) -> ContentBlock {
 483        ContentBlock::Markdown {
 484            markdown: cx
 485                .new(|cx| Markdown::new(content.into(), Some(language_registry.clone()), None, cx)),
 486        }
 487    }
 488
 489    fn block_string_contents(&self, block: acp::ContentBlock) -> String {
 490        match block {
 491            acp::ContentBlock::Text(text_content) => text_content.text,
 492            acp::ContentBlock::ResourceLink(resource_link) => {
 493                Self::resource_link_md(&resource_link.uri)
 494            }
 495            acp::ContentBlock::Resource(acp::EmbeddedResource {
 496                resource:
 497                    acp::EmbeddedResourceResource::TextResourceContents(acp::TextResourceContents {
 498                        uri,
 499                        ..
 500                    }),
 501                ..
 502            }) => Self::resource_link_md(&uri),
 503            acp::ContentBlock::Image(image) => Self::image_md(&image),
 504            acp::ContentBlock::Audio(_) | acp::ContentBlock::Resource(_) => String::new(),
 505        }
 506    }
 507
 508    fn resource_link_md(uri: &str) -> String {
 509        if let Some(uri) = MentionUri::parse(uri).log_err() {
 510            uri.as_link().to_string()
 511        } else {
 512            uri.to_string()
 513        }
 514    }
 515
 516    fn image_md(_image: &acp::ImageContent) -> String {
 517        "`Image`".into()
 518    }
 519
 520    pub fn to_markdown<'a>(&'a self, cx: &'a App) -> &'a str {
 521        match self {
 522            ContentBlock::Empty => "",
 523            ContentBlock::Markdown { markdown } => markdown.read(cx).source(),
 524            ContentBlock::ResourceLink { resource_link } => &resource_link.uri,
 525        }
 526    }
 527
 528    pub fn markdown(&self) -> Option<&Entity<Markdown>> {
 529        match self {
 530            ContentBlock::Empty => None,
 531            ContentBlock::Markdown { markdown } => Some(markdown),
 532            ContentBlock::ResourceLink { .. } => None,
 533        }
 534    }
 535
 536    pub fn resource_link(&self) -> Option<&acp::ResourceLink> {
 537        match self {
 538            ContentBlock::ResourceLink { resource_link } => Some(resource_link),
 539            _ => None,
 540        }
 541    }
 542}
 543
 544#[derive(Debug)]
 545pub enum ToolCallContent {
 546    ContentBlock(ContentBlock),
 547    Diff(Entity<Diff>),
 548    Terminal(Entity<Terminal>),
 549}
 550
 551impl ToolCallContent {
 552    pub fn from_acp(
 553        content: acp::ToolCallContent,
 554        language_registry: Arc<LanguageRegistry>,
 555        cx: &mut App,
 556    ) -> Self {
 557        match content {
 558            acp::ToolCallContent::Content { content } => {
 559                Self::ContentBlock(ContentBlock::new(content, &language_registry, cx))
 560            }
 561            acp::ToolCallContent::Diff { diff } => Self::Diff(cx.new(|cx| {
 562                Diff::finalized(
 563                    diff.path,
 564                    diff.old_text,
 565                    diff.new_text,
 566                    language_registry,
 567                    cx,
 568                )
 569            })),
 570        }
 571    }
 572
 573    pub fn update_from_acp(
 574        &mut self,
 575        new: acp::ToolCallContent,
 576        language_registry: Arc<LanguageRegistry>,
 577        cx: &mut App,
 578    ) {
 579        let needs_update = match (&self, &new) {
 580            (Self::Diff(old_diff), acp::ToolCallContent::Diff { diff: new_diff }) => {
 581                old_diff.read(cx).needs_update(
 582                    new_diff.old_text.as_deref().unwrap_or(""),
 583                    &new_diff.new_text,
 584                    cx,
 585                )
 586            }
 587            _ => true,
 588        };
 589
 590        if needs_update {
 591            *self = Self::from_acp(new, language_registry, cx);
 592        }
 593    }
 594
 595    pub fn to_markdown(&self, cx: &App) -> String {
 596        match self {
 597            Self::ContentBlock(content) => content.to_markdown(cx).to_string(),
 598            Self::Diff(diff) => diff.read(cx).to_markdown(cx),
 599            Self::Terminal(terminal) => terminal.read(cx).to_markdown(cx),
 600        }
 601    }
 602}
 603
 604#[derive(Debug, PartialEq)]
 605pub enum ToolCallUpdate {
 606    UpdateFields(acp::ToolCallUpdate),
 607    UpdateDiff(ToolCallUpdateDiff),
 608    UpdateTerminal(ToolCallUpdateTerminal),
 609}
 610
 611impl ToolCallUpdate {
 612    fn id(&self) -> &acp::ToolCallId {
 613        match self {
 614            Self::UpdateFields(update) => &update.id,
 615            Self::UpdateDiff(diff) => &diff.id,
 616            Self::UpdateTerminal(terminal) => &terminal.id,
 617        }
 618    }
 619}
 620
 621impl From<acp::ToolCallUpdate> for ToolCallUpdate {
 622    fn from(update: acp::ToolCallUpdate) -> Self {
 623        Self::UpdateFields(update)
 624    }
 625}
 626
 627impl From<ToolCallUpdateDiff> for ToolCallUpdate {
 628    fn from(diff: ToolCallUpdateDiff) -> Self {
 629        Self::UpdateDiff(diff)
 630    }
 631}
 632
 633#[derive(Debug, PartialEq)]
 634pub struct ToolCallUpdateDiff {
 635    pub id: acp::ToolCallId,
 636    pub diff: Entity<Diff>,
 637}
 638
 639impl From<ToolCallUpdateTerminal> for ToolCallUpdate {
 640    fn from(terminal: ToolCallUpdateTerminal) -> Self {
 641        Self::UpdateTerminal(terminal)
 642    }
 643}
 644
 645#[derive(Debug, PartialEq)]
 646pub struct ToolCallUpdateTerminal {
 647    pub id: acp::ToolCallId,
 648    pub terminal: Entity<Terminal>,
 649}
 650
 651#[derive(Debug, Default)]
 652pub struct Plan {
 653    pub entries: Vec<PlanEntry>,
 654}
 655
 656#[derive(Debug)]
 657pub struct PlanStats<'a> {
 658    pub in_progress_entry: Option<&'a PlanEntry>,
 659    pub pending: u32,
 660    pub completed: u32,
 661}
 662
 663impl Plan {
 664    pub fn is_empty(&self) -> bool {
 665        self.entries.is_empty()
 666    }
 667
 668    pub fn stats(&self) -> PlanStats<'_> {
 669        let mut stats = PlanStats {
 670            in_progress_entry: None,
 671            pending: 0,
 672            completed: 0,
 673        };
 674
 675        for entry in &self.entries {
 676            match &entry.status {
 677                acp::PlanEntryStatus::Pending => {
 678                    stats.pending += 1;
 679                }
 680                acp::PlanEntryStatus::InProgress => {
 681                    stats.in_progress_entry = stats.in_progress_entry.or(Some(entry));
 682                }
 683                acp::PlanEntryStatus::Completed => {
 684                    stats.completed += 1;
 685                }
 686            }
 687        }
 688
 689        stats
 690    }
 691}
 692
 693#[derive(Debug)]
 694pub struct PlanEntry {
 695    pub content: Entity<Markdown>,
 696    pub priority: acp::PlanEntryPriority,
 697    pub status: acp::PlanEntryStatus,
 698}
 699
 700impl PlanEntry {
 701    pub fn from_acp(entry: acp::PlanEntry, cx: &mut App) -> Self {
 702        Self {
 703            content: cx.new(|cx| Markdown::new(entry.content.into(), None, None, cx)),
 704            priority: entry.priority,
 705            status: entry.status,
 706        }
 707    }
 708}
 709
 710#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 711pub struct TokenUsage {
 712    pub max_tokens: u64,
 713    pub used_tokens: u64,
 714}
 715
 716impl TokenUsage {
 717    pub fn ratio(&self) -> TokenUsageRatio {
 718        #[cfg(debug_assertions)]
 719        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 720            .unwrap_or("0.8".to_string())
 721            .parse()
 722            .unwrap();
 723        #[cfg(not(debug_assertions))]
 724        let warning_threshold: f32 = 0.8;
 725
 726        // When the maximum is unknown because there is no selected model,
 727        // avoid showing the token limit warning.
 728        if self.max_tokens == 0 {
 729            TokenUsageRatio::Normal
 730        } else if self.used_tokens >= self.max_tokens {
 731            TokenUsageRatio::Exceeded
 732        } else if self.used_tokens as f32 / self.max_tokens as f32 >= warning_threshold {
 733            TokenUsageRatio::Warning
 734        } else {
 735            TokenUsageRatio::Normal
 736        }
 737    }
 738}
 739
 740#[derive(Debug, Clone, PartialEq, Eq)]
 741pub enum TokenUsageRatio {
 742    Normal,
 743    Warning,
 744    Exceeded,
 745}
 746
 747#[derive(Debug, Clone)]
 748pub struct RetryStatus {
 749    pub last_error: SharedString,
 750    pub attempt: usize,
 751    pub max_attempts: usize,
 752    pub started_at: Instant,
 753    pub duration: Duration,
 754}
 755
 756pub struct AcpThread {
 757    title: SharedString,
 758    entries: Vec<AgentThreadEntry>,
 759    plan: Plan,
 760    project: Entity<Project>,
 761    action_log: Entity<ActionLog>,
 762    shared_buffers: HashMap<Entity<Buffer>, BufferSnapshot>,
 763    send_task: Option<Task<()>>,
 764    connection: Rc<dyn AgentConnection>,
 765    session_id: acp::SessionId,
 766    token_usage: Option<TokenUsage>,
 767    prompt_capabilities: acp::PromptCapabilities,
 768    _observe_prompt_capabilities: Task<anyhow::Result<()>>,
 769}
 770
 771#[derive(Debug)]
 772pub enum AcpThreadEvent {
 773    NewEntry,
 774    TitleUpdated,
 775    TokenUsageUpdated,
 776    EntryUpdated(usize),
 777    EntriesRemoved(Range<usize>),
 778    ToolAuthorizationRequired,
 779    Retry(RetryStatus),
 780    Stopped,
 781    Error,
 782    LoadError(LoadError),
 783    PromptCapabilitiesUpdated,
 784}
 785
 786impl EventEmitter<AcpThreadEvent> for AcpThread {}
 787
 788#[derive(PartialEq, Eq, Debug)]
 789pub enum ThreadStatus {
 790    Idle,
 791    WaitingForToolConfirmation,
 792    Generating,
 793}
 794
 795#[derive(Debug, Clone)]
 796pub enum LoadError {
 797    NotInstalled {
 798        error_message: SharedString,
 799        install_message: SharedString,
 800        install_command: String,
 801    },
 802    Unsupported {
 803        error_message: SharedString,
 804        upgrade_message: SharedString,
 805        upgrade_command: String,
 806    },
 807    Exited {
 808        status: ExitStatus,
 809    },
 810    Other(SharedString),
 811}
 812
 813impl Display for LoadError {
 814    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
 815        match self {
 816            LoadError::NotInstalled { error_message, .. }
 817            | LoadError::Unsupported { error_message, .. } => {
 818                write!(f, "{error_message}")
 819            }
 820            LoadError::Exited { status } => write!(f, "Server exited with status {status}"),
 821            LoadError::Other(msg) => write!(f, "{}", msg),
 822        }
 823    }
 824}
 825
 826impl Error for LoadError {}
 827
 828impl AcpThread {
 829    pub fn new(
 830        title: impl Into<SharedString>,
 831        connection: Rc<dyn AgentConnection>,
 832        project: Entity<Project>,
 833        action_log: Entity<ActionLog>,
 834        session_id: acp::SessionId,
 835        mut prompt_capabilities_rx: watch::Receiver<acp::PromptCapabilities>,
 836        cx: &mut Context<Self>,
 837    ) -> Self {
 838        let prompt_capabilities = *prompt_capabilities_rx.borrow();
 839        let task = cx.spawn::<_, anyhow::Result<()>>(async move |this, cx| {
 840            loop {
 841                let caps = prompt_capabilities_rx.recv().await?;
 842                this.update(cx, |this, cx| {
 843                    this.prompt_capabilities = caps;
 844                    cx.emit(AcpThreadEvent::PromptCapabilitiesUpdated);
 845                })?;
 846            }
 847        });
 848
 849        Self {
 850            action_log,
 851            shared_buffers: Default::default(),
 852            entries: Default::default(),
 853            plan: Default::default(),
 854            title: title.into(),
 855            project,
 856            send_task: None,
 857            connection,
 858            session_id,
 859            token_usage: None,
 860            prompt_capabilities,
 861            _observe_prompt_capabilities: task,
 862        }
 863    }
 864
 865    pub fn prompt_capabilities(&self) -> acp::PromptCapabilities {
 866        self.prompt_capabilities
 867    }
 868
 869    pub fn connection(&self) -> &Rc<dyn AgentConnection> {
 870        &self.connection
 871    }
 872
 873    pub fn action_log(&self) -> &Entity<ActionLog> {
 874        &self.action_log
 875    }
 876
 877    pub fn project(&self) -> &Entity<Project> {
 878        &self.project
 879    }
 880
 881    pub fn title(&self) -> SharedString {
 882        self.title.clone()
 883    }
 884
 885    pub fn entries(&self) -> &[AgentThreadEntry] {
 886        &self.entries
 887    }
 888
 889    pub fn session_id(&self) -> &acp::SessionId {
 890        &self.session_id
 891    }
 892
 893    pub fn status(&self) -> ThreadStatus {
 894        if self.send_task.is_some() {
 895            if self.waiting_for_tool_confirmation() {
 896                ThreadStatus::WaitingForToolConfirmation
 897            } else {
 898                ThreadStatus::Generating
 899            }
 900        } else {
 901            ThreadStatus::Idle
 902        }
 903    }
 904
 905    pub fn token_usage(&self) -> Option<&TokenUsage> {
 906        self.token_usage.as_ref()
 907    }
 908
 909    pub fn has_pending_edit_tool_calls(&self) -> bool {
 910        for entry in self.entries.iter().rev() {
 911            match entry {
 912                AgentThreadEntry::UserMessage(_) => return false,
 913                AgentThreadEntry::ToolCall(
 914                    call @ ToolCall {
 915                        status: ToolCallStatus::InProgress | ToolCallStatus::Pending,
 916                        ..
 917                    },
 918                ) if call.diffs().next().is_some() => {
 919                    return true;
 920                }
 921                AgentThreadEntry::ToolCall(_) | AgentThreadEntry::AssistantMessage(_) => {}
 922            }
 923        }
 924
 925        false
 926    }
 927
 928    pub fn used_tools_since_last_user_message(&self) -> bool {
 929        for entry in self.entries.iter().rev() {
 930            match entry {
 931                AgentThreadEntry::UserMessage(..) => return false,
 932                AgentThreadEntry::AssistantMessage(..) => continue,
 933                AgentThreadEntry::ToolCall(..) => return true,
 934            }
 935        }
 936
 937        false
 938    }
 939
 940    pub fn handle_session_update(
 941        &mut self,
 942        update: acp::SessionUpdate,
 943        cx: &mut Context<Self>,
 944    ) -> Result<(), acp::Error> {
 945        match update {
 946            acp::SessionUpdate::UserMessageChunk { content } => {
 947                self.push_user_content_block(None, content, cx);
 948            }
 949            acp::SessionUpdate::AgentMessageChunk { content } => {
 950                self.push_assistant_content_block(content, false, cx);
 951            }
 952            acp::SessionUpdate::AgentThoughtChunk { content } => {
 953                self.push_assistant_content_block(content, true, cx);
 954            }
 955            acp::SessionUpdate::ToolCall(tool_call) => {
 956                self.upsert_tool_call(tool_call, cx)?;
 957            }
 958            acp::SessionUpdate::ToolCallUpdate(tool_call_update) => {
 959                self.update_tool_call(tool_call_update, cx)?;
 960            }
 961            acp::SessionUpdate::Plan(plan) => {
 962                self.update_plan(plan, cx);
 963            }
 964        }
 965        Ok(())
 966    }
 967
 968    pub fn push_user_content_block(
 969        &mut self,
 970        prompt_id: Option<acp::PromptId>,
 971        chunk: acp::ContentBlock,
 972        cx: &mut Context<Self>,
 973    ) {
 974        let language_registry = self.project.read(cx).languages().clone();
 975        let entries_len = self.entries.len();
 976
 977        if let Some(last_entry) = self.entries.last_mut()
 978            && let AgentThreadEntry::UserMessage(UserMessage {
 979                prompt_id: id,
 980                content,
 981                chunks,
 982                ..
 983            }) = last_entry
 984        {
 985            *id = prompt_id.or(id.take());
 986            content.append(chunk.clone(), &language_registry, cx);
 987            chunks.push(chunk);
 988            let idx = entries_len - 1;
 989            cx.emit(AcpThreadEvent::EntryUpdated(idx));
 990        } else {
 991            let content = ContentBlock::new(chunk.clone(), &language_registry, cx);
 992            self.push_entry(
 993                AgentThreadEntry::UserMessage(UserMessage {
 994                    prompt_id,
 995                    content,
 996                    chunks: vec![chunk],
 997                    checkpoint: None,
 998                }),
 999                cx,
1000            );
1001        }
1002    }
1003
1004    pub fn push_assistant_content_block(
1005        &mut self,
1006        chunk: acp::ContentBlock,
1007        is_thought: bool,
1008        cx: &mut Context<Self>,
1009    ) {
1010        let language_registry = self.project.read(cx).languages().clone();
1011        let entries_len = self.entries.len();
1012        if let Some(last_entry) = self.entries.last_mut()
1013            && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
1014        {
1015            let idx = entries_len - 1;
1016            cx.emit(AcpThreadEvent::EntryUpdated(idx));
1017            match (chunks.last_mut(), is_thought) {
1018                (Some(AssistantMessageChunk::Message { block }), false)
1019                | (Some(AssistantMessageChunk::Thought { block }), true) => {
1020                    block.append(chunk, &language_registry, cx)
1021                }
1022                _ => {
1023                    let block = ContentBlock::new(chunk, &language_registry, cx);
1024                    if is_thought {
1025                        chunks.push(AssistantMessageChunk::Thought { block })
1026                    } else {
1027                        chunks.push(AssistantMessageChunk::Message { block })
1028                    }
1029                }
1030            }
1031        } else {
1032            let block = ContentBlock::new(chunk, &language_registry, cx);
1033            let chunk = if is_thought {
1034                AssistantMessageChunk::Thought { block }
1035            } else {
1036                AssistantMessageChunk::Message { block }
1037            };
1038
1039            self.push_entry(
1040                AgentThreadEntry::AssistantMessage(AssistantMessage {
1041                    chunks: vec![chunk],
1042                }),
1043                cx,
1044            );
1045        }
1046    }
1047
1048    fn push_entry(&mut self, entry: AgentThreadEntry, cx: &mut Context<Self>) {
1049        self.entries.push(entry);
1050        cx.emit(AcpThreadEvent::NewEntry);
1051    }
1052
1053    pub fn can_set_title(&mut self, cx: &mut Context<Self>) -> bool {
1054        self.connection.set_title(&self.session_id, cx).is_some()
1055    }
1056
1057    pub fn set_title(&mut self, title: SharedString, cx: &mut Context<Self>) -> Task<Result<()>> {
1058        if title != self.title {
1059            self.title = title.clone();
1060            cx.emit(AcpThreadEvent::TitleUpdated);
1061            if let Some(set_title) = self.connection.set_title(&self.session_id, cx) {
1062                return set_title.run(title, cx);
1063            }
1064        }
1065        Task::ready(Ok(()))
1066    }
1067
1068    pub fn update_token_usage(&mut self, usage: Option<TokenUsage>, cx: &mut Context<Self>) {
1069        self.token_usage = usage;
1070        cx.emit(AcpThreadEvent::TokenUsageUpdated);
1071    }
1072
1073    pub fn update_retry_status(&mut self, status: RetryStatus, cx: &mut Context<Self>) {
1074        cx.emit(AcpThreadEvent::Retry(status));
1075    }
1076
1077    pub fn update_tool_call(
1078        &mut self,
1079        update: impl Into<ToolCallUpdate>,
1080        cx: &mut Context<Self>,
1081    ) -> Result<()> {
1082        let update = update.into();
1083        let languages = self.project.read(cx).languages().clone();
1084
1085        let (ix, current_call) = self
1086            .tool_call_mut(update.id())
1087            .context("Tool call not found")?;
1088        match update {
1089            ToolCallUpdate::UpdateFields(update) => {
1090                let location_updated = update.fields.locations.is_some();
1091                current_call.update_fields(update.fields, languages, cx);
1092                if location_updated {
1093                    self.resolve_locations(update.id, cx);
1094                }
1095            }
1096            ToolCallUpdate::UpdateDiff(update) => {
1097                current_call.content.clear();
1098                current_call
1099                    .content
1100                    .push(ToolCallContent::Diff(update.diff));
1101            }
1102            ToolCallUpdate::UpdateTerminal(update) => {
1103                current_call.content.clear();
1104                current_call
1105                    .content
1106                    .push(ToolCallContent::Terminal(update.terminal));
1107            }
1108        }
1109
1110        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1111
1112        Ok(())
1113    }
1114
1115    /// Updates a tool call if id matches an existing entry, otherwise inserts a new one.
1116    pub fn upsert_tool_call(
1117        &mut self,
1118        tool_call: acp::ToolCall,
1119        cx: &mut Context<Self>,
1120    ) -> Result<(), acp::Error> {
1121        let status = tool_call.status.into();
1122        self.upsert_tool_call_inner(tool_call.into(), status, cx)
1123    }
1124
1125    /// Fails if id does not match an existing entry.
1126    pub fn upsert_tool_call_inner(
1127        &mut self,
1128        tool_call_update: acp::ToolCallUpdate,
1129        status: ToolCallStatus,
1130        cx: &mut Context<Self>,
1131    ) -> Result<(), acp::Error> {
1132        let language_registry = self.project.read(cx).languages().clone();
1133        let id = tool_call_update.id.clone();
1134
1135        if let Some((ix, current_call)) = self.tool_call_mut(&id) {
1136            current_call.update_fields(tool_call_update.fields, language_registry, cx);
1137            current_call.status = status;
1138
1139            cx.emit(AcpThreadEvent::EntryUpdated(ix));
1140        } else {
1141            let call =
1142                ToolCall::from_acp(tool_call_update.try_into()?, status, language_registry, cx);
1143            self.push_entry(AgentThreadEntry::ToolCall(call), cx);
1144        };
1145
1146        self.resolve_locations(id, cx);
1147        Ok(())
1148    }
1149
1150    fn tool_call_mut(&mut self, id: &acp::ToolCallId) -> Option<(usize, &mut ToolCall)> {
1151        // The tool call we are looking for is typically the last one, or very close to the end.
1152        // At the moment, it doesn't seem like a hashmap would be a good fit for this use case.
1153        self.entries
1154            .iter_mut()
1155            .enumerate()
1156            .rev()
1157            .find_map(|(index, tool_call)| {
1158                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1159                    && &tool_call.id == id
1160                {
1161                    Some((index, tool_call))
1162                } else {
1163                    None
1164                }
1165            })
1166    }
1167
1168    pub fn tool_call(&mut self, id: &acp::ToolCallId) -> Option<(usize, &ToolCall)> {
1169        self.entries
1170            .iter()
1171            .enumerate()
1172            .rev()
1173            .find_map(|(index, tool_call)| {
1174                if let AgentThreadEntry::ToolCall(tool_call) = tool_call
1175                    && &tool_call.id == id
1176                {
1177                    Some((index, tool_call))
1178                } else {
1179                    None
1180                }
1181            })
1182    }
1183
1184    pub fn resolve_locations(&mut self, id: acp::ToolCallId, cx: &mut Context<Self>) {
1185        let project = self.project.clone();
1186        let Some((_, tool_call)) = self.tool_call_mut(&id) else {
1187            return;
1188        };
1189        let task = tool_call.resolve_locations(project, cx);
1190        cx.spawn(async move |this, cx| {
1191            let resolved_locations = task.await;
1192            this.update(cx, |this, cx| {
1193                let project = this.project.clone();
1194                let Some((ix, tool_call)) = this.tool_call_mut(&id) else {
1195                    return;
1196                };
1197                if let Some(Some(location)) = resolved_locations.last() {
1198                    project.update(cx, |project, cx| {
1199                        if let Some(agent_location) = project.agent_location() {
1200                            let should_ignore = agent_location.buffer == location.buffer
1201                                && location
1202                                    .buffer
1203                                    .update(cx, |buffer, _| {
1204                                        let snapshot = buffer.snapshot();
1205                                        let old_position =
1206                                            agent_location.position.to_point(&snapshot);
1207                                        let new_position = location.position.to_point(&snapshot);
1208                                        // ignore this so that when we get updates from the edit tool
1209                                        // the position doesn't reset to the startof line
1210                                        old_position.row == new_position.row
1211                                            && old_position.column > new_position.column
1212                                    })
1213                                    .ok()
1214                                    .unwrap_or_default();
1215                            if !should_ignore {
1216                                project.set_agent_location(Some(location.clone()), cx);
1217                            }
1218                        }
1219                    });
1220                }
1221                if tool_call.resolved_locations != resolved_locations {
1222                    tool_call.resolved_locations = resolved_locations;
1223                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1224                }
1225            })
1226        })
1227        .detach();
1228    }
1229
1230    pub fn request_tool_call_authorization(
1231        &mut self,
1232        tool_call: acp::ToolCallUpdate,
1233        options: Vec<acp::PermissionOption>,
1234        cx: &mut Context<Self>,
1235    ) -> Result<oneshot::Receiver<acp::PermissionOptionId>, acp::Error> {
1236        let (tx, rx) = oneshot::channel();
1237
1238        let status = ToolCallStatus::WaitingForConfirmation {
1239            options,
1240            respond_tx: tx,
1241        };
1242
1243        self.upsert_tool_call_inner(tool_call, status, cx)?;
1244        cx.emit(AcpThreadEvent::ToolAuthorizationRequired);
1245        Ok(rx)
1246    }
1247
1248    pub fn authorize_tool_call(
1249        &mut self,
1250        id: acp::ToolCallId,
1251        option_id: acp::PermissionOptionId,
1252        option_kind: acp::PermissionOptionKind,
1253        cx: &mut Context<Self>,
1254    ) {
1255        let Some((ix, call)) = self.tool_call_mut(&id) else {
1256            return;
1257        };
1258
1259        let new_status = match option_kind {
1260            acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
1261                ToolCallStatus::Rejected
1262            }
1263            acp::PermissionOptionKind::AllowOnce | acp::PermissionOptionKind::AllowAlways => {
1264                ToolCallStatus::InProgress
1265            }
1266        };
1267
1268        let curr_status = mem::replace(&mut call.status, new_status);
1269
1270        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
1271            respond_tx.send(option_id).log_err();
1272        } else if cfg!(debug_assertions) {
1273            panic!("tried to authorize an already authorized tool call");
1274        }
1275
1276        cx.emit(AcpThreadEvent::EntryUpdated(ix));
1277    }
1278
1279    /// Returns true if the last turn is awaiting tool authorization
1280    pub fn waiting_for_tool_confirmation(&self) -> bool {
1281        for entry in self.entries.iter().rev() {
1282            match &entry {
1283                AgentThreadEntry::ToolCall(call) => match call.status {
1284                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
1285                    ToolCallStatus::Pending
1286                    | ToolCallStatus::InProgress
1287                    | ToolCallStatus::Completed
1288                    | ToolCallStatus::Failed
1289                    | ToolCallStatus::Rejected
1290                    | ToolCallStatus::Canceled => continue,
1291                },
1292                AgentThreadEntry::UserMessage(_) | AgentThreadEntry::AssistantMessage(_) => {
1293                    // Reached the beginning of the turn
1294                    return false;
1295                }
1296            }
1297        }
1298        false
1299    }
1300
1301    pub fn plan(&self) -> &Plan {
1302        &self.plan
1303    }
1304
1305    pub fn update_plan(&mut self, request: acp::Plan, cx: &mut Context<Self>) {
1306        let new_entries_len = request.entries.len();
1307        let mut new_entries = request.entries.into_iter();
1308
1309        // Reuse existing markdown to prevent flickering
1310        for (old, new) in self.plan.entries.iter_mut().zip(new_entries.by_ref()) {
1311            let PlanEntry {
1312                content,
1313                priority,
1314                status,
1315            } = old;
1316            content.update(cx, |old, cx| {
1317                old.replace(new.content, cx);
1318            });
1319            *priority = new.priority;
1320            *status = new.status;
1321        }
1322        for new in new_entries {
1323            self.plan.entries.push(PlanEntry::from_acp(new, cx))
1324        }
1325        self.plan.entries.truncate(new_entries_len);
1326
1327        cx.notify();
1328    }
1329
1330    fn clear_completed_plan_entries(&mut self, cx: &mut Context<Self>) {
1331        self.plan
1332            .entries
1333            .retain(|entry| !matches!(entry.status, acp::PlanEntryStatus::Completed));
1334        cx.notify();
1335    }
1336
1337    #[cfg(any(test, feature = "test-support"))]
1338    pub fn send_raw(
1339        &mut self,
1340        message: &str,
1341        cx: &mut Context<Self>,
1342    ) -> BoxFuture<'static, Result<()>> {
1343        self.send(
1344            new_prompt_id(),
1345            vec![acp::ContentBlock::Text(acp::TextContent {
1346                text: message.to_string(),
1347                annotations: None,
1348            })],
1349            cx,
1350        )
1351    }
1352
1353    pub fn send(
1354        &mut self,
1355        prompt_id: acp::PromptId,
1356        message: Vec<acp::ContentBlock>,
1357        cx: &mut Context<Self>,
1358    ) -> BoxFuture<'static, Result<()>> {
1359        let block = ContentBlock::new_combined(
1360            message.clone(),
1361            self.project.read(cx).languages().clone(),
1362            cx,
1363        );
1364        let request = acp::PromptRequest {
1365            prompt_id: Some(prompt_id.clone()),
1366            prompt: message.clone(),
1367            session_id: self.session_id.clone(),
1368        };
1369        let git_store = self.project.read(cx).git_store().clone();
1370
1371        self.run_turn(cx, async move |this, cx| {
1372            this.update(cx, |this, cx| {
1373                this.push_entry(
1374                    AgentThreadEntry::UserMessage(UserMessage {
1375                        prompt_id: Some(prompt_id),
1376                        content: block,
1377                        chunks: message,
1378                        checkpoint: None,
1379                    }),
1380                    cx,
1381                );
1382            })
1383            .ok();
1384
1385            let old_checkpoint = git_store
1386                .update(cx, |git, cx| git.checkpoint(cx))?
1387                .await
1388                .context("failed to get old checkpoint")
1389                .log_err();
1390            this.update(cx, |this, cx| {
1391                if let Some((_ix, message)) = this.last_user_message() {
1392                    message.checkpoint = old_checkpoint.map(|git_checkpoint| Checkpoint {
1393                        git_checkpoint,
1394                        show: false,
1395                    });
1396                }
1397                this.connection.prompt(request, cx)
1398            })?
1399            .await
1400        })
1401    }
1402
1403    pub fn can_resume(&self, cx: &App) -> bool {
1404        self.connection.resume(&self.session_id, cx).is_some()
1405    }
1406
1407    pub fn resume(&mut self, cx: &mut Context<Self>) -> BoxFuture<'static, Result<()>> {
1408        self.run_turn(cx, async move |this, cx| {
1409            this.update(cx, |this, cx| {
1410                this.connection
1411                    .resume(&this.session_id, cx)
1412                    .map(|resume| resume.run(cx))
1413            })?
1414            .context("resuming a session is not supported")?
1415            .await
1416        })
1417    }
1418
1419    fn run_turn(
1420        &mut self,
1421        cx: &mut Context<Self>,
1422        f: impl 'static + AsyncFnOnce(WeakEntity<Self>, &mut AsyncApp) -> Result<acp::PromptResponse>,
1423    ) -> BoxFuture<'static, Result<()>> {
1424        self.clear_completed_plan_entries(cx);
1425
1426        let (tx, rx) = oneshot::channel();
1427        let cancel_task = self.cancel(cx);
1428
1429        self.send_task = Some(cx.spawn(async move |this, cx| {
1430            cancel_task.await;
1431            tx.send(f(this, cx).await).ok();
1432        }));
1433
1434        cx.spawn(async move |this, cx| {
1435            let response = rx.await;
1436
1437            this.update(cx, |this, cx| this.update_last_checkpoint(cx))?
1438                .await?;
1439
1440            this.update(cx, |this, cx| {
1441                this.project
1442                    .update(cx, |project, cx| project.set_agent_location(None, cx));
1443                match response {
1444                    Ok(Err(e)) => {
1445                        this.send_task.take();
1446                        cx.emit(AcpThreadEvent::Error);
1447                        Err(e)
1448                    }
1449                    result => {
1450                        let canceled = matches!(
1451                            result,
1452                            Ok(Ok(acp::PromptResponse {
1453                                stop_reason: acp::StopReason::Cancelled
1454                            }))
1455                        );
1456
1457                        // We only take the task if the current prompt wasn't canceled.
1458                        //
1459                        // This prompt may have been canceled because another one was sent
1460                        // while it was still generating. In these cases, dropping `send_task`
1461                        // would cause the next generation to be canceled.
1462                        if !canceled {
1463                            this.send_task.take();
1464                        }
1465
1466                        // Truncate entries if the last prompt was refused.
1467                        if let Ok(Ok(acp::PromptResponse {
1468                            stop_reason: acp::StopReason::Refusal,
1469                        })) = result
1470                            && let Some((ix, _)) = this.last_user_message()
1471                        {
1472                            let range = ix..this.entries.len();
1473                            this.entries.truncate(ix);
1474                            cx.emit(AcpThreadEvent::EntriesRemoved(range));
1475                        }
1476
1477                        cx.emit(AcpThreadEvent::Stopped);
1478                        Ok(())
1479                    }
1480                }
1481            })?
1482        })
1483        .boxed()
1484    }
1485
1486    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<()> {
1487        let Some(send_task) = self.send_task.take() else {
1488            return Task::ready(());
1489        };
1490
1491        for entry in self.entries.iter_mut() {
1492            if let AgentThreadEntry::ToolCall(call) = entry {
1493                let cancel = matches!(
1494                    call.status,
1495                    ToolCallStatus::Pending
1496                        | ToolCallStatus::WaitingForConfirmation { .. }
1497                        | ToolCallStatus::InProgress
1498                );
1499
1500                if cancel {
1501                    call.status = ToolCallStatus::Canceled;
1502                }
1503            }
1504        }
1505
1506        self.connection.cancel(&self.session_id, cx);
1507
1508        // Wait for the send task to complete
1509        cx.foreground_executor().spawn(send_task)
1510    }
1511
1512    /// Rewinds this thread to before the entry at `index`, removing it and all
1513    /// subsequent entries while reverting any changes made from that point.
1514    pub fn rewind(&mut self, id: acp::PromptId, cx: &mut Context<Self>) -> Task<Result<()>> {
1515        let Some(rewind) = self.connection.rewind(&self.session_id, cx) else {
1516            return Task::ready(Err(anyhow!("not supported")));
1517        };
1518        let Some(message) = self.user_message(&id) else {
1519            return Task::ready(Err(anyhow!("message not found")));
1520        };
1521
1522        let checkpoint = message
1523            .checkpoint
1524            .as_ref()
1525            .map(|c| c.git_checkpoint.clone());
1526
1527        let git_store = self.project.read(cx).git_store().clone();
1528        cx.spawn(async move |this, cx| {
1529            if let Some(checkpoint) = checkpoint {
1530                git_store
1531                    .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
1532                    .await?;
1533            }
1534
1535            cx.update(|cx| rewind.rewind(id.clone(), cx))?.await?;
1536            this.update(cx, |this, cx| {
1537                if let Some((ix, _)) = this.user_message_mut(&id) {
1538                    let range = ix..this.entries.len();
1539                    this.entries.truncate(ix);
1540                    cx.emit(AcpThreadEvent::EntriesRemoved(range));
1541                }
1542            })
1543        })
1544    }
1545
1546    fn update_last_checkpoint(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
1547        let git_store = self.project.read(cx).git_store().clone();
1548
1549        let old_checkpoint = if let Some((_, message)) = self.last_user_message() {
1550            if let Some(checkpoint) = message.checkpoint.as_ref() {
1551                checkpoint.git_checkpoint.clone()
1552            } else {
1553                return Task::ready(Ok(()));
1554            }
1555        } else {
1556            return Task::ready(Ok(()));
1557        };
1558
1559        let new_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
1560        cx.spawn(async move |this, cx| {
1561            let new_checkpoint = new_checkpoint
1562                .await
1563                .context("failed to get new checkpoint")
1564                .log_err();
1565            if let Some(new_checkpoint) = new_checkpoint {
1566                let equal = git_store
1567                    .update(cx, |git, cx| {
1568                        git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
1569                    })?
1570                    .await
1571                    .unwrap_or(true);
1572                this.update(cx, |this, cx| {
1573                    let (ix, message) = this.last_user_message().context("no user message")?;
1574                    let checkpoint = message.checkpoint.as_mut().context("no checkpoint")?;
1575                    checkpoint.show = !equal;
1576                    cx.emit(AcpThreadEvent::EntryUpdated(ix));
1577                    anyhow::Ok(())
1578                })??;
1579            }
1580
1581            Ok(())
1582        })
1583    }
1584
1585    fn last_user_message(&mut self) -> Option<(usize, &mut UserMessage)> {
1586        self.entries
1587            .iter_mut()
1588            .enumerate()
1589            .rev()
1590            .find_map(|(ix, entry)| {
1591                if let AgentThreadEntry::UserMessage(message) = entry {
1592                    Some((ix, message))
1593                } else {
1594                    None
1595                }
1596            })
1597    }
1598
1599    fn user_message(&self, id: &acp::PromptId) -> Option<&UserMessage> {
1600        self.entries.iter().find_map(|entry| {
1601            if let AgentThreadEntry::UserMessage(message) = entry {
1602                if message.prompt_id.as_ref() == Some(id) {
1603                    Some(message)
1604                } else {
1605                    None
1606                }
1607            } else {
1608                None
1609            }
1610        })
1611    }
1612
1613    fn user_message_mut(&mut self, id: &acp::PromptId) -> Option<(usize, &mut UserMessage)> {
1614        self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
1615            if let AgentThreadEntry::UserMessage(message) = entry {
1616                if message.prompt_id.as_ref() == Some(id) {
1617                    Some((ix, message))
1618                } else {
1619                    None
1620                }
1621            } else {
1622                None
1623            }
1624        })
1625    }
1626
1627    pub fn read_text_file(
1628        &self,
1629        path: PathBuf,
1630        line: Option<u32>,
1631        limit: Option<u32>,
1632        reuse_shared_snapshot: bool,
1633        cx: &mut Context<Self>,
1634    ) -> Task<Result<String>> {
1635        let project = self.project.clone();
1636        let action_log = self.action_log.clone();
1637        cx.spawn(async move |this, cx| {
1638            let load = project.update(cx, |project, cx| {
1639                let path = project
1640                    .project_path_for_absolute_path(&path, cx)
1641                    .context("invalid path")?;
1642                anyhow::Ok(project.open_buffer(path, cx))
1643            });
1644            let buffer = load??.await?;
1645
1646            let snapshot = if reuse_shared_snapshot {
1647                this.read_with(cx, |this, _| {
1648                    this.shared_buffers.get(&buffer.clone()).cloned()
1649                })
1650                .log_err()
1651                .flatten()
1652            } else {
1653                None
1654            };
1655
1656            let snapshot = if let Some(snapshot) = snapshot {
1657                snapshot
1658            } else {
1659                action_log.update(cx, |action_log, cx| {
1660                    action_log.buffer_read(buffer.clone(), cx);
1661                })?;
1662                project.update(cx, |project, cx| {
1663                    let position = buffer
1664                        .read(cx)
1665                        .snapshot()
1666                        .anchor_before(Point::new(line.unwrap_or_default(), 0));
1667                    project.set_agent_location(
1668                        Some(AgentLocation {
1669                            buffer: buffer.downgrade(),
1670                            position,
1671                        }),
1672                        cx,
1673                    );
1674                })?;
1675
1676                buffer.update(cx, |buffer, _| buffer.snapshot())?
1677            };
1678
1679            this.update(cx, |this, _| {
1680                let text = snapshot.text();
1681                this.shared_buffers.insert(buffer.clone(), snapshot);
1682                if line.is_none() && limit.is_none() {
1683                    return Ok(text);
1684                }
1685                let limit = limit.unwrap_or(u32::MAX) as usize;
1686                let Some(line) = line else {
1687                    return Ok(text.lines().take(limit).collect::<String>());
1688                };
1689
1690                let count = text.lines().count();
1691                if count < line as usize {
1692                    anyhow::bail!("There are only {} lines", count);
1693                }
1694                Ok(text
1695                    .lines()
1696                    .skip(line as usize + 1)
1697                    .take(limit)
1698                    .collect::<String>())
1699            })?
1700        })
1701    }
1702
1703    pub fn write_text_file(
1704        &self,
1705        path: PathBuf,
1706        content: String,
1707        cx: &mut Context<Self>,
1708    ) -> Task<Result<()>> {
1709        let project = self.project.clone();
1710        let action_log = self.action_log.clone();
1711        cx.spawn(async move |this, cx| {
1712            let load = project.update(cx, |project, cx| {
1713                let path = project
1714                    .project_path_for_absolute_path(&path, cx)
1715                    .context("invalid path")?;
1716                anyhow::Ok(project.open_buffer(path, cx))
1717            });
1718            let buffer = load??.await?;
1719            let snapshot = this.update(cx, |this, cx| {
1720                this.shared_buffers
1721                    .get(&buffer)
1722                    .cloned()
1723                    .unwrap_or_else(|| buffer.read(cx).snapshot())
1724            })?;
1725            let edits = cx
1726                .background_executor()
1727                .spawn(async move {
1728                    let old_text = snapshot.text();
1729                    text_diff(old_text.as_str(), &content)
1730                        .into_iter()
1731                        .map(|(range, replacement)| {
1732                            (
1733                                snapshot.anchor_after(range.start)
1734                                    ..snapshot.anchor_before(range.end),
1735                                replacement,
1736                            )
1737                        })
1738                        .collect::<Vec<_>>()
1739                })
1740                .await;
1741
1742            project.update(cx, |project, cx| {
1743                project.set_agent_location(
1744                    Some(AgentLocation {
1745                        buffer: buffer.downgrade(),
1746                        position: edits
1747                            .last()
1748                            .map(|(range, _)| range.end)
1749                            .unwrap_or(Anchor::MIN),
1750                    }),
1751                    cx,
1752                );
1753            })?;
1754
1755            let format_on_save = cx.update(|cx| {
1756                action_log.update(cx, |action_log, cx| {
1757                    action_log.buffer_read(buffer.clone(), cx);
1758                });
1759
1760                let format_on_save = buffer.update(cx, |buffer, cx| {
1761                    buffer.edit(edits, None, cx);
1762
1763                    let settings = language::language_settings::language_settings(
1764                        buffer.language().map(|l| l.name()),
1765                        buffer.file(),
1766                        cx,
1767                    );
1768
1769                    settings.format_on_save != FormatOnSave::Off
1770                });
1771                action_log.update(cx, |action_log, cx| {
1772                    action_log.buffer_edited(buffer.clone(), cx);
1773                });
1774                format_on_save
1775            })?;
1776
1777            if format_on_save {
1778                let format_task = project.update(cx, |project, cx| {
1779                    project.format(
1780                        HashSet::from_iter([buffer.clone()]),
1781                        LspFormatTarget::Buffers,
1782                        false,
1783                        FormatTrigger::Save,
1784                        cx,
1785                    )
1786                })?;
1787                format_task.await.log_err();
1788
1789                action_log.update(cx, |action_log, cx| {
1790                    action_log.buffer_edited(buffer.clone(), cx);
1791                })?;
1792            }
1793
1794            project
1795                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
1796                .await
1797        })
1798    }
1799
1800    pub fn to_markdown(&self, cx: &App) -> String {
1801        self.entries.iter().map(|e| e.to_markdown(cx)).collect()
1802    }
1803
1804    pub fn emit_load_error(&mut self, error: LoadError, cx: &mut Context<Self>) {
1805        cx.emit(AcpThreadEvent::LoadError(error));
1806    }
1807}
1808
1809fn markdown_for_raw_output(
1810    raw_output: &serde_json::Value,
1811    language_registry: &Arc<LanguageRegistry>,
1812    cx: &mut App,
1813) -> Option<Entity<Markdown>> {
1814    match raw_output {
1815        serde_json::Value::Null => None,
1816        serde_json::Value::Bool(value) => Some(cx.new(|cx| {
1817            Markdown::new(
1818                value.to_string().into(),
1819                Some(language_registry.clone()),
1820                None,
1821                cx,
1822            )
1823        })),
1824        serde_json::Value::Number(value) => Some(cx.new(|cx| {
1825            Markdown::new(
1826                value.to_string().into(),
1827                Some(language_registry.clone()),
1828                None,
1829                cx,
1830            )
1831        })),
1832        serde_json::Value::String(value) => Some(cx.new(|cx| {
1833            Markdown::new(
1834                value.clone().into(),
1835                Some(language_registry.clone()),
1836                None,
1837                cx,
1838            )
1839        })),
1840        value => Some(cx.new(|cx| {
1841            Markdown::new(
1842                format!("```json\n{}\n```", value).into(),
1843                Some(language_registry.clone()),
1844                None,
1845                cx,
1846            )
1847        })),
1848    }
1849}
1850
1851#[cfg(test)]
1852mod tests {
1853    use super::*;
1854    use anyhow::anyhow;
1855    use futures::{channel::mpsc, future::LocalBoxFuture, select};
1856    use gpui::{App, AsyncApp, TestAppContext, WeakEntity};
1857    use indoc::indoc;
1858    use project::{FakeFs, Fs};
1859    use rand::Rng as _;
1860    use serde_json::json;
1861    use settings::SettingsStore;
1862    use smol::stream::StreamExt as _;
1863    use std::{
1864        any::Any,
1865        cell::RefCell,
1866        path::Path,
1867        rc::Rc,
1868        sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
1869        time::Duration,
1870    };
1871    use util::path;
1872
1873    fn init_test(cx: &mut TestAppContext) {
1874        env_logger::try_init().ok();
1875        cx.update(|cx| {
1876            let settings_store = SettingsStore::test(cx);
1877            cx.set_global(settings_store);
1878            Project::init_settings(cx);
1879            language::init(cx);
1880        });
1881    }
1882
1883    #[gpui::test]
1884    async fn test_push_user_content_block(cx: &mut gpui::TestAppContext) {
1885        init_test(cx);
1886
1887        let fs = FakeFs::new(cx.executor());
1888        let project = Project::test(fs, [], cx).await;
1889        let connection = Rc::new(FakeAgentConnection::new());
1890        let thread = cx
1891            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
1892            .await
1893            .unwrap();
1894
1895        // Test creating a new user message
1896        thread.update(cx, |thread, cx| {
1897            thread.push_user_content_block(
1898                None,
1899                acp::ContentBlock::Text(acp::TextContent {
1900                    annotations: None,
1901                    text: "Hello, ".to_string(),
1902                }),
1903                cx,
1904            );
1905        });
1906
1907        thread.update(cx, |thread, cx| {
1908            assert_eq!(thread.entries.len(), 1);
1909            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1910                assert_eq!(user_msg.prompt_id, None);
1911                assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
1912            } else {
1913                panic!("Expected UserMessage");
1914            }
1915        });
1916
1917        // Test appending to existing user message
1918        let message_1_id = new_prompt_id();
1919        thread.update(cx, |thread, cx| {
1920            thread.push_user_content_block(
1921                Some(message_1_id.clone()),
1922                acp::ContentBlock::Text(acp::TextContent {
1923                    annotations: None,
1924                    text: "world!".to_string(),
1925                }),
1926                cx,
1927            );
1928        });
1929
1930        thread.update(cx, |thread, cx| {
1931            assert_eq!(thread.entries.len(), 1);
1932            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
1933                assert_eq!(user_msg.prompt_id, Some(message_1_id));
1934                assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
1935            } else {
1936                panic!("Expected UserMessage");
1937            }
1938        });
1939
1940        // Test creating new user message after assistant message
1941        thread.update(cx, |thread, cx| {
1942            thread.push_assistant_content_block(
1943                acp::ContentBlock::Text(acp::TextContent {
1944                    annotations: None,
1945                    text: "Assistant response".to_string(),
1946                }),
1947                false,
1948                cx,
1949            );
1950        });
1951
1952        let message_2_id = new_prompt_id();
1953        thread.update(cx, |thread, cx| {
1954            thread.push_user_content_block(
1955                Some(message_2_id.clone()),
1956                acp::ContentBlock::Text(acp::TextContent {
1957                    annotations: None,
1958                    text: "New user message".to_string(),
1959                }),
1960                cx,
1961            );
1962        });
1963
1964        thread.update(cx, |thread, cx| {
1965            assert_eq!(thread.entries.len(), 3);
1966            if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
1967                assert_eq!(user_msg.prompt_id, Some(message_2_id));
1968                assert_eq!(user_msg.content.to_markdown(cx), "New user message");
1969            } else {
1970                panic!("Expected UserMessage at index 2");
1971            }
1972        });
1973    }
1974
1975    #[gpui::test]
1976    async fn test_thinking_concatenation(cx: &mut gpui::TestAppContext) {
1977        init_test(cx);
1978
1979        let fs = FakeFs::new(cx.executor());
1980        let project = Project::test(fs, [], cx).await;
1981        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
1982            |_, thread, mut cx| {
1983                async move {
1984                    thread.update(&mut cx, |thread, cx| {
1985                        thread
1986                            .handle_session_update(
1987                                acp::SessionUpdate::AgentThoughtChunk {
1988                                    content: "Thinking ".into(),
1989                                },
1990                                cx,
1991                            )
1992                            .unwrap();
1993                        thread
1994                            .handle_session_update(
1995                                acp::SessionUpdate::AgentThoughtChunk {
1996                                    content: "hard!".into(),
1997                                },
1998                                cx,
1999                            )
2000                            .unwrap();
2001                    })?;
2002                    Ok(acp::PromptResponse {
2003                        stop_reason: acp::StopReason::EndTurn,
2004                    })
2005                }
2006                .boxed_local()
2007            },
2008        ));
2009
2010        let thread = cx
2011            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2012            .await
2013            .unwrap();
2014
2015        thread
2016            .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
2017            .await
2018            .unwrap();
2019
2020        let output = thread.read_with(cx, |thread, cx| thread.to_markdown(cx));
2021        assert_eq!(
2022            output,
2023            indoc! {r#"
2024            ## User
2025
2026            Hello from Zed!
2027
2028            ## Assistant
2029
2030            <thinking>
2031            Thinking hard!
2032            </thinking>
2033
2034            "#}
2035        );
2036    }
2037
2038    #[gpui::test]
2039    async fn test_edits_concurrently_to_user(cx: &mut TestAppContext) {
2040        init_test(cx);
2041
2042        let fs = FakeFs::new(cx.executor());
2043        fs.insert_tree(path!("/tmp"), json!({"foo": "one\ntwo\nthree\n"}))
2044            .await;
2045        let project = Project::test(fs.clone(), [], cx).await;
2046        let (read_file_tx, read_file_rx) = oneshot::channel::<()>();
2047        let read_file_tx = Rc::new(RefCell::new(Some(read_file_tx)));
2048        let connection = Rc::new(FakeAgentConnection::new().on_user_message(
2049            move |_, thread, mut cx| {
2050                let read_file_tx = read_file_tx.clone();
2051                async move {
2052                    let content = thread
2053                        .update(&mut cx, |thread, cx| {
2054                            thread.read_text_file(path!("/tmp/foo").into(), None, None, false, cx)
2055                        })
2056                        .unwrap()
2057                        .await
2058                        .unwrap();
2059                    assert_eq!(content, "one\ntwo\nthree\n");
2060                    read_file_tx.take().unwrap().send(()).unwrap();
2061                    thread
2062                        .update(&mut cx, |thread, cx| {
2063                            thread.write_text_file(
2064                                path!("/tmp/foo").into(),
2065                                "one\ntwo\nthree\nfour\nfive\n".to_string(),
2066                                cx,
2067                            )
2068                        })
2069                        .unwrap()
2070                        .await
2071                        .unwrap();
2072                    Ok(acp::PromptResponse {
2073                        stop_reason: acp::StopReason::EndTurn,
2074                    })
2075                }
2076                .boxed_local()
2077            },
2078        ));
2079
2080        let (worktree, pathbuf) = project
2081            .update(cx, |project, cx| {
2082                project.find_or_create_worktree(path!("/tmp/foo"), true, cx)
2083            })
2084            .await
2085            .unwrap();
2086        let buffer = project
2087            .update(cx, |project, cx| {
2088                project.open_buffer((worktree.read(cx).id(), pathbuf), cx)
2089            })
2090            .await
2091            .unwrap();
2092
2093        let thread = cx
2094            .update(|cx| connection.new_thread(project, Path::new(path!("/tmp")), cx))
2095            .await
2096            .unwrap();
2097
2098        let request = thread.update(cx, |thread, cx| {
2099            thread.send_raw("Extend the count in /tmp/foo", cx)
2100        });
2101        read_file_rx.await.ok();
2102        buffer.update(cx, |buffer, cx| {
2103            buffer.edit([(0..0, "zero\n".to_string())], None, cx);
2104        });
2105        cx.run_until_parked();
2106        assert_eq!(
2107            buffer.read_with(cx, |buffer, _| buffer.text()),
2108            "zero\none\ntwo\nthree\nfour\nfive\n"
2109        );
2110        assert_eq!(
2111            String::from_utf8(fs.read_file_sync(path!("/tmp/foo")).unwrap()).unwrap(),
2112            "zero\none\ntwo\nthree\nfour\nfive\n"
2113        );
2114        request.await.unwrap();
2115    }
2116
2117    #[gpui::test]
2118    async fn test_succeeding_canceled_toolcall(cx: &mut TestAppContext) {
2119        init_test(cx);
2120
2121        let fs = FakeFs::new(cx.executor());
2122        let project = Project::test(fs, [], cx).await;
2123        let id = acp::ToolCallId("test".into());
2124
2125        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2126            let id = id.clone();
2127            move |_, thread, mut cx| {
2128                let id = id.clone();
2129                async move {
2130                    thread
2131                        .update(&mut cx, |thread, cx| {
2132                            thread.handle_session_update(
2133                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2134                                    id: id.clone(),
2135                                    title: "Label".into(),
2136                                    kind: acp::ToolKind::Fetch,
2137                                    status: acp::ToolCallStatus::InProgress,
2138                                    content: vec![],
2139                                    locations: vec![],
2140                                    raw_input: None,
2141                                    raw_output: None,
2142                                }),
2143                                cx,
2144                            )
2145                        })
2146                        .unwrap()
2147                        .unwrap();
2148                    Ok(acp::PromptResponse {
2149                        stop_reason: acp::StopReason::EndTurn,
2150                    })
2151                }
2152                .boxed_local()
2153            }
2154        }));
2155
2156        let thread = cx
2157            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2158            .await
2159            .unwrap();
2160
2161        let request = thread.update(cx, |thread, cx| {
2162            thread.send_raw("Fetch https://example.com", cx)
2163        });
2164
2165        run_until_first_tool_call(&thread, cx).await;
2166
2167        thread.read_with(cx, |thread, _| {
2168            assert!(matches!(
2169                thread.entries[1],
2170                AgentThreadEntry::ToolCall(ToolCall {
2171                    status: ToolCallStatus::InProgress,
2172                    ..
2173                })
2174            ));
2175        });
2176
2177        thread.update(cx, |thread, cx| thread.cancel(cx)).await;
2178
2179        thread.read_with(cx, |thread, _| {
2180            assert!(matches!(
2181                &thread.entries[1],
2182                AgentThreadEntry::ToolCall(ToolCall {
2183                    status: ToolCallStatus::Canceled,
2184                    ..
2185                })
2186            ));
2187        });
2188
2189        thread
2190            .update(cx, |thread, cx| {
2191                thread.handle_session_update(
2192                    acp::SessionUpdate::ToolCallUpdate(acp::ToolCallUpdate {
2193                        id,
2194                        fields: acp::ToolCallUpdateFields {
2195                            status: Some(acp::ToolCallStatus::Completed),
2196                            ..Default::default()
2197                        },
2198                    }),
2199                    cx,
2200                )
2201            })
2202            .unwrap();
2203
2204        request.await.unwrap();
2205
2206        thread.read_with(cx, |thread, _| {
2207            assert!(matches!(
2208                thread.entries[1],
2209                AgentThreadEntry::ToolCall(ToolCall {
2210                    status: ToolCallStatus::Completed,
2211                    ..
2212                })
2213            ));
2214        });
2215    }
2216
2217    #[gpui::test]
2218    async fn test_no_pending_edits_if_tool_calls_are_completed(cx: &mut TestAppContext) {
2219        init_test(cx);
2220        let fs = FakeFs::new(cx.background_executor.clone());
2221        fs.insert_tree(path!("/test"), json!({})).await;
2222        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
2223
2224        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2225            move |_, thread, mut cx| {
2226                async move {
2227                    thread
2228                        .update(&mut cx, |thread, cx| {
2229                            thread.handle_session_update(
2230                                acp::SessionUpdate::ToolCall(acp::ToolCall {
2231                                    id: acp::ToolCallId("test".into()),
2232                                    title: "Label".into(),
2233                                    kind: acp::ToolKind::Edit,
2234                                    status: acp::ToolCallStatus::Completed,
2235                                    content: vec![acp::ToolCallContent::Diff {
2236                                        diff: acp::Diff {
2237                                            path: "/test/test.txt".into(),
2238                                            old_text: None,
2239                                            new_text: "foo".into(),
2240                                        },
2241                                    }],
2242                                    locations: vec![],
2243                                    raw_input: None,
2244                                    raw_output: None,
2245                                }),
2246                                cx,
2247                            )
2248                        })
2249                        .unwrap()
2250                        .unwrap();
2251                    Ok(acp::PromptResponse {
2252                        stop_reason: acp::StopReason::EndTurn,
2253                    })
2254                }
2255                .boxed_local()
2256            }
2257        }));
2258
2259        let thread = cx
2260            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2261            .await
2262            .unwrap();
2263
2264        cx.update(|cx| {
2265            thread.update(cx, |thread, cx| {
2266                thread.send(new_prompt_id(), vec!["Hi".into()], cx)
2267            })
2268        })
2269        .await
2270        .unwrap();
2271
2272        assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
2273    }
2274
2275    #[gpui::test(iterations = 10)]
2276    async fn test_checkpoints(cx: &mut TestAppContext) {
2277        init_test(cx);
2278        let fs = FakeFs::new(cx.background_executor.clone());
2279        fs.insert_tree(
2280            path!("/test"),
2281            json!({
2282                ".git": {}
2283            }),
2284        )
2285        .await;
2286        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2287
2288        let simulate_changes = Arc::new(AtomicBool::new(true));
2289        let next_filename = Arc::new(AtomicUsize::new(0));
2290        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2291            let simulate_changes = simulate_changes.clone();
2292            let next_filename = next_filename.clone();
2293            let fs = fs.clone();
2294            move |request, thread, mut cx| {
2295                let fs = fs.clone();
2296                let simulate_changes = simulate_changes.clone();
2297                let next_filename = next_filename.clone();
2298                async move {
2299                    if simulate_changes.load(SeqCst) {
2300                        let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
2301                        fs.write(Path::new(&filename), b"").await?;
2302                    }
2303
2304                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2305                        panic!("expected text content block");
2306                    };
2307                    thread.update(&mut cx, |thread, cx| {
2308                        thread
2309                            .handle_session_update(
2310                                acp::SessionUpdate::AgentMessageChunk {
2311                                    content: content.text.to_uppercase().into(),
2312                                },
2313                                cx,
2314                            )
2315                            .unwrap();
2316                    })?;
2317                    Ok(acp::PromptResponse {
2318                        stop_reason: acp::StopReason::EndTurn,
2319                    })
2320                }
2321                .boxed_local()
2322            }
2323        }));
2324        let thread = cx
2325            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2326            .await
2327            .unwrap();
2328
2329        cx.update(|cx| {
2330            thread.update(cx, |thread, cx| {
2331                thread.send(new_prompt_id(), vec!["Lorem".into()], cx)
2332            })
2333        })
2334        .await
2335        .unwrap();
2336        thread.read_with(cx, |thread, cx| {
2337            assert_eq!(
2338                thread.to_markdown(cx),
2339                indoc! {"
2340                    ## User (checkpoint)
2341
2342                    Lorem
2343
2344                    ## Assistant
2345
2346                    LOREM
2347
2348                "}
2349            );
2350        });
2351        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2352
2353        cx.update(|cx| {
2354            thread.update(cx, |thread, cx| {
2355                thread.send(new_prompt_id(), vec!["ipsum".into()], cx)
2356            })
2357        })
2358        .await
2359        .unwrap();
2360        thread.read_with(cx, |thread, cx| {
2361            assert_eq!(
2362                thread.to_markdown(cx),
2363                indoc! {"
2364                    ## User (checkpoint)
2365
2366                    Lorem
2367
2368                    ## Assistant
2369
2370                    LOREM
2371
2372                    ## User (checkpoint)
2373
2374                    ipsum
2375
2376                    ## Assistant
2377
2378                    IPSUM
2379
2380                "}
2381            );
2382        });
2383        assert_eq!(
2384            fs.files(),
2385            vec![
2386                Path::new(path!("/test/file-0")),
2387                Path::new(path!("/test/file-1"))
2388            ]
2389        );
2390
2391        // Checkpoint isn't stored when there are no changes.
2392        simulate_changes.store(false, SeqCst);
2393        cx.update(|cx| {
2394            thread.update(cx, |thread, cx| {
2395                thread.send(new_prompt_id(), vec!["dolor".into()], cx)
2396            })
2397        })
2398        .await
2399        .unwrap();
2400        thread.read_with(cx, |thread, cx| {
2401            assert_eq!(
2402                thread.to_markdown(cx),
2403                indoc! {"
2404                    ## User (checkpoint)
2405
2406                    Lorem
2407
2408                    ## Assistant
2409
2410                    LOREM
2411
2412                    ## User (checkpoint)
2413
2414                    ipsum
2415
2416                    ## Assistant
2417
2418                    IPSUM
2419
2420                    ## User
2421
2422                    dolor
2423
2424                    ## Assistant
2425
2426                    DOLOR
2427
2428                "}
2429            );
2430        });
2431        assert_eq!(
2432            fs.files(),
2433            vec![
2434                Path::new(path!("/test/file-0")),
2435                Path::new(path!("/test/file-1"))
2436            ]
2437        );
2438
2439        // Rewinding the conversation truncates the history and restores the checkpoint.
2440        thread
2441            .update(cx, |thread, cx| {
2442                let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
2443                    panic!("unexpected entries {:?}", thread.entries)
2444                };
2445                thread.rewind(message.prompt_id.clone().unwrap(), cx)
2446            })
2447            .await
2448            .unwrap();
2449        thread.read_with(cx, |thread, cx| {
2450            assert_eq!(
2451                thread.to_markdown(cx),
2452                indoc! {"
2453                    ## User (checkpoint)
2454
2455                    Lorem
2456
2457                    ## Assistant
2458
2459                    LOREM
2460
2461                "}
2462            );
2463        });
2464        assert_eq!(fs.files(), vec![Path::new(path!("/test/file-0"))]);
2465    }
2466
2467    #[gpui::test]
2468    async fn test_refusal(cx: &mut TestAppContext) {
2469        init_test(cx);
2470        let fs = FakeFs::new(cx.background_executor.clone());
2471        fs.insert_tree(path!("/"), json!({})).await;
2472        let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
2473
2474        let refuse_next = Arc::new(AtomicBool::new(false));
2475        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
2476            let refuse_next = refuse_next.clone();
2477            move |request, thread, mut cx| {
2478                let refuse_next = refuse_next.clone();
2479                async move {
2480                    if refuse_next.load(SeqCst) {
2481                        return Ok(acp::PromptResponse {
2482                            stop_reason: acp::StopReason::Refusal,
2483                        });
2484                    }
2485
2486                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
2487                        panic!("expected text content block");
2488                    };
2489                    thread.update(&mut cx, |thread, cx| {
2490                        thread
2491                            .handle_session_update(
2492                                acp::SessionUpdate::AgentMessageChunk {
2493                                    content: content.text.to_uppercase().into(),
2494                                },
2495                                cx,
2496                            )
2497                            .unwrap();
2498                    })?;
2499                    Ok(acp::PromptResponse {
2500                        stop_reason: acp::StopReason::EndTurn,
2501                    })
2502                }
2503                .boxed_local()
2504            }
2505        }));
2506        let thread = cx
2507            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
2508            .await
2509            .unwrap();
2510
2511        cx.update(|cx| {
2512            thread.update(cx, |thread, cx| {
2513                thread.send(new_prompt_id(), vec!["hello".into()], cx)
2514            })
2515        })
2516        .await
2517        .unwrap();
2518        thread.read_with(cx, |thread, cx| {
2519            assert_eq!(
2520                thread.to_markdown(cx),
2521                indoc! {"
2522                    ## User
2523
2524                    hello
2525
2526                    ## Assistant
2527
2528                    HELLO
2529
2530                "}
2531            );
2532        });
2533
2534        // Simulate refusing the second message, ensuring the conversation gets
2535        // truncated to before sending it.
2536        refuse_next.store(true, SeqCst);
2537        cx.update(|cx| {
2538            thread.update(cx, |thread, cx| {
2539                thread.send(new_prompt_id(), vec!["world".into()], cx)
2540            })
2541        })
2542        .await
2543        .unwrap();
2544        thread.read_with(cx, |thread, cx| {
2545            assert_eq!(
2546                thread.to_markdown(cx),
2547                indoc! {"
2548                    ## User
2549
2550                    hello
2551
2552                    ## Assistant
2553
2554                    HELLO
2555
2556                "}
2557            );
2558        });
2559    }
2560
2561    async fn run_until_first_tool_call(
2562        thread: &Entity<AcpThread>,
2563        cx: &mut TestAppContext,
2564    ) -> usize {
2565        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
2566
2567        let subscription = cx.update(|cx| {
2568            cx.subscribe(thread, move |thread, _, cx| {
2569                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
2570                    if matches!(entry, AgentThreadEntry::ToolCall(_)) {
2571                        return tx.try_send(ix).unwrap();
2572                    }
2573                }
2574            })
2575        });
2576
2577        select! {
2578            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
2579                panic!("Timeout waiting for tool call")
2580            }
2581            ix = rx.next().fuse() => {
2582                drop(subscription);
2583                ix.unwrap()
2584            }
2585        }
2586    }
2587
2588    #[derive(Clone, Default)]
2589    struct FakeAgentConnection {
2590        auth_methods: Vec<acp::AuthMethod>,
2591        sessions: Arc<parking_lot::Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
2592        on_user_message: Option<
2593            Rc<
2594                dyn Fn(
2595                        acp::PromptRequest,
2596                        WeakEntity<AcpThread>,
2597                        AsyncApp,
2598                    ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2599                    + 'static,
2600            >,
2601        >,
2602    }
2603
2604    impl FakeAgentConnection {
2605        fn new() -> Self {
2606            Self {
2607                auth_methods: Vec::new(),
2608                on_user_message: None,
2609                sessions: Arc::default(),
2610            }
2611        }
2612
2613        #[expect(unused)]
2614        fn with_auth_methods(mut self, auth_methods: Vec<acp::AuthMethod>) -> Self {
2615            self.auth_methods = auth_methods;
2616            self
2617        }
2618
2619        fn on_user_message(
2620            mut self,
2621            handler: impl Fn(
2622                acp::PromptRequest,
2623                WeakEntity<AcpThread>,
2624                AsyncApp,
2625            ) -> LocalBoxFuture<'static, Result<acp::PromptResponse>>
2626            + 'static,
2627        ) -> Self {
2628            self.on_user_message.replace(Rc::new(handler));
2629            self
2630        }
2631    }
2632
2633    impl AgentConnection for FakeAgentConnection {
2634        fn auth_methods(&self) -> &[acp::AuthMethod] {
2635            &self.auth_methods
2636        }
2637
2638        fn new_thread(
2639            self: Rc<Self>,
2640            project: Entity<Project>,
2641            _cwd: &Path,
2642            cx: &mut App,
2643        ) -> Task<gpui::Result<Entity<AcpThread>>> {
2644            let session_id = acp::SessionId(
2645                rand::thread_rng()
2646                    .sample_iter(&rand::distributions::Alphanumeric)
2647                    .take(7)
2648                    .map(char::from)
2649                    .collect::<String>()
2650                    .into(),
2651            );
2652            let action_log = cx.new(|_| ActionLog::new(project.clone()));
2653            let thread = cx.new(|cx| {
2654                AcpThread::new(
2655                    "Test",
2656                    self.clone(),
2657                    project,
2658                    action_log,
2659                    session_id.clone(),
2660                    watch::Receiver::constant(acp::PromptCapabilities {
2661                        image: true,
2662                        audio: true,
2663                        embedded_context: true,
2664                    }),
2665                    cx,
2666                )
2667            });
2668            self.sessions.lock().insert(session_id, thread.downgrade());
2669            Task::ready(Ok(thread))
2670        }
2671
2672        fn authenticate(&self, method: acp::AuthMethodId, _cx: &mut App) -> Task<gpui::Result<()>> {
2673            if self.auth_methods().iter().any(|m| m.id == method) {
2674                Task::ready(Ok(()))
2675            } else {
2676                Task::ready(Err(anyhow!("Invalid Auth Method")))
2677            }
2678        }
2679
2680        fn prompt(
2681            &self,
2682            params: acp::PromptRequest,
2683            cx: &mut App,
2684        ) -> Task<gpui::Result<acp::PromptResponse>> {
2685            let sessions = self.sessions.lock();
2686            let thread = sessions.get(&params.session_id).unwrap();
2687            if let Some(handler) = &self.on_user_message {
2688                let handler = handler.clone();
2689                let thread = thread.clone();
2690                cx.spawn(async move |cx| handler(params, thread, cx.clone()).await)
2691            } else {
2692                Task::ready(Ok(acp::PromptResponse {
2693                    stop_reason: acp::StopReason::EndTurn,
2694                }))
2695            }
2696        }
2697
2698        fn cancel(&self, session_id: &acp::SessionId, cx: &mut App) {
2699            let sessions = self.sessions.lock();
2700            let thread = sessions.get(session_id).unwrap().clone();
2701
2702            cx.spawn(async move |cx| {
2703                thread
2704                    .update(cx, |thread, cx| thread.cancel(cx))
2705                    .unwrap()
2706                    .await
2707            })
2708            .detach();
2709        }
2710
2711        fn rewind(
2712            &self,
2713            session_id: &acp::SessionId,
2714            _cx: &App,
2715        ) -> Option<Rc<dyn AgentSessionRewind>> {
2716            Some(Rc::new(FakeAgentSessionEditor {
2717                _session_id: session_id.clone(),
2718            }))
2719        }
2720
2721        fn into_any(self: Rc<Self>) -> Rc<dyn Any> {
2722            self
2723        }
2724    }
2725
2726    struct FakeAgentSessionEditor {
2727        _session_id: acp::SessionId,
2728    }
2729
2730    impl AgentSessionRewind for FakeAgentSessionEditor {
2731        fn rewind(&self, _message_id: acp::PromptId, _cx: &mut App) -> Task<Result<()>> {
2732            Task::ready(Ok(()))
2733        }
2734    }
2735}