mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use agent_client_protocol::{self as acp};
   4use agent_settings::AgentProfileId;
   5use anyhow::Result;
   6use client::{Client, UserStore};
   7use cloud_llm_client::CompletionIntent;
   8use collections::IndexMap;
   9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  10use feature_flags::FeatureFlagAppExt as _;
  11use fs::{FakeFs, Fs};
  12use futures::{
  13    FutureExt as _, StreamExt,
  14    channel::{
  15        mpsc::{self, UnboundedReceiver},
  16        oneshot,
  17    },
  18    future::{Fuse, Shared},
  19};
  20use gpui::{
  21    App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
  22    http_client::FakeHttpClient,
  23};
  24use indoc::indoc;
  25use language_model::{
  26    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  27    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
  28    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
  29    LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
  30};
  31use pretty_assertions::assert_eq;
  32use project::{
  33    Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
  34};
  35use prompt_store::ProjectContext;
  36use reqwest_client::ReqwestClient;
  37use schemars::JsonSchema;
  38use serde::{Deserialize, Serialize};
  39use serde_json::json;
  40use settings::{Settings, SettingsStore};
  41use std::{
  42    path::Path,
  43    pin::Pin,
  44    rc::Rc,
  45    sync::{
  46        Arc,
  47        atomic::{AtomicBool, Ordering},
  48    },
  49    time::Duration,
  50};
  51use util::path;
  52
  53mod test_tools;
  54use test_tools::*;
  55
  56fn init_test(cx: &mut TestAppContext) {
  57    cx.update(|cx| {
  58        let settings_store = SettingsStore::test(cx);
  59        cx.set_global(settings_store);
  60    });
  61}
  62
  63struct FakeTerminalHandle {
  64    killed: Arc<AtomicBool>,
  65    stopped_by_user: Arc<AtomicBool>,
  66    exit_sender: std::cell::RefCell<Option<futures::channel::oneshot::Sender<()>>>,
  67    wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
  68    output: acp::TerminalOutputResponse,
  69    id: acp::TerminalId,
  70}
  71
  72impl FakeTerminalHandle {
  73    fn new_never_exits(cx: &mut App) -> Self {
  74        let killed = Arc::new(AtomicBool::new(false));
  75        let stopped_by_user = Arc::new(AtomicBool::new(false));
  76
  77        let (exit_sender, exit_receiver) = futures::channel::oneshot::channel();
  78
  79        let wait_for_exit = cx
  80            .spawn(async move |_cx| {
  81                // Wait for the exit signal (sent when kill() is called)
  82                let _ = exit_receiver.await;
  83                acp::TerminalExitStatus::new()
  84            })
  85            .shared();
  86
  87        Self {
  88            killed,
  89            stopped_by_user,
  90            exit_sender: std::cell::RefCell::new(Some(exit_sender)),
  91            wait_for_exit,
  92            output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
  93            id: acp::TerminalId::new("fake_terminal".to_string()),
  94        }
  95    }
  96
  97    fn new_with_immediate_exit(cx: &mut App, exit_code: u32) -> Self {
  98        let killed = Arc::new(AtomicBool::new(false));
  99        let stopped_by_user = Arc::new(AtomicBool::new(false));
 100        let (exit_sender, _exit_receiver) = futures::channel::oneshot::channel();
 101
 102        let wait_for_exit = cx
 103            .spawn(async move |_cx| acp::TerminalExitStatus::new().exit_code(exit_code))
 104            .shared();
 105
 106        Self {
 107            killed,
 108            stopped_by_user,
 109            exit_sender: std::cell::RefCell::new(Some(exit_sender)),
 110            wait_for_exit,
 111            output: acp::TerminalOutputResponse::new("command output".to_string(), false),
 112            id: acp::TerminalId::new("fake_terminal".to_string()),
 113        }
 114    }
 115
 116    fn was_killed(&self) -> bool {
 117        self.killed.load(Ordering::SeqCst)
 118    }
 119
 120    fn set_stopped_by_user(&self, stopped: bool) {
 121        self.stopped_by_user.store(stopped, Ordering::SeqCst);
 122    }
 123
 124    fn signal_exit(&self) {
 125        if let Some(sender) = self.exit_sender.borrow_mut().take() {
 126            let _ = sender.send(());
 127        }
 128    }
 129}
 130
 131impl crate::TerminalHandle for FakeTerminalHandle {
 132    fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
 133        Ok(self.id.clone())
 134    }
 135
 136    fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
 137        Ok(self.output.clone())
 138    }
 139
 140    fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
 141        Ok(self.wait_for_exit.clone())
 142    }
 143
 144    fn kill(&self, _cx: &AsyncApp) -> Result<()> {
 145        self.killed.store(true, Ordering::SeqCst);
 146        self.signal_exit();
 147        Ok(())
 148    }
 149
 150    fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
 151        Ok(self.stopped_by_user.load(Ordering::SeqCst))
 152    }
 153}
 154
 155struct FakeThreadEnvironment {
 156    handle: Rc<FakeTerminalHandle>,
 157}
 158
 159impl crate::ThreadEnvironment for FakeThreadEnvironment {
 160    fn create_terminal(
 161        &self,
 162        _command: String,
 163        _cwd: Option<std::path::PathBuf>,
 164        _output_byte_limit: Option<u64>,
 165        _cx: &mut AsyncApp,
 166    ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
 167        Task::ready(Ok(self.handle.clone() as Rc<dyn crate::TerminalHandle>))
 168    }
 169}
 170
 171/// Environment that creates multiple independent terminal handles for testing concurrent terminals.
 172struct MultiTerminalEnvironment {
 173    handles: std::cell::RefCell<Vec<Rc<FakeTerminalHandle>>>,
 174}
 175
 176impl MultiTerminalEnvironment {
 177    fn new() -> Self {
 178        Self {
 179            handles: std::cell::RefCell::new(Vec::new()),
 180        }
 181    }
 182
 183    fn handles(&self) -> Vec<Rc<FakeTerminalHandle>> {
 184        self.handles.borrow().clone()
 185    }
 186}
 187
 188impl crate::ThreadEnvironment for MultiTerminalEnvironment {
 189    fn create_terminal(
 190        &self,
 191        _command: String,
 192        _cwd: Option<std::path::PathBuf>,
 193        _output_byte_limit: Option<u64>,
 194        cx: &mut AsyncApp,
 195    ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
 196        let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
 197        self.handles.borrow_mut().push(handle.clone());
 198        Task::ready(Ok(handle as Rc<dyn crate::TerminalHandle>))
 199    }
 200}
 201
 202fn always_allow_tools(cx: &mut TestAppContext) {
 203    cx.update(|cx| {
 204        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
 205        settings.always_allow_tool_actions = true;
 206        agent_settings::AgentSettings::override_global(settings, cx);
 207    });
 208}
 209
 210#[gpui::test]
 211async fn test_echo(cx: &mut TestAppContext) {
 212    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 213    let fake_model = model.as_fake();
 214
 215    let events = thread
 216        .update(cx, |thread, cx| {
 217            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
 218        })
 219        .unwrap();
 220    cx.run_until_parked();
 221    fake_model.send_last_completion_stream_text_chunk("Hello");
 222    fake_model
 223        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
 224    fake_model.end_last_completion_stream();
 225
 226    let events = events.collect().await;
 227    thread.update(cx, |thread, _cx| {
 228        assert_eq!(
 229            thread.last_message().unwrap().to_markdown(),
 230            indoc! {"
 231                ## Assistant
 232
 233                Hello
 234            "}
 235        )
 236    });
 237    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 238}
 239
 240#[gpui::test]
 241async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
 242    init_test(cx);
 243    always_allow_tools(cx);
 244
 245    let fs = FakeFs::new(cx.executor());
 246    let project = Project::test(fs, [], cx).await;
 247
 248    let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
 249    let environment = Rc::new(FakeThreadEnvironment {
 250        handle: handle.clone(),
 251    });
 252
 253    #[allow(clippy::arc_with_non_send_sync)]
 254    let tool = Arc::new(crate::TerminalTool::new(project, environment));
 255    let (event_stream, mut rx) = crate::ToolCallEventStream::test();
 256
 257    let task = cx.update(|cx| {
 258        tool.run(
 259            crate::TerminalToolInput {
 260                command: "sleep 1000".to_string(),
 261                cd: ".".to_string(),
 262                timeout_ms: Some(5),
 263            },
 264            event_stream,
 265            cx,
 266        )
 267    });
 268
 269    let update = rx.expect_update_fields().await;
 270    assert!(
 271        update.content.iter().any(|blocks| {
 272            blocks
 273                .iter()
 274                .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
 275        }),
 276        "expected tool call update to include terminal content"
 277    );
 278
 279    let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
 280
 281    let deadline = std::time::Instant::now() + Duration::from_millis(500);
 282    loop {
 283        if let Some(result) = task_future.as_mut().now_or_never() {
 284            let result = result.expect("terminal tool task should complete");
 285
 286            assert!(
 287                handle.was_killed(),
 288                "expected terminal handle to be killed on timeout"
 289            );
 290            assert!(
 291                result.contains("partial output"),
 292                "expected result to include terminal output, got: {result}"
 293            );
 294            return;
 295        }
 296
 297        if std::time::Instant::now() >= deadline {
 298            panic!("timed out waiting for terminal tool task to complete");
 299        }
 300
 301        cx.run_until_parked();
 302        cx.background_executor.timer(Duration::from_millis(1)).await;
 303    }
 304}
 305
 306#[gpui::test]
 307#[ignore]
 308async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
 309    init_test(cx);
 310    always_allow_tools(cx);
 311
 312    let fs = FakeFs::new(cx.executor());
 313    let project = Project::test(fs, [], cx).await;
 314
 315    let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
 316    let environment = Rc::new(FakeThreadEnvironment {
 317        handle: handle.clone(),
 318    });
 319
 320    #[allow(clippy::arc_with_non_send_sync)]
 321    let tool = Arc::new(crate::TerminalTool::new(project, environment));
 322    let (event_stream, mut rx) = crate::ToolCallEventStream::test();
 323
 324    let _task = cx.update(|cx| {
 325        tool.run(
 326            crate::TerminalToolInput {
 327                command: "sleep 1000".to_string(),
 328                cd: ".".to_string(),
 329                timeout_ms: None,
 330            },
 331            event_stream,
 332            cx,
 333        )
 334    });
 335
 336    let update = rx.expect_update_fields().await;
 337    assert!(
 338        update.content.iter().any(|blocks| {
 339            blocks
 340                .iter()
 341                .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
 342        }),
 343        "expected tool call update to include terminal content"
 344    );
 345
 346    cx.background_executor
 347        .timer(Duration::from_millis(25))
 348        .await;
 349
 350    assert!(
 351        !handle.was_killed(),
 352        "did not expect terminal handle to be killed without a timeout"
 353    );
 354}
 355
 356#[gpui::test]
 357async fn test_thinking(cx: &mut TestAppContext) {
 358    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 359    let fake_model = model.as_fake();
 360
 361    let events = thread
 362        .update(cx, |thread, cx| {
 363            thread.send(
 364                UserMessageId::new(),
 365                [indoc! {"
 366                    Testing:
 367
 368                    Generate a thinking step where you just think the word 'Think',
 369                    and have your final answer be 'Hello'
 370                "}],
 371                cx,
 372            )
 373        })
 374        .unwrap();
 375    cx.run_until_parked();
 376    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
 377        text: "Think".to_string(),
 378        signature: None,
 379    });
 380    fake_model.send_last_completion_stream_text_chunk("Hello");
 381    fake_model
 382        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
 383    fake_model.end_last_completion_stream();
 384
 385    let events = events.collect().await;
 386    thread.update(cx, |thread, _cx| {
 387        assert_eq!(
 388            thread.last_message().unwrap().to_markdown(),
 389            indoc! {"
 390                ## Assistant
 391
 392                <think>Think</think>
 393                Hello
 394            "}
 395        )
 396    });
 397    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 398}
 399
 400#[gpui::test]
 401async fn test_system_prompt(cx: &mut TestAppContext) {
 402    let ThreadTest {
 403        model,
 404        thread,
 405        project_context,
 406        ..
 407    } = setup(cx, TestModel::Fake).await;
 408    let fake_model = model.as_fake();
 409
 410    project_context.update(cx, |project_context, _cx| {
 411        project_context.shell = "test-shell".into()
 412    });
 413    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 414    thread
 415        .update(cx, |thread, cx| {
 416            thread.send(UserMessageId::new(), ["abc"], cx)
 417        })
 418        .unwrap();
 419    cx.run_until_parked();
 420    let mut pending_completions = fake_model.pending_completions();
 421    assert_eq!(
 422        pending_completions.len(),
 423        1,
 424        "unexpected pending completions: {:?}",
 425        pending_completions
 426    );
 427
 428    let pending_completion = pending_completions.pop().unwrap();
 429    assert_eq!(pending_completion.messages[0].role, Role::System);
 430
 431    let system_message = &pending_completion.messages[0];
 432    let system_prompt = system_message.content[0].to_str().unwrap();
 433    assert!(
 434        system_prompt.contains("test-shell"),
 435        "unexpected system message: {:?}",
 436        system_message
 437    );
 438    assert!(
 439        system_prompt.contains("## Fixing Diagnostics"),
 440        "unexpected system message: {:?}",
 441        system_message
 442    );
 443}
 444
 445#[gpui::test]
 446async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
 447    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 448    let fake_model = model.as_fake();
 449
 450    thread
 451        .update(cx, |thread, cx| {
 452            thread.send(UserMessageId::new(), ["abc"], cx)
 453        })
 454        .unwrap();
 455    cx.run_until_parked();
 456    let mut pending_completions = fake_model.pending_completions();
 457    assert_eq!(
 458        pending_completions.len(),
 459        1,
 460        "unexpected pending completions: {:?}",
 461        pending_completions
 462    );
 463
 464    let pending_completion = pending_completions.pop().unwrap();
 465    assert_eq!(pending_completion.messages[0].role, Role::System);
 466
 467    let system_message = &pending_completion.messages[0];
 468    let system_prompt = system_message.content[0].to_str().unwrap();
 469    assert!(
 470        !system_prompt.contains("## Tool Use"),
 471        "unexpected system message: {:?}",
 472        system_message
 473    );
 474    assert!(
 475        !system_prompt.contains("## Fixing Diagnostics"),
 476        "unexpected system message: {:?}",
 477        system_message
 478    );
 479}
 480
 481#[gpui::test]
 482async fn test_prompt_caching(cx: &mut TestAppContext) {
 483    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 484    let fake_model = model.as_fake();
 485
 486    // Send initial user message and verify it's cached
 487    thread
 488        .update(cx, |thread, cx| {
 489            thread.send(UserMessageId::new(), ["Message 1"], cx)
 490        })
 491        .unwrap();
 492    cx.run_until_parked();
 493
 494    let completion = fake_model.pending_completions().pop().unwrap();
 495    assert_eq!(
 496        completion.messages[1..],
 497        vec![LanguageModelRequestMessage {
 498            role: Role::User,
 499            content: vec!["Message 1".into()],
 500            cache: true,
 501            reasoning_details: None,
 502        }]
 503    );
 504    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 505        "Response to Message 1".into(),
 506    ));
 507    fake_model.end_last_completion_stream();
 508    cx.run_until_parked();
 509
 510    // Send another user message and verify only the latest is cached
 511    thread
 512        .update(cx, |thread, cx| {
 513            thread.send(UserMessageId::new(), ["Message 2"], cx)
 514        })
 515        .unwrap();
 516    cx.run_until_parked();
 517
 518    let completion = fake_model.pending_completions().pop().unwrap();
 519    assert_eq!(
 520        completion.messages[1..],
 521        vec![
 522            LanguageModelRequestMessage {
 523                role: Role::User,
 524                content: vec!["Message 1".into()],
 525                cache: false,
 526                reasoning_details: None,
 527            },
 528            LanguageModelRequestMessage {
 529                role: Role::Assistant,
 530                content: vec!["Response to Message 1".into()],
 531                cache: false,
 532                reasoning_details: None,
 533            },
 534            LanguageModelRequestMessage {
 535                role: Role::User,
 536                content: vec!["Message 2".into()],
 537                cache: true,
 538                reasoning_details: None,
 539            }
 540        ]
 541    );
 542    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 543        "Response to Message 2".into(),
 544    ));
 545    fake_model.end_last_completion_stream();
 546    cx.run_until_parked();
 547
 548    // Simulate a tool call and verify that the latest tool result is cached
 549    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 550    thread
 551        .update(cx, |thread, cx| {
 552            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 553        })
 554        .unwrap();
 555    cx.run_until_parked();
 556
 557    let tool_use = LanguageModelToolUse {
 558        id: "tool_1".into(),
 559        name: EchoTool::name().into(),
 560        raw_input: json!({"text": "test"}).to_string(),
 561        input: json!({"text": "test"}),
 562        is_input_complete: true,
 563        thought_signature: None,
 564    };
 565    fake_model
 566        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 567    fake_model.end_last_completion_stream();
 568    cx.run_until_parked();
 569
 570    let completion = fake_model.pending_completions().pop().unwrap();
 571    let tool_result = LanguageModelToolResult {
 572        tool_use_id: "tool_1".into(),
 573        tool_name: EchoTool::name().into(),
 574        is_error: false,
 575        content: "test".into(),
 576        output: Some("test".into()),
 577    };
 578    assert_eq!(
 579        completion.messages[1..],
 580        vec![
 581            LanguageModelRequestMessage {
 582                role: Role::User,
 583                content: vec!["Message 1".into()],
 584                cache: false,
 585                reasoning_details: None,
 586            },
 587            LanguageModelRequestMessage {
 588                role: Role::Assistant,
 589                content: vec!["Response to Message 1".into()],
 590                cache: false,
 591                reasoning_details: None,
 592            },
 593            LanguageModelRequestMessage {
 594                role: Role::User,
 595                content: vec!["Message 2".into()],
 596                cache: false,
 597                reasoning_details: None,
 598            },
 599            LanguageModelRequestMessage {
 600                role: Role::Assistant,
 601                content: vec!["Response to Message 2".into()],
 602                cache: false,
 603                reasoning_details: None,
 604            },
 605            LanguageModelRequestMessage {
 606                role: Role::User,
 607                content: vec!["Use the echo tool".into()],
 608                cache: false,
 609                reasoning_details: None,
 610            },
 611            LanguageModelRequestMessage {
 612                role: Role::Assistant,
 613                content: vec![MessageContent::ToolUse(tool_use)],
 614                cache: false,
 615                reasoning_details: None,
 616            },
 617            LanguageModelRequestMessage {
 618                role: Role::User,
 619                content: vec![MessageContent::ToolResult(tool_result)],
 620                cache: true,
 621                reasoning_details: None,
 622            }
 623        ]
 624    );
 625}
 626
 627#[gpui::test]
 628#[cfg_attr(not(feature = "e2e"), ignore)]
 629async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 630    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 631
 632    // Test a tool call that's likely to complete *before* streaming stops.
 633    let events = thread
 634        .update(cx, |thread, cx| {
 635            thread.add_tool(EchoTool);
 636            thread.send(
 637                UserMessageId::new(),
 638                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 639                cx,
 640            )
 641        })
 642        .unwrap()
 643        .collect()
 644        .await;
 645    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 646
 647    // Test a tool calls that's likely to complete *after* streaming stops.
 648    let events = thread
 649        .update(cx, |thread, cx| {
 650            thread.remove_tool(&EchoTool::name());
 651            thread.add_tool(DelayTool);
 652            thread.send(
 653                UserMessageId::new(),
 654                [
 655                    "Now call the delay tool with 200ms.",
 656                    "When the timer goes off, then you echo the output of the tool.",
 657                ],
 658                cx,
 659            )
 660        })
 661        .unwrap()
 662        .collect()
 663        .await;
 664    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 665    thread.update(cx, |thread, _cx| {
 666        assert!(
 667            thread
 668                .last_message()
 669                .unwrap()
 670                .as_agent_message()
 671                .unwrap()
 672                .content
 673                .iter()
 674                .any(|content| {
 675                    if let AgentMessageContent::Text(text) = content {
 676                        text.contains("Ding")
 677                    } else {
 678                        false
 679                    }
 680                }),
 681            "{}",
 682            thread.to_markdown()
 683        );
 684    });
 685}
 686
 687#[gpui::test]
 688#[cfg_attr(not(feature = "e2e"), ignore)]
 689async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 690    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 691
 692    // Test a tool call that's likely to complete *before* streaming stops.
 693    let mut events = thread
 694        .update(cx, |thread, cx| {
 695            thread.add_tool(WordListTool);
 696            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 697        })
 698        .unwrap();
 699
 700    let mut saw_partial_tool_use = false;
 701    while let Some(event) = events.next().await {
 702        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 703            thread.update(cx, |thread, _cx| {
 704                // Look for a tool use in the thread's last message
 705                let message = thread.last_message().unwrap();
 706                let agent_message = message.as_agent_message().unwrap();
 707                let last_content = agent_message.content.last().unwrap();
 708                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 709                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 710                    if tool_call.status == acp::ToolCallStatus::Pending {
 711                        if !last_tool_use.is_input_complete
 712                            && last_tool_use.input.get("g").is_none()
 713                        {
 714                            saw_partial_tool_use = true;
 715                        }
 716                    } else {
 717                        last_tool_use
 718                            .input
 719                            .get("a")
 720                            .expect("'a' has streamed because input is now complete");
 721                        last_tool_use
 722                            .input
 723                            .get("g")
 724                            .expect("'g' has streamed because input is now complete");
 725                    }
 726                } else {
 727                    panic!("last content should be a tool use");
 728                }
 729            });
 730        }
 731    }
 732
 733    assert!(
 734        saw_partial_tool_use,
 735        "should see at least one partially streamed tool use in the history"
 736    );
 737}
 738
 739#[gpui::test]
 740async fn test_tool_authorization(cx: &mut TestAppContext) {
 741    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 742    let fake_model = model.as_fake();
 743
 744    let mut events = thread
 745        .update(cx, |thread, cx| {
 746            thread.add_tool(ToolRequiringPermission);
 747            thread.send(UserMessageId::new(), ["abc"], cx)
 748        })
 749        .unwrap();
 750    cx.run_until_parked();
 751    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 752        LanguageModelToolUse {
 753            id: "tool_id_1".into(),
 754            name: ToolRequiringPermission::name().into(),
 755            raw_input: "{}".into(),
 756            input: json!({}),
 757            is_input_complete: true,
 758            thought_signature: None,
 759        },
 760    ));
 761    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 762        LanguageModelToolUse {
 763            id: "tool_id_2".into(),
 764            name: ToolRequiringPermission::name().into(),
 765            raw_input: "{}".into(),
 766            input: json!({}),
 767            is_input_complete: true,
 768            thought_signature: None,
 769        },
 770    ));
 771    fake_model.end_last_completion_stream();
 772    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 773    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 774
 775    // Approve the first - send "allow" option_id (UI transforms "once" to "allow")
 776    tool_call_auth_1
 777        .response
 778        .send(acp::PermissionOptionId::new("allow"))
 779        .unwrap();
 780    cx.run_until_parked();
 781
 782    // Reject the second - send "deny" option_id directly since Deny is now a button
 783    tool_call_auth_2
 784        .response
 785        .send(acp::PermissionOptionId::new("deny"))
 786        .unwrap();
 787    cx.run_until_parked();
 788
 789    let completion = fake_model.pending_completions().pop().unwrap();
 790    let message = completion.messages.last().unwrap();
 791    assert_eq!(
 792        message.content,
 793        vec![
 794            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 795                tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
 796                tool_name: ToolRequiringPermission::name().into(),
 797                is_error: false,
 798                content: "Allowed".into(),
 799                output: Some("Allowed".into())
 800            }),
 801            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 802                tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
 803                tool_name: ToolRequiringPermission::name().into(),
 804                is_error: true,
 805                content: "Permission to run tool denied by user".into(),
 806                output: Some("Permission to run tool denied by user".into())
 807            })
 808        ]
 809    );
 810
 811    // Simulate yet another tool call.
 812    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 813        LanguageModelToolUse {
 814            id: "tool_id_3".into(),
 815            name: ToolRequiringPermission::name().into(),
 816            raw_input: "{}".into(),
 817            input: json!({}),
 818            is_input_complete: true,
 819            thought_signature: None,
 820        },
 821    ));
 822    fake_model.end_last_completion_stream();
 823
 824    // Respond by always allowing tools - send transformed option_id
 825    // (UI transforms "always:tool_requiring_permission" to "always_allow:tool_requiring_permission")
 826    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 827    tool_call_auth_3
 828        .response
 829        .send(acp::PermissionOptionId::new(
 830            "always_allow:tool_requiring_permission",
 831        ))
 832        .unwrap();
 833    cx.run_until_parked();
 834    let completion = fake_model.pending_completions().pop().unwrap();
 835    let message = completion.messages.last().unwrap();
 836    assert_eq!(
 837        message.content,
 838        vec![language_model::MessageContent::ToolResult(
 839            LanguageModelToolResult {
 840                tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
 841                tool_name: ToolRequiringPermission::name().into(),
 842                is_error: false,
 843                content: "Allowed".into(),
 844                output: Some("Allowed".into())
 845            }
 846        )]
 847    );
 848
 849    // Simulate a final tool call, ensuring we don't trigger authorization.
 850    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 851        LanguageModelToolUse {
 852            id: "tool_id_4".into(),
 853            name: ToolRequiringPermission::name().into(),
 854            raw_input: "{}".into(),
 855            input: json!({}),
 856            is_input_complete: true,
 857            thought_signature: None,
 858        },
 859    ));
 860    fake_model.end_last_completion_stream();
 861    cx.run_until_parked();
 862    let completion = fake_model.pending_completions().pop().unwrap();
 863    let message = completion.messages.last().unwrap();
 864    assert_eq!(
 865        message.content,
 866        vec![language_model::MessageContent::ToolResult(
 867            LanguageModelToolResult {
 868                tool_use_id: "tool_id_4".into(),
 869                tool_name: ToolRequiringPermission::name().into(),
 870                is_error: false,
 871                content: "Allowed".into(),
 872                output: Some("Allowed".into())
 873            }
 874        )]
 875    );
 876}
 877
 878#[gpui::test]
 879async fn test_tool_hallucination(cx: &mut TestAppContext) {
 880    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 881    let fake_model = model.as_fake();
 882
 883    let mut events = thread
 884        .update(cx, |thread, cx| {
 885            thread.send(UserMessageId::new(), ["abc"], cx)
 886        })
 887        .unwrap();
 888    cx.run_until_parked();
 889    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 890        LanguageModelToolUse {
 891            id: "tool_id_1".into(),
 892            name: "nonexistent_tool".into(),
 893            raw_input: "{}".into(),
 894            input: json!({}),
 895            is_input_complete: true,
 896            thought_signature: None,
 897        },
 898    ));
 899    fake_model.end_last_completion_stream();
 900
 901    let tool_call = expect_tool_call(&mut events).await;
 902    assert_eq!(tool_call.title, "nonexistent_tool");
 903    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 904    let update = expect_tool_call_update_fields(&mut events).await;
 905    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 906}
 907
 908async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
 909    let event = events
 910        .next()
 911        .await
 912        .expect("no tool call authorization event received")
 913        .unwrap();
 914    match event {
 915        ThreadEvent::ToolCall(tool_call) => tool_call,
 916        event => {
 917            panic!("Unexpected event {event:?}");
 918        }
 919    }
 920}
 921
 922async fn expect_tool_call_update_fields(
 923    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 924) -> acp::ToolCallUpdate {
 925    let event = events
 926        .next()
 927        .await
 928        .expect("no tool call authorization event received")
 929        .unwrap();
 930    match event {
 931        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
 932        event => {
 933            panic!("Unexpected event {event:?}");
 934        }
 935    }
 936}
 937
 938async fn next_tool_call_authorization(
 939    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 940) -> ToolCallAuthorization {
 941    loop {
 942        let event = events
 943            .next()
 944            .await
 945            .expect("no tool call authorization event received")
 946            .unwrap();
 947        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
 948            let permission_kinds = tool_call_authorization
 949                .options
 950                .iter()
 951                .map(|o| o.kind)
 952                .collect::<Vec<_>>();
 953            // Only 2 options now: AllowAlways (for tool) and AllowOnce (granularity only)
 954            // Deny is handled by the UI buttons, not as a separate option
 955            assert_eq!(
 956                permission_kinds,
 957                vec![
 958                    acp::PermissionOptionKind::AllowAlways,
 959                    acp::PermissionOptionKind::AllowOnce,
 960                ]
 961            );
 962            return tool_call_authorization;
 963        }
 964    }
 965}
 966
 967#[gpui::test]
 968#[cfg_attr(not(feature = "e2e"), ignore)]
 969async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 970    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 971
 972    // Test concurrent tool calls with different delay times
 973    let events = thread
 974        .update(cx, |thread, cx| {
 975            thread.add_tool(DelayTool);
 976            thread.send(
 977                UserMessageId::new(),
 978                [
 979                    "Call the delay tool twice in the same message.",
 980                    "Once with 100ms. Once with 300ms.",
 981                    "When both timers are complete, describe the outputs.",
 982                ],
 983                cx,
 984            )
 985        })
 986        .unwrap()
 987        .collect()
 988        .await;
 989
 990    let stop_reasons = stop_events(events);
 991    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 992
 993    thread.update(cx, |thread, _cx| {
 994        let last_message = thread.last_message().unwrap();
 995        let agent_message = last_message.as_agent_message().unwrap();
 996        let text = agent_message
 997            .content
 998            .iter()
 999            .filter_map(|content| {
1000                if let AgentMessageContent::Text(text) = content {
1001                    Some(text.as_str())
1002                } else {
1003                    None
1004                }
1005            })
1006            .collect::<String>();
1007
1008        assert!(text.contains("Ding"));
1009    });
1010}
1011
1012#[gpui::test]
1013async fn test_profiles(cx: &mut TestAppContext) {
1014    let ThreadTest {
1015        model, thread, fs, ..
1016    } = setup(cx, TestModel::Fake).await;
1017    let fake_model = model.as_fake();
1018
1019    thread.update(cx, |thread, _cx| {
1020        thread.add_tool(DelayTool);
1021        thread.add_tool(EchoTool);
1022        thread.add_tool(InfiniteTool);
1023    });
1024
1025    // Override profiles and wait for settings to be loaded.
1026    fs.insert_file(
1027        paths::settings_file(),
1028        json!({
1029            "agent": {
1030                "profiles": {
1031                    "test-1": {
1032                        "name": "Test Profile 1",
1033                        "tools": {
1034                            EchoTool::name(): true,
1035                            DelayTool::name(): true,
1036                        }
1037                    },
1038                    "test-2": {
1039                        "name": "Test Profile 2",
1040                        "tools": {
1041                            InfiniteTool::name(): true,
1042                        }
1043                    }
1044                }
1045            }
1046        })
1047        .to_string()
1048        .into_bytes(),
1049    )
1050    .await;
1051    cx.run_until_parked();
1052
1053    // Test that test-1 profile (default) has echo and delay tools
1054    thread
1055        .update(cx, |thread, cx| {
1056            thread.set_profile(AgentProfileId("test-1".into()), cx);
1057            thread.send(UserMessageId::new(), ["test"], cx)
1058        })
1059        .unwrap();
1060    cx.run_until_parked();
1061
1062    let mut pending_completions = fake_model.pending_completions();
1063    assert_eq!(pending_completions.len(), 1);
1064    let completion = pending_completions.pop().unwrap();
1065    let tool_names: Vec<String> = completion
1066        .tools
1067        .iter()
1068        .map(|tool| tool.name.clone())
1069        .collect();
1070    assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
1071    fake_model.end_last_completion_stream();
1072
1073    // Switch to test-2 profile, and verify that it has only the infinite tool.
1074    thread
1075        .update(cx, |thread, cx| {
1076            thread.set_profile(AgentProfileId("test-2".into()), cx);
1077            thread.send(UserMessageId::new(), ["test2"], cx)
1078        })
1079        .unwrap();
1080    cx.run_until_parked();
1081    let mut pending_completions = fake_model.pending_completions();
1082    assert_eq!(pending_completions.len(), 1);
1083    let completion = pending_completions.pop().unwrap();
1084    let tool_names: Vec<String> = completion
1085        .tools
1086        .iter()
1087        .map(|tool| tool.name.clone())
1088        .collect();
1089    assert_eq!(tool_names, vec![InfiniteTool::name()]);
1090}
1091
1092#[gpui::test]
1093async fn test_mcp_tools(cx: &mut TestAppContext) {
1094    let ThreadTest {
1095        model,
1096        thread,
1097        context_server_store,
1098        fs,
1099        ..
1100    } = setup(cx, TestModel::Fake).await;
1101    let fake_model = model.as_fake();
1102
1103    // Override profiles and wait for settings to be loaded.
1104    fs.insert_file(
1105        paths::settings_file(),
1106        json!({
1107            "agent": {
1108                "always_allow_tool_actions": true,
1109                "profiles": {
1110                    "test": {
1111                        "name": "Test Profile",
1112                        "enable_all_context_servers": true,
1113                        "tools": {
1114                            EchoTool::name(): true,
1115                        }
1116                    },
1117                }
1118            }
1119        })
1120        .to_string()
1121        .into_bytes(),
1122    )
1123    .await;
1124    cx.run_until_parked();
1125    thread.update(cx, |thread, cx| {
1126        thread.set_profile(AgentProfileId("test".into()), cx)
1127    });
1128
1129    let mut mcp_tool_calls = setup_context_server(
1130        "test_server",
1131        vec![context_server::types::Tool {
1132            name: "echo".into(),
1133            description: None,
1134            input_schema: serde_json::to_value(EchoTool::input_schema(
1135                LanguageModelToolSchemaFormat::JsonSchema,
1136            ))
1137            .unwrap(),
1138            output_schema: None,
1139            annotations: None,
1140        }],
1141        &context_server_store,
1142        cx,
1143    );
1144
1145    let events = thread.update(cx, |thread, cx| {
1146        thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1147    });
1148    cx.run_until_parked();
1149
1150    // Simulate the model calling the MCP tool.
1151    let completion = fake_model.pending_completions().pop().unwrap();
1152    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1153    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1154        LanguageModelToolUse {
1155            id: "tool_1".into(),
1156            name: "echo".into(),
1157            raw_input: json!({"text": "test"}).to_string(),
1158            input: json!({"text": "test"}),
1159            is_input_complete: true,
1160            thought_signature: None,
1161        },
1162    ));
1163    fake_model.end_last_completion_stream();
1164    cx.run_until_parked();
1165
1166    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1167    assert_eq!(tool_call_params.name, "echo");
1168    assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1169    tool_call_response
1170        .send(context_server::types::CallToolResponse {
1171            content: vec![context_server::types::ToolResponseContent::Text {
1172                text: "test".into(),
1173            }],
1174            is_error: None,
1175            meta: None,
1176            structured_content: None,
1177        })
1178        .unwrap();
1179    cx.run_until_parked();
1180
1181    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1182    fake_model.send_last_completion_stream_text_chunk("Done!");
1183    fake_model.end_last_completion_stream();
1184    events.collect::<Vec<_>>().await;
1185
1186    // Send again after adding the echo tool, ensuring the name collision is resolved.
1187    let events = thread.update(cx, |thread, cx| {
1188        thread.add_tool(EchoTool);
1189        thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1190    });
1191    cx.run_until_parked();
1192    let completion = fake_model.pending_completions().pop().unwrap();
1193    assert_eq!(
1194        tool_names_for_completion(&completion),
1195        vec!["echo", "test_server_echo"]
1196    );
1197    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1198        LanguageModelToolUse {
1199            id: "tool_2".into(),
1200            name: "test_server_echo".into(),
1201            raw_input: json!({"text": "mcp"}).to_string(),
1202            input: json!({"text": "mcp"}),
1203            is_input_complete: true,
1204            thought_signature: None,
1205        },
1206    ));
1207    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1208        LanguageModelToolUse {
1209            id: "tool_3".into(),
1210            name: "echo".into(),
1211            raw_input: json!({"text": "native"}).to_string(),
1212            input: json!({"text": "native"}),
1213            is_input_complete: true,
1214            thought_signature: None,
1215        },
1216    ));
1217    fake_model.end_last_completion_stream();
1218    cx.run_until_parked();
1219
1220    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1221    assert_eq!(tool_call_params.name, "echo");
1222    assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1223    tool_call_response
1224        .send(context_server::types::CallToolResponse {
1225            content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1226            is_error: None,
1227            meta: None,
1228            structured_content: None,
1229        })
1230        .unwrap();
1231    cx.run_until_parked();
1232
1233    // Ensure the tool results were inserted with the correct names.
1234    let completion = fake_model.pending_completions().pop().unwrap();
1235    assert_eq!(
1236        completion.messages.last().unwrap().content,
1237        vec![
1238            MessageContent::ToolResult(LanguageModelToolResult {
1239                tool_use_id: "tool_3".into(),
1240                tool_name: "echo".into(),
1241                is_error: false,
1242                content: "native".into(),
1243                output: Some("native".into()),
1244            },),
1245            MessageContent::ToolResult(LanguageModelToolResult {
1246                tool_use_id: "tool_2".into(),
1247                tool_name: "test_server_echo".into(),
1248                is_error: false,
1249                content: "mcp".into(),
1250                output: Some("mcp".into()),
1251            },),
1252        ]
1253    );
1254    fake_model.end_last_completion_stream();
1255    events.collect::<Vec<_>>().await;
1256}
1257
1258#[gpui::test]
1259async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1260    let ThreadTest {
1261        model,
1262        thread,
1263        context_server_store,
1264        fs,
1265        ..
1266    } = setup(cx, TestModel::Fake).await;
1267    let fake_model = model.as_fake();
1268
1269    // Set up a profile with all tools enabled
1270    fs.insert_file(
1271        paths::settings_file(),
1272        json!({
1273            "agent": {
1274                "profiles": {
1275                    "test": {
1276                        "name": "Test Profile",
1277                        "enable_all_context_servers": true,
1278                        "tools": {
1279                            EchoTool::name(): true,
1280                            DelayTool::name(): true,
1281                            WordListTool::name(): true,
1282                            ToolRequiringPermission::name(): true,
1283                            InfiniteTool::name(): true,
1284                        }
1285                    },
1286                }
1287            }
1288        })
1289        .to_string()
1290        .into_bytes(),
1291    )
1292    .await;
1293    cx.run_until_parked();
1294
1295    thread.update(cx, |thread, cx| {
1296        thread.set_profile(AgentProfileId("test".into()), cx);
1297        thread.add_tool(EchoTool);
1298        thread.add_tool(DelayTool);
1299        thread.add_tool(WordListTool);
1300        thread.add_tool(ToolRequiringPermission);
1301        thread.add_tool(InfiniteTool);
1302    });
1303
1304    // Set up multiple context servers with some overlapping tool names
1305    let _server1_calls = setup_context_server(
1306        "xxx",
1307        vec![
1308            context_server::types::Tool {
1309                name: "echo".into(), // Conflicts with native EchoTool
1310                description: None,
1311                input_schema: serde_json::to_value(EchoTool::input_schema(
1312                    LanguageModelToolSchemaFormat::JsonSchema,
1313                ))
1314                .unwrap(),
1315                output_schema: None,
1316                annotations: None,
1317            },
1318            context_server::types::Tool {
1319                name: "unique_tool_1".into(),
1320                description: None,
1321                input_schema: json!({"type": "object", "properties": {}}),
1322                output_schema: None,
1323                annotations: None,
1324            },
1325        ],
1326        &context_server_store,
1327        cx,
1328    );
1329
1330    let _server2_calls = setup_context_server(
1331        "yyy",
1332        vec![
1333            context_server::types::Tool {
1334                name: "echo".into(), // Also conflicts with native EchoTool
1335                description: None,
1336                input_schema: serde_json::to_value(EchoTool::input_schema(
1337                    LanguageModelToolSchemaFormat::JsonSchema,
1338                ))
1339                .unwrap(),
1340                output_schema: None,
1341                annotations: None,
1342            },
1343            context_server::types::Tool {
1344                name: "unique_tool_2".into(),
1345                description: None,
1346                input_schema: json!({"type": "object", "properties": {}}),
1347                output_schema: None,
1348                annotations: None,
1349            },
1350            context_server::types::Tool {
1351                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1352                description: None,
1353                input_schema: json!({"type": "object", "properties": {}}),
1354                output_schema: None,
1355                annotations: None,
1356            },
1357            context_server::types::Tool {
1358                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1359                description: None,
1360                input_schema: json!({"type": "object", "properties": {}}),
1361                output_schema: None,
1362                annotations: None,
1363            },
1364        ],
1365        &context_server_store,
1366        cx,
1367    );
1368    let _server3_calls = setup_context_server(
1369        "zzz",
1370        vec![
1371            context_server::types::Tool {
1372                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1373                description: None,
1374                input_schema: json!({"type": "object", "properties": {}}),
1375                output_schema: None,
1376                annotations: None,
1377            },
1378            context_server::types::Tool {
1379                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1380                description: None,
1381                input_schema: json!({"type": "object", "properties": {}}),
1382                output_schema: None,
1383                annotations: None,
1384            },
1385            context_server::types::Tool {
1386                name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1387                description: None,
1388                input_schema: json!({"type": "object", "properties": {}}),
1389                output_schema: None,
1390                annotations: None,
1391            },
1392        ],
1393        &context_server_store,
1394        cx,
1395    );
1396
1397    thread
1398        .update(cx, |thread, cx| {
1399            thread.send(UserMessageId::new(), ["Go"], cx)
1400        })
1401        .unwrap();
1402    cx.run_until_parked();
1403    let completion = fake_model.pending_completions().pop().unwrap();
1404    assert_eq!(
1405        tool_names_for_completion(&completion),
1406        vec![
1407            "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1408            "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1409            "delay",
1410            "echo",
1411            "infinite",
1412            "tool_requiring_permission",
1413            "unique_tool_1",
1414            "unique_tool_2",
1415            "word_list",
1416            "xxx_echo",
1417            "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1418            "yyy_echo",
1419            "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1420        ]
1421    );
1422}
1423
1424#[gpui::test]
1425#[cfg_attr(not(feature = "e2e"), ignore)]
1426async fn test_cancellation(cx: &mut TestAppContext) {
1427    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1428
1429    let mut events = thread
1430        .update(cx, |thread, cx| {
1431            thread.add_tool(InfiniteTool);
1432            thread.add_tool(EchoTool);
1433            thread.send(
1434                UserMessageId::new(),
1435                ["Call the echo tool, then call the infinite tool, then explain their output"],
1436                cx,
1437            )
1438        })
1439        .unwrap();
1440
1441    // Wait until both tools are called.
1442    let mut expected_tools = vec!["Echo", "Infinite Tool"];
1443    let mut echo_id = None;
1444    let mut echo_completed = false;
1445    while let Some(event) = events.next().await {
1446        match event.unwrap() {
1447            ThreadEvent::ToolCall(tool_call) => {
1448                assert_eq!(tool_call.title, expected_tools.remove(0));
1449                if tool_call.title == "Echo" {
1450                    echo_id = Some(tool_call.tool_call_id);
1451                }
1452            }
1453            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1454                acp::ToolCallUpdate {
1455                    tool_call_id,
1456                    fields:
1457                        acp::ToolCallUpdateFields {
1458                            status: Some(acp::ToolCallStatus::Completed),
1459                            ..
1460                        },
1461                    ..
1462                },
1463            )) if Some(&tool_call_id) == echo_id.as_ref() => {
1464                echo_completed = true;
1465            }
1466            _ => {}
1467        }
1468
1469        if expected_tools.is_empty() && echo_completed {
1470            break;
1471        }
1472    }
1473
1474    // Cancel the current send and ensure that the event stream is closed, even
1475    // if one of the tools is still running.
1476    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
1477    let events = events.collect::<Vec<_>>().await;
1478    let last_event = events.last();
1479    assert!(
1480        matches!(
1481            last_event,
1482            Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1483        ),
1484        "unexpected event {last_event:?}"
1485    );
1486
1487    // Ensure we can still send a new message after cancellation.
1488    let events = thread
1489        .update(cx, |thread, cx| {
1490            thread.send(
1491                UserMessageId::new(),
1492                ["Testing: reply with 'Hello' then stop."],
1493                cx,
1494            )
1495        })
1496        .unwrap()
1497        .collect::<Vec<_>>()
1498        .await;
1499    thread.update(cx, |thread, _cx| {
1500        let message = thread.last_message().unwrap();
1501        let agent_message = message.as_agent_message().unwrap();
1502        assert_eq!(
1503            agent_message.content,
1504            vec![AgentMessageContent::Text("Hello".to_string())]
1505        );
1506    });
1507    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1508}
1509
1510#[gpui::test]
1511async fn test_terminal_tool_cancellation_captures_output(cx: &mut TestAppContext) {
1512    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1513    always_allow_tools(cx);
1514    let fake_model = model.as_fake();
1515
1516    let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1517    let environment = Rc::new(FakeThreadEnvironment {
1518        handle: handle.clone(),
1519    });
1520
1521    let mut events = thread
1522        .update(cx, |thread, cx| {
1523            thread.add_tool(crate::TerminalTool::new(
1524                thread.project().clone(),
1525                environment,
1526            ));
1527            thread.send(UserMessageId::new(), ["run a command"], cx)
1528        })
1529        .unwrap();
1530
1531    cx.run_until_parked();
1532
1533    // Simulate the model calling the terminal tool
1534    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1535        LanguageModelToolUse {
1536            id: "terminal_tool_1".into(),
1537            name: "terminal".into(),
1538            raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1539            input: json!({"command": "sleep 1000", "cd": "."}),
1540            is_input_complete: true,
1541            thought_signature: None,
1542        },
1543    ));
1544    fake_model.end_last_completion_stream();
1545
1546    // Wait for the terminal tool to start running
1547    wait_for_terminal_tool_started(&mut events, cx).await;
1548
1549    // Cancel the thread while the terminal is running
1550    thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1551
1552    // Collect remaining events, driving the executor to let cancellation complete
1553    let remaining_events = collect_events_until_stop(&mut events, cx).await;
1554
1555    // Verify the terminal was killed
1556    assert!(
1557        handle.was_killed(),
1558        "expected terminal handle to be killed on cancellation"
1559    );
1560
1561    // Verify we got a cancellation stop event
1562    assert_eq!(
1563        stop_events(remaining_events),
1564        vec![acp::StopReason::Cancelled],
1565    );
1566
1567    // Verify the tool result contains the terminal output, not just "Tool canceled by user"
1568    thread.update(cx, |thread, _cx| {
1569        let message = thread.last_message().unwrap();
1570        let agent_message = message.as_agent_message().unwrap();
1571
1572        let tool_use = agent_message
1573            .content
1574            .iter()
1575            .find_map(|content| match content {
1576                AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
1577                _ => None,
1578            })
1579            .expect("expected tool use in agent message");
1580
1581        let tool_result = agent_message
1582            .tool_results
1583            .get(&tool_use.id)
1584            .expect("expected tool result");
1585
1586        let result_text = match &tool_result.content {
1587            language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
1588            _ => panic!("expected text content in tool result"),
1589        };
1590
1591        // "partial output" comes from FakeTerminalHandle's output field
1592        assert!(
1593            result_text.contains("partial output"),
1594            "expected tool result to contain terminal output, got: {result_text}"
1595        );
1596        // Match the actual format from process_content in terminal_tool.rs
1597        assert!(
1598            result_text.contains("The user stopped this command"),
1599            "expected tool result to indicate user stopped, got: {result_text}"
1600        );
1601    });
1602
1603    // Verify we can send a new message after cancellation
1604    verify_thread_recovery(&thread, &fake_model, cx).await;
1605}
1606
1607#[gpui::test]
1608async fn test_cancellation_aware_tool_responds_to_cancellation(cx: &mut TestAppContext) {
1609    // This test verifies that tools which properly handle cancellation via
1610    // `event_stream.cancelled_by_user()` (like edit_file_tool) respond promptly
1611    // to cancellation and report that they were cancelled.
1612    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1613    always_allow_tools(cx);
1614    let fake_model = model.as_fake();
1615
1616    let (tool, was_cancelled) = CancellationAwareTool::new();
1617
1618    let mut events = thread
1619        .update(cx, |thread, cx| {
1620            thread.add_tool(tool);
1621            thread.send(
1622                UserMessageId::new(),
1623                ["call the cancellation aware tool"],
1624                cx,
1625            )
1626        })
1627        .unwrap();
1628
1629    cx.run_until_parked();
1630
1631    // Simulate the model calling the cancellation-aware tool
1632    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1633        LanguageModelToolUse {
1634            id: "cancellation_aware_1".into(),
1635            name: "cancellation_aware".into(),
1636            raw_input: r#"{}"#.into(),
1637            input: json!({}),
1638            is_input_complete: true,
1639            thought_signature: None,
1640        },
1641    ));
1642    fake_model.end_last_completion_stream();
1643
1644    cx.run_until_parked();
1645
1646    // Wait for the tool call to be reported
1647    let mut tool_started = false;
1648    let deadline = cx.executor().num_cpus() * 100;
1649    for _ in 0..deadline {
1650        cx.run_until_parked();
1651
1652        while let Some(Some(event)) = events.next().now_or_never() {
1653            if let Ok(ThreadEvent::ToolCall(tool_call)) = &event {
1654                if tool_call.title == "Cancellation Aware Tool" {
1655                    tool_started = true;
1656                    break;
1657                }
1658            }
1659        }
1660
1661        if tool_started {
1662            break;
1663        }
1664
1665        cx.background_executor
1666            .timer(Duration::from_millis(10))
1667            .await;
1668    }
1669    assert!(tool_started, "expected cancellation aware tool to start");
1670
1671    // Cancel the thread and wait for it to complete
1672    let cancel_task = thread.update(cx, |thread, cx| thread.cancel(cx));
1673
1674    // The cancel task should complete promptly because the tool handles cancellation
1675    let timeout = cx.background_executor.timer(Duration::from_secs(5));
1676    futures::select! {
1677        _ = cancel_task.fuse() => {}
1678        _ = timeout.fuse() => {
1679            panic!("cancel task timed out - tool did not respond to cancellation");
1680        }
1681    }
1682
1683    // Verify the tool detected cancellation via its flag
1684    assert!(
1685        was_cancelled.load(std::sync::atomic::Ordering::SeqCst),
1686        "tool should have detected cancellation via event_stream.cancelled_by_user()"
1687    );
1688
1689    // Collect remaining events
1690    let remaining_events = collect_events_until_stop(&mut events, cx).await;
1691
1692    // Verify we got a cancellation stop event
1693    assert_eq!(
1694        stop_events(remaining_events),
1695        vec![acp::StopReason::Cancelled],
1696    );
1697
1698    // Verify we can send a new message after cancellation
1699    verify_thread_recovery(&thread, &fake_model, cx).await;
1700}
1701
1702/// Helper to verify thread can recover after cancellation by sending a simple message.
1703async fn verify_thread_recovery(
1704    thread: &Entity<Thread>,
1705    fake_model: &FakeLanguageModel,
1706    cx: &mut TestAppContext,
1707) {
1708    let events = thread
1709        .update(cx, |thread, cx| {
1710            thread.send(
1711                UserMessageId::new(),
1712                ["Testing: reply with 'Hello' then stop."],
1713                cx,
1714            )
1715        })
1716        .unwrap();
1717    cx.run_until_parked();
1718    fake_model.send_last_completion_stream_text_chunk("Hello");
1719    fake_model
1720        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1721    fake_model.end_last_completion_stream();
1722
1723    let events = events.collect::<Vec<_>>().await;
1724    thread.update(cx, |thread, _cx| {
1725        let message = thread.last_message().unwrap();
1726        let agent_message = message.as_agent_message().unwrap();
1727        assert_eq!(
1728            agent_message.content,
1729            vec![AgentMessageContent::Text("Hello".to_string())]
1730        );
1731    });
1732    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1733}
1734
1735/// Waits for a terminal tool to start by watching for a ToolCallUpdate with terminal content.
1736async fn wait_for_terminal_tool_started(
1737    events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1738    cx: &mut TestAppContext,
1739) {
1740    let deadline = cx.executor().num_cpus() * 100; // Scale with available parallelism
1741    for _ in 0..deadline {
1742        cx.run_until_parked();
1743
1744        while let Some(Some(event)) = events.next().now_or_never() {
1745            if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1746                update,
1747            ))) = &event
1748            {
1749                if update.fields.content.as_ref().is_some_and(|content| {
1750                    content
1751                        .iter()
1752                        .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
1753                }) {
1754                    return;
1755                }
1756            }
1757        }
1758
1759        cx.background_executor
1760            .timer(Duration::from_millis(10))
1761            .await;
1762    }
1763    panic!("terminal tool did not start within the expected time");
1764}
1765
1766/// Collects events until a Stop event is received, driving the executor to completion.
1767async fn collect_events_until_stop(
1768    events: &mut mpsc::UnboundedReceiver<Result<ThreadEvent>>,
1769    cx: &mut TestAppContext,
1770) -> Vec<Result<ThreadEvent>> {
1771    let mut collected = Vec::new();
1772    let deadline = cx.executor().num_cpus() * 200;
1773
1774    for _ in 0..deadline {
1775        cx.executor().advance_clock(Duration::from_millis(10));
1776        cx.run_until_parked();
1777
1778        while let Some(Some(event)) = events.next().now_or_never() {
1779            let is_stop = matches!(&event, Ok(ThreadEvent::Stop(_)));
1780            collected.push(event);
1781            if is_stop {
1782                return collected;
1783            }
1784        }
1785    }
1786    panic!(
1787        "did not receive Stop event within the expected time; collected {} events",
1788        collected.len()
1789    );
1790}
1791
1792#[gpui::test]
1793async fn test_truncate_while_terminal_tool_running(cx: &mut TestAppContext) {
1794    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1795    always_allow_tools(cx);
1796    let fake_model = model.as_fake();
1797
1798    let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1799    let environment = Rc::new(FakeThreadEnvironment {
1800        handle: handle.clone(),
1801    });
1802
1803    let message_id = UserMessageId::new();
1804    let mut events = thread
1805        .update(cx, |thread, cx| {
1806            thread.add_tool(crate::TerminalTool::new(
1807                thread.project().clone(),
1808                environment,
1809            ));
1810            thread.send(message_id.clone(), ["run a command"], cx)
1811        })
1812        .unwrap();
1813
1814    cx.run_until_parked();
1815
1816    // Simulate the model calling the terminal tool
1817    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1818        LanguageModelToolUse {
1819            id: "terminal_tool_1".into(),
1820            name: "terminal".into(),
1821            raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1822            input: json!({"command": "sleep 1000", "cd": "."}),
1823            is_input_complete: true,
1824            thought_signature: None,
1825        },
1826    ));
1827    fake_model.end_last_completion_stream();
1828
1829    // Wait for the terminal tool to start running
1830    wait_for_terminal_tool_started(&mut events, cx).await;
1831
1832    // Truncate the thread while the terminal is running
1833    thread
1834        .update(cx, |thread, cx| thread.truncate(message_id, cx))
1835        .unwrap();
1836
1837    // Drive the executor to let cancellation complete
1838    let _ = collect_events_until_stop(&mut events, cx).await;
1839
1840    // Verify the terminal was killed
1841    assert!(
1842        handle.was_killed(),
1843        "expected terminal handle to be killed on truncate"
1844    );
1845
1846    // Verify the thread is empty after truncation
1847    thread.update(cx, |thread, _cx| {
1848        assert_eq!(
1849            thread.to_markdown(),
1850            "",
1851            "expected thread to be empty after truncating the only message"
1852        );
1853    });
1854
1855    // Verify we can send a new message after truncation
1856    verify_thread_recovery(&thread, &fake_model, cx).await;
1857}
1858
1859#[gpui::test]
1860async fn test_cancel_multiple_concurrent_terminal_tools(cx: &mut TestAppContext) {
1861    // Tests that cancellation properly kills all running terminal tools when multiple are active.
1862    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1863    always_allow_tools(cx);
1864    let fake_model = model.as_fake();
1865
1866    let environment = Rc::new(MultiTerminalEnvironment::new());
1867
1868    let mut events = thread
1869        .update(cx, |thread, cx| {
1870            thread.add_tool(crate::TerminalTool::new(
1871                thread.project().clone(),
1872                environment.clone(),
1873            ));
1874            thread.send(UserMessageId::new(), ["run multiple commands"], cx)
1875        })
1876        .unwrap();
1877
1878    cx.run_until_parked();
1879
1880    // Simulate the model calling two terminal tools
1881    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1882        LanguageModelToolUse {
1883            id: "terminal_tool_1".into(),
1884            name: "terminal".into(),
1885            raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1886            input: json!({"command": "sleep 1000", "cd": "."}),
1887            is_input_complete: true,
1888            thought_signature: None,
1889        },
1890    ));
1891    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1892        LanguageModelToolUse {
1893            id: "terminal_tool_2".into(),
1894            name: "terminal".into(),
1895            raw_input: r#"{"command": "sleep 2000", "cd": "."}"#.into(),
1896            input: json!({"command": "sleep 2000", "cd": "."}),
1897            is_input_complete: true,
1898            thought_signature: None,
1899        },
1900    ));
1901    fake_model.end_last_completion_stream();
1902
1903    // Wait for both terminal tools to start by counting terminal content updates
1904    let mut terminals_started = 0;
1905    let deadline = cx.executor().num_cpus() * 100;
1906    for _ in 0..deadline {
1907        cx.run_until_parked();
1908
1909        while let Some(Some(event)) = events.next().now_or_never() {
1910            if let Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1911                update,
1912            ))) = &event
1913            {
1914                if update.fields.content.as_ref().is_some_and(|content| {
1915                    content
1916                        .iter()
1917                        .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
1918                }) {
1919                    terminals_started += 1;
1920                    if terminals_started >= 2 {
1921                        break;
1922                    }
1923                }
1924            }
1925        }
1926        if terminals_started >= 2 {
1927            break;
1928        }
1929
1930        cx.background_executor
1931            .timer(Duration::from_millis(10))
1932            .await;
1933    }
1934    assert!(
1935        terminals_started >= 2,
1936        "expected 2 terminal tools to start, got {terminals_started}"
1937    );
1938
1939    // Cancel the thread while both terminals are running
1940    thread.update(cx, |thread, cx| thread.cancel(cx)).detach();
1941
1942    // Collect remaining events
1943    let remaining_events = collect_events_until_stop(&mut events, cx).await;
1944
1945    // Verify both terminal handles were killed
1946    let handles = environment.handles();
1947    assert_eq!(
1948        handles.len(),
1949        2,
1950        "expected 2 terminal handles to be created"
1951    );
1952    assert!(
1953        handles[0].was_killed(),
1954        "expected first terminal handle to be killed on cancellation"
1955    );
1956    assert!(
1957        handles[1].was_killed(),
1958        "expected second terminal handle to be killed on cancellation"
1959    );
1960
1961    // Verify we got a cancellation stop event
1962    assert_eq!(
1963        stop_events(remaining_events),
1964        vec![acp::StopReason::Cancelled],
1965    );
1966}
1967
1968#[gpui::test]
1969async fn test_terminal_tool_stopped_via_terminal_card_button(cx: &mut TestAppContext) {
1970    // Tests that clicking the stop button on the terminal card (as opposed to the main
1971    // cancel button) properly reports user stopped via the was_stopped_by_user path.
1972    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1973    always_allow_tools(cx);
1974    let fake_model = model.as_fake();
1975
1976    let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
1977    let environment = Rc::new(FakeThreadEnvironment {
1978        handle: handle.clone(),
1979    });
1980
1981    let mut events = thread
1982        .update(cx, |thread, cx| {
1983            thread.add_tool(crate::TerminalTool::new(
1984                thread.project().clone(),
1985                environment,
1986            ));
1987            thread.send(UserMessageId::new(), ["run a command"], cx)
1988        })
1989        .unwrap();
1990
1991    cx.run_until_parked();
1992
1993    // Simulate the model calling the terminal tool
1994    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1995        LanguageModelToolUse {
1996            id: "terminal_tool_1".into(),
1997            name: "terminal".into(),
1998            raw_input: r#"{"command": "sleep 1000", "cd": "."}"#.into(),
1999            input: json!({"command": "sleep 1000", "cd": "."}),
2000            is_input_complete: true,
2001            thought_signature: None,
2002        },
2003    ));
2004    fake_model.end_last_completion_stream();
2005
2006    // Wait for the terminal tool to start running
2007    wait_for_terminal_tool_started(&mut events, cx).await;
2008
2009    // Simulate user clicking stop on the terminal card itself.
2010    // This sets the flag and signals exit (simulating what the real UI would do).
2011    handle.set_stopped_by_user(true);
2012    handle.killed.store(true, Ordering::SeqCst);
2013    handle.signal_exit();
2014
2015    // Wait for the tool to complete
2016    cx.run_until_parked();
2017
2018    // The thread continues after tool completion - simulate the model ending its turn
2019    fake_model
2020        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2021    fake_model.end_last_completion_stream();
2022
2023    // Collect remaining events
2024    let remaining_events = collect_events_until_stop(&mut events, cx).await;
2025
2026    // Verify we got an EndTurn (not Cancelled, since we didn't cancel the thread)
2027    assert_eq!(
2028        stop_events(remaining_events),
2029        vec![acp::StopReason::EndTurn],
2030    );
2031
2032    // Verify the tool result indicates user stopped
2033    thread.update(cx, |thread, _cx| {
2034        let message = thread.last_message().unwrap();
2035        let agent_message = message.as_agent_message().unwrap();
2036
2037        let tool_use = agent_message
2038            .content
2039            .iter()
2040            .find_map(|content| match content {
2041                AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2042                _ => None,
2043            })
2044            .expect("expected tool use in agent message");
2045
2046        let tool_result = agent_message
2047            .tool_results
2048            .get(&tool_use.id)
2049            .expect("expected tool result");
2050
2051        let result_text = match &tool_result.content {
2052            language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2053            _ => panic!("expected text content in tool result"),
2054        };
2055
2056        assert!(
2057            result_text.contains("The user stopped this command"),
2058            "expected tool result to indicate user stopped, got: {result_text}"
2059        );
2060    });
2061}
2062
2063#[gpui::test]
2064async fn test_terminal_tool_timeout_expires(cx: &mut TestAppContext) {
2065    // Tests that when a timeout is configured and expires, the tool result indicates timeout.
2066    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2067    always_allow_tools(cx);
2068    let fake_model = model.as_fake();
2069
2070    let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
2071    let environment = Rc::new(FakeThreadEnvironment {
2072        handle: handle.clone(),
2073    });
2074
2075    let mut events = thread
2076        .update(cx, |thread, cx| {
2077            thread.add_tool(crate::TerminalTool::new(
2078                thread.project().clone(),
2079                environment,
2080            ));
2081            thread.send(UserMessageId::new(), ["run a command with timeout"], cx)
2082        })
2083        .unwrap();
2084
2085    cx.run_until_parked();
2086
2087    // Simulate the model calling the terminal tool with a short timeout
2088    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2089        LanguageModelToolUse {
2090            id: "terminal_tool_1".into(),
2091            name: "terminal".into(),
2092            raw_input: r#"{"command": "sleep 1000", "cd": ".", "timeout_ms": 100}"#.into(),
2093            input: json!({"command": "sleep 1000", "cd": ".", "timeout_ms": 100}),
2094            is_input_complete: true,
2095            thought_signature: None,
2096        },
2097    ));
2098    fake_model.end_last_completion_stream();
2099
2100    // Wait for the terminal tool to start running
2101    wait_for_terminal_tool_started(&mut events, cx).await;
2102
2103    // Advance clock past the timeout
2104    cx.executor().advance_clock(Duration::from_millis(200));
2105    cx.run_until_parked();
2106
2107    // The thread continues after tool completion - simulate the model ending its turn
2108    fake_model
2109        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2110    fake_model.end_last_completion_stream();
2111
2112    // Collect remaining events
2113    let remaining_events = collect_events_until_stop(&mut events, cx).await;
2114
2115    // Verify the terminal was killed due to timeout
2116    assert!(
2117        handle.was_killed(),
2118        "expected terminal handle to be killed on timeout"
2119    );
2120
2121    // Verify we got an EndTurn (the tool completed, just with timeout)
2122    assert_eq!(
2123        stop_events(remaining_events),
2124        vec![acp::StopReason::EndTurn],
2125    );
2126
2127    // Verify the tool result indicates timeout, not user stopped
2128    thread.update(cx, |thread, _cx| {
2129        let message = thread.last_message().unwrap();
2130        let agent_message = message.as_agent_message().unwrap();
2131
2132        let tool_use = agent_message
2133            .content
2134            .iter()
2135            .find_map(|content| match content {
2136                AgentMessageContent::ToolUse(tool_use) => Some(tool_use),
2137                _ => None,
2138            })
2139            .expect("expected tool use in agent message");
2140
2141        let tool_result = agent_message
2142            .tool_results
2143            .get(&tool_use.id)
2144            .expect("expected tool result");
2145
2146        let result_text = match &tool_result.content {
2147            language_model::LanguageModelToolResultContent::Text(text) => text.to_string(),
2148            _ => panic!("expected text content in tool result"),
2149        };
2150
2151        assert!(
2152            result_text.contains("timed out"),
2153            "expected tool result to indicate timeout, got: {result_text}"
2154        );
2155        assert!(
2156            !result_text.contains("The user stopped"),
2157            "tool result should not mention user stopped when it timed out, got: {result_text}"
2158        );
2159    });
2160}
2161
2162#[gpui::test]
2163async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
2164    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2165    let fake_model = model.as_fake();
2166
2167    let events_1 = thread
2168        .update(cx, |thread, cx| {
2169            thread.send(UserMessageId::new(), ["Hello 1"], cx)
2170        })
2171        .unwrap();
2172    cx.run_until_parked();
2173    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2174    cx.run_until_parked();
2175
2176    let events_2 = thread
2177        .update(cx, |thread, cx| {
2178            thread.send(UserMessageId::new(), ["Hello 2"], cx)
2179        })
2180        .unwrap();
2181    cx.run_until_parked();
2182    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2183    fake_model
2184        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2185    fake_model.end_last_completion_stream();
2186
2187    let events_1 = events_1.collect::<Vec<_>>().await;
2188    assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
2189    let events_2 = events_2.collect::<Vec<_>>().await;
2190    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2191}
2192
2193#[gpui::test]
2194async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
2195    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2196    let fake_model = model.as_fake();
2197
2198    let events_1 = thread
2199        .update(cx, |thread, cx| {
2200            thread.send(UserMessageId::new(), ["Hello 1"], cx)
2201        })
2202        .unwrap();
2203    cx.run_until_parked();
2204    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
2205    fake_model
2206        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2207    fake_model.end_last_completion_stream();
2208    let events_1 = events_1.collect::<Vec<_>>().await;
2209
2210    let events_2 = thread
2211        .update(cx, |thread, cx| {
2212            thread.send(UserMessageId::new(), ["Hello 2"], cx)
2213        })
2214        .unwrap();
2215    cx.run_until_parked();
2216    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
2217    fake_model
2218        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
2219    fake_model.end_last_completion_stream();
2220    let events_2 = events_2.collect::<Vec<_>>().await;
2221
2222    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
2223    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
2224}
2225
2226#[gpui::test]
2227async fn test_refusal(cx: &mut TestAppContext) {
2228    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2229    let fake_model = model.as_fake();
2230
2231    let events = thread
2232        .update(cx, |thread, cx| {
2233            thread.send(UserMessageId::new(), ["Hello"], cx)
2234        })
2235        .unwrap();
2236    cx.run_until_parked();
2237    thread.read_with(cx, |thread, _| {
2238        assert_eq!(
2239            thread.to_markdown(),
2240            indoc! {"
2241                ## User
2242
2243                Hello
2244            "}
2245        );
2246    });
2247
2248    fake_model.send_last_completion_stream_text_chunk("Hey!");
2249    cx.run_until_parked();
2250    thread.read_with(cx, |thread, _| {
2251        assert_eq!(
2252            thread.to_markdown(),
2253            indoc! {"
2254                ## User
2255
2256                Hello
2257
2258                ## Assistant
2259
2260                Hey!
2261            "}
2262        );
2263    });
2264
2265    // If the model refuses to continue, the thread should remove all the messages after the last user message.
2266    fake_model
2267        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
2268    let events = events.collect::<Vec<_>>().await;
2269    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
2270    thread.read_with(cx, |thread, _| {
2271        assert_eq!(thread.to_markdown(), "");
2272    });
2273}
2274
2275#[gpui::test]
2276async fn test_truncate_first_message(cx: &mut TestAppContext) {
2277    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2278    let fake_model = model.as_fake();
2279
2280    let message_id = UserMessageId::new();
2281    thread
2282        .update(cx, |thread, cx| {
2283            thread.send(message_id.clone(), ["Hello"], cx)
2284        })
2285        .unwrap();
2286    cx.run_until_parked();
2287    thread.read_with(cx, |thread, _| {
2288        assert_eq!(
2289            thread.to_markdown(),
2290            indoc! {"
2291                ## User
2292
2293                Hello
2294            "}
2295        );
2296        assert_eq!(thread.latest_token_usage(), None);
2297    });
2298
2299    fake_model.send_last_completion_stream_text_chunk("Hey!");
2300    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2301        language_model::TokenUsage {
2302            input_tokens: 32_000,
2303            output_tokens: 16_000,
2304            cache_creation_input_tokens: 0,
2305            cache_read_input_tokens: 0,
2306        },
2307    ));
2308    cx.run_until_parked();
2309    thread.read_with(cx, |thread, _| {
2310        assert_eq!(
2311            thread.to_markdown(),
2312            indoc! {"
2313                ## User
2314
2315                Hello
2316
2317                ## Assistant
2318
2319                Hey!
2320            "}
2321        );
2322        assert_eq!(
2323            thread.latest_token_usage(),
2324            Some(acp_thread::TokenUsage {
2325                used_tokens: 32_000 + 16_000,
2326                max_tokens: 1_000_000,
2327                input_tokens: 32_000,
2328                output_tokens: 16_000,
2329            })
2330        );
2331    });
2332
2333    thread
2334        .update(cx, |thread, cx| thread.truncate(message_id, cx))
2335        .unwrap();
2336    cx.run_until_parked();
2337    thread.read_with(cx, |thread, _| {
2338        assert_eq!(thread.to_markdown(), "");
2339        assert_eq!(thread.latest_token_usage(), None);
2340    });
2341
2342    // Ensure we can still send a new message after truncation.
2343    thread
2344        .update(cx, |thread, cx| {
2345            thread.send(UserMessageId::new(), ["Hi"], cx)
2346        })
2347        .unwrap();
2348    thread.update(cx, |thread, _cx| {
2349        assert_eq!(
2350            thread.to_markdown(),
2351            indoc! {"
2352                ## User
2353
2354                Hi
2355            "}
2356        );
2357    });
2358    cx.run_until_parked();
2359    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
2360    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2361        language_model::TokenUsage {
2362            input_tokens: 40_000,
2363            output_tokens: 20_000,
2364            cache_creation_input_tokens: 0,
2365            cache_read_input_tokens: 0,
2366        },
2367    ));
2368    cx.run_until_parked();
2369    thread.read_with(cx, |thread, _| {
2370        assert_eq!(
2371            thread.to_markdown(),
2372            indoc! {"
2373                ## User
2374
2375                Hi
2376
2377                ## Assistant
2378
2379                Ahoy!
2380            "}
2381        );
2382
2383        assert_eq!(
2384            thread.latest_token_usage(),
2385            Some(acp_thread::TokenUsage {
2386                used_tokens: 40_000 + 20_000,
2387                max_tokens: 1_000_000,
2388                input_tokens: 40_000,
2389                output_tokens: 20_000,
2390            })
2391        );
2392    });
2393}
2394
2395#[gpui::test]
2396async fn test_truncate_second_message(cx: &mut TestAppContext) {
2397    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2398    let fake_model = model.as_fake();
2399
2400    thread
2401        .update(cx, |thread, cx| {
2402            thread.send(UserMessageId::new(), ["Message 1"], cx)
2403        })
2404        .unwrap();
2405    cx.run_until_parked();
2406    fake_model.send_last_completion_stream_text_chunk("Message 1 response");
2407    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2408        language_model::TokenUsage {
2409            input_tokens: 32_000,
2410            output_tokens: 16_000,
2411            cache_creation_input_tokens: 0,
2412            cache_read_input_tokens: 0,
2413        },
2414    ));
2415    fake_model.end_last_completion_stream();
2416    cx.run_until_parked();
2417
2418    let assert_first_message_state = |cx: &mut TestAppContext| {
2419        thread.clone().read_with(cx, |thread, _| {
2420            assert_eq!(
2421                thread.to_markdown(),
2422                indoc! {"
2423                    ## User
2424
2425                    Message 1
2426
2427                    ## Assistant
2428
2429                    Message 1 response
2430                "}
2431            );
2432
2433            assert_eq!(
2434                thread.latest_token_usage(),
2435                Some(acp_thread::TokenUsage {
2436                    used_tokens: 32_000 + 16_000,
2437                    max_tokens: 1_000_000,
2438                    input_tokens: 32_000,
2439                    output_tokens: 16_000,
2440                })
2441            );
2442        });
2443    };
2444
2445    assert_first_message_state(cx);
2446
2447    let second_message_id = UserMessageId::new();
2448    thread
2449        .update(cx, |thread, cx| {
2450            thread.send(second_message_id.clone(), ["Message 2"], cx)
2451        })
2452        .unwrap();
2453    cx.run_until_parked();
2454
2455    fake_model.send_last_completion_stream_text_chunk("Message 2 response");
2456    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2457        language_model::TokenUsage {
2458            input_tokens: 40_000,
2459            output_tokens: 20_000,
2460            cache_creation_input_tokens: 0,
2461            cache_read_input_tokens: 0,
2462        },
2463    ));
2464    fake_model.end_last_completion_stream();
2465    cx.run_until_parked();
2466
2467    thread.read_with(cx, |thread, _| {
2468        assert_eq!(
2469            thread.to_markdown(),
2470            indoc! {"
2471                ## User
2472
2473                Message 1
2474
2475                ## Assistant
2476
2477                Message 1 response
2478
2479                ## User
2480
2481                Message 2
2482
2483                ## Assistant
2484
2485                Message 2 response
2486            "}
2487        );
2488
2489        assert_eq!(
2490            thread.latest_token_usage(),
2491            Some(acp_thread::TokenUsage {
2492                used_tokens: 40_000 + 20_000,
2493                max_tokens: 1_000_000,
2494                input_tokens: 40_000,
2495                output_tokens: 20_000,
2496            })
2497        );
2498    });
2499
2500    thread
2501        .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
2502        .unwrap();
2503    cx.run_until_parked();
2504
2505    assert_first_message_state(cx);
2506}
2507
2508#[gpui::test]
2509async fn test_title_generation(cx: &mut TestAppContext) {
2510    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2511    let fake_model = model.as_fake();
2512
2513    let summary_model = Arc::new(FakeLanguageModel::default());
2514    thread.update(cx, |thread, cx| {
2515        thread.set_summarization_model(Some(summary_model.clone()), cx)
2516    });
2517
2518    let send = thread
2519        .update(cx, |thread, cx| {
2520            thread.send(UserMessageId::new(), ["Hello"], cx)
2521        })
2522        .unwrap();
2523    cx.run_until_parked();
2524
2525    fake_model.send_last_completion_stream_text_chunk("Hey!");
2526    fake_model.end_last_completion_stream();
2527    cx.run_until_parked();
2528    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
2529
2530    // Ensure the summary model has been invoked to generate a title.
2531    summary_model.send_last_completion_stream_text_chunk("Hello ");
2532    summary_model.send_last_completion_stream_text_chunk("world\nG");
2533    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
2534    summary_model.end_last_completion_stream();
2535    send.collect::<Vec<_>>().await;
2536    cx.run_until_parked();
2537    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2538
2539    // Send another message, ensuring no title is generated this time.
2540    let send = thread
2541        .update(cx, |thread, cx| {
2542            thread.send(UserMessageId::new(), ["Hello again"], cx)
2543        })
2544        .unwrap();
2545    cx.run_until_parked();
2546    fake_model.send_last_completion_stream_text_chunk("Hey again!");
2547    fake_model.end_last_completion_stream();
2548    cx.run_until_parked();
2549    assert_eq!(summary_model.pending_completions(), Vec::new());
2550    send.collect::<Vec<_>>().await;
2551    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2552}
2553
2554#[gpui::test]
2555async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2556    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2557    let fake_model = model.as_fake();
2558
2559    let _events = thread
2560        .update(cx, |thread, cx| {
2561            thread.add_tool(ToolRequiringPermission);
2562            thread.add_tool(EchoTool);
2563            thread.send(UserMessageId::new(), ["Hey!"], cx)
2564        })
2565        .unwrap();
2566    cx.run_until_parked();
2567
2568    let permission_tool_use = LanguageModelToolUse {
2569        id: "tool_id_1".into(),
2570        name: ToolRequiringPermission::name().into(),
2571        raw_input: "{}".into(),
2572        input: json!({}),
2573        is_input_complete: true,
2574        thought_signature: None,
2575    };
2576    let echo_tool_use = LanguageModelToolUse {
2577        id: "tool_id_2".into(),
2578        name: EchoTool::name().into(),
2579        raw_input: json!({"text": "test"}).to_string(),
2580        input: json!({"text": "test"}),
2581        is_input_complete: true,
2582        thought_signature: None,
2583    };
2584    fake_model.send_last_completion_stream_text_chunk("Hi!");
2585    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2586        permission_tool_use,
2587    ));
2588    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2589        echo_tool_use.clone(),
2590    ));
2591    fake_model.end_last_completion_stream();
2592    cx.run_until_parked();
2593
2594    // Ensure pending tools are skipped when building a request.
2595    let request = thread
2596        .read_with(cx, |thread, cx| {
2597            thread.build_completion_request(CompletionIntent::EditFile, cx)
2598        })
2599        .unwrap();
2600    assert_eq!(
2601        request.messages[1..],
2602        vec![
2603            LanguageModelRequestMessage {
2604                role: Role::User,
2605                content: vec!["Hey!".into()],
2606                cache: true,
2607                reasoning_details: None,
2608            },
2609            LanguageModelRequestMessage {
2610                role: Role::Assistant,
2611                content: vec![
2612                    MessageContent::Text("Hi!".into()),
2613                    MessageContent::ToolUse(echo_tool_use.clone())
2614                ],
2615                cache: false,
2616                reasoning_details: None,
2617            },
2618            LanguageModelRequestMessage {
2619                role: Role::User,
2620                content: vec![MessageContent::ToolResult(LanguageModelToolResult {
2621                    tool_use_id: echo_tool_use.id.clone(),
2622                    tool_name: echo_tool_use.name,
2623                    is_error: false,
2624                    content: "test".into(),
2625                    output: Some("test".into())
2626                })],
2627                cache: false,
2628                reasoning_details: None,
2629            },
2630        ],
2631    );
2632}
2633
2634#[gpui::test]
2635async fn test_agent_connection(cx: &mut TestAppContext) {
2636    cx.update(settings::init);
2637    let templates = Templates::new();
2638
2639    // Initialize language model system with test provider
2640    cx.update(|cx| {
2641        gpui_tokio::init(cx);
2642
2643        let http_client = FakeHttpClient::with_404_response();
2644        let clock = Arc::new(clock::FakeSystemClock::new());
2645        let client = Client::new(clock, http_client, cx);
2646        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2647        language_model::init(client.clone(), cx);
2648        language_models::init(user_store, client.clone(), cx);
2649        LanguageModelRegistry::test(cx);
2650    });
2651    cx.executor().forbid_parking();
2652
2653    // Create a project for new_thread
2654    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
2655    fake_fs.insert_tree(path!("/test"), json!({})).await;
2656    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
2657    let cwd = Path::new("/test");
2658    let thread_store = cx.new(|cx| ThreadStore::new(cx));
2659
2660    // Create agent and connection
2661    let agent = NativeAgent::new(
2662        project.clone(),
2663        thread_store,
2664        templates.clone(),
2665        None,
2666        fake_fs.clone(),
2667        &mut cx.to_async(),
2668    )
2669    .await
2670    .unwrap();
2671    let connection = NativeAgentConnection(agent.clone());
2672
2673    // Create a thread using new_thread
2674    let connection_rc = Rc::new(connection.clone());
2675    let acp_thread = cx
2676        .update(|cx| connection_rc.new_thread(project, cwd, cx))
2677        .await
2678        .expect("new_thread should succeed");
2679
2680    // Get the session_id from the AcpThread
2681    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2682
2683    // Test model_selector returns Some
2684    let selector_opt = connection.model_selector(&session_id);
2685    assert!(
2686        selector_opt.is_some(),
2687        "agent should always support ModelSelector"
2688    );
2689    let selector = selector_opt.unwrap();
2690
2691    // Test list_models
2692    let listed_models = cx
2693        .update(|cx| selector.list_models(cx))
2694        .await
2695        .expect("list_models should succeed");
2696    let AgentModelList::Grouped(listed_models) = listed_models else {
2697        panic!("Unexpected model list type");
2698    };
2699    assert!(!listed_models.is_empty(), "should have at least one model");
2700    assert_eq!(
2701        listed_models[&AgentModelGroupName("Fake".into())][0]
2702            .id
2703            .0
2704            .as_ref(),
2705        "fake/fake"
2706    );
2707
2708    // Test selected_model returns the default
2709    let model = cx
2710        .update(|cx| selector.selected_model(cx))
2711        .await
2712        .expect("selected_model should succeed");
2713    let model = cx
2714        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
2715        .unwrap();
2716    let model = model.as_fake();
2717    assert_eq!(model.id().0, "fake", "should return default model");
2718
2719    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
2720    cx.run_until_parked();
2721    model.send_last_completion_stream_text_chunk("def");
2722    cx.run_until_parked();
2723    acp_thread.read_with(cx, |thread, cx| {
2724        assert_eq!(
2725            thread.to_markdown(cx),
2726            indoc! {"
2727                ## User
2728
2729                abc
2730
2731                ## Assistant
2732
2733                def
2734
2735            "}
2736        )
2737    });
2738
2739    // Test cancel
2740    cx.update(|cx| connection.cancel(&session_id, cx));
2741    request.await.expect("prompt should fail gracefully");
2742
2743    // Ensure that dropping the ACP thread causes the native thread to be
2744    // dropped as well.
2745    cx.update(|_| drop(acp_thread));
2746    let result = cx
2747        .update(|cx| {
2748            connection.prompt(
2749                Some(acp_thread::UserMessageId::new()),
2750                acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
2751                cx,
2752            )
2753        })
2754        .await;
2755    assert_eq!(
2756        result.as_ref().unwrap_err().to_string(),
2757        "Session not found",
2758        "unexpected result: {:?}",
2759        result
2760    );
2761}
2762
2763#[gpui::test]
2764async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2765    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2766    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
2767    let fake_model = model.as_fake();
2768
2769    let mut events = thread
2770        .update(cx, |thread, cx| {
2771            thread.send(UserMessageId::new(), ["Think"], cx)
2772        })
2773        .unwrap();
2774    cx.run_until_parked();
2775
2776    // Simulate streaming partial input.
2777    let input = json!({});
2778    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2779        LanguageModelToolUse {
2780            id: "1".into(),
2781            name: ThinkingTool::name().into(),
2782            raw_input: input.to_string(),
2783            input,
2784            is_input_complete: false,
2785            thought_signature: None,
2786        },
2787    ));
2788
2789    // Input streaming completed
2790    let input = json!({ "content": "Thinking hard!" });
2791    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2792        LanguageModelToolUse {
2793            id: "1".into(),
2794            name: "thinking".into(),
2795            raw_input: input.to_string(),
2796            input,
2797            is_input_complete: true,
2798            thought_signature: None,
2799        },
2800    ));
2801    fake_model.end_last_completion_stream();
2802    cx.run_until_parked();
2803
2804    let tool_call = expect_tool_call(&mut events).await;
2805    assert_eq!(
2806        tool_call,
2807        acp::ToolCall::new("1", "Thinking")
2808            .kind(acp::ToolKind::Think)
2809            .raw_input(json!({}))
2810            .meta(acp::Meta::from_iter([(
2811                "tool_name".into(),
2812                "thinking".into()
2813            )]))
2814    );
2815    let update = expect_tool_call_update_fields(&mut events).await;
2816    assert_eq!(
2817        update,
2818        acp::ToolCallUpdate::new(
2819            "1",
2820            acp::ToolCallUpdateFields::new()
2821                .title("Thinking")
2822                .kind(acp::ToolKind::Think)
2823                .raw_input(json!({ "content": "Thinking hard!"}))
2824        )
2825    );
2826    let update = expect_tool_call_update_fields(&mut events).await;
2827    assert_eq!(
2828        update,
2829        acp::ToolCallUpdate::new(
2830            "1",
2831            acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
2832        )
2833    );
2834    let update = expect_tool_call_update_fields(&mut events).await;
2835    assert_eq!(
2836        update,
2837        acp::ToolCallUpdate::new(
2838            "1",
2839            acp::ToolCallUpdateFields::new().content(vec!["Thinking hard!".into()])
2840        )
2841    );
2842    let update = expect_tool_call_update_fields(&mut events).await;
2843    assert_eq!(
2844        update,
2845        acp::ToolCallUpdate::new(
2846            "1",
2847            acp::ToolCallUpdateFields::new()
2848                .status(acp::ToolCallStatus::Completed)
2849                .raw_output("Finished thinking.")
2850        )
2851    );
2852}
2853
2854#[gpui::test]
2855async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2856    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2857    let fake_model = model.as_fake();
2858
2859    let mut events = thread
2860        .update(cx, |thread, cx| {
2861            thread.send(UserMessageId::new(), ["Hello!"], cx)
2862        })
2863        .unwrap();
2864    cx.run_until_parked();
2865
2866    fake_model.send_last_completion_stream_text_chunk("Hey!");
2867    fake_model.end_last_completion_stream();
2868
2869    let mut retry_events = Vec::new();
2870    while let Some(Ok(event)) = events.next().await {
2871        match event {
2872            ThreadEvent::Retry(retry_status) => {
2873                retry_events.push(retry_status);
2874            }
2875            ThreadEvent::Stop(..) => break,
2876            _ => {}
2877        }
2878    }
2879
2880    assert_eq!(retry_events.len(), 0);
2881    thread.read_with(cx, |thread, _cx| {
2882        assert_eq!(
2883            thread.to_markdown(),
2884            indoc! {"
2885                ## User
2886
2887                Hello!
2888
2889                ## Assistant
2890
2891                Hey!
2892            "}
2893        )
2894    });
2895}
2896
2897#[gpui::test]
2898async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2899    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2900    let fake_model = model.as_fake();
2901
2902    let mut events = thread
2903        .update(cx, |thread, cx| {
2904            thread.send(UserMessageId::new(), ["Hello!"], cx)
2905        })
2906        .unwrap();
2907    cx.run_until_parked();
2908
2909    fake_model.send_last_completion_stream_text_chunk("Hey,");
2910    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2911        provider: LanguageModelProviderName::new("Anthropic"),
2912        retry_after: Some(Duration::from_secs(3)),
2913    });
2914    fake_model.end_last_completion_stream();
2915
2916    cx.executor().advance_clock(Duration::from_secs(3));
2917    cx.run_until_parked();
2918
2919    fake_model.send_last_completion_stream_text_chunk("there!");
2920    fake_model.end_last_completion_stream();
2921    cx.run_until_parked();
2922
2923    let mut retry_events = Vec::new();
2924    while let Some(Ok(event)) = events.next().await {
2925        match event {
2926            ThreadEvent::Retry(retry_status) => {
2927                retry_events.push(retry_status);
2928            }
2929            ThreadEvent::Stop(..) => break,
2930            _ => {}
2931        }
2932    }
2933
2934    assert_eq!(retry_events.len(), 1);
2935    assert!(matches!(
2936        retry_events[0],
2937        acp_thread::RetryStatus { attempt: 1, .. }
2938    ));
2939    thread.read_with(cx, |thread, _cx| {
2940        assert_eq!(
2941            thread.to_markdown(),
2942            indoc! {"
2943                ## User
2944
2945                Hello!
2946
2947                ## Assistant
2948
2949                Hey,
2950
2951                [resume]
2952
2953                ## Assistant
2954
2955                there!
2956            "}
2957        )
2958    });
2959}
2960
2961#[gpui::test]
2962async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2963    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2964    let fake_model = model.as_fake();
2965
2966    let events = thread
2967        .update(cx, |thread, cx| {
2968            thread.add_tool(EchoTool);
2969            thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2970        })
2971        .unwrap();
2972    cx.run_until_parked();
2973
2974    let tool_use_1 = LanguageModelToolUse {
2975        id: "tool_1".into(),
2976        name: EchoTool::name().into(),
2977        raw_input: json!({"text": "test"}).to_string(),
2978        input: json!({"text": "test"}),
2979        is_input_complete: true,
2980        thought_signature: None,
2981    };
2982    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2983        tool_use_1.clone(),
2984    ));
2985    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2986        provider: LanguageModelProviderName::new("Anthropic"),
2987        retry_after: Some(Duration::from_secs(3)),
2988    });
2989    fake_model.end_last_completion_stream();
2990
2991    cx.executor().advance_clock(Duration::from_secs(3));
2992    let completion = fake_model.pending_completions().pop().unwrap();
2993    assert_eq!(
2994        completion.messages[1..],
2995        vec![
2996            LanguageModelRequestMessage {
2997                role: Role::User,
2998                content: vec!["Call the echo tool!".into()],
2999                cache: false,
3000                reasoning_details: None,
3001            },
3002            LanguageModelRequestMessage {
3003                role: Role::Assistant,
3004                content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
3005                cache: false,
3006                reasoning_details: None,
3007            },
3008            LanguageModelRequestMessage {
3009                role: Role::User,
3010                content: vec![language_model::MessageContent::ToolResult(
3011                    LanguageModelToolResult {
3012                        tool_use_id: tool_use_1.id.clone(),
3013                        tool_name: tool_use_1.name.clone(),
3014                        is_error: false,
3015                        content: "test".into(),
3016                        output: Some("test".into())
3017                    }
3018                )],
3019                cache: true,
3020                reasoning_details: None,
3021            },
3022        ]
3023    );
3024
3025    fake_model.send_last_completion_stream_text_chunk("Done");
3026    fake_model.end_last_completion_stream();
3027    cx.run_until_parked();
3028    events.collect::<Vec<_>>().await;
3029    thread.read_with(cx, |thread, _cx| {
3030        assert_eq!(
3031            thread.last_message(),
3032            Some(Message::Agent(AgentMessage {
3033                content: vec![AgentMessageContent::Text("Done".into())],
3034                tool_results: IndexMap::default(),
3035                reasoning_details: None,
3036            }))
3037        );
3038    })
3039}
3040
3041#[gpui::test]
3042async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
3043    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
3044    let fake_model = model.as_fake();
3045
3046    let mut events = thread
3047        .update(cx, |thread, cx| {
3048            thread.send(UserMessageId::new(), ["Hello!"], cx)
3049        })
3050        .unwrap();
3051    cx.run_until_parked();
3052
3053    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
3054        fake_model.send_last_completion_stream_error(
3055            LanguageModelCompletionError::ServerOverloaded {
3056                provider: LanguageModelProviderName::new("Anthropic"),
3057                retry_after: Some(Duration::from_secs(3)),
3058            },
3059        );
3060        fake_model.end_last_completion_stream();
3061        cx.executor().advance_clock(Duration::from_secs(3));
3062        cx.run_until_parked();
3063    }
3064
3065    let mut errors = Vec::new();
3066    let mut retry_events = Vec::new();
3067    while let Some(event) = events.next().await {
3068        match event {
3069            Ok(ThreadEvent::Retry(retry_status)) => {
3070                retry_events.push(retry_status);
3071            }
3072            Ok(ThreadEvent::Stop(..)) => break,
3073            Err(error) => errors.push(error),
3074            _ => {}
3075        }
3076    }
3077
3078    assert_eq!(
3079        retry_events.len(),
3080        crate::thread::MAX_RETRY_ATTEMPTS as usize
3081    );
3082    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
3083        assert_eq!(retry_events[i].attempt, i + 1);
3084    }
3085    assert_eq!(errors.len(), 1);
3086    let error = errors[0]
3087        .downcast_ref::<LanguageModelCompletionError>()
3088        .unwrap();
3089    assert!(matches!(
3090        error,
3091        LanguageModelCompletionError::ServerOverloaded { .. }
3092    ));
3093}
3094
3095/// Filters out the stop events for asserting against in tests
3096fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
3097    result_events
3098        .into_iter()
3099        .filter_map(|event| match event.unwrap() {
3100            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
3101            _ => None,
3102        })
3103        .collect()
3104}
3105
3106struct ThreadTest {
3107    model: Arc<dyn LanguageModel>,
3108    thread: Entity<Thread>,
3109    project_context: Entity<ProjectContext>,
3110    context_server_store: Entity<ContextServerStore>,
3111    fs: Arc<FakeFs>,
3112}
3113
3114enum TestModel {
3115    Sonnet4,
3116    Fake,
3117}
3118
3119impl TestModel {
3120    fn id(&self) -> LanguageModelId {
3121        match self {
3122            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
3123            TestModel::Fake => unreachable!(),
3124        }
3125    }
3126}
3127
3128async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
3129    cx.executor().allow_parking();
3130
3131    let fs = FakeFs::new(cx.background_executor.clone());
3132    fs.create_dir(paths::settings_file().parent().unwrap())
3133        .await
3134        .unwrap();
3135    fs.insert_file(
3136        paths::settings_file(),
3137        json!({
3138            "agent": {
3139                "default_profile": "test-profile",
3140                "profiles": {
3141                    "test-profile": {
3142                        "name": "Test Profile",
3143                        "tools": {
3144                            EchoTool::name(): true,
3145                            DelayTool::name(): true,
3146                            WordListTool::name(): true,
3147                            ToolRequiringPermission::name(): true,
3148                            InfiniteTool::name(): true,
3149                            CancellationAwareTool::name(): true,
3150                            ThinkingTool::name(): true,
3151                            "terminal": true,
3152                        }
3153                    }
3154                }
3155            }
3156        })
3157        .to_string()
3158        .into_bytes(),
3159    )
3160    .await;
3161
3162    cx.update(|cx| {
3163        settings::init(cx);
3164
3165        match model {
3166            TestModel::Fake => {}
3167            TestModel::Sonnet4 => {
3168                gpui_tokio::init(cx);
3169                let http_client = ReqwestClient::user_agent("agent tests").unwrap();
3170                cx.set_http_client(Arc::new(http_client));
3171                let client = Client::production(cx);
3172                let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
3173                language_model::init(client.clone(), cx);
3174                language_models::init(user_store, client.clone(), cx);
3175            }
3176        };
3177
3178        watch_settings(fs.clone(), cx);
3179    });
3180
3181    let templates = Templates::new();
3182
3183    fs.insert_tree(path!("/test"), json!({})).await;
3184    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
3185
3186    let model = cx
3187        .update(|cx| {
3188            if let TestModel::Fake = model {
3189                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
3190            } else {
3191                let model_id = model.id();
3192                let models = LanguageModelRegistry::read_global(cx);
3193                let model = models
3194                    .available_models(cx)
3195                    .find(|model| model.id() == model_id)
3196                    .unwrap();
3197
3198                let provider = models.provider(&model.provider_id()).unwrap();
3199                let authenticated = provider.authenticate(cx);
3200
3201                cx.spawn(async move |_cx| {
3202                    authenticated.await.unwrap();
3203                    model
3204                })
3205            }
3206        })
3207        .await;
3208
3209    let project_context = cx.new(|_cx| ProjectContext::default());
3210    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3211    let context_server_registry =
3212        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3213    let thread = cx.new(|cx| {
3214        Thread::new(
3215            project,
3216            project_context.clone(),
3217            context_server_registry,
3218            templates,
3219            Some(model.clone()),
3220            cx,
3221        )
3222    });
3223    ThreadTest {
3224        model,
3225        thread,
3226        project_context,
3227        context_server_store,
3228        fs,
3229    }
3230}
3231
3232#[cfg(test)]
3233#[ctor::ctor]
3234fn init_logger() {
3235    if std::env::var("RUST_LOG").is_ok() {
3236        env_logger::init();
3237    }
3238}
3239
3240fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
3241    let fs = fs.clone();
3242    cx.spawn({
3243        async move |cx| {
3244            let mut new_settings_content_rx = settings::watch_config_file(
3245                cx.background_executor(),
3246                fs,
3247                paths::settings_file().clone(),
3248            );
3249
3250            while let Some(new_settings_content) = new_settings_content_rx.next().await {
3251                cx.update(|cx| {
3252                    SettingsStore::update_global(cx, |settings, cx| {
3253                        settings.set_user_settings(&new_settings_content, cx)
3254                    })
3255                })
3256                .ok();
3257            }
3258        }
3259    })
3260    .detach();
3261}
3262
3263fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
3264    completion
3265        .tools
3266        .iter()
3267        .map(|tool| tool.name.clone())
3268        .collect()
3269}
3270
3271fn setup_context_server(
3272    name: &'static str,
3273    tools: Vec<context_server::types::Tool>,
3274    context_server_store: &Entity<ContextServerStore>,
3275    cx: &mut TestAppContext,
3276) -> mpsc::UnboundedReceiver<(
3277    context_server::types::CallToolParams,
3278    oneshot::Sender<context_server::types::CallToolResponse>,
3279)> {
3280    cx.update(|cx| {
3281        let mut settings = ProjectSettings::get_global(cx).clone();
3282        settings.context_servers.insert(
3283            name.into(),
3284            project::project_settings::ContextServerSettings::Stdio {
3285                enabled: true,
3286                remote: false,
3287                command: ContextServerCommand {
3288                    path: "somebinary".into(),
3289                    args: Vec::new(),
3290                    env: None,
3291                    timeout: None,
3292                },
3293            },
3294        );
3295        ProjectSettings::override_global(settings, cx);
3296    });
3297
3298    let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
3299    let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
3300        .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
3301            context_server::types::InitializeResponse {
3302                protocol_version: context_server::types::ProtocolVersion(
3303                    context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
3304                ),
3305                server_info: context_server::types::Implementation {
3306                    name: name.into(),
3307                    version: "1.0.0".to_string(),
3308                },
3309                capabilities: context_server::types::ServerCapabilities {
3310                    tools: Some(context_server::types::ToolsCapabilities {
3311                        list_changed: Some(true),
3312                    }),
3313                    ..Default::default()
3314                },
3315                meta: None,
3316            }
3317        })
3318        .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
3319            let tools = tools.clone();
3320            async move {
3321                context_server::types::ListToolsResponse {
3322                    tools,
3323                    next_cursor: None,
3324                    meta: None,
3325                }
3326            }
3327        })
3328        .on_request::<context_server::types::requests::CallTool, _>(move |params| {
3329            let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
3330            async move {
3331                let (response_tx, response_rx) = oneshot::channel();
3332                mcp_tool_calls_tx
3333                    .unbounded_send((params, response_tx))
3334                    .unwrap();
3335                response_rx.await.unwrap()
3336            }
3337        });
3338    context_server_store.update(cx, |store, cx| {
3339        store.start_server(
3340            Arc::new(ContextServer::new(
3341                ContextServerId(name.into()),
3342                Arc::new(fake_transport),
3343            )),
3344            cx,
3345        );
3346    });
3347    cx.run_until_parked();
3348    mcp_tool_calls_rx
3349}
3350
3351#[gpui::test]
3352async fn test_tokens_before_message(cx: &mut TestAppContext) {
3353    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3354    let fake_model = model.as_fake();
3355
3356    // First message
3357    let message_1_id = UserMessageId::new();
3358    thread
3359        .update(cx, |thread, cx| {
3360            thread.send(message_1_id.clone(), ["First message"], cx)
3361        })
3362        .unwrap();
3363    cx.run_until_parked();
3364
3365    // Before any response, tokens_before_message should return None for first message
3366    thread.read_with(cx, |thread, _| {
3367        assert_eq!(
3368            thread.tokens_before_message(&message_1_id),
3369            None,
3370            "First message should have no tokens before it"
3371        );
3372    });
3373
3374    // Complete first message with usage
3375    fake_model.send_last_completion_stream_text_chunk("Response 1");
3376    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3377        language_model::TokenUsage {
3378            input_tokens: 100,
3379            output_tokens: 50,
3380            cache_creation_input_tokens: 0,
3381            cache_read_input_tokens: 0,
3382        },
3383    ));
3384    fake_model.end_last_completion_stream();
3385    cx.run_until_parked();
3386
3387    // First message still has no tokens before it
3388    thread.read_with(cx, |thread, _| {
3389        assert_eq!(
3390            thread.tokens_before_message(&message_1_id),
3391            None,
3392            "First message should still have no tokens before it after response"
3393        );
3394    });
3395
3396    // Second message
3397    let message_2_id = UserMessageId::new();
3398    thread
3399        .update(cx, |thread, cx| {
3400            thread.send(message_2_id.clone(), ["Second message"], cx)
3401        })
3402        .unwrap();
3403    cx.run_until_parked();
3404
3405    // Second message should have first message's input tokens before it
3406    thread.read_with(cx, |thread, _| {
3407        assert_eq!(
3408            thread.tokens_before_message(&message_2_id),
3409            Some(100),
3410            "Second message should have 100 tokens before it (from first request)"
3411        );
3412    });
3413
3414    // Complete second message
3415    fake_model.send_last_completion_stream_text_chunk("Response 2");
3416    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3417        language_model::TokenUsage {
3418            input_tokens: 250, // Total for this request (includes previous context)
3419            output_tokens: 75,
3420            cache_creation_input_tokens: 0,
3421            cache_read_input_tokens: 0,
3422        },
3423    ));
3424    fake_model.end_last_completion_stream();
3425    cx.run_until_parked();
3426
3427    // Third message
3428    let message_3_id = UserMessageId::new();
3429    thread
3430        .update(cx, |thread, cx| {
3431            thread.send(message_3_id.clone(), ["Third message"], cx)
3432        })
3433        .unwrap();
3434    cx.run_until_parked();
3435
3436    // Third message should have second message's input tokens (250) before it
3437    thread.read_with(cx, |thread, _| {
3438        assert_eq!(
3439            thread.tokens_before_message(&message_3_id),
3440            Some(250),
3441            "Third message should have 250 tokens before it (from second request)"
3442        );
3443        // Second message should still have 100
3444        assert_eq!(
3445            thread.tokens_before_message(&message_2_id),
3446            Some(100),
3447            "Second message should still have 100 tokens before it"
3448        );
3449        // First message still has none
3450        assert_eq!(
3451            thread.tokens_before_message(&message_1_id),
3452            None,
3453            "First message should still have no tokens before it"
3454        );
3455    });
3456}
3457
3458#[gpui::test]
3459async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
3460    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3461    let fake_model = model.as_fake();
3462
3463    // Set up three messages with responses
3464    let message_1_id = UserMessageId::new();
3465    thread
3466        .update(cx, |thread, cx| {
3467            thread.send(message_1_id.clone(), ["Message 1"], cx)
3468        })
3469        .unwrap();
3470    cx.run_until_parked();
3471    fake_model.send_last_completion_stream_text_chunk("Response 1");
3472    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3473        language_model::TokenUsage {
3474            input_tokens: 100,
3475            output_tokens: 50,
3476            cache_creation_input_tokens: 0,
3477            cache_read_input_tokens: 0,
3478        },
3479    ));
3480    fake_model.end_last_completion_stream();
3481    cx.run_until_parked();
3482
3483    let message_2_id = UserMessageId::new();
3484    thread
3485        .update(cx, |thread, cx| {
3486            thread.send(message_2_id.clone(), ["Message 2"], cx)
3487        })
3488        .unwrap();
3489    cx.run_until_parked();
3490    fake_model.send_last_completion_stream_text_chunk("Response 2");
3491    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
3492        language_model::TokenUsage {
3493            input_tokens: 250,
3494            output_tokens: 75,
3495            cache_creation_input_tokens: 0,
3496            cache_read_input_tokens: 0,
3497        },
3498    ));
3499    fake_model.end_last_completion_stream();
3500    cx.run_until_parked();
3501
3502    // Verify initial state
3503    thread.read_with(cx, |thread, _| {
3504        assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
3505    });
3506
3507    // Truncate at message 2 (removes message 2 and everything after)
3508    thread
3509        .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
3510        .unwrap();
3511    cx.run_until_parked();
3512
3513    // After truncation, message_2_id no longer exists, so lookup should return None
3514    thread.read_with(cx, |thread, _| {
3515        assert_eq!(
3516            thread.tokens_before_message(&message_2_id),
3517            None,
3518            "After truncation, message 2 no longer exists"
3519        );
3520        // Message 1 still exists but has no tokens before it
3521        assert_eq!(
3522            thread.tokens_before_message(&message_1_id),
3523            None,
3524            "First message still has no tokens before it"
3525        );
3526    });
3527}
3528
3529#[gpui::test]
3530async fn test_terminal_tool_permission_rules(cx: &mut TestAppContext) {
3531    init_test(cx);
3532
3533    let fs = FakeFs::new(cx.executor());
3534    fs.insert_tree("/root", json!({})).await;
3535    let project = Project::test(fs, ["/root".as_ref()], cx).await;
3536
3537    // Test 1: Deny rule blocks command
3538    {
3539        let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3540        let environment = Rc::new(FakeThreadEnvironment {
3541            handle: handle.clone(),
3542        });
3543
3544        cx.update(|cx| {
3545            let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3546            settings.tool_permissions.tools.insert(
3547                "terminal".into(),
3548                agent_settings::ToolRules {
3549                    default_mode: settings::ToolPermissionMode::Confirm,
3550                    always_allow: vec![],
3551                    always_deny: vec![
3552                        agent_settings::CompiledRegex::new(r"rm\s+-rf", false).unwrap(),
3553                    ],
3554                    always_confirm: vec![],
3555                    invalid_patterns: vec![],
3556                },
3557            );
3558            agent_settings::AgentSettings::override_global(settings, cx);
3559        });
3560
3561        #[allow(clippy::arc_with_non_send_sync)]
3562        let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3563        let (event_stream, _rx) = crate::ToolCallEventStream::test();
3564
3565        let task = cx.update(|cx| {
3566            tool.run(
3567                crate::TerminalToolInput {
3568                    command: "rm -rf /".to_string(),
3569                    cd: ".".to_string(),
3570                    timeout_ms: None,
3571                },
3572                event_stream,
3573                cx,
3574            )
3575        });
3576
3577        let result = task.await;
3578        assert!(
3579            result.is_err(),
3580            "expected command to be blocked by deny rule"
3581        );
3582        assert!(
3583            result.unwrap_err().to_string().contains("blocked"),
3584            "error should mention the command was blocked"
3585        );
3586    }
3587
3588    // Test 2: Allow rule skips confirmation (and overrides default_mode: Deny)
3589    {
3590        let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3591        let environment = Rc::new(FakeThreadEnvironment {
3592            handle: handle.clone(),
3593        });
3594
3595        cx.update(|cx| {
3596            let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3597            settings.always_allow_tool_actions = false;
3598            settings.tool_permissions.tools.insert(
3599                "terminal".into(),
3600                agent_settings::ToolRules {
3601                    default_mode: settings::ToolPermissionMode::Deny,
3602                    always_allow: vec![
3603                        agent_settings::CompiledRegex::new(r"^echo\s", false).unwrap(),
3604                    ],
3605                    always_deny: vec![],
3606                    always_confirm: vec![],
3607                    invalid_patterns: vec![],
3608                },
3609            );
3610            agent_settings::AgentSettings::override_global(settings, cx);
3611        });
3612
3613        #[allow(clippy::arc_with_non_send_sync)]
3614        let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3615        let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3616
3617        let task = cx.update(|cx| {
3618            tool.run(
3619                crate::TerminalToolInput {
3620                    command: "echo hello".to_string(),
3621                    cd: ".".to_string(),
3622                    timeout_ms: None,
3623                },
3624                event_stream,
3625                cx,
3626            )
3627        });
3628
3629        let update = rx.expect_update_fields().await;
3630        assert!(
3631            update.content.iter().any(|blocks| {
3632                blocks
3633                    .iter()
3634                    .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
3635            }),
3636            "expected terminal content (allow rule should skip confirmation and override default deny)"
3637        );
3638
3639        let result = task.await;
3640        assert!(
3641            result.is_ok(),
3642            "expected command to succeed without confirmation"
3643        );
3644    }
3645
3646    // Test 3: Confirm rule forces confirmation even with always_allow_tool_actions=true
3647    {
3648        let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_with_immediate_exit(cx, 0)));
3649        let environment = Rc::new(FakeThreadEnvironment {
3650            handle: handle.clone(),
3651        });
3652
3653        cx.update(|cx| {
3654            let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3655            settings.always_allow_tool_actions = true;
3656            settings.tool_permissions.tools.insert(
3657                "terminal".into(),
3658                agent_settings::ToolRules {
3659                    default_mode: settings::ToolPermissionMode::Allow,
3660                    always_allow: vec![],
3661                    always_deny: vec![],
3662                    always_confirm: vec![
3663                        agent_settings::CompiledRegex::new(r"sudo", false).unwrap(),
3664                    ],
3665                    invalid_patterns: vec![],
3666                },
3667            );
3668            agent_settings::AgentSettings::override_global(settings, cx);
3669        });
3670
3671        #[allow(clippy::arc_with_non_send_sync)]
3672        let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3673        let (event_stream, mut rx) = crate::ToolCallEventStream::test();
3674
3675        let _task = cx.update(|cx| {
3676            tool.run(
3677                crate::TerminalToolInput {
3678                    command: "sudo rm file".to_string(),
3679                    cd: ".".to_string(),
3680                    timeout_ms: None,
3681                },
3682                event_stream,
3683                cx,
3684            )
3685        });
3686
3687        let auth = rx.expect_authorization().await;
3688        assert!(
3689            auth.tool_call.fields.title.is_some(),
3690            "expected authorization request for sudo command despite always_allow_tool_actions=true"
3691        );
3692    }
3693
3694    // Test 4: default_mode: Deny blocks commands when no pattern matches
3695    {
3696        let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3697        let environment = Rc::new(FakeThreadEnvironment {
3698            handle: handle.clone(),
3699        });
3700
3701        cx.update(|cx| {
3702            let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
3703            settings.always_allow_tool_actions = true;
3704            settings.tool_permissions.tools.insert(
3705                "terminal".into(),
3706                agent_settings::ToolRules {
3707                    default_mode: settings::ToolPermissionMode::Deny,
3708                    always_allow: vec![],
3709                    always_deny: vec![],
3710                    always_confirm: vec![],
3711                    invalid_patterns: vec![],
3712                },
3713            );
3714            agent_settings::AgentSettings::override_global(settings, cx);
3715        });
3716
3717        #[allow(clippy::arc_with_non_send_sync)]
3718        let tool = Arc::new(crate::TerminalTool::new(project.clone(), environment));
3719        let (event_stream, _rx) = crate::ToolCallEventStream::test();
3720
3721        let task = cx.update(|cx| {
3722            tool.run(
3723                crate::TerminalToolInput {
3724                    command: "echo hello".to_string(),
3725                    cd: ".".to_string(),
3726                    timeout_ms: None,
3727                },
3728                event_stream,
3729                cx,
3730            )
3731        });
3732
3733        let result = task.await;
3734        assert!(
3735            result.is_err(),
3736            "expected command to be blocked by default_mode: Deny"
3737        );
3738        assert!(
3739            result.unwrap_err().to_string().contains("disabled"),
3740            "error should mention the tool is disabled"
3741        );
3742    }
3743}
3744
3745#[gpui::test]
3746async fn test_subagent_tool_is_present_when_feature_flag_enabled(cx: &mut TestAppContext) {
3747    init_test(cx);
3748
3749    cx.update(|cx| {
3750        cx.update_flags(true, vec!["subagents".to_string()]);
3751    });
3752
3753    let fs = FakeFs::new(cx.executor());
3754    fs.insert_tree(path!("/test"), json!({})).await;
3755    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3756    let project_context = cx.new(|_cx| ProjectContext::default());
3757    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3758    let context_server_registry =
3759        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3760    let model = Arc::new(FakeLanguageModel::default());
3761
3762    let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3763    let environment = Rc::new(FakeThreadEnvironment { handle });
3764
3765    let thread = cx.new(|cx| {
3766        let mut thread = Thread::new(
3767            project.clone(),
3768            project_context,
3769            context_server_registry,
3770            Templates::new(),
3771            Some(model),
3772            cx,
3773        );
3774        thread.add_default_tools(environment, cx);
3775        thread
3776    });
3777
3778    thread.read_with(cx, |thread, _| {
3779        assert!(
3780            thread.has_registered_tool("subagent"),
3781            "subagent tool should be present when feature flag is enabled"
3782        );
3783    });
3784}
3785
3786#[gpui::test]
3787async fn test_subagent_thread_inherits_parent_model(cx: &mut TestAppContext) {
3788    init_test(cx);
3789
3790    cx.update(|cx| {
3791        cx.update_flags(true, vec!["subagents".to_string()]);
3792    });
3793
3794    let fs = FakeFs::new(cx.executor());
3795    fs.insert_tree(path!("/test"), json!({})).await;
3796    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3797    let project_context = cx.new(|_cx| ProjectContext::default());
3798    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3799    let context_server_registry =
3800        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3801    let model = Arc::new(FakeLanguageModel::default());
3802
3803    let subagent_context = SubagentContext {
3804        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3805        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3806        depth: 1,
3807        summary_prompt: "Summarize".to_string(),
3808        context_low_prompt: "Context low".to_string(),
3809    };
3810
3811    let subagent = cx.new(|cx| {
3812        Thread::new_subagent(
3813            project.clone(),
3814            project_context,
3815            context_server_registry,
3816            Templates::new(),
3817            model.clone(),
3818            subagent_context,
3819            std::collections::BTreeMap::new(),
3820            cx,
3821        )
3822    });
3823
3824    subagent.read_with(cx, |thread, _| {
3825        assert!(thread.is_subagent());
3826        assert_eq!(thread.depth(), 1);
3827        assert!(thread.model().is_some());
3828    });
3829}
3830
3831#[gpui::test]
3832async fn test_max_subagent_depth_prevents_tool_registration(cx: &mut TestAppContext) {
3833    init_test(cx);
3834
3835    cx.update(|cx| {
3836        cx.update_flags(true, vec!["subagents".to_string()]);
3837    });
3838
3839    let fs = FakeFs::new(cx.executor());
3840    fs.insert_tree(path!("/test"), json!({})).await;
3841    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
3842    let project_context = cx.new(|_cx| ProjectContext::default());
3843    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3844    let context_server_registry =
3845        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3846    let model = Arc::new(FakeLanguageModel::default());
3847
3848    let subagent_context = SubagentContext {
3849        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3850        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3851        depth: MAX_SUBAGENT_DEPTH,
3852        summary_prompt: "Summarize".to_string(),
3853        context_low_prompt: "Context low".to_string(),
3854    };
3855
3856    let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
3857    let environment = Rc::new(FakeThreadEnvironment { handle });
3858
3859    let deep_subagent = cx.new(|cx| {
3860        let mut thread = Thread::new_subagent(
3861            project.clone(),
3862            project_context,
3863            context_server_registry,
3864            Templates::new(),
3865            model.clone(),
3866            subagent_context,
3867            std::collections::BTreeMap::new(),
3868            cx,
3869        );
3870        thread.add_default_tools(environment, cx);
3871        thread
3872    });
3873
3874    deep_subagent.read_with(cx, |thread, _| {
3875        assert_eq!(thread.depth(), MAX_SUBAGENT_DEPTH);
3876        assert!(
3877            !thread.has_registered_tool("subagent"),
3878            "subagent tool should not be present at max depth"
3879        );
3880    });
3881}
3882
3883#[gpui::test]
3884async fn test_subagent_receives_task_prompt(cx: &mut TestAppContext) {
3885    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3886    let fake_model = model.as_fake();
3887
3888    cx.update(|cx| {
3889        cx.update_flags(true, vec!["subagents".to_string()]);
3890    });
3891
3892    let subagent_context = SubagentContext {
3893        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3894        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3895        depth: 1,
3896        summary_prompt: "Summarize your work".to_string(),
3897        context_low_prompt: "Context low, wrap up".to_string(),
3898    };
3899
3900    let project = thread.read_with(cx, |t, _| t.project.clone());
3901    let project_context = cx.new(|_cx| ProjectContext::default());
3902    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3903    let context_server_registry =
3904        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3905
3906    let subagent = cx.new(|cx| {
3907        Thread::new_subagent(
3908            project.clone(),
3909            project_context,
3910            context_server_registry,
3911            Templates::new(),
3912            model.clone(),
3913            subagent_context,
3914            std::collections::BTreeMap::new(),
3915            cx,
3916        )
3917    });
3918
3919    let task_prompt = "Find all TODO comments in the codebase";
3920    subagent
3921        .update(cx, |thread, cx| thread.submit_user_message(task_prompt, cx))
3922        .unwrap();
3923    cx.run_until_parked();
3924
3925    let pending = fake_model.pending_completions();
3926    assert_eq!(pending.len(), 1, "should have one pending completion");
3927
3928    let messages = &pending[0].messages;
3929    let user_messages: Vec<_> = messages
3930        .iter()
3931        .filter(|m| m.role == language_model::Role::User)
3932        .collect();
3933    assert_eq!(user_messages.len(), 1, "should have one user message");
3934
3935    let content = &user_messages[0].content[0];
3936    assert!(
3937        content.to_str().unwrap().contains("TODO"),
3938        "task prompt should be in user message"
3939    );
3940}
3941
3942#[gpui::test]
3943async fn test_subagent_returns_summary_on_completion(cx: &mut TestAppContext) {
3944    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
3945    let fake_model = model.as_fake();
3946
3947    cx.update(|cx| {
3948        cx.update_flags(true, vec!["subagents".to_string()]);
3949    });
3950
3951    let subagent_context = SubagentContext {
3952        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
3953        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
3954        depth: 1,
3955        summary_prompt: "Please summarize what you found".to_string(),
3956        context_low_prompt: "Context low, wrap up".to_string(),
3957    };
3958
3959    let project = thread.read_with(cx, |t, _| t.project.clone());
3960    let project_context = cx.new(|_cx| ProjectContext::default());
3961    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
3962    let context_server_registry =
3963        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
3964
3965    let subagent = cx.new(|cx| {
3966        Thread::new_subagent(
3967            project.clone(),
3968            project_context,
3969            context_server_registry,
3970            Templates::new(),
3971            model.clone(),
3972            subagent_context,
3973            std::collections::BTreeMap::new(),
3974            cx,
3975        )
3976    });
3977
3978    subagent
3979        .update(cx, |thread, cx| {
3980            thread.submit_user_message("Do some work", cx)
3981        })
3982        .unwrap();
3983    cx.run_until_parked();
3984
3985    fake_model.send_last_completion_stream_text_chunk("I did the work");
3986    fake_model
3987        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
3988    fake_model.end_last_completion_stream();
3989    cx.run_until_parked();
3990
3991    subagent
3992        .update(cx, |thread, cx| thread.request_final_summary(cx))
3993        .unwrap();
3994    cx.run_until_parked();
3995
3996    let pending = fake_model.pending_completions();
3997    assert!(
3998        !pending.is_empty(),
3999        "should have pending completion for summary"
4000    );
4001
4002    let messages = &pending.last().unwrap().messages;
4003    let user_messages: Vec<_> = messages
4004        .iter()
4005        .filter(|m| m.role == language_model::Role::User)
4006        .collect();
4007
4008    let last_user = user_messages.last().unwrap();
4009    assert!(
4010        last_user.content[0].to_str().unwrap().contains("summarize"),
4011        "summary prompt should be sent"
4012    );
4013}
4014
4015#[gpui::test]
4016async fn test_allowed_tools_restricts_subagent_capabilities(cx: &mut TestAppContext) {
4017    init_test(cx);
4018
4019    cx.update(|cx| {
4020        cx.update_flags(true, vec!["subagents".to_string()]);
4021    });
4022
4023    let fs = FakeFs::new(cx.executor());
4024    fs.insert_tree(path!("/test"), json!({})).await;
4025    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4026    let project_context = cx.new(|_cx| ProjectContext::default());
4027    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4028    let context_server_registry =
4029        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4030    let model = Arc::new(FakeLanguageModel::default());
4031
4032    let subagent_context = SubagentContext {
4033        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4034        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4035        depth: 1,
4036        summary_prompt: "Summarize".to_string(),
4037        context_low_prompt: "Context low".to_string(),
4038    };
4039
4040    let subagent = cx.new(|cx| {
4041        let mut thread = Thread::new_subagent(
4042            project.clone(),
4043            project_context,
4044            context_server_registry,
4045            Templates::new(),
4046            model.clone(),
4047            subagent_context,
4048            std::collections::BTreeMap::new(),
4049            cx,
4050        );
4051        thread.add_tool(EchoTool);
4052        thread.add_tool(DelayTool);
4053        thread.add_tool(WordListTool);
4054        thread
4055    });
4056
4057    subagent.read_with(cx, |thread, _| {
4058        assert!(thread.has_registered_tool("echo"));
4059        assert!(thread.has_registered_tool("delay"));
4060        assert!(thread.has_registered_tool("word_list"));
4061    });
4062
4063    let allowed: collections::HashSet<gpui::SharedString> =
4064        vec!["echo".into()].into_iter().collect();
4065
4066    subagent.update(cx, |thread, _cx| {
4067        thread.restrict_tools(&allowed);
4068    });
4069
4070    subagent.read_with(cx, |thread, _| {
4071        assert!(
4072            thread.has_registered_tool("echo"),
4073            "echo should still be available"
4074        );
4075        assert!(
4076            !thread.has_registered_tool("delay"),
4077            "delay should be removed"
4078        );
4079        assert!(
4080            !thread.has_registered_tool("word_list"),
4081            "word_list should be removed"
4082        );
4083    });
4084}
4085
4086#[gpui::test]
4087async fn test_parent_cancel_stops_subagent(cx: &mut TestAppContext) {
4088    init_test(cx);
4089
4090    cx.update(|cx| {
4091        cx.update_flags(true, vec!["subagents".to_string()]);
4092    });
4093
4094    let fs = FakeFs::new(cx.executor());
4095    fs.insert_tree(path!("/test"), json!({})).await;
4096    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4097    let project_context = cx.new(|_cx| ProjectContext::default());
4098    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4099    let context_server_registry =
4100        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4101    let model = Arc::new(FakeLanguageModel::default());
4102
4103    let parent = cx.new(|cx| {
4104        Thread::new(
4105            project.clone(),
4106            project_context.clone(),
4107            context_server_registry.clone(),
4108            Templates::new(),
4109            Some(model.clone()),
4110            cx,
4111        )
4112    });
4113
4114    let subagent_context = SubagentContext {
4115        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4116        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4117        depth: 1,
4118        summary_prompt: "Summarize".to_string(),
4119        context_low_prompt: "Context low".to_string(),
4120    };
4121
4122    let subagent = cx.new(|cx| {
4123        Thread::new_subagent(
4124            project.clone(),
4125            project_context.clone(),
4126            context_server_registry.clone(),
4127            Templates::new(),
4128            model.clone(),
4129            subagent_context,
4130            std::collections::BTreeMap::new(),
4131            cx,
4132        )
4133    });
4134
4135    parent.update(cx, |thread, _cx| {
4136        thread.register_running_subagent(subagent.downgrade());
4137    });
4138
4139    subagent
4140        .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4141        .unwrap();
4142    cx.run_until_parked();
4143
4144    subagent.read_with(cx, |thread, _| {
4145        assert!(!thread.is_turn_complete(), "subagent should be running");
4146    });
4147
4148    parent.update(cx, |thread, cx| {
4149        thread.cancel(cx).detach();
4150    });
4151
4152    subagent.read_with(cx, |thread, _| {
4153        assert!(
4154            thread.is_turn_complete(),
4155            "subagent should be cancelled when parent cancels"
4156        );
4157    });
4158}
4159
4160#[gpui::test]
4161async fn test_subagent_tool_cancellation(cx: &mut TestAppContext) {
4162    // This test verifies that the subagent tool properly handles user cancellation
4163    // via `event_stream.cancelled_by_user()` and stops all running subagents.
4164    init_test(cx);
4165    always_allow_tools(cx);
4166
4167    cx.update(|cx| {
4168        cx.update_flags(true, vec!["subagents".to_string()]);
4169    });
4170
4171    let fs = FakeFs::new(cx.executor());
4172    fs.insert_tree(path!("/test"), json!({})).await;
4173    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4174    let project_context = cx.new(|_cx| ProjectContext::default());
4175    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4176    let context_server_registry =
4177        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4178    let model = Arc::new(FakeLanguageModel::default());
4179
4180    let parent = cx.new(|cx| {
4181        Thread::new(
4182            project.clone(),
4183            project_context.clone(),
4184            context_server_registry.clone(),
4185            Templates::new(),
4186            Some(model.clone()),
4187            cx,
4188        )
4189    });
4190
4191    let parent_tools: std::collections::BTreeMap<gpui::SharedString, Arc<dyn crate::AnyAgentTool>> =
4192        std::collections::BTreeMap::new();
4193
4194    #[allow(clippy::arc_with_non_send_sync)]
4195    let tool = Arc::new(SubagentTool::new(
4196        parent.downgrade(),
4197        project.clone(),
4198        project_context,
4199        context_server_registry,
4200        Templates::new(),
4201        0,
4202        parent_tools,
4203    ));
4204
4205    let (event_stream, _rx, mut cancellation_tx) =
4206        crate::ToolCallEventStream::test_with_cancellation();
4207
4208    // Start the subagent tool
4209    let task = cx.update(|cx| {
4210        tool.run(
4211            SubagentToolInput {
4212                subagents: vec![crate::SubagentConfig {
4213                    label: "Long running task".to_string(),
4214                    task_prompt: "Do a very long task that takes forever".to_string(),
4215                    summary_prompt: "Summarize".to_string(),
4216                    context_low_prompt: "Context low".to_string(),
4217                    timeout_ms: None,
4218                    allowed_tools: None,
4219                }],
4220            },
4221            event_stream.clone(),
4222            cx,
4223        )
4224    });
4225
4226    cx.run_until_parked();
4227
4228    // Signal cancellation via the event stream
4229    crate::ToolCallEventStream::signal_cancellation_with_sender(&mut cancellation_tx);
4230
4231    // The task should complete promptly with a cancellation error
4232    let timeout = cx.background_executor.timer(Duration::from_secs(5));
4233    let result = futures::select! {
4234        result = task.fuse() => result,
4235        _ = timeout.fuse() => {
4236            panic!("subagent tool did not respond to cancellation within timeout");
4237        }
4238    };
4239
4240    // Verify we got a cancellation error
4241    let err = result.unwrap_err();
4242    assert!(
4243        err.to_string().contains("cancelled by user"),
4244        "expected cancellation error, got: {}",
4245        err
4246    );
4247}
4248
4249#[gpui::test]
4250async fn test_subagent_model_error_returned_as_tool_error(cx: &mut TestAppContext) {
4251    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4252    let fake_model = model.as_fake();
4253
4254    cx.update(|cx| {
4255        cx.update_flags(true, vec!["subagents".to_string()]);
4256    });
4257
4258    let subagent_context = SubagentContext {
4259        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4260        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4261        depth: 1,
4262        summary_prompt: "Summarize".to_string(),
4263        context_low_prompt: "Context low".to_string(),
4264    };
4265
4266    let project = thread.read_with(cx, |t, _| t.project.clone());
4267    let project_context = cx.new(|_cx| ProjectContext::default());
4268    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4269    let context_server_registry =
4270        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4271
4272    let subagent = cx.new(|cx| {
4273        Thread::new_subagent(
4274            project.clone(),
4275            project_context,
4276            context_server_registry,
4277            Templates::new(),
4278            model.clone(),
4279            subagent_context,
4280            std::collections::BTreeMap::new(),
4281            cx,
4282        )
4283    });
4284
4285    subagent
4286        .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4287        .unwrap();
4288    cx.run_until_parked();
4289
4290    subagent.read_with(cx, |thread, _| {
4291        assert!(!thread.is_turn_complete(), "turn should be in progress");
4292    });
4293
4294    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::NoApiKey {
4295        provider: LanguageModelProviderName::from("Fake".to_string()),
4296    });
4297    fake_model.end_last_completion_stream();
4298    cx.run_until_parked();
4299
4300    subagent.read_with(cx, |thread, _| {
4301        assert!(
4302            thread.is_turn_complete(),
4303            "turn should be complete after non-retryable error"
4304        );
4305    });
4306}
4307
4308#[gpui::test]
4309async fn test_subagent_timeout_triggers_early_summary(cx: &mut TestAppContext) {
4310    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4311    let fake_model = model.as_fake();
4312
4313    cx.update(|cx| {
4314        cx.update_flags(true, vec!["subagents".to_string()]);
4315    });
4316
4317    let subagent_context = SubagentContext {
4318        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4319        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4320        depth: 1,
4321        summary_prompt: "Summarize your work".to_string(),
4322        context_low_prompt: "Context low, stop and summarize".to_string(),
4323    };
4324
4325    let project = thread.read_with(cx, |t, _| t.project.clone());
4326    let project_context = cx.new(|_cx| ProjectContext::default());
4327    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4328    let context_server_registry =
4329        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4330
4331    let subagent = cx.new(|cx| {
4332        Thread::new_subagent(
4333            project.clone(),
4334            project_context.clone(),
4335            context_server_registry.clone(),
4336            Templates::new(),
4337            model.clone(),
4338            subagent_context.clone(),
4339            std::collections::BTreeMap::new(),
4340            cx,
4341        )
4342    });
4343
4344    subagent.update(cx, |thread, _| {
4345        thread.add_tool(EchoTool);
4346    });
4347
4348    subagent
4349        .update(cx, |thread, cx| {
4350            thread.submit_user_message("Do some work", cx)
4351        })
4352        .unwrap();
4353    cx.run_until_parked();
4354
4355    fake_model.send_last_completion_stream_text_chunk("Working on it...");
4356    fake_model
4357        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4358    fake_model.end_last_completion_stream();
4359    cx.run_until_parked();
4360
4361    let interrupt_result = subagent.update(cx, |thread, cx| thread.interrupt_for_summary(cx));
4362    assert!(
4363        interrupt_result.is_ok(),
4364        "interrupt_for_summary should succeed"
4365    );
4366
4367    cx.run_until_parked();
4368
4369    let pending = fake_model.pending_completions();
4370    assert!(
4371        !pending.is_empty(),
4372        "should have pending completion for interrupted summary"
4373    );
4374
4375    let messages = &pending.last().unwrap().messages;
4376    let user_messages: Vec<_> = messages
4377        .iter()
4378        .filter(|m| m.role == language_model::Role::User)
4379        .collect();
4380
4381    let last_user = user_messages.last().unwrap();
4382    let content_str = last_user.content[0].to_str().unwrap();
4383    assert!(
4384        content_str.contains("Context low") || content_str.contains("stop and summarize"),
4385        "context_low_prompt should be sent when interrupting: got {:?}",
4386        content_str
4387    );
4388}
4389
4390#[gpui::test]
4391async fn test_context_low_check_returns_true_when_usage_high(cx: &mut TestAppContext) {
4392    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4393    let fake_model = model.as_fake();
4394
4395    cx.update(|cx| {
4396        cx.update_flags(true, vec!["subagents".to_string()]);
4397    });
4398
4399    let subagent_context = SubagentContext {
4400        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4401        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4402        depth: 1,
4403        summary_prompt: "Summarize".to_string(),
4404        context_low_prompt: "Context low".to_string(),
4405    };
4406
4407    let project = thread.read_with(cx, |t, _| t.project.clone());
4408    let project_context = cx.new(|_cx| ProjectContext::default());
4409    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4410    let context_server_registry =
4411        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4412
4413    let subagent = cx.new(|cx| {
4414        Thread::new_subagent(
4415            project.clone(),
4416            project_context,
4417            context_server_registry,
4418            Templates::new(),
4419            model.clone(),
4420            subagent_context,
4421            std::collections::BTreeMap::new(),
4422            cx,
4423        )
4424    });
4425
4426    subagent
4427        .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4428        .unwrap();
4429    cx.run_until_parked();
4430
4431    let max_tokens = model.max_token_count();
4432    let high_usage = language_model::TokenUsage {
4433        input_tokens: (max_tokens as f64 * 0.80) as u64,
4434        output_tokens: 0,
4435        cache_creation_input_tokens: 0,
4436        cache_read_input_tokens: 0,
4437    };
4438
4439    fake_model
4440        .send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(high_usage));
4441    fake_model.send_last_completion_stream_text_chunk("Working...");
4442    fake_model
4443        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4444    fake_model.end_last_completion_stream();
4445    cx.run_until_parked();
4446
4447    let usage = subagent.read_with(cx, |thread, _| thread.latest_token_usage());
4448    assert!(usage.is_some(), "should have token usage after completion");
4449
4450    let usage = usage.unwrap();
4451    let remaining_ratio = 1.0 - (usage.used_tokens as f32 / usage.max_tokens as f32);
4452    assert!(
4453        remaining_ratio <= 0.25,
4454        "remaining ratio should be at or below 25% (got {}%), indicating context is low",
4455        remaining_ratio * 100.0
4456    );
4457}
4458
4459#[gpui::test]
4460async fn test_allowed_tools_rejects_unknown_tool(cx: &mut TestAppContext) {
4461    init_test(cx);
4462
4463    cx.update(|cx| {
4464        cx.update_flags(true, vec!["subagents".to_string()]);
4465    });
4466
4467    let fs = FakeFs::new(cx.executor());
4468    fs.insert_tree(path!("/test"), json!({})).await;
4469    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4470    let project_context = cx.new(|_cx| ProjectContext::default());
4471    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4472    let context_server_registry =
4473        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4474    let model = Arc::new(FakeLanguageModel::default());
4475
4476    let parent = cx.new(|cx| {
4477        let mut thread = Thread::new(
4478            project.clone(),
4479            project_context.clone(),
4480            context_server_registry.clone(),
4481            Templates::new(),
4482            Some(model.clone()),
4483            cx,
4484        );
4485        thread.add_tool(EchoTool);
4486        thread
4487    });
4488
4489    let mut parent_tools: std::collections::BTreeMap<
4490        gpui::SharedString,
4491        Arc<dyn crate::AnyAgentTool>,
4492    > = std::collections::BTreeMap::new();
4493    parent_tools.insert("echo".into(), EchoTool.erase());
4494
4495    #[allow(clippy::arc_with_non_send_sync)]
4496    let tool = Arc::new(SubagentTool::new(
4497        parent.downgrade(),
4498        project,
4499        project_context,
4500        context_server_registry,
4501        Templates::new(),
4502        0,
4503        parent_tools,
4504    ));
4505
4506    let subagent_configs = vec![crate::SubagentConfig {
4507        label: "Test".to_string(),
4508        task_prompt: "Do something".to_string(),
4509        summary_prompt: "Summarize".to_string(),
4510        context_low_prompt: "Context low".to_string(),
4511        timeout_ms: None,
4512        allowed_tools: Some(vec!["nonexistent_tool".to_string()]),
4513    }];
4514    let result = tool.validate_subagents(&subagent_configs);
4515    assert!(result.is_err(), "should reject unknown tool");
4516    let err_msg = result.unwrap_err().to_string();
4517    assert!(
4518        err_msg.contains("nonexistent_tool"),
4519        "error should mention the invalid tool name: {}",
4520        err_msg
4521    );
4522    assert!(
4523        err_msg.contains("do not exist"),
4524        "error should explain the tool does not exist: {}",
4525        err_msg
4526    );
4527}
4528
4529#[gpui::test]
4530async fn test_subagent_empty_response_handled(cx: &mut TestAppContext) {
4531    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
4532    let fake_model = model.as_fake();
4533
4534    cx.update(|cx| {
4535        cx.update_flags(true, vec!["subagents".to_string()]);
4536    });
4537
4538    let subagent_context = SubagentContext {
4539        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4540        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4541        depth: 1,
4542        summary_prompt: "Summarize".to_string(),
4543        context_low_prompt: "Context low".to_string(),
4544    };
4545
4546    let project = thread.read_with(cx, |t, _| t.project.clone());
4547    let project_context = cx.new(|_cx| ProjectContext::default());
4548    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4549    let context_server_registry =
4550        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4551
4552    let subagent = cx.new(|cx| {
4553        Thread::new_subagent(
4554            project.clone(),
4555            project_context,
4556            context_server_registry,
4557            Templates::new(),
4558            model.clone(),
4559            subagent_context,
4560            std::collections::BTreeMap::new(),
4561            cx,
4562        )
4563    });
4564
4565    subagent
4566        .update(cx, |thread, cx| thread.submit_user_message("Do work", cx))
4567        .unwrap();
4568    cx.run_until_parked();
4569
4570    fake_model
4571        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4572    fake_model.end_last_completion_stream();
4573    cx.run_until_parked();
4574
4575    subagent.read_with(cx, |thread, _| {
4576        assert!(
4577            thread.is_turn_complete(),
4578            "turn should complete even with empty response"
4579        );
4580    });
4581}
4582
4583#[gpui::test]
4584async fn test_nested_subagent_at_depth_2_succeeds(cx: &mut TestAppContext) {
4585    init_test(cx);
4586
4587    cx.update(|cx| {
4588        cx.update_flags(true, vec!["subagents".to_string()]);
4589    });
4590
4591    let fs = FakeFs::new(cx.executor());
4592    fs.insert_tree(path!("/test"), json!({})).await;
4593    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4594    let project_context = cx.new(|_cx| ProjectContext::default());
4595    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4596    let context_server_registry =
4597        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4598    let model = Arc::new(FakeLanguageModel::default());
4599
4600    let depth_1_context = SubagentContext {
4601        parent_thread_id: agent_client_protocol::SessionId::new("root-id"),
4602        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-1"),
4603        depth: 1,
4604        summary_prompt: "Summarize".to_string(),
4605        context_low_prompt: "Context low".to_string(),
4606    };
4607
4608    let depth_1_subagent = cx.new(|cx| {
4609        Thread::new_subagent(
4610            project.clone(),
4611            project_context.clone(),
4612            context_server_registry.clone(),
4613            Templates::new(),
4614            model.clone(),
4615            depth_1_context,
4616            std::collections::BTreeMap::new(),
4617            cx,
4618        )
4619    });
4620
4621    depth_1_subagent.read_with(cx, |thread, _| {
4622        assert_eq!(thread.depth(), 1);
4623        assert!(thread.is_subagent());
4624    });
4625
4626    let depth_2_context = SubagentContext {
4627        parent_thread_id: agent_client_protocol::SessionId::new("depth-1-id"),
4628        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-2"),
4629        depth: 2,
4630        summary_prompt: "Summarize depth 2".to_string(),
4631        context_low_prompt: "Context low depth 2".to_string(),
4632    };
4633
4634    let depth_2_subagent = cx.new(|cx| {
4635        Thread::new_subagent(
4636            project.clone(),
4637            project_context.clone(),
4638            context_server_registry.clone(),
4639            Templates::new(),
4640            model.clone(),
4641            depth_2_context,
4642            std::collections::BTreeMap::new(),
4643            cx,
4644        )
4645    });
4646
4647    depth_2_subagent.read_with(cx, |thread, _| {
4648        assert_eq!(thread.depth(), 2);
4649        assert!(thread.is_subagent());
4650    });
4651
4652    depth_2_subagent
4653        .update(cx, |thread, cx| {
4654            thread.submit_user_message("Nested task", cx)
4655        })
4656        .unwrap();
4657    cx.run_until_parked();
4658
4659    let pending = model.as_fake().pending_completions();
4660    assert!(
4661        !pending.is_empty(),
4662        "depth-2 subagent should be able to submit messages"
4663    );
4664}
4665
4666#[gpui::test]
4667async fn test_subagent_uses_tool_and_returns_result(cx: &mut TestAppContext) {
4668    init_test(cx);
4669    always_allow_tools(cx);
4670
4671    cx.update(|cx| {
4672        cx.update_flags(true, vec!["subagents".to_string()]);
4673    });
4674
4675    let fs = FakeFs::new(cx.executor());
4676    fs.insert_tree(path!("/test"), json!({})).await;
4677    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4678    let project_context = cx.new(|_cx| ProjectContext::default());
4679    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4680    let context_server_registry =
4681        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4682    let model = Arc::new(FakeLanguageModel::default());
4683    let fake_model = model.as_fake();
4684
4685    let subagent_context = SubagentContext {
4686        parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4687        tool_use_id: language_model::LanguageModelToolUseId::from("tool-use-id"),
4688        depth: 1,
4689        summary_prompt: "Summarize what you did".to_string(),
4690        context_low_prompt: "Context low".to_string(),
4691    };
4692
4693    let subagent = cx.new(|cx| {
4694        let mut thread = Thread::new_subagent(
4695            project.clone(),
4696            project_context,
4697            context_server_registry,
4698            Templates::new(),
4699            model.clone(),
4700            subagent_context,
4701            std::collections::BTreeMap::new(),
4702            cx,
4703        );
4704        thread.add_tool(EchoTool);
4705        thread
4706    });
4707
4708    subagent.read_with(cx, |thread, _| {
4709        assert!(
4710            thread.has_registered_tool("echo"),
4711            "subagent should have echo tool"
4712        );
4713    });
4714
4715    subagent
4716        .update(cx, |thread, cx| {
4717            thread.submit_user_message("Use the echo tool to echo 'hello world'", cx)
4718        })
4719        .unwrap();
4720    cx.run_until_parked();
4721
4722    let tool_use = LanguageModelToolUse {
4723        id: "tool_call_1".into(),
4724        name: EchoTool::name().into(),
4725        raw_input: json!({"text": "hello world"}).to_string(),
4726        input: json!({"text": "hello world"}),
4727        is_input_complete: true,
4728        thought_signature: None,
4729    };
4730    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
4731    fake_model.end_last_completion_stream();
4732    cx.run_until_parked();
4733
4734    let pending = fake_model.pending_completions();
4735    assert!(
4736        !pending.is_empty(),
4737        "should have pending completion after tool use"
4738    );
4739
4740    let last_completion = pending.last().unwrap();
4741    let has_tool_result = last_completion.messages.iter().any(|m| {
4742        m.content
4743            .iter()
4744            .any(|c| matches!(c, MessageContent::ToolResult(_)))
4745    });
4746    assert!(
4747        has_tool_result,
4748        "tool result should be in the messages sent back to the model"
4749    );
4750}
4751
4752#[gpui::test]
4753async fn test_max_parallel_subagents_enforced(cx: &mut TestAppContext) {
4754    init_test(cx);
4755
4756    cx.update(|cx| {
4757        cx.update_flags(true, vec!["subagents".to_string()]);
4758    });
4759
4760    let fs = FakeFs::new(cx.executor());
4761    fs.insert_tree(path!("/test"), json!({})).await;
4762    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4763    let project_context = cx.new(|_cx| ProjectContext::default());
4764    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4765    let context_server_registry =
4766        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4767    let model = Arc::new(FakeLanguageModel::default());
4768
4769    let parent = cx.new(|cx| {
4770        Thread::new(
4771            project.clone(),
4772            project_context.clone(),
4773            context_server_registry.clone(),
4774            Templates::new(),
4775            Some(model.clone()),
4776            cx,
4777        )
4778    });
4779
4780    let mut subagents = Vec::new();
4781    for i in 0..MAX_PARALLEL_SUBAGENTS {
4782        let subagent_context = SubagentContext {
4783            parent_thread_id: agent_client_protocol::SessionId::new("parent-id"),
4784            tool_use_id: language_model::LanguageModelToolUseId::from(format!("tool-use-{}", i)),
4785            depth: 1,
4786            summary_prompt: "Summarize".to_string(),
4787            context_low_prompt: "Context low".to_string(),
4788        };
4789
4790        let subagent = cx.new(|cx| {
4791            Thread::new_subagent(
4792                project.clone(),
4793                project_context.clone(),
4794                context_server_registry.clone(),
4795                Templates::new(),
4796                model.clone(),
4797                subagent_context,
4798                std::collections::BTreeMap::new(),
4799                cx,
4800            )
4801        });
4802
4803        parent.update(cx, |thread, _cx| {
4804            thread.register_running_subagent(subagent.downgrade());
4805        });
4806        subagents.push(subagent);
4807    }
4808
4809    parent.read_with(cx, |thread, _| {
4810        assert_eq!(
4811            thread.running_subagent_count(),
4812            MAX_PARALLEL_SUBAGENTS,
4813            "should have MAX_PARALLEL_SUBAGENTS registered"
4814        );
4815    });
4816
4817    let parent_tools: std::collections::BTreeMap<gpui::SharedString, Arc<dyn crate::AnyAgentTool>> =
4818        std::collections::BTreeMap::new();
4819
4820    #[allow(clippy::arc_with_non_send_sync)]
4821    let tool = Arc::new(SubagentTool::new(
4822        parent.downgrade(),
4823        project.clone(),
4824        project_context,
4825        context_server_registry,
4826        Templates::new(),
4827        0,
4828        parent_tools,
4829    ));
4830
4831    let (event_stream, _rx) = crate::ToolCallEventStream::test();
4832
4833    let result = cx.update(|cx| {
4834        tool.run(
4835            SubagentToolInput {
4836                subagents: vec![crate::SubagentConfig {
4837                    label: "Test".to_string(),
4838                    task_prompt: "Do something".to_string(),
4839                    summary_prompt: "Summarize".to_string(),
4840                    context_low_prompt: "Context low".to_string(),
4841                    timeout_ms: None,
4842                    allowed_tools: None,
4843                }],
4844            },
4845            event_stream,
4846            cx,
4847        )
4848    });
4849
4850    let err = result.await.unwrap_err();
4851    assert!(
4852        err.to_string().contains("Maximum parallel subagents"),
4853        "should reject when max parallel subagents reached: {}",
4854        err
4855    );
4856
4857    drop(subagents);
4858}
4859
4860#[gpui::test]
4861async fn test_subagent_tool_end_to_end(cx: &mut TestAppContext) {
4862    init_test(cx);
4863    always_allow_tools(cx);
4864
4865    cx.update(|cx| {
4866        cx.update_flags(true, vec!["subagents".to_string()]);
4867    });
4868
4869    let fs = FakeFs::new(cx.executor());
4870    fs.insert_tree(path!("/test"), json!({})).await;
4871    let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
4872    let project_context = cx.new(|_cx| ProjectContext::default());
4873    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
4874    let context_server_registry =
4875        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
4876    let model = Arc::new(FakeLanguageModel::default());
4877    let fake_model = model.as_fake();
4878
4879    let parent = cx.new(|cx| {
4880        let mut thread = Thread::new(
4881            project.clone(),
4882            project_context.clone(),
4883            context_server_registry.clone(),
4884            Templates::new(),
4885            Some(model.clone()),
4886            cx,
4887        );
4888        thread.add_tool(EchoTool);
4889        thread
4890    });
4891
4892    let mut parent_tools: std::collections::BTreeMap<
4893        gpui::SharedString,
4894        Arc<dyn crate::AnyAgentTool>,
4895    > = std::collections::BTreeMap::new();
4896    parent_tools.insert("echo".into(), EchoTool.erase());
4897
4898    #[allow(clippy::arc_with_non_send_sync)]
4899    let tool = Arc::new(SubagentTool::new(
4900        parent.downgrade(),
4901        project.clone(),
4902        project_context,
4903        context_server_registry,
4904        Templates::new(),
4905        0,
4906        parent_tools,
4907    ));
4908
4909    let (event_stream, _rx) = crate::ToolCallEventStream::test();
4910
4911    let task = cx.update(|cx| {
4912        tool.run(
4913            SubagentToolInput {
4914                subagents: vec![crate::SubagentConfig {
4915                    label: "Research task".to_string(),
4916                    task_prompt: "Find all TODOs in the codebase".to_string(),
4917                    summary_prompt: "Summarize what you found".to_string(),
4918                    context_low_prompt: "Context low, wrap up".to_string(),
4919                    timeout_ms: None,
4920                    allowed_tools: None,
4921                }],
4922            },
4923            event_stream,
4924            cx,
4925        )
4926    });
4927
4928    cx.run_until_parked();
4929
4930    let pending = fake_model.pending_completions();
4931    assert!(
4932        !pending.is_empty(),
4933        "subagent should have started and sent a completion request"
4934    );
4935
4936    let first_completion = &pending[0];
4937    let has_task_prompt = first_completion.messages.iter().any(|m| {
4938        m.role == language_model::Role::User
4939            && m.content
4940                .iter()
4941                .any(|c| c.to_str().map(|s| s.contains("TODO")).unwrap_or(false))
4942    });
4943    assert!(has_task_prompt, "task prompt should be sent to subagent");
4944
4945    fake_model.send_last_completion_stream_text_chunk("I found 5 TODOs in the codebase.");
4946    fake_model
4947        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4948    fake_model.end_last_completion_stream();
4949    cx.run_until_parked();
4950
4951    let pending = fake_model.pending_completions();
4952    assert!(
4953        !pending.is_empty(),
4954        "should have pending completion for summary request"
4955    );
4956
4957    let last_completion = pending.last().unwrap();
4958    let has_summary_prompt = last_completion.messages.iter().any(|m| {
4959        m.role == language_model::Role::User
4960            && m.content.iter().any(|c| {
4961                c.to_str()
4962                    .map(|s| s.contains("Summarize") || s.contains("summarize"))
4963                    .unwrap_or(false)
4964            })
4965    });
4966    assert!(
4967        has_summary_prompt,
4968        "summary prompt should be sent after task completion"
4969    );
4970
4971    fake_model.send_last_completion_stream_text_chunk("Summary: Found 5 TODOs across 3 files.");
4972    fake_model
4973        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
4974    fake_model.end_last_completion_stream();
4975    cx.run_until_parked();
4976
4977    let result = task.await;
4978    assert!(result.is_ok(), "subagent tool should complete successfully");
4979
4980    let summary = result.unwrap();
4981    assert!(
4982        summary.contains("Summary") || summary.contains("TODO") || summary.contains("5"),
4983        "summary should contain subagent's response: {}",
4984        summary
4985    );
4986}
4987
4988#[gpui::test]
4989async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
4990    init_test(cx);
4991
4992    let fs = FakeFs::new(cx.executor());
4993    fs.insert_tree("/root", json!({"sensitive_config.txt": "secret data"}))
4994        .await;
4995    let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
4996
4997    cx.update(|cx| {
4998        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
4999        settings.tool_permissions.tools.insert(
5000            "edit_file".into(),
5001            agent_settings::ToolRules {
5002                default_mode: settings::ToolPermissionMode::Allow,
5003                always_allow: vec![],
5004                always_deny: vec![agent_settings::CompiledRegex::new(r"sensitive", false).unwrap()],
5005                always_confirm: vec![],
5006                invalid_patterns: vec![],
5007            },
5008        );
5009        agent_settings::AgentSettings::override_global(settings, cx);
5010    });
5011
5012    let context_server_registry =
5013        cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5014    let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5015    let templates = crate::Templates::new();
5016    let thread = cx.new(|cx| {
5017        crate::Thread::new(
5018            project.clone(),
5019            cx.new(|_cx| prompt_store::ProjectContext::default()),
5020            context_server_registry,
5021            templates.clone(),
5022            None,
5023            cx,
5024        )
5025    });
5026
5027    #[allow(clippy::arc_with_non_send_sync)]
5028    let tool = Arc::new(crate::EditFileTool::new(
5029        project.clone(),
5030        thread.downgrade(),
5031        language_registry,
5032        templates,
5033    ));
5034    let (event_stream, _rx) = crate::ToolCallEventStream::test();
5035
5036    let task = cx.update(|cx| {
5037        tool.run(
5038            crate::EditFileToolInput {
5039                display_description: "Edit sensitive file".to_string(),
5040                path: "root/sensitive_config.txt".into(),
5041                mode: crate::EditFileMode::Edit,
5042            },
5043            event_stream,
5044            cx,
5045        )
5046    });
5047
5048    let result = task.await;
5049    assert!(result.is_err(), "expected edit to be blocked");
5050    assert!(
5051        result.unwrap_err().to_string().contains("blocked"),
5052        "error should mention the edit was blocked"
5053    );
5054}
5055
5056#[gpui::test]
5057async fn test_delete_path_tool_deny_rule_blocks_deletion(cx: &mut TestAppContext) {
5058    init_test(cx);
5059
5060    let fs = FakeFs::new(cx.executor());
5061    fs.insert_tree("/root", json!({"important_data.txt": "critical info"}))
5062        .await;
5063    let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5064
5065    cx.update(|cx| {
5066        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5067        settings.tool_permissions.tools.insert(
5068            "delete_path".into(),
5069            agent_settings::ToolRules {
5070                default_mode: settings::ToolPermissionMode::Allow,
5071                always_allow: vec![],
5072                always_deny: vec![agent_settings::CompiledRegex::new(r"important", false).unwrap()],
5073                always_confirm: vec![],
5074                invalid_patterns: vec![],
5075            },
5076        );
5077        agent_settings::AgentSettings::override_global(settings, cx);
5078    });
5079
5080    let action_log = cx.new(|_cx| action_log::ActionLog::new(project.clone()));
5081
5082    #[allow(clippy::arc_with_non_send_sync)]
5083    let tool = Arc::new(crate::DeletePathTool::new(project, action_log));
5084    let (event_stream, _rx) = crate::ToolCallEventStream::test();
5085
5086    let task = cx.update(|cx| {
5087        tool.run(
5088            crate::DeletePathToolInput {
5089                path: "root/important_data.txt".to_string(),
5090            },
5091            event_stream,
5092            cx,
5093        )
5094    });
5095
5096    let result = task.await;
5097    assert!(result.is_err(), "expected deletion to be blocked");
5098    assert!(
5099        result.unwrap_err().to_string().contains("blocked"),
5100        "error should mention the deletion was blocked"
5101    );
5102}
5103
5104#[gpui::test]
5105async fn test_move_path_tool_denies_if_destination_denied(cx: &mut TestAppContext) {
5106    init_test(cx);
5107
5108    let fs = FakeFs::new(cx.executor());
5109    fs.insert_tree(
5110        "/root",
5111        json!({
5112            "safe.txt": "content",
5113            "protected": {}
5114        }),
5115    )
5116    .await;
5117    let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5118
5119    cx.update(|cx| {
5120        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5121        settings.tool_permissions.tools.insert(
5122            "move_path".into(),
5123            agent_settings::ToolRules {
5124                default_mode: settings::ToolPermissionMode::Allow,
5125                always_allow: vec![],
5126                always_deny: vec![agent_settings::CompiledRegex::new(r"protected", false).unwrap()],
5127                always_confirm: vec![],
5128                invalid_patterns: vec![],
5129            },
5130        );
5131        agent_settings::AgentSettings::override_global(settings, cx);
5132    });
5133
5134    #[allow(clippy::arc_with_non_send_sync)]
5135    let tool = Arc::new(crate::MovePathTool::new(project));
5136    let (event_stream, _rx) = crate::ToolCallEventStream::test();
5137
5138    let task = cx.update(|cx| {
5139        tool.run(
5140            crate::MovePathToolInput {
5141                source_path: "root/safe.txt".to_string(),
5142                destination_path: "root/protected/safe.txt".to_string(),
5143            },
5144            event_stream,
5145            cx,
5146        )
5147    });
5148
5149    let result = task.await;
5150    assert!(
5151        result.is_err(),
5152        "expected move to be blocked due to destination path"
5153    );
5154    assert!(
5155        result.unwrap_err().to_string().contains("blocked"),
5156        "error should mention the move was blocked"
5157    );
5158}
5159
5160#[gpui::test]
5161async fn test_move_path_tool_denies_if_source_denied(cx: &mut TestAppContext) {
5162    init_test(cx);
5163
5164    let fs = FakeFs::new(cx.executor());
5165    fs.insert_tree(
5166        "/root",
5167        json!({
5168            "secret.txt": "secret content",
5169            "public": {}
5170        }),
5171    )
5172    .await;
5173    let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5174
5175    cx.update(|cx| {
5176        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5177        settings.tool_permissions.tools.insert(
5178            "move_path".into(),
5179            agent_settings::ToolRules {
5180                default_mode: settings::ToolPermissionMode::Allow,
5181                always_allow: vec![],
5182                always_deny: vec![agent_settings::CompiledRegex::new(r"secret", false).unwrap()],
5183                always_confirm: vec![],
5184                invalid_patterns: vec![],
5185            },
5186        );
5187        agent_settings::AgentSettings::override_global(settings, cx);
5188    });
5189
5190    #[allow(clippy::arc_with_non_send_sync)]
5191    let tool = Arc::new(crate::MovePathTool::new(project));
5192    let (event_stream, _rx) = crate::ToolCallEventStream::test();
5193
5194    let task = cx.update(|cx| {
5195        tool.run(
5196            crate::MovePathToolInput {
5197                source_path: "root/secret.txt".to_string(),
5198                destination_path: "root/public/not_secret.txt".to_string(),
5199            },
5200            event_stream,
5201            cx,
5202        )
5203    });
5204
5205    let result = task.await;
5206    assert!(
5207        result.is_err(),
5208        "expected move to be blocked due to source path"
5209    );
5210    assert!(
5211        result.unwrap_err().to_string().contains("blocked"),
5212        "error should mention the move was blocked"
5213    );
5214}
5215
5216#[gpui::test]
5217async fn test_copy_path_tool_deny_rule_blocks_copy(cx: &mut TestAppContext) {
5218    init_test(cx);
5219
5220    let fs = FakeFs::new(cx.executor());
5221    fs.insert_tree(
5222        "/root",
5223        json!({
5224            "confidential.txt": "confidential data",
5225            "dest": {}
5226        }),
5227    )
5228    .await;
5229    let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5230
5231    cx.update(|cx| {
5232        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5233        settings.tool_permissions.tools.insert(
5234            "copy_path".into(),
5235            agent_settings::ToolRules {
5236                default_mode: settings::ToolPermissionMode::Allow,
5237                always_allow: vec![],
5238                always_deny: vec![
5239                    agent_settings::CompiledRegex::new(r"confidential", false).unwrap(),
5240                ],
5241                always_confirm: vec![],
5242                invalid_patterns: vec![],
5243            },
5244        );
5245        agent_settings::AgentSettings::override_global(settings, cx);
5246    });
5247
5248    #[allow(clippy::arc_with_non_send_sync)]
5249    let tool = Arc::new(crate::CopyPathTool::new(project));
5250    let (event_stream, _rx) = crate::ToolCallEventStream::test();
5251
5252    let task = cx.update(|cx| {
5253        tool.run(
5254            crate::CopyPathToolInput {
5255                source_path: "root/confidential.txt".to_string(),
5256                destination_path: "root/dest/copy.txt".to_string(),
5257            },
5258            event_stream,
5259            cx,
5260        )
5261    });
5262
5263    let result = task.await;
5264    assert!(result.is_err(), "expected copy to be blocked");
5265    assert!(
5266        result.unwrap_err().to_string().contains("blocked"),
5267        "error should mention the copy was blocked"
5268    );
5269}
5270
5271#[gpui::test]
5272async fn test_save_file_tool_denies_if_any_path_denied(cx: &mut TestAppContext) {
5273    init_test(cx);
5274
5275    let fs = FakeFs::new(cx.executor());
5276    fs.insert_tree(
5277        "/root",
5278        json!({
5279            "normal.txt": "normal content",
5280            "readonly": {
5281                "config.txt": "readonly content"
5282            }
5283        }),
5284    )
5285    .await;
5286    let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5287
5288    cx.update(|cx| {
5289        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5290        settings.tool_permissions.tools.insert(
5291            "save_file".into(),
5292            agent_settings::ToolRules {
5293                default_mode: settings::ToolPermissionMode::Allow,
5294                always_allow: vec![],
5295                always_deny: vec![agent_settings::CompiledRegex::new(r"readonly", false).unwrap()],
5296                always_confirm: vec![],
5297                invalid_patterns: vec![],
5298            },
5299        );
5300        agent_settings::AgentSettings::override_global(settings, cx);
5301    });
5302
5303    #[allow(clippy::arc_with_non_send_sync)]
5304    let tool = Arc::new(crate::SaveFileTool::new(project));
5305    let (event_stream, _rx) = crate::ToolCallEventStream::test();
5306
5307    let task = cx.update(|cx| {
5308        tool.run(
5309            crate::SaveFileToolInput {
5310                paths: vec![
5311                    std::path::PathBuf::from("root/normal.txt"),
5312                    std::path::PathBuf::from("root/readonly/config.txt"),
5313                ],
5314            },
5315            event_stream,
5316            cx,
5317        )
5318    });
5319
5320    let result = task.await;
5321    assert!(
5322        result.is_err(),
5323        "expected save to be blocked due to denied path"
5324    );
5325    assert!(
5326        result.unwrap_err().to_string().contains("blocked"),
5327        "error should mention the save was blocked"
5328    );
5329}
5330
5331#[gpui::test]
5332async fn test_save_file_tool_respects_deny_rules(cx: &mut TestAppContext) {
5333    init_test(cx);
5334
5335    let fs = FakeFs::new(cx.executor());
5336    fs.insert_tree("/root", json!({"config.secret": "secret config"}))
5337        .await;
5338    let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5339
5340    cx.update(|cx| {
5341        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5342        settings.always_allow_tool_actions = false;
5343        settings.tool_permissions.tools.insert(
5344            "save_file".into(),
5345            agent_settings::ToolRules {
5346                default_mode: settings::ToolPermissionMode::Allow,
5347                always_allow: vec![],
5348                always_deny: vec![agent_settings::CompiledRegex::new(r"\.secret$", false).unwrap()],
5349                always_confirm: vec![],
5350                invalid_patterns: vec![],
5351            },
5352        );
5353        agent_settings::AgentSettings::override_global(settings, cx);
5354    });
5355
5356    #[allow(clippy::arc_with_non_send_sync)]
5357    let tool = Arc::new(crate::SaveFileTool::new(project));
5358    let (event_stream, _rx) = crate::ToolCallEventStream::test();
5359
5360    let task = cx.update(|cx| {
5361        tool.run(
5362            crate::SaveFileToolInput {
5363                paths: vec![std::path::PathBuf::from("root/config.secret")],
5364            },
5365            event_stream,
5366            cx,
5367        )
5368    });
5369
5370    let result = task.await;
5371    assert!(result.is_err(), "expected save to be blocked");
5372    assert!(
5373        result.unwrap_err().to_string().contains("blocked"),
5374        "error should mention the save was blocked"
5375    );
5376}
5377
5378#[gpui::test]
5379async fn test_web_search_tool_deny_rule_blocks_search(cx: &mut TestAppContext) {
5380    init_test(cx);
5381
5382    cx.update(|cx| {
5383        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5384        settings.tool_permissions.tools.insert(
5385            "web_search".into(),
5386            agent_settings::ToolRules {
5387                default_mode: settings::ToolPermissionMode::Allow,
5388                always_allow: vec![],
5389                always_deny: vec![
5390                    agent_settings::CompiledRegex::new(r"internal\.company", false).unwrap(),
5391                ],
5392                always_confirm: vec![],
5393                invalid_patterns: vec![],
5394            },
5395        );
5396        agent_settings::AgentSettings::override_global(settings, cx);
5397    });
5398
5399    #[allow(clippy::arc_with_non_send_sync)]
5400    let tool = Arc::new(crate::WebSearchTool);
5401    let (event_stream, _rx) = crate::ToolCallEventStream::test();
5402
5403    let input: crate::WebSearchToolInput =
5404        serde_json::from_value(json!({"query": "internal.company.com secrets"})).unwrap();
5405
5406    let task = cx.update(|cx| tool.run(input, event_stream, cx));
5407
5408    let result = task.await;
5409    assert!(result.is_err(), "expected search to be blocked");
5410    assert!(
5411        result.unwrap_err().to_string().contains("blocked"),
5412        "error should mention the search was blocked"
5413    );
5414}
5415
5416#[gpui::test]
5417async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5418    init_test(cx);
5419
5420    let fs = FakeFs::new(cx.executor());
5421    fs.insert_tree("/root", json!({"README.md": "# Hello"}))
5422        .await;
5423    let project = Project::test(fs.clone(), ["/root".as_ref()], cx).await;
5424
5425    cx.update(|cx| {
5426        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5427        settings.always_allow_tool_actions = false;
5428        settings.tool_permissions.tools.insert(
5429            "edit_file".into(),
5430            agent_settings::ToolRules {
5431                default_mode: settings::ToolPermissionMode::Confirm,
5432                always_allow: vec![agent_settings::CompiledRegex::new(r"\.md$", false).unwrap()],
5433                always_deny: vec![],
5434                always_confirm: vec![],
5435                invalid_patterns: vec![],
5436            },
5437        );
5438        agent_settings::AgentSettings::override_global(settings, cx);
5439    });
5440
5441    let context_server_registry =
5442        cx.new(|cx| crate::ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
5443    let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
5444    let templates = crate::Templates::new();
5445    let thread = cx.new(|cx| {
5446        crate::Thread::new(
5447            project.clone(),
5448            cx.new(|_cx| prompt_store::ProjectContext::default()),
5449            context_server_registry,
5450            templates.clone(),
5451            None,
5452            cx,
5453        )
5454    });
5455
5456    #[allow(clippy::arc_with_non_send_sync)]
5457    let tool = Arc::new(crate::EditFileTool::new(
5458        project,
5459        thread.downgrade(),
5460        language_registry,
5461        templates,
5462    ));
5463    let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5464
5465    let _task = cx.update(|cx| {
5466        tool.run(
5467            crate::EditFileToolInput {
5468                display_description: "Edit README".to_string(),
5469                path: "root/README.md".into(),
5470                mode: crate::EditFileMode::Edit,
5471            },
5472            event_stream,
5473            cx,
5474        )
5475    });
5476
5477    cx.run_until_parked();
5478
5479    let event = rx.try_next();
5480    assert!(
5481        !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5482        "expected no authorization request for allowed .md file"
5483    );
5484}
5485
5486#[gpui::test]
5487async fn test_fetch_tool_deny_rule_blocks_url(cx: &mut TestAppContext) {
5488    init_test(cx);
5489
5490    cx.update(|cx| {
5491        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5492        settings.tool_permissions.tools.insert(
5493            "fetch".into(),
5494            agent_settings::ToolRules {
5495                default_mode: settings::ToolPermissionMode::Allow,
5496                always_allow: vec![],
5497                always_deny: vec![
5498                    agent_settings::CompiledRegex::new(r"internal\.company\.com", false).unwrap(),
5499                ],
5500                always_confirm: vec![],
5501                invalid_patterns: vec![],
5502            },
5503        );
5504        agent_settings::AgentSettings::override_global(settings, cx);
5505    });
5506
5507    let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5508
5509    #[allow(clippy::arc_with_non_send_sync)]
5510    let tool = Arc::new(crate::FetchTool::new(http_client));
5511    let (event_stream, _rx) = crate::ToolCallEventStream::test();
5512
5513    let input: crate::FetchToolInput =
5514        serde_json::from_value(json!({"url": "https://internal.company.com/api"})).unwrap();
5515
5516    let task = cx.update(|cx| tool.run(input, event_stream, cx));
5517
5518    let result = task.await;
5519    assert!(result.is_err(), "expected fetch to be blocked");
5520    assert!(
5521        result.unwrap_err().to_string().contains("blocked"),
5522        "error should mention the fetch was blocked"
5523    );
5524}
5525
5526#[gpui::test]
5527async fn test_fetch_tool_allow_rule_skips_confirmation(cx: &mut TestAppContext) {
5528    init_test(cx);
5529
5530    cx.update(|cx| {
5531        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
5532        settings.always_allow_tool_actions = false;
5533        settings.tool_permissions.tools.insert(
5534            "fetch".into(),
5535            agent_settings::ToolRules {
5536                default_mode: settings::ToolPermissionMode::Confirm,
5537                always_allow: vec![agent_settings::CompiledRegex::new(r"docs\.rs", false).unwrap()],
5538                always_deny: vec![],
5539                always_confirm: vec![],
5540                invalid_patterns: vec![],
5541            },
5542        );
5543        agent_settings::AgentSettings::override_global(settings, cx);
5544    });
5545
5546    let http_client = gpui::http_client::FakeHttpClient::with_200_response();
5547
5548    #[allow(clippy::arc_with_non_send_sync)]
5549    let tool = Arc::new(crate::FetchTool::new(http_client));
5550    let (event_stream, mut rx) = crate::ToolCallEventStream::test();
5551
5552    let input: crate::FetchToolInput =
5553        serde_json::from_value(json!({"url": "https://docs.rs/some-crate"})).unwrap();
5554
5555    let _task = cx.update(|cx| tool.run(input, event_stream, cx));
5556
5557    cx.run_until_parked();
5558
5559    let event = rx.try_next();
5560    assert!(
5561        !matches!(event, Ok(Some(Ok(ThreadEvent::ToolCallAuthorization(_))))),
5562        "expected no authorization request for allowed docs.rs URL"
5563    );
5564}