acp.rs

   1mod server;
   2mod thread_view;
   3
   4use agentic_coding_protocol::{self as acp};
   5use anyhow::{Context as _, Result};
   6use buffer_diff::BufferDiff;
   7use chrono::{DateTime, Utc};
   8use editor::{MultiBuffer, PathKey};
   9use futures::channel::oneshot;
  10use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
  11use language::{Anchor, Buffer, Capability, LanguageRegistry, OffsetRangeExt as _};
  12use markdown::Markdown;
  13use project::Project;
  14use std::{mem, ops::Range, path::PathBuf, sync::Arc};
  15use ui::{App, IconName};
  16use util::{ResultExt, debug_panic};
  17
  18pub use server::AcpServer;
  19pub use thread_view::AcpThreadView;
  20
  21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  22pub struct ThreadId(SharedString);
  23
  24#[derive(Copy, Clone, Debug, PartialEq, Eq)]
  25pub struct FileVersion(u64);
  26
  27#[derive(Debug)]
  28pub struct AgentThreadSummary {
  29    pub id: ThreadId,
  30    pub title: String,
  31    pub created_at: DateTime<Utc>,
  32}
  33
  34#[derive(Clone, Debug, PartialEq, Eq)]
  35pub struct FileContent {
  36    pub path: PathBuf,
  37    pub version: FileVersion,
  38    pub content: SharedString,
  39}
  40
  41#[derive(Clone, Debug, Eq, PartialEq)]
  42pub struct UserMessage {
  43    pub chunks: Vec<UserMessageChunk>,
  44}
  45
  46impl UserMessage {
  47    fn into_acp(self, cx: &App) -> acp::UserMessage {
  48        acp::UserMessage {
  49            chunks: self
  50                .chunks
  51                .into_iter()
  52                .map(|chunk| chunk.into_acp(cx))
  53                .collect(),
  54        }
  55    }
  56}
  57
  58#[derive(Clone, Debug, Eq, PartialEq)]
  59pub enum UserMessageChunk {
  60    Text {
  61        chunk: Entity<Markdown>,
  62    },
  63    File {
  64        content: FileContent,
  65    },
  66    Directory {
  67        path: PathBuf,
  68        contents: Vec<FileContent>,
  69    },
  70    Symbol {
  71        path: PathBuf,
  72        range: Range<u64>,
  73        version: FileVersion,
  74        name: SharedString,
  75        content: SharedString,
  76    },
  77    Fetch {
  78        url: SharedString,
  79        content: SharedString,
  80    },
  81}
  82
  83impl UserMessageChunk {
  84    pub fn into_acp(self, cx: &App) -> acp::UserMessageChunk {
  85        match self {
  86            Self::Text { chunk } => acp::UserMessageChunk::Text {
  87                chunk: chunk.read(cx).source().to_string(),
  88            },
  89            Self::File { .. } => todo!(),
  90            Self::Directory { .. } => todo!(),
  91            Self::Symbol { .. } => todo!(),
  92            Self::Fetch { .. } => todo!(),
  93        }
  94    }
  95
  96    pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
  97        Self::Text {
  98            chunk: cx.new(|cx| {
  99                Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
 100            }),
 101        }
 102    }
 103}
 104
 105#[derive(Clone, Debug, Eq, PartialEq)]
 106pub struct AssistantMessage {
 107    pub chunks: Vec<AssistantMessageChunk>,
 108}
 109
 110#[derive(Clone, Debug, Eq, PartialEq)]
 111pub enum AssistantMessageChunk {
 112    Text { chunk: Entity<Markdown> },
 113    Thought { chunk: Entity<Markdown> },
 114}
 115
 116impl AssistantMessageChunk {
 117    pub fn from_acp(
 118        chunk: acp::AssistantMessageChunk,
 119        language_registry: Arc<LanguageRegistry>,
 120        cx: &mut App,
 121    ) -> Self {
 122        match chunk {
 123            acp::AssistantMessageChunk::Text { chunk } => Self::Text {
 124                chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
 125            },
 126            acp::AssistantMessageChunk::Thought { chunk } => Self::Thought {
 127                chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
 128            },
 129        }
 130    }
 131
 132    pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
 133        Self::Text {
 134            chunk: cx.new(|cx| {
 135                Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
 136            }),
 137        }
 138    }
 139}
 140
 141#[derive(Debug)]
 142pub enum AgentThreadEntryContent {
 143    UserMessage(UserMessage),
 144    AssistantMessage(AssistantMessage),
 145    ToolCall(ToolCall),
 146}
 147
 148#[derive(Debug)]
 149pub struct ToolCall {
 150    id: ToolCallId,
 151    label: Entity<Markdown>,
 152    icon: IconName,
 153    content: Option<ToolCallContent>,
 154    status: ToolCallStatus,
 155}
 156
 157#[derive(Debug)]
 158pub enum ToolCallStatus {
 159    WaitingForConfirmation {
 160        confirmation: ToolCallConfirmation,
 161        respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
 162    },
 163    Allowed {
 164        status: acp::ToolCallStatus,
 165    },
 166    Rejected,
 167    Canceled,
 168}
 169
 170#[derive(Debug)]
 171pub enum ToolCallConfirmation {
 172    Edit {
 173        description: Option<Entity<Markdown>>,
 174    },
 175    Execute {
 176        command: String,
 177        root_command: String,
 178        description: Option<Entity<Markdown>>,
 179    },
 180    Mcp {
 181        server_name: String,
 182        tool_name: String,
 183        tool_display_name: String,
 184        description: Option<Entity<Markdown>>,
 185    },
 186    Fetch {
 187        urls: Vec<String>,
 188        description: Option<Entity<Markdown>>,
 189    },
 190    Other {
 191        description: Entity<Markdown>,
 192    },
 193}
 194
 195impl ToolCallConfirmation {
 196    pub fn from_acp(
 197        confirmation: acp::ToolCallConfirmation,
 198        language_registry: Arc<LanguageRegistry>,
 199        cx: &mut App,
 200    ) -> Self {
 201        let to_md = |description: String, cx: &mut App| -> Entity<Markdown> {
 202            cx.new(|cx| {
 203                Markdown::new(
 204                    description.into(),
 205                    Some(language_registry.clone()),
 206                    None,
 207                    cx,
 208                )
 209            })
 210        };
 211
 212        match confirmation {
 213            acp::ToolCallConfirmation::Edit { description } => Self::Edit {
 214                description: description.map(|description| to_md(description, cx)),
 215            },
 216            acp::ToolCallConfirmation::Execute {
 217                command,
 218                root_command,
 219                description,
 220            } => Self::Execute {
 221                command,
 222                root_command,
 223                description: description.map(|description| to_md(description, cx)),
 224            },
 225            acp::ToolCallConfirmation::Mcp {
 226                server_name,
 227                tool_name,
 228                tool_display_name,
 229                description,
 230            } => Self::Mcp {
 231                server_name,
 232                tool_name,
 233                tool_display_name,
 234                description: description.map(|description| to_md(description, cx)),
 235            },
 236            acp::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch {
 237                urls,
 238                description: description.map(|description| to_md(description, cx)),
 239            },
 240            acp::ToolCallConfirmation::Other { description } => Self::Other {
 241                description: to_md(description, cx),
 242            },
 243        }
 244    }
 245}
 246
 247#[derive(Debug)]
 248pub enum ToolCallContent {
 249    Markdown { markdown: Entity<Markdown> },
 250    Diff { diff: Diff },
 251}
 252
 253impl ToolCallContent {
 254    pub fn from_acp(
 255        content: acp::ToolCallContent,
 256        language_registry: Arc<LanguageRegistry>,
 257        cx: &mut App,
 258    ) -> Self {
 259        match content {
 260            acp::ToolCallContent::Markdown { markdown } => Self::Markdown {
 261                markdown: cx.new(|cx| Markdown::new_text(markdown.into(), cx)),
 262            },
 263            acp::ToolCallContent::Diff { diff } => Self::Diff {
 264                diff: Diff::from_acp(diff, language_registry, cx),
 265            },
 266        }
 267    }
 268}
 269
 270#[derive(Debug)]
 271pub struct Diff {
 272    multibuffer: Entity<MultiBuffer>,
 273    path: PathBuf,
 274    _task: Task<Result<()>>,
 275}
 276
 277impl Diff {
 278    pub fn from_acp(
 279        diff: acp::Diff,
 280        language_registry: Arc<LanguageRegistry>,
 281        cx: &mut App,
 282    ) -> Self {
 283        let acp::Diff {
 284            path,
 285            old_text,
 286            new_text,
 287        } = diff;
 288
 289        let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
 290
 291        let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
 292        let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
 293        let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
 294        let old_buffer_snapshot = old_buffer.read(cx).snapshot();
 295        let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
 296        let diff_task = buffer_diff.update(cx, |diff, cx| {
 297            diff.set_base_text(
 298                old_buffer_snapshot,
 299                Some(language_registry.clone()),
 300                new_buffer_snapshot,
 301                cx,
 302            )
 303        });
 304
 305        let task = cx.spawn({
 306            let multibuffer = multibuffer.clone();
 307            let path = path.clone();
 308            async move |cx| {
 309                diff_task.await?;
 310
 311                multibuffer
 312                    .update(cx, |multibuffer, cx| {
 313                        let hunk_ranges = {
 314                            let buffer = new_buffer.read(cx);
 315                            let diff = buffer_diff.read(cx);
 316                            diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
 317                                .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
 318                                .collect::<Vec<_>>()
 319                        };
 320
 321                        multibuffer.set_excerpts_for_path(
 322                            PathKey::for_buffer(&new_buffer, cx),
 323                            new_buffer.clone(),
 324                            hunk_ranges,
 325                            editor::DEFAULT_MULTIBUFFER_CONTEXT,
 326                            cx,
 327                        );
 328                        multibuffer.add_diff(buffer_diff.clone(), cx);
 329                    })
 330                    .log_err();
 331
 332                if let Some(language) = language_registry
 333                    .language_for_file_path(&path)
 334                    .await
 335                    .log_err()
 336                {
 337                    new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
 338                }
 339
 340                anyhow::Ok(())
 341            }
 342        });
 343
 344        Self {
 345            multibuffer,
 346            path,
 347            _task: task,
 348        }
 349    }
 350}
 351
 352/// A `ThreadEntryId` that is known to be a ToolCall
 353#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
 354pub struct ToolCallId(ThreadEntryId);
 355
 356impl ToolCallId {
 357    pub fn as_u64(&self) -> u64 {
 358        self.0.0
 359    }
 360}
 361
 362#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
 363pub struct ThreadEntryId(pub u64);
 364
 365impl ThreadEntryId {
 366    pub fn post_inc(&mut self) -> Self {
 367        let id = *self;
 368        self.0 += 1;
 369        id
 370    }
 371}
 372
 373#[derive(Debug)]
 374pub struct ThreadEntry {
 375    pub id: ThreadEntryId,
 376    pub content: AgentThreadEntryContent,
 377}
 378
 379pub struct AcpThread {
 380    id: ThreadId,
 381    next_entry_id: ThreadEntryId,
 382    entries: Vec<ThreadEntry>,
 383    server: Arc<AcpServer>,
 384    title: SharedString,
 385    project: Entity<Project>,
 386    send_task: Option<Task<()>>,
 387}
 388
 389enum AcpThreadEvent {
 390    NewEntry,
 391    EntryUpdated(usize),
 392}
 393
 394#[derive(PartialEq, Eq)]
 395pub enum ThreadStatus {
 396    Idle,
 397    WaitingForToolConfirmation,
 398    Generating,
 399}
 400
 401impl EventEmitter<AcpThreadEvent> for AcpThread {}
 402
 403impl AcpThread {
 404    pub fn new(
 405        server: Arc<AcpServer>,
 406        thread_id: ThreadId,
 407        entries: Vec<AgentThreadEntryContent>,
 408        project: Entity<Project>,
 409        _: &mut Context<Self>,
 410    ) -> Self {
 411        let mut next_entry_id = ThreadEntryId(0);
 412        Self {
 413            title: "ACP Thread".into(),
 414            entries: entries
 415                .into_iter()
 416                .map(|entry| ThreadEntry {
 417                    id: next_entry_id.post_inc(),
 418                    content: entry,
 419                })
 420                .collect(),
 421            server,
 422            id: thread_id,
 423            next_entry_id,
 424            project,
 425            send_task: None,
 426        }
 427    }
 428
 429    pub fn title(&self) -> SharedString {
 430        self.title.clone()
 431    }
 432
 433    pub fn entries(&self) -> &[ThreadEntry] {
 434        &self.entries
 435    }
 436
 437    pub fn status(&self) -> ThreadStatus {
 438        if self.send_task.is_some() {
 439            if self.waiting_for_tool_confirmation() {
 440                ThreadStatus::WaitingForToolConfirmation
 441            } else {
 442                ThreadStatus::Generating
 443            }
 444        } else {
 445            ThreadStatus::Idle
 446        }
 447    }
 448
 449    pub fn push_entry(
 450        &mut self,
 451        entry: AgentThreadEntryContent,
 452        cx: &mut Context<Self>,
 453    ) -> ThreadEntryId {
 454        let id = self.next_entry_id.post_inc();
 455        self.entries.push(ThreadEntry { id, content: entry });
 456        cx.emit(AcpThreadEvent::NewEntry);
 457        id
 458    }
 459
 460    pub fn push_assistant_chunk(
 461        &mut self,
 462        chunk: acp::AssistantMessageChunk,
 463        cx: &mut Context<Self>,
 464    ) {
 465        let entries_len = self.entries.len();
 466        if let Some(last_entry) = self.entries.last_mut()
 467            && let AgentThreadEntryContent::AssistantMessage(AssistantMessage { ref mut chunks }) =
 468                last_entry.content
 469        {
 470            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 471
 472            match (chunks.last_mut(), &chunk) {
 473                (
 474                    Some(AssistantMessageChunk::Text { chunk: old_chunk }),
 475                    acp::AssistantMessageChunk::Text { chunk: new_chunk },
 476                )
 477                | (
 478                    Some(AssistantMessageChunk::Thought { chunk: old_chunk }),
 479                    acp::AssistantMessageChunk::Thought { chunk: new_chunk },
 480                ) => {
 481                    old_chunk.update(cx, |old_chunk, cx| {
 482                        old_chunk.append(&new_chunk, cx);
 483                    });
 484                }
 485                _ => {
 486                    chunks.push(AssistantMessageChunk::from_acp(
 487                        chunk,
 488                        self.project.read(cx).languages().clone(),
 489                        cx,
 490                    ));
 491                }
 492            }
 493        } else {
 494            let chunk = AssistantMessageChunk::from_acp(
 495                chunk,
 496                self.project.read(cx).languages().clone(),
 497                cx,
 498            );
 499
 500            self.push_entry(
 501                AgentThreadEntryContent::AssistantMessage(AssistantMessage {
 502                    chunks: vec![chunk],
 503                }),
 504                cx,
 505            );
 506        }
 507    }
 508
 509    pub fn request_tool_call(
 510        &mut self,
 511        label: String,
 512        icon: acp::Icon,
 513        content: Option<acp::ToolCallContent>,
 514        confirmation: acp::ToolCallConfirmation,
 515        cx: &mut Context<Self>,
 516    ) -> ToolCallRequest {
 517        let (tx, rx) = oneshot::channel();
 518
 519        let status = ToolCallStatus::WaitingForConfirmation {
 520            confirmation: ToolCallConfirmation::from_acp(
 521                confirmation,
 522                self.project.read(cx).languages().clone(),
 523                cx,
 524            ),
 525            respond_tx: tx,
 526        };
 527
 528        let id = self.insert_tool_call(label, status, icon, content, cx);
 529        ToolCallRequest { id, outcome: rx }
 530    }
 531
 532    pub fn push_tool_call(
 533        &mut self,
 534        label: String,
 535        icon: acp::Icon,
 536        content: Option<acp::ToolCallContent>,
 537        cx: &mut Context<Self>,
 538    ) -> ToolCallId {
 539        let status = ToolCallStatus::Allowed {
 540            status: acp::ToolCallStatus::Running,
 541        };
 542
 543        self.insert_tool_call(label, status, icon, content, cx)
 544    }
 545
 546    fn insert_tool_call(
 547        &mut self,
 548        label: String,
 549        status: ToolCallStatus,
 550        icon: acp::Icon,
 551        content: Option<acp::ToolCallContent>,
 552        cx: &mut Context<Self>,
 553    ) -> ToolCallId {
 554        let language_registry = self.project.read(cx).languages().clone();
 555
 556        let entry_id = self.push_entry(
 557            AgentThreadEntryContent::ToolCall(ToolCall {
 558                // todo! clean up id creation
 559                id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
 560                label: cx.new(|cx| {
 561                    Markdown::new(label.into(), Some(language_registry.clone()), None, cx)
 562                }),
 563                icon: acp_icon_to_ui_icon(icon),
 564                content: content
 565                    .map(|content| ToolCallContent::from_acp(content, language_registry, cx)),
 566                status,
 567            }),
 568            cx,
 569        );
 570
 571        ToolCallId(entry_id)
 572    }
 573
 574    pub fn authorize_tool_call(
 575        &mut self,
 576        id: ToolCallId,
 577        outcome: acp::ToolCallConfirmationOutcome,
 578        cx: &mut Context<Self>,
 579    ) {
 580        let Some(entry) = self.entry_mut(id.0) else {
 581            return;
 582        };
 583
 584        let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
 585            debug_panic!("expected ToolCall");
 586            return;
 587        };
 588
 589        let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
 590            ToolCallStatus::Rejected
 591        } else {
 592            ToolCallStatus::Allowed {
 593                status: acp::ToolCallStatus::Running,
 594            }
 595        };
 596
 597        let curr_status = mem::replace(&mut call.status, new_status);
 598
 599        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
 600            respond_tx.send(outcome).log_err();
 601        } else {
 602            debug_panic!("tried to authorize an already authorized tool call");
 603        }
 604
 605        cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
 606    }
 607
 608    pub fn update_tool_call(
 609        &mut self,
 610        id: ToolCallId,
 611        new_status: acp::ToolCallStatus,
 612        new_content: Option<acp::ToolCallContent>,
 613        cx: &mut Context<Self>,
 614    ) -> Result<()> {
 615        let language_registry = self.project.read(cx).languages().clone();
 616        let entry = self.entry_mut(id.0).context("Entry not found")?;
 617
 618        match &mut entry.content {
 619            AgentThreadEntryContent::ToolCall(call) => {
 620                call.content = new_content.map(|new_content| {
 621                    ToolCallContent::from_acp(new_content, language_registry, cx)
 622                });
 623
 624                match &mut call.status {
 625                    ToolCallStatus::Allowed { status } => {
 626                        *status = new_status;
 627                    }
 628                    ToolCallStatus::WaitingForConfirmation { .. } => {
 629                        anyhow::bail!("Tool call hasn't been authorized yet")
 630                    }
 631                    ToolCallStatus::Rejected => {
 632                        anyhow::bail!("Tool call was rejected and therefore can't be updated")
 633                    }
 634                    ToolCallStatus::Canceled => {
 635                        // todo! test this case with fake server
 636                        call.status = ToolCallStatus::Allowed { status: new_status };
 637                    }
 638                }
 639            }
 640            _ => anyhow::bail!("Entry is not a tool call"),
 641        }
 642
 643        cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
 644        Ok(())
 645    }
 646
 647    fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
 648        let entry = self.entries.get_mut(id.0 as usize);
 649        debug_assert!(
 650            entry.is_some(),
 651            "We shouldn't give out ids to entries that don't exist"
 652        );
 653        entry
 654    }
 655
 656    /// Returns true if the last turn is awaiting tool authorization
 657    pub fn waiting_for_tool_confirmation(&self) -> bool {
 658        // todo!("should we use a hashmap?")
 659        for entry in self.entries.iter().rev() {
 660            match &entry.content {
 661                AgentThreadEntryContent::ToolCall(call) => match call.status {
 662                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
 663                    ToolCallStatus::Allowed { .. }
 664                    | ToolCallStatus::Rejected
 665                    | ToolCallStatus::Canceled => continue,
 666                },
 667                AgentThreadEntryContent::UserMessage(_)
 668                | AgentThreadEntryContent::AssistantMessage(_) => {
 669                    // Reached the beginning of the turn
 670                    return false;
 671                }
 672            }
 673        }
 674        false
 675    }
 676
 677    pub fn send(
 678        &mut self,
 679        message: &str,
 680        cx: &mut Context<Self>,
 681    ) -> impl use<> + Future<Output = Result<()>> {
 682        let agent = self.server.clone();
 683        let id = self.id.clone();
 684        let chunk =
 685            UserMessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
 686        let message = UserMessage {
 687            chunks: vec![chunk],
 688        };
 689        self.push_entry(AgentThreadEntryContent::UserMessage(message.clone()), cx);
 690        let acp_message = message.into_acp(cx);
 691
 692        let (tx, rx) = oneshot::channel();
 693        let cancel = self.cancel(cx);
 694
 695        self.send_task = Some(cx.spawn(async move |this, cx| {
 696            cancel.await.log_err();
 697
 698            let result = agent.send_message(id, acp_message, cx).await;
 699            tx.send(result).log_err();
 700            this.update(cx, |this, _cx| this.send_task.take()).log_err();
 701        }));
 702
 703        async move {
 704            match rx.await {
 705                Ok(result) => result,
 706                Err(_) => Ok(()),
 707            }
 708        }
 709    }
 710
 711    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 712        let agent = self.server.clone();
 713        let id = self.id.clone();
 714
 715        if self.send_task.take().is_some() {
 716            cx.spawn(async move |this, cx| {
 717                agent.cancel_send_message(id, cx).await?;
 718
 719                this.update(cx, |this, _cx| {
 720                    for entry in this.entries.iter_mut() {
 721                        if let AgentThreadEntryContent::ToolCall(call) = &mut entry.content {
 722                            let cancel = matches!(
 723                                call.status,
 724                                ToolCallStatus::WaitingForConfirmation { .. }
 725                                    | ToolCallStatus::Allowed {
 726                                        status: acp::ToolCallStatus::Running
 727                                    }
 728                            );
 729
 730                            if cancel {
 731                                let curr_status =
 732                                    mem::replace(&mut call.status, ToolCallStatus::Canceled);
 733
 734                                if let ToolCallStatus::WaitingForConfirmation {
 735                                    respond_tx, ..
 736                                } = curr_status
 737                                {
 738                                    respond_tx
 739                                        .send(acp::ToolCallConfirmationOutcome::Cancel)
 740                                        .ok();
 741                                }
 742                            }
 743                        }
 744                    }
 745                })
 746            })
 747        } else {
 748            Task::ready(Ok(()))
 749        }
 750    }
 751
 752    #[cfg(test)]
 753    pub fn to_string(&self, cx: &App) -> String {
 754        let mut result = String::new();
 755        for entry in &self.entries {
 756            match &entry.content {
 757                AgentThreadEntryContent::UserMessage(user_message) => {
 758                    result.push_str("# User\n");
 759                    for chunk in &user_message.chunks {
 760                        match chunk {
 761                            UserMessageChunk::Text { chunk } => {
 762                                result.push_str(chunk.read(cx).source());
 763                                result.push('\n');
 764                            }
 765                            _ => unimplemented!(),
 766                        }
 767                    }
 768                }
 769                AgentThreadEntryContent::AssistantMessage(assistant_message) => {
 770                    result.push_str("# Assistant\n");
 771                    for chunk in &assistant_message.chunks {
 772                        match chunk {
 773                            AssistantMessageChunk::Text { chunk } => {
 774                                result.push_str(chunk.read(cx).source());
 775                                result.push('\n')
 776                            }
 777                            AssistantMessageChunk::Thought { chunk } => {
 778                                result.push_str("<thinking>\n");
 779                                result.push_str(chunk.read(cx).source());
 780                                result.push_str("\n</thinking>\n");
 781                            }
 782                        }
 783                    }
 784                }
 785                AgentThreadEntryContent::ToolCall(_tool_call) => unimplemented!(),
 786            }
 787        }
 788        result
 789    }
 790}
 791
 792fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
 793    match icon {
 794        acp::Icon::FileSearch => IconName::FileSearch,
 795        acp::Icon::Folder => IconName::Folder,
 796        acp::Icon::Globe => IconName::Globe,
 797        acp::Icon::Hammer => IconName::Hammer,
 798        acp::Icon::LightBulb => IconName::LightBulb,
 799        acp::Icon::Pencil => IconName::Pencil,
 800        acp::Icon::Regex => IconName::Regex,
 801        acp::Icon::Terminal => IconName::Terminal,
 802    }
 803}
 804
 805pub struct ToolCallRequest {
 806    pub id: ToolCallId,
 807    pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
 808}
 809
 810#[cfg(test)]
 811mod tests {
 812    use super::*;
 813    use async_pipe::{PipeReader, PipeWriter};
 814    use async_trait::async_trait;
 815    use futures::{FutureExt as _, channel::mpsc, future::LocalBoxFuture, select};
 816    use gpui::{AsyncApp, TestAppContext};
 817    use indoc::indoc;
 818    use project::FakeFs;
 819    use serde_json::json;
 820    use settings::SettingsStore;
 821    use smol::{future::BoxedLocal, stream::StreamExt as _};
 822    use std::{env, path::Path, process::Stdio, rc::Rc, time::Duration};
 823    use util::path;
 824
 825    fn init_test(cx: &mut TestAppContext) {
 826        env_logger::try_init().ok();
 827        cx.update(|cx| {
 828            let settings_store = SettingsStore::test(cx);
 829            cx.set_global(settings_store);
 830            Project::init_settings(cx);
 831            language::init(cx);
 832        });
 833    }
 834
 835    #[gpui::test]
 836    async fn test_thinking_concatenation(cx: &mut TestAppContext) {
 837        init_test(cx);
 838
 839        cx.executor().allow_parking();
 840
 841        let fs = FakeFs::new(cx.executor());
 842        let project = Project::test(fs, [], cx).await;
 843        let (server, fake_server) = fake_acp_server(project, cx);
 844
 845        server.initialize().await.unwrap();
 846
 847        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
 848
 849        fake_server.update(cx, |fake_server, _| {
 850            fake_server.on_user_message(move |params, server, mut cx| async move {
 851                server
 852                    .update(&mut cx, |server, _| {
 853                        server.send_to_zed(acp::StreamAssistantMessageChunkParams {
 854                            thread_id: params.thread_id.clone(),
 855                            chunk: acp::AssistantMessageChunk::Thought {
 856                                chunk: "Thinking ".into(),
 857                            },
 858                        })
 859                    })?
 860                    .await
 861                    .unwrap();
 862                server
 863                    .update(&mut cx, |server, _| {
 864                        server.send_to_zed(acp::StreamAssistantMessageChunkParams {
 865                            thread_id: params.thread_id,
 866                            chunk: acp::AssistantMessageChunk::Thought {
 867                                chunk: "hard!".into(),
 868                            },
 869                        })
 870                    })?
 871                    .await
 872                    .unwrap();
 873
 874                Ok(acp::SendUserMessageResponse)
 875            })
 876        });
 877
 878        thread
 879            .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
 880            .await
 881            .unwrap();
 882
 883        let output = thread.read_with(cx, |thread, cx| thread.to_string(cx));
 884        assert_eq!(
 885            output,
 886            indoc! {r#"
 887            # User
 888            Hello from Zed!
 889            # Assistant
 890            <thinking>
 891            Thinking hard!
 892            </thinking>
 893            "#}
 894        );
 895    }
 896
 897    #[gpui::test]
 898    async fn test_gemini_basic(cx: &mut TestAppContext) {
 899        init_test(cx);
 900
 901        cx.executor().allow_parking();
 902
 903        let fs = FakeFs::new(cx.executor());
 904        let project = Project::test(fs, [], cx).await;
 905        let server = gemini_acp_server(project.clone(), cx).await;
 906        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
 907        thread
 908            .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
 909            .await
 910            .unwrap();
 911
 912        thread.read_with(cx, |thread, _| {
 913            assert_eq!(thread.entries.len(), 2);
 914            assert!(matches!(
 915                thread.entries[0].content,
 916                AgentThreadEntryContent::UserMessage(_)
 917            ));
 918            assert!(matches!(
 919                thread.entries[1].content,
 920                AgentThreadEntryContent::AssistantMessage(_)
 921            ));
 922        });
 923    }
 924
 925    #[gpui::test]
 926    async fn test_gemini_tool_call(cx: &mut TestAppContext) {
 927        init_test(cx);
 928
 929        cx.executor().allow_parking();
 930
 931        let fs = FakeFs::new(cx.executor());
 932        fs.insert_tree(
 933            path!("/private/tmp"),
 934            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
 935        )
 936        .await;
 937        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
 938        let server = gemini_acp_server(project.clone(), cx).await;
 939        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
 940        thread
 941            .update(cx, |thread, cx| {
 942                thread.send(
 943                    "Read the '/private/tmp/foo' file and tell me what you see.",
 944                    cx,
 945                )
 946            })
 947            .await
 948            .unwrap();
 949        thread.read_with(cx, |thread, _cx| {
 950            assert!(matches!(
 951                &thread.entries()[2].content,
 952                AgentThreadEntryContent::ToolCall(ToolCall {
 953                    status: ToolCallStatus::Allowed { .. },
 954                    ..
 955                })
 956            ));
 957
 958            assert!(matches!(
 959                thread.entries[3].content,
 960                AgentThreadEntryContent::AssistantMessage(_)
 961            ));
 962        });
 963    }
 964
 965    #[gpui::test]
 966    async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
 967        init_test(cx);
 968
 969        cx.executor().allow_parking();
 970
 971        let fs = FakeFs::new(cx.executor());
 972        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
 973        let server = gemini_acp_server(project.clone(), cx).await;
 974        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
 975        let full_turn = thread.update(cx, |thread, cx| {
 976            thread.send(r#"Run `echo "Hello, world!"`"#, cx)
 977        });
 978
 979        run_until_first_tool_call(&thread, cx).await;
 980
 981        let tool_call_id = thread.read_with(cx, |thread, _cx| {
 982            let AgentThreadEntryContent::ToolCall(ToolCall {
 983                id,
 984                status:
 985                    ToolCallStatus::WaitingForConfirmation {
 986                        confirmation: ToolCallConfirmation::Execute { root_command, .. },
 987                        ..
 988                    },
 989                ..
 990            }) = &thread.entries()[2].content
 991            else {
 992                panic!();
 993            };
 994
 995            assert_eq!(root_command, "echo");
 996
 997            *id
 998        });
 999
1000        thread.update(cx, |thread, cx| {
1001            thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
1002
1003            assert!(matches!(
1004                &thread.entries()[2].content,
1005                AgentThreadEntryContent::ToolCall(ToolCall {
1006                    status: ToolCallStatus::Allowed { .. },
1007                    ..
1008                })
1009            ));
1010        });
1011
1012        full_turn.await.unwrap();
1013
1014        thread.read_with(cx, |thread, cx| {
1015            let AgentThreadEntryContent::ToolCall(ToolCall {
1016                content: Some(ToolCallContent::Markdown { markdown }),
1017                status: ToolCallStatus::Allowed { .. },
1018                ..
1019            }) = &thread.entries()[2].content
1020            else {
1021                panic!();
1022            };
1023
1024            markdown.read_with(cx, |md, _cx| {
1025                assert!(
1026                    md.source().contains("Hello, world!"),
1027                    r#"Expected '{}' to contain "Hello, world!""#,
1028                    md.source()
1029                );
1030            });
1031        });
1032    }
1033
1034    #[gpui::test]
1035    async fn test_gemini_cancel(cx: &mut TestAppContext) {
1036        init_test(cx);
1037
1038        cx.executor().allow_parking();
1039
1040        let fs = FakeFs::new(cx.executor());
1041        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
1042        let server = gemini_acp_server(project.clone(), cx).await;
1043        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
1044        let full_turn = thread.update(cx, |thread, cx| {
1045            thread.send(r#"Run `echo "Hello, world!"`"#, cx)
1046        });
1047
1048        let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await;
1049
1050        thread.read_with(cx, |thread, _cx| {
1051            let AgentThreadEntryContent::ToolCall(ToolCall {
1052                id,
1053                status:
1054                    ToolCallStatus::WaitingForConfirmation {
1055                        confirmation: ToolCallConfirmation::Execute { root_command, .. },
1056                        ..
1057                    },
1058                ..
1059            }) = &thread.entries()[first_tool_call_ix].content
1060            else {
1061                panic!("{:?}", thread.entries()[1].content);
1062            };
1063
1064            assert_eq!(root_command, "echo");
1065
1066            *id
1067        });
1068
1069        thread
1070            .update(cx, |thread, cx| thread.cancel(cx))
1071            .await
1072            .unwrap();
1073        full_turn.await.unwrap();
1074        thread.read_with(cx, |thread, _| {
1075            let AgentThreadEntryContent::ToolCall(ToolCall {
1076                status: ToolCallStatus::Canceled,
1077                ..
1078            }) = &thread.entries()[first_tool_call_ix].content
1079            else {
1080                panic!();
1081            };
1082        });
1083
1084        thread
1085            .update(cx, |thread, cx| {
1086                thread.send(r#"Stop running and say goodbye to me."#, cx)
1087            })
1088            .await
1089            .unwrap();
1090        thread.read_with(cx, |thread, _| {
1091            assert!(matches!(
1092                &thread.entries().last().unwrap().content,
1093                AgentThreadEntryContent::AssistantMessage(..),
1094            ))
1095        });
1096    }
1097
1098    async fn run_until_first_tool_call(
1099        thread: &Entity<AcpThread>,
1100        cx: &mut TestAppContext,
1101    ) -> usize {
1102        let (mut tx, mut rx) = mpsc::channel::<usize>(1);
1103
1104        let subscription = cx.update(|cx| {
1105            cx.subscribe(thread, move |thread, _, cx| {
1106                for (ix, entry) in thread.read(cx).entries.iter().enumerate() {
1107                    if matches!(entry.content, AgentThreadEntryContent::ToolCall(_)) {
1108                        return tx.try_send(ix).unwrap();
1109                    }
1110                }
1111            })
1112        });
1113
1114        select! {
1115            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
1116                panic!("Timeout waiting for tool call")
1117            }
1118            ix = rx.next().fuse() => {
1119                drop(subscription);
1120                ix.unwrap()
1121            }
1122        }
1123    }
1124
1125    pub async fn gemini_acp_server(
1126        project: Entity<Project>,
1127        cx: &mut TestAppContext,
1128    ) -> Arc<AcpServer> {
1129        let cli_path =
1130            Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
1131        let mut command = util::command::new_smol_command("node");
1132        command
1133            .arg(cli_path)
1134            .arg("--acp")
1135            .current_dir("/private/tmp")
1136            .stdin(Stdio::piped())
1137            .stdout(Stdio::piped())
1138            .stderr(Stdio::inherit())
1139            .kill_on_drop(true);
1140
1141        if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
1142            command.env("GEMINI_API_KEY", gemini_key);
1143        }
1144
1145        let child = command.spawn().unwrap();
1146        let server = cx.update(|cx| AcpServer::stdio(child, project, cx));
1147        server.initialize().await.unwrap();
1148        server
1149    }
1150
1151    pub fn fake_acp_server(
1152        project: Entity<Project>,
1153        cx: &mut TestAppContext,
1154    ) -> (Arc<AcpServer>, Entity<FakeAcpServer>) {
1155        let (stdin_tx, stdin_rx) = async_pipe::pipe();
1156        let (stdout_tx, stdout_rx) = async_pipe::pipe();
1157        let server = cx.update(|cx| AcpServer::fake(stdin_tx, stdout_rx, project, cx));
1158        let agent = cx.update(|cx| cx.new(|cx| FakeAcpServer::new(stdin_rx, stdout_tx, cx)));
1159        (server, agent)
1160    }
1161
1162    pub struct FakeAcpServer {
1163        connection: acp::ClientConnection,
1164        _handler_task: Task<()>,
1165        _io_task: Task<()>,
1166        on_user_message: Option<
1167            Rc<
1168                dyn Fn(
1169                    acp::SendUserMessageParams,
1170                    Entity<FakeAcpServer>,
1171                    AsyncApp,
1172                )
1173                    -> LocalBoxFuture<'static, Result<acp::SendUserMessageResponse>>,
1174            >,
1175        >,
1176    }
1177
1178    #[derive(Clone)]
1179    struct FakeAgent {
1180        server: Entity<FakeAcpServer>,
1181        cx: AsyncApp,
1182    }
1183
1184    #[async_trait(?Send)]
1185    impl acp::Agent for FakeAgent {
1186        async fn initialize(
1187            &self,
1188            _request: acp::InitializeParams,
1189        ) -> Result<acp::InitializeResponse> {
1190            Ok(acp::InitializeResponse {
1191                is_authenticated: true,
1192            })
1193        }
1194
1195        async fn authenticate(
1196            &self,
1197            _request: acp::AuthenticateParams,
1198        ) -> Result<acp::AuthenticateResponse> {
1199            Ok(acp::AuthenticateResponse)
1200        }
1201
1202        async fn create_thread(
1203            &self,
1204            _request: acp::CreateThreadParams,
1205        ) -> Result<acp::CreateThreadResponse> {
1206            Ok(acp::CreateThreadResponse {
1207                thread_id: acp::ThreadId("test-thread".into()),
1208            })
1209        }
1210
1211        async fn send_user_message(
1212            &self,
1213            request: acp::SendUserMessageParams,
1214        ) -> Result<acp::SendUserMessageResponse> {
1215            let mut cx = self.cx.clone();
1216            let handler = self
1217                .server
1218                .update(&mut cx, |server, _| server.on_user_message.clone())
1219                .ok()
1220                .flatten();
1221            if let Some(handler) = handler {
1222                handler(request, self.server.clone(), self.cx.clone()).await
1223            } else {
1224                anyhow::bail!("No handler for on_user_message")
1225            }
1226        }
1227    }
1228
1229    impl FakeAcpServer {
1230        fn new(stdin: PipeReader, stdout: PipeWriter, cx: &Context<Self>) -> Self {
1231            let agent = FakeAgent {
1232                server: cx.entity(),
1233                cx: cx.to_async(),
1234            };
1235
1236            let (connection, handler_fut, io_fut) =
1237                acp::ClientConnection::connect_to_client(agent.clone(), stdout, stdin);
1238            FakeAcpServer {
1239                connection: connection,
1240                on_user_message: None,
1241                _handler_task: cx.foreground_executor().spawn(handler_fut),
1242                _io_task: cx.background_spawn(async move {
1243                    io_fut.await.log_err();
1244                }),
1245            }
1246        }
1247
1248        fn on_user_message<F>(
1249            &mut self,
1250            handler: impl for<'a> Fn(acp::SendUserMessageParams, Entity<FakeAcpServer>, AsyncApp) -> F
1251            + 'static,
1252        ) where
1253            F: Future<Output = Result<acp::SendUserMessageResponse>> + 'static,
1254        {
1255            self.on_user_message
1256                .replace(Rc::new(move |request, server, cx| {
1257                    handler(request, server, cx).boxed_local()
1258                }));
1259        }
1260
1261        fn send_to_zed<T: acp::ClientRequest>(
1262            &self,
1263            message: T,
1264        ) -> BoxedLocal<Result<T::Response, acp::Error>> {
1265            self.connection.request(message).boxed_local()
1266        }
1267    }
1268}