acp.rs

   1mod server;
   2mod thread_view;
   3
   4use agentic_coding_protocol::{self as acp, Role};
   5use anyhow::{Context as _, Result};
   6use buffer_diff::BufferDiff;
   7use chrono::{DateTime, Utc};
   8use editor::{MultiBuffer, PathKey};
   9use futures::channel::oneshot;
  10use gpui::{AppContext, Context, Entity, EventEmitter, SharedString, Task};
  11use language::{Anchor, Buffer, Capability, LanguageRegistry, OffsetRangeExt as _};
  12use markdown::Markdown;
  13use project::Project;
  14use std::{mem, ops::Range, path::PathBuf, sync::Arc};
  15use ui::{App, IconName};
  16use util::{ResultExt, debug_panic};
  17
  18pub use server::AcpServer;
  19pub use thread_view::AcpThreadView;
  20
  21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
  22pub struct ThreadId(SharedString);
  23
  24#[derive(Copy, Clone, Debug, PartialEq, Eq)]
  25pub struct FileVersion(u64);
  26
  27#[derive(Debug)]
  28pub struct AgentThreadSummary {
  29    pub id: ThreadId,
  30    pub title: String,
  31    pub created_at: DateTime<Utc>,
  32}
  33
  34#[derive(Clone, Debug, PartialEq, Eq)]
  35pub struct FileContent {
  36    pub path: PathBuf,
  37    pub version: FileVersion,
  38    pub content: SharedString,
  39}
  40
  41#[derive(Clone, Debug, Eq, PartialEq)]
  42pub struct Message {
  43    pub role: acp::Role,
  44    pub chunks: Vec<MessageChunk>,
  45}
  46
  47impl Message {
  48    fn into_acp(self, cx: &App) -> acp::Message {
  49        acp::Message {
  50            role: self.role,
  51            chunks: self
  52                .chunks
  53                .into_iter()
  54                .map(|chunk| chunk.into_acp(cx))
  55                .collect(),
  56        }
  57    }
  58}
  59
  60#[derive(Clone, Debug, Eq, PartialEq)]
  61pub enum MessageChunk {
  62    Text {
  63        chunk: Entity<Markdown>,
  64    },
  65    File {
  66        content: FileContent,
  67    },
  68    Directory {
  69        path: PathBuf,
  70        contents: Vec<FileContent>,
  71    },
  72    Symbol {
  73        path: PathBuf,
  74        range: Range<u64>,
  75        version: FileVersion,
  76        name: SharedString,
  77        content: SharedString,
  78    },
  79    Fetch {
  80        url: SharedString,
  81        content: SharedString,
  82    },
  83}
  84
  85impl MessageChunk {
  86    pub fn from_acp(
  87        chunk: acp::MessageChunk,
  88        language_registry: Arc<LanguageRegistry>,
  89        cx: &mut App,
  90    ) -> Self {
  91        match chunk {
  92            acp::MessageChunk::Text { chunk } => MessageChunk::Text {
  93                chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
  94            },
  95        }
  96    }
  97
  98    pub fn into_acp(self, cx: &App) -> acp::MessageChunk {
  99        match self {
 100            MessageChunk::Text { chunk } => acp::MessageChunk::Text {
 101                chunk: chunk.read(cx).source().to_string(),
 102            },
 103            MessageChunk::File { .. } => todo!(),
 104            MessageChunk::Directory { .. } => todo!(),
 105            MessageChunk::Symbol { .. } => todo!(),
 106            MessageChunk::Fetch { .. } => todo!(),
 107        }
 108    }
 109
 110    pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
 111        MessageChunk::Text {
 112            chunk: cx.new(|cx| {
 113                Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
 114            }),
 115        }
 116    }
 117}
 118
 119#[derive(Debug)]
 120pub enum AgentThreadEntryContent {
 121    Message(Message),
 122    ToolCall(ToolCall),
 123}
 124
 125#[derive(Debug)]
 126pub struct ToolCall {
 127    id: ToolCallId,
 128    label: Entity<Markdown>,
 129    icon: IconName,
 130    content: Option<ToolCallContent>,
 131    status: ToolCallStatus,
 132}
 133
 134#[derive(Debug)]
 135pub enum ToolCallStatus {
 136    WaitingForConfirmation {
 137        confirmation: ToolCallConfirmation,
 138        respond_tx: oneshot::Sender<acp::ToolCallConfirmationOutcome>,
 139    },
 140    Allowed {
 141        status: acp::ToolCallStatus,
 142    },
 143    Rejected,
 144    Canceled,
 145}
 146
 147#[derive(Debug)]
 148pub enum ToolCallConfirmation {
 149    Edit {
 150        description: Option<Entity<Markdown>>,
 151    },
 152    Execute {
 153        command: String,
 154        root_command: String,
 155        description: Option<Entity<Markdown>>,
 156    },
 157    Mcp {
 158        server_name: String,
 159        tool_name: String,
 160        tool_display_name: String,
 161        description: Option<Entity<Markdown>>,
 162    },
 163    Fetch {
 164        urls: Vec<String>,
 165        description: Option<Entity<Markdown>>,
 166    },
 167    Other {
 168        description: Entity<Markdown>,
 169    },
 170}
 171
 172impl ToolCallConfirmation {
 173    pub fn from_acp(
 174        confirmation: acp::ToolCallConfirmation,
 175        language_registry: Arc<LanguageRegistry>,
 176        cx: &mut App,
 177    ) -> Self {
 178        let to_md = |description: String, cx: &mut App| -> Entity<Markdown> {
 179            cx.new(|cx| {
 180                Markdown::new(
 181                    description.into(),
 182                    Some(language_registry.clone()),
 183                    None,
 184                    cx,
 185                )
 186            })
 187        };
 188
 189        match confirmation {
 190            acp::ToolCallConfirmation::Edit { description } => Self::Edit {
 191                description: description.map(|description| to_md(description, cx)),
 192            },
 193            acp::ToolCallConfirmation::Execute {
 194                command,
 195                root_command,
 196                description,
 197            } => Self::Execute {
 198                command,
 199                root_command,
 200                description: description.map(|description| to_md(description, cx)),
 201            },
 202            acp::ToolCallConfirmation::Mcp {
 203                server_name,
 204                tool_name,
 205                tool_display_name,
 206                description,
 207            } => Self::Mcp {
 208                server_name,
 209                tool_name,
 210                tool_display_name,
 211                description: description.map(|description| to_md(description, cx)),
 212            },
 213            acp::ToolCallConfirmation::Fetch { urls, description } => Self::Fetch {
 214                urls,
 215                description: description.map(|description| to_md(description, cx)),
 216            },
 217            acp::ToolCallConfirmation::Other { description } => Self::Other {
 218                description: to_md(description, cx),
 219            },
 220        }
 221    }
 222}
 223
 224#[derive(Debug)]
 225pub enum ToolCallContent {
 226    Markdown { markdown: Entity<Markdown> },
 227    Diff { diff: Diff },
 228}
 229
 230impl ToolCallContent {
 231    pub fn from_acp(
 232        content: acp::ToolCallContent,
 233        language_registry: Arc<LanguageRegistry>,
 234        cx: &mut App,
 235    ) -> Self {
 236        match content {
 237            acp::ToolCallContent::Markdown { markdown } => Self::Markdown {
 238                markdown: cx.new(|cx| Markdown::new_text(markdown.into(), cx)),
 239            },
 240            acp::ToolCallContent::Diff { diff } => Self::Diff {
 241                diff: Diff::from_acp(diff, language_registry, cx),
 242            },
 243        }
 244    }
 245}
 246
 247#[derive(Debug)]
 248pub struct Diff {
 249    multibuffer: Entity<MultiBuffer>,
 250    path: PathBuf,
 251    _task: Task<Result<()>>,
 252}
 253
 254impl Diff {
 255    pub fn from_acp(
 256        diff: acp::Diff,
 257        language_registry: Arc<LanguageRegistry>,
 258        cx: &mut App,
 259    ) -> Self {
 260        let acp::Diff {
 261            path,
 262            old_text,
 263            new_text,
 264        } = diff;
 265
 266        let multibuffer = cx.new(|_cx| MultiBuffer::without_headers(Capability::ReadOnly));
 267
 268        let new_buffer = cx.new(|cx| Buffer::local(new_text, cx));
 269        let old_buffer = cx.new(|cx| Buffer::local(old_text.unwrap_or("".into()), cx));
 270        let new_buffer_snapshot = new_buffer.read(cx).text_snapshot();
 271        let old_buffer_snapshot = old_buffer.read(cx).snapshot();
 272        let buffer_diff = cx.new(|cx| BufferDiff::new(&new_buffer_snapshot, cx));
 273        let diff_task = buffer_diff.update(cx, |diff, cx| {
 274            diff.set_base_text(
 275                old_buffer_snapshot,
 276                Some(language_registry.clone()),
 277                new_buffer_snapshot,
 278                cx,
 279            )
 280        });
 281
 282        let task = cx.spawn({
 283            let multibuffer = multibuffer.clone();
 284            let path = path.clone();
 285            async move |cx| {
 286                diff_task.await?;
 287
 288                multibuffer
 289                    .update(cx, |multibuffer, cx| {
 290                        let hunk_ranges = {
 291                            let buffer = new_buffer.read(cx);
 292                            let diff = buffer_diff.read(cx);
 293                            diff.hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &buffer, cx)
 294                                .map(|diff_hunk| diff_hunk.buffer_range.to_point(&buffer))
 295                                .collect::<Vec<_>>()
 296                        };
 297
 298                        multibuffer.set_excerpts_for_path(
 299                            PathKey::for_buffer(&new_buffer, cx),
 300                            new_buffer.clone(),
 301                            hunk_ranges,
 302                            editor::DEFAULT_MULTIBUFFER_CONTEXT,
 303                            cx,
 304                        );
 305                        multibuffer.add_diff(buffer_diff.clone(), cx);
 306                    })
 307                    .log_err();
 308
 309                if let Some(language) = language_registry
 310                    .language_for_file_path(&path)
 311                    .await
 312                    .log_err()
 313                {
 314                    new_buffer.update(cx, |buffer, cx| buffer.set_language(Some(language), cx))?;
 315                }
 316
 317                anyhow::Ok(())
 318            }
 319        });
 320
 321        Self {
 322            multibuffer,
 323            path,
 324            _task: task,
 325        }
 326    }
 327}
 328
 329/// A `ThreadEntryId` that is known to be a ToolCall
 330#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
 331pub struct ToolCallId(ThreadEntryId);
 332
 333impl ToolCallId {
 334    pub fn as_u64(&self) -> u64 {
 335        self.0.0
 336    }
 337}
 338
 339#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
 340pub struct ThreadEntryId(pub u64);
 341
 342impl ThreadEntryId {
 343    pub fn post_inc(&mut self) -> Self {
 344        let id = *self;
 345        self.0 += 1;
 346        id
 347    }
 348}
 349
 350#[derive(Debug)]
 351pub struct ThreadEntry {
 352    pub id: ThreadEntryId,
 353    pub content: AgentThreadEntryContent,
 354}
 355
 356pub struct AcpThread {
 357    id: ThreadId,
 358    next_entry_id: ThreadEntryId,
 359    entries: Vec<ThreadEntry>,
 360    server: Arc<AcpServer>,
 361    title: SharedString,
 362    project: Entity<Project>,
 363    send_task: Option<Task<()>>,
 364}
 365
 366enum AcpThreadEvent {
 367    NewEntry,
 368    EntryUpdated(usize),
 369}
 370
 371#[derive(PartialEq, Eq)]
 372pub enum ThreadStatus {
 373    Idle,
 374    WaitingForToolConfirmation,
 375    Generating,
 376}
 377
 378impl EventEmitter<AcpThreadEvent> for AcpThread {}
 379
 380impl AcpThread {
 381    pub fn new(
 382        server: Arc<AcpServer>,
 383        thread_id: ThreadId,
 384        entries: Vec<AgentThreadEntryContent>,
 385        project: Entity<Project>,
 386        _: &mut Context<Self>,
 387    ) -> Self {
 388        let mut next_entry_id = ThreadEntryId(0);
 389        Self {
 390            title: "ACP Thread".into(),
 391            entries: entries
 392                .into_iter()
 393                .map(|entry| ThreadEntry {
 394                    id: next_entry_id.post_inc(),
 395                    content: entry,
 396                })
 397                .collect(),
 398            server,
 399            id: thread_id,
 400            next_entry_id,
 401            project,
 402            send_task: None,
 403        }
 404    }
 405
 406    pub fn title(&self) -> SharedString {
 407        self.title.clone()
 408    }
 409
 410    pub fn entries(&self) -> &[ThreadEntry] {
 411        &self.entries
 412    }
 413
 414    pub fn status(&self) -> ThreadStatus {
 415        if self.send_task.is_some() {
 416            if self.waiting_for_tool_confirmation() {
 417                ThreadStatus::WaitingForToolConfirmation
 418            } else {
 419                ThreadStatus::Generating
 420            }
 421        } else {
 422            ThreadStatus::Idle
 423        }
 424    }
 425
 426    pub fn push_entry(
 427        &mut self,
 428        entry: AgentThreadEntryContent,
 429        cx: &mut Context<Self>,
 430    ) -> ThreadEntryId {
 431        let id = self.next_entry_id.post_inc();
 432        self.entries.push(ThreadEntry { id, content: entry });
 433        cx.emit(AcpThreadEvent::NewEntry);
 434        id
 435    }
 436
 437    pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context<Self>) {
 438        let entries_len = self.entries.len();
 439        if let Some(last_entry) = self.entries.last_mut()
 440            && let AgentThreadEntryContent::Message(Message {
 441                ref mut chunks,
 442                role: Role::Assistant,
 443            }) = last_entry.content
 444        {
 445            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 446
 447            if let (
 448                Some(MessageChunk::Text { chunk: old_chunk }),
 449                acp::MessageChunk::Text { chunk: new_chunk },
 450            ) = (chunks.last_mut(), &chunk)
 451            {
 452                old_chunk.update(cx, |old_chunk, cx| {
 453                    old_chunk.append(&new_chunk, cx);
 454                });
 455            } else {
 456                chunks.push(MessageChunk::from_acp(
 457                    chunk,
 458                    self.project.read(cx).languages().clone(),
 459                    cx,
 460                ));
 461            }
 462
 463            return;
 464        }
 465
 466        let chunk = MessageChunk::from_acp(chunk, self.project.read(cx).languages().clone(), cx);
 467
 468        self.push_entry(
 469            AgentThreadEntryContent::Message(Message {
 470                role: Role::Assistant,
 471                chunks: vec![chunk],
 472            }),
 473            cx,
 474        );
 475    }
 476
 477    pub fn request_tool_call(
 478        &mut self,
 479        label: String,
 480        icon: acp::Icon,
 481        content: Option<acp::ToolCallContent>,
 482        confirmation: acp::ToolCallConfirmation,
 483        cx: &mut Context<Self>,
 484    ) -> ToolCallRequest {
 485        let (tx, rx) = oneshot::channel();
 486
 487        let status = ToolCallStatus::WaitingForConfirmation {
 488            confirmation: ToolCallConfirmation::from_acp(
 489                confirmation,
 490                self.project.read(cx).languages().clone(),
 491                cx,
 492            ),
 493            respond_tx: tx,
 494        };
 495
 496        let id = self.insert_tool_call(label, status, icon, content, cx);
 497        ToolCallRequest { id, outcome: rx }
 498    }
 499
 500    pub fn push_tool_call(
 501        &mut self,
 502        label: String,
 503        icon: acp::Icon,
 504        content: Option<acp::ToolCallContent>,
 505        cx: &mut Context<Self>,
 506    ) -> ToolCallId {
 507        let status = ToolCallStatus::Allowed {
 508            status: acp::ToolCallStatus::Running,
 509        };
 510
 511        self.insert_tool_call(label, status, icon, content, cx)
 512    }
 513
 514    fn insert_tool_call(
 515        &mut self,
 516        label: String,
 517        status: ToolCallStatus,
 518        icon: acp::Icon,
 519        content: Option<acp::ToolCallContent>,
 520        cx: &mut Context<Self>,
 521    ) -> ToolCallId {
 522        let language_registry = self.project.read(cx).languages().clone();
 523
 524        let entry_id = self.push_entry(
 525            AgentThreadEntryContent::ToolCall(ToolCall {
 526                // todo! clean up id creation
 527                id: ToolCallId(ThreadEntryId(self.entries.len() as u64)),
 528                label: cx.new(|cx| {
 529                    Markdown::new(label.into(), Some(language_registry.clone()), None, cx)
 530                }),
 531                icon: acp_icon_to_ui_icon(icon),
 532                content: content
 533                    .map(|content| ToolCallContent::from_acp(content, language_registry, cx)),
 534                status,
 535            }),
 536            cx,
 537        );
 538
 539        ToolCallId(entry_id)
 540    }
 541
 542    pub fn authorize_tool_call(
 543        &mut self,
 544        id: ToolCallId,
 545        outcome: acp::ToolCallConfirmationOutcome,
 546        cx: &mut Context<Self>,
 547    ) {
 548        let Some(entry) = self.entry_mut(id.0) else {
 549            return;
 550        };
 551
 552        let AgentThreadEntryContent::ToolCall(call) = &mut entry.content else {
 553            debug_panic!("expected ToolCall");
 554            return;
 555        };
 556
 557        let new_status = if outcome == acp::ToolCallConfirmationOutcome::Reject {
 558            ToolCallStatus::Rejected
 559        } else {
 560            ToolCallStatus::Allowed {
 561                status: acp::ToolCallStatus::Running,
 562            }
 563        };
 564
 565        let curr_status = mem::replace(&mut call.status, new_status);
 566
 567        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
 568            respond_tx.send(outcome).log_err();
 569        } else {
 570            debug_panic!("tried to authorize an already authorized tool call");
 571        }
 572
 573        cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
 574    }
 575
 576    pub fn update_tool_call(
 577        &mut self,
 578        id: ToolCallId,
 579        new_status: acp::ToolCallStatus,
 580        new_content: Option<acp::ToolCallContent>,
 581        cx: &mut Context<Self>,
 582    ) -> Result<()> {
 583        let language_registry = self.project.read(cx).languages().clone();
 584        let entry = self.entry_mut(id.0).context("Entry not found")?;
 585
 586        match &mut entry.content {
 587            AgentThreadEntryContent::ToolCall(call) => {
 588                call.content = new_content.map(|new_content| {
 589                    ToolCallContent::from_acp(new_content, language_registry, cx)
 590                });
 591
 592                match &mut call.status {
 593                    ToolCallStatus::Allowed { status } => {
 594                        *status = new_status;
 595                    }
 596                    ToolCallStatus::WaitingForConfirmation { .. } => {
 597                        anyhow::bail!("Tool call hasn't been authorized yet")
 598                    }
 599                    ToolCallStatus::Rejected => {
 600                        anyhow::bail!("Tool call was rejected and therefore can't be updated")
 601                    }
 602                    ToolCallStatus::Canceled => {
 603                        // todo! test this case with fake server
 604                        call.status = ToolCallStatus::Allowed { status: new_status };
 605                    }
 606                }
 607            }
 608            _ => anyhow::bail!("Entry is not a tool call"),
 609        }
 610
 611        cx.emit(AcpThreadEvent::EntryUpdated(id.as_u64() as usize));
 612        Ok(())
 613    }
 614
 615    fn entry_mut(&mut self, id: ThreadEntryId) -> Option<&mut ThreadEntry> {
 616        let entry = self.entries.get_mut(id.0 as usize);
 617        debug_assert!(
 618            entry.is_some(),
 619            "We shouldn't give out ids to entries that don't exist"
 620        );
 621        entry
 622    }
 623
 624    /// Returns true if the last turn is awaiting tool authorization
 625    pub fn waiting_for_tool_confirmation(&self) -> bool {
 626        // todo!("should we use a hashmap?")
 627        for entry in self.entries.iter().rev() {
 628            match &entry.content {
 629                AgentThreadEntryContent::ToolCall(call) => match call.status {
 630                    ToolCallStatus::WaitingForConfirmation { .. } => return true,
 631                    ToolCallStatus::Allowed { .. }
 632                    | ToolCallStatus::Rejected
 633                    | ToolCallStatus::Canceled => continue,
 634                },
 635                AgentThreadEntryContent::Message(_) => {
 636                    // Reached the beginning of the turn
 637                    return false;
 638                }
 639            }
 640        }
 641        false
 642    }
 643
 644    pub fn send(
 645        &mut self,
 646        message: &str,
 647        cx: &mut Context<Self>,
 648    ) -> impl use<> + Future<Output = Result<()>> {
 649        let agent = self.server.clone();
 650        let id = self.id.clone();
 651
 652        let chunk = MessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
 653        let message = Message {
 654            role: Role::User,
 655            chunks: vec![chunk],
 656        };
 657        self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
 658        let acp_message = message.into_acp(cx);
 659
 660        let (tx, rx) = oneshot::channel();
 661        let cancel = self.cancel(cx);
 662
 663        self.send_task = Some(cx.spawn(async move |this, cx| {
 664            cancel.await.log_err();
 665
 666            let result = agent.send_message(id, acp_message, cx).await;
 667            tx.send(result).log_err();
 668            this.update(cx, |this, _cx| this.send_task.take()).log_err();
 669        }));
 670
 671        async move {
 672            match rx.await {
 673                Ok(result) => result,
 674                Err(_) => Ok(()),
 675            }
 676        }
 677    }
 678
 679    pub fn cancel(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 680        let agent = self.server.clone();
 681        let id = self.id.clone();
 682
 683        if self.send_task.take().is_some() {
 684            cx.spawn(async move |this, cx| {
 685                agent.cancel_send_message(id, cx).await?;
 686
 687                this.update(cx, |this, _cx| {
 688                    for entry in this.entries.iter_mut() {
 689                        if let AgentThreadEntryContent::ToolCall(call) = &mut entry.content {
 690                            let cancel = matches!(
 691                                call.status,
 692                                ToolCallStatus::WaitingForConfirmation { .. }
 693                                    | ToolCallStatus::Allowed {
 694                                        status: acp::ToolCallStatus::Running
 695                                    }
 696                            );
 697
 698                            if cancel {
 699                                let curr_status =
 700                                    mem::replace(&mut call.status, ToolCallStatus::Canceled);
 701
 702                                if let ToolCallStatus::WaitingForConfirmation {
 703                                    respond_tx, ..
 704                                } = curr_status
 705                                {
 706                                    respond_tx
 707                                        .send(acp::ToolCallConfirmationOutcome::Cancel)
 708                                        .ok();
 709                                }
 710                            }
 711                        }
 712                    }
 713                })
 714            })
 715        } else {
 716            Task::ready(Ok(()))
 717        }
 718    }
 719}
 720
 721fn acp_icon_to_ui_icon(icon: acp::Icon) -> IconName {
 722    match icon {
 723        acp::Icon::FileSearch => IconName::FileSearch,
 724        acp::Icon::Folder => IconName::Folder,
 725        acp::Icon::Globe => IconName::Globe,
 726        acp::Icon::Hammer => IconName::Hammer,
 727        acp::Icon::LightBulb => IconName::LightBulb,
 728        acp::Icon::Pencil => IconName::Pencil,
 729        acp::Icon::Regex => IconName::Regex,
 730        acp::Icon::Terminal => IconName::Terminal,
 731    }
 732}
 733
 734pub struct ToolCallRequest {
 735    pub id: ToolCallId,
 736    pub outcome: oneshot::Receiver<acp::ToolCallConfirmationOutcome>,
 737}
 738
 739#[cfg(test)]
 740mod tests {
 741    use super::*;
 742    use futures::{FutureExt as _, channel::mpsc, select};
 743    use gpui::TestAppContext;
 744    use project::FakeFs;
 745    use serde_json::json;
 746    use settings::SettingsStore;
 747    use smol::stream::StreamExt as _;
 748    use std::{env, path::Path, process::Stdio, time::Duration};
 749    use util::path;
 750
 751    fn init_test(cx: &mut TestAppContext) {
 752        env_logger::try_init().ok();
 753        cx.update(|cx| {
 754            let settings_store = SettingsStore::test(cx);
 755            cx.set_global(settings_store);
 756            Project::init_settings(cx);
 757            language::init(cx);
 758        });
 759    }
 760
 761    #[gpui::test]
 762    async fn test_gemini_basic(cx: &mut TestAppContext) {
 763        init_test(cx);
 764
 765        cx.executor().allow_parking();
 766
 767        let fs = FakeFs::new(cx.executor());
 768        let project = Project::test(fs, [], cx).await;
 769        let server = gemini_acp_server(project.clone(), cx).await;
 770        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
 771        thread
 772            .update(cx, |thread, cx| thread.send("Hello from Zed!", cx))
 773            .await
 774            .unwrap();
 775
 776        thread.read_with(cx, |thread, _| {
 777            assert_eq!(thread.entries.len(), 2);
 778            assert!(matches!(
 779                thread.entries[0].content,
 780                AgentThreadEntryContent::Message(Message {
 781                    role: Role::User,
 782                    ..
 783                })
 784            ));
 785            assert!(matches!(
 786                thread.entries[1].content,
 787                AgentThreadEntryContent::Message(Message {
 788                    role: Role::Assistant,
 789                    ..
 790                })
 791            ));
 792        });
 793    }
 794
 795    #[gpui::test]
 796    async fn test_gemini_tool_call(cx: &mut TestAppContext) {
 797        init_test(cx);
 798
 799        cx.executor().allow_parking();
 800
 801        let fs = FakeFs::new(cx.executor());
 802        fs.insert_tree(
 803            path!("/private/tmp"),
 804            json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
 805        )
 806        .await;
 807        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
 808        let server = gemini_acp_server(project.clone(), cx).await;
 809        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
 810        thread
 811            .update(cx, |thread, cx| {
 812                thread.send(
 813                    "Read the '/private/tmp/foo' file and tell me what you see.",
 814                    cx,
 815                )
 816            })
 817            .await
 818            .unwrap();
 819        thread.read_with(cx, |thread, _cx| {
 820            assert!(matches!(
 821                &thread.entries()[1].content,
 822                AgentThreadEntryContent::ToolCall(ToolCall {
 823                    status: ToolCallStatus::Allowed { .. },
 824                    ..
 825                })
 826            ));
 827
 828            assert!(matches!(
 829                thread.entries[2].content,
 830                AgentThreadEntryContent::Message(Message {
 831                    role: Role::Assistant,
 832                    ..
 833                })
 834            ));
 835        });
 836    }
 837
 838    #[gpui::test]
 839    async fn test_gemini_tool_call_with_confirmation(cx: &mut TestAppContext) {
 840        init_test(cx);
 841
 842        cx.executor().allow_parking();
 843
 844        let fs = FakeFs::new(cx.executor());
 845        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
 846        let server = gemini_acp_server(project.clone(), cx).await;
 847        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
 848        let full_turn = thread.update(cx, |thread, cx| {
 849            thread.send(r#"Run `echo "Hello, world!"`"#, cx)
 850        });
 851
 852        run_until_tool_call(&thread, cx).await;
 853
 854        let tool_call_id = thread.read_with(cx, |thread, _cx| {
 855            let AgentThreadEntryContent::ToolCall(ToolCall {
 856                id,
 857                status:
 858                    ToolCallStatus::WaitingForConfirmation {
 859                        confirmation: ToolCallConfirmation::Execute { root_command, .. },
 860                        ..
 861                    },
 862                ..
 863            }) = &thread.entries()[1].content
 864            else {
 865                panic!();
 866            };
 867
 868            assert_eq!(root_command, "echo");
 869
 870            *id
 871        });
 872
 873        thread.update(cx, |thread, cx| {
 874            thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
 875
 876            assert!(matches!(
 877                &thread.entries()[1].content,
 878                AgentThreadEntryContent::ToolCall(ToolCall {
 879                    status: ToolCallStatus::Allowed { .. },
 880                    ..
 881                })
 882            ));
 883        });
 884
 885        full_turn.await.unwrap();
 886
 887        thread.read_with(cx, |thread, cx| {
 888            let AgentThreadEntryContent::ToolCall(ToolCall {
 889                content: Some(ToolCallContent::Markdown { markdown }),
 890                status: ToolCallStatus::Allowed { .. },
 891                ..
 892            }) = &thread.entries()[1].content
 893            else {
 894                panic!();
 895            };
 896
 897            markdown.read_with(cx, |md, _cx| {
 898                assert!(
 899                    md.source().contains("Hello, world!"),
 900                    r#"Expected '{}' to contain "Hello, world!""#,
 901                    md.source()
 902                );
 903            });
 904        });
 905    }
 906
 907    #[gpui::test]
 908    async fn test_gemini_cancel(cx: &mut TestAppContext) {
 909        init_test(cx);
 910
 911        cx.executor().allow_parking();
 912
 913        let fs = FakeFs::new(cx.executor());
 914        let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
 915        let server = gemini_acp_server(project.clone(), cx).await;
 916        let thread = server.create_thread(&mut cx.to_async()).await.unwrap();
 917        let full_turn = thread.update(cx, |thread, cx| {
 918            thread.send(r#"Run `echo "Hello, world!"`"#, cx)
 919        });
 920
 921        run_until_tool_call(&thread, cx).await;
 922
 923        thread.read_with(cx, |thread, _cx| {
 924            let AgentThreadEntryContent::ToolCall(ToolCall {
 925                id,
 926                status:
 927                    ToolCallStatus::WaitingForConfirmation {
 928                        confirmation: ToolCallConfirmation::Execute { root_command, .. },
 929                        ..
 930                    },
 931                ..
 932            }) = &thread.entries()[1].content
 933            else {
 934                panic!();
 935            };
 936
 937            assert_eq!(root_command, "echo");
 938
 939            *id
 940        });
 941
 942        thread
 943            .update(cx, |thread, cx| thread.cancel(cx))
 944            .await
 945            .unwrap();
 946        full_turn.await.unwrap();
 947        thread.read_with(cx, |thread, _| {
 948            let AgentThreadEntryContent::ToolCall(ToolCall {
 949                status: ToolCallStatus::Canceled,
 950                ..
 951            }) = &thread.entries()[1].content
 952            else {
 953                panic!();
 954            };
 955        });
 956
 957        thread
 958            .update(cx, |thread, cx| {
 959                thread.send(r#"Stop running and say goodbye to me."#, cx)
 960            })
 961            .await
 962            .unwrap();
 963        thread.read_with(cx, |thread, _| {
 964            let AgentThreadEntryContent::Message(Message {
 965                role: Role::Assistant,
 966                ..
 967            }) = &thread.entries()[3].content
 968            else {
 969                panic!();
 970            };
 971        });
 972    }
 973
 974    async fn run_until_tool_call(thread: &Entity<AcpThread>, cx: &mut TestAppContext) {
 975        let (mut tx, mut rx) = mpsc::channel::<()>(1);
 976
 977        let subscription = cx.update(|cx| {
 978            cx.subscribe(thread, move |thread, _, cx| {
 979                if thread
 980                    .read(cx)
 981                    .entries
 982                    .iter()
 983                    .any(|e| matches!(e.content, AgentThreadEntryContent::ToolCall(_)))
 984                {
 985                    tx.try_send(()).unwrap();
 986                }
 987            })
 988        });
 989
 990        select! {
 991            _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
 992                panic!("Timeout waiting for tool call")
 993            }
 994            _ = rx.next().fuse() => {
 995                drop(subscription);
 996            }
 997        }
 998    }
 999
1000    pub async fn gemini_acp_server(
1001        project: Entity<Project>,
1002        cx: &mut TestAppContext,
1003    ) -> Arc<AcpServer> {
1004        let cli_path =
1005            Path::new(env!("CARGO_MANIFEST_DIR")).join("../../../gemini-cli/packages/cli");
1006        let mut command = util::command::new_smol_command("node");
1007        command
1008            .arg(cli_path)
1009            .arg("--acp")
1010            .current_dir("/private/tmp")
1011            .stdin(Stdio::piped())
1012            .stdout(Stdio::piped())
1013            .stderr(Stdio::inherit())
1014            .kill_on_drop(true);
1015
1016        if let Ok(gemini_key) = std::env::var("GEMINI_API_KEY") {
1017            command.env("GEMINI_API_KEY", gemini_key);
1018        }
1019
1020        let child = command.spawn().unwrap();
1021        let server = cx.update(|cx| AcpServer::stdio(child, project, cx));
1022        server.initialize().await.unwrap();
1023        server
1024    }
1025}