mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use agent_client_protocol::{self as acp};
   4use agent_settings::AgentProfileId;
   5use anyhow::Result;
   6use client::{Client, UserStore};
   7use cloud_llm_client::CompletionIntent;
   8use collections::IndexMap;
   9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  10use fs::{FakeFs, Fs};
  11use futures::{
  12    FutureExt as _, StreamExt,
  13    channel::{
  14        mpsc::{self, UnboundedReceiver},
  15        oneshot,
  16    },
  17    future::{Fuse, Shared},
  18};
  19use gpui::{
  20    App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
  21    http_client::FakeHttpClient,
  22};
  23use indoc::indoc;
  24use language_model::{
  25    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  26    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
  27    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
  28    LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
  29};
  30use pretty_assertions::assert_eq;
  31use project::{
  32    Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
  33};
  34use prompt_store::ProjectContext;
  35use reqwest_client::ReqwestClient;
  36use schemars::JsonSchema;
  37use serde::{Deserialize, Serialize};
  38use serde_json::json;
  39use settings::{Settings, SettingsStore};
  40use std::{
  41    path::Path,
  42    pin::Pin,
  43    rc::Rc,
  44    sync::{
  45        Arc,
  46        atomic::{AtomicBool, Ordering},
  47    },
  48    time::Duration,
  49};
  50use util::path;
  51
  52mod test_tools;
  53use test_tools::*;
  54
  55fn init_test(cx: &mut TestAppContext) {
  56    cx.update(|cx| {
  57        let settings_store = SettingsStore::test(cx);
  58        cx.set_global(settings_store);
  59    });
  60}
  61
  62struct FakeTerminalHandle {
  63    killed: Arc<AtomicBool>,
  64    wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
  65    output: acp::TerminalOutputResponse,
  66    id: acp::TerminalId,
  67}
  68
  69impl FakeTerminalHandle {
  70    fn new_never_exits(cx: &mut App) -> Self {
  71        let killed = Arc::new(AtomicBool::new(false));
  72
  73        let killed_for_task = killed.clone();
  74        let wait_for_exit = cx
  75            .spawn(async move |cx| {
  76                loop {
  77                    if killed_for_task.load(Ordering::SeqCst) {
  78                        return acp::TerminalExitStatus::new();
  79                    }
  80                    cx.background_executor()
  81                        .timer(Duration::from_millis(1))
  82                        .await;
  83                }
  84            })
  85            .shared();
  86
  87        Self {
  88            killed,
  89            wait_for_exit,
  90            output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
  91            id: acp::TerminalId::new("fake_terminal".to_string()),
  92        }
  93    }
  94
  95    fn was_killed(&self) -> bool {
  96        self.killed.load(Ordering::SeqCst)
  97    }
  98}
  99
 100impl crate::TerminalHandle for FakeTerminalHandle {
 101    fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
 102        Ok(self.id.clone())
 103    }
 104
 105    fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
 106        Ok(self.output.clone())
 107    }
 108
 109    fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
 110        Ok(self.wait_for_exit.clone())
 111    }
 112
 113    fn kill(&self, _cx: &AsyncApp) -> Result<()> {
 114        self.killed.store(true, Ordering::SeqCst);
 115        Ok(())
 116    }
 117}
 118
 119struct FakeThreadEnvironment {
 120    handle: Rc<FakeTerminalHandle>,
 121}
 122
 123impl crate::ThreadEnvironment for FakeThreadEnvironment {
 124    fn create_terminal(
 125        &self,
 126        _command: String,
 127        _cwd: Option<std::path::PathBuf>,
 128        _output_byte_limit: Option<u64>,
 129        _cx: &mut AsyncApp,
 130    ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
 131        Task::ready(Ok(self.handle.clone() as Rc<dyn crate::TerminalHandle>))
 132    }
 133}
 134
 135fn always_allow_tools(cx: &mut TestAppContext) {
 136    cx.update(|cx| {
 137        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
 138        settings.always_allow_tool_actions = true;
 139        agent_settings::AgentSettings::override_global(settings, cx);
 140    });
 141}
 142
 143#[gpui::test]
 144async fn test_echo(cx: &mut TestAppContext) {
 145    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 146    let fake_model = model.as_fake();
 147
 148    let events = thread
 149        .update(cx, |thread, cx| {
 150            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
 151        })
 152        .unwrap();
 153    cx.run_until_parked();
 154    fake_model.send_last_completion_stream_text_chunk("Hello");
 155    fake_model
 156        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
 157    fake_model.end_last_completion_stream();
 158
 159    let events = events.collect().await;
 160    thread.update(cx, |thread, _cx| {
 161        assert_eq!(
 162            thread.last_message().unwrap().to_markdown(),
 163            indoc! {"
 164                ## Assistant
 165
 166                Hello
 167            "}
 168        )
 169    });
 170    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 171}
 172
 173#[gpui::test]
 174async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
 175    init_test(cx);
 176    always_allow_tools(cx);
 177
 178    let fs = FakeFs::new(cx.executor());
 179    let project = Project::test(fs, [], cx).await;
 180
 181    let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
 182    let environment = Rc::new(FakeThreadEnvironment {
 183        handle: handle.clone(),
 184    });
 185
 186    #[allow(clippy::arc_with_non_send_sync)]
 187    let tool = Arc::new(crate::TerminalTool::new(project, environment));
 188    let (event_stream, mut rx) = crate::ToolCallEventStream::test();
 189
 190    let task = cx.update(|cx| {
 191        tool.run(
 192            crate::TerminalToolInput {
 193                command: "sleep 1000".to_string(),
 194                cd: ".".to_string(),
 195                timeout_ms: Some(5),
 196            },
 197            event_stream,
 198            cx,
 199        )
 200    });
 201
 202    let update = rx.expect_update_fields().await;
 203    assert!(
 204        update.content.iter().any(|blocks| {
 205            blocks
 206                .iter()
 207                .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
 208        }),
 209        "expected tool call update to include terminal content"
 210    );
 211
 212    let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
 213
 214    let deadline = std::time::Instant::now() + Duration::from_millis(500);
 215    loop {
 216        if let Some(result) = task_future.as_mut().now_or_never() {
 217            let result = result.expect("terminal tool task should complete");
 218
 219            assert!(
 220                handle.was_killed(),
 221                "expected terminal handle to be killed on timeout"
 222            );
 223            assert!(
 224                result.contains("partial output"),
 225                "expected result to include terminal output, got: {result}"
 226            );
 227            return;
 228        }
 229
 230        if std::time::Instant::now() >= deadline {
 231            panic!("timed out waiting for terminal tool task to complete");
 232        }
 233
 234        cx.run_until_parked();
 235        cx.background_executor.timer(Duration::from_millis(1)).await;
 236    }
 237}
 238
 239#[gpui::test]
 240#[ignore]
 241async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
 242    init_test(cx);
 243    always_allow_tools(cx);
 244
 245    let fs = FakeFs::new(cx.executor());
 246    let project = Project::test(fs, [], cx).await;
 247
 248    let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
 249    let environment = Rc::new(FakeThreadEnvironment {
 250        handle: handle.clone(),
 251    });
 252
 253    #[allow(clippy::arc_with_non_send_sync)]
 254    let tool = Arc::new(crate::TerminalTool::new(project, environment));
 255    let (event_stream, mut rx) = crate::ToolCallEventStream::test();
 256
 257    let _task = cx.update(|cx| {
 258        tool.run(
 259            crate::TerminalToolInput {
 260                command: "sleep 1000".to_string(),
 261                cd: ".".to_string(),
 262                timeout_ms: None,
 263            },
 264            event_stream,
 265            cx,
 266        )
 267    });
 268
 269    let update = rx.expect_update_fields().await;
 270    assert!(
 271        update.content.iter().any(|blocks| {
 272            blocks
 273                .iter()
 274                .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
 275        }),
 276        "expected tool call update to include terminal content"
 277    );
 278
 279    smol::Timer::after(Duration::from_millis(25)).await;
 280
 281    assert!(
 282        !handle.was_killed(),
 283        "did not expect terminal handle to be killed without a timeout"
 284    );
 285}
 286
 287#[gpui::test]
 288async fn test_thinking(cx: &mut TestAppContext) {
 289    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 290    let fake_model = model.as_fake();
 291
 292    let events = thread
 293        .update(cx, |thread, cx| {
 294            thread.send(
 295                UserMessageId::new(),
 296                [indoc! {"
 297                    Testing:
 298
 299                    Generate a thinking step where you just think the word 'Think',
 300                    and have your final answer be 'Hello'
 301                "}],
 302                cx,
 303            )
 304        })
 305        .unwrap();
 306    cx.run_until_parked();
 307    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
 308        text: "Think".to_string(),
 309        signature: None,
 310    });
 311    fake_model.send_last_completion_stream_text_chunk("Hello");
 312    fake_model
 313        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
 314    fake_model.end_last_completion_stream();
 315
 316    let events = events.collect().await;
 317    thread.update(cx, |thread, _cx| {
 318        assert_eq!(
 319            thread.last_message().unwrap().to_markdown(),
 320            indoc! {"
 321                ## Assistant
 322
 323                <think>Think</think>
 324                Hello
 325            "}
 326        )
 327    });
 328    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 329}
 330
 331#[gpui::test]
 332async fn test_system_prompt(cx: &mut TestAppContext) {
 333    let ThreadTest {
 334        model,
 335        thread,
 336        project_context,
 337        ..
 338    } = setup(cx, TestModel::Fake).await;
 339    let fake_model = model.as_fake();
 340
 341    project_context.update(cx, |project_context, _cx| {
 342        project_context.shell = "test-shell".into()
 343    });
 344    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 345    thread
 346        .update(cx, |thread, cx| {
 347            thread.send(UserMessageId::new(), ["abc"], cx)
 348        })
 349        .unwrap();
 350    cx.run_until_parked();
 351    let mut pending_completions = fake_model.pending_completions();
 352    assert_eq!(
 353        pending_completions.len(),
 354        1,
 355        "unexpected pending completions: {:?}",
 356        pending_completions
 357    );
 358
 359    let pending_completion = pending_completions.pop().unwrap();
 360    assert_eq!(pending_completion.messages[0].role, Role::System);
 361
 362    let system_message = &pending_completion.messages[0];
 363    let system_prompt = system_message.content[0].to_str().unwrap();
 364    assert!(
 365        system_prompt.contains("test-shell"),
 366        "unexpected system message: {:?}",
 367        system_message
 368    );
 369    assert!(
 370        system_prompt.contains("## Fixing Diagnostics"),
 371        "unexpected system message: {:?}",
 372        system_message
 373    );
 374}
 375
 376#[gpui::test]
 377async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
 378    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 379    let fake_model = model.as_fake();
 380
 381    thread
 382        .update(cx, |thread, cx| {
 383            thread.send(UserMessageId::new(), ["abc"], cx)
 384        })
 385        .unwrap();
 386    cx.run_until_parked();
 387    let mut pending_completions = fake_model.pending_completions();
 388    assert_eq!(
 389        pending_completions.len(),
 390        1,
 391        "unexpected pending completions: {:?}",
 392        pending_completions
 393    );
 394
 395    let pending_completion = pending_completions.pop().unwrap();
 396    assert_eq!(pending_completion.messages[0].role, Role::System);
 397
 398    let system_message = &pending_completion.messages[0];
 399    let system_prompt = system_message.content[0].to_str().unwrap();
 400    assert!(
 401        !system_prompt.contains("## Tool Use"),
 402        "unexpected system message: {:?}",
 403        system_message
 404    );
 405    assert!(
 406        !system_prompt.contains("## Fixing Diagnostics"),
 407        "unexpected system message: {:?}",
 408        system_message
 409    );
 410}
 411
 412#[gpui::test]
 413async fn test_prompt_caching(cx: &mut TestAppContext) {
 414    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 415    let fake_model = model.as_fake();
 416
 417    // Send initial user message and verify it's cached
 418    thread
 419        .update(cx, |thread, cx| {
 420            thread.send(UserMessageId::new(), ["Message 1"], cx)
 421        })
 422        .unwrap();
 423    cx.run_until_parked();
 424
 425    let completion = fake_model.pending_completions().pop().unwrap();
 426    assert_eq!(
 427        completion.messages[1..],
 428        vec![LanguageModelRequestMessage {
 429            role: Role::User,
 430            content: vec!["Message 1".into()],
 431            cache: true,
 432            reasoning_details: None,
 433        }]
 434    );
 435    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 436        "Response to Message 1".into(),
 437    ));
 438    fake_model.end_last_completion_stream();
 439    cx.run_until_parked();
 440
 441    // Send another user message and verify only the latest is cached
 442    thread
 443        .update(cx, |thread, cx| {
 444            thread.send(UserMessageId::new(), ["Message 2"], cx)
 445        })
 446        .unwrap();
 447    cx.run_until_parked();
 448
 449    let completion = fake_model.pending_completions().pop().unwrap();
 450    assert_eq!(
 451        completion.messages[1..],
 452        vec![
 453            LanguageModelRequestMessage {
 454                role: Role::User,
 455                content: vec!["Message 1".into()],
 456                cache: false,
 457                reasoning_details: None,
 458            },
 459            LanguageModelRequestMessage {
 460                role: Role::Assistant,
 461                content: vec!["Response to Message 1".into()],
 462                cache: false,
 463                reasoning_details: None,
 464            },
 465            LanguageModelRequestMessage {
 466                role: Role::User,
 467                content: vec!["Message 2".into()],
 468                cache: true,
 469                reasoning_details: None,
 470            }
 471        ]
 472    );
 473    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 474        "Response to Message 2".into(),
 475    ));
 476    fake_model.end_last_completion_stream();
 477    cx.run_until_parked();
 478
 479    // Simulate a tool call and verify that the latest tool result is cached
 480    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 481    thread
 482        .update(cx, |thread, cx| {
 483            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 484        })
 485        .unwrap();
 486    cx.run_until_parked();
 487
 488    let tool_use = LanguageModelToolUse {
 489        id: "tool_1".into(),
 490        name: EchoTool::name().into(),
 491        raw_input: json!({"text": "test"}).to_string(),
 492        input: json!({"text": "test"}),
 493        is_input_complete: true,
 494        thought_signature: None,
 495    };
 496    fake_model
 497        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 498    fake_model.end_last_completion_stream();
 499    cx.run_until_parked();
 500
 501    let completion = fake_model.pending_completions().pop().unwrap();
 502    let tool_result = LanguageModelToolResult {
 503        tool_use_id: "tool_1".into(),
 504        tool_name: EchoTool::name().into(),
 505        is_error: false,
 506        content: "test".into(),
 507        output: Some("test".into()),
 508    };
 509    assert_eq!(
 510        completion.messages[1..],
 511        vec![
 512            LanguageModelRequestMessage {
 513                role: Role::User,
 514                content: vec!["Message 1".into()],
 515                cache: false,
 516                reasoning_details: None,
 517            },
 518            LanguageModelRequestMessage {
 519                role: Role::Assistant,
 520                content: vec!["Response to Message 1".into()],
 521                cache: false,
 522                reasoning_details: None,
 523            },
 524            LanguageModelRequestMessage {
 525                role: Role::User,
 526                content: vec!["Message 2".into()],
 527                cache: false,
 528                reasoning_details: None,
 529            },
 530            LanguageModelRequestMessage {
 531                role: Role::Assistant,
 532                content: vec!["Response to Message 2".into()],
 533                cache: false,
 534                reasoning_details: None,
 535            },
 536            LanguageModelRequestMessage {
 537                role: Role::User,
 538                content: vec!["Use the echo tool".into()],
 539                cache: false,
 540                reasoning_details: None,
 541            },
 542            LanguageModelRequestMessage {
 543                role: Role::Assistant,
 544                content: vec![MessageContent::ToolUse(tool_use)],
 545                cache: false,
 546                reasoning_details: None,
 547            },
 548            LanguageModelRequestMessage {
 549                role: Role::User,
 550                content: vec![MessageContent::ToolResult(tool_result)],
 551                cache: true,
 552                reasoning_details: None,
 553            }
 554        ]
 555    );
 556}
 557
 558#[gpui::test]
 559#[cfg_attr(not(feature = "e2e"), ignore)]
 560async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 561    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 562
 563    // Test a tool call that's likely to complete *before* streaming stops.
 564    let events = thread
 565        .update(cx, |thread, cx| {
 566            thread.add_tool(EchoTool);
 567            thread.send(
 568                UserMessageId::new(),
 569                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 570                cx,
 571            )
 572        })
 573        .unwrap()
 574        .collect()
 575        .await;
 576    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 577
 578    // Test a tool calls that's likely to complete *after* streaming stops.
 579    let events = thread
 580        .update(cx, |thread, cx| {
 581            thread.remove_tool(&EchoTool::name());
 582            thread.add_tool(DelayTool);
 583            thread.send(
 584                UserMessageId::new(),
 585                [
 586                    "Now call the delay tool with 200ms.",
 587                    "When the timer goes off, then you echo the output of the tool.",
 588                ],
 589                cx,
 590            )
 591        })
 592        .unwrap()
 593        .collect()
 594        .await;
 595    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 596    thread.update(cx, |thread, _cx| {
 597        assert!(
 598            thread
 599                .last_message()
 600                .unwrap()
 601                .as_agent_message()
 602                .unwrap()
 603                .content
 604                .iter()
 605                .any(|content| {
 606                    if let AgentMessageContent::Text(text) = content {
 607                        text.contains("Ding")
 608                    } else {
 609                        false
 610                    }
 611                }),
 612            "{}",
 613            thread.to_markdown()
 614        );
 615    });
 616}
 617
 618#[gpui::test]
 619#[cfg_attr(not(feature = "e2e"), ignore)]
 620async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 621    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 622
 623    // Test a tool call that's likely to complete *before* streaming stops.
 624    let mut events = thread
 625        .update(cx, |thread, cx| {
 626            thread.add_tool(WordListTool);
 627            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 628        })
 629        .unwrap();
 630
 631    let mut saw_partial_tool_use = false;
 632    while let Some(event) = events.next().await {
 633        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 634            thread.update(cx, |thread, _cx| {
 635                // Look for a tool use in the thread's last message
 636                let message = thread.last_message().unwrap();
 637                let agent_message = message.as_agent_message().unwrap();
 638                let last_content = agent_message.content.last().unwrap();
 639                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 640                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 641                    if tool_call.status == acp::ToolCallStatus::Pending {
 642                        if !last_tool_use.is_input_complete
 643                            && last_tool_use.input.get("g").is_none()
 644                        {
 645                            saw_partial_tool_use = true;
 646                        }
 647                    } else {
 648                        last_tool_use
 649                            .input
 650                            .get("a")
 651                            .expect("'a' has streamed because input is now complete");
 652                        last_tool_use
 653                            .input
 654                            .get("g")
 655                            .expect("'g' has streamed because input is now complete");
 656                    }
 657                } else {
 658                    panic!("last content should be a tool use");
 659                }
 660            });
 661        }
 662    }
 663
 664    assert!(
 665        saw_partial_tool_use,
 666        "should see at least one partially streamed tool use in the history"
 667    );
 668}
 669
 670#[gpui::test]
 671async fn test_tool_authorization(cx: &mut TestAppContext) {
 672    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 673    let fake_model = model.as_fake();
 674
 675    let mut events = thread
 676        .update(cx, |thread, cx| {
 677            thread.add_tool(ToolRequiringPermission);
 678            thread.send(UserMessageId::new(), ["abc"], cx)
 679        })
 680        .unwrap();
 681    cx.run_until_parked();
 682    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 683        LanguageModelToolUse {
 684            id: "tool_id_1".into(),
 685            name: ToolRequiringPermission::name().into(),
 686            raw_input: "{}".into(),
 687            input: json!({}),
 688            is_input_complete: true,
 689            thought_signature: None,
 690        },
 691    ));
 692    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 693        LanguageModelToolUse {
 694            id: "tool_id_2".into(),
 695            name: ToolRequiringPermission::name().into(),
 696            raw_input: "{}".into(),
 697            input: json!({}),
 698            is_input_complete: true,
 699            thought_signature: None,
 700        },
 701    ));
 702    fake_model.end_last_completion_stream();
 703    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 704    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 705
 706    // Approve the first
 707    tool_call_auth_1
 708        .response
 709        .send(tool_call_auth_1.options[1].option_id.clone())
 710        .unwrap();
 711    cx.run_until_parked();
 712
 713    // Reject the second
 714    tool_call_auth_2
 715        .response
 716        .send(tool_call_auth_1.options[2].option_id.clone())
 717        .unwrap();
 718    cx.run_until_parked();
 719
 720    let completion = fake_model.pending_completions().pop().unwrap();
 721    let message = completion.messages.last().unwrap();
 722    assert_eq!(
 723        message.content,
 724        vec![
 725            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 726                tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
 727                tool_name: ToolRequiringPermission::name().into(),
 728                is_error: false,
 729                content: "Allowed".into(),
 730                output: Some("Allowed".into())
 731            }),
 732            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 733                tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
 734                tool_name: ToolRequiringPermission::name().into(),
 735                is_error: true,
 736                content: "Permission to run tool denied by user".into(),
 737                output: Some("Permission to run tool denied by user".into())
 738            })
 739        ]
 740    );
 741
 742    // Simulate yet another tool call.
 743    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 744        LanguageModelToolUse {
 745            id: "tool_id_3".into(),
 746            name: ToolRequiringPermission::name().into(),
 747            raw_input: "{}".into(),
 748            input: json!({}),
 749            is_input_complete: true,
 750            thought_signature: None,
 751        },
 752    ));
 753    fake_model.end_last_completion_stream();
 754
 755    // Respond by always allowing tools.
 756    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 757    tool_call_auth_3
 758        .response
 759        .send(tool_call_auth_3.options[0].option_id.clone())
 760        .unwrap();
 761    cx.run_until_parked();
 762    let completion = fake_model.pending_completions().pop().unwrap();
 763    let message = completion.messages.last().unwrap();
 764    assert_eq!(
 765        message.content,
 766        vec![language_model::MessageContent::ToolResult(
 767            LanguageModelToolResult {
 768                tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
 769                tool_name: ToolRequiringPermission::name().into(),
 770                is_error: false,
 771                content: "Allowed".into(),
 772                output: Some("Allowed".into())
 773            }
 774        )]
 775    );
 776
 777    // Simulate a final tool call, ensuring we don't trigger authorization.
 778    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 779        LanguageModelToolUse {
 780            id: "tool_id_4".into(),
 781            name: ToolRequiringPermission::name().into(),
 782            raw_input: "{}".into(),
 783            input: json!({}),
 784            is_input_complete: true,
 785            thought_signature: None,
 786        },
 787    ));
 788    fake_model.end_last_completion_stream();
 789    cx.run_until_parked();
 790    let completion = fake_model.pending_completions().pop().unwrap();
 791    let message = completion.messages.last().unwrap();
 792    assert_eq!(
 793        message.content,
 794        vec![language_model::MessageContent::ToolResult(
 795            LanguageModelToolResult {
 796                tool_use_id: "tool_id_4".into(),
 797                tool_name: ToolRequiringPermission::name().into(),
 798                is_error: false,
 799                content: "Allowed".into(),
 800                output: Some("Allowed".into())
 801            }
 802        )]
 803    );
 804}
 805
 806#[gpui::test]
 807async fn test_tool_hallucination(cx: &mut TestAppContext) {
 808    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 809    let fake_model = model.as_fake();
 810
 811    let mut events = thread
 812        .update(cx, |thread, cx| {
 813            thread.send(UserMessageId::new(), ["abc"], cx)
 814        })
 815        .unwrap();
 816    cx.run_until_parked();
 817    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 818        LanguageModelToolUse {
 819            id: "tool_id_1".into(),
 820            name: "nonexistent_tool".into(),
 821            raw_input: "{}".into(),
 822            input: json!({}),
 823            is_input_complete: true,
 824            thought_signature: None,
 825        },
 826    ));
 827    fake_model.end_last_completion_stream();
 828
 829    let tool_call = expect_tool_call(&mut events).await;
 830    assert_eq!(tool_call.title, "nonexistent_tool");
 831    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 832    let update = expect_tool_call_update_fields(&mut events).await;
 833    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 834}
 835
 836#[gpui::test]
 837async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 838    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 839    let fake_model = model.as_fake();
 840
 841    let events = thread
 842        .update(cx, |thread, cx| {
 843            thread.add_tool(EchoTool);
 844            thread.send(UserMessageId::new(), ["abc"], cx)
 845        })
 846        .unwrap();
 847    cx.run_until_parked();
 848    let tool_use = LanguageModelToolUse {
 849        id: "tool_id_1".into(),
 850        name: EchoTool::name().into(),
 851        raw_input: "{}".into(),
 852        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 853        is_input_complete: true,
 854        thought_signature: None,
 855    };
 856    fake_model
 857        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 858    fake_model.end_last_completion_stream();
 859
 860    cx.run_until_parked();
 861    let completion = fake_model.pending_completions().pop().unwrap();
 862    let tool_result = LanguageModelToolResult {
 863        tool_use_id: "tool_id_1".into(),
 864        tool_name: EchoTool::name().into(),
 865        is_error: false,
 866        content: "def".into(),
 867        output: Some("def".into()),
 868    };
 869    assert_eq!(
 870        completion.messages[1..],
 871        vec![
 872            LanguageModelRequestMessage {
 873                role: Role::User,
 874                content: vec!["abc".into()],
 875                cache: false,
 876                reasoning_details: None,
 877            },
 878            LanguageModelRequestMessage {
 879                role: Role::Assistant,
 880                content: vec![MessageContent::ToolUse(tool_use.clone())],
 881                cache: false,
 882                reasoning_details: None,
 883            },
 884            LanguageModelRequestMessage {
 885                role: Role::User,
 886                content: vec![MessageContent::ToolResult(tool_result.clone())],
 887                cache: true,
 888                reasoning_details: None,
 889            },
 890        ]
 891    );
 892
 893    // Simulate reaching tool use limit.
 894    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
 895    fake_model.end_last_completion_stream();
 896    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 897    assert!(
 898        last_event
 899            .unwrap_err()
 900            .is::<language_model::ToolUseLimitReachedError>()
 901    );
 902
 903    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 904    cx.run_until_parked();
 905    let completion = fake_model.pending_completions().pop().unwrap();
 906    assert_eq!(
 907        completion.messages[1..],
 908        vec![
 909            LanguageModelRequestMessage {
 910                role: Role::User,
 911                content: vec!["abc".into()],
 912                cache: false,
 913                reasoning_details: None,
 914            },
 915            LanguageModelRequestMessage {
 916                role: Role::Assistant,
 917                content: vec![MessageContent::ToolUse(tool_use)],
 918                cache: false,
 919                reasoning_details: None,
 920            },
 921            LanguageModelRequestMessage {
 922                role: Role::User,
 923                content: vec![MessageContent::ToolResult(tool_result)],
 924                cache: false,
 925                reasoning_details: None,
 926            },
 927            LanguageModelRequestMessage {
 928                role: Role::User,
 929                content: vec!["Continue where you left off".into()],
 930                cache: true,
 931                reasoning_details: None,
 932            }
 933        ]
 934    );
 935
 936    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 937    fake_model.end_last_completion_stream();
 938    events.collect::<Vec<_>>().await;
 939    thread.read_with(cx, |thread, _cx| {
 940        assert_eq!(
 941            thread.last_message().unwrap().to_markdown(),
 942            indoc! {"
 943                ## Assistant
 944
 945                Done
 946            "}
 947        )
 948    });
 949}
 950
 951#[gpui::test]
 952async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 953    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 954    let fake_model = model.as_fake();
 955
 956    let events = thread
 957        .update(cx, |thread, cx| {
 958            thread.add_tool(EchoTool);
 959            thread.send(UserMessageId::new(), ["abc"], cx)
 960        })
 961        .unwrap();
 962    cx.run_until_parked();
 963
 964    let tool_use = LanguageModelToolUse {
 965        id: "tool_id_1".into(),
 966        name: EchoTool::name().into(),
 967        raw_input: "{}".into(),
 968        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 969        is_input_complete: true,
 970        thought_signature: None,
 971    };
 972    let tool_result = LanguageModelToolResult {
 973        tool_use_id: "tool_id_1".into(),
 974        tool_name: EchoTool::name().into(),
 975        is_error: false,
 976        content: "def".into(),
 977        output: Some("def".into()),
 978    };
 979    fake_model
 980        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 981    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
 982    fake_model.end_last_completion_stream();
 983    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 984    assert!(
 985        last_event
 986            .unwrap_err()
 987            .is::<language_model::ToolUseLimitReachedError>()
 988    );
 989
 990    thread
 991        .update(cx, |thread, cx| {
 992            thread.send(UserMessageId::new(), vec!["ghi"], cx)
 993        })
 994        .unwrap();
 995    cx.run_until_parked();
 996    let completion = fake_model.pending_completions().pop().unwrap();
 997    assert_eq!(
 998        completion.messages[1..],
 999        vec![
1000            LanguageModelRequestMessage {
1001                role: Role::User,
1002                content: vec!["abc".into()],
1003                cache: false,
1004                reasoning_details: None,
1005            },
1006            LanguageModelRequestMessage {
1007                role: Role::Assistant,
1008                content: vec![MessageContent::ToolUse(tool_use)],
1009                cache: false,
1010                reasoning_details: None,
1011            },
1012            LanguageModelRequestMessage {
1013                role: Role::User,
1014                content: vec![MessageContent::ToolResult(tool_result)],
1015                cache: false,
1016                reasoning_details: None,
1017            },
1018            LanguageModelRequestMessage {
1019                role: Role::User,
1020                content: vec!["ghi".into()],
1021                cache: true,
1022                reasoning_details: None,
1023            }
1024        ]
1025    );
1026}
1027
1028async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
1029    let event = events
1030        .next()
1031        .await
1032        .expect("no tool call authorization event received")
1033        .unwrap();
1034    match event {
1035        ThreadEvent::ToolCall(tool_call) => tool_call,
1036        event => {
1037            panic!("Unexpected event {event:?}");
1038        }
1039    }
1040}
1041
1042async fn expect_tool_call_update_fields(
1043    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1044) -> acp::ToolCallUpdate {
1045    let event = events
1046        .next()
1047        .await
1048        .expect("no tool call authorization event received")
1049        .unwrap();
1050    match event {
1051        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
1052        event => {
1053            panic!("Unexpected event {event:?}");
1054        }
1055    }
1056}
1057
1058async fn next_tool_call_authorization(
1059    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1060) -> ToolCallAuthorization {
1061    loop {
1062        let event = events
1063            .next()
1064            .await
1065            .expect("no tool call authorization event received")
1066            .unwrap();
1067        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1068            let permission_kinds = tool_call_authorization
1069                .options
1070                .iter()
1071                .map(|o| o.kind)
1072                .collect::<Vec<_>>();
1073            assert_eq!(
1074                permission_kinds,
1075                vec![
1076                    acp::PermissionOptionKind::AllowAlways,
1077                    acp::PermissionOptionKind::AllowOnce,
1078                    acp::PermissionOptionKind::RejectOnce,
1079                ]
1080            );
1081            return tool_call_authorization;
1082        }
1083    }
1084}
1085
1086#[gpui::test]
1087#[cfg_attr(not(feature = "e2e"), ignore)]
1088async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1089    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1090
1091    // Test concurrent tool calls with different delay times
1092    let events = thread
1093        .update(cx, |thread, cx| {
1094            thread.add_tool(DelayTool);
1095            thread.send(
1096                UserMessageId::new(),
1097                [
1098                    "Call the delay tool twice in the same message.",
1099                    "Once with 100ms. Once with 300ms.",
1100                    "When both timers are complete, describe the outputs.",
1101                ],
1102                cx,
1103            )
1104        })
1105        .unwrap()
1106        .collect()
1107        .await;
1108
1109    let stop_reasons = stop_events(events);
1110    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1111
1112    thread.update(cx, |thread, _cx| {
1113        let last_message = thread.last_message().unwrap();
1114        let agent_message = last_message.as_agent_message().unwrap();
1115        let text = agent_message
1116            .content
1117            .iter()
1118            .filter_map(|content| {
1119                if let AgentMessageContent::Text(text) = content {
1120                    Some(text.as_str())
1121                } else {
1122                    None
1123                }
1124            })
1125            .collect::<String>();
1126
1127        assert!(text.contains("Ding"));
1128    });
1129}
1130
1131#[gpui::test]
1132async fn test_profiles(cx: &mut TestAppContext) {
1133    let ThreadTest {
1134        model, thread, fs, ..
1135    } = setup(cx, TestModel::Fake).await;
1136    let fake_model = model.as_fake();
1137
1138    thread.update(cx, |thread, _cx| {
1139        thread.add_tool(DelayTool);
1140        thread.add_tool(EchoTool);
1141        thread.add_tool(InfiniteTool);
1142    });
1143
1144    // Override profiles and wait for settings to be loaded.
1145    fs.insert_file(
1146        paths::settings_file(),
1147        json!({
1148            "agent": {
1149                "profiles": {
1150                    "test-1": {
1151                        "name": "Test Profile 1",
1152                        "tools": {
1153                            EchoTool::name(): true,
1154                            DelayTool::name(): true,
1155                        }
1156                    },
1157                    "test-2": {
1158                        "name": "Test Profile 2",
1159                        "tools": {
1160                            InfiniteTool::name(): true,
1161                        }
1162                    }
1163                }
1164            }
1165        })
1166        .to_string()
1167        .into_bytes(),
1168    )
1169    .await;
1170    cx.run_until_parked();
1171
1172    // Test that test-1 profile (default) has echo and delay tools
1173    thread
1174        .update(cx, |thread, cx| {
1175            thread.set_profile(AgentProfileId("test-1".into()), cx);
1176            thread.send(UserMessageId::new(), ["test"], cx)
1177        })
1178        .unwrap();
1179    cx.run_until_parked();
1180
1181    let mut pending_completions = fake_model.pending_completions();
1182    assert_eq!(pending_completions.len(), 1);
1183    let completion = pending_completions.pop().unwrap();
1184    let tool_names: Vec<String> = completion
1185        .tools
1186        .iter()
1187        .map(|tool| tool.name.clone())
1188        .collect();
1189    assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
1190    fake_model.end_last_completion_stream();
1191
1192    // Switch to test-2 profile, and verify that it has only the infinite tool.
1193    thread
1194        .update(cx, |thread, cx| {
1195            thread.set_profile(AgentProfileId("test-2".into()), cx);
1196            thread.send(UserMessageId::new(), ["test2"], cx)
1197        })
1198        .unwrap();
1199    cx.run_until_parked();
1200    let mut pending_completions = fake_model.pending_completions();
1201    assert_eq!(pending_completions.len(), 1);
1202    let completion = pending_completions.pop().unwrap();
1203    let tool_names: Vec<String> = completion
1204        .tools
1205        .iter()
1206        .map(|tool| tool.name.clone())
1207        .collect();
1208    assert_eq!(tool_names, vec![InfiniteTool::name()]);
1209}
1210
1211#[gpui::test]
1212async fn test_mcp_tools(cx: &mut TestAppContext) {
1213    let ThreadTest {
1214        model,
1215        thread,
1216        context_server_store,
1217        fs,
1218        ..
1219    } = setup(cx, TestModel::Fake).await;
1220    let fake_model = model.as_fake();
1221
1222    // Override profiles and wait for settings to be loaded.
1223    fs.insert_file(
1224        paths::settings_file(),
1225        json!({
1226            "agent": {
1227                "always_allow_tool_actions": true,
1228                "profiles": {
1229                    "test": {
1230                        "name": "Test Profile",
1231                        "enable_all_context_servers": true,
1232                        "tools": {
1233                            EchoTool::name(): true,
1234                        }
1235                    },
1236                }
1237            }
1238        })
1239        .to_string()
1240        .into_bytes(),
1241    )
1242    .await;
1243    cx.run_until_parked();
1244    thread.update(cx, |thread, cx| {
1245        thread.set_profile(AgentProfileId("test".into()), cx)
1246    });
1247
1248    let mut mcp_tool_calls = setup_context_server(
1249        "test_server",
1250        vec![context_server::types::Tool {
1251            name: "echo".into(),
1252            description: None,
1253            input_schema: serde_json::to_value(EchoTool::input_schema(
1254                LanguageModelToolSchemaFormat::JsonSchema,
1255            ))
1256            .unwrap(),
1257            output_schema: None,
1258            annotations: None,
1259        }],
1260        &context_server_store,
1261        cx,
1262    );
1263
1264    let events = thread.update(cx, |thread, cx| {
1265        thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1266    });
1267    cx.run_until_parked();
1268
1269    // Simulate the model calling the MCP tool.
1270    let completion = fake_model.pending_completions().pop().unwrap();
1271    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1272    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1273        LanguageModelToolUse {
1274            id: "tool_1".into(),
1275            name: "echo".into(),
1276            raw_input: json!({"text": "test"}).to_string(),
1277            input: json!({"text": "test"}),
1278            is_input_complete: true,
1279            thought_signature: None,
1280        },
1281    ));
1282    fake_model.end_last_completion_stream();
1283    cx.run_until_parked();
1284
1285    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1286    assert_eq!(tool_call_params.name, "echo");
1287    assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1288    tool_call_response
1289        .send(context_server::types::CallToolResponse {
1290            content: vec![context_server::types::ToolResponseContent::Text {
1291                text: "test".into(),
1292            }],
1293            is_error: None,
1294            meta: None,
1295            structured_content: None,
1296        })
1297        .unwrap();
1298    cx.run_until_parked();
1299
1300    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1301    fake_model.send_last_completion_stream_text_chunk("Done!");
1302    fake_model.end_last_completion_stream();
1303    events.collect::<Vec<_>>().await;
1304
1305    // Send again after adding the echo tool, ensuring the name collision is resolved.
1306    let events = thread.update(cx, |thread, cx| {
1307        thread.add_tool(EchoTool);
1308        thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1309    });
1310    cx.run_until_parked();
1311    let completion = fake_model.pending_completions().pop().unwrap();
1312    assert_eq!(
1313        tool_names_for_completion(&completion),
1314        vec!["echo", "test_server_echo"]
1315    );
1316    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1317        LanguageModelToolUse {
1318            id: "tool_2".into(),
1319            name: "test_server_echo".into(),
1320            raw_input: json!({"text": "mcp"}).to_string(),
1321            input: json!({"text": "mcp"}),
1322            is_input_complete: true,
1323            thought_signature: None,
1324        },
1325    ));
1326    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1327        LanguageModelToolUse {
1328            id: "tool_3".into(),
1329            name: "echo".into(),
1330            raw_input: json!({"text": "native"}).to_string(),
1331            input: json!({"text": "native"}),
1332            is_input_complete: true,
1333            thought_signature: None,
1334        },
1335    ));
1336    fake_model.end_last_completion_stream();
1337    cx.run_until_parked();
1338
1339    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1340    assert_eq!(tool_call_params.name, "echo");
1341    assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1342    tool_call_response
1343        .send(context_server::types::CallToolResponse {
1344            content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1345            is_error: None,
1346            meta: None,
1347            structured_content: None,
1348        })
1349        .unwrap();
1350    cx.run_until_parked();
1351
1352    // Ensure the tool results were inserted with the correct names.
1353    let completion = fake_model.pending_completions().pop().unwrap();
1354    assert_eq!(
1355        completion.messages.last().unwrap().content,
1356        vec![
1357            MessageContent::ToolResult(LanguageModelToolResult {
1358                tool_use_id: "tool_3".into(),
1359                tool_name: "echo".into(),
1360                is_error: false,
1361                content: "native".into(),
1362                output: Some("native".into()),
1363            },),
1364            MessageContent::ToolResult(LanguageModelToolResult {
1365                tool_use_id: "tool_2".into(),
1366                tool_name: "test_server_echo".into(),
1367                is_error: false,
1368                content: "mcp".into(),
1369                output: Some("mcp".into()),
1370            },),
1371        ]
1372    );
1373    fake_model.end_last_completion_stream();
1374    events.collect::<Vec<_>>().await;
1375}
1376
1377#[gpui::test]
1378async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1379    let ThreadTest {
1380        model,
1381        thread,
1382        context_server_store,
1383        fs,
1384        ..
1385    } = setup(cx, TestModel::Fake).await;
1386    let fake_model = model.as_fake();
1387
1388    // Set up a profile with all tools enabled
1389    fs.insert_file(
1390        paths::settings_file(),
1391        json!({
1392            "agent": {
1393                "profiles": {
1394                    "test": {
1395                        "name": "Test Profile",
1396                        "enable_all_context_servers": true,
1397                        "tools": {
1398                            EchoTool::name(): true,
1399                            DelayTool::name(): true,
1400                            WordListTool::name(): true,
1401                            ToolRequiringPermission::name(): true,
1402                            InfiniteTool::name(): true,
1403                        }
1404                    },
1405                }
1406            }
1407        })
1408        .to_string()
1409        .into_bytes(),
1410    )
1411    .await;
1412    cx.run_until_parked();
1413
1414    thread.update(cx, |thread, cx| {
1415        thread.set_profile(AgentProfileId("test".into()), cx);
1416        thread.add_tool(EchoTool);
1417        thread.add_tool(DelayTool);
1418        thread.add_tool(WordListTool);
1419        thread.add_tool(ToolRequiringPermission);
1420        thread.add_tool(InfiniteTool);
1421    });
1422
1423    // Set up multiple context servers with some overlapping tool names
1424    let _server1_calls = setup_context_server(
1425        "xxx",
1426        vec![
1427            context_server::types::Tool {
1428                name: "echo".into(), // Conflicts with native EchoTool
1429                description: None,
1430                input_schema: serde_json::to_value(EchoTool::input_schema(
1431                    LanguageModelToolSchemaFormat::JsonSchema,
1432                ))
1433                .unwrap(),
1434                output_schema: None,
1435                annotations: None,
1436            },
1437            context_server::types::Tool {
1438                name: "unique_tool_1".into(),
1439                description: None,
1440                input_schema: json!({"type": "object", "properties": {}}),
1441                output_schema: None,
1442                annotations: None,
1443            },
1444        ],
1445        &context_server_store,
1446        cx,
1447    );
1448
1449    let _server2_calls = setup_context_server(
1450        "yyy",
1451        vec![
1452            context_server::types::Tool {
1453                name: "echo".into(), // Also conflicts with native EchoTool
1454                description: None,
1455                input_schema: serde_json::to_value(EchoTool::input_schema(
1456                    LanguageModelToolSchemaFormat::JsonSchema,
1457                ))
1458                .unwrap(),
1459                output_schema: None,
1460                annotations: None,
1461            },
1462            context_server::types::Tool {
1463                name: "unique_tool_2".into(),
1464                description: None,
1465                input_schema: json!({"type": "object", "properties": {}}),
1466                output_schema: None,
1467                annotations: None,
1468            },
1469            context_server::types::Tool {
1470                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1471                description: None,
1472                input_schema: json!({"type": "object", "properties": {}}),
1473                output_schema: None,
1474                annotations: None,
1475            },
1476            context_server::types::Tool {
1477                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1478                description: None,
1479                input_schema: json!({"type": "object", "properties": {}}),
1480                output_schema: None,
1481                annotations: None,
1482            },
1483        ],
1484        &context_server_store,
1485        cx,
1486    );
1487    let _server3_calls = setup_context_server(
1488        "zzz",
1489        vec![
1490            context_server::types::Tool {
1491                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1492                description: None,
1493                input_schema: json!({"type": "object", "properties": {}}),
1494                output_schema: None,
1495                annotations: None,
1496            },
1497            context_server::types::Tool {
1498                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1499                description: None,
1500                input_schema: json!({"type": "object", "properties": {}}),
1501                output_schema: None,
1502                annotations: None,
1503            },
1504            context_server::types::Tool {
1505                name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1506                description: None,
1507                input_schema: json!({"type": "object", "properties": {}}),
1508                output_schema: None,
1509                annotations: None,
1510            },
1511        ],
1512        &context_server_store,
1513        cx,
1514    );
1515
1516    thread
1517        .update(cx, |thread, cx| {
1518            thread.send(UserMessageId::new(), ["Go"], cx)
1519        })
1520        .unwrap();
1521    cx.run_until_parked();
1522    let completion = fake_model.pending_completions().pop().unwrap();
1523    assert_eq!(
1524        tool_names_for_completion(&completion),
1525        vec![
1526            "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1527            "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1528            "delay",
1529            "echo",
1530            "infinite",
1531            "tool_requiring_permission",
1532            "unique_tool_1",
1533            "unique_tool_2",
1534            "word_list",
1535            "xxx_echo",
1536            "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1537            "yyy_echo",
1538            "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1539        ]
1540    );
1541}
1542
1543#[gpui::test]
1544#[cfg_attr(not(feature = "e2e"), ignore)]
1545async fn test_cancellation(cx: &mut TestAppContext) {
1546    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1547
1548    let mut events = thread
1549        .update(cx, |thread, cx| {
1550            thread.add_tool(InfiniteTool);
1551            thread.add_tool(EchoTool);
1552            thread.send(
1553                UserMessageId::new(),
1554                ["Call the echo tool, then call the infinite tool, then explain their output"],
1555                cx,
1556            )
1557        })
1558        .unwrap();
1559
1560    // Wait until both tools are called.
1561    let mut expected_tools = vec!["Echo", "Infinite Tool"];
1562    let mut echo_id = None;
1563    let mut echo_completed = false;
1564    while let Some(event) = events.next().await {
1565        match event.unwrap() {
1566            ThreadEvent::ToolCall(tool_call) => {
1567                assert_eq!(tool_call.title, expected_tools.remove(0));
1568                if tool_call.title == "Echo" {
1569                    echo_id = Some(tool_call.tool_call_id);
1570                }
1571            }
1572            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1573                acp::ToolCallUpdate {
1574                    tool_call_id,
1575                    fields:
1576                        acp::ToolCallUpdateFields {
1577                            status: Some(acp::ToolCallStatus::Completed),
1578                            ..
1579                        },
1580                    ..
1581                },
1582            )) if Some(&tool_call_id) == echo_id.as_ref() => {
1583                echo_completed = true;
1584            }
1585            _ => {}
1586        }
1587
1588        if expected_tools.is_empty() && echo_completed {
1589            break;
1590        }
1591    }
1592
1593    // Cancel the current send and ensure that the event stream is closed, even
1594    // if one of the tools is still running.
1595    thread.update(cx, |thread, cx| thread.cancel(cx));
1596    let events = events.collect::<Vec<_>>().await;
1597    let last_event = events.last();
1598    assert!(
1599        matches!(
1600            last_event,
1601            Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1602        ),
1603        "unexpected event {last_event:?}"
1604    );
1605
1606    // Ensure we can still send a new message after cancellation.
1607    let events = thread
1608        .update(cx, |thread, cx| {
1609            thread.send(
1610                UserMessageId::new(),
1611                ["Testing: reply with 'Hello' then stop."],
1612                cx,
1613            )
1614        })
1615        .unwrap()
1616        .collect::<Vec<_>>()
1617        .await;
1618    thread.update(cx, |thread, _cx| {
1619        let message = thread.last_message().unwrap();
1620        let agent_message = message.as_agent_message().unwrap();
1621        assert_eq!(
1622            agent_message.content,
1623            vec![AgentMessageContent::Text("Hello".to_string())]
1624        );
1625    });
1626    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1627}
1628
1629#[gpui::test]
1630async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1631    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1632    let fake_model = model.as_fake();
1633
1634    let events_1 = thread
1635        .update(cx, |thread, cx| {
1636            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1637        })
1638        .unwrap();
1639    cx.run_until_parked();
1640    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1641    cx.run_until_parked();
1642
1643    let events_2 = thread
1644        .update(cx, |thread, cx| {
1645            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1646        })
1647        .unwrap();
1648    cx.run_until_parked();
1649    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1650    fake_model
1651        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1652    fake_model.end_last_completion_stream();
1653
1654    let events_1 = events_1.collect::<Vec<_>>().await;
1655    assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1656    let events_2 = events_2.collect::<Vec<_>>().await;
1657    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1658}
1659
1660#[gpui::test]
1661async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1662    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1663    let fake_model = model.as_fake();
1664
1665    let events_1 = thread
1666        .update(cx, |thread, cx| {
1667            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1668        })
1669        .unwrap();
1670    cx.run_until_parked();
1671    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1672    fake_model
1673        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1674    fake_model.end_last_completion_stream();
1675    let events_1 = events_1.collect::<Vec<_>>().await;
1676
1677    let events_2 = thread
1678        .update(cx, |thread, cx| {
1679            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1680        })
1681        .unwrap();
1682    cx.run_until_parked();
1683    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1684    fake_model
1685        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1686    fake_model.end_last_completion_stream();
1687    let events_2 = events_2.collect::<Vec<_>>().await;
1688
1689    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1690    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1691}
1692
1693#[gpui::test]
1694async fn test_refusal(cx: &mut TestAppContext) {
1695    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1696    let fake_model = model.as_fake();
1697
1698    let events = thread
1699        .update(cx, |thread, cx| {
1700            thread.send(UserMessageId::new(), ["Hello"], cx)
1701        })
1702        .unwrap();
1703    cx.run_until_parked();
1704    thread.read_with(cx, |thread, _| {
1705        assert_eq!(
1706            thread.to_markdown(),
1707            indoc! {"
1708                ## User
1709
1710                Hello
1711            "}
1712        );
1713    });
1714
1715    fake_model.send_last_completion_stream_text_chunk("Hey!");
1716    cx.run_until_parked();
1717    thread.read_with(cx, |thread, _| {
1718        assert_eq!(
1719            thread.to_markdown(),
1720            indoc! {"
1721                ## User
1722
1723                Hello
1724
1725                ## Assistant
1726
1727                Hey!
1728            "}
1729        );
1730    });
1731
1732    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1733    fake_model
1734        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1735    let events = events.collect::<Vec<_>>().await;
1736    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1737    thread.read_with(cx, |thread, _| {
1738        assert_eq!(thread.to_markdown(), "");
1739    });
1740}
1741
1742#[gpui::test]
1743async fn test_truncate_first_message(cx: &mut TestAppContext) {
1744    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1745    let fake_model = model.as_fake();
1746
1747    let message_id = UserMessageId::new();
1748    thread
1749        .update(cx, |thread, cx| {
1750            thread.send(message_id.clone(), ["Hello"], cx)
1751        })
1752        .unwrap();
1753    cx.run_until_parked();
1754    thread.read_with(cx, |thread, _| {
1755        assert_eq!(
1756            thread.to_markdown(),
1757            indoc! {"
1758                ## User
1759
1760                Hello
1761            "}
1762        );
1763        assert_eq!(thread.latest_token_usage(), None);
1764    });
1765
1766    fake_model.send_last_completion_stream_text_chunk("Hey!");
1767    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1768        language_model::TokenUsage {
1769            input_tokens: 32_000,
1770            output_tokens: 16_000,
1771            cache_creation_input_tokens: 0,
1772            cache_read_input_tokens: 0,
1773        },
1774    ));
1775    cx.run_until_parked();
1776    thread.read_with(cx, |thread, _| {
1777        assert_eq!(
1778            thread.to_markdown(),
1779            indoc! {"
1780                ## User
1781
1782                Hello
1783
1784                ## Assistant
1785
1786                Hey!
1787            "}
1788        );
1789        assert_eq!(
1790            thread.latest_token_usage(),
1791            Some(acp_thread::TokenUsage {
1792                used_tokens: 32_000 + 16_000,
1793                max_tokens: 1_000_000,
1794            })
1795        );
1796    });
1797
1798    thread
1799        .update(cx, |thread, cx| thread.truncate(message_id, cx))
1800        .unwrap();
1801    cx.run_until_parked();
1802    thread.read_with(cx, |thread, _| {
1803        assert_eq!(thread.to_markdown(), "");
1804        assert_eq!(thread.latest_token_usage(), None);
1805    });
1806
1807    // Ensure we can still send a new message after truncation.
1808    thread
1809        .update(cx, |thread, cx| {
1810            thread.send(UserMessageId::new(), ["Hi"], cx)
1811        })
1812        .unwrap();
1813    thread.update(cx, |thread, _cx| {
1814        assert_eq!(
1815            thread.to_markdown(),
1816            indoc! {"
1817                ## User
1818
1819                Hi
1820            "}
1821        );
1822    });
1823    cx.run_until_parked();
1824    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1825    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1826        language_model::TokenUsage {
1827            input_tokens: 40_000,
1828            output_tokens: 20_000,
1829            cache_creation_input_tokens: 0,
1830            cache_read_input_tokens: 0,
1831        },
1832    ));
1833    cx.run_until_parked();
1834    thread.read_with(cx, |thread, _| {
1835        assert_eq!(
1836            thread.to_markdown(),
1837            indoc! {"
1838                ## User
1839
1840                Hi
1841
1842                ## Assistant
1843
1844                Ahoy!
1845            "}
1846        );
1847
1848        assert_eq!(
1849            thread.latest_token_usage(),
1850            Some(acp_thread::TokenUsage {
1851                used_tokens: 40_000 + 20_000,
1852                max_tokens: 1_000_000,
1853            })
1854        );
1855    });
1856}
1857
1858#[gpui::test]
1859async fn test_truncate_second_message(cx: &mut TestAppContext) {
1860    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1861    let fake_model = model.as_fake();
1862
1863    thread
1864        .update(cx, |thread, cx| {
1865            thread.send(UserMessageId::new(), ["Message 1"], cx)
1866        })
1867        .unwrap();
1868    cx.run_until_parked();
1869    fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1870    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1871        language_model::TokenUsage {
1872            input_tokens: 32_000,
1873            output_tokens: 16_000,
1874            cache_creation_input_tokens: 0,
1875            cache_read_input_tokens: 0,
1876        },
1877    ));
1878    fake_model.end_last_completion_stream();
1879    cx.run_until_parked();
1880
1881    let assert_first_message_state = |cx: &mut TestAppContext| {
1882        thread.clone().read_with(cx, |thread, _| {
1883            assert_eq!(
1884                thread.to_markdown(),
1885                indoc! {"
1886                    ## User
1887
1888                    Message 1
1889
1890                    ## Assistant
1891
1892                    Message 1 response
1893                "}
1894            );
1895
1896            assert_eq!(
1897                thread.latest_token_usage(),
1898                Some(acp_thread::TokenUsage {
1899                    used_tokens: 32_000 + 16_000,
1900                    max_tokens: 1_000_000,
1901                })
1902            );
1903        });
1904    };
1905
1906    assert_first_message_state(cx);
1907
1908    let second_message_id = UserMessageId::new();
1909    thread
1910        .update(cx, |thread, cx| {
1911            thread.send(second_message_id.clone(), ["Message 2"], cx)
1912        })
1913        .unwrap();
1914    cx.run_until_parked();
1915
1916    fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1917    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1918        language_model::TokenUsage {
1919            input_tokens: 40_000,
1920            output_tokens: 20_000,
1921            cache_creation_input_tokens: 0,
1922            cache_read_input_tokens: 0,
1923        },
1924    ));
1925    fake_model.end_last_completion_stream();
1926    cx.run_until_parked();
1927
1928    thread.read_with(cx, |thread, _| {
1929        assert_eq!(
1930            thread.to_markdown(),
1931            indoc! {"
1932                ## User
1933
1934                Message 1
1935
1936                ## Assistant
1937
1938                Message 1 response
1939
1940                ## User
1941
1942                Message 2
1943
1944                ## Assistant
1945
1946                Message 2 response
1947            "}
1948        );
1949
1950        assert_eq!(
1951            thread.latest_token_usage(),
1952            Some(acp_thread::TokenUsage {
1953                used_tokens: 40_000 + 20_000,
1954                max_tokens: 1_000_000,
1955            })
1956        );
1957    });
1958
1959    thread
1960        .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1961        .unwrap();
1962    cx.run_until_parked();
1963
1964    assert_first_message_state(cx);
1965}
1966
1967#[gpui::test]
1968async fn test_title_generation(cx: &mut TestAppContext) {
1969    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1970    let fake_model = model.as_fake();
1971
1972    let summary_model = Arc::new(FakeLanguageModel::default());
1973    thread.update(cx, |thread, cx| {
1974        thread.set_summarization_model(Some(summary_model.clone()), cx)
1975    });
1976
1977    let send = thread
1978        .update(cx, |thread, cx| {
1979            thread.send(UserMessageId::new(), ["Hello"], cx)
1980        })
1981        .unwrap();
1982    cx.run_until_parked();
1983
1984    fake_model.send_last_completion_stream_text_chunk("Hey!");
1985    fake_model.end_last_completion_stream();
1986    cx.run_until_parked();
1987    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1988
1989    // Ensure the summary model has been invoked to generate a title.
1990    summary_model.send_last_completion_stream_text_chunk("Hello ");
1991    summary_model.send_last_completion_stream_text_chunk("world\nG");
1992    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1993    summary_model.end_last_completion_stream();
1994    send.collect::<Vec<_>>().await;
1995    cx.run_until_parked();
1996    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1997
1998    // Send another message, ensuring no title is generated this time.
1999    let send = thread
2000        .update(cx, |thread, cx| {
2001            thread.send(UserMessageId::new(), ["Hello again"], cx)
2002        })
2003        .unwrap();
2004    cx.run_until_parked();
2005    fake_model.send_last_completion_stream_text_chunk("Hey again!");
2006    fake_model.end_last_completion_stream();
2007    cx.run_until_parked();
2008    assert_eq!(summary_model.pending_completions(), Vec::new());
2009    send.collect::<Vec<_>>().await;
2010    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2011}
2012
2013#[gpui::test]
2014async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2015    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2016    let fake_model = model.as_fake();
2017
2018    let _events = thread
2019        .update(cx, |thread, cx| {
2020            thread.add_tool(ToolRequiringPermission);
2021            thread.add_tool(EchoTool);
2022            thread.send(UserMessageId::new(), ["Hey!"], cx)
2023        })
2024        .unwrap();
2025    cx.run_until_parked();
2026
2027    let permission_tool_use = LanguageModelToolUse {
2028        id: "tool_id_1".into(),
2029        name: ToolRequiringPermission::name().into(),
2030        raw_input: "{}".into(),
2031        input: json!({}),
2032        is_input_complete: true,
2033        thought_signature: None,
2034    };
2035    let echo_tool_use = LanguageModelToolUse {
2036        id: "tool_id_2".into(),
2037        name: EchoTool::name().into(),
2038        raw_input: json!({"text": "test"}).to_string(),
2039        input: json!({"text": "test"}),
2040        is_input_complete: true,
2041        thought_signature: None,
2042    };
2043    fake_model.send_last_completion_stream_text_chunk("Hi!");
2044    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2045        permission_tool_use,
2046    ));
2047    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2048        echo_tool_use.clone(),
2049    ));
2050    fake_model.end_last_completion_stream();
2051    cx.run_until_parked();
2052
2053    // Ensure pending tools are skipped when building a request.
2054    let request = thread
2055        .read_with(cx, |thread, cx| {
2056            thread.build_completion_request(CompletionIntent::EditFile, cx)
2057        })
2058        .unwrap();
2059    assert_eq!(
2060        request.messages[1..],
2061        vec![
2062            LanguageModelRequestMessage {
2063                role: Role::User,
2064                content: vec!["Hey!".into()],
2065                cache: true,
2066                reasoning_details: None,
2067            },
2068            LanguageModelRequestMessage {
2069                role: Role::Assistant,
2070                content: vec![
2071                    MessageContent::Text("Hi!".into()),
2072                    MessageContent::ToolUse(echo_tool_use.clone())
2073                ],
2074                cache: false,
2075                reasoning_details: None,
2076            },
2077            LanguageModelRequestMessage {
2078                role: Role::User,
2079                content: vec![MessageContent::ToolResult(LanguageModelToolResult {
2080                    tool_use_id: echo_tool_use.id.clone(),
2081                    tool_name: echo_tool_use.name,
2082                    is_error: false,
2083                    content: "test".into(),
2084                    output: Some("test".into())
2085                })],
2086                cache: false,
2087                reasoning_details: None,
2088            },
2089        ],
2090    );
2091}
2092
2093#[gpui::test]
2094async fn test_agent_connection(cx: &mut TestAppContext) {
2095    cx.update(settings::init);
2096    let templates = Templates::new();
2097
2098    // Initialize language model system with test provider
2099    cx.update(|cx| {
2100        gpui_tokio::init(cx);
2101
2102        let http_client = FakeHttpClient::with_404_response();
2103        let clock = Arc::new(clock::FakeSystemClock::new());
2104        let client = Client::new(clock, http_client, cx);
2105        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2106        language_model::init(client.clone(), cx);
2107        language_models::init(user_store, client.clone(), cx);
2108        LanguageModelRegistry::test(cx);
2109    });
2110    cx.executor().forbid_parking();
2111
2112    // Create a project for new_thread
2113    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
2114    fake_fs.insert_tree(path!("/test"), json!({})).await;
2115    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
2116    let cwd = Path::new("/test");
2117    let text_thread_store =
2118        cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
2119    let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
2120
2121    // Create agent and connection
2122    let agent = NativeAgent::new(
2123        project.clone(),
2124        history_store,
2125        templates.clone(),
2126        None,
2127        fake_fs.clone(),
2128        &mut cx.to_async(),
2129    )
2130    .await
2131    .unwrap();
2132    let connection = NativeAgentConnection(agent.clone());
2133
2134    // Create a thread using new_thread
2135    let connection_rc = Rc::new(connection.clone());
2136    let acp_thread = cx
2137        .update(|cx| connection_rc.new_thread(project, cwd, cx))
2138        .await
2139        .expect("new_thread should succeed");
2140
2141    // Get the session_id from the AcpThread
2142    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2143
2144    // Test model_selector returns Some
2145    let selector_opt = connection.model_selector(&session_id);
2146    assert!(
2147        selector_opt.is_some(),
2148        "agent should always support ModelSelector"
2149    );
2150    let selector = selector_opt.unwrap();
2151
2152    // Test list_models
2153    let listed_models = cx
2154        .update(|cx| selector.list_models(cx))
2155        .await
2156        .expect("list_models should succeed");
2157    let AgentModelList::Grouped(listed_models) = listed_models else {
2158        panic!("Unexpected model list type");
2159    };
2160    assert!(!listed_models.is_empty(), "should have at least one model");
2161    assert_eq!(
2162        listed_models[&AgentModelGroupName("Fake".into())][0]
2163            .id
2164            .0
2165            .as_ref(),
2166        "fake/fake"
2167    );
2168
2169    // Test selected_model returns the default
2170    let model = cx
2171        .update(|cx| selector.selected_model(cx))
2172        .await
2173        .expect("selected_model should succeed");
2174    let model = cx
2175        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
2176        .unwrap();
2177    let model = model.as_fake();
2178    assert_eq!(model.id().0, "fake", "should return default model");
2179
2180    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
2181    cx.run_until_parked();
2182    model.send_last_completion_stream_text_chunk("def");
2183    cx.run_until_parked();
2184    acp_thread.read_with(cx, |thread, cx| {
2185        assert_eq!(
2186            thread.to_markdown(cx),
2187            indoc! {"
2188                ## User
2189
2190                abc
2191
2192                ## Assistant
2193
2194                def
2195
2196            "}
2197        )
2198    });
2199
2200    // Test cancel
2201    cx.update(|cx| connection.cancel(&session_id, cx));
2202    request.await.expect("prompt should fail gracefully");
2203
2204    // Ensure that dropping the ACP thread causes the native thread to be
2205    // dropped as well.
2206    cx.update(|_| drop(acp_thread));
2207    let result = cx
2208        .update(|cx| {
2209            connection.prompt(
2210                Some(acp_thread::UserMessageId::new()),
2211                acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
2212                cx,
2213            )
2214        })
2215        .await;
2216    assert_eq!(
2217        result.as_ref().unwrap_err().to_string(),
2218        "Session not found",
2219        "unexpected result: {:?}",
2220        result
2221    );
2222}
2223
2224#[gpui::test]
2225async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2226    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2227    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
2228    let fake_model = model.as_fake();
2229
2230    let mut events = thread
2231        .update(cx, |thread, cx| {
2232            thread.send(UserMessageId::new(), ["Think"], cx)
2233        })
2234        .unwrap();
2235    cx.run_until_parked();
2236
2237    // Simulate streaming partial input.
2238    let input = json!({});
2239    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2240        LanguageModelToolUse {
2241            id: "1".into(),
2242            name: ThinkingTool::name().into(),
2243            raw_input: input.to_string(),
2244            input,
2245            is_input_complete: false,
2246            thought_signature: None,
2247        },
2248    ));
2249
2250    // Input streaming completed
2251    let input = json!({ "content": "Thinking hard!" });
2252    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2253        LanguageModelToolUse {
2254            id: "1".into(),
2255            name: "thinking".into(),
2256            raw_input: input.to_string(),
2257            input,
2258            is_input_complete: true,
2259            thought_signature: None,
2260        },
2261    ));
2262    fake_model.end_last_completion_stream();
2263    cx.run_until_parked();
2264
2265    let tool_call = expect_tool_call(&mut events).await;
2266    assert_eq!(
2267        tool_call,
2268        acp::ToolCall::new("1", "Thinking")
2269            .kind(acp::ToolKind::Think)
2270            .raw_input(json!({}))
2271            .meta(acp::Meta::from_iter([(
2272                "tool_name".into(),
2273                "thinking".into()
2274            )]))
2275    );
2276    let update = expect_tool_call_update_fields(&mut events).await;
2277    assert_eq!(
2278        update,
2279        acp::ToolCallUpdate::new(
2280            "1",
2281            acp::ToolCallUpdateFields::new()
2282                .title("Thinking")
2283                .kind(acp::ToolKind::Think)
2284                .raw_input(json!({ "content": "Thinking hard!"}))
2285        )
2286    );
2287    let update = expect_tool_call_update_fields(&mut events).await;
2288    assert_eq!(
2289        update,
2290        acp::ToolCallUpdate::new(
2291            "1",
2292            acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
2293        )
2294    );
2295    let update = expect_tool_call_update_fields(&mut events).await;
2296    assert_eq!(
2297        update,
2298        acp::ToolCallUpdate::new(
2299            "1",
2300            acp::ToolCallUpdateFields::new().content(vec!["Thinking hard!".into()])
2301        )
2302    );
2303    let update = expect_tool_call_update_fields(&mut events).await;
2304    assert_eq!(
2305        update,
2306        acp::ToolCallUpdate::new(
2307            "1",
2308            acp::ToolCallUpdateFields::new()
2309                .status(acp::ToolCallStatus::Completed)
2310                .raw_output("Finished thinking.")
2311        )
2312    );
2313}
2314
2315#[gpui::test]
2316async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2317    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2318    let fake_model = model.as_fake();
2319
2320    let mut events = thread
2321        .update(cx, |thread, cx| {
2322            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2323            thread.send(UserMessageId::new(), ["Hello!"], cx)
2324        })
2325        .unwrap();
2326    cx.run_until_parked();
2327
2328    fake_model.send_last_completion_stream_text_chunk("Hey!");
2329    fake_model.end_last_completion_stream();
2330
2331    let mut retry_events = Vec::new();
2332    while let Some(Ok(event)) = events.next().await {
2333        match event {
2334            ThreadEvent::Retry(retry_status) => {
2335                retry_events.push(retry_status);
2336            }
2337            ThreadEvent::Stop(..) => break,
2338            _ => {}
2339        }
2340    }
2341
2342    assert_eq!(retry_events.len(), 0);
2343    thread.read_with(cx, |thread, _cx| {
2344        assert_eq!(
2345            thread.to_markdown(),
2346            indoc! {"
2347                ## User
2348
2349                Hello!
2350
2351                ## Assistant
2352
2353                Hey!
2354            "}
2355        )
2356    });
2357}
2358
2359#[gpui::test]
2360async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2361    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2362    let fake_model = model.as_fake();
2363
2364    let mut events = thread
2365        .update(cx, |thread, cx| {
2366            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2367            thread.send(UserMessageId::new(), ["Hello!"], cx)
2368        })
2369        .unwrap();
2370    cx.run_until_parked();
2371
2372    fake_model.send_last_completion_stream_text_chunk("Hey,");
2373    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2374        provider: LanguageModelProviderName::new("Anthropic"),
2375        retry_after: Some(Duration::from_secs(3)),
2376    });
2377    fake_model.end_last_completion_stream();
2378
2379    cx.executor().advance_clock(Duration::from_secs(3));
2380    cx.run_until_parked();
2381
2382    fake_model.send_last_completion_stream_text_chunk("there!");
2383    fake_model.end_last_completion_stream();
2384    cx.run_until_parked();
2385
2386    let mut retry_events = Vec::new();
2387    while let Some(Ok(event)) = events.next().await {
2388        match event {
2389            ThreadEvent::Retry(retry_status) => {
2390                retry_events.push(retry_status);
2391            }
2392            ThreadEvent::Stop(..) => break,
2393            _ => {}
2394        }
2395    }
2396
2397    assert_eq!(retry_events.len(), 1);
2398    assert!(matches!(
2399        retry_events[0],
2400        acp_thread::RetryStatus { attempt: 1, .. }
2401    ));
2402    thread.read_with(cx, |thread, _cx| {
2403        assert_eq!(
2404            thread.to_markdown(),
2405            indoc! {"
2406                ## User
2407
2408                Hello!
2409
2410                ## Assistant
2411
2412                Hey,
2413
2414                [resume]
2415
2416                ## Assistant
2417
2418                there!
2419            "}
2420        )
2421    });
2422}
2423
2424#[gpui::test]
2425async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2426    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2427    let fake_model = model.as_fake();
2428
2429    let events = thread
2430        .update(cx, |thread, cx| {
2431            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2432            thread.add_tool(EchoTool);
2433            thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2434        })
2435        .unwrap();
2436    cx.run_until_parked();
2437
2438    let tool_use_1 = LanguageModelToolUse {
2439        id: "tool_1".into(),
2440        name: EchoTool::name().into(),
2441        raw_input: json!({"text": "test"}).to_string(),
2442        input: json!({"text": "test"}),
2443        is_input_complete: true,
2444        thought_signature: None,
2445    };
2446    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2447        tool_use_1.clone(),
2448    ));
2449    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2450        provider: LanguageModelProviderName::new("Anthropic"),
2451        retry_after: Some(Duration::from_secs(3)),
2452    });
2453    fake_model.end_last_completion_stream();
2454
2455    cx.executor().advance_clock(Duration::from_secs(3));
2456    let completion = fake_model.pending_completions().pop().unwrap();
2457    assert_eq!(
2458        completion.messages[1..],
2459        vec![
2460            LanguageModelRequestMessage {
2461                role: Role::User,
2462                content: vec!["Call the echo tool!".into()],
2463                cache: false,
2464                reasoning_details: None,
2465            },
2466            LanguageModelRequestMessage {
2467                role: Role::Assistant,
2468                content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2469                cache: false,
2470                reasoning_details: None,
2471            },
2472            LanguageModelRequestMessage {
2473                role: Role::User,
2474                content: vec![language_model::MessageContent::ToolResult(
2475                    LanguageModelToolResult {
2476                        tool_use_id: tool_use_1.id.clone(),
2477                        tool_name: tool_use_1.name.clone(),
2478                        is_error: false,
2479                        content: "test".into(),
2480                        output: Some("test".into())
2481                    }
2482                )],
2483                cache: true,
2484                reasoning_details: None,
2485            },
2486        ]
2487    );
2488
2489    fake_model.send_last_completion_stream_text_chunk("Done");
2490    fake_model.end_last_completion_stream();
2491    cx.run_until_parked();
2492    events.collect::<Vec<_>>().await;
2493    thread.read_with(cx, |thread, _cx| {
2494        assert_eq!(
2495            thread.last_message(),
2496            Some(Message::Agent(AgentMessage {
2497                content: vec![AgentMessageContent::Text("Done".into())],
2498                tool_results: IndexMap::default(),
2499                reasoning_details: None,
2500            }))
2501        );
2502    })
2503}
2504
2505#[gpui::test]
2506async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2507    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2508    let fake_model = model.as_fake();
2509
2510    let mut events = thread
2511        .update(cx, |thread, cx| {
2512            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2513            thread.send(UserMessageId::new(), ["Hello!"], cx)
2514        })
2515        .unwrap();
2516    cx.run_until_parked();
2517
2518    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2519        fake_model.send_last_completion_stream_error(
2520            LanguageModelCompletionError::ServerOverloaded {
2521                provider: LanguageModelProviderName::new("Anthropic"),
2522                retry_after: Some(Duration::from_secs(3)),
2523            },
2524        );
2525        fake_model.end_last_completion_stream();
2526        cx.executor().advance_clock(Duration::from_secs(3));
2527        cx.run_until_parked();
2528    }
2529
2530    let mut errors = Vec::new();
2531    let mut retry_events = Vec::new();
2532    while let Some(event) = events.next().await {
2533        match event {
2534            Ok(ThreadEvent::Retry(retry_status)) => {
2535                retry_events.push(retry_status);
2536            }
2537            Ok(ThreadEvent::Stop(..)) => break,
2538            Err(error) => errors.push(error),
2539            _ => {}
2540        }
2541    }
2542
2543    assert_eq!(
2544        retry_events.len(),
2545        crate::thread::MAX_RETRY_ATTEMPTS as usize
2546    );
2547    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2548        assert_eq!(retry_events[i].attempt, i + 1);
2549    }
2550    assert_eq!(errors.len(), 1);
2551    let error = errors[0]
2552        .downcast_ref::<LanguageModelCompletionError>()
2553        .unwrap();
2554    assert!(matches!(
2555        error,
2556        LanguageModelCompletionError::ServerOverloaded { .. }
2557    ));
2558}
2559
2560/// Filters out the stop events for asserting against in tests
2561fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2562    result_events
2563        .into_iter()
2564        .filter_map(|event| match event.unwrap() {
2565            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2566            _ => None,
2567        })
2568        .collect()
2569}
2570
2571struct ThreadTest {
2572    model: Arc<dyn LanguageModel>,
2573    thread: Entity<Thread>,
2574    project_context: Entity<ProjectContext>,
2575    context_server_store: Entity<ContextServerStore>,
2576    fs: Arc<FakeFs>,
2577}
2578
2579enum TestModel {
2580    Sonnet4,
2581    Fake,
2582}
2583
2584impl TestModel {
2585    fn id(&self) -> LanguageModelId {
2586        match self {
2587            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2588            TestModel::Fake => unreachable!(),
2589        }
2590    }
2591}
2592
2593async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2594    cx.executor().allow_parking();
2595
2596    let fs = FakeFs::new(cx.background_executor.clone());
2597    fs.create_dir(paths::settings_file().parent().unwrap())
2598        .await
2599        .unwrap();
2600    fs.insert_file(
2601        paths::settings_file(),
2602        json!({
2603            "agent": {
2604                "default_profile": "test-profile",
2605                "profiles": {
2606                    "test-profile": {
2607                        "name": "Test Profile",
2608                        "tools": {
2609                            EchoTool::name(): true,
2610                            DelayTool::name(): true,
2611                            WordListTool::name(): true,
2612                            ToolRequiringPermission::name(): true,
2613                            InfiniteTool::name(): true,
2614                            ThinkingTool::name(): true,
2615                        }
2616                    }
2617                }
2618            }
2619        })
2620        .to_string()
2621        .into_bytes(),
2622    )
2623    .await;
2624
2625    cx.update(|cx| {
2626        settings::init(cx);
2627
2628        match model {
2629            TestModel::Fake => {}
2630            TestModel::Sonnet4 => {
2631                gpui_tokio::init(cx);
2632                let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2633                cx.set_http_client(Arc::new(http_client));
2634                let client = Client::production(cx);
2635                let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2636                language_model::init(client.clone(), cx);
2637                language_models::init(user_store, client.clone(), cx);
2638            }
2639        };
2640
2641        watch_settings(fs.clone(), cx);
2642    });
2643
2644    let templates = Templates::new();
2645
2646    fs.insert_tree(path!("/test"), json!({})).await;
2647    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2648
2649    let model = cx
2650        .update(|cx| {
2651            if let TestModel::Fake = model {
2652                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2653            } else {
2654                let model_id = model.id();
2655                let models = LanguageModelRegistry::read_global(cx);
2656                let model = models
2657                    .available_models(cx)
2658                    .find(|model| model.id() == model_id)
2659                    .unwrap();
2660
2661                let provider = models.provider(&model.provider_id()).unwrap();
2662                let authenticated = provider.authenticate(cx);
2663
2664                cx.spawn(async move |_cx| {
2665                    authenticated.await.unwrap();
2666                    model
2667                })
2668            }
2669        })
2670        .await;
2671
2672    let project_context = cx.new(|_cx| ProjectContext::default());
2673    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2674    let context_server_registry =
2675        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2676    let thread = cx.new(|cx| {
2677        Thread::new(
2678            project,
2679            project_context.clone(),
2680            context_server_registry,
2681            templates,
2682            Some(model.clone()),
2683            cx,
2684        )
2685    });
2686    ThreadTest {
2687        model,
2688        thread,
2689        project_context,
2690        context_server_store,
2691        fs,
2692    }
2693}
2694
2695#[cfg(test)]
2696#[ctor::ctor]
2697fn init_logger() {
2698    if std::env::var("RUST_LOG").is_ok() {
2699        env_logger::init();
2700    }
2701}
2702
2703fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2704    let fs = fs.clone();
2705    cx.spawn({
2706        async move |cx| {
2707            let mut new_settings_content_rx = settings::watch_config_file(
2708                cx.background_executor(),
2709                fs,
2710                paths::settings_file().clone(),
2711            );
2712
2713            while let Some(new_settings_content) = new_settings_content_rx.next().await {
2714                cx.update(|cx| {
2715                    SettingsStore::update_global(cx, |settings, cx| {
2716                        settings.set_user_settings(&new_settings_content, cx)
2717                    })
2718                })
2719                .ok();
2720            }
2721        }
2722    })
2723    .detach();
2724}
2725
2726fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2727    completion
2728        .tools
2729        .iter()
2730        .map(|tool| tool.name.clone())
2731        .collect()
2732}
2733
2734fn setup_context_server(
2735    name: &'static str,
2736    tools: Vec<context_server::types::Tool>,
2737    context_server_store: &Entity<ContextServerStore>,
2738    cx: &mut TestAppContext,
2739) -> mpsc::UnboundedReceiver<(
2740    context_server::types::CallToolParams,
2741    oneshot::Sender<context_server::types::CallToolResponse>,
2742)> {
2743    cx.update(|cx| {
2744        let mut settings = ProjectSettings::get_global(cx).clone();
2745        settings.context_servers.insert(
2746            name.into(),
2747            project::project_settings::ContextServerSettings::Stdio {
2748                enabled: true,
2749                command: ContextServerCommand {
2750                    path: "somebinary".into(),
2751                    args: Vec::new(),
2752                    env: None,
2753                    timeout: None,
2754                },
2755            },
2756        );
2757        ProjectSettings::override_global(settings, cx);
2758    });
2759
2760    let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2761    let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2762        .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2763            context_server::types::InitializeResponse {
2764                protocol_version: context_server::types::ProtocolVersion(
2765                    context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2766                ),
2767                server_info: context_server::types::Implementation {
2768                    name: name.into(),
2769                    version: "1.0.0".to_string(),
2770                },
2771                capabilities: context_server::types::ServerCapabilities {
2772                    tools: Some(context_server::types::ToolsCapabilities {
2773                        list_changed: Some(true),
2774                    }),
2775                    ..Default::default()
2776                },
2777                meta: None,
2778            }
2779        })
2780        .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2781            let tools = tools.clone();
2782            async move {
2783                context_server::types::ListToolsResponse {
2784                    tools,
2785                    next_cursor: None,
2786                    meta: None,
2787                }
2788            }
2789        })
2790        .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2791            let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2792            async move {
2793                let (response_tx, response_rx) = oneshot::channel();
2794                mcp_tool_calls_tx
2795                    .unbounded_send((params, response_tx))
2796                    .unwrap();
2797                response_rx.await.unwrap()
2798            }
2799        });
2800    context_server_store.update(cx, |store, cx| {
2801        store.start_server(
2802            Arc::new(ContextServer::new(
2803                ContextServerId(name.into()),
2804                Arc::new(fake_transport),
2805            )),
2806            cx,
2807        );
2808    });
2809    cx.run_until_parked();
2810    mcp_tool_calls_rx
2811}
2812
2813#[gpui::test]
2814async fn test_tokens_before_message(cx: &mut TestAppContext) {
2815    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2816    let fake_model = model.as_fake();
2817
2818    // First message
2819    let message_1_id = UserMessageId::new();
2820    thread
2821        .update(cx, |thread, cx| {
2822            thread.send(message_1_id.clone(), ["First message"], cx)
2823        })
2824        .unwrap();
2825    cx.run_until_parked();
2826
2827    // Before any response, tokens_before_message should return None for first message
2828    thread.read_with(cx, |thread, _| {
2829        assert_eq!(
2830            thread.tokens_before_message(&message_1_id),
2831            None,
2832            "First message should have no tokens before it"
2833        );
2834    });
2835
2836    // Complete first message with usage
2837    fake_model.send_last_completion_stream_text_chunk("Response 1");
2838    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2839        language_model::TokenUsage {
2840            input_tokens: 100,
2841            output_tokens: 50,
2842            cache_creation_input_tokens: 0,
2843            cache_read_input_tokens: 0,
2844        },
2845    ));
2846    fake_model.end_last_completion_stream();
2847    cx.run_until_parked();
2848
2849    // First message still has no tokens before it
2850    thread.read_with(cx, |thread, _| {
2851        assert_eq!(
2852            thread.tokens_before_message(&message_1_id),
2853            None,
2854            "First message should still have no tokens before it after response"
2855        );
2856    });
2857
2858    // Second message
2859    let message_2_id = UserMessageId::new();
2860    thread
2861        .update(cx, |thread, cx| {
2862            thread.send(message_2_id.clone(), ["Second message"], cx)
2863        })
2864        .unwrap();
2865    cx.run_until_parked();
2866
2867    // Second message should have first message's input tokens before it
2868    thread.read_with(cx, |thread, _| {
2869        assert_eq!(
2870            thread.tokens_before_message(&message_2_id),
2871            Some(100),
2872            "Second message should have 100 tokens before it (from first request)"
2873        );
2874    });
2875
2876    // Complete second message
2877    fake_model.send_last_completion_stream_text_chunk("Response 2");
2878    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2879        language_model::TokenUsage {
2880            input_tokens: 250, // Total for this request (includes previous context)
2881            output_tokens: 75,
2882            cache_creation_input_tokens: 0,
2883            cache_read_input_tokens: 0,
2884        },
2885    ));
2886    fake_model.end_last_completion_stream();
2887    cx.run_until_parked();
2888
2889    // Third message
2890    let message_3_id = UserMessageId::new();
2891    thread
2892        .update(cx, |thread, cx| {
2893            thread.send(message_3_id.clone(), ["Third message"], cx)
2894        })
2895        .unwrap();
2896    cx.run_until_parked();
2897
2898    // Third message should have second message's input tokens (250) before it
2899    thread.read_with(cx, |thread, _| {
2900        assert_eq!(
2901            thread.tokens_before_message(&message_3_id),
2902            Some(250),
2903            "Third message should have 250 tokens before it (from second request)"
2904        );
2905        // Second message should still have 100
2906        assert_eq!(
2907            thread.tokens_before_message(&message_2_id),
2908            Some(100),
2909            "Second message should still have 100 tokens before it"
2910        );
2911        // First message still has none
2912        assert_eq!(
2913            thread.tokens_before_message(&message_1_id),
2914            None,
2915            "First message should still have no tokens before it"
2916        );
2917    });
2918}
2919
2920#[gpui::test]
2921async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
2922    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2923    let fake_model = model.as_fake();
2924
2925    // Set up three messages with responses
2926    let message_1_id = UserMessageId::new();
2927    thread
2928        .update(cx, |thread, cx| {
2929            thread.send(message_1_id.clone(), ["Message 1"], cx)
2930        })
2931        .unwrap();
2932    cx.run_until_parked();
2933    fake_model.send_last_completion_stream_text_chunk("Response 1");
2934    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2935        language_model::TokenUsage {
2936            input_tokens: 100,
2937            output_tokens: 50,
2938            cache_creation_input_tokens: 0,
2939            cache_read_input_tokens: 0,
2940        },
2941    ));
2942    fake_model.end_last_completion_stream();
2943    cx.run_until_parked();
2944
2945    let message_2_id = UserMessageId::new();
2946    thread
2947        .update(cx, |thread, cx| {
2948            thread.send(message_2_id.clone(), ["Message 2"], cx)
2949        })
2950        .unwrap();
2951    cx.run_until_parked();
2952    fake_model.send_last_completion_stream_text_chunk("Response 2");
2953    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2954        language_model::TokenUsage {
2955            input_tokens: 250,
2956            output_tokens: 75,
2957            cache_creation_input_tokens: 0,
2958            cache_read_input_tokens: 0,
2959        },
2960    ));
2961    fake_model.end_last_completion_stream();
2962    cx.run_until_parked();
2963
2964    // Verify initial state
2965    thread.read_with(cx, |thread, _| {
2966        assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
2967    });
2968
2969    // Truncate at message 2 (removes message 2 and everything after)
2970    thread
2971        .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
2972        .unwrap();
2973    cx.run_until_parked();
2974
2975    // After truncation, message_2_id no longer exists, so lookup should return None
2976    thread.read_with(cx, |thread, _| {
2977        assert_eq!(
2978            thread.tokens_before_message(&message_2_id),
2979            None,
2980            "After truncation, message 2 no longer exists"
2981        );
2982        // Message 1 still exists but has no tokens before it
2983        assert_eq!(
2984            thread.tokens_before_message(&message_1_id),
2985            None,
2986            "First message still has no tokens before it"
2987        );
2988    });
2989}