mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use agent_client_protocol::{self as acp};
   4use agent_settings::AgentProfileId;
   5use anyhow::Result;
   6use client::{Client, UserStore};
   7use cloud_llm_client::CompletionIntent;
   8use collections::IndexMap;
   9use context_server::{ContextServer, ContextServerCommand, ContextServerId};
  10use fs::{FakeFs, Fs};
  11use futures::{
  12    FutureExt as _, StreamExt,
  13    channel::{
  14        mpsc::{self, UnboundedReceiver},
  15        oneshot,
  16    },
  17    future::{Fuse, Shared},
  18};
  19use gpui::{
  20    App, AppContext, AsyncApp, Entity, Task, TestAppContext, UpdateGlobal,
  21    http_client::FakeHttpClient,
  22};
  23use indoc::indoc;
  24use language_model::{
  25    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  26    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
  27    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
  28    LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
  29};
  30use pretty_assertions::assert_eq;
  31use project::{
  32    Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
  33};
  34use prompt_store::ProjectContext;
  35use reqwest_client::ReqwestClient;
  36use schemars::JsonSchema;
  37use serde::{Deserialize, Serialize};
  38use serde_json::json;
  39use settings::{Settings, SettingsStore};
  40use std::{
  41    path::Path,
  42    pin::Pin,
  43    rc::Rc,
  44    sync::{
  45        Arc,
  46        atomic::{AtomicBool, Ordering},
  47    },
  48    time::Duration,
  49};
  50use util::path;
  51
  52mod test_tools;
  53use test_tools::*;
  54
  55fn init_test(cx: &mut TestAppContext) {
  56    cx.update(|cx| {
  57        let settings_store = SettingsStore::test(cx);
  58        cx.set_global(settings_store);
  59    });
  60}
  61
  62struct FakeTerminalHandle {
  63    killed: Arc<AtomicBool>,
  64    wait_for_exit: Shared<Task<acp::TerminalExitStatus>>,
  65    output: acp::TerminalOutputResponse,
  66    id: acp::TerminalId,
  67}
  68
  69impl FakeTerminalHandle {
  70    fn new_never_exits(cx: &mut App) -> Self {
  71        let killed = Arc::new(AtomicBool::new(false));
  72
  73        let killed_for_task = killed.clone();
  74        let wait_for_exit = cx
  75            .spawn(async move |cx| {
  76                loop {
  77                    if killed_for_task.load(Ordering::SeqCst) {
  78                        return acp::TerminalExitStatus::new();
  79                    }
  80                    cx.background_executor()
  81                        .timer(Duration::from_millis(1))
  82                        .await;
  83                }
  84            })
  85            .shared();
  86
  87        Self {
  88            killed,
  89            wait_for_exit,
  90            output: acp::TerminalOutputResponse::new("partial output".to_string(), false),
  91            id: acp::TerminalId::new("fake_terminal".to_string()),
  92        }
  93    }
  94
  95    fn was_killed(&self) -> bool {
  96        self.killed.load(Ordering::SeqCst)
  97    }
  98}
  99
 100impl crate::TerminalHandle for FakeTerminalHandle {
 101    fn id(&self, _cx: &AsyncApp) -> Result<acp::TerminalId> {
 102        Ok(self.id.clone())
 103    }
 104
 105    fn current_output(&self, _cx: &AsyncApp) -> Result<acp::TerminalOutputResponse> {
 106        Ok(self.output.clone())
 107    }
 108
 109    fn wait_for_exit(&self, _cx: &AsyncApp) -> Result<Shared<Task<acp::TerminalExitStatus>>> {
 110        Ok(self.wait_for_exit.clone())
 111    }
 112
 113    fn kill(&self, _cx: &AsyncApp) -> Result<()> {
 114        self.killed.store(true, Ordering::SeqCst);
 115        Ok(())
 116    }
 117
 118    fn was_stopped_by_user(&self, _cx: &AsyncApp) -> Result<bool> {
 119        Ok(false)
 120    }
 121}
 122
 123struct FakeThreadEnvironment {
 124    handle: Rc<FakeTerminalHandle>,
 125}
 126
 127impl crate::ThreadEnvironment for FakeThreadEnvironment {
 128    fn create_terminal(
 129        &self,
 130        _command: String,
 131        _cwd: Option<std::path::PathBuf>,
 132        _output_byte_limit: Option<u64>,
 133        _cx: &mut AsyncApp,
 134    ) -> Task<Result<Rc<dyn crate::TerminalHandle>>> {
 135        Task::ready(Ok(self.handle.clone() as Rc<dyn crate::TerminalHandle>))
 136    }
 137}
 138
 139fn always_allow_tools(cx: &mut TestAppContext) {
 140    cx.update(|cx| {
 141        let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
 142        settings.always_allow_tool_actions = true;
 143        agent_settings::AgentSettings::override_global(settings, cx);
 144    });
 145}
 146
 147#[gpui::test]
 148async fn test_echo(cx: &mut TestAppContext) {
 149    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 150    let fake_model = model.as_fake();
 151
 152    let events = thread
 153        .update(cx, |thread, cx| {
 154            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
 155        })
 156        .unwrap();
 157    cx.run_until_parked();
 158    fake_model.send_last_completion_stream_text_chunk("Hello");
 159    fake_model
 160        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
 161    fake_model.end_last_completion_stream();
 162
 163    let events = events.collect().await;
 164    thread.update(cx, |thread, _cx| {
 165        assert_eq!(
 166            thread.last_message().unwrap().to_markdown(),
 167            indoc! {"
 168                ## Assistant
 169
 170                Hello
 171            "}
 172        )
 173    });
 174    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 175}
 176
 177#[gpui::test]
 178async fn test_terminal_tool_timeout_kills_handle(cx: &mut TestAppContext) {
 179    init_test(cx);
 180    always_allow_tools(cx);
 181
 182    let fs = FakeFs::new(cx.executor());
 183    let project = Project::test(fs, [], cx).await;
 184
 185    let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
 186    let environment = Rc::new(FakeThreadEnvironment {
 187        handle: handle.clone(),
 188    });
 189
 190    #[allow(clippy::arc_with_non_send_sync)]
 191    let tool = Arc::new(crate::TerminalTool::new(project, environment));
 192    let (event_stream, mut rx) = crate::ToolCallEventStream::test();
 193
 194    let task = cx.update(|cx| {
 195        tool.run(
 196            crate::TerminalToolInput {
 197                command: "sleep 1000".to_string(),
 198                cd: ".".to_string(),
 199                timeout_ms: Some(5),
 200            },
 201            event_stream,
 202            cx,
 203        )
 204    });
 205
 206    let update = rx.expect_update_fields().await;
 207    assert!(
 208        update.content.iter().any(|blocks| {
 209            blocks
 210                .iter()
 211                .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
 212        }),
 213        "expected tool call update to include terminal content"
 214    );
 215
 216    let mut task_future: Pin<Box<Fuse<Task<Result<String>>>>> = Box::pin(task.fuse());
 217
 218    let deadline = std::time::Instant::now() + Duration::from_millis(500);
 219    loop {
 220        if let Some(result) = task_future.as_mut().now_or_never() {
 221            let result = result.expect("terminal tool task should complete");
 222
 223            assert!(
 224                handle.was_killed(),
 225                "expected terminal handle to be killed on timeout"
 226            );
 227            assert!(
 228                result.contains("partial output"),
 229                "expected result to include terminal output, got: {result}"
 230            );
 231            return;
 232        }
 233
 234        if std::time::Instant::now() >= deadline {
 235            panic!("timed out waiting for terminal tool task to complete");
 236        }
 237
 238        cx.run_until_parked();
 239        cx.background_executor.timer(Duration::from_millis(1)).await;
 240    }
 241}
 242
 243#[gpui::test]
 244#[ignore]
 245async fn test_terminal_tool_without_timeout_does_not_kill_handle(cx: &mut TestAppContext) {
 246    init_test(cx);
 247    always_allow_tools(cx);
 248
 249    let fs = FakeFs::new(cx.executor());
 250    let project = Project::test(fs, [], cx).await;
 251
 252    let handle = Rc::new(cx.update(|cx| FakeTerminalHandle::new_never_exits(cx)));
 253    let environment = Rc::new(FakeThreadEnvironment {
 254        handle: handle.clone(),
 255    });
 256
 257    #[allow(clippy::arc_with_non_send_sync)]
 258    let tool = Arc::new(crate::TerminalTool::new(project, environment));
 259    let (event_stream, mut rx) = crate::ToolCallEventStream::test();
 260
 261    let _task = cx.update(|cx| {
 262        tool.run(
 263            crate::TerminalToolInput {
 264                command: "sleep 1000".to_string(),
 265                cd: ".".to_string(),
 266                timeout_ms: None,
 267            },
 268            event_stream,
 269            cx,
 270        )
 271    });
 272
 273    let update = rx.expect_update_fields().await;
 274    assert!(
 275        update.content.iter().any(|blocks| {
 276            blocks
 277                .iter()
 278                .any(|c| matches!(c, acp::ToolCallContent::Terminal(_)))
 279        }),
 280        "expected tool call update to include terminal content"
 281    );
 282
 283    smol::Timer::after(Duration::from_millis(25)).await;
 284
 285    assert!(
 286        !handle.was_killed(),
 287        "did not expect terminal handle to be killed without a timeout"
 288    );
 289}
 290
 291#[gpui::test]
 292async fn test_thinking(cx: &mut TestAppContext) {
 293    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 294    let fake_model = model.as_fake();
 295
 296    let events = thread
 297        .update(cx, |thread, cx| {
 298            thread.send(
 299                UserMessageId::new(),
 300                [indoc! {"
 301                    Testing:
 302
 303                    Generate a thinking step where you just think the word 'Think',
 304                    and have your final answer be 'Hello'
 305                "}],
 306                cx,
 307            )
 308        })
 309        .unwrap();
 310    cx.run_until_parked();
 311    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
 312        text: "Think".to_string(),
 313        signature: None,
 314    });
 315    fake_model.send_last_completion_stream_text_chunk("Hello");
 316    fake_model
 317        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
 318    fake_model.end_last_completion_stream();
 319
 320    let events = events.collect().await;
 321    thread.update(cx, |thread, _cx| {
 322        assert_eq!(
 323            thread.last_message().unwrap().to_markdown(),
 324            indoc! {"
 325                ## Assistant
 326
 327                <think>Think</think>
 328                Hello
 329            "}
 330        )
 331    });
 332    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 333}
 334
 335#[gpui::test]
 336async fn test_system_prompt(cx: &mut TestAppContext) {
 337    let ThreadTest {
 338        model,
 339        thread,
 340        project_context,
 341        ..
 342    } = setup(cx, TestModel::Fake).await;
 343    let fake_model = model.as_fake();
 344
 345    project_context.update(cx, |project_context, _cx| {
 346        project_context.shell = "test-shell".into()
 347    });
 348    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 349    thread
 350        .update(cx, |thread, cx| {
 351            thread.send(UserMessageId::new(), ["abc"], cx)
 352        })
 353        .unwrap();
 354    cx.run_until_parked();
 355    let mut pending_completions = fake_model.pending_completions();
 356    assert_eq!(
 357        pending_completions.len(),
 358        1,
 359        "unexpected pending completions: {:?}",
 360        pending_completions
 361    );
 362
 363    let pending_completion = pending_completions.pop().unwrap();
 364    assert_eq!(pending_completion.messages[0].role, Role::System);
 365
 366    let system_message = &pending_completion.messages[0];
 367    let system_prompt = system_message.content[0].to_str().unwrap();
 368    assert!(
 369        system_prompt.contains("test-shell"),
 370        "unexpected system message: {:?}",
 371        system_message
 372    );
 373    assert!(
 374        system_prompt.contains("## Fixing Diagnostics"),
 375        "unexpected system message: {:?}",
 376        system_message
 377    );
 378}
 379
 380#[gpui::test]
 381async fn test_system_prompt_without_tools(cx: &mut TestAppContext) {
 382    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 383    let fake_model = model.as_fake();
 384
 385    thread
 386        .update(cx, |thread, cx| {
 387            thread.send(UserMessageId::new(), ["abc"], cx)
 388        })
 389        .unwrap();
 390    cx.run_until_parked();
 391    let mut pending_completions = fake_model.pending_completions();
 392    assert_eq!(
 393        pending_completions.len(),
 394        1,
 395        "unexpected pending completions: {:?}",
 396        pending_completions
 397    );
 398
 399    let pending_completion = pending_completions.pop().unwrap();
 400    assert_eq!(pending_completion.messages[0].role, Role::System);
 401
 402    let system_message = &pending_completion.messages[0];
 403    let system_prompt = system_message.content[0].to_str().unwrap();
 404    assert!(
 405        !system_prompt.contains("## Tool Use"),
 406        "unexpected system message: {:?}",
 407        system_message
 408    );
 409    assert!(
 410        !system_prompt.contains("## Fixing Diagnostics"),
 411        "unexpected system message: {:?}",
 412        system_message
 413    );
 414}
 415
 416#[gpui::test]
 417async fn test_prompt_caching(cx: &mut TestAppContext) {
 418    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 419    let fake_model = model.as_fake();
 420
 421    // Send initial user message and verify it's cached
 422    thread
 423        .update(cx, |thread, cx| {
 424            thread.send(UserMessageId::new(), ["Message 1"], cx)
 425        })
 426        .unwrap();
 427    cx.run_until_parked();
 428
 429    let completion = fake_model.pending_completions().pop().unwrap();
 430    assert_eq!(
 431        completion.messages[1..],
 432        vec![LanguageModelRequestMessage {
 433            role: Role::User,
 434            content: vec!["Message 1".into()],
 435            cache: true,
 436            reasoning_details: None,
 437        }]
 438    );
 439    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 440        "Response to Message 1".into(),
 441    ));
 442    fake_model.end_last_completion_stream();
 443    cx.run_until_parked();
 444
 445    // Send another user message and verify only the latest is cached
 446    thread
 447        .update(cx, |thread, cx| {
 448            thread.send(UserMessageId::new(), ["Message 2"], cx)
 449        })
 450        .unwrap();
 451    cx.run_until_parked();
 452
 453    let completion = fake_model.pending_completions().pop().unwrap();
 454    assert_eq!(
 455        completion.messages[1..],
 456        vec![
 457            LanguageModelRequestMessage {
 458                role: Role::User,
 459                content: vec!["Message 1".into()],
 460                cache: false,
 461                reasoning_details: None,
 462            },
 463            LanguageModelRequestMessage {
 464                role: Role::Assistant,
 465                content: vec!["Response to Message 1".into()],
 466                cache: false,
 467                reasoning_details: None,
 468            },
 469            LanguageModelRequestMessage {
 470                role: Role::User,
 471                content: vec!["Message 2".into()],
 472                cache: true,
 473                reasoning_details: None,
 474            }
 475        ]
 476    );
 477    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 478        "Response to Message 2".into(),
 479    ));
 480    fake_model.end_last_completion_stream();
 481    cx.run_until_parked();
 482
 483    // Simulate a tool call and verify that the latest tool result is cached
 484    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 485    thread
 486        .update(cx, |thread, cx| {
 487            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 488        })
 489        .unwrap();
 490    cx.run_until_parked();
 491
 492    let tool_use = LanguageModelToolUse {
 493        id: "tool_1".into(),
 494        name: EchoTool::name().into(),
 495        raw_input: json!({"text": "test"}).to_string(),
 496        input: json!({"text": "test"}),
 497        is_input_complete: true,
 498        thought_signature: None,
 499    };
 500    fake_model
 501        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 502    fake_model.end_last_completion_stream();
 503    cx.run_until_parked();
 504
 505    let completion = fake_model.pending_completions().pop().unwrap();
 506    let tool_result = LanguageModelToolResult {
 507        tool_use_id: "tool_1".into(),
 508        tool_name: EchoTool::name().into(),
 509        is_error: false,
 510        content: "test".into(),
 511        output: Some("test".into()),
 512    };
 513    assert_eq!(
 514        completion.messages[1..],
 515        vec![
 516            LanguageModelRequestMessage {
 517                role: Role::User,
 518                content: vec!["Message 1".into()],
 519                cache: false,
 520                reasoning_details: None,
 521            },
 522            LanguageModelRequestMessage {
 523                role: Role::Assistant,
 524                content: vec!["Response to Message 1".into()],
 525                cache: false,
 526                reasoning_details: None,
 527            },
 528            LanguageModelRequestMessage {
 529                role: Role::User,
 530                content: vec!["Message 2".into()],
 531                cache: false,
 532                reasoning_details: None,
 533            },
 534            LanguageModelRequestMessage {
 535                role: Role::Assistant,
 536                content: vec!["Response to Message 2".into()],
 537                cache: false,
 538                reasoning_details: None,
 539            },
 540            LanguageModelRequestMessage {
 541                role: Role::User,
 542                content: vec!["Use the echo tool".into()],
 543                cache: false,
 544                reasoning_details: None,
 545            },
 546            LanguageModelRequestMessage {
 547                role: Role::Assistant,
 548                content: vec![MessageContent::ToolUse(tool_use)],
 549                cache: false,
 550                reasoning_details: None,
 551            },
 552            LanguageModelRequestMessage {
 553                role: Role::User,
 554                content: vec![MessageContent::ToolResult(tool_result)],
 555                cache: true,
 556                reasoning_details: None,
 557            }
 558        ]
 559    );
 560}
 561
 562#[gpui::test]
 563#[cfg_attr(not(feature = "e2e"), ignore)]
 564async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 565    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 566
 567    // Test a tool call that's likely to complete *before* streaming stops.
 568    let events = thread
 569        .update(cx, |thread, cx| {
 570            thread.add_tool(EchoTool);
 571            thread.send(
 572                UserMessageId::new(),
 573                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 574                cx,
 575            )
 576        })
 577        .unwrap()
 578        .collect()
 579        .await;
 580    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 581
 582    // Test a tool calls that's likely to complete *after* streaming stops.
 583    let events = thread
 584        .update(cx, |thread, cx| {
 585            thread.remove_tool(&EchoTool::name());
 586            thread.add_tool(DelayTool);
 587            thread.send(
 588                UserMessageId::new(),
 589                [
 590                    "Now call the delay tool with 200ms.",
 591                    "When the timer goes off, then you echo the output of the tool.",
 592                ],
 593                cx,
 594            )
 595        })
 596        .unwrap()
 597        .collect()
 598        .await;
 599    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 600    thread.update(cx, |thread, _cx| {
 601        assert!(
 602            thread
 603                .last_message()
 604                .unwrap()
 605                .as_agent_message()
 606                .unwrap()
 607                .content
 608                .iter()
 609                .any(|content| {
 610                    if let AgentMessageContent::Text(text) = content {
 611                        text.contains("Ding")
 612                    } else {
 613                        false
 614                    }
 615                }),
 616            "{}",
 617            thread.to_markdown()
 618        );
 619    });
 620}
 621
 622#[gpui::test]
 623#[cfg_attr(not(feature = "e2e"), ignore)]
 624async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 625    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 626
 627    // Test a tool call that's likely to complete *before* streaming stops.
 628    let mut events = thread
 629        .update(cx, |thread, cx| {
 630            thread.add_tool(WordListTool);
 631            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 632        })
 633        .unwrap();
 634
 635    let mut saw_partial_tool_use = false;
 636    while let Some(event) = events.next().await {
 637        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 638            thread.update(cx, |thread, _cx| {
 639                // Look for a tool use in the thread's last message
 640                let message = thread.last_message().unwrap();
 641                let agent_message = message.as_agent_message().unwrap();
 642                let last_content = agent_message.content.last().unwrap();
 643                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 644                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 645                    if tool_call.status == acp::ToolCallStatus::Pending {
 646                        if !last_tool_use.is_input_complete
 647                            && last_tool_use.input.get("g").is_none()
 648                        {
 649                            saw_partial_tool_use = true;
 650                        }
 651                    } else {
 652                        last_tool_use
 653                            .input
 654                            .get("a")
 655                            .expect("'a' has streamed because input is now complete");
 656                        last_tool_use
 657                            .input
 658                            .get("g")
 659                            .expect("'g' has streamed because input is now complete");
 660                    }
 661                } else {
 662                    panic!("last content should be a tool use");
 663                }
 664            });
 665        }
 666    }
 667
 668    assert!(
 669        saw_partial_tool_use,
 670        "should see at least one partially streamed tool use in the history"
 671    );
 672}
 673
 674#[gpui::test]
 675async fn test_tool_authorization(cx: &mut TestAppContext) {
 676    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 677    let fake_model = model.as_fake();
 678
 679    let mut events = thread
 680        .update(cx, |thread, cx| {
 681            thread.add_tool(ToolRequiringPermission);
 682            thread.send(UserMessageId::new(), ["abc"], cx)
 683        })
 684        .unwrap();
 685    cx.run_until_parked();
 686    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 687        LanguageModelToolUse {
 688            id: "tool_id_1".into(),
 689            name: ToolRequiringPermission::name().into(),
 690            raw_input: "{}".into(),
 691            input: json!({}),
 692            is_input_complete: true,
 693            thought_signature: None,
 694        },
 695    ));
 696    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 697        LanguageModelToolUse {
 698            id: "tool_id_2".into(),
 699            name: ToolRequiringPermission::name().into(),
 700            raw_input: "{}".into(),
 701            input: json!({}),
 702            is_input_complete: true,
 703            thought_signature: None,
 704        },
 705    ));
 706    fake_model.end_last_completion_stream();
 707    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 708    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 709
 710    // Approve the first
 711    tool_call_auth_1
 712        .response
 713        .send(tool_call_auth_1.options[1].option_id.clone())
 714        .unwrap();
 715    cx.run_until_parked();
 716
 717    // Reject the second
 718    tool_call_auth_2
 719        .response
 720        .send(tool_call_auth_1.options[2].option_id.clone())
 721        .unwrap();
 722    cx.run_until_parked();
 723
 724    let completion = fake_model.pending_completions().pop().unwrap();
 725    let message = completion.messages.last().unwrap();
 726    assert_eq!(
 727        message.content,
 728        vec![
 729            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 730                tool_use_id: tool_call_auth_1.tool_call.tool_call_id.0.to_string().into(),
 731                tool_name: ToolRequiringPermission::name().into(),
 732                is_error: false,
 733                content: "Allowed".into(),
 734                output: Some("Allowed".into())
 735            }),
 736            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 737                tool_use_id: tool_call_auth_2.tool_call.tool_call_id.0.to_string().into(),
 738                tool_name: ToolRequiringPermission::name().into(),
 739                is_error: true,
 740                content: "Permission to run tool denied by user".into(),
 741                output: Some("Permission to run tool denied by user".into())
 742            })
 743        ]
 744    );
 745
 746    // Simulate yet another tool call.
 747    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 748        LanguageModelToolUse {
 749            id: "tool_id_3".into(),
 750            name: ToolRequiringPermission::name().into(),
 751            raw_input: "{}".into(),
 752            input: json!({}),
 753            is_input_complete: true,
 754            thought_signature: None,
 755        },
 756    ));
 757    fake_model.end_last_completion_stream();
 758
 759    // Respond by always allowing tools.
 760    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 761    tool_call_auth_3
 762        .response
 763        .send(tool_call_auth_3.options[0].option_id.clone())
 764        .unwrap();
 765    cx.run_until_parked();
 766    let completion = fake_model.pending_completions().pop().unwrap();
 767    let message = completion.messages.last().unwrap();
 768    assert_eq!(
 769        message.content,
 770        vec![language_model::MessageContent::ToolResult(
 771            LanguageModelToolResult {
 772                tool_use_id: tool_call_auth_3.tool_call.tool_call_id.0.to_string().into(),
 773                tool_name: ToolRequiringPermission::name().into(),
 774                is_error: false,
 775                content: "Allowed".into(),
 776                output: Some("Allowed".into())
 777            }
 778        )]
 779    );
 780
 781    // Simulate a final tool call, ensuring we don't trigger authorization.
 782    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 783        LanguageModelToolUse {
 784            id: "tool_id_4".into(),
 785            name: ToolRequiringPermission::name().into(),
 786            raw_input: "{}".into(),
 787            input: json!({}),
 788            is_input_complete: true,
 789            thought_signature: None,
 790        },
 791    ));
 792    fake_model.end_last_completion_stream();
 793    cx.run_until_parked();
 794    let completion = fake_model.pending_completions().pop().unwrap();
 795    let message = completion.messages.last().unwrap();
 796    assert_eq!(
 797        message.content,
 798        vec![language_model::MessageContent::ToolResult(
 799            LanguageModelToolResult {
 800                tool_use_id: "tool_id_4".into(),
 801                tool_name: ToolRequiringPermission::name().into(),
 802                is_error: false,
 803                content: "Allowed".into(),
 804                output: Some("Allowed".into())
 805            }
 806        )]
 807    );
 808}
 809
 810#[gpui::test]
 811async fn test_tool_hallucination(cx: &mut TestAppContext) {
 812    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 813    let fake_model = model.as_fake();
 814
 815    let mut events = thread
 816        .update(cx, |thread, cx| {
 817            thread.send(UserMessageId::new(), ["abc"], cx)
 818        })
 819        .unwrap();
 820    cx.run_until_parked();
 821    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 822        LanguageModelToolUse {
 823            id: "tool_id_1".into(),
 824            name: "nonexistent_tool".into(),
 825            raw_input: "{}".into(),
 826            input: json!({}),
 827            is_input_complete: true,
 828            thought_signature: None,
 829        },
 830    ));
 831    fake_model.end_last_completion_stream();
 832
 833    let tool_call = expect_tool_call(&mut events).await;
 834    assert_eq!(tool_call.title, "nonexistent_tool");
 835    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 836    let update = expect_tool_call_update_fields(&mut events).await;
 837    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 838}
 839
 840#[gpui::test]
 841async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 842    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 843    let fake_model = model.as_fake();
 844
 845    let events = thread
 846        .update(cx, |thread, cx| {
 847            thread.add_tool(EchoTool);
 848            thread.send(UserMessageId::new(), ["abc"], cx)
 849        })
 850        .unwrap();
 851    cx.run_until_parked();
 852    let tool_use = LanguageModelToolUse {
 853        id: "tool_id_1".into(),
 854        name: EchoTool::name().into(),
 855        raw_input: "{}".into(),
 856        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 857        is_input_complete: true,
 858        thought_signature: None,
 859    };
 860    fake_model
 861        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 862    fake_model.end_last_completion_stream();
 863
 864    cx.run_until_parked();
 865    let completion = fake_model.pending_completions().pop().unwrap();
 866    let tool_result = LanguageModelToolResult {
 867        tool_use_id: "tool_id_1".into(),
 868        tool_name: EchoTool::name().into(),
 869        is_error: false,
 870        content: "def".into(),
 871        output: Some("def".into()),
 872    };
 873    assert_eq!(
 874        completion.messages[1..],
 875        vec![
 876            LanguageModelRequestMessage {
 877                role: Role::User,
 878                content: vec!["abc".into()],
 879                cache: false,
 880                reasoning_details: None,
 881            },
 882            LanguageModelRequestMessage {
 883                role: Role::Assistant,
 884                content: vec![MessageContent::ToolUse(tool_use.clone())],
 885                cache: false,
 886                reasoning_details: None,
 887            },
 888            LanguageModelRequestMessage {
 889                role: Role::User,
 890                content: vec![MessageContent::ToolResult(tool_result.clone())],
 891                cache: true,
 892                reasoning_details: None,
 893            },
 894        ]
 895    );
 896
 897    // Simulate reaching tool use limit.
 898    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
 899    fake_model.end_last_completion_stream();
 900    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 901    assert!(
 902        last_event
 903            .unwrap_err()
 904            .is::<language_model::ToolUseLimitReachedError>()
 905    );
 906
 907    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 908    cx.run_until_parked();
 909    let completion = fake_model.pending_completions().pop().unwrap();
 910    assert_eq!(
 911        completion.messages[1..],
 912        vec![
 913            LanguageModelRequestMessage {
 914                role: Role::User,
 915                content: vec!["abc".into()],
 916                cache: false,
 917                reasoning_details: None,
 918            },
 919            LanguageModelRequestMessage {
 920                role: Role::Assistant,
 921                content: vec![MessageContent::ToolUse(tool_use)],
 922                cache: false,
 923                reasoning_details: None,
 924            },
 925            LanguageModelRequestMessage {
 926                role: Role::User,
 927                content: vec![MessageContent::ToolResult(tool_result)],
 928                cache: false,
 929                reasoning_details: None,
 930            },
 931            LanguageModelRequestMessage {
 932                role: Role::User,
 933                content: vec!["Continue where you left off".into()],
 934                cache: true,
 935                reasoning_details: None,
 936            }
 937        ]
 938    );
 939
 940    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 941    fake_model.end_last_completion_stream();
 942    events.collect::<Vec<_>>().await;
 943    thread.read_with(cx, |thread, _cx| {
 944        assert_eq!(
 945            thread.last_message().unwrap().to_markdown(),
 946            indoc! {"
 947                ## Assistant
 948
 949                Done
 950            "}
 951        )
 952    });
 953}
 954
 955#[gpui::test]
 956async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 957    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 958    let fake_model = model.as_fake();
 959
 960    let events = thread
 961        .update(cx, |thread, cx| {
 962            thread.add_tool(EchoTool);
 963            thread.send(UserMessageId::new(), ["abc"], cx)
 964        })
 965        .unwrap();
 966    cx.run_until_parked();
 967
 968    let tool_use = LanguageModelToolUse {
 969        id: "tool_id_1".into(),
 970        name: EchoTool::name().into(),
 971        raw_input: "{}".into(),
 972        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 973        is_input_complete: true,
 974        thought_signature: None,
 975    };
 976    let tool_result = LanguageModelToolResult {
 977        tool_use_id: "tool_id_1".into(),
 978        tool_name: EchoTool::name().into(),
 979        is_error: false,
 980        content: "def".into(),
 981        output: Some("def".into()),
 982    };
 983    fake_model
 984        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 985    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
 986    fake_model.end_last_completion_stream();
 987    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 988    assert!(
 989        last_event
 990            .unwrap_err()
 991            .is::<language_model::ToolUseLimitReachedError>()
 992    );
 993
 994    thread
 995        .update(cx, |thread, cx| {
 996            thread.send(UserMessageId::new(), vec!["ghi"], cx)
 997        })
 998        .unwrap();
 999    cx.run_until_parked();
