mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use agent_client_protocol::{self as acp};
   4use agent_settings::AgentProfileId;
   5use anyhow::Result;
   6use client::{Client, UserStore};
   7use context_server::{ContextServer, ContextServerCommand, ContextServerId};
   8use fs::{FakeFs, Fs};
   9use futures::{
  10    StreamExt,
  11    channel::{
  12        mpsc::{self, UnboundedReceiver},
  13        oneshot,
  14    },
  15};
  16use gpui::{
  17    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  18};
  19use indoc::indoc;
  20use language_model::{
  21    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  22    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
  23    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
  24    LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
  25};
  26use pretty_assertions::assert_eq;
  27use project::{
  28    Project, context_server_store::ContextServerStore, project_settings::ProjectSettings,
  29};
  30use prompt_store::ProjectContext;
  31use reqwest_client::ReqwestClient;
  32use schemars::JsonSchema;
  33use serde::{Deserialize, Serialize};
  34use serde_json::json;
  35use settings::{Settings, SettingsStore};
  36use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
  37use util::path;
  38
  39mod test_tools;
  40use test_tools::*;
  41
  42#[gpui::test]
  43async fn test_echo(cx: &mut TestAppContext) {
  44    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  45    let fake_model = model.as_fake();
  46
  47    let events = thread
  48        .update(cx, |thread, cx| {
  49            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  50        })
  51        .unwrap();
  52    cx.run_until_parked();
  53    fake_model.send_last_completion_stream_text_chunk("Hello");
  54    fake_model
  55        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
  56    fake_model.end_last_completion_stream();
  57
  58    let events = events.collect().await;
  59    thread.update(cx, |thread, _cx| {
  60        assert_eq!(
  61            thread.last_message().unwrap().to_markdown(),
  62            indoc! {"
  63                ## Assistant
  64
  65                Hello
  66            "}
  67        )
  68    });
  69    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  70}
  71
  72#[gpui::test]
  73async fn test_thinking(cx: &mut TestAppContext) {
  74    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  75    let fake_model = model.as_fake();
  76
  77    let events = thread
  78        .update(cx, |thread, cx| {
  79            thread.send(
  80                UserMessageId::new(),
  81                [indoc! {"
  82                    Testing:
  83
  84                    Generate a thinking step where you just think the word 'Think',
  85                    and have your final answer be 'Hello'
  86                "}],
  87                cx,
  88            )
  89        })
  90        .unwrap();
  91    cx.run_until_parked();
  92    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
  93        text: "Think".to_string(),
  94        signature: None,
  95    });
  96    fake_model.send_last_completion_stream_text_chunk("Hello");
  97    fake_model
  98        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
  99    fake_model.end_last_completion_stream();
 100
 101    let events = events.collect().await;
 102    thread.update(cx, |thread, _cx| {
 103        assert_eq!(
 104            thread.last_message().unwrap().to_markdown(),
 105            indoc! {"
 106                ## Assistant
 107
 108                <think>Think</think>
 109                Hello
 110            "}
 111        )
 112    });
 113    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 114}
 115
 116#[gpui::test]
 117async fn test_system_prompt(cx: &mut TestAppContext) {
 118    let ThreadTest {
 119        model,
 120        thread,
 121        project_context,
 122        ..
 123    } = setup(cx, TestModel::Fake).await;
 124    let fake_model = model.as_fake();
 125
 126    project_context.update(cx, |project_context, _cx| {
 127        project_context.shell = "test-shell".into()
 128    });
 129    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 130    thread
 131        .update(cx, |thread, cx| {
 132            thread.send(UserMessageId::new(), ["abc"], cx)
 133        })
 134        .unwrap();
 135    cx.run_until_parked();
 136    let mut pending_completions = fake_model.pending_completions();
 137    assert_eq!(
 138        pending_completions.len(),
 139        1,
 140        "unexpected pending completions: {:?}",
 141        pending_completions
 142    );
 143
 144    let pending_completion = pending_completions.pop().unwrap();
 145    assert_eq!(pending_completion.messages[0].role, Role::System);
 146
 147    let system_message = &pending_completion.messages[0];
 148    let system_prompt = system_message.content[0].to_str().unwrap();
 149    assert!(
 150        system_prompt.contains("test-shell"),
 151        "unexpected system message: {:?}",
 152        system_message
 153    );
 154    assert!(
 155        system_prompt.contains("## Fixing Diagnostics"),
 156        "unexpected system message: {:?}",
 157        system_message
 158    );
 159}
 160
 161#[gpui::test]
 162async fn test_prompt_caching(cx: &mut TestAppContext) {
 163    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 164    let fake_model = model.as_fake();
 165
 166    // Send initial user message and verify it's cached
 167    thread
 168        .update(cx, |thread, cx| {
 169            thread.send(UserMessageId::new(), ["Message 1"], cx)
 170        })
 171        .unwrap();
 172    cx.run_until_parked();
 173
 174    let completion = fake_model.pending_completions().pop().unwrap();
 175    assert_eq!(
 176        completion.messages[1..],
 177        vec![LanguageModelRequestMessage {
 178            role: Role::User,
 179            content: vec!["Message 1".into()],
 180            cache: true
 181        }]
 182    );
 183    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 184        "Response to Message 1".into(),
 185    ));
 186    fake_model.end_last_completion_stream();
 187    cx.run_until_parked();
 188
 189    // Send another user message and verify only the latest is cached
 190    thread
 191        .update(cx, |thread, cx| {
 192            thread.send(UserMessageId::new(), ["Message 2"], cx)
 193        })
 194        .unwrap();
 195    cx.run_until_parked();
 196
 197    let completion = fake_model.pending_completions().pop().unwrap();
 198    assert_eq!(
 199        completion.messages[1..],
 200        vec![
 201            LanguageModelRequestMessage {
 202                role: Role::User,
 203                content: vec!["Message 1".into()],
 204                cache: false
 205            },
 206            LanguageModelRequestMessage {
 207                role: Role::Assistant,
 208                content: vec!["Response to Message 1".into()],
 209                cache: false
 210            },
 211            LanguageModelRequestMessage {
 212                role: Role::User,
 213                content: vec!["Message 2".into()],
 214                cache: true
 215            }
 216        ]
 217    );
 218    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 219        "Response to Message 2".into(),
 220    ));
 221    fake_model.end_last_completion_stream();
 222    cx.run_until_parked();
 223
 224    // Simulate a tool call and verify that the latest tool result is cached
 225    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 226    thread
 227        .update(cx, |thread, cx| {
 228            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 229        })
 230        .unwrap();
 231    cx.run_until_parked();
 232
 233    let tool_use = LanguageModelToolUse {
 234        id: "tool_1".into(),
 235        name: EchoTool::name().into(),
 236        raw_input: json!({"text": "test"}).to_string(),
 237        input: json!({"text": "test"}),
 238        is_input_complete: true,
 239    };
 240    fake_model
 241        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 242    fake_model.end_last_completion_stream();
 243    cx.run_until_parked();
 244
 245    let completion = fake_model.pending_completions().pop().unwrap();
 246    let tool_result = LanguageModelToolResult {
 247        tool_use_id: "tool_1".into(),
 248        tool_name: EchoTool::name().into(),
 249        is_error: false,
 250        content: "test".into(),
 251        output: Some("test".into()),
 252    };
 253    assert_eq!(
 254        completion.messages[1..],
 255        vec![
 256            LanguageModelRequestMessage {
 257                role: Role::User,
 258                content: vec!["Message 1".into()],
 259                cache: false
 260            },
 261            LanguageModelRequestMessage {
 262                role: Role::Assistant,
 263                content: vec!["Response to Message 1".into()],
 264                cache: false
 265            },
 266            LanguageModelRequestMessage {
 267                role: Role::User,
 268                content: vec!["Message 2".into()],
 269                cache: false
 270            },
 271            LanguageModelRequestMessage {
 272                role: Role::Assistant,
 273                content: vec!["Response to Message 2".into()],
 274                cache: false
 275            },
 276            LanguageModelRequestMessage {
 277                role: Role::User,
 278                content: vec!["Use the echo tool".into()],
 279                cache: false
 280            },
 281            LanguageModelRequestMessage {
 282                role: Role::Assistant,
 283                content: vec![MessageContent::ToolUse(tool_use)],
 284                cache: false
 285            },
 286            LanguageModelRequestMessage {
 287                role: Role::User,
 288                content: vec![MessageContent::ToolResult(tool_result)],
 289                cache: true
 290            }
 291        ]
 292    );
 293}
 294
 295#[gpui::test]
 296#[cfg_attr(not(feature = "e2e"), ignore)]
 297async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 298    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 299
 300    // Test a tool call that's likely to complete *before* streaming stops.
 301    let events = thread
 302        .update(cx, |thread, cx| {
 303            thread.add_tool(EchoTool);
 304            thread.send(
 305                UserMessageId::new(),
 306                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 307                cx,
 308            )
 309        })
 310        .unwrap()
 311        .collect()
 312        .await;
 313    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 314
 315    // Test a tool calls that's likely to complete *after* streaming stops.
 316    let events = thread
 317        .update(cx, |thread, cx| {
 318            thread.remove_tool(&EchoTool::name());
 319            thread.add_tool(DelayTool);
 320            thread.send(
 321                UserMessageId::new(),
 322                [
 323                    "Now call the delay tool with 200ms.",
 324                    "When the timer goes off, then you echo the output of the tool.",
 325                ],
 326                cx,
 327            )
 328        })
 329        .unwrap()
 330        .collect()
 331        .await;
 332    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 333    thread.update(cx, |thread, _cx| {
 334        assert!(
 335            thread
 336                .last_message()
 337                .unwrap()
 338                .as_agent_message()
 339                .unwrap()
 340                .content
 341                .iter()
 342                .any(|content| {
 343                    if let AgentMessageContent::Text(text) = content {
 344                        text.contains("Ding")
 345                    } else {
 346                        false
 347                    }
 348                }),
 349            "{}",
 350            thread.to_markdown()
 351        );
 352    });
 353}
 354
 355#[gpui::test]
 356#[cfg_attr(not(feature = "e2e"), ignore)]
 357async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 358    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 359
 360    // Test a tool call that's likely to complete *before* streaming stops.
 361    let mut events = thread
 362        .update(cx, |thread, cx| {
 363            thread.add_tool(WordListTool);
 364            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 365        })
 366        .unwrap();
 367
 368    let mut saw_partial_tool_use = false;
 369    while let Some(event) = events.next().await {
 370        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 371            thread.update(cx, |thread, _cx| {
 372                // Look for a tool use in the thread's last message
 373                let message = thread.last_message().unwrap();
 374                let agent_message = message.as_agent_message().unwrap();
 375                let last_content = agent_message.content.last().unwrap();
 376                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 377                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 378                    if tool_call.status == acp::ToolCallStatus::Pending {
 379                        if !last_tool_use.is_input_complete
 380                            && last_tool_use.input.get("g").is_none()
 381                        {
 382                            saw_partial_tool_use = true;
 383                        }
 384                    } else {
 385                        last_tool_use
 386                            .input
 387                            .get("a")
 388                            .expect("'a' has streamed because input is now complete");
 389                        last_tool_use
 390                            .input
 391                            .get("g")
 392                            .expect("'g' has streamed because input is now complete");
 393                    }
 394                } else {
 395                    panic!("last content should be a tool use");
 396                }
 397            });
 398        }
 399    }
 400
 401    assert!(
 402        saw_partial_tool_use,
 403        "should see at least one partially streamed tool use in the history"
 404    );
 405}
 406
 407#[gpui::test]
 408async fn test_tool_authorization(cx: &mut TestAppContext) {
 409    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 410    let fake_model = model.as_fake();
 411
 412    let mut events = thread
 413        .update(cx, |thread, cx| {
 414            thread.add_tool(ToolRequiringPermission);
 415            thread.send(UserMessageId::new(), ["abc"], cx)
 416        })
 417        .unwrap();
 418    cx.run_until_parked();
 419    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 420        LanguageModelToolUse {
 421            id: "tool_id_1".into(),
 422            name: ToolRequiringPermission::name().into(),
 423            raw_input: "{}".into(),
 424            input: json!({}),
 425            is_input_complete: true,
 426        },
 427    ));
 428    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 429        LanguageModelToolUse {
 430            id: "tool_id_2".into(),
 431            name: ToolRequiringPermission::name().into(),
 432            raw_input: "{}".into(),
 433            input: json!({}),
 434            is_input_complete: true,
 435        },
 436    ));
 437    fake_model.end_last_completion_stream();
 438    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 439    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 440
 441    // Approve the first
 442    tool_call_auth_1
 443        .response
 444        .send(tool_call_auth_1.options[1].id.clone())
 445        .unwrap();
 446    cx.run_until_parked();
 447
 448    // Reject the second
 449    tool_call_auth_2
 450        .response
 451        .send(tool_call_auth_1.options[2].id.clone())
 452        .unwrap();
 453    cx.run_until_parked();
 454
 455    let completion = fake_model.pending_completions().pop().unwrap();
 456    let message = completion.messages.last().unwrap();
 457    assert_eq!(
 458        message.content,
 459        vec![
 460            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 461                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 462                tool_name: ToolRequiringPermission::name().into(),
 463                is_error: false,
 464                content: "Allowed".into(),
 465                output: Some("Allowed".into())
 466            }),
 467            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 468                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 469                tool_name: ToolRequiringPermission::name().into(),
 470                is_error: true,
 471                content: "Permission to run tool denied by user".into(),
 472                output: None
 473            })
 474        ]
 475    );
 476
 477    // Simulate yet another tool call.
 478    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 479        LanguageModelToolUse {
 480            id: "tool_id_3".into(),
 481            name: ToolRequiringPermission::name().into(),
 482            raw_input: "{}".into(),
 483            input: json!({}),
 484            is_input_complete: true,
 485        },
 486    ));
 487    fake_model.end_last_completion_stream();
 488
 489    // Respond by always allowing tools.
 490    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 491    tool_call_auth_3
 492        .response
 493        .send(tool_call_auth_3.options[0].id.clone())
 494        .unwrap();
 495    cx.run_until_parked();
 496    let completion = fake_model.pending_completions().pop().unwrap();
 497    let message = completion.messages.last().unwrap();
 498    assert_eq!(
 499        message.content,
 500        vec![language_model::MessageContent::ToolResult(
 501            LanguageModelToolResult {
 502                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 503                tool_name: ToolRequiringPermission::name().into(),
 504                is_error: false,
 505                content: "Allowed".into(),
 506                output: Some("Allowed".into())
 507            }
 508        )]
 509    );
 510
 511    // Simulate a final tool call, ensuring we don't trigger authorization.
 512    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 513        LanguageModelToolUse {
 514            id: "tool_id_4".into(),
 515            name: ToolRequiringPermission::name().into(),
 516            raw_input: "{}".into(),
 517            input: json!({}),
 518            is_input_complete: true,
 519        },
 520    ));
 521    fake_model.end_last_completion_stream();
 522    cx.run_until_parked();
 523    let completion = fake_model.pending_completions().pop().unwrap();
 524    let message = completion.messages.last().unwrap();
 525    assert_eq!(
 526        message.content,
 527        vec![language_model::MessageContent::ToolResult(
 528            LanguageModelToolResult {
 529                tool_use_id: "tool_id_4".into(),
 530                tool_name: ToolRequiringPermission::name().into(),
 531                is_error: false,
 532                content: "Allowed".into(),
 533                output: Some("Allowed".into())
 534            }
 535        )]
 536    );
 537}
 538
 539#[gpui::test]
 540async fn test_tool_hallucination(cx: &mut TestAppContext) {
 541    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 542    let fake_model = model.as_fake();
 543
 544    let mut events = thread
 545        .update(cx, |thread, cx| {
 546            thread.send(UserMessageId::new(), ["abc"], cx)
 547        })
 548        .unwrap();
 549    cx.run_until_parked();
 550    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 551        LanguageModelToolUse {
 552            id: "tool_id_1".into(),
 553            name: "nonexistent_tool".into(),
 554            raw_input: "{}".into(),
 555            input: json!({}),
 556            is_input_complete: true,
 557        },
 558    ));
 559    fake_model.end_last_completion_stream();
 560
 561    let tool_call = expect_tool_call(&mut events).await;
 562    assert_eq!(tool_call.title, "nonexistent_tool");
 563    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 564    let update = expect_tool_call_update_fields(&mut events).await;
 565    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 566}
 567
 568#[gpui::test]
 569async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 570    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 571    let fake_model = model.as_fake();
 572
 573    let events = thread
 574        .update(cx, |thread, cx| {
 575            thread.add_tool(EchoTool);
 576            thread.send(UserMessageId::new(), ["abc"], cx)
 577        })
 578        .unwrap();
 579    cx.run_until_parked();
 580    let tool_use = LanguageModelToolUse {
 581        id: "tool_id_1".into(),
 582        name: EchoTool::name().into(),
 583        raw_input: "{}".into(),
 584        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 585        is_input_complete: true,
 586    };
 587    fake_model
 588        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 589    fake_model.end_last_completion_stream();
 590
 591    cx.run_until_parked();
 592    let completion = fake_model.pending_completions().pop().unwrap();
 593    let tool_result = LanguageModelToolResult {
 594        tool_use_id: "tool_id_1".into(),
 595        tool_name: EchoTool::name().into(),
 596        is_error: false,
 597        content: "def".into(),
 598        output: Some("def".into()),
 599    };
 600    assert_eq!(
 601        completion.messages[1..],
 602        vec![
 603            LanguageModelRequestMessage {
 604                role: Role::User,
 605                content: vec!["abc".into()],
 606                cache: false
 607            },
 608            LanguageModelRequestMessage {
 609                role: Role::Assistant,
 610                content: vec![MessageContent::ToolUse(tool_use.clone())],
 611                cache: false
 612            },
 613            LanguageModelRequestMessage {
 614                role: Role::User,
 615                content: vec![MessageContent::ToolResult(tool_result.clone())],
 616                cache: true
 617            },
 618        ]
 619    );
 620
 621    // Simulate reaching tool use limit.
 622    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 623        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 624    ));
 625    fake_model.end_last_completion_stream();
 626    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 627    assert!(
 628        last_event
 629            .unwrap_err()
 630            .is::<language_model::ToolUseLimitReachedError>()
 631    );
 632
 633    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 634    cx.run_until_parked();
 635    let completion = fake_model.pending_completions().pop().unwrap();
 636    assert_eq!(
 637        completion.messages[1..],
 638        vec![
 639            LanguageModelRequestMessage {
 640                role: Role::User,
 641                content: vec!["abc".into()],
 642                cache: false
 643            },
 644            LanguageModelRequestMessage {
 645                role: Role::Assistant,
 646                content: vec![MessageContent::ToolUse(tool_use)],
 647                cache: false
 648            },
 649            LanguageModelRequestMessage {
 650                role: Role::User,
 651                content: vec![MessageContent::ToolResult(tool_result)],
 652                cache: false
 653            },
 654            LanguageModelRequestMessage {
 655                role: Role::User,
 656                content: vec!["Continue where you left off".into()],
 657                cache: true
 658            }
 659        ]
 660    );
 661
 662    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 663    fake_model.end_last_completion_stream();
 664    events.collect::<Vec<_>>().await;
 665    thread.read_with(cx, |thread, _cx| {
 666        assert_eq!(
 667            thread.last_message().unwrap().to_markdown(),
 668            indoc! {"
 669                ## Assistant
 670
 671                Done
 672            "}
 673        )
 674    });
 675
 676    // Ensure we error if calling resume when tool use limit was *not* reached.
 677    let error = thread
 678        .update(cx, |thread, cx| thread.resume(cx))
 679        .unwrap_err();
 680    assert_eq!(
 681        error.to_string(),
 682        "can only resume after tool use limit is reached"
 683    )
 684}
 685
 686#[gpui::test]
 687async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 688    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 689    let fake_model = model.as_fake();
 690
 691    let events = thread
 692        .update(cx, |thread, cx| {
 693            thread.add_tool(EchoTool);
 694            thread.send(UserMessageId::new(), ["abc"], cx)
 695        })
 696        .unwrap();
 697    cx.run_until_parked();
 698
 699    let tool_use = LanguageModelToolUse {
 700        id: "tool_id_1".into(),
 701        name: EchoTool::name().into(),
 702        raw_input: "{}".into(),
 703        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 704        is_input_complete: true,
 705    };
 706    let tool_result = LanguageModelToolResult {
 707        tool_use_id: "tool_id_1".into(),
 708        tool_name: EchoTool::name().into(),
 709        is_error: false,
 710        content: "def".into(),
 711        output: Some("def".into()),
 712    };
 713    fake_model
 714        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 715    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 716        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 717    ));
 718    fake_model.end_last_completion_stream();
 719    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 720    assert!(
 721        last_event
 722            .unwrap_err()
 723            .is::<language_model::ToolUseLimitReachedError>()
 724    );
 725
 726    thread
 727        .update(cx, |thread, cx| {
 728            thread.send(UserMessageId::new(), vec!["ghi"], cx)
 729        })
 730        .unwrap();
 731    cx.run_until_parked();
 732    let completion = fake_model.pending_completions().pop().unwrap();
 733    assert_eq!(
 734        completion.messages[1..],
 735        vec![
 736            LanguageModelRequestMessage {
 737                role: Role::User,
 738                content: vec!["abc".into()],
 739                cache: false
 740            },
 741            LanguageModelRequestMessage {
 742                role: Role::Assistant,
 743                content: vec![MessageContent::ToolUse(tool_use)],
 744                cache: false
 745            },
 746            LanguageModelRequestMessage {
 747                role: Role::User,
 748                content: vec![MessageContent::ToolResult(tool_result)],
 749                cache: false
 750            },
 751            LanguageModelRequestMessage {
 752                role: Role::User,
 753                content: vec!["ghi".into()],
 754                cache: true
 755            }
 756        ]
 757    );
 758}
 759
 760async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
 761    let event = events
 762        .next()
 763        .await
 764        .expect("no tool call authorization event received")
 765        .unwrap();
 766    match event {
 767        ThreadEvent::ToolCall(tool_call) => tool_call,
 768        event => {
 769            panic!("Unexpected event {event:?}");
 770        }
 771    }
 772}
 773
 774async fn expect_tool_call_update_fields(
 775    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 776) -> acp::ToolCallUpdate {
 777    let event = events
 778        .next()
 779        .await
 780        .expect("no tool call authorization event received")
 781        .unwrap();
 782    match event {
 783        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
 784        event => {
 785            panic!("Unexpected event {event:?}");
 786        }
 787    }
 788}
 789
 790async fn next_tool_call_authorization(
 791    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 792) -> ToolCallAuthorization {
 793    loop {
 794        let event = events
 795            .next()
 796            .await
 797            .expect("no tool call authorization event received")
 798            .unwrap();
 799        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
 800            let permission_kinds = tool_call_authorization
 801                .options
 802                .iter()
 803                .map(|o| o.kind)
 804                .collect::<Vec<_>>();
 805            assert_eq!(
 806                permission_kinds,
 807                vec![
 808                    acp::PermissionOptionKind::AllowAlways,
 809                    acp::PermissionOptionKind::AllowOnce,
 810                    acp::PermissionOptionKind::RejectOnce,
 811                ]
 812            );
 813            return tool_call_authorization;
 814        }
 815    }
 816}
 817
 818#[gpui::test]
 819#[cfg_attr(not(feature = "e2e"), ignore)]
 820async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 821    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 822
 823    // Test concurrent tool calls with different delay times
 824    let events = thread
 825        .update(cx, |thread, cx| {
 826            thread.add_tool(DelayTool);
 827            thread.send(
 828                UserMessageId::new(),
 829                [
 830                    "Call the delay tool twice in the same message.",
 831                    "Once with 100ms. Once with 300ms.",
 832                    "When both timers are complete, describe the outputs.",
 833                ],
 834                cx,
 835            )
 836        })
 837        .unwrap()
 838        .collect()
 839        .await;
 840
 841    let stop_reasons = stop_events(events);
 842    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 843
 844    thread.update(cx, |thread, _cx| {
 845        let last_message = thread.last_message().unwrap();
 846        let agent_message = last_message.as_agent_message().unwrap();
 847        let text = agent_message
 848            .content
 849            .iter()
 850            .filter_map(|content| {
 851                if let AgentMessageContent::Text(text) = content {
 852                    Some(text.as_str())
 853                } else {
 854                    None
 855                }
 856            })
 857            .collect::<String>();
 858
 859        assert!(text.contains("Ding"));
 860    });
 861}
 862
 863#[gpui::test]
 864async fn test_profiles(cx: &mut TestAppContext) {
 865    let ThreadTest {
 866        model, thread, fs, ..
 867    } = setup(cx, TestModel::Fake).await;
 868    let fake_model = model.as_fake();
 869
 870    thread.update(cx, |thread, _cx| {
 871        thread.add_tool(DelayTool);
 872        thread.add_tool(EchoTool);
 873        thread.add_tool(InfiniteTool);
 874    });
 875
 876    // Override profiles and wait for settings to be loaded.
 877    fs.insert_file(
 878        paths::settings_file(),
 879        json!({
 880            "agent": {
 881                "profiles": {
 882                    "test-1": {
 883                        "name": "Test Profile 1",
 884                        "tools": {
 885                            EchoTool::name(): true,
 886                            DelayTool::name(): true,
 887                        }
 888                    },
 889                    "test-2": {
 890                        "name": "Test Profile 2",
 891                        "tools": {
 892                            InfiniteTool::name(): true,
 893                        }
 894                    }
 895                }
 896            }
 897        })
 898        .to_string()
 899        .into_bytes(),
 900    )
 901    .await;
 902    cx.run_until_parked();
 903
 904    // Test that test-1 profile (default) has echo and delay tools
 905    thread
 906        .update(cx, |thread, cx| {
 907            thread.set_profile(AgentProfileId("test-1".into()));
 908            thread.send(UserMessageId::new(), ["test"], cx)
 909        })
 910        .unwrap();
 911    cx.run_until_parked();
 912
 913    let mut pending_completions = fake_model.pending_completions();
 914    assert_eq!(pending_completions.len(), 1);
 915    let completion = pending_completions.pop().unwrap();
 916    let tool_names: Vec<String> = completion
 917        .tools
 918        .iter()
 919        .map(|tool| tool.name.clone())
 920        .collect();
 921    assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
 922    fake_model.end_last_completion_stream();
 923
 924    // Switch to test-2 profile, and verify that it has only the infinite tool.
 925    thread
 926        .update(cx, |thread, cx| {
 927            thread.set_profile(AgentProfileId("test-2".into()));
 928            thread.send(UserMessageId::new(), ["test2"], cx)
 929        })
 930        .unwrap();
 931    cx.run_until_parked();
 932    let mut pending_completions = fake_model.pending_completions();
 933    assert_eq!(pending_completions.len(), 1);
 934    let completion = pending_completions.pop().unwrap();
 935    let tool_names: Vec<String> = completion
 936        .tools
 937        .iter()
 938        .map(|tool| tool.name.clone())
 939        .collect();
 940    assert_eq!(tool_names, vec![InfiniteTool::name()]);
 941}
 942
 943#[gpui::test]
 944async fn test_mcp_tools(cx: &mut TestAppContext) {
 945    let ThreadTest {
 946        model,
 947        thread,
 948        context_server_store,
 949        fs,
 950        ..
 951    } = setup(cx, TestModel::Fake).await;
 952    let fake_model = model.as_fake();
 953
 954    // Override profiles and wait for settings to be loaded.
 955    fs.insert_file(
 956        paths::settings_file(),
 957        json!({
 958            "agent": {
 959                "profiles": {
 960                    "test": {
 961                        "name": "Test Profile",
 962                        "enable_all_context_servers": true,
 963                        "tools": {
 964                            EchoTool::name(): true,
 965                        }
 966                    },
 967                }
 968            }
 969        })
 970        .to_string()
 971        .into_bytes(),
 972    )
 973    .await;
 974    cx.run_until_parked();
 975    thread.update(cx, |thread, _| {
 976        thread.set_profile(AgentProfileId("test".into()))
 977    });
 978
 979    let mut mcp_tool_calls = setup_context_server(
 980        "test_server",
 981        vec![context_server::types::Tool {
 982            name: "echo".into(),
 983            description: None,
 984            input_schema: serde_json::to_value(
 985                EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
 986            )
 987            .unwrap(),
 988            output_schema: None,
 989            annotations: None,
 990        }],
 991        &context_server_store,
 992        cx,
 993    );
 994
 995    let events = thread.update(cx, |thread, cx| {
 996        thread.send(UserMessageId::new(), ["Hey"], cx).unwrap()
 997    });
 998    cx.run_until_parked();
 999
1000    // Simulate the model calling the MCP tool.
1001    let completion = fake_model.pending_completions().pop().unwrap();
1002    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1003    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1004        LanguageModelToolUse {
1005            id: "tool_1".into(),
1006            name: "echo".into(),
1007            raw_input: json!({"text": "test"}).to_string(),
1008            input: json!({"text": "test"}),
1009            is_input_complete: true,
1010        },
1011    ));
1012    fake_model.end_last_completion_stream();
1013    cx.run_until_parked();
1014
1015    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1016    assert_eq!(tool_call_params.name, "echo");
1017    assert_eq!(tool_call_params.arguments, Some(json!({"text": "test"})));
1018    tool_call_response
1019        .send(context_server::types::CallToolResponse {
1020            content: vec![context_server::types::ToolResponseContent::Text {
1021                text: "test".into(),
1022            }],
1023            is_error: None,
1024            meta: None,
1025            structured_content: None,
1026        })
1027        .unwrap();
1028    cx.run_until_parked();
1029
1030    assert_eq!(tool_names_for_completion(&completion), vec!["echo"]);
1031    fake_model.send_last_completion_stream_text_chunk("Done!");
1032    fake_model.end_last_completion_stream();
1033    events.collect::<Vec<_>>().await;
1034
1035    // Send again after adding the echo tool, ensuring the name collision is resolved.
1036    let events = thread.update(cx, |thread, cx| {
1037        thread.add_tool(EchoTool);
1038        thread.send(UserMessageId::new(), ["Go"], cx).unwrap()
1039    });
1040    cx.run_until_parked();
1041    let completion = fake_model.pending_completions().pop().unwrap();
1042    assert_eq!(
1043        tool_names_for_completion(&completion),
1044        vec!["echo", "test_server_echo"]
1045    );
1046    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1047        LanguageModelToolUse {
1048            id: "tool_2".into(),
1049            name: "test_server_echo".into(),
1050            raw_input: json!({"text": "mcp"}).to_string(),
1051            input: json!({"text": "mcp"}),
1052            is_input_complete: true,
1053        },
1054    ));
1055    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1056        LanguageModelToolUse {
1057            id: "tool_3".into(),
1058            name: "echo".into(),
1059            raw_input: json!({"text": "native"}).to_string(),
1060            input: json!({"text": "native"}),
1061            is_input_complete: true,
1062        },
1063    ));
1064    fake_model.end_last_completion_stream();
1065    cx.run_until_parked();
1066
1067    let (tool_call_params, tool_call_response) = mcp_tool_calls.next().await.unwrap();
1068    assert_eq!(tool_call_params.name, "echo");
1069    assert_eq!(tool_call_params.arguments, Some(json!({"text": "mcp"})));
1070    tool_call_response
1071        .send(context_server::types::CallToolResponse {
1072            content: vec![context_server::types::ToolResponseContent::Text { text: "mcp".into() }],
1073            is_error: None,
1074            meta: None,
1075            structured_content: None,
1076        })
1077        .unwrap();
1078    cx.run_until_parked();
1079
1080    // Ensure the tool results were inserted with the correct names.
1081    let completion = fake_model.pending_completions().pop().unwrap();
1082    assert_eq!(
1083        completion.messages.last().unwrap().content,
1084        vec![
1085            MessageContent::ToolResult(LanguageModelToolResult {
1086                tool_use_id: "tool_3".into(),
1087                tool_name: "echo".into(),
1088                is_error: false,
1089                content: "native".into(),
1090                output: Some("native".into()),
1091            },),
1092            MessageContent::ToolResult(LanguageModelToolResult {
1093                tool_use_id: "tool_2".into(),
1094                tool_name: "test_server_echo".into(),
1095                is_error: false,
1096                content: "mcp".into(),
1097                output: Some("mcp".into()),
1098            },),
1099        ]
1100    );
1101    fake_model.end_last_completion_stream();
1102    events.collect::<Vec<_>>().await;
1103}
1104
1105#[gpui::test]
1106async fn test_mcp_tool_truncation(cx: &mut TestAppContext) {
1107    let ThreadTest {
1108        model,
1109        thread,
1110        context_server_store,
1111        fs,
1112        ..
1113    } = setup(cx, TestModel::Fake).await;
1114    let fake_model = model.as_fake();
1115
1116    // Set up a profile with all tools enabled
1117    fs.insert_file(
1118        paths::settings_file(),
1119        json!({
1120            "agent": {
1121                "profiles": {
1122                    "test": {
1123                        "name": "Test Profile",
1124                        "enable_all_context_servers": true,
1125                        "tools": {
1126                            EchoTool::name(): true,
1127                            DelayTool::name(): true,
1128                            WordListTool::name(): true,
1129                            ToolRequiringPermission::name(): true,
1130                            InfiniteTool::name(): true,
1131                        }
1132                    },
1133                }
1134            }
1135        })
1136        .to_string()
1137        .into_bytes(),
1138    )
1139    .await;
1140    cx.run_until_parked();
1141
1142    thread.update(cx, |thread, _| {
1143        thread.set_profile(AgentProfileId("test".into()));
1144        thread.add_tool(EchoTool);
1145        thread.add_tool(DelayTool);
1146        thread.add_tool(WordListTool);
1147        thread.add_tool(ToolRequiringPermission);
1148        thread.add_tool(InfiniteTool);
1149    });
1150
1151    // Set up multiple context servers with some overlapping tool names
1152    let _server1_calls = setup_context_server(
1153        "xxx",
1154        vec![
1155            context_server::types::Tool {
1156                name: "echo".into(), // Conflicts with native EchoTool
1157                description: None,
1158                input_schema: serde_json::to_value(
1159                    EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
1160                )
1161                .unwrap(),
1162                output_schema: None,
1163                annotations: None,
1164            },
1165            context_server::types::Tool {
1166                name: "unique_tool_1".into(),
1167                description: None,
1168                input_schema: json!({"type": "object", "properties": {}}),
1169                output_schema: None,
1170                annotations: None,
1171            },
1172        ],
1173        &context_server_store,
1174        cx,
1175    );
1176
1177    let _server2_calls = setup_context_server(
1178        "yyy",
1179        vec![
1180            context_server::types::Tool {
1181                name: "echo".into(), // Also conflicts with native EchoTool
1182                description: None,
1183                input_schema: serde_json::to_value(
1184                    EchoTool.input_schema(LanguageModelToolSchemaFormat::JsonSchema),
1185                )
1186                .unwrap(),
1187                output_schema: None,
1188                annotations: None,
1189            },
1190            context_server::types::Tool {
1191                name: "unique_tool_2".into(),
1192                description: None,
1193                input_schema: json!({"type": "object", "properties": {}}),
1194                output_schema: None,
1195                annotations: None,
1196            },
1197            context_server::types::Tool {
1198                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1199                description: None,
1200                input_schema: json!({"type": "object", "properties": {}}),
1201                output_schema: None,
1202                annotations: None,
1203            },
1204            context_server::types::Tool {
1205                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1206                description: None,
1207                input_schema: json!({"type": "object", "properties": {}}),
1208                output_schema: None,
1209                annotations: None,
1210            },
1211        ],
1212        &context_server_store,
1213        cx,
1214    );
1215    let _server3_calls = setup_context_server(
1216        "zzz",
1217        vec![
1218            context_server::types::Tool {
1219                name: "a".repeat(MAX_TOOL_NAME_LENGTH - 2),
1220                description: None,
1221                input_schema: json!({"type": "object", "properties": {}}),
1222                output_schema: None,
1223                annotations: None,
1224            },
1225            context_server::types::Tool {
1226                name: "b".repeat(MAX_TOOL_NAME_LENGTH - 1),
1227                description: None,
1228                input_schema: json!({"type": "object", "properties": {}}),
1229                output_schema: None,
1230                annotations: None,
1231            },
1232            context_server::types::Tool {
1233                name: "c".repeat(MAX_TOOL_NAME_LENGTH + 1),
1234                description: None,
1235                input_schema: json!({"type": "object", "properties": {}}),
1236                output_schema: None,
1237                annotations: None,
1238            },
1239        ],
1240        &context_server_store,
1241        cx,
1242    );
1243
1244    thread
1245        .update(cx, |thread, cx| {
1246            thread.send(UserMessageId::new(), ["Go"], cx)
1247        })
1248        .unwrap();
1249    cx.run_until_parked();
1250    let completion = fake_model.pending_completions().pop().unwrap();
1251    assert_eq!(
1252        tool_names_for_completion(&completion),
1253        vec![
1254            "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb",
1255            "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc",
1256            "delay",
1257            "echo",
1258            "infinite",
1259            "tool_requiring_permission",
1260            "unique_tool_1",
1261            "unique_tool_2",
1262            "word_list",
1263            "xxx_echo",
1264            "y_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1265            "yyy_echo",
1266            "z_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
1267        ]
1268    );
1269}
1270
1271#[gpui::test]
1272#[cfg_attr(not(feature = "e2e"), ignore)]
1273async fn test_cancellation(cx: &mut TestAppContext) {
1274    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
1275
1276    let mut events = thread
1277        .update(cx, |thread, cx| {
1278            thread.add_tool(InfiniteTool);
1279            thread.add_tool(EchoTool);
1280            thread.send(
1281                UserMessageId::new(),
1282                ["Call the echo tool, then call the infinite tool, then explain their output"],
1283                cx,
1284            )
1285        })
1286        .unwrap();
1287
1288    // Wait until both tools are called.
1289    let mut expected_tools = vec!["Echo", "Infinite Tool"];
1290    let mut echo_id = None;
1291    let mut echo_completed = false;
1292    while let Some(event) = events.next().await {
1293        match event.unwrap() {
1294            ThreadEvent::ToolCall(tool_call) => {
1295                assert_eq!(tool_call.title, expected_tools.remove(0));
1296                if tool_call.title == "Echo" {
1297                    echo_id = Some(tool_call.id);
1298                }
1299            }
1300            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
1301                acp::ToolCallUpdate {
1302                    id,
1303                    fields:
1304                        acp::ToolCallUpdateFields {
1305                            status: Some(acp::ToolCallStatus::Completed),
1306                            ..
1307                        },
1308                },
1309            )) if Some(&id) == echo_id.as_ref() => {
1310                echo_completed = true;
1311            }
1312            _ => {}
1313        }
1314
1315        if expected_tools.is_empty() && echo_completed {
1316            break;
1317        }
1318    }
1319
1320    // Cancel the current send and ensure that the event stream is closed, even
1321    // if one of the tools is still running.
1322    thread.update(cx, |thread, cx| thread.cancel(cx));
1323    let events = events.collect::<Vec<_>>().await;
1324    let last_event = events.last();
1325    assert!(
1326        matches!(
1327            last_event,
1328            Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
1329        ),
1330        "unexpected event {last_event:?}"
1331    );
1332
1333    // Ensure we can still send a new message after cancellation.
1334    let events = thread
1335        .update(cx, |thread, cx| {
1336            thread.send(
1337                UserMessageId::new(),
1338                ["Testing: reply with 'Hello' then stop."],
1339                cx,
1340            )
1341        })
1342        .unwrap()
1343        .collect::<Vec<_>>()
1344        .await;
1345    thread.update(cx, |thread, _cx| {
1346        let message = thread.last_message().unwrap();
1347        let agent_message = message.as_agent_message().unwrap();
1348        assert_eq!(
1349            agent_message.content,
1350            vec![AgentMessageContent::Text("Hello".to_string())]
1351        );
1352    });
1353    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1354}
1355
1356#[gpui::test]
1357async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1358    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1359    let fake_model = model.as_fake();
1360
1361    let events_1 = thread
1362        .update(cx, |thread, cx| {
1363            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1364        })
1365        .unwrap();
1366    cx.run_until_parked();
1367    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1368    cx.run_until_parked();
1369
1370    let events_2 = thread
1371        .update(cx, |thread, cx| {
1372            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1373        })
1374        .unwrap();
1375    cx.run_until_parked();
1376    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1377    fake_model
1378        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1379    fake_model.end_last_completion_stream();
1380
1381    let events_1 = events_1.collect::<Vec<_>>().await;
1382    assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1383    let events_2 = events_2.collect::<Vec<_>>().await;
1384    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1385}
1386
1387#[gpui::test]
1388async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1389    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1390    let fake_model = model.as_fake();
1391
1392    let events_1 = thread
1393        .update(cx, |thread, cx| {
1394            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1395        })
1396        .unwrap();
1397    cx.run_until_parked();
1398    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1399    fake_model
1400        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1401    fake_model.end_last_completion_stream();
1402    let events_1 = events_1.collect::<Vec<_>>().await;
1403
1404    let events_2 = thread
1405        .update(cx, |thread, cx| {
1406            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1407        })
1408        .unwrap();
1409    cx.run_until_parked();
1410    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1411    fake_model
1412        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1413    fake_model.end_last_completion_stream();
1414    let events_2 = events_2.collect::<Vec<_>>().await;
1415
1416    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1417    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1418}
1419
1420#[gpui::test]
1421async fn test_refusal(cx: &mut TestAppContext) {
1422    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1423    let fake_model = model.as_fake();
1424
1425    let events = thread
1426        .update(cx, |thread, cx| {
1427            thread.send(UserMessageId::new(), ["Hello"], cx)
1428        })
1429        .unwrap();
1430    cx.run_until_parked();
1431    thread.read_with(cx, |thread, _| {
1432        assert_eq!(
1433            thread.to_markdown(),
1434            indoc! {"
1435                ## User
1436
1437                Hello
1438            "}
1439        );
1440    });
1441
1442    fake_model.send_last_completion_stream_text_chunk("Hey!");
1443    cx.run_until_parked();
1444    thread.read_with(cx, |thread, _| {
1445        assert_eq!(
1446            thread.to_markdown(),
1447            indoc! {"
1448                ## User
1449
1450                Hello
1451
1452                ## Assistant
1453
1454                Hey!
1455            "}
1456        );
1457    });
1458
1459    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1460    fake_model
1461        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1462    let events = events.collect::<Vec<_>>().await;
1463    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1464    thread.read_with(cx, |thread, _| {
1465        assert_eq!(thread.to_markdown(), "");
1466    });
1467}
1468
1469#[gpui::test]
1470async fn test_truncate_first_message(cx: &mut TestAppContext) {
1471    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1472    let fake_model = model.as_fake();
1473
1474    let message_id = UserMessageId::new();
1475    thread
1476        .update(cx, |thread, cx| {
1477            thread.send(message_id.clone(), ["Hello"], cx)
1478        })
1479        .unwrap();
1480    cx.run_until_parked();
1481    thread.read_with(cx, |thread, _| {
1482        assert_eq!(
1483            thread.to_markdown(),
1484            indoc! {"
1485                ## User
1486
1487                Hello
1488            "}
1489        );
1490        assert_eq!(thread.latest_token_usage(), None);
1491    });
1492
1493    fake_model.send_last_completion_stream_text_chunk("Hey!");
1494    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1495        language_model::TokenUsage {
1496            input_tokens: 32_000,
1497            output_tokens: 16_000,
1498            cache_creation_input_tokens: 0,
1499            cache_read_input_tokens: 0,
1500        },
1501    ));
1502    cx.run_until_parked();
1503    thread.read_with(cx, |thread, _| {
1504        assert_eq!(
1505            thread.to_markdown(),
1506            indoc! {"
1507                ## User
1508
1509                Hello
1510
1511                ## Assistant
1512
1513                Hey!
1514            "}
1515        );
1516        assert_eq!(
1517            thread.latest_token_usage(),
1518            Some(acp_thread::TokenUsage {
1519                used_tokens: 32_000 + 16_000,
1520                max_tokens: 1_000_000,
1521            })
1522        );
1523    });
1524
1525    thread
1526        .update(cx, |thread, cx| thread.truncate(message_id, cx))
1527        .unwrap();
1528    cx.run_until_parked();
1529    thread.read_with(cx, |thread, _| {
1530        assert_eq!(thread.to_markdown(), "");
1531        assert_eq!(thread.latest_token_usage(), None);
1532    });
1533
1534    // Ensure we can still send a new message after truncation.
1535    thread
1536        .update(cx, |thread, cx| {
1537            thread.send(UserMessageId::new(), ["Hi"], cx)
1538        })
1539        .unwrap();
1540    thread.update(cx, |thread, _cx| {
1541        assert_eq!(
1542            thread.to_markdown(),
1543            indoc! {"
1544                ## User
1545
1546                Hi
1547            "}
1548        );
1549    });
1550    cx.run_until_parked();
1551    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1552    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1553        language_model::TokenUsage {
1554            input_tokens: 40_000,
1555            output_tokens: 20_000,
1556            cache_creation_input_tokens: 0,
1557            cache_read_input_tokens: 0,
1558        },
1559    ));
1560    cx.run_until_parked();
1561    thread.read_with(cx, |thread, _| {
1562        assert_eq!(
1563            thread.to_markdown(),
1564            indoc! {"
1565                ## User
1566
1567                Hi
1568
1569                ## Assistant
1570
1571                Ahoy!
1572            "}
1573        );
1574
1575        assert_eq!(
1576            thread.latest_token_usage(),
1577            Some(acp_thread::TokenUsage {
1578                used_tokens: 40_000 + 20_000,
1579                max_tokens: 1_000_000,
1580            })
1581        );
1582    });
1583}
1584
1585#[gpui::test]
1586async fn test_truncate_second_message(cx: &mut TestAppContext) {
1587    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1588    let fake_model = model.as_fake();
1589
1590    thread
1591        .update(cx, |thread, cx| {
1592            thread.send(UserMessageId::new(), ["Message 1"], cx)
1593        })
1594        .unwrap();
1595    cx.run_until_parked();
1596    fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1597    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1598        language_model::TokenUsage {
1599            input_tokens: 32_000,
1600            output_tokens: 16_000,
1601            cache_creation_input_tokens: 0,
1602            cache_read_input_tokens: 0,
1603        },
1604    ));
1605    fake_model.end_last_completion_stream();
1606    cx.run_until_parked();
1607
1608    let assert_first_message_state = |cx: &mut TestAppContext| {
1609        thread.clone().read_with(cx, |thread, _| {
1610            assert_eq!(
1611                thread.to_markdown(),
1612                indoc! {"
1613                    ## User
1614
1615                    Message 1
1616
1617                    ## Assistant
1618
1619                    Message 1 response
1620                "}
1621            );
1622
1623            assert_eq!(
1624                thread.latest_token_usage(),
1625                Some(acp_thread::TokenUsage {
1626                    used_tokens: 32_000 + 16_000,
1627                    max_tokens: 1_000_000,
1628                })
1629            );
1630        });
1631    };
1632
1633    assert_first_message_state(cx);
1634
1635    let second_message_id = UserMessageId::new();
1636    thread
1637        .update(cx, |thread, cx| {
1638            thread.send(second_message_id.clone(), ["Message 2"], cx)
1639        })
1640        .unwrap();
1641    cx.run_until_parked();
1642
1643    fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1644    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1645        language_model::TokenUsage {
1646            input_tokens: 40_000,
1647            output_tokens: 20_000,
1648            cache_creation_input_tokens: 0,
1649            cache_read_input_tokens: 0,
1650        },
1651    ));
1652    fake_model.end_last_completion_stream();
1653    cx.run_until_parked();
1654
1655    thread.read_with(cx, |thread, _| {
1656        assert_eq!(
1657            thread.to_markdown(),
1658            indoc! {"
1659                ## User
1660
1661                Message 1
1662
1663                ## Assistant
1664
1665                Message 1 response
1666
1667                ## User
1668
1669                Message 2
1670
1671                ## Assistant
1672
1673                Message 2 response
1674            "}
1675        );
1676
1677        assert_eq!(
1678            thread.latest_token_usage(),
1679            Some(acp_thread::TokenUsage {
1680                used_tokens: 40_000 + 20_000,
1681                max_tokens: 1_000_000,
1682            })
1683        );
1684    });
1685
1686    thread
1687        .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1688        .unwrap();
1689    cx.run_until_parked();
1690
1691    assert_first_message_state(cx);
1692}
1693
1694#[gpui::test]
1695async fn test_title_generation(cx: &mut TestAppContext) {
1696    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1697    let fake_model = model.as_fake();
1698
1699    let summary_model = Arc::new(FakeLanguageModel::default());
1700    thread.update(cx, |thread, cx| {
1701        thread.set_summarization_model(Some(summary_model.clone()), cx)
1702    });
1703
1704    let send = thread
1705        .update(cx, |thread, cx| {
1706            thread.send(UserMessageId::new(), ["Hello"], cx)
1707        })
1708        .unwrap();
1709    cx.run_until_parked();
1710
1711    fake_model.send_last_completion_stream_text_chunk("Hey!");
1712    fake_model.end_last_completion_stream();
1713    cx.run_until_parked();
1714    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1715
1716    // Ensure the summary model has been invoked to generate a title.
1717    summary_model.send_last_completion_stream_text_chunk("Hello ");
1718    summary_model.send_last_completion_stream_text_chunk("world\nG");
1719    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1720    summary_model.end_last_completion_stream();
1721    send.collect::<Vec<_>>().await;
1722    cx.run_until_parked();
1723    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1724
1725    // Send another message, ensuring no title is generated this time.
1726    let send = thread
1727        .update(cx, |thread, cx| {
1728            thread.send(UserMessageId::new(), ["Hello again"], cx)
1729        })
1730        .unwrap();
1731    cx.run_until_parked();
1732    fake_model.send_last_completion_stream_text_chunk("Hey again!");
1733    fake_model.end_last_completion_stream();
1734    cx.run_until_parked();
1735    assert_eq!(summary_model.pending_completions(), Vec::new());
1736    send.collect::<Vec<_>>().await;
1737    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1738}
1739
1740#[gpui::test]
1741async fn test_agent_connection(cx: &mut TestAppContext) {
1742    cx.update(settings::init);
1743    let templates = Templates::new();
1744
1745    // Initialize language model system with test provider
1746    cx.update(|cx| {
1747        gpui_tokio::init(cx);
1748        client::init_settings(cx);
1749
1750        let http_client = FakeHttpClient::with_404_response();
1751        let clock = Arc::new(clock::FakeSystemClock::new());
1752        let client = Client::new(clock, http_client, cx);
1753        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1754        Project::init_settings(cx);
1755        agent_settings::init(cx);
1756        language_model::init(client.clone(), cx);
1757        language_models::init(user_store, client.clone(), cx);
1758        LanguageModelRegistry::test(cx);
1759    });
1760    cx.executor().forbid_parking();
1761
1762    // Create a project for new_thread
1763    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1764    fake_fs.insert_tree(path!("/test"), json!({})).await;
1765    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1766    let cwd = Path::new("/test");
1767    let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1768    let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1769
1770    // Create agent and connection
1771    let agent = NativeAgent::new(
1772        project.clone(),
1773        history_store,
1774        templates.clone(),
1775        None,
1776        fake_fs.clone(),
1777        &mut cx.to_async(),
1778    )
1779    .await
1780    .unwrap();
1781    let connection = NativeAgentConnection(agent.clone());
1782
1783    // Test model_selector returns Some
1784    let selector_opt = connection.model_selector();
1785    assert!(
1786        selector_opt.is_some(),
1787        "agent2 should always support ModelSelector"
1788    );
1789    let selector = selector_opt.unwrap();
1790
1791    // Test list_models
1792    let listed_models = cx
1793        .update(|cx| selector.list_models(cx))
1794        .await
1795        .expect("list_models should succeed");
1796    let AgentModelList::Grouped(listed_models) = listed_models else {
1797        panic!("Unexpected model list type");
1798    };
1799    assert!(!listed_models.is_empty(), "should have at least one model");
1800    assert_eq!(
1801        listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1802        "fake/fake"
1803    );
1804
1805    // Create a thread using new_thread
1806    let connection_rc = Rc::new(connection.clone());
1807    let acp_thread = cx
1808        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1809        .await
1810        .expect("new_thread should succeed");
1811
1812    // Get the session_id from the AcpThread
1813    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1814
1815    // Test selected_model returns the default
1816    let model = cx
1817        .update(|cx| selector.selected_model(&session_id, cx))
1818        .await
1819        .expect("selected_model should succeed");
1820    let model = cx
1821        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1822        .unwrap();
1823    let model = model.as_fake();
1824    assert_eq!(model.id().0, "fake", "should return default model");
1825
1826    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1827    cx.run_until_parked();
1828    model.send_last_completion_stream_text_chunk("def");
1829    cx.run_until_parked();
1830    acp_thread.read_with(cx, |thread, cx| {
1831        assert_eq!(
1832            thread.to_markdown(cx),
1833            indoc! {"
1834                ## User
1835
1836                abc
1837
1838                ## Assistant
1839
1840                def
1841
1842            "}
1843        )
1844    });
1845
1846    // Test cancel
1847    cx.update(|cx| connection.cancel(&session_id, cx));
1848    request.await.expect("prompt should fail gracefully");
1849
1850    // Ensure that dropping the ACP thread causes the native thread to be
1851    // dropped as well.
1852    cx.update(|_| drop(acp_thread));
1853    let result = cx
1854        .update(|cx| {
1855            connection.prompt(
1856                Some(acp_thread::UserMessageId::new()),
1857                acp::PromptRequest {
1858                    session_id: session_id.clone(),
1859                    prompt: vec!["ghi".into()],
1860                },
1861                cx,
1862            )
1863        })
1864        .await;
1865    assert_eq!(
1866        result.as_ref().unwrap_err().to_string(),
1867        "Session not found",
1868        "unexpected result: {:?}",
1869        result
1870    );
1871}
1872
1873#[gpui::test]
1874async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1875    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1876    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1877    let fake_model = model.as_fake();
1878
1879    let mut events = thread
1880        .update(cx, |thread, cx| {
1881            thread.send(UserMessageId::new(), ["Think"], cx)
1882        })
1883        .unwrap();
1884    cx.run_until_parked();
1885
1886    // Simulate streaming partial input.
1887    let input = json!({});
1888    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1889        LanguageModelToolUse {
1890            id: "1".into(),
1891            name: ThinkingTool::name().into(),
1892            raw_input: input.to_string(),
1893            input,
1894            is_input_complete: false,
1895        },
1896    ));
1897
1898    // Input streaming completed
1899    let input = json!({ "content": "Thinking hard!" });
1900    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1901        LanguageModelToolUse {
1902            id: "1".into(),
1903            name: "thinking".into(),
1904            raw_input: input.to_string(),
1905            input,
1906            is_input_complete: true,
1907        },
1908    ));
1909    fake_model.end_last_completion_stream();
1910    cx.run_until_parked();
1911
1912    let tool_call = expect_tool_call(&mut events).await;
1913    assert_eq!(
1914        tool_call,
1915        acp::ToolCall {
1916            id: acp::ToolCallId("1".into()),
1917            title: "Thinking".into(),
1918            kind: acp::ToolKind::Think,
1919            status: acp::ToolCallStatus::Pending,
1920            content: vec![],
1921            locations: vec![],
1922            raw_input: Some(json!({})),
1923            raw_output: None,
1924        }
1925    );
1926    let update = expect_tool_call_update_fields(&mut events).await;
1927    assert_eq!(
1928        update,
1929        acp::ToolCallUpdate {
1930            id: acp::ToolCallId("1".into()),
1931            fields: acp::ToolCallUpdateFields {
1932                title: Some("Thinking".into()),
1933                kind: Some(acp::ToolKind::Think),
1934                raw_input: Some(json!({ "content": "Thinking hard!" })),
1935                ..Default::default()
1936            },
1937        }
1938    );
1939    let update = expect_tool_call_update_fields(&mut events).await;
1940    assert_eq!(
1941        update,
1942        acp::ToolCallUpdate {
1943            id: acp::ToolCallId("1".into()),
1944            fields: acp::ToolCallUpdateFields {
1945                status: Some(acp::ToolCallStatus::InProgress),
1946                ..Default::default()
1947            },
1948        }
1949    );
1950    let update = expect_tool_call_update_fields(&mut events).await;
1951    assert_eq!(
1952        update,
1953        acp::ToolCallUpdate {
1954            id: acp::ToolCallId("1".into()),
1955            fields: acp::ToolCallUpdateFields {
1956                content: Some(vec!["Thinking hard!".into()]),
1957                ..Default::default()
1958            },
1959        }
1960    );
1961    let update = expect_tool_call_update_fields(&mut events).await;
1962    assert_eq!(
1963        update,
1964        acp::ToolCallUpdate {
1965            id: acp::ToolCallId("1".into()),
1966            fields: acp::ToolCallUpdateFields {
1967                status: Some(acp::ToolCallStatus::Completed),
1968                raw_output: Some("Finished thinking.".into()),
1969                ..Default::default()
1970            },
1971        }
1972    );
1973}
1974
1975#[gpui::test]
1976async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
1977    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1978    let fake_model = model.as_fake();
1979
1980    let mut events = thread
1981        .update(cx, |thread, cx| {
1982            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1983            thread.send(UserMessageId::new(), ["Hello!"], cx)
1984        })
1985        .unwrap();
1986    cx.run_until_parked();
1987
1988    fake_model.send_last_completion_stream_text_chunk("Hey!");
1989    fake_model.end_last_completion_stream();
1990
1991    let mut retry_events = Vec::new();
1992    while let Some(Ok(event)) = events.next().await {
1993        match event {
1994            ThreadEvent::Retry(retry_status) => {
1995                retry_events.push(retry_status);
1996            }
1997            ThreadEvent::Stop(..) => break,
1998            _ => {}
1999        }
2000    }
2001
2002    assert_eq!(retry_events.len(), 0);
2003    thread.read_with(cx, |thread, _cx| {
2004        assert_eq!(
2005            thread.to_markdown(),
2006            indoc! {"
2007                ## User
2008
2009                Hello!
2010
2011                ## Assistant
2012
2013                Hey!
2014            "}
2015        )
2016    });
2017}
2018
2019#[gpui::test]
2020async fn test_send_retry_on_error(cx: &mut TestAppContext) {
2021    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2022    let fake_model = model.as_fake();
2023
2024    let mut events = thread
2025        .update(cx, |thread, cx| {
2026            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2027            thread.send(UserMessageId::new(), ["Hello!"], cx)
2028        })
2029        .unwrap();
2030    cx.run_until_parked();
2031
2032    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
2033        provider: LanguageModelProviderName::new("Anthropic"),
2034        retry_after: Some(Duration::from_secs(3)),
2035    });
2036    fake_model.end_last_completion_stream();
2037
2038    cx.executor().advance_clock(Duration::from_secs(3));
2039    cx.run_until_parked();
2040
2041    fake_model.send_last_completion_stream_text_chunk("Hey!");
2042    fake_model.end_last_completion_stream();
2043
2044    let mut retry_events = Vec::new();
2045    while let Some(Ok(event)) = events.next().await {
2046        match event {
2047            ThreadEvent::Retry(retry_status) => {
2048                retry_events.push(retry_status);
2049            }
2050            ThreadEvent::Stop(..) => break,
2051            _ => {}
2052        }
2053    }
2054
2055    assert_eq!(retry_events.len(), 1);
2056    assert!(matches!(
2057        retry_events[0],
2058        acp_thread::RetryStatus { attempt: 1, .. }
2059    ));
2060    thread.read_with(cx, |thread, _cx| {
2061        assert_eq!(
2062            thread.to_markdown(),
2063            indoc! {"
2064                ## User
2065
2066                Hello!
2067
2068                ## Assistant
2069
2070                Hey!
2071            "}
2072        )
2073    });
2074}
2075
2076#[gpui::test]
2077async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
2078    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
2079    let fake_model = model.as_fake();
2080
2081    let mut events = thread
2082        .update(cx, |thread, cx| {
2083            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
2084            thread.send(UserMessageId::new(), ["Hello!"], cx)
2085        })
2086        .unwrap();
2087    cx.run_until_parked();
2088
2089    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
2090        fake_model.send_last_completion_stream_error(
2091            LanguageModelCompletionError::ServerOverloaded {
2092                provider: LanguageModelProviderName::new("Anthropic"),
2093                retry_after: Some(Duration::from_secs(3)),
2094            },
2095        );
2096        fake_model.end_last_completion_stream();
2097        cx.executor().advance_clock(Duration::from_secs(3));
2098        cx.run_until_parked();
2099    }
2100
2101    let mut errors = Vec::new();
2102    let mut retry_events = Vec::new();
2103    while let Some(event) = events.next().await {
2104        match event {
2105            Ok(ThreadEvent::Retry(retry_status)) => {
2106                retry_events.push(retry_status);
2107            }
2108            Ok(ThreadEvent::Stop(..)) => break,
2109            Err(error) => errors.push(error),
2110            _ => {}
2111        }
2112    }
2113
2114    assert_eq!(
2115        retry_events.len(),
2116        crate::thread::MAX_RETRY_ATTEMPTS as usize
2117    );
2118    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
2119        assert_eq!(retry_events[i].attempt, i + 1);
2120    }
2121    assert_eq!(errors.len(), 1);
2122    let error = errors[0]
2123        .downcast_ref::<LanguageModelCompletionError>()
2124        .unwrap();
2125    assert!(matches!(
2126        error,
2127        LanguageModelCompletionError::ServerOverloaded { .. }
2128    ));
2129}
2130
2131/// Filters out the stop events for asserting against in tests
2132fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
2133    result_events
2134        .into_iter()
2135        .filter_map(|event| match event.unwrap() {
2136            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
2137            _ => None,
2138        })
2139        .collect()
2140}
2141
2142struct ThreadTest {
2143    model: Arc<dyn LanguageModel>,
2144    thread: Entity<Thread>,
2145    project_context: Entity<ProjectContext>,
2146    context_server_store: Entity<ContextServerStore>,
2147    fs: Arc<FakeFs>,
2148}
2149
2150enum TestModel {
2151    Sonnet4,
2152    Fake,
2153}
2154
2155impl TestModel {
2156    fn id(&self) -> LanguageModelId {
2157        match self {
2158            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
2159            TestModel::Fake => unreachable!(),
2160        }
2161    }
2162}
2163
2164async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
2165    cx.executor().allow_parking();
2166
2167    let fs = FakeFs::new(cx.background_executor.clone());
2168    fs.create_dir(paths::settings_file().parent().unwrap())
2169        .await
2170        .unwrap();
2171    fs.insert_file(
2172        paths::settings_file(),
2173        json!({
2174            "agent": {
2175                "default_profile": "test-profile",
2176                "profiles": {
2177                    "test-profile": {
2178                        "name": "Test Profile",
2179                        "tools": {
2180                            EchoTool::name(): true,
2181                            DelayTool::name(): true,
2182                            WordListTool::name(): true,
2183                            ToolRequiringPermission::name(): true,
2184                            InfiniteTool::name(): true,
2185                            ThinkingTool::name(): true,
2186                        }
2187                    }
2188                }
2189            }
2190        })
2191        .to_string()
2192        .into_bytes(),
2193    )
2194    .await;
2195
2196    cx.update(|cx| {
2197        settings::init(cx);
2198        Project::init_settings(cx);
2199        agent_settings::init(cx);
2200        gpui_tokio::init(cx);
2201        let http_client = ReqwestClient::user_agent("agent tests").unwrap();
2202        cx.set_http_client(Arc::new(http_client));
2203
2204        client::init_settings(cx);
2205        let client = Client::production(cx);
2206        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
2207        language_model::init(client.clone(), cx);
2208        language_models::init(user_store, client.clone(), cx);
2209
2210        watch_settings(fs.clone(), cx);
2211    });
2212
2213    let templates = Templates::new();
2214
2215    fs.insert_tree(path!("/test"), json!({})).await;
2216    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
2217
2218    let model = cx
2219        .update(|cx| {
2220            if let TestModel::Fake = model {
2221                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
2222            } else {
2223                let model_id = model.id();
2224                let models = LanguageModelRegistry::read_global(cx);
2225                let model = models
2226                    .available_models(cx)
2227                    .find(|model| model.id() == model_id)
2228                    .unwrap();
2229
2230                let provider = models.provider(&model.provider_id()).unwrap();
2231                let authenticated = provider.authenticate(cx);
2232
2233                cx.spawn(async move |_cx| {
2234                    authenticated.await.unwrap();
2235                    model
2236                })
2237            }
2238        })
2239        .await;
2240
2241    let project_context = cx.new(|_cx| ProjectContext::default());
2242    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
2243    let context_server_registry =
2244        cx.new(|cx| ContextServerRegistry::new(context_server_store.clone(), cx));
2245    let thread = cx.new(|cx| {
2246        Thread::new(
2247            project,
2248            project_context.clone(),
2249            context_server_registry,
2250            templates,
2251            Some(model.clone()),
2252            cx,
2253        )
2254    });
2255    ThreadTest {
2256        model,
2257        thread,
2258        project_context,
2259        context_server_store,
2260        fs,
2261    }
2262}
2263
2264#[cfg(test)]
2265#[ctor::ctor]
2266fn init_logger() {
2267    if std::env::var("RUST_LOG").is_ok() {
2268        env_logger::init();
2269    }
2270}
2271
2272fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
2273    let fs = fs.clone();
2274    cx.spawn({
2275        async move |cx| {
2276            let mut new_settings_content_rx = settings::watch_config_file(
2277                cx.background_executor(),
2278                fs,
2279                paths::settings_file().clone(),
2280            );
2281
2282            while let Some(new_settings_content) = new_settings_content_rx.next().await {
2283                cx.update(|cx| {
2284                    SettingsStore::update_global(cx, |settings, cx| {
2285                        settings.set_user_settings(&new_settings_content, cx)
2286                    })
2287                })
2288                .ok();
2289            }
2290        }
2291    })
2292    .detach();
2293}
2294
2295fn tool_names_for_completion(completion: &LanguageModelRequest) -> Vec<String> {
2296    completion
2297        .tools
2298        .iter()
2299        .map(|tool| tool.name.clone())
2300        .collect()
2301}
2302
2303fn setup_context_server(
2304    name: &'static str,
2305    tools: Vec<context_server::types::Tool>,
2306    context_server_store: &Entity<ContextServerStore>,
2307    cx: &mut TestAppContext,
2308) -> mpsc::UnboundedReceiver<(
2309    context_server::types::CallToolParams,
2310    oneshot::Sender<context_server::types::CallToolResponse>,
2311)> {
2312    cx.update(|cx| {
2313        let mut settings = ProjectSettings::get_global(cx).clone();
2314        settings.context_servers.insert(
2315            name.into(),
2316            project::project_settings::ContextServerSettings::Custom {
2317                enabled: true,
2318                command: ContextServerCommand {
2319                    path: "somebinary".into(),
2320                    args: Vec::new(),
2321                    env: None,
2322                },
2323            },
2324        );
2325        ProjectSettings::override_global(settings, cx);
2326    });
2327
2328    let (mcp_tool_calls_tx, mcp_tool_calls_rx) = mpsc::unbounded();
2329    let fake_transport = context_server::test::create_fake_transport(name, cx.executor())
2330        .on_request::<context_server::types::requests::Initialize, _>(move |_params| async move {
2331            context_server::types::InitializeResponse {
2332                protocol_version: context_server::types::ProtocolVersion(
2333                    context_server::types::LATEST_PROTOCOL_VERSION.to_string(),
2334                ),
2335                server_info: context_server::types::Implementation {
2336                    name: name.into(),
2337                    version: "1.0.0".to_string(),
2338                },
2339                capabilities: context_server::types::ServerCapabilities {
2340                    tools: Some(context_server::types::ToolsCapabilities {
2341                        list_changed: Some(true),
2342                    }),
2343                    ..Default::default()
2344                },
2345                meta: None,
2346            }
2347        })
2348        .on_request::<context_server::types::requests::ListTools, _>(move |_params| {
2349            let tools = tools.clone();
2350            async move {
2351                context_server::types::ListToolsResponse {
2352                    tools,
2353                    next_cursor: None,
2354                    meta: None,
2355                }
2356            }
2357        })
2358        .on_request::<context_server::types::requests::CallTool, _>(move |params| {
2359            let mcp_tool_calls_tx = mcp_tool_calls_tx.clone();
2360            async move {
2361                let (response_tx, response_rx) = oneshot::channel();
2362                mcp_tool_calls_tx
2363                    .unbounded_send((params, response_tx))
2364                    .unwrap();
2365                response_rx.await.unwrap()
2366            }
2367        });
2368    context_server_store.update(cx, |store, cx| {
2369        store.start_server(
2370            Arc::new(ContextServer::new(
2371                ContextServerId(name.into()),
2372                Arc::new(fake_transport),
2373            )),
2374            cx,
2375        );
2376    });
2377    cx.run_until_parked();
2378    mcp_tool_calls_rx
2379}