mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use agent_client_protocol::{self as acp};
   4use agent_settings::AgentProfileId;
   5use anyhow::Result;
   6use client::{Client, UserStore};
   7use cloud_llm_client::CompletionIntent;
   8use collections::IndexMap;
   9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  10use feature_flags::FeatureFlagAppExt as _;
  11use fs::{FakeFs, Fs};
  12use futures::{
  13    FutureExt as _, StreamExt,
  14    channel::{
  15        mpsc::{self, UnboundedReceiver},
  16        oneshot,
  17    },
  18    future::{Fuse, Shared},
  19};
  20use gpui::{
  21    App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
  22    http_client::FakeHttpClient,
  23};
  24use indoc::indoc;
  25use language_model::{
  26    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  27    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
  28    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
  29    LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
  30};
  31use pretty_assertions::assert_eq;
  32use project::{
  33    Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
  34};
  35use prompt_store::ProjectContext;
  36use reqwest_client::ReqwestClient;
  37use schemars::JsonSchema;
  38use serde::{Deserialize, Serialize};
  39use serde_json::json;
  40use settings::{Settings, SettingsStore};
  41use std::{
  42    path::Path,
  43    pin::Pin,
  44    rc::Rc,
  45    sync::{
  46        Arc,
  47        atomic::{AtomicBool, Ordering},
  48    },
  49    time::Duration,
  50};
  51use util::path;
  52
  53mod test_tools;
  54use test_tools::*;
  55
  56fn init_test(cx: &mut TestAppContext) {
  57    cx.update(|cx| {
  58        let settings_store = SettingsStore::test(cx);
  59        cx.set_global(settings_store);
  60    });
  61}
  62
  63struct FakeTerminalHandle {
  64    killed: Arc<AtomicBool>,
  65    stopped_by_user: Arc<AtomicBool>,
  66    exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
  67    wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
  68    output: acp::TerminalOutputResponse,
  69    id: acp::TerminalId,
  70}
  71
  72impl FakeTerminalHandle {
  73    fn new_never_exits(cx: &mut App) -> Self {
  74        let killed = Arc::new(AtomicBool::new(false));
  75        let stopped_by_user = Arc::new(AtomicBool::new(false));
  76
  77        let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
  78
  79        let wait_for_exit = cx
  80            .spawn(async move |_cx| {
  81                // Wait for the exit signal (sent when kill() is called)
  82                let _ = exit_receiver.await;
  83                acp::TerminalExitStatus::new()
  84            })
  85            .shared();
  86
  87        Self {
  88            killed,
  89            stopped_by_user,
  90            exit_sender: std::cell::RefCell::new(Some(exit_sender)),
  91            wait_for_exit,
  92            output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
  93            id: acp::TerminalId::new("fake_terminal".to_string()),
  94        }
  95    }
  96
  97    fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
  98        let killed = Arc::new(AtomicBool::new(false));
  99        let stopped_by_user = Arc::new(AtomicBool::new(false));
 100        let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
 101
 102        let wait_for_exit = cx
 103            .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
 104            .shared();
 105
 106        Self {
 107            killed,
 108            stopped_by_user,
 109            exit_sender: std::cell::RefCell::new(Some(exit_sender)),
 110            wait_for_exit,
 111            output: acp::TerminalOutputResponse::new("command output".to_string(), false),
 112            id: acp::TerminalId::new("fake_terminal".to_string()),
 113        }
 114    }
 115
 116    fn was_killed(&self) -> bool {
 117        self.killed.load(Ordering::SeqCst)
 118    }
 119
 120    fn set_stopped_by_user(&self, stopped: bool) {
 121        self.stopped_by_user.store(stopped, Ordering::SeqCst);
 122    }
 123
 124    fn signal_exit(&self) {
 125        if let Some(sender) = self.exit_sender.borrow_mut().take() {
 126            let _ = sender.send(());
 127        }
 128    }
 129}
 130
 131impl crate::TerminalHandle for FakeTerminalHandle {
 132    fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
 133        Ok(self.id.clone())
 134    }
 135
 136    fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
 137        Ok(self.output.clone())
 138    }
 139
 140    fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
 141        Ok(self.wait_for_exit.clone())
 142    }
 143
 144    fn kill(&self, _cx: &AsyncApp) -> Result<()> {
 145        self.killed.store(true, Ordering::SeqCst);
 146        self.signal_exit();
 147        Ok(())
 148    }
 149
 150    fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
 151        Ok(self.stopped_by_user.load(Ordering::SeqCst))
 152    }
 153}
 154
 155struct FakeThreadEnvironment {
 156    handle: Rc<FakeTerminalHandle>,
 157}
 158
 159impl crate::ThreadEnvironment for FakeThreadEnvironment {
 160    fn create_terminal(
 161        &self,
 162        _command: String,
 163        _cwd: Option<std::path::PathBuf>,
 164        _output_byte_limit: Option<u64>,
 165        _cx: &mut AsyncApp,
 166    ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
 167        Task::ready(Ok(self.handle.clone() as Rc<dyn crate::TerminalHandle>))
 168    }
 169}
 170
 171/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
 172struct MultiTerminalEnvironment {
 173    handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
 174}
 175
 176impl MultiTerminalEnvironment {
 177    fn new() -> Self {
 178        Self {
 179            handles: std::cell::RefCell::new(Vec::new()),
 180        }
 181    }
 182
 183    fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
 184        self.handles.borrow().clone()
 185    }
 186}
 187
 188impl crate::ThreadEnvironment for MultiTerminalEnvironment {
 189    fn create_terminal(
 190        &self,
 191        _command: String,
 192        _cwd: Option<std::path::PathBuf>,
 193        _output_byte_limit: Option<u64>,
 194        cx: &mut AsyncApp,
 195    ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
 196        let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
 197        self.handles.borrow_mut().push(handle.clone());
 198        Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
 199    }
 200}
 201
 202fn always_allow_tools(cx: &mut TestAppContext) {
 203    cx.update(|cx| {
 204        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
 205        settings.always_allow_tool_actions = true;
 206        agent_settings::AgentSettings::override_global(settings, cx);
 207    });
 208}
 209
 210#[gpui::test]
 211async fn test_echo(cx: &mut TestAppContext) {
 212    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 213    let fake_model = model.as_fake();
 214
 215    let events = thread
 216        .update(cx, |thread, cx| {
 217            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
 218        })
 219        .unwrap();
 220    cx.run_until_parked();
 221    fake_model.send_last_completion_stream_text_chunk("Hello");
 222    fake_model
 223        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
 224    fake_model.end_last_completion_stream();
 225
 226    let events = events.collect().await;
 227    thread.update(cx, |thread, _cx| {
 228        assert_eq!(
 229            thread.last_message().unwrap().to_markdown(),
 230            indoc! {"
 231                ## Assistant
 232
 233                Hello
 234            "}
 235        )
 236    });
 237    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 238}
 239
 240#[gpui::test]
 241async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
 242    init_test(cx);
 243    always_allow_tools(cx);
 244
 245    let fs = FakeFs::new(cx.executor());
 246    let project = Project::test(fs, [], cx).await;
 247
 248    let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
 249    let environment = Rc::new(FakeThreadEnvironment {
 250        handle: handle.clone(),
 251    });
 252
 253    #[allow(clippy::arc_with_non_send_sync)]
 254    let tool = Arc::new(crate::TerminalTool::new(project, environment));
 255    let (event_stream, mut rx) = crate::ToolCallEventStream::test();
 256
 257    let task = cx.update(|cx| {
 258        tool.run(
 259            crate::TerminalToolInput {
 260                command: "sleep 1000".to_string(),
 261                cd: ".".to_string(),
 262                timeout_ms: Some(5),
 263            },
 264            event_stream,
 265            cx,
 266        )
 267    });
 268
 269    let update = rx.expect_update_fields().await;
 270    assert!(
 271        update.content.iter().any(|blocks| {
 272            blocks
 273                .iter()
 274                .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
 275        }),
 276        "expected tool call update to include terminal content"
 277    );
 278
 279    let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
 280
 281    let deadline = std::time::Instant::now() + Duration::from_millis(500);
 282    loop {
 283        if let Some(result) = task_future.as_mut().now_or_never() {
 284            let result = result.expect("terminal tool task should complete");
 285
 286            assert!(
 287                handle.was_killed(),
 288                "expected terminal handle to be killed on timeout"
 289            );
 290            assert!(
 291                result.contains("partial output"),
 292                "expected result to include terminal output, got: {result}"
 293            );
 294            return;
 295        }
 296
 297        if std::time::Instant::now() >= deadline {
 298            panic!("timed out waiting for terminal tool task to complete");
 299        }
 300
 301        cx.run_until_parked();
 302        cx.background_executor.timer(Duration::from_millis(1)).await;
 303    }
 304}
 305
 306#[gpui::test]
 307#[ignore]
 308async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
 309    init_test(cx);
 310    always_allow_tools(cx);
 311
 312    let fs = FakeFs::new(cx.executor());
 313    let project = Project::test(fs, [], cx).await;
 314
 315    let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
 316    let environment = Rc::new(FakeThreadEnvironment {
 317        handle: handle.clone(),
 318    });
 319
 320    #[allow(clippy::arc_with_non_send_sync)]
 321    let tool = Arc::new(crate::TerminalTool::new(project, environment));
 322    let (event_stream, mut rx) = crate::ToolCallEventStream::test();
 323
 324    let _task = cx.update(|cx| {
 325        tool.run(
 326            crate::TerminalToolInput {
 327                command: "sleep 1000".to_string(),
 328                cd: ".".to_string(),
 329                timeout_ms: None,
 330            },
 331            event_stream,
 332            cx,
 333        )
 334    });
 335
 336    let update = rx.expect_update_fields().await;
 337    assert!(
 338        update.content.iter().any(|blocks| {
 339            blocks
 340                .iter()
 341                .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
 342        }),
 343        "expected tool call update to include terminal content"
 344    );
 345
 346    cx.background_executor
 347        .timer(Duration::from_millis(25))
 348        .await;
 349
 350    assert!(
 351        !handle.was_killed(),
 352        "did not expect terminal handle to be killed without a timeout"
 353    );
 354}
 355
 356#[gpui::test]
 357async fn test_thinking(cx: &mut TestAppContext) {
 358    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 359    let fake_model = model.as_fake();
 360
 361    let events = thread
 362        .update(cx, |thread, cx| {
 363            thread.send(
 364                UserMessageId::new(),
 365                [indoc! {"
 366                    Testing:
 367
 368                    Generate a thinking step where you just think the word 'Think',
 369                    and have your final answer be 'Hello'
 370                "}],
 371                cx,
 372            )
 373        })
 374        .unwrap();
 375    cx.run_until_parked();
 376    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
 377        text: "Think".to_string(),
 378        signature: None,
 379    });
 380    fake_model.send_last_completion_stream_text_chunk("Hello");
 381    fake_model
 382        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
 383    fake_model.end_last_completion_stream();
 384
 385    let events = events.collect().await;
 386    thread.update(cx, |thread, _cx| {
 387        assert_eq!(
 388            thread.last_message().unwrap().to_markdown(),
 389            indoc! {"
 390                ## Assistant
 391
 392                <think>Think</think>
 393                Hello
 394            "}
 395        )
 396    });
 397    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 398}
 399
 400#[gpui::test]
 401async fn test_system_prompt(cx: &mut TestAppContext) {
 402    let ThreadTest {
 403        model,
 404        thread,
 405        project_context,
 406        ..
 407    } = setup(cx, TestModel::Fake).await;
 408    let fake_model = model.as_fake();
 409
 410    project_context.update(cx, |project_context, _cx| {
 411        project_context.shell = "test-shell".into()
 412    });
 413    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 414    thread
 415        .update(cx, |thread, cx| {
 416            thread.send(UserMessageId::new(), ["abc"], cx)
 417        })
 418        .unwrap();
 419    cx.run_until_parked();
 420    let mut pending_completions = fake_model.pending_completions();
 421    assert_eq!(
 422        pending_completions.len(),
 423        1,
 424        "unexpected pending completions: {:?}",
 425        pending_completions
 426    );
 427
 428    let pending_completion = pending_completions.pop().unwrap();
 429    assert_eq!(pending_completion.messages[0].role, Role::System);
 430
 431    let system_message = &pending_completion.messages[0];
 432    let system_prompt = system_message.content[0].to_str().unwrap();
 433    assert!(
 434        system_prompt.contains("test-shell"),
 435        "unexpected system message: {:?}",
 436        system_message
 437    );
 438    assert!(
 439        system_prompt.contains("## Fixing Diagnostics"),
 440        "unexpected system message: {:?}",
 441        system_message
 442    );
 443}
 444
 445#[gpui::test]
 446async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
 447    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 448    let fake_model = model.as_fake();
 449
 450    thread
 451        .update(cx, |thread, cx| {
 452            thread.send(UserMessageId::new(), ["abc"], cx)
 453        })
 454        .unwrap();
 455    cx.run_until_parked();
 456    let mut pending_completions = fake_model.pending_completions();
 457    assert_eq!(
 458        pending_completions.len(),
 459        1,
 460        "unexpected pending completions: {:?}",
 461        pending_completions
 462    );
 463
 464    let pending_completion = pending_completions.pop().unwrap();
 465    assert_eq!(pending_completion.messages[0].role, Role::System);
 466
 467    let system_message = &pending_completion.messages[0];
 468    let system_prompt = system_message.content[0].to_str().unwrap();
 469    assert!(
 470        !system_prompt.contains("## Tool Use"),
 471        "unexpected system message: {:?}",
 472        system_message
 473    );
 474    assert!(
 475        !system_prompt.contains("## Fixing Diagnostics"),
 476        "unexpected system message: {:?}",
 477        system_message
 478    );
 479}
 480
 481#[gpui::test]
 482async fn test_prompt_caching(cx: &mut TestAppContext) {
 483    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 484    let fake_model = model.as_fake();
 485
 486    // Send initial user message and verify it's cached
 487    thread
 488        .update(cx, |thread, cx| {
 489            thread.send(UserMessageId::new(), ["Message 1"], cx)
 490        })
 491        .unwrap();
 492    cx.run_until_parked();
 493
 494    let completion = fake_model.pending_completions().pop().unwrap();
 495    assert_eq!(
 496        completion.messages[1..],
 497        vec![LanguageModelRequestMessage {
 498            role: Role::User,
 499            content: vec!["Message 1".into()],
 500            cache: true,
 501            reasoning_details: None,
 502        }]
 503    );
 504    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 505        "Response to Message 1".into(),
 506    ));
 507    fake_model.end_last_completion_stream();
 508    cx.run_until_parked();
 509
 510    // Send another user message and verify only the latest is cached
 511    thread
 512        .update(cx, |thread, cx| {
 513            thread.send(UserMessageId::new(), ["Message 2"], cx)
 514        })
 515        .unwrap();
 516    cx.run_until_parked();
 517
 518    let completion = fake_model.pending_completions().pop().unwrap();
 519    assert_eq!(
 520        completion.messages[1..],
 521        vec![
 522            LanguageModelRequestMessage {
 523                role: Role::User,
 524                content: vec!["Message 1".into()],
 525                cache: false,
 526                reasoning_details: None,
 527            },
 528            LanguageModelRequestMessage {
 529                role: Role::Assistant,
 530                content: vec!["Response to Message 1".into()],
 531                cache: false,
 532                reasoning_details: None,
 533            },
 534            LanguageModelRequestMessage {
 535                role: Role::User,
 536                content: vec!["Message 2".into()],
 537                cache: true,
 538                reasoning_details: None,
 539            }
 540        ]
 541    );
 542    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 543        "Response to Message 2".into(),
 544    ));
 545    fake_model.end_last_completion_stream();
 546    cx.run_until_parked();
 547
 548    // Simulate a tool call and verify that the latest tool result is cached
 549    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 550    thread
 551        .update(cx, |thread, cx| {
 552            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 553        })
 554        .unwrap();
 555    cx.run_until_parked();
 556
 557    let tool_use = LanguageModelToolUse {
 558        id: "tool_1".into(),
 559        name: EchoTool::name().into(),
 560        raw_input: json!({"text": "test"}).to_string(),
 561        input: json!({"text": "test"}),
 562        is_input_complete: true,
 563        thought_signature: None,
 564    };
 565    fake_model
 566        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 567    fake_model.end_last_completion_stream();
 568    cx.run_until_parked();
 569
 570    let completion = fake_model.pending_completions().pop().unwrap();
 571    let tool_result = LanguageModelToolResult {
 572        tool_use_id: "tool_1".into(),
 573        tool_name: EchoTool::name().into(),
 574        is_error: false,
 575        content: "test".into(),
 576        output: Some("test".into()),
 577    };
 578    assert_eq!(
 579        completion.messages[1..],
 580        vec![
 581            LanguageModelRequestMessage {
 582                role: Role::User,
 583                content: vec!["Message 1".into()],
 584                cache: false,
 585                reasoning_details: None,
 586            },
 587            LanguageModelRequestMessage {
 588                role: Role::Assistant,
 589                content: vec!["Response to Message 1".into()],
 590                cache: false,
 591                reasoning_details: None,
 592            },
 593            LanguageModelRequestMessage {
 594                role: Role::User,
 595                content: vec!["Message 2".into()],
 596                cache: false,
 597                reasoning_details: None,
 598            },
 599            LanguageModelRequestMessage {
 600                role: Role::Assistant,
 601                content: vec!["Response to Message 2".into()],
 602                cache: false,
 603                reasoning_details: None,
 604            },
 605            LanguageModelRequestMessage {
 606                role: Role::User,
 607                content: vec!["Use the echo tool".into()],
 608                cache: false,
 609                reasoning_details: None,
 610            },
 611            LanguageModelRequestMessage {
 612                role: Role::Assistant,
 613                content: vec![MessageContent::ToolUse(tool_use)],
 614                cache: false,
 615                reasoning_details: None,
 616            },
 617            LanguageModelRequestMessage {
 618                role: Role::User,
 619                content: vec![MessageContent::ToolResult(tool_result)],
 620                cache: true,
 621                reasoning_details: None,
 622            }
 623        ]
 624    );
 625}
 626
 627#[gpui::test]
 628#[cfg_attr(not(feature = "e2e"), ignore)]
 629async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 630    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 631
 632    // Test a tool call that's likely to complete *before* streaming stops.
 633    let events = thread
 634        .update(cx, |thread, cx| {
 635            thread.add_tool(EchoTool);
 636            thread.send(
 637                UserMessageId::new(),
 638                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 639                cx,
 640            )
 641        })
 642        .unwrap()
 643        .collect()
 644        .await;
 645    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 646
 647    // Test a tool calls that's likely to complete *after* streaming stops.
 648    let events = thread
 649        .update(cx, |thread, cx| {
 650            thread.remove_tool(&EchoTool::name());
 651            thread.add_tool(DelayTool);
 652            thread.send(
 653                UserMessageId::new(),
 654                [
 655                    "Now call the delay tool with 200ms.",
 656                    "When the timer goes off, then you echo the output of the tool.",
 657                ],
 658                cx,
 659            )
 660        })
 661        .unwrap()
 662        .collect()
 663        .await;
 664    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 665    thread.update(cx, |thread, _cx| {
 666        assert!(
 667            thread
 668                .last_message()
 669                .unwrap()
 670                .as_agent_message()
 671                .unwrap()
 672                .content
 673                .iter()
 674                .any(|content| {
 675                    if let AgentMessageContent::Text(text) = content {
 676                        text.contains("Ding")
 677                    } else {
 678                        false
 679                    }
 680                }),
 681            "{}",
 682            thread.to_markdown()
 683        );
 684    });
 685}
 686
 687#[gpui::test]
 688#[cfg_attr(not(feature = "e2e"), ignore)]
 689async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 690    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 691
 692    // Test a tool call that's likely to complete *before* streaming stops.
 693    let mut events = thread
 694        .update(cx, |thread, cx| {
 695            thread.add_tool(WordListTool);
 696            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 697        })
 698        .unwrap();
 699
 700    let mut saw_partial_tool_use = false;
 701    while let Some(event) = events.next().await {
 702        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 703            thread.update(cx, |thread, _cx| {
 704                // Look for a tool use in the thread's last message
 705                let message = thread.last_message().unwrap();
 706                let agent_message = message.as_agent_message().unwrap();
 707                let last_content = agent_message.content.last().unwrap();
 708                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 709                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 710                    if tool_call.status == acp::ToolCallStatus::Pending {
 711                        if !last_tool_use.is_input_complete
 712                            && last_tool_use.input.get("g").is_none()
 713                        {
 714                            saw_partial_tool_use = true;
 715                        }
 716                    } else {
 717                        last_tool_use
 718                            .input
 719                            .get("a")
 720                            .expect("'a' has streamed because input is now complete");
 721                        last_tool_use
 722                            .input
 723                            .get("g")
 724                            .expect("'g' has streamed because input is now complete");
 725                    }
 726                } else {
 727                    panic!("last content should be a tool use");
 728                }
 729            });
 730        }
 731    }
 732
 733    assert!(
 734        saw_partial_tool_use,
 735        "should see at least one partially streamed tool use in the history"
 736    );
 737}
 738
 739#[gpui::test]
 740async fn test_tool_authorization(cx: &mut TestAppContext) {
 741    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 742    let fake_model = model.as_fake();
 743
 744    let mut events = thread
 745        .update(cx, |thread, cx| {
 746            thread.add_tool(ToolRequiringPermission);
 747            thread.send(UserMessageId::new(), ["abc"], cx)
 748        })
 749        .unwrap();
 750    cx.run_until_parked();
 751    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 752        LanguageModelToolUse {
 753            id: "tool_id_1".into(),
 754            name: ToolRequiringPermission::name().into(),
 755            raw_input: "{}".into(),
 756            input: json!({}),
 757            is_input_complete: true,
 758            thought_signature: None,
 759        },
 760    ));
 761    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 762        LanguageModelToolUse {
 763            id: "tool_id_2".into(),
 764            name: ToolRequiringPermission::name().into(),
 765            raw_input: "{}".into(),
 766            input: json!({}),
 767            is_input_complete: true,
 768            thought_signature: None,
 769        },
 770    ));
 771    fake_model.end_last_completion_stream();
 772    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 773    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 774
 775    // Approve the first
 776    tool_call_auth_1
 777        .response
 778        .send(tool_call_auth_1.options[1].option_id.clone())
 779        .unwrap();
 780    cx.run_until_parked();
 781
 782    // Reject the second
 783    tool_call_auth_2
 784        .response
 785        .send(tool_call_auth_1.options[2].option_id.clone())
 786        .unwrap();
 787    cx.run_until_parked();
 788
 789    let completion = fake_model.pending_completions().pop().unwrap();
 790    let message = completion.messages.last().unwrap();
 791    assert_eq!(
 792        message.content,
 793        vec![
 794            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 795                tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
 796                tool_name: ToolRequiringPermission::name().into(),
 797                is_error: false,
 798                content: "Allowed".into(),
 799                output: Some("Allowed".into())
 800            }),
 801            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 802                tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
 803                tool_name: ToolRequiringPermission::name().into(),
 804                is_error: true,
 805                content: "Permission to run tool denied by user".into(),
 806                output: Some("Permission to run tool denied by user".into())
 807            })
 808        ]
 809    );
 810
 811    // Simulate yet another tool call.
 812    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 813        LanguageModelToolUse {
 814            id: "tool_id_3".into(),
 815            name: ToolRequiringPermission::name().into(),
 816            raw_input: "{}".into(),
 817            input: json!({}),
 818            is_input_complete: true,
 819            thought_signature: None,
 820        },
 821    ));
 822    fake_model.end_last_completion_stream();
 823
 824    // Respond by always allowing tools.
 825    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 826    tool_call_auth_3
 827        .response
 828        .send(tool_call_auth_3.options[0].option_id.clone())
 829        .unwrap();
 830    cx.run_until_parked();
 831    let completion = fake_model.pending_completions().pop().unwrap();
 832    let message = completion.messages.last().unwrap();
 833    assert_eq!(
 834        message.content,
 835        vec![language_model::MessageContent::ToolResult(
 836            LanguageModelToolResult {
 837                tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
 838                tool_name: ToolRequiringPermission::name().into(),
 839                is_error: false,
 840                content: "Allowed".into(),
 841                output: Some("Allowed".into())
 842            }
 843        )]
 844    );
 845
 846    // Simulate a final tool call, ensuring we don't trigger authorization.
 847    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 848        LanguageModelToolUse {
 849            id: "tool_id_4".into(),
 850            name: ToolRequiringPermission::name().into(),
 851            raw_input: "{}".into(),
 852            input: json!({}),
 853            is_input_complete: true,
 854            thought_signature: None,
 855        },
 856    ));
 857    fake_model.end_last_completion_stream();
 858    cx.run_until_parked();
 859    let completion = fake_model.pending_completions().pop().unwrap();
 860    let message = completion.messages.last().unwrap();
 861    assert_eq!(
 862        message.content,
 863        vec![language_model::MessageContent::ToolResult(
 864            LanguageModelToolResult {
 865                tool_use_id: "tool_id_4".into(),
 866                tool_name: ToolRequiringPermission::name().into(),
 867                is_error: false,
 868                content: "Allowed".into(),
 869                output: Some("Allowed".into())
 870            }
 871        )]
 872    );
 873}
 874
 875#[gpui::test]
 876async fn test_tool_hallucination(cx: &mut TestAppContext) {
 877    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 878    let fake_model = model.as_fake();
 879
 880    let mut events = thread
 881        .update(cx, |thread, cx| {
 882            thread.send(UserMessageId::new(), ["abc"], cx)
 883        })
 884        .unwrap();
 885    cx.run_until_parked();
 886    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 887        LanguageModelToolUse {
 888            id: "tool_id_1".into(),
 889            name: "nonexistent_tool".into(),
 890            raw_input: "{}".into(),
 891            input: json!({}),
 892            is_input_complete: true,
 893            thought_signature: None,
 894        },
 895    ));
 896    fake_model.end_last_completion_stream();
 897
 898    let tool_call = expect_tool_call(&mut events).await;
 899    assert_eq!(tool_call.title, "nonexistent_tool");
 900    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 901    let update = expect_tool_call_update_fields(&mut events).await;
 902    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 903}
 904
 905#[gpui::test]
 906async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 907    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 908    let fake_model = model.as_fake();
 909
 910    let events = thread
 911        .update(cx, |thread, cx| {
 912            thread.add_tool(EchoTool);
 913            thread.send(UserMessageId::new(), ["abc"], cx)
 914        })
 915        .unwrap();
 916    cx.run_until_parked();
 917    let tool_use = LanguageModelToolUse {
 918        id: "tool_id_1".into(),
 919        name: EchoTool::name().into(),
 920        raw_input: "{}".into(),
 921        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 922        is_input_complete: true,
 923        thought_signature: None,
 924    };
 925    fake_model
 926        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 927    fake_model.end_last_completion_stream();
 928
 929    cx.run_until_parked();
 930    let completion = fake_model.pending_completions().pop().unwrap();
 931    let tool_result = LanguageModelToolResult {
 932        tool_use_id: "tool_id_1".into(),
 933        tool_name: EchoTool::name().into(),
 934        is_error: false,
 935        content: "def".into(),
 936        output: Some("def".into()),
 937    };
 938    assert_eq!(
 939        completion.messages[1..],
 940        vec![
 941            LanguageModelRequestMessage {
 942                role: Role::User,
 943                content: vec!["abc".into()],
 944                cache: false,
 945                reasoning_details: None,
 946            },
 947            LanguageModelRequestMessage {
 948                role: Role::Assistant,
 949                content: vec![MessageContent::ToolUse(tool_use.clone())],
 950                cache: false,
 951                reasoning_details: None,
 952            },
 953            LanguageModelRequestMessage {
 954                role: Role::User,
 955                content: vec![MessageContent::ToolResult(tool_result.clone())],
 956                cache: true,
 957                reasoning_details: None,
 958            },
 959        ]
 960    );
 961
 962    // Simulate reaching tool use limit.
 963    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
 964    fake_model.end_last_completion_stream();
 965    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 966    assert!(
 967        last_event
 968            .unwrap_err()
 969            .is::<language_model::ToolUseLimitReachedError>()
 970    );
 971
 972    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 973    cx.run_until_parked();
 974    let completion = fake_model.pending_completions().pop().unwrap();
 975    assert_eq!(
 976        completion.messages[1..],
 977        vec![
 978            LanguageModelRequestMessage {
 979                role: Role::User,
 980                content: vec!["abc".into()],
 981                cache: false,
 982                reasoning_details: None,
 983            },
 984            LanguageModelRequestMessage {
 985                role: Role::Assistant,
 986                content: vec![MessageContent::ToolUse(tool_use)],
 987                cache: false,
 988                reasoning_details: None,
 989            },
 990            LanguageModelRequestMessage {
 991                role: Role::User,
 992                content: vec![MessageContent::ToolResult(tool_result)],
 993                cache: false,
 994                reasoning_details: None,
 995            },
 996            LanguageModelRequestMessage {
 997                role: Role::User,
 998                content: vec!["Continue where you left off".into()],
 999                cache: true,
1000                reasoning_details: None,
1001            }
1002        ]
1003    );
1004
1005    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
1006    fake_model.end_last_completion_stream();
1007    events.collect::<Vec<_>>().await;
1008    thread.read_with(cx, |thread, _cx| {
1009        assert_eq!(
1010            thread.last_message().unwrap().to_markdown(),
1011            indoc! {"
1012                ## Assistant
1013
1014                Done
1015            "}
1016        )
1017    });
1018}
1019
1020#[gpui::test]
1021async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
1022    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1023    let fake_model = model.as_fake();
1024
1025    let events = thread
1026        .update(cx, |thread, cx| {
1027            thread.add_tool(EchoTool);
1028            thread.send(UserMessageId::new(), ["abc"], cx)
1029        })
1030        .unwrap();
1031    cx.run_until_parked();
1032
1033    let tool_use = LanguageModelToolUse {
1034        id: "tool_id_1".into(),
1035        name: EchoTool::name().into(),
1036        raw_input: "{}".into(),
1037        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
1038        is_input_complete: true,
1039        thought_signature: None,
1040    };
1041    let tool_result = LanguageModelToolResult {
1042        tool_use_id: "tool_id_1".into(),
1043        tool_name: EchoTool::name().into(),
1044        is_error: false,
1045        content: "def".into(),
1046        output: Some("def".into()),
1047    };
1048    fake_model
1049        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
1050    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
1051    fake_model.end_last_completion_stream();
1052    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
1053    assert!(
1054        last_event
1055            .unwrap_err()
1056            .is::<language_model::ToolUseLimitReachedError>()
1057    );
1058
1059    thread
1060        .update(cx, |thread, cx| {
1061            thread.send(UserMessageId::new(), vec!["ghi"], cx)
1062        })
1063        .unwrap();
1064    cx.run_until_parked();
1065    let completion = fake_model.pending_completions().pop().unwrap();
1066    assert_eq!(
1067        completion.messages[1..],
1068        vec![
1069            LanguageModelRequestMessage {
1070                role: Role::User,
1071                content: vec!["abc".into()],
1072                cache: false,
1073                reasoning_details: None,
1074            },
1075            LanguageModelRequestMessage {
1076                role: Role::Assistant,
1077                content: vec![MessageContent::ToolUse(tool_use)],
1078                cache: false,
1079                reasoning_details: None,
1080            },
1081            LanguageModelRequestMessage {
1082                role: Role::User,
1083                content: vec![MessageContent::ToolResult(tool_result)],
1084                cache: false,
1085                reasoning_details: None,
1086            },
1087            LanguageModelRequestMessage {
1088                role: Role::User,
1089                content: vec!["ghi".into()],
1090                cache: true,
1091                reasoning_details: None,
1092            }
1093        ]
1094    );
1095}
1096
1097async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
1098    let event = events
1099        .next()
1100        .await
1101        .expect("no tool call authorization event received")
1102        .unwrap();
1103    match event {
1104        ThreadEvent::ToolCall(tool_call) => tool_call,
1105        event => {
1106            panic!("Unexpected event {event:?}");
1107        }
1108    }
1109}
1110
1111async fn expect_tool_call_update_fields(
1112    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1113) -> acp::ToolCallUpdate {
1114    let event = events
1115        .next()
1116        .await
1117        .expect("no tool call authorization event received")
1118        .unwrap();
1119    match event {
1120        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
1121        event => {
1122            panic!("Unexpected event {event:?}");
1123        }
1124    }
1125}
1126
1127async fn next_tool_call_authorization(
1128    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1129) -> ToolCallAuthorization {
1130    loop {
1131        let event = events
1132            .next()
1133            .await
1134            .expect("no tool call authorization event received")
1135            .unwrap();
1136        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1137            let permission_kinds = tool_call_authorization
1138                .options
1139                .iter()
1140                .map(|o| o.kind)
1141                .collect::<Vec<_>>();
1142            assert_eq!(
1143                permission_kinds,
1144                vec![
1145                    acp::PermissionOptionKind::AllowAlways,
1146                    acp::PermissionOptionKind::AllowOnce,
1147                    acp::PermissionOptionKind::RejectOnce,
1148                ]
1149            );
1150            return tool_call_authorization;
1151        }
1152    }
1153}
1154
1155#[gpui::test]
1156#[cfg_attr(not(feature = "e2e"), ignore)]
1157async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1158    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1159
1160    // Test concurrent tool calls with different delay times
1161    let events = thread
1162        .update(cx, |thread, cx| {
1163            thread.add_tool(DelayTool);
1164            thread.send(
1165                UserMessageId::new(),
1166                [
1167                    "Call the delay tool twice in the same message.",
1168                    "Once with 100ms. Once with 300ms.",
1169                    "When both timers are complete, describe the outputs.",
1170                ],
1171                cx,
1172            )
1173        })
1174        .unwrap()
1175        .collect()
1176        .await;
1177
1178    let stop_reasons = stop_events(events);
1179    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1180
1181    thread.update(cx, |thread, _cx| {
1182        let last_message = thread.last_message().unwrap();
1183        let agent_message = last_message.as_agent_message().unwrap();
1184        let text = agent_message
1185            .content
1186            .iter()
1187            .filter_map(|content| {
1188                if let AgentMessageContent::Text(text) = content {
1189                    Some(text.as_str())
1190                } else {
1191                    None
1192                }
1193            })
1194            .collect::<String>();
1195
1196        assert!(text.contains("Ding"));
1197    });
1198}
1199
1200#[gpui::test]
1201async fn test_profiles(cx: &mut TestAppContext) {
1202    let ThreadTest {
1203        model, thread, fs, ..
1204    } = setup(cx, TestModel::Fake).await;
1205    let fake_model = model.as_fake();
1206
1207    thread.update(cx, |thread, _cx| {
1208        thread.add_tool(DelayTool);
1209        thread.add_tool(EchoTool);
1210        thread.add_tool(InfiniteTool);
1211    });
1212
1213    // Override profiles and wait for settings to be loaded.
1214    fs.insert_file(
1215        paths::settings_file(),
1216        json!({
1217            "agent": {
1218                "profiles": {
1219                    "test-1": {
1220                        "name": "Test Profile 1",
1221                        "tools": {
1222                            EchoTool::name(): true,
1223                            DelayTool::name(): true,
1224                        }
1225                    },
1226                    "test-2": {
1227                        "name": "Test Profile 2",
1228                        "tools": {
1229                            InfiniteTool::name(): true,
1230                        }
1231                    }
1232                }
1233            }
1234        })
1235        .to_string()
1236        .into_bytes(),
1237    )
1238    .await;
1239    cx.run_until_parked();
1240
1241    // Test that test-1 profile (default) has echo and delay tools
1242    thread
1243        .update(cx, |thread, cx| {
1244            thread.set_profile(AgentProfileId("test-1".into()), cx);
1245            thread.send(UserMessageId::new(), ["test"], cx)
1246        })
1247        .unwrap();
1248    cx.run_until_parked();
1249
1250    let mut pending_completions = fake_model.pending_completions();
1251    assert_eq!(pending_completions.len(), 1);
1252    let completion = pending_completions.pop().unwrap();
1253    let tool_names: Vec<String> = completion
1254        .tools
1255        .iter()
1256        .map(|tool| tool.name.clone())
1257        .collect();
1258    assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
1259    fake_model.end_last_completion_stream();
1260
1261    // Switch to test-2 profile, and verify that it has only the infinite tool.
1262    thread
1263        .update(cx, |thread, cx| {
1264            thread.set_profile(AgentProfileId("test-2".into()), cx);
1265            thread.send(UserMessageId::new(), ["test2"], cx)
1266        })
1267        .unwrap();
1268    cx.run_until_parked();
1269    let mut pending_completions = fake_model.pending_completions();
1270    assert_eq!(pending_completions.len(), 1);
1271    let completion = pending_completions.pop().unwrap();
1272    let tool_names: Vec<String> = completion
1273        .tools
1274        .iter()
1275        .map(|tool| tool.name.clone())
1276        .collect();
1277    assert_eq!(tool_names, vec![InfiniteTool::name()]);
1278}
1279
1280#[gpui::test]
1281async fn test_mcp_tools(cx: &mut TestAppContext) {
1282    let ThreadTest {
1283        model,
1284        thread,
1285        context_server_store,
1286        fs,
1287        ..
1288    } = setup(cx, TestModel::Fake).await;
1289    let fake_model = model.as_fake();
1290
1291    // Override profiles and wait for settings to be loaded.
1292    fs.insert_file(
1293        paths::settings_file(),
1294        json!({
1295            "agent": {
1296                "always_allow_tool_actions": true,
1297                "profiles": {
1298                    "test": {
1299                        "name": "Test Profile",
1300                        "enable_all_context_servers": true,
1301                        "tools": {
1302                            EchoTool::name(): true,
1303                        }
1304                    },
1305                }
1306            }
1307        })
1308        .to_string()
1309        .into_bytes(),
1310    )
1311    .await;
1312    cx.run_until_parked();
1313    thread.update(cx, |thread, cx| {
1314        thread.set_profile(AgentProfileId("test".into()), cx)
1315    });
1316
1317    let mut mcp_tool_calls = setup_context_server(
1318        "test_server",
1319        vec![context_server::types::Tool {
1320            name: "echo".into(),
1321            description: None,
1322            input_schema: serde_json::to_value(EchoTool::input_schema(
1323                LanguageModelToolSchemaFormat::JsonSchema,
1324            ))
1325            .unwrap(),
1326            output_schema: None,
1327            annotations: None,
1328        }],
1329        &context_server_store,
1330        cx,
1331    );
1332
1333    let events = thread.update(cx, |thread, cx| {
1334        thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1335    });
1336    cx.run_until_parked();
1337
1338    // Simulate the model calling the MCP tool.
1339    let completion = fake_model.pending_completions().pop().unwrap();
1340    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1341    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1342        LanguageModelToolUse {
1343            id: "tool_1".into(),
1344            name: "echo".into(),
1345            raw_input: json!({"text": "test"}).to_string(),
1346            input: json!({"text": "test"}),
1347            is_input_complete: true,
1348            thought_signature: None,
1349        },
1350    ));
1351    fake_model.end_last_completion_stream();
1352    cx.run_until_parked();
1353
1354    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1355    assert_eq!(tool_call_params.name, "echo");
1356    assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1357    tool_call_response
1358        .send(context_server::types::CallToolResponse {
1359            content: vec![context_server::types::ToolResponseContent::Text {
1360                text: "test".into(),
1361            }],
1362            is_error: None,
1363            meta: None,
1364            structured_content: None,
1365        })
1366        .unwrap();
1367    cx.run_until_parked();
1368
1369    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1370    fake_model.send_last_completion_stream_text_chunk("Done!");
1371    fake_model.end_last_completion_stream();
1372    events.collect::<Vec<_>>().await;
1373
1374    // Send again after adding the echo tool, ensuring the name collision is resolved.
1375    let events = thread.update(cx, |thread, cx| {
1376        thread.add_tool(EchoTool);
1377        thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1378    });
1379    cx.run_until_parked();
1380    let completion = fake_model.pending_completions().pop().unwrap();
1381    assert_eq!(
1382        tool_names_for_completion(&completion),
1383        vec!["echo", "test_server_echo"]
1384    );
1385    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1386        LanguageModelToolUse {
1387            id: "tool_2".into(),
1388            name: "test_server_echo".into(),
1389            raw_input: json!({"text": "mcp"}).to_string(),
1390            input: json!({"text": "mcp"}),
1391            is_input_complete: true,
1392            thought_signature: None,
1393        },
1394    ));
1395    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1396        LanguageModelToolUse {
1397            id: "tool_3".into(),
1398            name: "echo".into(),
1399            raw_input: json!({"text": "native"}).to_string(),
1400            input: json!({"text": "native"}),
1401            is_input_complete: true,
1402            thought_signature: None,
1403        },
1404    ));
1405    fake_model.end_last_completion_stream();
1406    cx.run_until_parked();
1407
1408    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1409    assert_eq!(tool_call_params.name, "echo");
1410    assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1411    tool_call_response
1412        .send(context_server::types::CallToolResponse {
1413            content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1414            is_error: None,
1415            meta: None,
1416            structured_content: None,
1417        })
1418        .unwrap();
1419    cx.run_until_parked();
1420
1421    // Ensure the tool results were inserted with the correct names.
1422    let completion = fake_model.pending_completions().pop().unwrap();
1423    assert_eq!(
1424        completion.messages.last().unwrap().content,
1425        vec![
1426            MessageContent::ToolResult(LanguageModelToolResult {
1427                tool_use_id: "tool_3".into(),
1428                tool_name: "echo".into(),
1429                is_error: false,
1430                content: "native".into(),
1431                output: Some("native".into()),
1432            },),
1433            MessageContent::ToolResult(LanguageModelToolResult {
1434                tool_use_id: "tool_2".into(),
1435                tool_name: "test_server_echo".into(),
1436                is_error: false,
1437                content: "mcp".into(),
1438                output: Some("mcp".into()),
1439            },),
1440        ]
1441    );
1442    fake_model.end_last_completion_stream();
1443    events.collect::<Vec<_>>().await;
1444}
1445
1446#[gpui::test]
1447async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1448    let ThreadTest {
1449        model,
1450        thread,
1451        context_server_store,
1452        fs,
1453        ..
1454    } = setup(cx, TestModel::Fake).await;
1455    let fake_model = model.as_fake();
1456
1457    // Set up a profile with all tools enabled
1458    fs.insert_file(
1459        paths::settings_file(),
1460        json!({
1461            "agent": {
1462                "profiles": {
1463                    "test": {
1464                        "name": "Test Profile",
1465                        "enable_all_context_servers": true,
1466                        "tools": {
1467                            EchoTool::name(): true,
1468                            DelayTool::name(): true,
1469                            WordListTool::name(): true,
1470                            ToolRequiringPermission::name(): true,
1471                            InfiniteTool::name(): true,
1472                        }
1473                    },
1474                }
1475            }
1476        })
1477        .to_string()
1478        .into_bytes(),
1479    )
1480    .await;
1481    cx.run_until_parked();
1482
1483    thread.update(cx, |thread, cx| {
1484        thread.set_profile(AgentProfileId("test".into()), cx);
1485        thread.add_tool(EchoTool);
1486        thread.add_tool(DelayTool);
1487        thread.add_tool(WordListTool);
1488        thread.add_tool(ToolRequiringPermission);
1489        thread.add_tool(InfiniteTool);
1490    });
1491
1492    // Set up multiple context servers with some overlapping tool names
1493    let _server1_calls = setup_context_server(
1494        "xxx",
1495        vec![
1496            context_server::types::Tool {
1497                name: "echo".into(), // Conflicts with native EchoTool
1498                description: None,
1499                input_schema: serde_json::to_value(EchoTool::input_schema(
1500                    LanguageModelToolSchemaFormat::JsonSchema,
1501                ))
1502                .unwrap(),
1503                output_schema: None,
1504                annotations: None,
1505            },
1506            context_server::types::Tool {
1507                name: "unique_tool_1".into(),
1508                description: None,
1509                input_schema: json!({"type": "object", "properties": {}}),
1510                output_schema: None,
1511                annotations: None,
1512            },
1513        ],
1514        &context_server_store,
1515        cx,
1516    );
1517
1518    let _server2_calls = setup_context_server(
1519        "yyy",
1520        vec![
1521            context_server::types::Tool {
1522                name: "echo".into(), // Also conflicts with native EchoTool
1523                description: None,
1524                input_schema: serde_json::to_value(EchoTool::input_schema(
1525                    LanguageModelToolSchemaFormat::JsonSchema,
1526                ))
1527                .unwrap(),
1528                output_schema: None,
1529                annotations: None,
1530            },
1531            context_server::types::Tool {
1532                name: "unique_tool_2".into(),
1533                description: None,
1534                input_schema: json!({"type": "object", "properties": {}}),
1535                output_schema: None,
1536                annotations: None,
1537            },
1538            context_server::types::Tool {
1539                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1540                description: None,
1541                input_schema: json!({"type": "object", "properties": {}}),
1542                output_schema: None,
1543                annotations: None,
1544            },
1545            context_server::types::Tool {
1546                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1547                description: None,
1548                input_schema: json!({"type": "object", "properties": {}}),
1549                output_schema: None,
1550                annotations: None,
1551            },
1552        ],
1553        &context_server_store,
1554        cx,
1555    );
1556    let _server3_calls = setup_context_server(
1557        "zzz",
1558        vec![
1559            context_server::types::Tool {
1560                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1561                description: None,
1562                input_schema: json!({"type": "object", "properties": {}}),
1563                output_schema: None,
1564                annotations: None,
1565            },
1566            context_server::types::Tool {
1567                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1568                description: None,
1569                input_schema: json!({"type": "object", "properties": {}}),
1570                output_schema: None,
1571                annotations: None,
1572            },
1573            context_server::types::Tool {
1574                name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1575                description: None,
1576                input_schema: json!({"type": "object", "properties": {}}),
1577                output_schema: None,
1578                annotations: None,
1579            },
1580        ],
1581        &context_server_store,
1582        cx,
1583    );
1584
1585    thread
1586        .update(cx, |thread, cx| {
1587            thread.send(UserMessageId::new(), ["Go"], cx)
1588        })
1589        .unwrap();
1590    cx.run_until_parked();
1591    let completion = fake_model.pending_completions().pop().unwrap();
1592    assert_eq!(
1593        tool_names_for_completion(&completion),
1594        vec![
1595            "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1596            "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1597            "delay",
1598            "echo",
1599            "infinite",
1600            "tool_requiring_permission",
1601            "unique_tool_1",
1602            "unique_tool_2",
1603            "word_list",
1604            "xxx_echo",
1605            "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1606            "yyy_echo",
1607            "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1608        ]
1609    );
1610}
1611
1612#[gpui::test]
1613#[cfg_attr(not(feature = "e2e"), ignore)]
1614async fn test_cancellation(cx: &mut TestAppContext) {
1615    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1616
1617    let mut events = thread
1618        .update(cx, |thread, cx| {
1619            thread.add_tool(InfiniteTool);
1620            thread.add_tool(EchoTool);
1621            thread.send(
1622                UserMessageId::new(),
1623                ["Call the echo tool, then call the infinite tool, then explain their output"],
1624                cx,
1625            )
1626        })
1627        .unwrap();
1628
1629    // Wait until both tools are called.
1630    let mut expected_tools = vec!["Echo", "Infinite Tool"];
1631    let mut echo_id = None;
1632    let mut echo_completed = false;
1633    while let Some(event) = events.next().await {
1634        match event.unwrap() {
1635            ThreadEvent::ToolCall(tool_call) => {
1636                assert_eq!(tool_call.title, expected_tools.remove(0));
1637                if tool_call.title == "Echo" {
1638                    echo_id = Some(tool_call.tool_call_id);
1639                }
1640            }
1641            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1642                acp::ToolCallUpdate {
1643                    tool_call_id,
1644                    fields:
1645                        acp::ToolCallUpdateFields {
1646                            status: Some(acp::ToolCallStatus::Completed),
1647                            ..
1648                        },
1649                    ..
1650                },
1651            )) if Some(&tool_call_id) == echo_id.as_ref() => {
1652                echo_completed = true;
1653            }
1654            _ => {}
1655        }
1656
1657        if expected_tools.is_empty() && echo_completed {
1658            break;
1659        }
1660    }
1661
1662    // Cancel the current send and ensure that the event stream is closed, even
1663    // if one of the tools is still running.
1664    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1665    let events = events.collect::<Vec<_>>().await;
1666    let last_event = events.last();
1667    assert!(
1668        matches!(
1669            last_event,
1670            Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1671        ),
1672        "unexpected event {last_event:?}"
1673    );
1674
1675    // Ensure we can still send a new message after cancellation.
1676    let events = thread
1677        .update(cx, |thread, cx| {
1678            thread.send(
1679                UserMessageId::new(),
1680                ["Testing: reply with 'Hello' then stop."],
1681                cx,
1682            )
1683        })
1684        .unwrap()
1685        .collect::<Vec<_>>()
1686        .await;
1687    thread.update(cx, |thread, _cx| {
1688        let message = thread.last_message().unwrap();
1689        let agent_message = message.as_agent_message().unwrap();
1690        assert_eq!(
1691            agent_message.content,
1692            vec![AgentMessageContent::Text("Hello".to_string())]
1693        );
1694    });
1695    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1696}
1697
1698#[gpui::test]
1699async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1700    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1701    always_allow_tools(cx);
1702    let fake_model = model.as_fake();
1703
1704    let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1705    let environment = Rc::new(FakeThreadEnvironment {
1706        handle: handle.clone(),
1707    });
1708
1709    let mut events = thread
1710        .update(cx, |thread, cx| {
1711            thread.add_tool(crate::TerminalTool::new(
1712                thread.project().clone(),
1713                environment,
1714            ));
1715            thread.send(UserMessageId::new(), ["run a command"], cx)
1716        })
1717        .unwrap();
1718
1719    cx.run_until_parked();
1720
1721    // Simulate the model calling the terminal tool
1722    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1723        LanguageModelToolUse {
1724            id: "terminal_tool_1".into(),
1725            name: "terminal".into(),
1726            raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1727            input: json!({"command": "sleep 1000", "cd": "."}),
1728            is_input_complete: true,
1729            thought_signature: None,
1730        },
1731    ));
1732    fake_model.end_last_completion_stream();
1733
1734    // Wait for the terminal tool to start running
1735    wait_for_terminal_tool_started(&mut events, cx).await;
1736
1737    // Cancel the thread while the terminal is running
1738    thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1739
1740    // Collect remaining events, driving the executor to let cancellation complete
1741    let remaining_events = collect_events_until_stop(&mut events, cx).await;
1742
1743    // Verify the terminal was killed
1744    assert!(
1745        handle.was_killed(),
1746        "expected terminal handle to be killed on cancellation"
1747    );
1748
1749    // Verify we got a cancellation stop event
1750    assert_eq!(
1751        stop_events(remaining_events),
1752        vec![acp::StopReason::Cancelled],
1753    );
1754
1755    // Verify the tool result contains the terminal output, not just "Tool canceled by user"
1756    thread.update(cx, |thread, _cx| {
1757        let message = thread.last_message().unwrap();
1758        let agent_message = message.as_agent_message().unwrap();
1759
1760        let tool_use = agent_message
1761            .content
1762            .iter()
1763            .find_map(|content| match content {
1764                AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
1765                _ => None,
1766            })
1767            .expect("expected tool use in agent message");
1768
1769        let tool_result = agent_message
1770            .tool_results
1771            .get(&tool_use.id)
1772            .expect("expected tool result");
1773
1774        let result_text = match &tool_result.content {
1775            language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
1776            _ => panic!("expected text content in tool result"),
1777        };
1778
1779        // "partial output" comes from FakeTerminalHandle's output field
1780        assert!(
1781            result_text.contains("partial output"),
1782            "expected tool result to contain terminal output, got: {result_text}"
1783        );
1784        // Match the actual format from process_content in terminal_tool.rs
1785        assert!(
1786            result_text.contains("The user stopped this command"),
1787            "expected tool result to indicate user stopped, got: {result_text}"
1788        );
1789    });
1790
1791    // Verify we can send a new message after cancellation
1792    verify_thread_recovery(&thread, &fake_model, cx).await;
1793}
1794
1795#[gpui::test]
1796async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
1797    // This test verifies that tools which properly handle cancellation via
1798    // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
1799    // to cancellation and report that they were cancelled.
1800    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1801    always_allow_tools(cx);
1802    let fake_model = model.as_fake();
1803
1804    let (tool, was_cancelled) = CancellationAwareTool::new();
1805
1806    let mut events = thread
1807        .update(cx, |thread, cx| {
1808            thread.add_tool(tool);
1809            thread.send(
1810                UserMessageId::new(),
1811                ["call the cancellation aware tool"],
1812                cx,
1813            )
1814        })
1815        .unwrap();
1816
1817    cx.run_until_parked();
1818
1819    // Simulate the model calling the cancellation-aware tool
1820    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1821        LanguageModelToolUse {
1822            id: "cancellation_aware_1".into(),
1823            name: "cancellation_aware".into(),
1824            raw_input: r#"{}"#.into(),
1825            input: json!({}),
1826            is_input_complete: true,
1827            thought_signature: None,
1828        },
1829    ));
1830    fake_model.end_last_completion_stream();
1831
1832    cx.run_until_parked();
1833
1834    // Wait for the tool call to be reported
1835    let mut tool_started = false;
1836    let deadline = cx.executor().num_cpus() * 100;
1837    for _ in 0..deadline {
1838        cx.run_until_parked();
1839
1840        while let Some(Some(event)) = events.next().now_or_never() {
1841            if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
1842                if tool_call.title == "Cancellation Aware Tool" {
1843                    tool_started = true;
1844                    break;
1845                }
1846            }
1847        }
1848
1849        if tool_started {
1850            break;
1851        }
1852
1853        cx.background_executor
1854            .timer(Duration::from_millis(10))
1855            .await;
1856    }
1857    assert!(tool_started, "expected cancellation aware tool to start");
1858
1859    // Cancel the thread and wait for it to complete
1860    let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
1861
1862    // The cancel task should complete promptly because the tool handles cancellation
1863    let timeout = cx.background_executor.timer(Duration::from_secs(5));
1864    futures::select! {
1865        _ = cancel_task.fuse() => {}
1866        _ = timeout.fuse() => {
1867            panic!("cancel task timed out - tool did not respond to cancellation");
1868        }
1869    }
1870
1871    // Verify the tool detected cancellation via its flag
1872    assert!(
1873        was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
1874        "tool should have detected cancellation via event_stream.cancelled_by_user()"
1875    );
1876
1877    // Collect remaining events
1878    let remaining_events = collect_events_until_stop(&mut events, cx).await;
1879
1880    // Verify we got a cancellation stop event
1881    assert_eq!(
1882        stop_events(remaining_events),
1883        vec![acp::StopReason::Cancelled],
1884    );
1885
1886    // Verify we can send a new message after cancellation
1887    verify_thread_recovery(&thread, &fake_model, cx).await;
1888}
1889
1890/// Helper to verify thread can recover after cancellation by sending a simple message.
1891async fn verify_thread_recovery(
1892    thread: &Entity<Thread>,
1893    fake_model: &FakeLanguageModel,
1894    cx: &mut TestAppContext,
1895) {
1896    let events = thread
1897        .update(cx, |thread, cx| {
1898            thread.send(
1899                UserMessageId::new(),
1900                ["Testing: reply with 'Hello' then stop."],
1901                cx,
1902            )
1903        })
1904        .unwrap();
1905    cx.run_until_parked();
1906    fake_model.send_last_completion_stream_text_chunk("Hello");
1907    fake_model
1908        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1909    fake_model.end_last_completion_stream();
1910
1911    let events = events.collect::<Vec<_>>().await;
1912    thread.update(cx, |thread, _cx| {
1913        let message = thread.last_message().unwrap();
1914        let agent_message = message.as_agent_message().unwrap();
1915        assert_eq!(
1916            agent_message.content,
1917            vec![AgentMessageContent::Text("Hello".to_string())]
1918        );
1919    });
1920    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1921}
1922
1923/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
1924async fn wait_for_terminal_tool_started(
1925    events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1926    cx: &mut TestAppContext,
1927) {
1928    let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
1929    for _ in 0..deadline {
1930        cx.run_until_parked();
1931
1932        while let Some(Some(event)) = events.next().now_or_never() {
1933            if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1934                update,
1935            ))) = &event
1936            {
1937                if update.fields.content.as_ref().is_some_and(|content| {
1938                    content
1939                        .iter()
1940                        .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
1941                }) {
1942                    return;
1943                }
1944            }
1945        }
1946
1947        cx.background_executor
1948            .timer(Duration::from_millis(10))
1949            .await;
1950    }
1951    panic!("terminal tool did not start within the expected time");
1952}
1953
1954/// Collects events until a Stop event is received, driving the executor to completion.
1955async fn collect_events_until_stop(
1956    events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1957    cx: &mut TestAppContext,
1958) -> Vec<Result<ThreadEvent>> {
1959    let mut collected = Vec::new();
1960    let deadline = cx.executor().num_cpus() * 200;
1961
1962    for _ in 0..deadline {
1963        cx.executor().advance_clock(Duration::from_millis(10));
1964        cx.run_until_parked();
1965
1966        while let Some(Some(event)) = events.next().now_or_never() {
1967            let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
1968            collected.push(event);
1969            if is_stop {
1970                return collected;
1971            }
1972        }
1973    }
1974    panic!(
1975        "did not receive Stop event within the expected time; collected {} events",
1976        collected.len()
1977    );
1978}
1979
1980#[gpui::test]
1981async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
1982    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1983    always_allow_tools(cx);
1984    let fake_model = model.as_fake();
1985
1986    let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1987    let environment = Rc::new(FakeThreadEnvironment {
1988        handle: handle.clone(),
1989    });
1990
1991    let message_id = UserMessageId::new();
1992    let mut events = thread
1993        .update(cx, |thread, cx| {
1994            thread.add_tool(crate::TerminalTool::new(
1995                thread.project().clone(),
1996                environment,
1997            ));
1998            thread.send(message_id.clone(), ["run a command"], cx)
1999        })
2000        .unwrap();
2001
2002    cx.run_until_parked();
2003
2004    // Simulate the model calling the terminal tool
2005    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2006        LanguageModelToolUse {
2007            id: "terminal_tool_1".into(),
2008            name: "terminal".into(),
2009            raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2010            input: json!({"command": "sleep 1000", "cd": "."}),
2011            is_input_complete: true,
2012            thought_signature: None,
2013        },
2014    ));
2015    fake_model.end_last_completion_stream();
2016
2017    // Wait for the terminal tool to start running
2018    wait_for_terminal_tool_started(&mut events, cx).await;
2019
2020    // Truncate the thread while the terminal is running
2021    thread
2022        .update(cx, |thread, cx| thread.truncate(message_id, cx))
2023        .unwrap();
2024
2025    // Drive the executor to let cancellation complete
2026    let _ = collect_events_until_stop(&mut events, cx).await;
2027
2028    // Verify the terminal was killed
2029    assert!(
2030        handle.was_killed(),
2031        "expected terminal handle to be killed on truncate"
2032    );
2033
2034    // Verify the thread is empty after truncation
2035    thread.update(cx, |thread, _cx| {
2036        assert_eq!(
2037            thread.to_markdown(),
2038            "",
2039            "expected thread to be empty after truncating the only message"
2040        );
2041    });
2042
2043    // Verify we can send a new message after truncation
2044    verify_thread_recovery(&thread, &fake_model, cx).await;
2045}
2046
2047#[gpui::test]
2048async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
2049    // Tests that cancellation properly kills all running terminal tools when multiple are active.
2050    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2051    always_allow_tools(cx);
2052    let fake_model = model.as_fake();
2053
2054    let environment = Rc::new(MultiTerminalEnvironment::new());
2055
2056    let mut events = thread
2057        .update(cx, |thread, cx| {
2058            thread.add_tool(crate::TerminalTool::new(
2059                thread.project().clone(),
2060                environment.clone(),
2061            ));
2062            thread.send(UserMessageId::new(), ["run multiple commands"], cx)
2063        })
2064        .unwrap();
2065
2066    cx.run_until_parked();
2067
2068    // Simulate the model calling two terminal tools
2069    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2070        LanguageModelToolUse {
2071            id: "terminal_tool_1".into(),
2072            name: "terminal".into(),
2073            raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2074            input: json!({"command": "sleep 1000", "cd": "."}),
2075            is_input_complete: true,
2076            thought_signature: None,
2077        },
2078    ));
2079    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2080        LanguageModelToolUse {
2081            id: "terminal_tool_2".into(),
2082            name: "terminal".into(),
2083            raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
2084            input: json!({"command": "sleep 2000", "cd": "."}),
2085            is_input_complete: true,
2086            thought_signature: None,
2087        },
2088    ));
2089    fake_model.end_last_completion_stream();
2090
2091    // Wait for both terminal tools to start by counting terminal content updates
2092    let mut terminals_started = 0;
2093    let deadline = cx.executor().num_cpus() * 100;
2094    for _ in 0..deadline {
2095        cx.run_until_parked();
2096
2097        while let Some(Some(event)) = events.next().now_or_never() {
2098            if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
2099                update,
2100            ))) = &event
2101            {
2102                if update.fields.content.as_ref().is_some_and(|content| {
2103                    content
2104                        .iter()
2105                        .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
2106                }) {
2107                    terminals_started += 1;
2108                    if terminals_started >= 2 {
2109                        break;
2110                    }
2111                }
2112            }
2113        }
2114        if terminals_started >= 2 {
2115            break;
2116        }
2117
2118        cx.background_executor
2119            .timer(Duration::from_millis(10))
2120            .await;
2121    }
2122    assert!(
2123        terminals_started >= 2,
2124        "expected 2 terminal tools to start, got {terminals_started}"
2125    );
2126
2127    // Cancel the thread while both terminals are running
2128    thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
2129
2130    // Collect remaining events
2131    let remaining_events = collect_events_until_stop(&mut events, cx).await;
2132
2133    // Verify both terminal handles were killed
2134    let handles = environment.handles();
2135    assert_eq!(
2136        handles.len(),
2137        2,
2138        "expected 2 terminal handles to be created"
2139    );
2140    assert!(
2141        handles[0].was_killed(),
2142        "expected first terminal handle to be killed on cancellation"
2143    );
2144    assert!(
2145        handles[1].was_killed(),
2146        "expected second terminal handle to be killed on cancellation"
2147    );
2148
2149    // Verify we got a cancellation stop event
2150    assert_eq!(
2151        stop_events(remaining_events),
2152        vec![acp::StopReason::Cancelled],
2153    );
2154}
2155
2156#[gpui::test]
2157async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
2158    // Tests that clicking the stop button on the terminal card (as opposed to the main
2159    // cancel button) properly reports user stopped via the was_stopped_by_user path.
2160    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2161    always_allow_tools(cx);
2162    let fake_model = model.as_fake();
2163
2164    let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2165    let environment = Rc::new(FakeThreadEnvironment {
2166        handle: handle.clone(),
2167    });
2168
2169    let mut events = thread
2170        .update(cx, |thread, cx| {
2171            thread.add_tool(crate::TerminalTool::new(
2172                thread.project().clone(),
2173                environment,
2174            ));
2175            thread.send(UserMessageId::new(), ["run a command"], cx)
2176        })
2177        .unwrap();
2178
2179    cx.run_until_parked();
2180
2181    // Simulate the model calling the terminal tool
2182    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2183        LanguageModelToolUse {
2184            id: "terminal_tool_1".into(),
2185            name: "terminal".into(),
2186            raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
2187            input: json!({"command": "sleep 1000", "cd": "."}),
2188            is_input_complete: true,
2189            thought_signature: None,
2190        },
2191    ));
2192    fake_model.end_last_completion_stream();
2193
2194    // Wait for the terminal tool to start running
2195    wait_for_terminal_tool_started(&mut events, cx).await;
2196
2197    // Simulate user clicking stop on the terminal card itself.
2198    // This sets the flag and signals exit (simulating what the real UI would do).
2199    handle.set_stopped_by_user(true);
2200    handle.killed.store(true, Ordering::SeqCst);
2201    handle.signal_exit();
2202
2203    // Wait for the tool to complete
2204    cx.run_until_parked();
2205
2206    // The thread continues after tool completion - simulate the model ending its turn
2207    fake_model
2208        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2209    fake_model.end_last_completion_stream();
2210
2211    // Collect remaining events
2212    let remaining_events = collect_events_until_stop(&mut events, cx).await;
2213
2214    // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2215    assert_eq!(
2216        stop_events(remaining_events),
2217        vec![acp::StopReason::EndTurn],
2218    );
2219
2220    // Verify the tool result indicates user stopped
2221    thread.update(cx, |thread, _cx| {
2222        let message = thread.last_message().unwrap();
2223        let agent_message = message.as_agent_message().unwrap();
2224
2225        let tool_use = agent_message
2226            .content
2227            .iter()
2228            .find_map(|content| match content {
2229                AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2230                _ => None,
2231            })
2232            .expect("expected tool use in agent message");
2233
2234        let tool_result = agent_message
2235            .tool_results
2236            .get(&tool_use.id)
2237            .expect("expected tool result");
2238
2239        let result_text = match &tool_result.content {
2240            language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2241            _ => panic!("expected text content in tool result"),
2242        };
2243
2244        assert!(
2245            result_text.contains("The user stopped this command"),
2246            "expected tool result to indicate user stopped, got: {result_text}"
2247        );
2248    });
2249}
2250
2251#[gpui::test]
2252async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2253    // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2254    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2255    always_allow_tools(cx);
2256    let fake_model = model.as_fake();
2257
2258    let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2259    let environment = Rc::new(FakeThreadEnvironment {
2260        handle: handle.clone(),
2261    });
2262
2263    let mut events = thread
2264        .update(cx, |thread, cx| {
2265            thread.add_tool(crate::TerminalTool::new(
2266                thread.project().clone(),
2267                environment,
2268            ));
2269            thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2270        })
2271        .unwrap();
2272
2273    cx.run_until_parked();
2274
2275    // Simulate the model calling the terminal tool with a short timeout
2276    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2277        LanguageModelToolUse {
2278            id: "terminal_tool_1".into(),
2279            name: "terminal".into(),
2280            raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2281            input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2282            is_input_complete: true,
2283            thought_signature: None,
2284        },
2285    ));
2286    fake_model.end_last_completion_stream();
2287
2288    // Wait for the terminal tool to start running
2289    wait_for_terminal_tool_started(&mut events, cx).await;
2290
2291    // Advance clock past the timeout
2292    cx.executor().advance_clock(Duration::from_millis(200));
2293    cx.run_until_parked();
2294
2295    // The thread continues after tool completion - simulate the model ending its turn
2296    fake_model
2297        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2298    fake_model.end_last_completion_stream();
2299
2300    // Collect remaining events
2301    let remaining_events = collect_events_until_stop(&mut events, cx).await;
2302
2303    // Verify the terminal was killed due to timeout
2304    assert!(
2305        handle.was_killed(),
2306        "expected terminal handle to be killed on timeout"
2307    );
2308
2309    // Verify we got an EndTurn (the tool completed, just with timeout)
2310    assert_eq!(
2311        stop_events(remaining_events),
2312        vec![acp::StopReason::EndTurn],
2313    );
2314
2315    // Verify the tool result indicates timeout, not user stopped
2316    thread.update(cx, |thread, _cx| {
2317        let message = thread.last_message().unwrap();
2318        let agent_message = message.as_agent_message().unwrap();
2319
2320        let tool_use = agent_message
2321            .content
2322            .iter()
2323            .find_map(|content| match content {
2324                AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2325                _ => None,
2326            })
2327            .expect("expected tool use in agent message");
2328
2329        let tool_result = agent_message
2330            .tool_results
2331            .get(&tool_use.id)
2332            .expect("expected tool result");
2333
2334        let result_text = match &tool_result.content {
2335            language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2336            _ => panic!("expected text content in tool result"),
2337        };
2338
2339        assert!(
2340            result_text.contains("timed out"),
2341            "expected tool result to indicate timeout, got: {result_text}"
2342        );
2343        assert!(
2344            !result_text.contains("The user stopped"),
2345            "tool result should not mention user stopped when it timed out, got: {result_text}"
2346        );
2347    });
2348}
2349
2350#[gpui::test]
2351async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2352    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2353    let fake_model = model.as_fake();
2354
2355    let events_1 = thread
2356        .update(cx, |thread, cx| {
2357            thread.send(UserMessageId::new(), ["Hello 1"], cx)
2358        })
2359        .unwrap();
2360    cx.run_until_parked();
2361    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2362    cx.run_until_parked();
2363
2364    let events_2 = thread
2365        .update(cx, |thread, cx| {
2366            thread.send(UserMessageId::new(), ["Hello 2"], cx)
2367        })
2368        .unwrap();
2369    cx.run_until_parked();
2370    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2371    fake_model
2372        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2373    fake_model.end_last_completion_stream();
2374
2375    let events_1 = events_1.collect::<Vec<_>>().await;
2376    assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2377    let events_2 = events_2.collect::<Vec<_>>().await;
2378    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2379}
2380
2381#[gpui::test]
2382async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2383    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2384    let fake_model = model.as_fake();
2385
2386    let events_1 = thread
2387        .update(cx, |thread, cx| {
2388            thread.send(UserMessageId::new(), ["Hello 1"], cx)
2389        })
2390        .unwrap();
2391    cx.run_until_parked();
2392    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2393    fake_model
2394        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2395    fake_model.end_last_completion_stream();
2396    let events_1 = events_1.collect::<Vec<_>>().await;
2397
2398    let events_2 = thread
2399        .update(cx, |thread, cx| {
2400            thread.send(UserMessageId::new(), ["Hello 2"], cx)
2401        })
2402        .unwrap();
2403    cx.run_until_parked();
2404    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2405    fake_model
2406        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2407    fake_model.end_last_completion_stream();
2408    let events_2 = events_2.collect::<Vec<_>>().await;
2409
2410    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2411    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2412}
2413
2414#[gpui::test]
2415async fn test_refusal(cx: &mut TestAppContext) {
2416    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2417    let fake_model = model.as_fake();
2418
2419    let events = thread
2420        .update(cx, |thread, cx| {
2421            thread.send(UserMessageId::new(), ["Hello"], cx)
2422        })
2423        .unwrap();
2424    cx.run_until_parked();
2425    thread.read_with(cx, |thread, _| {
2426        assert_eq!(
2427            thread.to_markdown(),
2428            indoc! {"
2429                ## User
2430
2431                Hello
2432            "}
2433        );
2434    });
2435
2436    fake_model.send_last_completion_stream_text_chunk("Hey!");
2437    cx.run_until_parked();
2438    thread.read_with(cx, |thread, _| {
2439        assert_eq!(
2440            thread.to_markdown(),
2441            indoc! {"
2442                ## User
2443
2444                Hello
2445
2446                ## Assistant
2447
2448                Hey!
2449            "}
2450        );
2451    });
2452
2453    // If the model refuses to continue, the thread should remove all the messages after the last user message.
2454    fake_model
2455        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2456    let events = events.collect::<Vec<_>>().await;
2457    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2458    thread.read_with(cx, |thread, _| {
2459        assert_eq!(thread.to_markdown(), "");
2460    });
2461}
2462
2463#[gpui::test]
2464async fn test_truncate_first_message(cx: &mut TestAppContext) {
2465    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2466    let fake_model = model.as_fake();
2467
2468    let message_id = UserMessageId::new();
2469    thread
2470        .update(cx, |thread, cx| {
2471            thread.send(message_id.clone(), ["Hello"], cx)
2472        })
2473        .unwrap();
2474    cx.run_until_parked();
2475    thread.read_with(cx, |thread, _| {
2476        assert_eq!(
2477            thread.to_markdown(),
2478            indoc! {"
2479                ## User
2480
2481                Hello
2482            "}
2483        );
2484        assert_eq!(thread.latest_token_usage(), None);
2485    });
2486
2487    fake_model.send_last_completion_stream_text_chunk("Hey!");
2488    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2489        language_model::TokenUsage {
2490            input_tokens: 32_000,
2491            output_tokens: 16_000,
2492            cache_creation_input_tokens: 0,
2493            cache_read_input_tokens: 0,
2494        },
2495    ));
2496    cx.run_until_parked();
2497    thread.read_with(cx, |thread, _| {
2498        assert_eq!(
2499            thread.to_markdown(),
2500            indoc! {"
2501                ## User
2502
2503                Hello
2504
2505                ## Assistant
2506
2507                Hey!
2508            "}
2509        );
2510        assert_eq!(
2511            thread.latest_token_usage(),
2512            Some(acp_thread::TokenUsage {
2513                used_tokens: 32_000 + 16_000,
2514                max_tokens: 1_000_000,
2515                output_tokens: 16_000,
2516            })
2517        );
2518    });
2519
2520    thread
2521        .update(cx, |thread, cx| thread.truncate(message_id, cx))
2522        .unwrap();
2523    cx.run_until_parked();
2524    thread.read_with(cx, |thread, _| {
2525        assert_eq!(thread.to_markdown(), "");
2526        assert_eq!(thread.latest_token_usage(), None);
2527    });
2528
2529    // Ensure we can still send a new message after truncation.
2530    thread
2531        .update(cx, |thread, cx| {
2532            thread.send(UserMessageId::new(), ["Hi"], cx)
2533        })
2534        .unwrap();
2535    thread.update(cx, |thread, _cx| {
2536        assert_eq!(
2537            thread.to_markdown(),
2538            indoc! {"
2539                ## User
2540
2541                Hi
2542            "}
2543        );
2544    });
2545    cx.run_until_parked();
2546    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2547    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2548        language_model::TokenUsage {
2549            input_tokens: 40_000,
2550            output_tokens: 20_000,
2551            cache_creation_input_tokens: 0,
2552            cache_read_input_tokens: 0,
2553        },
2554    ));
2555    cx.run_until_parked();
2556    thread.read_with(cx, |thread, _| {
2557        assert_eq!(
2558            thread.to_markdown(),
2559            indoc! {"
2560                ## User
2561
2562                Hi
2563
2564                ## Assistant
2565
2566                Ahoy!
2567            "}
2568        );
2569
2570        assert_eq!(
2571            thread.latest_token_usage(),
2572            Some(acp_thread::TokenUsage {
2573                used_tokens: 40_000 + 20_000,
2574                max_tokens: 1_000_000,
2575                output_tokens: 20_000,
2576            })
2577        );
2578    });
2579}
2580
2581#[gpui::test]
2582async fn test_truncate_second_message(cx: &mut TestAppContext) {
2583    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2584    let fake_model = model.as_fake();
2585
2586    thread
2587        .update(cx, |thread, cx| {
2588            thread.send(UserMessageId::new(), ["Message 1"], cx)
2589        })
2590        .unwrap();
2591    cx.run_until_parked();
2592    fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2593    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2594        language_model::TokenUsage {
2595            input_tokens: 32_000,
2596            output_tokens: 16_000,
2597            cache_creation_input_tokens: 0,
2598            cache_read_input_tokens: 0,
2599        },
2600    ));
2601    fake_model.end_last_completion_stream();
2602    cx.run_until_parked();
2603
2604    let assert_first_message_state = |cx: &mut TestAppContext| {
2605        thread.clone().read_with(cx, |thread, _| {
2606            assert_eq!(
2607                thread.to_markdown(),
2608                indoc! {"
2609                    ## User
2610
2611                    Message 1
2612
2613                    ## Assistant
2614
2615                    Message 1 response
2616                "}
2617            );
2618
2619            assert_eq!(
2620                thread.latest_token_usage(),
2621                Some(acp_thread::TokenUsage {
2622                    used_tokens: 32_000 + 16_000,
2623                    max_tokens: 1_000_000,
2624                    output_tokens: 16_000,
2625                })
2626            );
2627        });
2628    };
2629
2630    assert_first_message_state(cx);
2631
2632    let second_message_id = UserMessageId::new();
2633    thread
2634        .update(cx, |thread, cx| {
2635            thread.send(second_message_id.clone(), ["Message 2"], cx)
2636        })
2637        .unwrap();
2638    cx.run_until_parked();
2639
2640    fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2641    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2642        language_model::TokenUsage {
2643            input_tokens: 40_000,
2644            output_tokens: 20_000,
2645            cache_creation_input_tokens: 0,
2646            cache_read_input_tokens: 0,
2647        },
2648    ));
2649    fake_model.end_last_completion_stream();
2650    cx.run_until_parked();
2651
2652    thread.read_with(cx, |thread, _| {
2653        assert_eq!(
2654            thread.to_markdown(),
2655            indoc! {"
2656                ## User
2657
2658                Message 1
2659
2660                ## Assistant
2661
2662                Message 1 response
2663
2664                ## User
2665
2666                Message 2
2667
2668                ## Assistant
2669
2670                Message 2 response
2671            "}
2672        );
2673
2674        assert_eq!(
2675            thread.latest_token_usage(),
2676            Some(acp_thread::TokenUsage {
2677                used_tokens: 40_000 + 20_000,
2678                max_tokens: 1_000_000,
2679                output_tokens: 20_000,
2680            })
2681        );
2682    });
2683
2684    thread
2685        .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2686        .unwrap();
2687    cx.run_until_parked();
2688
2689    assert_first_message_state(cx);
2690}
2691
2692#[gpui::test]
2693async fn test_title_generation(cx: &mut TestAppContext) {
2694    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2695    let fake_model = model.as_fake();
2696
2697    let summary_model = Arc::new(FakeLanguageModel::default());
2698    thread.update(cx, |thread, cx| {
2699        thread.set_summarization_model(Some(summary_model.clone()), cx)
2700    });
2701
2702    let send = thread
2703        .update(cx, |thread, cx| {
2704            thread.send(UserMessageId::new(), ["Hello"], cx)
2705        })
2706        .unwrap();
2707    cx.run_until_parked();
2708
2709    fake_model.send_last_completion_stream_text_chunk("Hey!");
2710    fake_model.end_last_completion_stream();
2711    cx.run_until_parked();
2712    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2713
2714    // Ensure the summary model has been invoked to generate a title.
2715    summary_model.send_last_completion_stream_text_chunk("Hello ");
2716    summary_model.send_last_completion_stream_text_chunk("world\nG");
2717    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2718    summary_model.end_last_completion_stream();
2719    send.collect::<Vec<_>>().await;
2720    cx.run_until_parked();
2721    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2722
2723    // Send another message, ensuring no title is generated this time.
2724    let send = thread
2725        .update(cx, |thread, cx| {
2726            thread.send(UserMessageId::new(), ["Hello again"], cx)
2727        })
2728        .unwrap();
2729    cx.run_until_parked();
2730    fake_model.send_last_completion_stream_text_chunk("Hey again!");
2731    fake_model.end_last_completion_stream();
2732    cx.run_until_parked();
2733    assert_eq!(summary_model.pending_completions(), Vec::new());
2734    send.collect::<Vec<_>>().await;
2735    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2736}
2737
2738#[gpui::test]
2739async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2740    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2741    let fake_model = model.as_fake();
2742
2743    let _events = thread
2744        .update(cx, |thread, cx| {
2745            thread.add_tool(ToolRequiringPermission);
2746            thread.add_tool(EchoTool);
2747            thread.send(UserMessageId::new(), ["Hey!"], cx)
2748        })
2749        .unwrap();
2750    cx.run_until_parked();
2751
2752    let permission_tool_use = LanguageModelToolUse {
2753        id: "tool_id_1".into(),
2754        name: ToolRequiringPermission::name().into(),
2755        raw_input: "{}".into(),
2756        input: json!({}),
2757        is_input_complete: true,
2758        thought_signature: None,
2759    };
2760    let echo_tool_use = LanguageModelToolUse {
2761        id: "tool_id_2".into(),
2762        name: EchoTool::name().into(),
2763        raw_input: json!({"text": "test"}).to_string(),
2764        input: json!({"text": "test"}),
2765        is_input_complete: true,
2766        thought_signature: None,
2767    };
2768    fake_model.send_last_completion_stream_text_chunk("Hi!");
2769    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2770        permission_tool_use,
2771    ));
2772    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2773        echo_tool_use.clone(),
2774    ));
2775    fake_model.end_last_completion_stream();
2776    cx.run_until_parked();
2777
2778    // Ensure pending tools are skipped when building a request.
2779    let request = thread
2780        .read_with(cx, |thread, cx| {
2781            thread.build_completion_request(CompletionIntent::EditFile, cx)
2782        })
2783        .unwrap();
2784    assert_eq!(
2785        request.messages[1..],
2786        vec![
2787            LanguageModelRequestMessage {
2788                role: Role::User,
2789                content: vec!["Hey!".into()],
2790                cache: true,
2791                reasoning_details: None,
2792            },
2793            LanguageModelRequestMessage {
2794                role: Role::Assistant,
2795                content: vec![
2796                    MessageContent::Text("Hi!".into()),
2797                    MessageContent::ToolUse(echo_tool_use.clone())
2798                ],
2799                cache: false,
2800                reasoning_details: None,
2801            },
2802            LanguageModelRequestMessage {
2803                role: Role::User,
2804                content: vec![MessageContent::ToolResult(LanguageModelToolResult {
2805                    tool_use_id: echo_tool_use.id.clone(),
2806                    tool_name: echo_tool_use.name,
2807                    is_error: false,
2808                    content: "test".into(),
2809                    output: Some("test".into())
2810                })],
2811                cache: false,
2812                reasoning_details: None,
2813            },
2814        ],
2815    );
2816}
2817
2818#[gpui::test]
2819async fn test_agent_connection(cx: &mut TestAppContext) {
2820    cx.update(settings::init);
2821    let templates = Templates::new();
2822
2823    // Initialize language model system with test provider
2824    cx.update(|cx| {
2825        gpui_tokio::init(cx);
2826
2827        let http_client = FakeHttpClient::with_404_response();
2828        let clock = Arc::new(clock::FakeSystemClock::new());
2829        let client = Client::new(clock, http_client, cx);
2830        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2831        language_model::init(client.clone(), cx);
2832        language_models::init(user_store, client.clone(), cx);
2833        LanguageModelRegistry::test(cx);
2834    });
2835    cx.executor().forbid_parking();
2836
2837    // Create a project for new_thread
2838    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
2839    fake_fs.insert_tree(path!("/test"), json!({})).await;
2840    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
2841    let cwd = Path::new("/test");
2842    let thread_store = cx.new(|cx| ThreadStore::new(cx));
2843
2844    // Create agent and connection
2845    let agent = NativeAgent::new(
2846        project.clone(),
2847        thread_store,
2848        templates.clone(),
2849        None,
2850        fake_fs.clone(),
2851        &mut cx.to_async(),
2852    )
2853    .await
2854    .unwrap();
2855    let connection = NativeAgentConnection(agent.clone());
2856
2857    // Create a thread using new_thread
2858    let connection_rc = Rc::new(connection.clone());
2859    let acp_thread = cx
2860        .update(|cx| connection_rc.new_thread(project, cwd, cx))
2861        .await
2862        .expect("new_thread should succeed");
2863
2864    // Get the session_id from the AcpThread
2865    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2866
2867    // Test model_selector returns Some
2868    let selector_opt = connection.model_selector(&session_id);
2869    assert!(
2870        selector_opt.is_some(),
2871        "agent should always support ModelSelector"
2872    );
2873    let selector = selector_opt.unwrap();
2874
2875    // Test list_models
2876    let listed_models = cx
2877        .update(|cx| selector.list_models(cx))
2878        .await
2879        .expect("list_models should succeed");
2880    let AgentModelList::Grouped(listed_models) = listed_models else {
2881        panic!("Unexpected model list type");
2882    };
2883    assert!(!listed_models.is_empty(), "should have at least one model");
2884    assert_eq!(
2885        listed_models[&AgentModelGroupName("Fake".into())][0]
2886            .id
2887            .0
2888            .as_ref(),
2889        "fake/fake"
2890    );
2891
2892    // Test selected_model returns the default
2893    let model = cx
2894        .update(|cx| selector.selected_model(cx))
2895        .await
2896        .expect("selected_model should succeed");
2897    let model = cx
2898        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
2899        .unwrap();
2900    let model = model.as_fake();
2901    assert_eq!(model.id().0, "fake", "should return default model");
2902
2903    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
2904    cx.run_until_parked();
2905    model.send_last_completion_stream_text_chunk("def");
2906    cx.run_until_parked();
2907    acp_thread.read_with(cx, |thread, cx| {
2908        assert_eq!(
2909            thread.to_markdown(cx),
2910            indoc! {"
2911                ## User
2912
2913                abc
2914
2915                ## Assistant
2916
2917                def
2918
2919            "}
2920        )
2921    });
2922
2923    // Test cancel
2924    cx.update(|cx| connection.cancel(&session_id, cx));
2925    request.await.expect("prompt should fail gracefully");
2926
2927    // Ensure that dropping the ACP thread causes the native thread to be
2928    // dropped as well.
2929    cx.update(|_| drop(acp_thread));
2930    let result = cx
2931        .update(|cx| {
2932            connection.prompt(
2933                Some(acp_thread::UserMessageId::new()),
2934                acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
2935                cx,
2936            )
2937        })
2938        .await;
2939    assert_eq!(
2940        result.as_ref().unwrap_err().to_string(),
2941        "Session not found",
2942        "unexpected result: {:?}",
2943        result
2944    );
2945}
2946
2947#[gpui::test]
2948async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2949    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2950    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
2951    let fake_model = model.as_fake();
2952
2953    let mut events = thread
2954        .update(cx, |thread, cx| {
2955            thread.send(UserMessageId::new(), ["Think"], cx)
2956        })
2957        .unwrap();
2958    cx.run_until_parked();
2959
2960    // Simulate streaming partial input.
2961    let input = json!({});
2962    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2963        LanguageModelToolUse {
2964            id: "1".into(),
2965            name: ThinkingTool::name().into(),
2966            raw_input: input.to_string(),
2967            input,
2968            is_input_complete: false,
2969            thought_signature: None,
2970        },
2971    ));
2972
2973    // Input streaming completed
2974    let input = json!({ "content": "Thinking hard!" });
2975    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2976        LanguageModelToolUse {
2977            id: "1".into(),
2978            name: "thinking".into(),
2979            raw_input: input.to_string(),
2980            input,
2981            is_input_complete: true,
2982            thought_signature: None,
2983        },
2984    ));
2985    fake_model.end_last_completion_stream();
2986    cx.run_until_parked();
2987
2988    let tool_call = expect_tool_call(&mut events).await;
2989    assert_eq!(
2990        tool_call,
2991        acp::ToolCall::new("1", "Thinking")
2992            .kind(acp::ToolKind::Think)
2993            .raw_input(json!({}))
2994            .meta(acp::Meta::from_iter([(
2995                "tool_name".into(),
2996                "thinking".into()
2997            )]))
2998    );
2999    let update = expect_tool_call_update_fields(&mut events).await;
3000    assert_eq!(
3001        update,
3002        acp::ToolCallUpdate::new(
3003            "1",
3004            acp::ToolCallUpdateFields::new()
3005                .title("Thinking")
3006                .kind(acp::ToolKind::Think)
3007                .raw_input(json!({ "content": "Thinking hard!"}))
3008        )
3009    );
3010    let update = expect_tool_call_update_fields(&mut events).await;
3011    assert_eq!(
3012        update,
3013        acp::ToolCallUpdate::new(
3014            "1",
3015            acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
3016        )
3017    );
3018    let update = expect_tool_call_update_fields(&mut events).await;
3019    assert_eq!(
3020        update,
3021        acp::ToolCallUpdate::new(
3022            "1",
3023            acp::ToolCallUpdateFields::new().content(vec!["Thinking hard!".into()])
3024        )
3025    );
3026    let update = expect_tool_call_update_fields(&mut events).await;
3027    assert_eq!(
3028        update,
3029        acp::ToolCallUpdate::new(
3030            "1",
3031            acp::ToolCallUpdateFields::new()
3032                .status(acp::ToolCallStatus::Completed)
3033                .raw_output("Finished thinking.")
3034        )
3035    );
3036}
3037
3038#[gpui::test]
3039async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
3040    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3041    let fake_model = model.as_fake();
3042
3043    let mut events = thread
3044        .update(cx, |thread, cx| {
3045            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
3046            thread.send(UserMessageId::new(), ["Hello!"], cx)
3047        })
3048        .unwrap();
3049    cx.run_until_parked();
3050
3051    fake_model.send_last_completion_stream_text_chunk("Hey!");
3052    fake_model.end_last_completion_stream();
3053
3054    let mut retry_events = Vec::new();
3055    while let Some(Ok(event)) = events.next().await {
3056        match event {
3057            ThreadEvent::Retry(retry_status) => {
3058                retry_events.push(retry_status);
3059            }
3060            ThreadEvent::Stop(..) => break,
3061            _ => {}
3062        }
3063    }
3064
3065    assert_eq!(retry_events.len(), 0);
3066    thread.read_with(cx, |thread, _cx| {
3067        assert_eq!(
3068            thread.to_markdown(),
3069            indoc! {"
3070                ## User
3071
3072                Hello!
3073
3074                ## Assistant
3075
3076                Hey!
3077            "}
3078        )
3079    });
3080}
3081
3082#[gpui::test]
3083async fn test_send_retry_on_error(cx: &mut TestAppContext) {
3084    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3085    let fake_model = model.as_fake();
3086
3087    let mut events = thread
3088        .update(cx, |thread, cx| {
3089            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
3090            thread.send(UserMessageId::new(), ["Hello!"], cx)
3091        })
3092        .unwrap();
3093    cx.run_until_parked();
3094
3095    fake_model.send_last_completion_stream_text_chunk("Hey,");
3096    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3097        provider: LanguageModelProviderName::new("Anthropic"),
3098        retry_after: Some(Duration::from_secs(3)),
3099    });
3100    fake_model.end_last_completion_stream();
3101
3102    cx.executor().advance_clock(Duration::from_secs(3));
3103    cx.run_until_parked();
3104
3105    fake_model.send_last_completion_stream_text_chunk("there!");
3106    fake_model.end_last_completion_stream();
3107    cx.run_until_parked();
3108
3109    let mut retry_events = Vec::new();
3110    while let Some(Ok(event)) = events.next().await {
3111        match event {
3112            ThreadEvent::Retry(retry_status) => {
3113                retry_events.push(retry_status);
3114            }
3115            ThreadEvent::Stop(..) => break,
3116            _ => {}
3117        }
3118    }
3119
3120    assert_eq!(retry_events.len(), 1);
3121    assert!(matches!(
3122        retry_events[0],
3123        acp_thread::RetryStatus { attempt: 1, .. }
3124    ));
3125    thread.read_with(cx, |thread, _cx| {
3126        assert_eq!(
3127            thread.to_markdown(),
3128            indoc! {"
3129                ## User
3130
3131                Hello!
3132
3133                ## Assistant
3134
3135                Hey,
3136
3137                [resume]
3138
3139                ## Assistant
3140
3141                there!
3142            "}
3143        )
3144    });
3145}
3146
3147#[gpui::test]
3148async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
3149    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3150    let fake_model = model.as_fake();
3151
3152    let events = thread
3153        .update(cx, |thread, cx| {
3154            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
3155            thread.add_tool(EchoTool);
3156            thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
3157        })
3158        .unwrap();
3159    cx.run_until_parked();
3160
3161    let tool_use_1 = LanguageModelToolUse {
3162        id: "tool_1".into(),
3163        name: EchoTool::name().into(),
3164        raw_input: json!({"text": "test"}).to_string(),
3165        input: json!({"text": "test"}),
3166        is_input_complete: true,
3167        thought_signature: None,
3168    };
3169    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
3170        tool_use_1.clone(),
3171    ));
3172    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
3173        provider: LanguageModelProviderName::new("Anthropic"),
3174        retry_after: Some(Duration::from_secs(3)),
3175    });
3176    fake_model.end_last_completion_stream();
3177
3178    cx.executor().advance_clock(Duration::from_secs(3));
3179    let completion = fake_model.pending_completions().pop().unwrap();
3180    assert_eq!(
3181        completion.messages[1..],
3182        vec![
3183            LanguageModelRequestMessage {
3184                role: Role::User,
3185                content: vec!["Call the echo tool!".into()],
3186                cache: false,
3187                reasoning_details: None,
3188            },
3189            LanguageModelRequestMessage {
3190                role: Role::Assistant,
3191                content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3192                cache: false,
3193                reasoning_details: None,
3194            },
3195            LanguageModelRequestMessage {
3196                role: Role::User,
3197                content: vec![language_model::MessageContent::ToolResult(
3198                    LanguageModelToolResult {
3199                        tool_use_id: tool_use_1.id.clone(),
3200                        tool_name: tool_use_1.name.clone(),
3201                        is_error: false,
3202                        content: "test".into(),
3203                        output: Some("test".into())
3204                    }
3205                )],
3206                cache: true,
3207                reasoning_details: None,
3208            },
3209        ]
3210    );
3211
3212    fake_model.send_last_completion_stream_text_chunk("Done");
3213    fake_model.end_last_completion_stream();
3214    cx.run_until_parked();
3215    events.collect::<Vec<_>>().await;
3216    thread.read_with(cx, |thread, _cx| {
3217        assert_eq!(
3218            thread.last_message(),
3219            Some(Message::Agent(AgentMessage {
3220                content: vec![AgentMessageContent::Text("Done".into())],
3221                tool_results: IndexMap::default(),
3222                reasoning_details: None,
3223            }))
3224        );
3225    })
3226}
3227
3228#[gpui::test]
3229async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3230    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3231    let fake_model = model.as_fake();
3232
3233    let mut events = thread
3234        .update(cx, |thread, cx| {
3235            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
3236            thread.send(UserMessageId::new(), ["Hello!"], cx)
3237        })
3238        .unwrap();
3239    cx.run_until_parked();
3240
3241    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3242        fake_model.send_last_completion_stream_error(
3243            LanguageModelCompletionError::ServerOverloaded {
3244                provider: LanguageModelProviderName::new("Anthropic"),
3245                retry_after: Some(Duration::from_secs(3)),
3246            },
3247        );
3248        fake_model.end_last_completion_stream();
3249        cx.executor().advance_clock(Duration::from_secs(3));
3250        cx.run_until_parked();
3251    }
3252
3253    let mut errors = Vec::new();
3254    let mut retry_events = Vec::new();
3255    while let Some(event) = events.next().await {
3256        match event {
3257            Ok(ThreadEvent::Retry(retry_status)) => {
3258                retry_events.push(retry_status);
3259            }
3260            Ok(ThreadEvent::Stop(..)) => break,
3261            Err(error) => errors.push(error),
3262            _ => {}
3263        }
3264    }
3265
3266    assert_eq!(
3267        retry_events.len(),
3268        crate::thread::MAX_RETRY_ATTEMPTS as usize
3269    );
3270    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3271        assert_eq!(retry_events[i].attempt, i + 1);
3272    }
3273    assert_eq!(errors.len(), 1);
3274    let error = errors[0]
3275        .downcast_ref::<LanguageModelCompletionError>()
3276        .unwrap();
3277    assert!(matches!(
3278        error,
3279        LanguageModelCompletionError::ServerOverloaded { .. }
3280    ));
3281}
3282
3283/// Filters out the stop events for asserting against in tests
3284fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3285    result_events
3286        .into_iter()
3287        .filter_map(|event| match event.unwrap() {
3288            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3289            _ => None,
3290        })
3291        .collect()
3292}
3293
3294struct ThreadTest {
3295    model: Arc<dyn LanguageModel>,
3296    thread: Entity<Thread>,
3297    project_context: Entity<ProjectContext>,
3298    context_server_store: Entity<ContextServerStore>,
3299    fs: Arc<FakeFs>,
3300}
3301
3302enum TestModel {
3303    Sonnet4,
3304    Fake,
3305}
3306
3307impl TestModel {
3308    fn id(&self) -> LanguageModelId {
3309        match self {
3310            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3311            TestModel::Fake => unreachable!(),
3312        }
3313    }
3314}
3315
3316async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3317    cx.executor().allow_parking();
3318
3319    let fs = FakeFs::new(cx.background_executor.clone());
3320    fs.create_dir(paths::settings_file().parent().unwrap())
3321        .await
3322        .unwrap();
3323    fs.insert_file(
3324        paths::settings_file(),
3325        json!({
3326            "agent": {
3327                "default_profile": "test-profile",
3328                "profiles": {
3329                    "test-profile": {
3330                        "name": "Test Profile",
3331                        "tools": {
3332                            EchoTool::name(): true,
3333                            DelayTool::name(): true,
3334                            WordListTool::name(): true,
3335                            ToolRequiringPermission::name(): true,
3336                            InfiniteTool::name(): true,
3337                            CancellationAwareTool::name(): true,
3338                            ThinkingTool::name(): true,
3339                            "terminal": true,
3340                        }
3341                    }
3342                }
3343            }
3344        })
3345        .to_string()
3346        .into_bytes(),
3347    )
3348    .await;
3349
3350    cx.update(|cx| {
3351        settings::init(cx);
3352
3353        match model {
3354            TestModel::Fake => {}
3355            TestModel::Sonnet4 => {
3356                gpui_tokio::init(cx);
3357                let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3358                cx.set_http_client(Arc::new(http_client));
3359                let client = Client::production(cx);
3360                let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3361                language_model::init(client.clone(), cx);
3362                language_models::init(user_store, client.clone(), cx);
3363            }
3364        };
3365
3366        watch_settings(fs.clone(), cx);
3367    });
3368
3369    let templates = Templates::new();
3370
3371    fs.insert_tree(path!("/test"), json!({})).await;
3372    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3373
3374    let model = cx
3375        .update(|cx| {
3376            if let TestModel::Fake = model {
3377                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3378            } else {
3379                let model_id = model.id();
3380                let models = LanguageModelRegistry::read_global(cx);
3381                let model = models
3382                    .available_models(cx)
3383                    .find(|model| model.id() == model_id)
3384                    .unwrap();
3385
3386                let provider = models.provider(&model.provider_id()).unwrap();
3387                let authenticated = provider.authenticate(cx);
3388
3389                cx.spawn(async move |_cx| {
3390                    authenticated.await.unwrap();
3391                    model
3392                })
3393            }
3394        })
3395        .await;
3396
3397    let project_context = cx.new(|_cx| ProjectContext::default());
3398    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3399    let context_server_registry =
3400        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3401    let thread = cx.new(|cx| {
3402        Thread::new(
3403            project,
3404            project_context.clone(),
3405            context_server_registry,
3406            templates,
3407            Some(model.clone()),
3408            cx,
3409        )
3410    });
3411    ThreadTest {
3412        model,
3413        thread,
3414        project_context,
3415        context_server_store,
3416        fs,
3417    }
3418}
3419
3420#[cfg(test)]
3421#[ctor::ctor]
3422fn init_logger() {
3423    if std::env::var("RUST_LOG").is_ok() {
3424        env_logger::init();
3425    }
3426}
3427
3428fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3429    let fs = fs.clone();
3430    cx.spawn({
3431        async move |cx| {
3432            let mut new_settings_content_rx = settings::watch_config_file(
3433                cx.background_executor(),
3434                fs,
3435                paths::settings_file().clone(),
3436            );
3437
3438            while let Some(new_settings_content) = new_settings_content_rx.next().await {
3439                cx.update(|cx| {
3440                    SettingsStore::update_global(cx, |settings, cx| {
3441                        settings.set_user_settings(&new_settings_content, cx)
3442                    })
3443                })
3444                .ok();
3445            }
3446        }
3447    })
3448    .detach();
3449}
3450
3451fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3452    completion
3453        .tools
3454        .iter()
3455        .map(|tool| tool.name.clone())
3456        .collect()
3457}
3458
3459fn setup_context_server(
3460    name: &'static str,
3461    tools: Vec<context_server::types::Tool>,
3462    context_server_store: &Entity<ContextServerStore>,
3463    cx: &mut TestAppContext,
3464) -> mpsc::UnboundedReceiver<(
3465    context_server::types::CallToolParams,
3466    oneshot::Sender<context_server::types::CallToolResponse>,
3467)> {
3468    cx.update(|cx| {
3469        let mut settings = ProjectSettings::get_global(cx).clone();
3470        settings.context_servers.insert(
3471            name.into(),
3472            project::project_settings::ContextServerSettings::Stdio {
3473                enabled: true,
3474                command: ContextServerCommand {
3475                    path: "somebinary".into(),
3476                    args: Vec::new(),
3477                    env: None,
3478                    timeout: None,
3479                },
3480            },
3481        );
3482        ProjectSettings::override_global(settings, cx);
3483    });
3484
3485    let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3486    let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3487        .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3488            context_server::types::InitializeResponse {
3489                protocol_version: context_server::types::ProtocolVersion(
3490                    context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3491                ),
3492                server_info: context_server::types::Implementation {
3493                    name: name.into(),
3494                    version: "1.0.0".to_string(),
3495                },
3496                capabilities: context_server::types::ServerCapabilities {
3497                    tools: Some(context_server::types::ToolsCapabilities {
3498                        list_changed: Some(true),
3499                    }),
3500                    ..Default::default()
3501                },
3502                meta: None,
3503            }
3504        })
3505        .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3506            let tools = tools.clone();
3507            async move {
3508                context_server::types::ListToolsResponse {
3509                    tools,
3510                    next_cursor: None,
3511                    meta: None,
3512                }
3513            }
3514        })
3515        .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3516            let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3517            async move {
3518                let (response_tx, response_rx) = oneshot::channel();
3519                mcp_tool_calls_tx
3520                    .unbounded_send((params, response_tx))
3521                    .unwrap();
3522                response_rx.await.unwrap()
3523            }
3524        });
3525    context_server_store.update(cx, |store, cx| {
3526        store.start_server(
3527            Arc::new(ContextServer::new(
3528                ContextServerId(name.into()),
3529                Arc::new(fake_transport),
3530            )),
3531            cx,
3532        );
3533    });
3534    cx.run_until_parked();
3535    mcp_tool_calls_rx
3536}
3537
3538#[gpui::test]
3539async fn test_tokens_before_message(cx: &mut TestAppContext) {
3540    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3541    let fake_model = model.as_fake();
3542
3543    // First message
3544    let message_1_id = UserMessageId::new();
3545    thread
3546        .update(cx, |thread, cx| {
3547            thread.send(message_1_id.clone(), ["First message"], cx)
3548        })
3549        .unwrap();
3550    cx.run_until_parked();
3551
3552    // Before any response, tokens_before_message should return None for first message
3553    thread.read_with(cx, |thread, _| {
3554        assert_eq!(
3555            thread.tokens_before_message(&message_1_id),
3556            None,
3557            "First message should have no tokens before it"
3558        );
3559    });
3560
3561    // Complete first message with usage
3562    fake_model.send_last_completion_stream_text_chunk("Response 1");
3563    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3564        language_model::TokenUsage {
3565            input_tokens: 100,
3566            output_tokens: 50,
3567            cache_creation_input_tokens: 0,
3568            cache_read_input_tokens: 0,
3569        },
3570    ));
3571    fake_model.end_last_completion_stream();
3572    cx.run_until_parked();
3573
3574    // First message still has no tokens before it
3575    thread.read_with(cx, |thread, _| {
3576        assert_eq!(
3577            thread.tokens_before_message(&message_1_id),
3578            None,
3579            "First message should still have no tokens before it after response"
3580        );
3581    });
3582
3583    // Second message
3584    let message_2_id = UserMessageId::new();
3585    thread
3586        .update(cx, |thread, cx| {
3587            thread.send(message_2_id.clone(), ["Second message"], cx)
3588        })
3589        .unwrap();
3590    cx.run_until_parked();
3591
3592    // Second message should have first message's input tokens before it
3593    thread.read_with(cx, |thread, _| {
3594        assert_eq!(
3595            thread.tokens_before_message(&message_2_id),
3596            Some(100),
3597            "Second message should have 100 tokens before it (from first request)"
3598        );
3599    });
3600
3601    // Complete second message
3602    fake_model.send_last_completion_stream_text_chunk("Response 2");
3603    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3604        language_model::TokenUsage {
3605            input_tokens: 250, // Total for this request (includes previous context)
3606            output_tokens: 75,
3607            cache_creation_input_tokens: 0,
3608            cache_read_input_tokens: 0,
3609        },
3610    ));
3611    fake_model.end_last_completion_stream();
3612    cx.run_until_parked();
3613
3614    // Third message
3615    let message_3_id = UserMessageId::new();
3616    thread
3617        .update(cx, |thread, cx| {
3618            thread.send(message_3_id.clone(), ["Third message"], cx)
3619        })
3620        .unwrap();
3621    cx.run_until_parked();
3622
3623    // Third message should have second message's input tokens (250) before it
3624    thread.read_with(cx, |thread, _| {
3625        assert_eq!(
3626            thread.tokens_before_message(&message_3_id),
3627            Some(250),
3628            "Third message should have 250 tokens before it (from second request)"
3629        );
3630        // Second message should still have 100
3631        assert_eq!(
3632            thread.tokens_before_message(&message_2_id),
3633            Some(100),
3634            "Second message should still have 100 tokens before it"
3635        );
3636        // First message still has none
3637        assert_eq!(
3638            thread.tokens_before_message(&message_1_id),
3639            None,
3640            "First message should still have no tokens before it"
3641        );
3642    });
3643}
3644
3645#[gpui::test]
3646async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3647    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3648    let fake_model = model.as_fake();
3649
3650    // Set up three messages with responses
3651    let message_1_id = UserMessageId::new();
3652    thread
3653        .update(cx, |thread, cx| {
3654            thread.send(message_1_id.clone(), ["Message 1"], cx)
3655        })
3656        .unwrap();
3657    cx.run_until_parked();
3658    fake_model.send_last_completion_stream_text_chunk("Response 1");
3659    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3660        language_model::TokenUsage {
3661            input_tokens: 100,
3662            output_tokens: 50,
3663            cache_creation_input_tokens: 0,
3664            cache_read_input_tokens: 0,
3665        },
3666    ));
3667    fake_model.end_last_completion_stream();
3668    cx.run_until_parked();
3669
3670    let message_2_id = UserMessageId::new();
3671    thread
3672        .update(cx, |thread, cx| {
3673            thread.send(message_2_id.clone(), ["Message 2"], cx)
3674        })
3675        .unwrap();
3676    cx.run_until_parked();
3677    fake_model.send_last_completion_stream_text_chunk("Response 2");
3678    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3679        language_model::TokenUsage {
3680            input_tokens: 250,
3681            output_tokens: 75,
3682            cache_creation_input_tokens: 0,
3683            cache_read_input_tokens: 0,
3684        },
3685    ));
3686    fake_model.end_last_completion_stream();
3687    cx.run_until_parked();
3688
3689    // Verify initial state
3690    thread.read_with(cx, |thread, _| {
3691        assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3692    });
3693
3694    // Truncate at message 2 (removes message 2 and everything after)
3695    thread
3696        .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3697        .unwrap();
3698    cx.run_until_parked();
3699
3700    // After truncation, message_2_id no longer exists, so lookup should return None
3701    thread.read_with(cx, |thread, _| {
3702        assert_eq!(
3703            thread.tokens_before_message(&message_2_id),
3704            None,
3705            "After truncation, message 2 no longer exists"
3706        );
3707        // Message 1 still exists but has no tokens before it
3708        assert_eq!(
3709            thread.tokens_before_message(&message_1_id),
3710            None,
3711            "First message still has no tokens before it"
3712        );
3713    });
3714}
3715
3716#[gpui::test]
3717async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3718    init_test(cx);
3719
3720    let fs = FakeFs::new(cx.executor());
3721    fs.insert_tree("/root", json!({})).await;
3722    let project = Project::test(fs, ["/root".as_ref()], cx).await;
3723
3724    // Test 1: Deny rule blocks command
3725    {
3726        let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3727        let environment = Rc::new(FakeThreadEnvironment {
3728            handle: handle.clone(),
3729        });
3730
3731        cx.update(|cx| {
3732            let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3733            settings.tool_permissions.tools.insert(
3734                "terminal".into(),
3735                agent_settings::ToolRules {
3736                    default_mode: settings::ToolPermissionMode::Confirm,
3737                    always_allow: vec![],
3738                    always_deny: vec![
3739                        agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3740                    ],
3741                    always_confirm: vec![],
3742                    invalid_patterns: vec![],
3743                },
3744            );
3745            agent_settings::AgentSettings::override_global(settings, cx);
3746        });
3747
3748        #[allow(clippy::arc_with_non_send_sync)]
3749        let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3750        let (event_stream, _rx) = crate::ToolCallEventStream::test();
3751
3752        let task = cx.update(|cx| {
3753            tool.run(
3754                crate::TerminalToolInput {
3755                    command: "rm -rf /".to_string(),
3756                    cd: ".".to_string(),
3757                    timeout_ms: None,
3758                },
3759                event_stream,
3760                cx,
3761            )
3762        });
3763
3764        let result = task.await;
3765        assert!(
3766            result.is_err(),
3767            "expected command to be blocked by deny rule"
3768        );
3769        assert!(
3770            result.unwrap_err().to_string().contains("blocked"),
3771            "error should mention the command was blocked"
3772        );
3773    }
3774
3775    // Test 2: Allow rule skips confirmation (and overrides default_mode: Deny)
3776    {
3777        let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3778        let environment = Rc::new(FakeThreadEnvironment {
3779            handle: handle.clone(),
3780        });
3781
3782        cx.update(|cx| {
3783            let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3784            settings.always_allow_tool_actions = false;
3785            settings.tool_permissions.tools.insert(
3786                "terminal".into(),
3787                agent_settings::ToolRules {
3788                    default_mode: settings::ToolPermissionMode::Deny,
3789                    always_allow: vec![
3790                        agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
3791                    ],
3792                    always_deny: vec![],
3793                    always_confirm: vec![],
3794                    invalid_patterns: vec![],
3795                },
3796            );
3797            agent_settings::AgentSettings::override_global(settings, cx);
3798        });
3799
3800        #[allow(clippy::arc_with_non_send_sync)]
3801        let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3802        let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3803
3804        let task = cx.update(|cx| {
3805            tool.run(
3806                crate::TerminalToolInput {
3807                    command: "echo hello".to_string(),
3808                    cd: ".".to_string(),
3809                    timeout_ms: None,
3810                },
3811                event_stream,
3812                cx,
3813            )
3814        });
3815
3816        let update = rx.expect_update_fields().await;
3817        assert!(
3818            update.content.iter().any(|blocks| {
3819                blocks
3820                    .iter()
3821                    .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
3822            }),
3823            "expected terminal content (allow rule should skip confirmation and override default deny)"
3824        );
3825
3826        let result = task.await;
3827        assert!(
3828            result.is_ok(),
3829            "expected command to succeed without confirmation"
3830        );
3831    }
3832
3833    // Test 3: Confirm rule forces confirmation even with always_allow_tool_actions=true
3834    {
3835        let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3836        let environment = Rc::new(FakeThreadEnvironment {
3837            handle: handle.clone(),
3838        });
3839
3840        cx.update(|cx| {
3841            let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3842            settings.always_allow_tool_actions = true;
3843            settings.tool_permissions.tools.insert(
3844                "terminal".into(),
3845                agent_settings::ToolRules {
3846                    default_mode: settings::ToolPermissionMode::Allow,
3847                    always_allow: vec![],
3848                    always_deny: vec![],
3849                    always_confirm: vec![
3850                        agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
3851                    ],
3852                    invalid_patterns: vec![],
3853                },
3854            );
3855            agent_settings::AgentSettings::override_global(settings, cx);
3856        });
3857
3858        #[allow(clippy::arc_with_non_send_sync)]
3859        let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3860        let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3861
3862        let _task = cx.update(|cx| {
3863            tool.run(
3864                crate::TerminalToolInput {
3865                    command: "sudo rm file".to_string(),
3866                    cd: ".".to_string(),
3867                    timeout_ms: None,
3868                },
3869                event_stream,
3870                cx,
3871            )
3872        });
3873
3874        let auth = rx.expect_authorization().await;
3875        assert!(
3876            auth.tool_call.fields.title.is_some(),
3877            "expected authorization request for sudo command despite always_allow_tool_actions=true"
3878        );
3879    }
3880
3881    // Test 4: default_mode: Deny blocks commands when no pattern matches
3882    {
3883        let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3884        let environment = Rc::new(FakeThreadEnvironment {
3885            handle: handle.clone(),
3886        });
3887
3888        cx.update(|cx| {
3889            let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3890            settings.always_allow_tool_actions = true;
3891            settings.tool_permissions.tools.insert(
3892                "terminal".into(),
3893                agent_settings::ToolRules {
3894                    default_mode: settings::ToolPermissionMode::Deny,
3895                    always_allow: vec![],
3896                    always_deny: vec![],
3897                    always_confirm: vec![],
3898                    invalid_patterns: vec![],
3899                },
3900            );
3901            agent_settings::AgentSettings::override_global(settings, cx);
3902        });
3903
3904        #[allow(clippy::arc_with_non_send_sync)]
3905        let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3906        let (event_stream, _rx) = crate::ToolCallEventStream::test();
3907
3908        let task = cx.update(|cx| {
3909            tool.run(
3910                crate::TerminalToolInput {
3911                    command: "echo hello".to_string(),
3912                    cd: ".".to_string(),
3913                    timeout_ms: None,
3914                },
3915                event_stream,
3916                cx,
3917            )
3918        });
3919
3920        let result = task.await;
3921        assert!(
3922            result.is_err(),
3923            "expected command to be blocked by default_mode: Deny"
3924        );
3925        assert!(
3926            result.unwrap_err().to_string().contains("disabled"),
3927            "error should mention the tool is disabled"
3928        );
3929    }
3930}
3931
3932#[gpui::test]
3933async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
3934    init_test(cx);
3935
3936    cx.update(|cx| {
3937        cx.update_flags(true, vec!["subagents".to_string()]);
3938    });
3939
3940    let fs = FakeFs::new(cx.executor());
3941    fs.insert_tree(path!("/test"), json!({})).await;
3942    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3943    let project_context = cx.new(|_cx| ProjectContext::default());
3944    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3945    let context_server_registry =
3946        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3947    let model = Arc::new(FakeLanguageModel::default());
3948
3949    let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3950    let environment = Rc::new(FakeThreadEnvironment { handle });
3951
3952    let thread = cx.new(|cx| {
3953        let mut thread = Thread::new(
3954            project.clone(),
3955            project_context,
3956            context_server_registry,
3957            Templates::new(),
3958            Some(model),
3959            cx,
3960        );
3961        thread.add_default_tools(environment, cx);
3962        thread
3963    });
3964
3965    thread.read_with(cx, |thread, _| {
3966        assert!(
3967            thread.has_registered_tool("subagent"),
3968            "subagent tool should be present when feature flag is enabled"
3969        );
3970    });
3971}
3972
3973#[gpui::test]
3974async fn test_subagent_thread_inherits_parent_model(cx: &mut TestAppContext) {
3975    init_test(cx);
3976
3977    cx.update(|cx| {
3978        cx.update_flags(true, vec!["subagents".to_string()]);
3979    });
3980
3981    let fs = FakeFs::new(cx.executor());
3982    fs.insert_tree(path!("/test"), json!({})).await;
3983    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3984    let project_context = cx.new(|_cx| ProjectContext::default());
3985    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3986    let context_server_registry =
3987        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3988    let model = Arc::new(FakeLanguageModel::default());
3989
3990    let subagent_context = SubagentContext {
3991        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3992        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3993        depth: 1,
3994        summary_prompt: "Summarize".to_string(),
3995        context_low_prompt: "Context low".to_string(),
3996    };
3997
3998    let subagent = cx.new(|cx| {
3999        Thread::new_subagent(
4000            project.clone(),
4001            project_context,
4002            context_server_registry,
4003            Templates::new(),
4004            model.clone(),
4005            subagent_context,
4006            std::collections::BTreeMap::new(),
4007            cx,
4008        )
4009    });
4010
4011    subagent.read_with(cx, |thread, _| {
4012        assert!(thread.is_subagent());
4013        assert_eq!(thread.depth(), 1);
4014        assert!(thread.model().is_some());
4015    });
4016}
4017
4018#[gpui::test]
4019async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
4020    init_test(cx);
4021
4022    cx.update(|cx| {
4023        cx.update_flags(true, vec!["subagents".to_string()]);
4024    });
4025
4026    let fs = FakeFs::new(cx.executor());
4027    fs.insert_tree(path!("/test"), json!({})).await;
4028    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4029    let project_context = cx.new(|_cx| ProjectContext::default());
4030    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4031    let context_server_registry =
4032        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4033    let model = Arc::new(FakeLanguageModel::default());
4034
4035    let subagent_context = SubagentContext {
4036        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4037        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4038        depth: MAX_SUBAGENT_DEPTH,
4039        summary_prompt: "Summarize".to_string(),
4040        context_low_prompt: "Context low".to_string(),
4041    };
4042
4043    let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
4044    let environment = Rc::new(FakeThreadEnvironment { handle });
4045
4046    let deep_subagent = cx.new(|cx| {
4047        let mut thread = Thread::new_subagent(
4048            project.clone(),
4049            project_context,
4050            context_server_registry,
4051            Templates::new(),
4052            model.clone(),
4053            subagent_context,
4054            std::collections::BTreeMap::new(),
4055            cx,
4056        );
4057        thread.add_default_tools(environment, cx);
4058        thread
4059    });
4060
4061    deep_subagent.read_with(cx, |thread, _| {
4062        assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
4063        assert!(
4064            !thread.has_registered_tool("subagent"),
4065            "subagent tool should not be present at max depth"
4066        );
4067    });
4068}
4069
4070#[gpui::test]
4071async fn test_subagent_receives_task_prompt(cx: &mut TestAppContext) {
4072    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4073    let fake_model = model.as_fake();
4074
4075    cx.update(|cx| {
4076        cx.update_flags(true, vec!["subagents".to_string()]);
4077    });
4078
4079    let subagent_context = SubagentContext {
4080        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4081        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4082        depth: 1,
4083        summary_prompt: "Summarize your work".to_string(),
4084        context_low_prompt: "Context low, wrap up".to_string(),
4085    };
4086
4087    let project = thread.read_with(cx, |t, _| t.project.clone());
4088    let project_context = cx.new(|_cx| ProjectContext::default());
4089    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4090    let context_server_registry =
4091        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4092
4093    let subagent = cx.new(|cx| {
4094        Thread::new_subagent(
4095            project.clone(),
4096            project_context,
4097            context_server_registry,
4098            Templates::new(),
4099            model.clone(),
4100            subagent_context,
4101            std::collections::BTreeMap::new(),
4102            cx,
4103        )
4104    });
4105
4106    let task_prompt = "Find all TODO comments in the codebase";
4107    subagent
4108        .update(cx, |thread, cx| thread.submit_user_message(task_prompt, cx))
4109        .unwrap();
4110    cx.run_until_parked();
4111
4112    let pending = fake_model.pending_completions();
4113    assert_eq!(pending.len(), 1, "should have one pending completion");
4114
4115    let messages = &pending[0].messages;
4116    let user_messages: Vec<_> = messages
4117        .iter()
4118        .filter(|m| m.role == language_model::Role::User)
4119        .collect();
4120    assert_eq!(user_messages.len(), 1, "should have one user message");
4121
4122    let content = &user_messages[0].content[0];
4123    assert!(
4124        content.to_str().unwrap().contains("TODO"),
4125        "task prompt should be in user message"
4126    );
4127}
4128
4129#[gpui::test]
4130async fn test_subagent_returns_summary_on_completion(cx: &mut TestAppContext) {
4131    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4132    let fake_model = model.as_fake();
4133
4134    cx.update(|cx| {
4135        cx.update_flags(true, vec!["subagents".to_string()]);
4136    });
4137
4138    let subagent_context = SubagentContext {
4139        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4140        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4141        depth: 1,
4142        summary_prompt: "Please summarize what you found".to_string(),
4143        context_low_prompt: "Context low, wrap up".to_string(),
4144    };
4145
4146    let project = thread.read_with(cx, |t, _| t.project.clone());
4147    let project_context = cx.new(|_cx| ProjectContext::default());
4148    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4149    let context_server_registry =
4150        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4151
4152    let subagent = cx.new(|cx| {
4153        Thread::new_subagent(
4154            project.clone(),
4155            project_context,
4156            context_server_registry,
4157            Templates::new(),
4158            model.clone(),
4159            subagent_context,
4160            std::collections::BTreeMap::new(),
4161            cx,
4162        )
4163    });
4164
4165    subagent
4166        .update(cx, |thread, cx| {
4167            thread.submit_user_message("Do some work", cx)
4168        })
4169        .unwrap();
4170    cx.run_until_parked();
4171
4172    fake_model.send_last_completion_stream_text_chunk("I did the work");
4173    fake_model
4174        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4175    fake_model.end_last_completion_stream();
4176    cx.run_until_parked();
4177
4178    subagent
4179        .update(cx, |thread, cx| thread.request_final_summary(cx))
4180        .unwrap();
4181    cx.run_until_parked();
4182
4183    let pending = fake_model.pending_completions();
4184    assert!(
4185        !pending.is_empty(),
4186        "should have pending completion for summary"
4187    );
4188
4189    let messages = &pending.last().unwrap().messages;
4190    let user_messages: Vec<_> = messages
4191        .iter()
4192        .filter(|m| m.role == language_model::Role::User)
4193        .collect();
4194
4195    let last_user = user_messages.last().unwrap();
4196    assert!(
4197        last_user.content[0].to_str().unwrap().contains("summarize"),
4198        "summary prompt should be sent"
4199    );
4200}
4201
4202#[gpui::test]
4203async fn test_allowed_tools_restricts_subagent_capabilities(cx: &mut TestAppContext) {
4204    init_test(cx);
4205
4206    cx.update(|cx| {
4207        cx.update_flags(true, vec!["subagents".to_string()]);
4208    });
4209
4210    let fs = FakeFs::new(cx.executor());
4211    fs.insert_tree(path!("/test"), json!({})).await;
4212    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4213    let project_context = cx.new(|_cx| ProjectContext::default());
4214    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4215    let context_server_registry =
4216        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4217    let model = Arc::new(FakeLanguageModel::default());
4218
4219    let subagent_context = SubagentContext {
4220        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4221        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4222        depth: 1,
4223        summary_prompt: "Summarize".to_string(),
4224        context_low_prompt: "Context low".to_string(),
4225    };
4226
4227    let subagent = cx.new(|cx| {
4228        let mut thread = Thread::new_subagent(
4229            project.clone(),
4230            project_context,
4231            context_server_registry,
4232            Templates::new(),
4233            model.clone(),
4234            subagent_context,
4235            std::collections::BTreeMap::new(),
4236            cx,
4237        );
4238        thread.add_tool(EchoTool);
4239        thread.add_tool(DelayTool);
4240        thread.add_tool(WordListTool);
4241        thread
4242    });
4243
4244    subagent.read_with(cx, |thread, _| {
4245        assert!(thread.has_registered_tool("echo"));
4246        assert!(thread.has_registered_tool("delay"));
4247        assert!(thread.has_registered_tool("word_list"));
4248    });
4249
4250    let allowed: collections::HashSet<gpui::SharedString> =
4251        vec!["echo".into()].into_iter().collect();
4252
4253    subagent.update(cx, |thread, _cx| {
4254        thread.restrict_tools(&allowed);
4255    });
4256
4257    subagent.read_with(cx, |thread, _| {
4258        assert!(
4259            thread.has_registered_tool("echo"),
4260            "echo should still be available"
4261        );
4262        assert!(
4263            !thread.has_registered_tool("delay"),
4264            "delay should be removed"
4265        );
4266        assert!(
4267            !thread.has_registered_tool("word_list"),
4268            "word_list should be removed"
4269        );
4270    });
4271}
4272
4273#[gpui::test]
4274async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4275    init_test(cx);
4276
4277    cx.update(|cx| {
4278        cx.update_flags(true, vec!["subagents".to_string()]);
4279    });
4280
4281    let fs = FakeFs::new(cx.executor());
4282    fs.insert_tree(path!("/test"), json!({})).await;
4283    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4284    let project_context = cx.new(|_cx| ProjectContext::default());
4285    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4286    let context_server_registry =
4287        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4288    let model = Arc::new(FakeLanguageModel::default());
4289
4290    let parent = cx.new(|cx| {
4291        Thread::new(
4292            project.clone(),
4293            project_context.clone(),
4294            context_server_registry.clone(),
4295            Templates::new(),
4296            Some(model.clone()),
4297            cx,
4298        )
4299    });
4300
4301    let subagent_context = SubagentContext {
4302        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4303        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4304        depth: 1,
4305        summary_prompt: "Summarize".to_string(),
4306        context_low_prompt: "Context low".to_string(),
4307    };
4308
4309    let subagent = cx.new(|cx| {
4310        Thread::new_subagent(
4311            project.clone(),
4312            project_context.clone(),
4313            context_server_registry.clone(),
4314            Templates::new(),
4315            model.clone(),
4316            subagent_context,
4317            std::collections::BTreeMap::new(),
4318            cx,
4319        )
4320    });
4321
4322    parent.update(cx, |thread, _cx| {
4323        thread.register_running_subagent(subagent.downgrade());
4324    });
4325
4326    subagent
4327        .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4328        .unwrap();
4329    cx.run_until_parked();
4330
4331    subagent.read_with(cx, |thread, _| {
4332        assert!(!thread.is_turn_complete(), "subagent should be running");
4333    });
4334
4335    parent.update(cx, |thread, cx| {
4336        thread.cancel(cx).detach();
4337    });
4338
4339    subagent.read_with(cx, |thread, _| {
4340        assert!(
4341            thread.is_turn_complete(),
4342            "subagent should be cancelled when parent cancels"
4343        );
4344    });
4345}
4346
4347#[gpui::test]
4348async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) {
4349    // This test verifies that the subagent tool properly handles user cancellation
4350    // via `event_stream.cancelled_by_user()` and stops all running subagents.
4351    init_test(cx);
4352    always_allow_tools(cx);
4353
4354    cx.update(|cx| {
4355        cx.update_flags(true, vec!["subagents".to_string()]);
4356    });
4357
4358    let fs = FakeFs::new(cx.executor());
4359    fs.insert_tree(path!("/test"), json!({})).await;
4360    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4361    let project_context = cx.new(|_cx| ProjectContext::default());
4362    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4363    let context_server_registry =
4364        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4365    let model = Arc::new(FakeLanguageModel::default());
4366
4367    let parent = cx.new(|cx| {
4368        Thread::new(
4369            project.clone(),
4370            project_context.clone(),
4371            context_server_registry.clone(),
4372            Templates::new(),
4373            Some(model.clone()),
4374            cx,
4375        )
4376    });
4377
4378    let parent_tools: std::collections::BTreeMap<gpui::SharedString, Arc<dyn crate::AnyAgentTool>> =
4379        std::collections::BTreeMap::new();
4380
4381    #[allow(clippy::arc_with_non_send_sync)]
4382    let tool = Arc::new(SubagentTool::new(
4383        parent.downgrade(),
4384        project.clone(),
4385        project_context,
4386        context_server_registry,
4387        Templates::new(),
4388        0,
4389        parent_tools,
4390    ));
4391
4392    let (event_stream, _rx, mut cancellation_tx) =
4393        crate::ToolCallEventStream::test_with_cancellation();
4394
4395    // Start the subagent tool
4396    let task = cx.update(|cx| {
4397        tool.run(
4398            SubagentToolInput {
4399                subagents: vec![crate::SubagentConfig {
4400                    label: "Long running task".to_string(),
4401                    task_prompt: "Do a very long task that takes forever".to_string(),
4402                    summary_prompt: "Summarize".to_string(),
4403                    context_low_prompt: "Context low".to_string(),
4404                    timeout_ms: None,
4405                    allowed_tools: None,
4406                }],
4407            },
4408            event_stream.clone(),
4409            cx,
4410        )
4411    });
4412
4413    cx.run_until_parked();
4414
4415    // Signal cancellation via the event stream
4416    crate::ToolCallEventStream::signal_cancellation_with_sender(&mut cancellation_tx);
4417
4418    // The task should complete promptly with a cancellation error
4419    let timeout = cx.background_executor.timer(Duration::from_secs(5));
4420    let result = futures::select! {
4421        result = task.fuse() => result,
4422        _ = timeout.fuse() => {
4423            panic!("subagent tool did not respond to cancellation within timeout");
4424        }
4425    };
4426
4427    // Verify we got a cancellation error
4428    let err = result.unwrap_err();
4429    assert!(
4430        err.to_string().contains("cancelled by user"),
4431        "expected cancellation error, got: {}",
4432        err
4433    );
4434}
4435
4436#[gpui::test]
4437async fn test_subagent_model_error_returned_as_tool_error(cx: &mut TestAppContext) {
4438    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4439    let fake_model = model.as_fake();
4440
4441    cx.update(|cx| {
4442        cx.update_flags(true, vec!["subagents".to_string()]);
4443    });
4444
4445    let subagent_context = SubagentContext {
4446        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4447        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4448        depth: 1,
4449        summary_prompt: "Summarize".to_string(),
4450        context_low_prompt: "Context low".to_string(),
4451    };
4452
4453    let project = thread.read_with(cx, |t, _| t.project.clone());
4454    let project_context = cx.new(|_cx| ProjectContext::default());
4455    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4456    let context_server_registry =
4457        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4458
4459    let subagent = cx.new(|cx| {
4460        Thread::new_subagent(
4461            project.clone(),
4462            project_context,
4463            context_server_registry,
4464            Templates::new(),
4465            model.clone(),
4466            subagent_context,
4467            std::collections::BTreeMap::new(),
4468            cx,
4469        )
4470    });
4471
4472    subagent
4473        .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4474        .unwrap();
4475    cx.run_until_parked();
4476
4477    subagent.read_with(cx, |thread, _| {
4478        assert!(!thread.is_turn_complete(), "turn should be in progress");
4479    });
4480
4481    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::NoApiKey {
4482        provider: LanguageModelProviderName::from("Fake".to_string()),
4483    });
4484    fake_model.end_last_completion_stream();
4485    cx.run_until_parked();
4486
4487    subagent.read_with(cx, |thread, _| {
4488        assert!(
4489            thread.is_turn_complete(),
4490            "turn should be complete after non-retryable error"
4491        );
4492    });
4493}
4494
4495#[gpui::test]
4496async fn test_subagent_timeout_triggers_early_summary(cx: &mut TestAppContext) {
4497    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4498    let fake_model = model.as_fake();
4499
4500    cx.update(|cx| {
4501        cx.update_flags(true, vec!["subagents".to_string()]);
4502    });
4503
4504    let subagent_context = SubagentContext {
4505        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4506        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4507        depth: 1,
4508        summary_prompt: "Summarize your work".to_string(),
4509        context_low_prompt: "Context low, stop and summarize".to_string(),
4510    };
4511
4512    let project = thread.read_with(cx, |t, _| t.project.clone());
4513    let project_context = cx.new(|_cx| ProjectContext::default());
4514    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4515    let context_server_registry =
4516        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4517
4518    let subagent = cx.new(|cx| {
4519        Thread::new_subagent(
4520            project.clone(),
4521            project_context.clone(),
4522            context_server_registry.clone(),
4523            Templates::new(),
4524            model.clone(),
4525            subagent_context.clone(),
4526            std::collections::BTreeMap::new(),
4527            cx,
4528        )
4529    });
4530
4531    subagent.update(cx, |thread, _| {
4532        thread.add_tool(EchoTool);
4533    });
4534
4535    subagent
4536        .update(cx, |thread, cx| {
4537            thread.submit_user_message("Do some work", cx)
4538        })
4539        .unwrap();
4540    cx.run_until_parked();
4541
4542    fake_model.send_last_completion_stream_text_chunk("Working on it...");
4543    fake_model
4544        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4545    fake_model.end_last_completion_stream();
4546    cx.run_until_parked();
4547
4548    let interrupt_result = subagent.update(cx, |thread, cx| thread.interrupt_for_summary(cx));
4549    assert!(
4550        interrupt_result.is_ok(),
4551        "interrupt_for_summary should succeed"
4552    );
4553
4554    cx.run_until_parked();
4555
4556    let pending = fake_model.pending_completions();
4557    assert!(
4558        !pending.is_empty(),
4559        "should have pending completion for interrupted summary"
4560    );
4561
4562    let messages = &pending.last().unwrap().messages;
4563    let user_messages: Vec<_> = messages
4564        .iter()
4565        .filter(|m| m.role == language_model::Role::User)
4566        .collect();
4567
4568    let last_user = user_messages.last().unwrap();
4569    let content_str = last_user.content[0].to_str().unwrap();
4570    assert!(
4571        content_str.contains("Context low") || content_str.contains("stop and summarize"),
4572        "context_low_prompt should be sent when interrupting: got {:?}",
4573        content_str
4574    );
4575}
4576
4577#[gpui::test]
4578async fn test_context_low_check_returns_true_when_usage_high(cx: &mut TestAppContext) {
4579    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4580    let fake_model = model.as_fake();
4581
4582    cx.update(|cx| {
4583        cx.update_flags(true, vec!["subagents".to_string()]);
4584    });
4585
4586    let subagent_context = SubagentContext {
4587        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4588        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4589        depth: 1,
4590        summary_prompt: "Summarize".to_string(),
4591        context_low_prompt: "Context low".to_string(),
4592    };
4593
4594    let project = thread.read_with(cx, |t, _| t.project.clone());
4595    let project_context = cx.new(|_cx| ProjectContext::default());
4596    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4597    let context_server_registry =
4598        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4599
4600    let subagent = cx.new(|cx| {
4601        Thread::new_subagent(
4602            project.clone(),
4603            project_context,
4604            context_server_registry,
4605            Templates::new(),
4606            model.clone(),
4607            subagent_context,
4608            std::collections::BTreeMap::new(),
4609            cx,
4610        )
4611    });
4612
4613    subagent
4614        .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4615        .unwrap();
4616    cx.run_until_parked();
4617
4618    let max_tokens = model.max_token_count();
4619    let high_usage = language_model::TokenUsage {
4620        input_tokens: (max_tokens as f64 * 0.80) as u64,
4621        output_tokens: 0,
4622        cache_creation_input_tokens: 0,
4623        cache_read_input_tokens: 0,
4624    };
4625
4626    fake_model
4627        .send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(high_usage));
4628    fake_model.send_last_completion_stream_text_chunk("Working...");
4629    fake_model
4630        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4631    fake_model.end_last_completion_stream();
4632    cx.run_until_parked();
4633
4634    let usage = subagent.read_with(cx, |thread, _| thread.latest_token_usage());
4635    assert!(usage.is_some(), "should have token usage after completion");
4636
4637    let usage = usage.unwrap();
4638    let remaining_ratio = 1.0 - (usage.used_tokens as f32 / usage.max_tokens as f32);
4639    assert!(
4640        remaining_ratio <= 0.25,
4641        "remaining ratio should be at or below 25% (got {}%), indicating context is low",
4642        remaining_ratio * 100.0
4643    );
4644}
4645
4646#[gpui::test]
4647async fn test_allowed_tools_rejects_unknown_tool(cx: &mut TestAppContext) {
4648    init_test(cx);
4649
4650    cx.update(|cx| {
4651        cx.update_flags(true, vec!["subagents".to_string()]);
4652    });
4653
4654    let fs = FakeFs::new(cx.executor());
4655    fs.insert_tree(path!("/test"), json!({})).await;
4656    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4657    let project_context = cx.new(|_cx| ProjectContext::default());
4658    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4659    let context_server_registry =
4660        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4661    let model = Arc::new(FakeLanguageModel::default());
4662
4663    let parent = cx.new(|cx| {
4664        let mut thread = Thread::new(
4665            project.clone(),
4666            project_context.clone(),
4667            context_server_registry.clone(),
4668            Templates::new(),
4669            Some(model.clone()),
4670            cx,
4671        );
4672        thread.add_tool(EchoTool);
4673        thread
4674    });
4675
4676    let mut parent_tools: std::collections::BTreeMap<
4677        gpui::SharedString,
4678        Arc<dyn crate::AnyAgentTool>,
4679    > = std::collections::BTreeMap::new();
4680    parent_tools.insert("echo".into(), EchoTool.erase());
4681
4682    #[allow(clippy::arc_with_non_send_sync)]
4683    let tool = Arc::new(SubagentTool::new(
4684        parent.downgrade(),
4685        project,
4686        project_context,
4687        context_server_registry,
4688        Templates::new(),
4689        0,
4690        parent_tools,
4691    ));
4692
4693    let subagent_configs = vec![crate::SubagentConfig {
4694        label: "Test".to_string(),
4695        task_prompt: "Do something".to_string(),
4696        summary_prompt: "Summarize".to_string(),
4697        context_low_prompt: "Context low".to_string(),
4698        timeout_ms: None,
4699        allowed_tools: Some(vec!["nonexistent_tool".to_string()]),
4700    }];
4701    let result = tool.validate_subagents(&subagent_configs);
4702    assert!(result.is_err(), "should reject unknown tool");
4703    let err_msg = result.unwrap_err().to_string();
4704    assert!(
4705        err_msg.contains("nonexistent_tool"),
4706        "error should mention the invalid tool name: {}",
4707        err_msg
4708    );
4709    assert!(
4710        err_msg.contains("do not exist"),
4711        "error should explain the tool does not exist: {}",
4712        err_msg
4713    );
4714}
4715
4716#[gpui::test]
4717async fn test_subagent_empty_response_handled(cx: &mut TestAppContext) {
4718    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4719    let fake_model = model.as_fake();
4720
4721    cx.update(|cx| {
4722        cx.update_flags(true, vec!["subagents".to_string()]);
4723    });
4724
4725    let subagent_context = SubagentContext {
4726        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4727        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4728        depth: 1,
4729        summary_prompt: "Summarize".to_string(),
4730        context_low_prompt: "Context low".to_string(),
4731    };
4732
4733    let project = thread.read_with(cx, |t, _| t.project.clone());
4734    let project_context = cx.new(|_cx| ProjectContext::default());
4735    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4736    let context_server_registry =
4737        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4738
4739    let subagent = cx.new(|cx| {
4740        Thread::new_subagent(
4741            project.clone(),
4742            project_context,
4743            context_server_registry,
4744            Templates::new(),
4745            model.clone(),
4746            subagent_context,
4747            std::collections::BTreeMap::new(),
4748            cx,
4749        )
4750    });
4751
4752    subagent
4753        .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4754        .unwrap();
4755    cx.run_until_parked();
4756
4757    fake_model
4758        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4759    fake_model.end_last_completion_stream();
4760    cx.run_until_parked();
4761
4762    subagent.read_with(cx, |thread, _| {
4763        assert!(
4764            thread.is_turn_complete(),
4765            "turn should complete even with empty response"
4766        );
4767    });
4768}
4769
4770#[gpui::test]
4771async fn test_nested_subagent_at_depth_2_succeeds(cx: &mut TestAppContext) {
4772    init_test(cx);
4773
4774    cx.update(|cx| {
4775        cx.update_flags(true, vec!["subagents".to_string()]);
4776    });
4777
4778    let fs = FakeFs::new(cx.executor());
4779    fs.insert_tree(path!("/test"), json!({})).await;
4780    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4781    let project_context = cx.new(|_cx| ProjectContext::default());
4782    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4783    let context_server_registry =
4784        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4785    let model = Arc::new(FakeLanguageModel::default());
4786
4787    let depth_1_context = SubagentContext {
4788        parent_thread_id: agent_client_protocol::SessionId::new("root-id"),
4789        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-1"),
4790        depth: 1,
4791        summary_prompt: "Summarize".to_string(),
4792        context_low_prompt: "Context low".to_string(),
4793    };
4794
4795    let depth_1_subagent = cx.new(|cx| {
4796        Thread::new_subagent(
4797            project.clone(),
4798            project_context.clone(),
4799            context_server_registry.clone(),
4800            Templates::new(),
4801            model.clone(),
4802            depth_1_context,
4803            std::collections::BTreeMap::new(),
4804            cx,
4805        )
4806    });
4807
4808    depth_1_subagent.read_with(cx, |thread, _| {
4809        assert_eq!(thread.depth(), 1);
4810        assert!(thread.is_subagent());
4811    });
4812
4813    let depth_2_context = SubagentContext {
4814        parent_thread_id: agent_client_protocol::SessionId::new("depth-1-id"),
4815        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-2"),
4816        depth: 2,
4817        summary_prompt: "Summarize depth 2".to_string(),
4818        context_low_prompt: "Context low depth 2".to_string(),
4819    };
4820
4821    let depth_2_subagent = cx.new(|cx| {
4822        Thread::new_subagent(
4823            project.clone(),
4824            project_context.clone(),
4825            context_server_registry.clone(),
4826            Templates::new(),
4827            model.clone(),
4828            depth_2_context,
4829            std::collections::BTreeMap::new(),
4830            cx,
4831        )
4832    });
4833
4834    depth_2_subagent.read_with(cx, |thread, _| {
4835        assert_eq!(thread.depth(), 2);
4836        assert!(thread.is_subagent());
4837    });
4838
4839    depth_2_subagent
4840        .update(cx, |thread, cx| {
4841            thread.submit_user_message("Nested task", cx)
4842        })
4843        .unwrap();
4844    cx.run_until_parked();
4845
4846    let pending = model.as_fake().pending_completions();
4847    assert!(
4848        !pending.is_empty(),
4849        "depth-2 subagent should be able to submit messages"
4850    );
4851}
4852
4853#[gpui::test]
4854async fn test_subagent_uses_tool_and_returns_result(cx: &mut TestAppContext) {
4855    init_test(cx);
4856    always_allow_tools(cx);
4857
4858    cx.update(|cx| {
4859        cx.update_flags(true, vec!["subagents".to_string()]);
4860    });
4861
4862    let fs = FakeFs::new(cx.executor());
4863    fs.insert_tree(path!("/test"), json!({})).await;
4864    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4865    let project_context = cx.new(|_cx| ProjectContext::default());
4866    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4867    let context_server_registry =
4868        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4869    let model = Arc::new(FakeLanguageModel::default());
4870    let fake_model = model.as_fake();
4871
4872    let subagent_context = SubagentContext {
4873        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4874        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4875        depth: 1,
4876        summary_prompt: "Summarize what you did".to_string(),
4877        context_low_prompt: "Context low".to_string(),
4878    };
4879
4880    let subagent = cx.new(|cx| {
4881        let mut thread = Thread::new_subagent(
4882            project.clone(),
4883            project_context,
4884            context_server_registry,
4885            Templates::new(),
4886            model.clone(),
4887            subagent_context,
4888            std::collections::BTreeMap::new(),
4889            cx,
4890        );
4891        thread.add_tool(EchoTool);
4892        thread
4893    });
4894
4895    subagent.read_with(cx, |thread, _| {
4896        assert!(
4897            thread.has_registered_tool("echo"),
4898            "subagent should have echo tool"
4899        );
4900    });
4901
4902    subagent
4903        .update(cx, |thread, cx| {
4904            thread.submit_user_message("Use the echo tool to echo 'hello world'", cx)
4905        })
4906        .unwrap();
4907    cx.run_until_parked();
4908
4909    let tool_use = LanguageModelToolUse {
4910        id: "tool_call_1".into(),
4911        name: EchoTool::name().into(),
4912        raw_input: json!({"text": "hello world"}).to_string(),
4913        input: json!({"text": "hello world"}),
4914        is_input_complete: true,
4915        thought_signature: None,
4916    };
4917    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
4918    fake_model.end_last_completion_stream();
4919    cx.run_until_parked();
4920
4921    let pending = fake_model.pending_completions();
4922    assert!(
4923        !pending.is_empty(),
4924        "should have pending completion after tool use"
4925    );
4926
4927    let last_completion = pending.last().unwrap();
4928    let has_tool_result = last_completion.messages.iter().any(|m| {
4929        m.content
4930            .iter()
4931            .any(|c| matches!(c, MessageContent::ToolResult(_)))
4932    });
4933    assert!(
4934        has_tool_result,
4935        "tool result should be in the messages sent back to the model"
4936    );
4937}
4938
4939#[gpui::test]
4940async fn test_max_parallel_subagents_enforced(cx: &mut TestAppContext) {
4941    init_test(cx);
4942
4943    cx.update(|cx| {
4944        cx.update_flags(true, vec!["subagents".to_string()]);
4945    });
4946
4947    let fs = FakeFs::new(cx.executor());
4948    fs.insert_tree(path!("/test"), json!({})).await;
4949    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4950    let project_context = cx.new(|_cx| ProjectContext::default());
4951    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4952    let context_server_registry =
4953        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4954    let model = Arc::new(FakeLanguageModel::default());
4955
4956    let parent = cx.new(|cx| {
4957        Thread::new(
4958            project.clone(),
4959            project_context.clone(),
4960            context_server_registry.clone(),
4961            Templates::new(),
4962            Some(model.clone()),
4963            cx,
4964        )
4965    });
4966
4967    let mut subagents = Vec::new();
4968    for i in 0..MAX_PARALLEL_SUBAGENTS {
4969        let subagent_context = SubagentContext {
4970            parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4971            tool_use_id: language_model::LanguageModelToolUseId::from(format!("tool-use-{}", i)),
4972            depth: 1,
4973            summary_prompt: "Summarize".to_string(),
4974            context_low_prompt: "Context low".to_string(),
4975        };
4976
4977        let subagent = cx.new(|cx| {
4978            Thread::new_subagent(
4979                project.clone(),
4980                project_context.clone(),
4981                context_server_registry.clone(),
4982                Templates::new(),
4983                model.clone(),
4984                subagent_context,
4985                std::collections::BTreeMap::new(),
4986                cx,
4987            )
4988        });
4989
4990        parent.update(cx, |thread, _cx| {
4991            thread.register_running_subagent(subagent.downgrade());
4992        });
4993        subagents.push(subagent);
4994    }
4995
4996    parent.read_with(cx, |thread, _| {
4997        assert_eq!(
4998            thread.running_subagent_count(),
4999            MAX_PARALLEL_SUBAGENTS,
5000            "should have MAX_PARALLEL_SUBAGENTS registered"
5001        );
5002    });
5003
5004    let parent_tools: std::collections::BTreeMap<gpui::SharedString, Arc<dyn crate::AnyAgentTool>> =
5005        std::collections::BTreeMap::new();
5006
5007    #[allow(clippy::arc_with_non_send_sync)]
5008    let tool = Arc::new(SubagentTool::new(
5009        parent.downgrade(),
5010        project.clone(),
5011        project_context,
5012        context_server_registry,
5013        Templates::new(),
5014        0,
5015        parent_tools,
5016    ));
5017
5018    let (event_stream, _rx) = crate::ToolCallEventStream::test();
5019
5020    let result = cx.update(|cx| {
5021        tool.run(
5022            SubagentToolInput {
5023                subagents: vec![crate::SubagentConfig {
5024                    label: "Test".to_string(),
5025                    task_prompt: "Do something".to_string(),
5026                    summary_prompt: "Summarize".to_string(),
5027                    context_low_prompt: "Context low".to_string(),
5028                    timeout_ms: None,
5029                    allowed_tools: None,
5030                }],
5031            },
5032            event_stream,
5033            cx,
5034        )
5035    });
5036
5037    let err = result.await.unwrap_err();
5038    assert!(
5039        err.to_string().contains("Maximum parallel subagents"),
5040        "should reject when max parallel subagents reached: {}",
5041        err
5042    );
5043
5044    drop(subagents);
5045}
5046
5047#[gpui::test]
5048async fn test_subagent_tool_end_to_end(cx: &mut TestAppContext) {
5049    init_test(cx);
5050    always_allow_tools(cx);
5051
5052    cx.update(|cx| {
5053        cx.update_flags(true, vec!["subagents".to_string()]);
5054    });
5055
5056    let fs = FakeFs::new(cx.executor());
5057    fs.insert_tree(path!("/test"), json!({})).await;
5058    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
5059    let project_context = cx.new(|_cx| ProjectContext::default());
5060    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
5061    let context_server_registry =
5062        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
5063    let model = Arc::new(FakeLanguageModel::default());
5064    let fake_model = model.as_fake();
5065
5066    let parent = cx.new(|cx| {
5067        let mut thread = Thread::new(
5068            project.clone(),
5069            project_context.clone(),
5070            context_server_registry.clone(),
5071            Templates::new(),
5072            Some(model.clone()),
5073            cx,
5074        );
5075        thread.add_tool(EchoTool);
5076        thread
5077    });
5078
5079    let mut parent_tools: std::collections::BTreeMap<
5080        gpui::SharedString,
5081        Arc<dyn crate::AnyAgentTool>,
5082    > = std::collections::BTreeMap::new();
5083    parent_tools.insert("echo".into(), EchoTool.erase());
5084
5085    #[allow(clippy::arc_with_non_send_sync)]
5086    let tool = Arc::new(SubagentTool::new(
5087        parent.downgrade(),
5088        project.clone(),
5089        project_context,
5090        context_server_registry,
5091        Templates::new(),
5092        0,
5093        parent_tools,
5094    ));
5095
5096    let (event_stream, _rx) = crate::ToolCallEventStream::test();
5097
5098    let task = cx.update(|cx| {
5099        tool.run(
5100            SubagentToolInput {
5101                subagents: vec![crate::SubagentConfig {
5102                    label: "Research task".to_string(),
5103                    task_prompt: "Find all TODOs in the codebase".to_string(),
5104                    summary_prompt: "Summarize what you found".to_string(),
5105                    context_low_prompt: "Context low, wrap up".to_string(),
5106                    timeout_ms: None,
5107                    allowed_tools: None,
5108                }],
5109            },
5110            event_stream,
5111            cx,
5112        )
5113    });
5114
5115    cx.run_until_parked();
5116
5117    let pending = fake_model.pending_completions();
5118    assert!(
5119        !pending.is_empty(),
5120        "subagent should have started and sent a completion request"
5121    );
5122
5123    let first_completion = &pending[0];
5124    let has_task_prompt = first_completion.messages.iter().any(|m| {
5125        m.role == language_model::Role::User
5126            && m.content
5127                .iter()
5128                .any(|c| c.to_str().map(|s| s.contains("TODO")).unwrap_or(false))
5129    });
5130    assert!(has_task_prompt, "task prompt should be sent to subagent");
5131
5132    fake_model.send_last_completion_stream_text_chunk("I found 5 TODOs in the codebase.");
5133    fake_model
5134        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
5135    fake_model.end_last_completion_stream();
5136    cx.run_until_parked();
5137
5138    let pending = fake_model.pending_completions();
5139    assert!(
5140        !pending.is_empty(),
5141        "should have pending completion for summary request"
5142    );
5143
5144    let last_completion = pending.last().unwrap();
5145    let has_summary_prompt = last_completion.messages.iter().any(|m| {
5146        m.role == language_model::Role::User
5147            && m.content.iter().any(|c| {
5148                c.to_str()
5149                    .map(|s| s.contains("Summarize") || s.contains("summarize"))
5150                    .unwrap_or(false)
5151            })
5152    });
5153    assert!(
5154        has_summary_prompt,
5155        "summary prompt should be sent after task completion"
5156    );
5157
5158    fake_model.send_last_completion_stream_text_chunk("Summary: Found 5 TODOs across 3 files.");
5159    fake_model
5160        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
5161    fake_model.end_last_completion_stream();
5162    cx.run_until_parked();
5163
5164    let result = task.await;
5165    assert!(result.is_ok(), "subagent tool should complete successfully");
5166
5167    let summary = result.unwrap();
5168    assert!(
5169        summary.contains("Summary") || summary.contains("TODO") || summary.contains("5"),
5170        "summary should contain subagent's response: {}",
5171        summary
5172    );
5173}
5174
5175#[gpui::test]
5176async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
5177    init_test(cx);
5178
5179    let fs = FakeFs::new(cx.executor());
5180    fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
5181        .await;
5182    let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5183
5184    cx.update(|cx| {
5185        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5186        settings.tool_permissions.tools.insert(
5187            "edit_file".into(),
5188            agent_settings::ToolRules {
5189                default_mode: settings::ToolPermissionMode::Allow,
5190                always_allow: vec![],
5191                always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
5192                always_confirm: vec![],
5193                invalid_patterns: vec![],
5194            },
5195        );
5196        agent_settings::AgentSettings::override_global(settings, cx);
5197    });
5198
5199    let context_server_registry =
5200        cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5201    let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5202    let templates = crate::Templates::new();
5203    let thread = cx.new(|cx| {
5204        crate::Thread::new(
5205            project.clone(),
5206            cx.new(|_cx| prompt_store::ProjectContext::default()),
5207            context_server_registry,
5208            templates.clone(),
5209            None,
5210            cx,
5211        )
5212    });
5213
5214    #[allow(clippy::arc_with_non_send_sync)]
5215    let tool = Arc::new(crate::EditFileTool::new(
5216        project.clone(),
5217        thread.downgrade(),
5218        language_registry,
5219        templates,
5220    ));
5221    let (event_stream, _rx) = crate::ToolCallEventStream::test();
5222
5223    let task = cx.update(|cx| {
5224        tool.run(
5225            crate::EditFileToolInput {
5226                display_description: "Edit sensitive file".to_string(),
5227                path: "root/sensitive_config.txt".into(),
5228                mode: crate::EditFileMode::Edit,
5229            },
5230            event_stream,
5231            cx,
5232        )
5233    });
5234
5235    let result = task.await;
5236    assert!(result.is_err(), "expected edit to be blocked");
5237    assert!(
5238        result.unwrap_err().to_string().contains("blocked"),
5239        "error should mention the edit was blocked"
5240    );
5241}
5242
5243#[gpui::test]
5244async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
5245    init_test(cx);
5246
5247    let fs = FakeFs::new(cx.executor());
5248    fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
5249        .await;
5250    let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5251
5252    cx.update(|cx| {
5253        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5254        settings.tool_permissions.tools.insert(
5255            "delete_path".into(),
5256            agent_settings::ToolRules {
5257                default_mode: settings::ToolPermissionMode::Allow,
5258                always_allow: vec![],
5259                always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
5260                always_confirm: vec![],
5261                invalid_patterns: vec![],
5262            },
5263        );
5264        agent_settings::AgentSettings::override_global(settings, cx);
5265    });
5266
5267    let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
5268
5269    #[allow(clippy::arc_with_non_send_sync)]
5270    let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
5271    let (event_stream, _rx) = crate::ToolCallEventStream::test();
5272
5273    let task = cx.update(|cx| {
5274        tool.run(
5275            crate::DeletePathToolInput {
5276                path: "root/important_data.txt".to_string(),
5277            },
5278            event_stream,
5279            cx,
5280        )
5281    });
5282
5283    let result = task.await;
5284    assert!(result.is_err(), "expected deletion to be blocked");
5285    assert!(
5286        result.unwrap_err().to_string().contains("blocked"),
5287        "error should mention the deletion was blocked"
5288    );
5289}
5290
5291#[gpui::test]
5292async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
5293    init_test(cx);
5294
5295    let fs = FakeFs::new(cx.executor());
5296    fs.insert_tree(
5297        "/root",
5298        json!({
5299            "safe.txt": "content",
5300            "protected": {}
5301        }),
5302    )
5303    .await;
5304    let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5305
5306    cx.update(|cx| {
5307        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5308        settings.tool_permissions.tools.insert(
5309            "move_path".into(),
5310            agent_settings::ToolRules {
5311                default_mode: settings::ToolPermissionMode::Allow,
5312                always_allow: vec![],
5313                always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
5314                always_confirm: vec![],
5315                invalid_patterns: vec![],
5316            },
5317        );
5318        agent_settings::AgentSettings::override_global(settings, cx);
5319    });
5320
5321    #[allow(clippy::arc_with_non_send_sync)]
5322    let tool = Arc::new(crate::MovePathTool::new(project));
5323    let (event_stream, _rx) = crate::ToolCallEventStream::test();
5324
5325    let task = cx.update(|cx| {
5326        tool.run(
5327            crate::MovePathToolInput {
5328                source_path: "root/safe.txt".to_string(),
5329                destination_path: "root/protected/safe.txt".to_string(),
5330            },
5331            event_stream,
5332            cx,
5333        )
5334    });
5335
5336    let result = task.await;
5337    assert!(
5338        result.is_err(),
5339        "expected move to be blocked due to destination path"
5340    );
5341    assert!(
5342        result.unwrap_err().to_string().contains("blocked"),
5343        "error should mention the move was blocked"
5344    );
5345}
5346
5347#[gpui::test]
5348async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5349    init_test(cx);
5350
5351    let fs = FakeFs::new(cx.executor());
5352    fs.insert_tree(
5353        "/root",
5354        json!({
5355            "secret.txt": "secret content",
5356            "public": {}
5357        }),
5358    )
5359    .await;
5360    let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5361
5362    cx.update(|cx| {
5363        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5364        settings.tool_permissions.tools.insert(
5365            "move_path".into(),
5366            agent_settings::ToolRules {
5367                default_mode: settings::ToolPermissionMode::Allow,
5368                always_allow: vec![],
5369                always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5370                always_confirm: vec![],
5371                invalid_patterns: vec![],
5372            },
5373        );
5374        agent_settings::AgentSettings::override_global(settings, cx);
5375    });
5376
5377    #[allow(clippy::arc_with_non_send_sync)]
5378    let tool = Arc::new(crate::MovePathTool::new(project));
5379    let (event_stream, _rx) = crate::ToolCallEventStream::test();
5380
5381    let task = cx.update(|cx| {
5382        tool.run(
5383            crate::MovePathToolInput {
5384                source_path: "root/secret.txt".to_string(),
5385                destination_path: "root/public/not_secret.txt".to_string(),
5386            },
5387            event_stream,
5388            cx,
5389        )
5390    });
5391
5392    let result = task.await;
5393    assert!(
5394        result.is_err(),
5395        "expected move to be blocked due to source path"
5396    );
5397    assert!(
5398        result.unwrap_err().to_string().contains("blocked"),
5399        "error should mention the move was blocked"
5400    );
5401}
5402
5403#[gpui::test]
5404async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5405    init_test(cx);
5406
5407    let fs = FakeFs::new(cx.executor());
5408    fs.insert_tree(
5409        "/root",
5410        json!({
5411            "confidential.txt": "confidential data",
5412            "dest": {}
5413        }),
5414    )
5415    .await;
5416    let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5417
5418    cx.update(|cx| {
5419        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5420        settings.tool_permissions.tools.insert(
5421            "copy_path".into(),
5422            agent_settings::ToolRules {
5423                default_mode: settings::ToolPermissionMode::Allow,
5424                always_allow: vec![],
5425                always_deny: vec![
5426                    agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5427                ],
5428                always_confirm: vec![],
5429                invalid_patterns: vec![],
5430            },
5431        );
5432        agent_settings::AgentSettings::override_global(settings, cx);
5433    });
5434
5435    #[allow(clippy::arc_with_non_send_sync)]
5436    let tool = Arc::new(crate::CopyPathTool::new(project));
5437    let (event_stream, _rx) = crate::ToolCallEventStream::test();
5438
5439    let task = cx.update(|cx| {
5440        tool.run(
5441            crate::CopyPathToolInput {
5442                source_path: "root/confidential.txt".to_string(),
5443                destination_path: "root/dest/copy.txt".to_string(),
5444            },
5445            event_stream,
5446            cx,
5447        )
5448    });
5449
5450    let result = task.await;
5451    assert!(result.is_err(), "expected copy to be blocked");
5452    assert!(
5453        result.unwrap_err().to_string().contains("blocked"),
5454        "error should mention the copy was blocked"
5455    );
5456}
5457
5458#[gpui::test]
5459async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5460    init_test(cx);
5461
5462    let fs = FakeFs::new(cx.executor());
5463    fs.insert_tree(
5464        "/root",
5465        json!({
5466            "normal.txt": "normal content",
5467            "readonly": {
5468                "config.txt": "readonly content"
5469            }
5470        }),
5471    )
5472    .await;
5473    let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5474
5475    cx.update(|cx| {
5476        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5477        settings.tool_permissions.tools.insert(
5478            "save_file".into(),
5479            agent_settings::ToolRules {
5480                default_mode: settings::ToolPermissionMode::Allow,
5481                always_allow: vec![],
5482                always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5483                always_confirm: vec![],
5484                invalid_patterns: vec![],
5485            },
5486        );
5487        agent_settings::AgentSettings::override_global(settings, cx);
5488    });
5489
5490    #[allow(clippy::arc_with_non_send_sync)]
5491    let tool = Arc::new(crate::SaveFileTool::new(project));
5492    let (event_stream, _rx) = crate::ToolCallEventStream::test();
5493
5494    let task = cx.update(|cx| {
5495        tool.run(
5496            crate::SaveFileToolInput {
5497                paths: vec![
5498                    std::path::PathBuf::from("root/normal.txt"),
5499                    std::path::PathBuf::from("root/readonly/config.txt"),
5500                ],
5501            },
5502            event_stream,
5503            cx,
5504        )
5505    });
5506
5507    let result = task.await;
5508    assert!(
5509        result.is_err(),
5510        "expected save to be blocked due to denied path"
5511    );
5512    assert!(
5513        result.unwrap_err().to_string().contains("blocked"),
5514        "error should mention the save was blocked"
5515    );
5516}
5517
5518#[gpui::test]
5519async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5520    init_test(cx);
5521
5522    let fs = FakeFs::new(cx.executor());
5523    fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5524        .await;
5525    let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5526
5527    cx.update(|cx| {
5528        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5529        settings.always_allow_tool_actions = false;
5530        settings.tool_permissions.tools.insert(
5531            "save_file".into(),
5532            agent_settings::ToolRules {
5533                default_mode: settings::ToolPermissionMode::Allow,
5534                always_allow: vec![],
5535                always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5536                always_confirm: vec![],
5537                invalid_patterns: vec![],
5538            },
5539        );
5540        agent_settings::AgentSettings::override_global(settings, cx);
5541    });
5542
5543    #[allow(clippy::arc_with_non_send_sync)]
5544    let tool = Arc::new(crate::SaveFileTool::new(project));
5545    let (event_stream, _rx) = crate::ToolCallEventStream::test();
5546
5547    let task = cx.update(|cx| {
5548        tool.run(
5549            crate::SaveFileToolInput {
5550                paths: vec![std::path::PathBuf::from("root/config.secret")],
5551            },
5552            event_stream,
5553            cx,
5554        )
5555    });
5556
5557    let result = task.await;
5558    assert!(result.is_err(), "expected save to be blocked");
5559    assert!(
5560        result.unwrap_err().to_string().contains("blocked"),
5561        "error should mention the save was blocked"
5562    );
5563}
5564
5565#[gpui::test]
5566async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5567    init_test(cx);
5568
5569    cx.update(|cx| {
5570        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5571        settings.tool_permissions.tools.insert(
5572            "web_search".into(),
5573            agent_settings::ToolRules {
5574                default_mode: settings::ToolPermissionMode::Allow,
5575                always_allow: vec![],
5576                always_deny: vec![
5577                    agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5578                ],
5579                always_confirm: vec![],
5580                invalid_patterns: vec![],
5581            },
5582        );
5583        agent_settings::AgentSettings::override_global(settings, cx);
5584    });
5585
5586    #[allow(clippy::arc_with_non_send_sync)]
5587    let tool = Arc::new(crate::WebSearchTool);
5588    let (event_stream, _rx) = crate::ToolCallEventStream::test();
5589
5590    let input: crate::WebSearchToolInput =
5591        serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5592
5593    let task = cx.update(|cx| tool.run(input, event_stream, cx));
5594
5595    let result = task.await;
5596    assert!(result.is_err(), "expected search to be blocked");
5597    assert!(
5598        result.unwrap_err().to_string().contains("blocked"),
5599        "error should mention the search was blocked"
5600    );
5601}
5602
5603#[gpui::test]
5604async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5605    init_test(cx);
5606
5607    let fs = FakeFs::new(cx.executor());
5608    fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5609        .await;
5610    let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5611
5612    cx.update(|cx| {
5613        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5614        settings.always_allow_tool_actions = false;
5615        settings.tool_permissions.tools.insert(
5616            "edit_file".into(),
5617            agent_settings::ToolRules {
5618                default_mode: settings::ToolPermissionMode::Confirm,
5619                always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5620                always_deny: vec![],
5621                always_confirm: vec![],
5622                invalid_patterns: vec![],
5623            },
5624        );
5625        agent_settings::AgentSettings::override_global(settings, cx);
5626    });
5627
5628    let context_server_registry =
5629        cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5630    let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5631    let templates = crate::Templates::new();
5632    let thread = cx.new(|cx| {
5633        crate::Thread::new(
5634            project.clone(),
5635            cx.new(|_cx| prompt_store::ProjectContext::default()),
5636            context_server_registry,
5637            templates.clone(),
5638            None,
5639            cx,
5640        )
5641    });
5642
5643    #[allow(clippy::arc_with_non_send_sync)]
5644    let tool = Arc::new(crate::EditFileTool::new(
5645        project,
5646        thread.downgrade(),
5647        language_registry,
5648        templates,
5649    ));
5650    let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5651
5652    let _task = cx.update(|cx| {
5653        tool.run(
5654            crate::EditFileToolInput {
5655                display_description: "Edit README".to_string(),
5656                path: "root/README.md".into(),
5657                mode: crate::EditFileMode::Edit,
5658            },
5659            event_stream,
5660            cx,
5661        )
5662    });
5663
5664    cx.run_until_parked();
5665
5666    let event = rx.try_next();
5667    assert!(
5668        !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5669        "expected no authorization request for allowed .md file"
5670    );
5671}
5672
5673#[gpui::test]
5674async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5675    init_test(cx);
5676
5677    cx.update(|cx| {
5678        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5679        settings.tool_permissions.tools.insert(
5680            "fetch".into(),
5681            agent_settings::ToolRules {
5682                default_mode: settings::ToolPermissionMode::Allow,
5683                always_allow: vec![],
5684                always_deny: vec![
5685                    agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5686                ],
5687                always_confirm: vec![],
5688                invalid_patterns: vec![],
5689            },
5690        );
5691        agent_settings::AgentSettings::override_global(settings, cx);
5692    });
5693
5694    let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5695
5696    #[allow(clippy::arc_with_non_send_sync)]
5697    let tool = Arc::new(crate::FetchTool::new(http_client));
5698    let (event_stream, _rx) = crate::ToolCallEventStream::test();
5699
5700    let input: crate::FetchToolInput =
5701        serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5702
5703    let task = cx.update(|cx| tool.run(input, event_stream, cx));
5704
5705    let result = task.await;
5706    assert!(result.is_err(), "expected fetch to be blocked");
5707    assert!(
5708        result.unwrap_err().to_string().contains("blocked"),
5709        "error should mention the fetch was blocked"
5710    );
5711}
5712
5713#[gpui::test]
5714async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5715    init_test(cx);
5716
5717    cx.update(|cx| {
5718        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5719        settings.always_allow_tool_actions = false;
5720        settings.tool_permissions.tools.insert(
5721            "fetch".into(),
5722            agent_settings::ToolRules {
5723                default_mode: settings::ToolPermissionMode::Confirm,
5724                always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5725                always_deny: vec![],
5726                always_confirm: vec![],
5727                invalid_patterns: vec![],
5728            },
5729        );
5730        agent_settings::AgentSettings::override_global(settings, cx);
5731    });
5732
5733    let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5734
5735    #[allow(clippy::arc_with_non_send_sync)]
5736    let tool = Arc::new(crate::FetchTool::new(http_client));
5737    let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5738
5739    let input: crate::FetchToolInput =
5740        serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5741
5742    let _task = cx.update(|cx| tool.run(input, event_stream, cx));
5743
5744    cx.run_until_parked();
5745
5746    let event = rx.try_next();
5747    assert!(
5748        !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5749        "expected no authorization request for allowed docs.rs URL"
5750    );
5751}