1000    let completion = fake_model.pending_completions().pop().unwrap();
1001    assert_eq!(
1002        completion.messages[1..],
1003        vec![
1004            LanguageModelRequestMessage {
1005                role: Role::User,
1006                content: vec!["abc".into()],
1007                cache: false,
1008                reasoning_details: None,
1009            },
1010            LanguageModelRequestMessage {
1011                role: Role::Assistant,
1012                content: vec![MessageContent::ToolUse(tool_use)],
1013                cache: false,
1014                reasoning_details: None,
1015            },
1016            LanguageModelRequestMessage {
1017                role: Role::User,
1018                content: vec![MessageContent::ToolResult(tool_result)],
1019                cache: false,
1020                reasoning_details: None,
1021            },
1022            LanguageModelRequestMessage {
1023                role: Role::User,
1024                content: vec!["ghi".into()],
1025                cache: true,
1026                reasoning_details: None,
1027            }
1028        ]
1029    );
1030}
1031
1032async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
1033    let event = events
1034        .next()
1035        .await
1036        .expect("no tool call authorization event received")
1037        .unwrap();
1038    match event {
1039        ThreadEvent::ToolCall(tool_call) => tool_call,
1040        event => {
1041            panic!("Unexpected event {event:?}");
1042        }
1043    }
1044}
1045
1046async fn expect_tool_call_update_fields(
1047    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1048) -> acp::ToolCallUpdate {
1049    let event = events
1050        .next()
1051        .await
1052        .expect("no tool call authorization event received")
1053        .unwrap();
1054    match event {
1055        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
1056        event => {
1057            panic!("Unexpected event {event:?}");
1058        }
1059    }
1060}
1061
1062async fn next_tool_call_authorization(
1063    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
1064) -> ToolCallAuthorization {
1065    loop {
1066        let event = events
1067            .next()
1068            .await
1069            .expect("no tool call authorization event received")
1070            .unwrap();
1071        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
1072            let permission_kinds = tool_call_authorization
1073                .options
1074                .iter()
1075                .map(|o| o.kind)
1076                .collect::<Vec<_>>();
1077            assert_eq!(
1078                permission_kinds,
1079                vec![
1080                    acp::PermissionOptionKind::AllowAlways,
1081                    acp::PermissionOptionKind::AllowOnce,
1082                    acp::PermissionOptionKind::RejectOnce,
1083                ]
1084            );
1085            return tool_call_authorization;
1086        }
1087    }
1088}
1089
1090#[gpui::test]
1091#[cfg_attr(not(feature = "e2e"), ignore)]
1092async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
1093    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1094
1095    // Test concurrent tool calls with different delay times
1096    let events = thread
1097        .update(cx, |thread, cx| {
1098            thread.add_tool(DelayTool);
1099            thread.send(
1100                UserMessageId::new(),
1101                [
1102                    "Call the delay tool twice in the same message.",
1103                    "Once with 100ms. Once with 300ms.",
1104                    "When both timers are complete, describe the outputs.",
1105                ],
1106                cx,
1107            )
1108        })
1109        .unwrap()
1110        .collect()
1111        .await;
1112
1113    let stop_reasons = stop_events(events);
1114    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
1115
1116    thread.update(cx, |thread, _cx| {
1117        let last_message = thread.last_message().unwrap();
1118        let agent_message = last_message.as_agent_message().unwrap();
1119        let text = agent_message
1120            .content
1121            .iter()
1122            .filter_map(|content| {
1123                if let AgentMessageContent::Text(text) = content {
1124                    Some(text.as_str())
1125                } else {
1126                    None
1127                }
1128            })
1129            .collect::<String>();
1130
1131        assert!(text.contains("Ding"));
1132    });
1133}
1134
1135#[gpui::test]
1136async fn test_profiles(cx: &mut TestAppContext) {
1137    let ThreadTest {
1138        model, thread, fs, ..
1139    } = setup(cx, TestModel::Fake).await;
1140    let fake_model = model.as_fake();
1141
1142    thread.update(cx, |thread, _cx| {
1143        thread.add_tool(DelayTool);
1144        thread.add_tool(EchoTool);
1145        thread.add_tool(InfiniteTool);
1146    });
1147
1148    // Override profiles and wait for settings to be loaded.
1149    fs.insert_file(
1150        paths::settings_file(),
1151        json!({
1152            "agent": {
1153                "profiles": {
1154                    "test-1": {
1155                        "name": "Test Profile 1",
1156                        "tools": {
1157                            EchoTool::name(): true,
1158                            DelayTool::name(): true,
1159                        }
1160                    },
1161                    "test-2": {
1162                        "name": "Test Profile 2",
1163                        "tools": {
1164                            InfiniteTool::name(): true,
1165                        }
1166                    }
1167                }
1168            }
1169        })
1170        .to_string()
1171        .into_bytes(),
1172    )
1173    .await;
1174    cx.run_until_parked();
1175
1176    // Test that test-1 profile (default) has echo and delay tools
1177    thread
1178        .update(cx, |thread, cx| {
1179            thread.set_profile(AgentProfileId("test-1".into()), cx);
1180            thread.send(UserMessageId::new(), ["test"], cx)
1181        })
1182        .unwrap();
1183    cx.run_until_parked();
1184
1185    let mut pending_completions = fake_model.pending_completions();
1186    assert_eq!(pending_completions.len(), 1);
1187    let completion = pending_completions.pop().unwrap();
1188    let tool_names: Vec<String> = completion
1189        .tools
1190        .iter()
1191        .map(|tool| tool.name.clone())
1192        .collect();
1193    assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
1194    fake_model.end_last_completion_stream();
1195
1196    // Switch to test-2 profile, and verify that it has only the infinite tool.
1197    thread
1198        .update(cx, |thread, cx| {
1199            thread.set_profile(AgentProfileId("test-2".into()), cx);
1200            thread.send(UserMessageId::new(), ["test2"], cx)
1201        })
1202        .unwrap();
1203    cx.run_until_parked();
1204    let mut pending_completions = fake_model.pending_completions();
1205    assert_eq!(pending_completions.len(), 1);
1206    let completion = pending_completions.pop().unwrap();
1207    let tool_names: Vec<String> = completion
1208        .tools
1209        .iter()
1210        .map(|tool| tool.name.clone())
1211        .collect();
1212    assert_eq!(tool_names, vec![InfiniteTool::name()]);
1213}
1214
1215#[gpui::test]
1216async fn test_mcp_tools(cx: &mut TestAppContext) {
1217    let ThreadTest {
1218        model,
1219        thread,
1220        context_server_store,
1221        fs,
1222        ..
1223    } = setup(cx, TestModel::Fake).await;
1224    let fake_model = model.as_fake();
1225
1226    // Override profiles and wait for settings to be loaded.
1227    fs.insert_file(
1228        paths::settings_file(),
1229        json!({
1230            "agent": {
1231                "always_allow_tool_actions": true,
1232                "profiles": {
1233                    "test": {
1234                        "name": "Test Profile",
1235                        "enable_all_context_servers": true,
1236                        "tools": {
1237                            EchoTool::name(): true,
1238                        }
1239                    },
1240                }
1241            }
1242        })
1243        .to_string()
1244        .into_bytes(),
1245    )
1246    .await;
1247    cx.run_until_parked();
1248    thread.update(cx, |thread, cx| {
1249        thread.set_profile(AgentProfileId("test".into()), cx)
1250    });
1251
1252    let mut mcp_tool_calls = setup_context_server(
1253        "test_server",
1254        vec![context_server::types::Tool {
1255            name: "echo".into(),
1256            description: None,
1257            input_schema: serde_json::to_value(EchoTool::input_schema(
1258                LanguageModelToolSchemaFormat::JsonSchema,
1259            ))
1260            .unwrap(),
1261            output_schema: None,
1262            annotations: None,
1263        }],
1264        &context_server_store,
1265        cx,
1266    );
1267
1268    let events = thread.update(cx, |thread, cx| {
1269        thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
1270    });
1271    cx.run_until_parked();
1272
1273    // Simulate the model calling the MCP tool.
1274    let completion = fake_model.pending_completions().pop().unwrap();
1275    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1276    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1277        LanguageModelToolUse {
1278            id: "tool_1".into(),
1279            name: "echo".into(),
1280            raw_input: json!({"text": "test"}).to_string(),
1281            input: json!({"text": "test"}),
1282            is_input_complete: true,
1283            thought_signature: None,
1284        },
1285    ));
1286    fake_model.end_last_completion_stream();
1287    cx.run_until_parked();
1288
1289    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1290    assert_eq!(tool_call_params.name, "echo");
1291    assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1292    tool_call_response
1293        .send(context_server::types::CallToolResponse {
1294            content: vec![context_server::types::ToolResponseContent::Text {
1295                text: "test".into(),
1296            }],
1297            is_error: None,
1298            meta: None,
1299            structured_content: None,
1300        })
1301        .unwrap();
1302    cx.run_until_parked();
1303
1304    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1305    fake_model.send_last_completion_stream_text_chunk("Done!");
1306    fake_model.end_last_completion_stream();
1307    events.collect::<Vec<_>>().await;
1308
1309    // Send again after adding the echo tool, ensuring the name collision is resolved.
1310    let events = thread.update(cx, |thread, cx| {
1311        thread.add_tool(EchoTool);
1312        thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1313    });
1314    cx.run_until_parked();
1315    let completion = fake_model.pending_completions().pop().unwrap();
1316    assert_eq!(
1317        tool_names_for_completion(&completion),
1318        vec!["echo", "test_server_echo"]
1319    );
1320    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1321        LanguageModelToolUse {
1322            id: "tool_2".into(),
1323            name: "test_server_echo".into(),
1324            raw_input: json!({"text": "mcp"}).to_string(),
1325            input: json!({"text": "mcp"}),
1326            is_input_complete: true,
1327            thought_signature: None,
1328        },
1329    ));
1330    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1331        LanguageModelToolUse {
1332            id: "tool_3".into(),
1333            name: "echo".into(),
1334            raw_input: json!({"text": "native"}).to_string(),
1335            input: json!({"text": "native"}),
1336            is_input_complete: true,
1337            thought_signature: None,
1338        },
1339    ));
1340    fake_model.end_last_completion_stream();
1341    cx.run_until_parked();
1342
1343    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1344    assert_eq!(tool_call_params.name, "echo");
1345    assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1346    tool_call_response
1347        .send(context_server::types::CallToolResponse {
1348            content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1349            is_error: None,
1350            meta: None,
1351            structured_content: None,
1352        })
1353        .unwrap();
1354    cx.run_until_parked();
1355
1356    // Ensure the tool results were inserted with the correct names.
1357    let completion = fake_model.pending_completions().pop().unwrap();
1358    assert_eq!(
1359        completion.messages.last().unwrap().content,
1360        vec![
1361            MessageContent::ToolResult(LanguageModelToolResult {
1362                tool_use_id: "tool_3".into(),
1363                tool_name: "echo".into(),
1364                is_error: false,
1365                content: "native".into(),
1366                output: Some("native".into()),
1367            },),
1368            MessageContent::ToolResult(LanguageModelToolResult {
1369                tool_use_id: "tool_2".into(),
1370                tool_name: "test_server_echo".into(),
1371                is_error: false,
1372                content: "mcp".into(),
1373                output: Some("mcp".into()),
1374            },),
1375        ]
1376    );
1377    fake_model.end_last_completion_stream();
1378    events.collect::<Vec<_>>().await;
1379}
1380
1381#[gpui::test]
1382async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1383    let ThreadTest {
1384        model,
1385        thread,
1386        context_server_store,
1387        fs,
1388        ..
1389    } = setup(cx, TestModel::Fake).await;
1390    let fake_model = model.as_fake();
1391
1392    // Set up a profile with all tools enabled
1393    fs.insert_file(
1394        paths::settings_file(),
1395        json!({
1396            "agent": {
1397                "profiles": {
1398                    "test": {
1399                        "name": "Test Profile",
1400                        "enable_all_context_servers": true,
1401                        "tools": {
1402                            EchoTool::name(): true,
1403                            DelayTool::name(): true,
1404                            WordListTool::name(): true,
1405                            ToolRequiringPermission::name(): true,
1406                            InfiniteTool::name(): true,
1407                        }
1408                    },
1409                }
1410            }
1411        })
1412        .to_string()
1413        .into_bytes(),
1414    )
1415    .await;
1416    cx.run_until_parked();
1417
1418    thread.update(cx, |thread, cx| {
1419        thread.set_profile(AgentProfileId("test".into()), cx);
1420        thread.add_tool(EchoTool);
1421        thread.add_tool(DelayTool);
1422        thread.add_tool(WordListTool);
1423        thread.add_tool(ToolRequiringPermission);
1424        thread.add_tool(InfiniteTool);
1425    });
1426
1427    // Set up multiple context servers with some overlapping tool names
1428    let _server1_calls = setup_context_server(
1429        "xxx",
1430        vec![
1431            context_server::types::Tool {
1432                name: "echo".into(), // Conflicts with native EchoTool
1433                description: None,
1434                input_schema: serde_json::to_value(EchoTool::input_schema(
1435                    LanguageModelToolSchemaFormat::JsonSchema,
1436                ))
1437                .unwrap(),
1438                output_schema: None,
1439                annotations: None,
1440            },
1441            context_server::types::Tool {
1442                name: "unique_tool_1".into(),
1443                description: None,
1444                input_schema: json!({"type": "object", "properties": {}}),
1445                output_schema: None,
1446                annotations: None,
1447            },
1448        ],
1449        &context_server_store,
1450        cx,
1451    );
1452
1453    let _server2_calls = setup_context_server(
1454        "yyy",
1455        vec![
1456            context_server::types::Tool {
1457                name: "echo".into(), // Also conflicts with native EchoTool
1458                description: None,
1459                input_schema: serde_json::to_value(EchoTool::input_schema(
1460                    LanguageModelToolSchemaFormat::JsonSchema,
1461                ))
1462                .unwrap(),
1463                output_schema: None,
1464                annotations: None,
1465            },
1466            context_server::types::Tool {
1467                name: "unique_tool_2".into(),
1468                description: None,
1469                input_schema: json!({"type": "object", "properties": {}}),
1470                output_schema: None,
1471                annotations: None,
1472            },
1473            context_server::types::Tool {
1474                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1475                description: None,
1476                input_schema: json!({"type": "object", "properties": {}}),
1477                output_schema: None,
1478                annotations: None,
1479            },
1480            context_server::types::Tool {
1481                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1482                description: None,
1483                input_schema: json!({"type": "object", "properties": {}}),
1484                output_schema: None,
1485                annotations: None,
1486            },
1487        ],
1488        &context_server_store,
1489        cx,
1490    );
1491    let _server3_calls = setup_context_server(
1492        "zzz",
1493        vec![
1494            context_server::types::Tool {
1495                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1496                description: None,
1497                input_schema: json!({"type": "object", "properties": {}}),
1498                output_schema: None,
1499                annotations: None,
1500            },
1501            context_server::types::Tool {
1502                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1503                description: None,
1504                input_schema: json!({"type": "object", "properties": {}}),
1505                output_schema: None,
1506                annotations: None,
1507            },
1508            context_server::types::Tool {
1509                name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1510                description: None,
1511                input_schema: json!({"type": "object", "properties": {}}),
1512                output_schema: None,
1513                annotations: None,
1514            },
1515        ],
1516        &context_server_store,
1517        cx,
1518    );
1519
1520    thread
1521        .update(cx, |thread, cx| {
1522            thread.send(UserMessageId::new(), ["Go"], cx)
1523        })
1524        .unwrap();
1525    cx.run_until_parked();
1526    let completion = fake_model.pending_completions().pop().unwrap();
1527    assert_eq!(
1528        tool_names_for_completion(&completion),
1529        vec![
1530            "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1531            "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1532            "delay",
1533            "echo",
1534            "infinite",
1535            "tool_requiring_permission",
1536            "unique_tool_1",
1537            "unique_tool_2",
1538            "word_list",
1539            "xxx_echo",
1540            "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1541            "yyy_echo",
1542            "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1543        ]
1544    );
1545}
1546
1547#[gpui::test]
1548#[cfg_attr(not(feature = "e2e"), ignore)]
1549async fn test_cancellation(cx: &mut TestAppContext) {
1550    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1551
1552    let mut events = thread
1553        .update(cx, |thread, cx| {
1554            thread.add_tool(InfiniteTool);
1555            thread.add_tool(EchoTool);
1556            thread.send(
1557                UserMessageId::new(),
1558                ["Call the echo tool, then call the infinite tool, then explain their output"],
1559                cx,
1560            )
1561        })
1562        .unwrap();
1563
1564    // Wait until both tools are called.
1565    let mut expected_tools = vec!["Echo", "Infinite Tool"];
1566    let mut echo_id = None;
1567    let mut echo_completed = false;
1568    while let Some(event) = events.next().await {
1569        match event.unwrap() {
1570            ThreadEvent::ToolCall(tool_call) => {
1571                assert_eq!(tool_call.title, expected_tools.remove(0));
1572                if tool_call.title == "Echo" {
1573                    echo_id = Some(tool_call.tool_call_id);
1574                }
1575            }
1576            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1577                acp::ToolCallUpdate {
1578                    tool_call_id,
1579                    fields:
1580                        acp::ToolCallUpdateFields {
1581                            status: Some(acp::ToolCallStatus::Completed),
1582                            ..
1583                        },
1584                    ..
1585                },
1586            )) if Some(&tool_call_id) == echo_id.as_ref() => {
1587                echo_completed = true;
1588            }
1589            _ => {}
1590        }
1591
1592        if expected_tools.is_empty() && echo_completed {
1593            break;
1594        }
1595    }
1596
1597    // Cancel the current send and ensure that the event stream is closed, even
1598    // if one of the tools is still running.
1599    thread.update(cx, |thread, cx| thread.cancel(cx));
1600    let events = events.collect::<Vec<_>>().await;
1601    let last_event = events.last();
1602    assert!(
1603        matches!(
1604            last_event,
1605            Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1606        ),
1607        "unexpected event {last_event:?}"
1608    );
1609
1610    // Ensure we can still send a new message after cancellation.
1611    let events = thread
1612        .update(cx, |thread, cx| {
1613            thread.send(
1614                UserMessageId::new(),
1615                ["Testing: reply with 'Hello' then stop."],
1616                cx,
1617            )
1618        })
1619        .unwrap()
1620        .collect::<Vec<_>>()
1621        .await;
1622    thread.update(cx, |thread, _cx| {
1623        let message = thread.last_message().unwrap();
1624        let agent_message = message.as_agent_message().unwrap();
1625        assert_eq!(
1626            agent_message.content,
1627            vec![AgentMessageContent::Text("Hello".to_string())]
1628        );
1629    });
1630    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1631}
1632
1633#[gpui::test]
1634async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1635    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1636    let fake_model = model.as_fake();
1637
1638    let events_1 = thread
1639        .update(cx, |thread, cx| {
1640            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1641        })
1642        .unwrap();
1643    cx.run_until_parked();
1644    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1645    cx.run_until_parked();
1646
1647    let events_2 = thread
1648        .update(cx, |thread, cx| {
1649            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1650        })
1651        .unwrap();
1652    cx.run_until_parked();
1653    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1654    fake_model
1655        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1656    fake_model.end_last_completion_stream();
1657
1658    let events_1 = events_1.collect::<Vec<_>>().await;
1659    assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1660    let events_2 = events_2.collect::<Vec<_>>().await;
1661    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1662}
1663
1664#[gpui::test]
1665async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1666    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1667    let fake_model = model.as_fake();
1668
1669    let events_1 = thread
1670        .update(cx, |thread, cx| {
1671            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1672        })
1673        .unwrap();
1674    cx.run_until_parked();
1675    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1676    fake_model
1677        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1678    fake_model.end_last_completion_stream();
1679    let events_1 = events_1.collect::<Vec<_>>().await;
1680
1681    let events_2 = thread
1682        .update(cx, |thread, cx| {
1683            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1684        })
1685        .unwrap();
1686    cx.run_until_parked();
1687    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1688    fake_model
1689        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1690    fake_model.end_last_completion_stream();
1691    let events_2 = events_2.collect::<Vec<_>>().await;
1692
1693    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1694    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1695}
1696
1697#[gpui::test]
1698async fn test_refusal(cx: &mut TestAppContext) {
1699    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1700    let fake_model = model.as_fake();
1701
1702    let events = thread
1703        .update(cx, |thread, cx| {
1704            thread.send(UserMessageId::new(), ["Hello"], cx)
1705        })
1706        .unwrap();
1707    cx.run_until_parked();
1708    thread.read_with(cx, |thread, _| {
1709        assert_eq!(
1710            thread.to_markdown(),
1711            indoc! {"
1712                ## User
1713
1714                Hello
1715            "}
1716        );
1717    });
1718
1719    fake_model.send_last_completion_stream_text_chunk("Hey!");
1720    cx.run_until_parked();
1721    thread.read_with(cx, |thread, _| {
1722        assert_eq!(
1723            thread.to_markdown(),
1724            indoc! {"
1725                ## User
1726
1727                Hello
1728
1729                ## Assistant
1730
1731                Hey!
1732            "}
1733        );
1734    });
1735
1736    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1737    fake_model
1738        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1739    let events = events.collect::<Vec<_>>().await;
1740    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1741    thread.read_with(cx, |thread, _| {
1742        assert_eq!(thread.to_markdown(), "");
1743    });
1744}
1745
1746#[gpui::test]
1747async fn test_truncate_first_message(cx: &mut TestAppContext) {
1748    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1749    let fake_model = model.as_fake();
1750
1751    let message_id = UserMessageId::new();
1752    thread
1753        .update(cx, |thread, cx| {
1754            thread.send(message_id.clone(), ["Hello"], cx)
1755        })
1756        .unwrap();
1757    cx.run_until_parked();
1758    thread.read_with(cx, |thread, _| {
1759        assert_eq!(
1760            thread.to_markdown(),
1761            indoc! {"
1762                ## User
1763
1764                Hello
1765            "}
1766        );
1767        assert_eq!(thread.latest_token_usage(), None);
1768    });
1769
1770    fake_model.send_last_completion_stream_text_chunk("Hey!");
1771    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1772        language_model::TokenUsage {
1773            input_tokens: 32_000,
1774            output_tokens: 16_000,
1775            cache_creation_input_tokens: 0,
1776            cache_read_input_tokens: 0,
1777        },
1778    ));
1779    cx.run_until_parked();
1780    thread.read_with(cx, |thread, _| {
1781        assert_eq!(
1782            thread.to_markdown(),
1783            indoc! {"
1784                ## User
1785
1786                Hello
1787
1788                ## Assistant
1789
1790                Hey!
1791            "}
1792        );
1793        assert_eq!(
1794            thread.latest_token_usage(),
1795            Some(acp_thread::TokenUsage {
1796                used_tokens: 32_000 + 16_000,
1797                max_tokens: 1_000_000,
1798            })
1799        );
1800    });
1801
1802    thread
1803        .update(cx, |thread, cx| thread.truncate(message_id, cx))
1804        .unwrap();
1805    cx.run_until_parked();
1806    thread.read_with(cx, |thread, _| {
1807        assert_eq!(thread.to_markdown(), "");
1808        assert_eq!(thread.latest_token_usage(), None);
1809    });
1810
1811    // Ensure we can still send a new message after truncation.
1812    thread
1813        .update(cx, |thread, cx| {
1814            thread.send(UserMessageId::new(), ["Hi"], cx)
1815        })
1816        .unwrap();
1817    thread.update(cx, |thread, _cx| {
1818        assert_eq!(
1819            thread.to_markdown(),
1820            indoc! {"
1821                ## User
1822
1823                Hi
1824            "}
1825        );
1826    });
1827    cx.run_until_parked();
1828    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1829    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1830        language_model::TokenUsage {
1831            input_tokens: 40_000,
1832            output_tokens: 20_000,
1833            cache_creation_input_tokens: 0,
1834            cache_read_input_tokens: 0,
1835        },
1836    ));
1837    cx.run_until_parked();
1838    thread.read_with(cx, |thread, _| {
1839        assert_eq!(
1840            thread.to_markdown(),
1841            indoc! {"
1842                ## User
1843
1844                Hi
1845
1846                ## Assistant
1847
1848                Ahoy!
1849            "}
1850        );
1851
1852        assert_eq!(
1853            thread.latest_token_usage(),
1854            Some(acp_thread::TokenUsage {
1855                used_tokens: 40_000 + 20_000,
1856                max_tokens: 1_000_000,
1857            })
1858        );
1859    });
1860}
1861
1862#[gpui::test]
1863async fn test_truncate_second_message(cx: &mut TestAppContext) {
1864    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1865    let fake_model = model.as_fake();
1866
1867    thread
1868        .update(cx, |thread, cx| {
1869            thread.send(UserMessageId::new(), ["Message 1"], cx)
1870        })
1871        .unwrap();
1872    cx.run_until_parked();
1873    fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1874    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1875        language_model::TokenUsage {
1876            input_tokens: 32_000,
1877            output_tokens: 16_000,
1878            cache_creation_input_tokens: 0,
1879            cache_read_input_tokens: 0,
1880        },
1881    ));
1882    fake_model.end_last_completion_stream();
1883    cx.run_until_parked();
1884
1885    let assert_first_message_state = |cx: &mut TestAppContext| {
1886        thread.clone().read_with(cx, |thread, _| {
1887            assert_eq!(
1888                thread.to_markdown(),
1889                indoc! {"
1890                    ## User
1891
1892                    Message 1
1893
1894                    ## Assistant
1895
1896                    Message 1 response
1897                "}
1898            );
1899
1900            assert_eq!(
1901                thread.latest_token_usage(),
1902                Some(acp_thread::TokenUsage {
1903                    used_tokens: 32_000 + 16_000,
1904                    max_tokens: 1_000_000,
1905                })
1906            );
1907        });
1908    };
1909
1910    assert_first_message_state(cx);
1911
1912    let second_message_id = UserMessageId::new();
1913    thread
1914        .update(cx, |thread, cx| {
1915            thread.send(second_message_id.clone(), ["Message 2"], cx)
1916        })
1917        .unwrap();
1918    cx.run_until_parked();
1919
1920    fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1921    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1922        language_model::TokenUsage {
1923            input_tokens: 40_000,
1924            output_tokens: 20_000,
1925            cache_creation_input_tokens: 0,
1926            cache_read_input_tokens: 0,
1927        },
1928    ));
1929    fake_model.end_last_completion_stream();
1930    cx.run_until_parked();
1931
1932    thread.read_with(cx, |thread, _| {
1933        assert_eq!(
1934            thread.to_markdown(),
1935            indoc! {"
1936                ## User
1937
1938                Message 1
1939
1940                ## Assistant
1941
1942                Message 1 response
1943
1944                ## User
1945
1946                Message 2
1947
1948                ## Assistant
1949
1950                Message 2 response
1951            "}
1952        );
1953
1954        assert_eq!(
1955            thread.latest_token_usage(),
1956            Some(acp_thread::TokenUsage {
1957                used_tokens: 40_000 + 20_000,
1958                max_tokens: 1_000_000,
1959            })
1960        );
1961    });
1962
1963    thread
1964        .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1965        .unwrap();
1966    cx.run_until_parked();
1967
1968    assert_first_message_state(cx);
1969}
1970
1971#[gpui::test]
1972async fn test_title_generation(cx: &mut TestAppContext) {
1973    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1974    let fake_model = model.as_fake();
1975
1976    let summary_model = Arc::new(FakeLanguageModel::default());
1977    thread.update(cx, |thread, cx| {
1978        thread.set_summarization_model(Some(summary_model.clone()), cx)
1979    });
1980
1981    let send = thread
1982        .update(cx, |thread, cx| {
1983            thread.send(UserMessageId::new(), ["Hello"], cx)
1984        })
1985        .unwrap();
1986    cx.run_until_parked();
1987
1988    fake_model.send_last_completion_stream_text_chunk("Hey!");
1989    fake_model.end_last_completion_stream();
1990    cx.run_until_parked();
1991    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1992
1993    // Ensure the summary model has been invoked to generate a title.
1994    summary_model.send_last_completion_stream_text_chunk("Hello ");
1995    summary_model.send_last_completion_stream_text_chunk("world\nG");
1996    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1997    summary_model.end_last_completion_stream();
1998    send.collect::<Vec<_>>().await;
1999    cx.run_until_parked();
2000    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2001
2002    // Send another message, ensuring no title is generated this time.
2003    let send = thread
2004        .update(cx, |thread, cx| {
2005            thread.send(UserMessageId::new(), ["Hello again"], cx)
2006        })
2007        .unwrap();
2008    cx.run_until_parked();
2009    fake_model.send_last_completion_stream_text_chunk("Hey again!");
2010    fake_model.end_last_completion_stream();
2011    cx.run_until_parked();
2012    assert_eq!(summary_model.pending_completions(), Vec::new());
2013    send.collect::<Vec<_>>().await;
2014    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
2015}
2016
2017#[gpui::test]
2018async fn test_building_request_with_pending_tools(cx: &mut TestAppContext) {
2019    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2020    let fake_model = model.as_fake();
2021
2022    let _events = thread
2023        .update(cx, |thread, cx| {
2024            thread.add_tool(ToolRequiringPermission);
2025            thread.add_tool(EchoTool);
2026            thread.send(UserMessageId::new(), ["Hey!"], cx)
2027        })
2028        .unwrap();
2029    cx.run_until_parked();
2030
2031    let permission_tool_use = LanguageModelToolUse {
2032        id: "tool_id_1".into(),
2033        name: ToolRequiringPermission::name().into(),
2034        raw_input: "{}".into(),
2035        input: json!({}),
2036        is_input_complete: true,
2037        thought_signature: None,
2038    };
2039    let echo_tool_use = LanguageModelToolUse {
2040        id: "tool_id_2".into(),
2041        name: EchoTool::name().into(),
2042        raw_input: json!({"text": "test"}).to_string(),
2043        input: json!({"text": "test"}),
2044        is_input_complete: true,
2045        thought_signature: None,
2046    };
2047    fake_model.send_last_completion_stream_text_chunk("Hi!");
2048    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2049        permission_tool_use,
2050    ));
2051    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2052        echo_tool_use.clone(),
2053    ));
2054    fake_model.end_last_completion_stream();
2055    cx.run_until_parked();
2056
2057    // Ensure pending tools are skipped when building a request.
2058    let request = thread
2059        .read_with(cx, |thread, cx| {
2060            thread.build_completion_request(CompletionIntent::EditFile, cx)
2061        })
2062        .unwrap();
2063    assert_eq!(
2064        request.messages[1..],
2065        vec![
2066            LanguageModelRequestMessage {
2067                role: Role::User,
2068                content: vec!["Hey!".into()],
2069                cache: true,
2070                reasoning_details: None,
2071            },
2072            LanguageModelRequestMessage {
2073                role: Role::Assistant,
2074                content: vec![
2075                    MessageContent::Text("Hi!".into()),
2076                    MessageContent::ToolUse(echo_tool_use.clone())
2077                ],
2078                cache: false,
2079                reasoning_details: None,
2080            },
2081            LanguageModelRequestMessage {
2082                role: Role::User,
2083                content: vec![MessageContent::ToolResult(LanguageModelToolResult {
2084                    tool_use_id: echo_tool_use.id.clone(),
2085                    tool_name: echo_tool_use.name,
2086                    is_error: false,
2087                    content: "test".into(),
2088                    output: Some("test".into())
2089                })],
2090                cache: false,
2091                reasoning_details: None,
2092            },
2093        ],
2094    );
2095}
2096
2097#[gpui::test]
2098async fn test_agent_connection(cx: &mut TestAppContext) {
2099    cx.update(settings::init);
2100    let templates = Templates::new();
2101
2102    // Initialize language model system with test provider
2103    cx.update(|cx| {
2104        gpui_tokio::init(cx);
2105
2106        let http_client = FakeHttpClient::with_404_response();
2107        let clock = Arc::new(clock::FakeSystemClock::new());
2108        let client = Client::new(clock, http_client, cx);
2109        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2110        language_model::init(client.clone(), cx);
2111        language_models::init(user_store, client.clone(), cx);
2112        LanguageModelRegistry::test(cx);
2113    });
2114    cx.executor().forbid_parking();
2115
2116    // Create a project for new_thread
2117    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
2118    fake_fs.insert_tree(path!("/test"), json!({})).await;
2119    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
2120    let cwd = Path::new("/test");
2121    let text_thread_store =
2122        cx.new(|cx| assistant_text_thread::TextThreadStore::fake(project.clone(), cx));
2123    let history_store = cx.new(|cx| HistoryStore::new(text_thread_store, cx));
2124
2125    // Create agent and connection
2126    let agent = NativeAgent::new(
2127        project.clone(),
2128        history_store,
2129        templates.clone(),
2130        None,
2131        fake_fs.clone(),
2132        &mut cx.to_async(),
2133    )
2134    .await
2135    .unwrap();
2136    let connection = NativeAgentConnection(agent.clone());
2137
2138    // Create a thread using new_thread
2139    let connection_rc = Rc::new(connection.clone());
2140    let acp_thread = cx
2141        .update(|cx| connection_rc.new_thread(project, cwd, cx))
2142        .await
2143        .expect("new_thread should succeed");
2144
2145    // Get the session_id from the AcpThread
2146    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
2147
2148    // Test model_selector returns Some
2149    let selector_opt = connection.model_selector(&session_id);
2150    assert!(
2151        selector_opt.is_some(),
2152        "agent should always support ModelSelector"
2153    );
2154    let selector = selector_opt.unwrap();
2155
2156    // Test list_models
2157    let listed_models = cx
2158        .update(|cx| selector.list_models(cx))
2159        .await
2160        .expect("list_models should succeed");
2161    let AgentModelList::Grouped(listed_models) = listed_models else {
2162        panic!("Unexpected model list type");
2163    };
2164    assert!(!listed_models.is_empty(), "should have at least one model");
2165    assert_eq!(
2166        listed_models[&AgentModelGroupName("Fake".into())][0]
2167            .id
2168            .0
2169            .as_ref(),
2170        "fake/fake"
2171    );
2172
2173    // Test selected_model returns the default
2174    let model = cx
2175        .update(|cx| selector.selected_model(cx))
2176        .await
2177        .expect("selected_model should succeed");
2178    let model = cx
2179        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
2180        .unwrap();
2181    let model = model.as_fake();
2182    assert_eq!(model.id().0, "fake", "should return default model");
2183
2184    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
2185    cx.run_until_parked();
2186    model.send_last_completion_stream_text_chunk("def");
2187    cx.run_until_parked();
2188    acp_thread.read_with(cx, |thread, cx| {
2189        assert_eq!(
2190            thread.to_markdown(cx),
2191            indoc! {"
2192                ## User
2193
2194                abc
2195
2196                ## Assistant
2197
2198                def
2199
2200            "}
2201        )
2202    });
2203
2204    // Test cancel
2205    cx.update(|cx| connection.cancel(&session_id, cx));
2206    request.await.expect("prompt should fail gracefully");
2207
2208    // Ensure that dropping the ACP thread causes the native thread to be
2209    // dropped as well.
2210    cx.update(|_| drop(acp_thread));
2211    let result = cx
2212        .update(|cx| {
2213            connection.prompt(
2214                Some(acp_thread::UserMessageId::new()),
2215                acp::PromptRequest::new(session_id.clone(), vec!["ghi".into()]),
2216                cx,
2217            )
2218        })
2219        .await;
2220    assert_eq!(
2221        result.as_ref().unwrap_err().to_string(),
2222        "Session not found",
2223        "unexpected result: {:?}",
2224        result
2225    );
2226}
2227
2228#[gpui::test]
2229async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
2230    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2231    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
2232    let fake_model = model.as_fake();
2233
2234    let mut events = thread
2235        .update(cx, |thread, cx| {
2236            thread.send(UserMessageId::new(), ["Think"], cx)
2237        })
2238        .unwrap();
2239    cx.run_until_parked();
2240
2241    // Simulate streaming partial input.
2242    let input = json!({});
2243    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2244        LanguageModelToolUse {
2245            id: "1".into(),
2246            name: ThinkingTool::name().into(),
2247            raw_input: input.to_string(),
2248            input,
2249            is_input_complete: false,
2250            thought_signature: None,
2251        },
2252    ));
2253
2254    // Input streaming completed
2255    let input = json!({ "content": "Thinking hard!" });
2256    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2257        LanguageModelToolUse {
2258            id: "1".into(),
2259            name: "thinking".into(),
2260            raw_input: input.to_string(),
2261            input,
2262            is_input_complete: true,
2263            thought_signature: None,
2264        },
2265    ));
2266    fake_model.end_last_completion_stream();
2267    cx.run_until_parked();
2268
2269    let tool_call = expect_tool_call(&mut events).await;
2270    assert_eq!(
2271        tool_call,
2272        acp::ToolCall::new("1", "Thinking")
2273            .kind(acp::ToolKind::Think)
2274            .raw_input(json!({}))
2275            .meta(acp::Meta::from_iter([(
2276                "tool_name".into(),
2277                "thinking".into()
2278            )]))
2279    );
2280    let update = expect_tool_call_update_fields(&mut events).await;
2281    assert_eq!(
2282        update,
2283        acp::ToolCallUpdate::new(
2284            "1",
2285            acp::ToolCallUpdateFields::new()
2286                .title("Thinking")
2287                .kind(acp::ToolKind::Think)
2288                .raw_input(json!({ "content": "Thinking hard!"}))
2289        )
2290    );
2291    let update = expect_tool_call_update_fields(&mut events).await;
2292    assert_eq!(
2293        update,
2294        acp::ToolCallUpdate::new(
2295            "1",
2296            acp::ToolCallUpdateFields::new().status(acp::ToolCallStatus::InProgress)
2297        )
2298    );
2299    let update = expect_tool_call_update_fields(&mut events).await;
2300    assert_eq!(
2301        update,
2302        acp::ToolCallUpdate::new(
2303            "1",
2304            acp::ToolCallUpdateFields::new().content(vec!["Thinking hard!".into()])
2305        )
2306    );
2307    let update = expect_tool_call_update_fields(&mut events).await;
2308    assert_eq!(
2309        update,
2310        acp::ToolCallUpdate::new(
2311            "1",
2312            acp::ToolCallUpdateFields::new()
2313                .status(acp::ToolCallStatus::Completed)
2314                .raw_output("Finished thinking.")
2315        )
2316    );
2317}
2318
2319#[gpui::test]
2320async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
2321    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2322    let fake_model = model.as_fake();
2323
2324    let mut events = thread
2325        .update(cx, |thread, cx| {
2326            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2327            thread.send(UserMessageId::new(), ["Hello!"], cx)
2328        })
2329        .unwrap();
2330    cx.run_until_parked();
2331
2332    fake_model.send_last_completion_stream_text_chunk("Hey!");
2333    fake_model.end_last_completion_stream();
2334
2335    let mut retry_events = Vec::new();
2336    while let Some(Ok(event)) = events.next().await {
2337        match event {
2338            ThreadEvent::Retry(retry_status) => {
2339                retry_events.push(retry_status);
2340            }
2341            ThreadEvent::Stop(..) => break,
2342            _ => {}
2343        }
2344    }
2345
2346    assert_eq!(retry_events.len(), 0);
2347    thread.read_with(cx, |thread, _cx| {
2348        assert_eq!(
2349            thread.to_markdown(),
2350            indoc! {"
2351                ## User
2352
2353                Hello!
2354
2355                ## Assistant
2356
2357                Hey!
2358            "}
2359        )
2360    });
2361}
2362
2363#[gpui::test]
2364async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2365    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2366    let fake_model = model.as_fake();
2367
2368    let mut events = thread
2369        .update(cx, |thread, cx| {
2370            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2371            thread.send(UserMessageId::new(), ["Hello!"], cx)
2372        })
2373        .unwrap();
2374    cx.run_until_parked();
2375
2376    fake_model.send_last_completion_stream_text_chunk("Hey,");
2377    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2378        provider: LanguageModelProviderName::new("Anthropic"),
2379        retry_after: Some(Duration::from_secs(3)),
2380    });
2381    fake_model.end_last_completion_stream();
2382
2383    cx.executor().advance_clock(Duration::from_secs(3));
2384    cx.run_until_parked();
2385
2386    fake_model.send_last_completion_stream_text_chunk("there!");
2387    fake_model.end_last_completion_stream();
2388    cx.run_until_parked();
2389
2390    let mut retry_events = Vec::new();
2391    while let Some(Ok(event)) = events.next().await {
2392        match event {
2393            ThreadEvent::Retry(retry_status) => {
2394                retry_events.push(retry_status);
2395            }
2396            ThreadEvent::Stop(..) => break,
2397            _ => {}
2398        }
2399    }
2400
2401    assert_eq!(retry_events.len(), 1);
2402    assert!(matches!(
2403        retry_events[0],
2404        acp_thread::RetryStatus { attempt: 1, .. }
2405    ));
2406    thread.read_with(cx, |thread, _cx| {
2407        assert_eq!(
2408            thread.to_markdown(),
2409            indoc! {"
2410                ## User
2411
2412                Hello!
2413
2414                ## Assistant
2415
2416                Hey,
2417
2418                [resume]
2419
2420                ## Assistant
2421
2422                there!
2423            "}
2424        )
2425    });
2426}
2427
2428#[gpui::test]
2429async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
2430    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2431    let fake_model = model.as_fake();
2432
2433    let events = thread
2434        .update(cx, |thread, cx| {
2435            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2436            thread.add_tool(EchoTool);
2437            thread.send(UserMessageId::new(), ["Call the echo tool!"], cx)
2438        })
2439        .unwrap();
2440    cx.run_until_parked();
2441
2442    let tool_use_1 = LanguageModelToolUse {
2443        id: "tool_1".into(),
2444        name: EchoTool::name().into(),
2445        raw_input: json!({"text": "test"}).to_string(),
2446        input: json!({"text": "test"}),
2447        is_input_complete: true,
2448        thought_signature: None,
2449    };
2450    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
2451        tool_use_1.clone(),
2452    ));
2453    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2454        provider: LanguageModelProviderName::new("Anthropic"),
2455        retry_after: Some(Duration::from_secs(3)),
2456    });
2457    fake_model.end_last_completion_stream();
2458
2459    cx.executor().advance_clock(Duration::from_secs(3));
2460    let completion = fake_model.pending_completions().pop().unwrap();
2461    assert_eq!(
2462        completion.messages[1..],
2463        vec![
2464            LanguageModelRequestMessage {
2465                role: Role::User,
2466                content: vec!["Call the echo tool!".into()],
2467                cache: false,
2468                reasoning_details: None,
2469            },
2470            LanguageModelRequestMessage {
2471                role: Role::Assistant,
2472                content: vec![language_model::MessageContent::ToolUse(tool_use_1.clone())],
2473                cache: false,
2474                reasoning_details: None,
2475            },
2476            LanguageModelRequestMessage {
2477                role: Role::User,
2478                content: vec![language_model::MessageContent::ToolResult(
2479                    LanguageModelToolResult {
2480                        tool_use_id: tool_use_1.id.clone(),
2481                        tool_name: tool_use_1.name.clone(),
2482                        is_error: false,
2483                        content: "test".into(),
2484                        output: Some("test".into())
2485                    }
2486                )],
2487                cache: true,
2488                reasoning_details: None,
2489            },
2490        ]
2491    );
2492
2493    fake_model.send_last_completion_stream_text_chunk("Done");
2494    fake_model.end_last_completion_stream();
2495    cx.run_until_parked();
2496    events.collect::<Vec<_>>().await;
2497    thread.read_with(cx, |thread, _cx| {
2498        assert_eq!(
2499            thread.last_message(),
2500            Some(Message::Agent(AgentMessage {
2501                content: vec![AgentMessageContent::Text("Done".into())],
2502                tool_results: IndexMap::default(),
2503                reasoning_details: None,
2504            }))
2505        );
2506    })
2507}
2508
2509#[gpui::test]
2510async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2511    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2512    let fake_model = model.as_fake();
2513
2514    let mut events = thread
2515        .update(cx, |thread, cx| {
2516            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2517            thread.send(UserMessageId::new(), ["Hello!"], cx)
2518        })
2519        .unwrap();
2520    cx.run_until_parked();
2521
2522    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2523        fake_model.send_last_completion_stream_error(
2524            LanguageModelCompletionError::ServerOverloaded {
2525                provider: LanguageModelProviderName::new("Anthropic"),
2526                retry_after: Some(Duration::from_secs(3)),
2527            },
2528        );
2529        fake_model.end_last_completion_stream();
2530        cx.executor().advance_clock(Duration::from_secs(3));
2531        cx.run_until_parked();
2532    }
2533
2534    let mut errors = Vec::new();
2535    let mut retry_events = Vec::new();
2536    while let Some(event) = events.next().await {
2537        match event {
2538            Ok(ThreadEvent::Retry(retry_status)) => {
2539                retry_events.push(retry_status);
2540            }
2541            Ok(ThreadEvent::Stop(..)) => break,
2542            Err(error) => errors.push(error),
2543            _ => {}
2544        }
2545    }
2546
2547    assert_eq!(
2548        retry_events.len(),
2549        crate::thread::MAX_RETRY_ATTEMPTS as usize
2550    );
2551    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2552        assert_eq!(retry_events[i].attempt, i + 1);
2553    }
2554    assert_eq!(errors.len(), 1);
2555    let error = errors[0]
2556        .downcast_ref::<LanguageModelCompletionError>()
2557        .unwrap();
2558    assert!(matches!(
2559        error,
2560        LanguageModelCompletionError::ServerOverloaded { .. }
2561    ));
2562}
2563
2564/// Filters out the stop events for asserting against in tests
2565fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2566    result_events
2567        .into_iter()
2568        .filter_map(|event| match event.unwrap() {
2569            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2570            _ => None,
2571        })
2572        .collect()
2573}
2574
2575struct ThreadTest {
2576    model: Arc<dyn LanguageModel>,
2577    thread: Entity<Thread>,
2578    project_context: Entity<ProjectContext>,
2579    context_server_store: Entity<ContextServerStore>,
2580    fs: Arc<FakeFs>,
2581}
2582
2583enum TestModel {
2584    Sonnet4,
2585    Fake,
2586}
2587
2588impl TestModel {
2589    fn id(&self) -> LanguageModelId {
2590        match self {
2591            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2592            TestModel::Fake => unreachable!(),
2593        }
2594    }
2595}
2596
2597async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2598    cx.executor().allow_parking();
2599
2600    let fs = FakeFs::new(cx.background_executor.clone());
2601    fs.create_dir(paths::settings_file().parent().unwrap())
2602        .await
2603        .unwrap();
2604    fs.insert_file(
2605        paths::settings_file(),
2606        json!({
2607            "agent": {
2608                "default_profile": "test-profile",
2609                "profiles": {
2610                    "test-profile": {
2611                        "name": "Test Profile",
2612                        "tools": {
2613                            EchoTool::name(): true,
2614                            DelayTool::name(): true,
2615                            WordListTool::name(): true,
2616                            ToolRequiringPermission::name(): true,
2617                            InfiniteTool::name(): true,
2618                            ThinkingTool::name(): true,
2619                        }
2620                    }
2621                }
2622            }
2623        })
2624        .to_string()
2625        .into_bytes(),
2626    )
2627    .await;
2628
2629    cx.update(|cx| {
2630        settings::init(cx);
2631
2632        match model {
2633            TestModel::Fake => {}
2634            TestModel::Sonnet4 => {
2635                gpui_tokio::init(cx);
2636                let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2637                cx.set_http_client(Arc::new(http_client));
2638                let client = Client::production(cx);
2639                let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2640                language_model::init(client.clone(), cx);
2641                language_models::init(user_store, client.clone(), cx);
2642            }
2643        };
2644
2645        watch_settings(fs.clone(), cx);
2646    });
2647
2648    let templates = Templates::new();
2649
2650    fs.insert_tree(path!("/test"), json!({})).await;
2651    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2652
2653    let model = cx
2654        .update(|cx| {
2655            if let TestModel::Fake = model {
2656                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2657            } else {
2658                let model_id = model.id();
2659                let models = LanguageModelRegistry::read_global(cx);
2660                let model = models
2661                    .available_models(cx)
2662                    .find(|model| model.id() == model_id)
2663                    .unwrap();
2664
2665                let provider = models.provider(&model.provider_id()).unwrap();
2666                let authenticated = provider.authenticate(cx);
2667
2668                cx.spawn(async move |_cx| {
2669                    authenticated.await.unwrap();
2670                    model
2671                })
2672            }
2673        })
2674        .await;
2675
2676    let project_context = cx.new(|_cx| ProjectContext::default());
2677    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2678    let context_server_registry =
2679        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2680    let thread = cx.new(|cx| {
2681        Thread::new(
2682            project,
2683            project_context.clone(),
2684            context_server_registry,
2685            templates,
2686            Some(model.clone()),
2687            cx,
2688        )
2689    });
2690    ThreadTest {
2691        model,
2692        thread,
2693        project_context,
2694        context_server_store,
2695        fs,
2696    }
2697}
2698
2699#[cfg(test)]
2700#[ctor::ctor]
2701fn init_logger() {
2702    if std::env::var("RUST_LOG").is_ok() {
2703        env_logger::init();
2704    }
2705}
2706
2707fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2708    let fs = fs.clone();
2709    cx.spawn({
2710        async move |cx| {
2711            let mut new_settings_content_rx = settings::watch_config_file(
2712                cx.background_executor(),
2713                fs,
2714                paths::settings_file().clone(),
2715            );
2716
2717            while let Some(new_settings_content) = new_settings_content_rx.next().await {
2718                cx.update(|cx| {
2719                    SettingsStore::update_global(cx, |settings, cx| {
2720                        settings.set_user_settings(&new_settings_content, cx)
2721                    })
2722                })
2723                .ok();
2724            }
2725        }
2726    })
2727    .detach();
2728}
2729
2730fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2731    completion
2732        .tools
2733        .iter()
2734        .map(|tool| tool.name.clone())
2735        .collect()
2736}
2737
2738fn setup_context_server(
2739    name: &'static str,
2740    tools: Vec<context_server::types::Tool>,
2741    context_server_store: &Entity<ContextServerStore>,
2742    cx: &mut TestAppContext,
2743) -> mpsc::UnboundedReceiver<(
2744    context_server::types::CallToolParams,
2745    oneshot::Sender<context_server::types::CallToolResponse>,
2746)> {
2747    cx.update(|cx| {
2748        let mut settings = ProjectSettings::get_global(cx).clone();
2749        settings.context_servers.insert(
2750            name.into(),
2751            project::project_settings::ContextServerSettings::Stdio {
2752                enabled: true,
2753                command: ContextServerCommand {
2754                    path: "somebinary".into(),
2755                    args: Vec::new(),
2756                    env: None,
2757                    timeout: None,
2758                },
2759            },
2760        );
2761        ProjectSettings::override_global(settings, cx);
2762    });
2763
2764    let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2765    let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2766        .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2767            context_server::types::InitializeResponse {
2768                protocol_version: context_server::types::ProtocolVersion(
2769                    context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2770                ),
2771                server_info: context_server::types::Implementation {
2772                    name: name.into(),
2773                    version: "1.0.0".to_string(),
2774                },
2775                capabilities: context_server::types::ServerCapabilities {
2776                    tools: Some(context_server::types::ToolsCapabilities {
2777                        list_changed: Some(true),
2778                    }),
2779                    ..Default::default()
2780                },
2781                meta: None,
2782            }
2783        })
2784        .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2785            let tools = tools.clone();
2786            async move {
2787                context_server::types::ListToolsResponse {
2788                    tools,
2789                    next_cursor: None,
2790                    meta: None,
2791                }
2792            }
2793        })
2794        .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2795            let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2796            async move {
2797                let (response_tx, response_rx) = oneshot::channel();
2798                mcp_tool_calls_tx
2799                    .unbounded_send((params, response_tx))
2800                    .unwrap();
2801                response_rx.await.unwrap()
2802            }
2803        });
2804    context_server_store.update(cx, |store, cx| {
2805        store.start_server(
2806            Arc::new(ContextServer::new(
2807                ContextServerId(name.into()),
2808                Arc::new(fake_transport),
2809            )),
2810            cx,
2811        );
2812    });
2813    cx.run_until_parked();
2814    mcp_tool_calls_rx
2815}
2816
2817#[gpui::test]
2818async fn test_tokens_before_message(cx: &mut TestAppContext) {
2819    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2820    let fake_model = model.as_fake();
2821
2822    // First message
2823    let message_1_id = UserMessageId::new();
2824    thread
2825        .update(cx, |thread, cx| {
2826            thread.send(message_1_id.clone(), ["First message"], cx)
2827        })
2828        .unwrap();
2829    cx.run_until_parked();
2830
2831    // Before any response, tokens_before_message should return None for first message
2832    thread.read_with(cx, |thread, _| {
2833        assert_eq!(
2834            thread.tokens_before_message(&message_1_id),
2835            None,
2836            "First message should have no tokens before it"
2837        );
2838    });
2839
2840    // Complete first message with usage
2841    fake_model.send_last_completion_stream_text_chunk("Response 1");
2842    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2843        language_model::TokenUsage {
2844            input_tokens: 100,
2845            output_tokens: 50,
2846            cache_creation_input_tokens: 0,
2847            cache_read_input_tokens: 0,
2848        },
2849    ));
2850    fake_model.end_last_completion_stream();
2851    cx.run_until_parked();
2852
2853    // First message still has no tokens before it
2854    thread.read_with(cx, |thread, _| {
2855        assert_eq!(
2856            thread.tokens_before_message(&message_1_id),
2857            None,
2858            "First message should still have no tokens before it after response"
2859        );
2860    });
2861
2862    // Second message
2863    let message_2_id = UserMessageId::new();
2864    thread
2865        .update(cx, |thread, cx| {
2866            thread.send(message_2_id.clone(), ["Second message"], cx)
2867        })
2868        .unwrap();
2869    cx.run_until_parked();
2870
2871    // Second message should have first message's input tokens before it
2872    thread.read_with(cx, |thread, _| {
2873        assert_eq!(
2874            thread.tokens_before_message(&message_2_id),
2875            Some(100),
2876            "Second message should have 100 tokens before it (from first request)"
2877        );
2878    });
2879
2880    // Complete second message
2881    fake_model.send_last_completion_stream_text_chunk("Response 2");
2882    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2883        language_model::TokenUsage {
2884            input_tokens: 250, // Total for this request (includes previous context)
2885            output_tokens: 75,
2886            cache_creation_input_tokens: 0,
2887            cache_read_input_tokens: 0,
2888        },
2889    ));
2890    fake_model.end_last_completion_stream();
2891    cx.run_until_parked();
2892
2893    // Third message
2894    let message_3_id = UserMessageId::new();
2895    thread
2896        .update(cx, |thread, cx| {
2897            thread.send(message_3_id.clone(), ["Third message"], cx)
2898        })
2899        .unwrap();
2900    cx.run_until_parked();
2901
2902    // Third message should have second message's input tokens (250) before it
2903    thread.read_with(cx, |thread, _| {
2904        assert_eq!(
2905            thread.tokens_before_message(&message_3_id),
2906            Some(250),
2907            "Third message should have 250 tokens before it (from second request)"
2908        );
2909        // Second message should still have 100
2910        assert_eq!(
2911            thread.tokens_before_message(&message_2_id),
2912            Some(100),
2913            "Second message should still have 100 tokens before it"
2914        );
2915        // First message still has none
2916        assert_eq!(
2917            thread.tokens_before_message(&message_1_id),
2918            None,
2919            "First message should still have no tokens before it"
2920        );
2921    });
2922}
2923
2924#[gpui::test]
2925async fn test_tokens_before_message_after_truncate(cx: &mut TestAppContext) {
2926    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
2927    let fake_model = model.as_fake();
2928
2929    // Set up three messages with responses
2930    let message_1_id = UserMessageId::new();
2931    thread
2932        .update(cx, |thread, cx| {
2933            thread.send(message_1_id.clone(), ["Message 1"], cx)
2934        })
2935        .unwrap();
2936    cx.run_until_parked();
2937    fake_model.send_last_completion_stream_text_chunk("Response 1");
2938    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2939        language_model::TokenUsage {
2940            input_tokens: 100,
2941            output_tokens: 50,
2942            cache_creation_input_tokens: 0,
2943            cache_read_input_tokens: 0,
2944        },
2945    ));
2946    fake_model.end_last_completion_stream();
2947    cx.run_until_parked();
2948
2949    let message_2_id = UserMessageId::new();
2950    thread
2951        .update(cx, |thread, cx| {
2952            thread.send(message_2_id.clone(), ["Message 2"], cx)
2953        })
2954        .unwrap();
2955    cx.run_until_parked();
2956    fake_model.send_last_completion_stream_text_chunk("Response 2");
2957    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
2958        language_model::TokenUsage {
2959            input_tokens: 250,
2960            output_tokens: 75,
2961            cache_creation_input_tokens: 0,
2962            cache_read_input_tokens: 0,
2963        },
2964    ));
2965    fake_model.end_last_completion_stream();
2966    cx.run_until_parked();
2967
2968    // Verify initial state
2969    thread.read_with(cx, |thread, _| {
2970        assert_eq!(thread.tokens_before_message(&message_2_id), Some(100));
2971    });
2972
2973    // Truncate at message 2 (removes message 2 and everything after)
2974    thread
2975        .update(cx, |thread, cx| thread.truncate(message_2_id.clone(), cx))
2976        .unwrap();
2977    cx.run_until_parked();
2978
2979    // After truncation, message_2_id no longer exists, so lookup should return None
2980    thread.read_with(cx, |thread, _| {
2981        assert_eq!(
2982            thread.tokens_before_message(&message_2_id),
2983            None,
2984            "After truncation, message 2 no longer exists"
2985        );
2986        // Message 1 still exists but has no tokens before it
2987        assert_eq!(
2988            thread.tokens_before_message(&message_1_id),
2989            None,
2990            "First message still has no tokens before it"
2991        );
2992    });
2993}