mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use action_log::ActionLog;
   4use agent_client_protocol::{self as acp};
   5use agent_settings::AgentProfileId;
   6use anyhow::Result;
   7use client::{Client, UserStore};
   8use fs::{FakeFs, Fs};
   9use futures::channel::mpsc::UnboundedReceiver;
  10use gpui::{
  11    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  12};
  13use indoc::indoc;
  14use language_model::{
  15    LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry,
  16    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
  17    Role, StopReason, fake_provider::FakeLanguageModel,
  18};
  19use pretty_assertions::assert_eq;
  20use project::Project;
  21use prompt_store::ProjectContext;
  22use reqwest_client::ReqwestClient;
  23use schemars::JsonSchema;
  24use serde::{Deserialize, Serialize};
  25use serde_json::json;
  26use settings::SettingsStore;
  27use smol::stream::StreamExt;
  28use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
  29use util::path;
  30
  31mod test_tools;
  32use test_tools::*;
  33
  34#[gpui::test]
  35#[ignore = "can't run on CI yet"]
  36async fn test_echo(cx: &mut TestAppContext) {
  37    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
  38
  39    let events = thread
  40        .update(cx, |thread, cx| {
  41            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  42        })
  43        .collect()
  44        .await;
  45    thread.update(cx, |thread, _cx| {
  46        assert_eq!(
  47            thread.last_message().unwrap().to_markdown(),
  48            indoc! {"
  49                ## Assistant
  50
  51                Hello
  52            "}
  53        )
  54    });
  55    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  56}
  57
  58#[gpui::test]
  59#[ignore = "can't run on CI yet"]
  60async fn test_thinking(cx: &mut TestAppContext) {
  61    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
  62
  63    let events = thread
  64        .update(cx, |thread, cx| {
  65            thread.send(
  66                UserMessageId::new(),
  67                [indoc! {"
  68                    Testing:
  69
  70                    Generate a thinking step where you just think the word 'Think',
  71                    and have your final answer be 'Hello'
  72                "}],
  73                cx,
  74            )
  75        })
  76        .collect()
  77        .await;
  78    thread.update(cx, |thread, _cx| {
  79        assert_eq!(
  80            thread.last_message().unwrap().to_markdown(),
  81            indoc! {"
  82                ## Assistant
  83
  84                <think>Think</think>
  85                Hello
  86            "}
  87        )
  88    });
  89    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  90}
  91
  92#[gpui::test]
  93async fn test_system_prompt(cx: &mut TestAppContext) {
  94    let ThreadTest {
  95        model,
  96        thread,
  97        project_context,
  98        ..
  99    } = setup(cx, TestModel::Fake).await;
 100    let fake_model = model.as_fake();
 101
 102    project_context.borrow_mut().shell = "test-shell".into();
 103    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 104    thread.update(cx, |thread, cx| {
 105        thread.send(UserMessageId::new(), ["abc"], cx)
 106    });
 107    cx.run_until_parked();
 108    let mut pending_completions = fake_model.pending_completions();
 109    assert_eq!(
 110        pending_completions.len(),
 111        1,
 112        "unexpected pending completions: {:?}",
 113        pending_completions
 114    );
 115
 116    let pending_completion = pending_completions.pop().unwrap();
 117    assert_eq!(pending_completion.messages[0].role, Role::System);
 118
 119    let system_message = &pending_completion.messages[0];
 120    let system_prompt = system_message.content[0].to_str().unwrap();
 121    assert!(
 122        system_prompt.contains("test-shell"),
 123        "unexpected system message: {:?}",
 124        system_message
 125    );
 126    assert!(
 127        system_prompt.contains("## Fixing Diagnostics"),
 128        "unexpected system message: {:?}",
 129        system_message
 130    );
 131}
 132
 133#[gpui::test]
 134async fn test_prompt_caching(cx: &mut TestAppContext) {
 135    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 136    let fake_model = model.as_fake();
 137
 138    // Send initial user message and verify it's cached
 139    thread.update(cx, |thread, cx| {
 140        thread.send(UserMessageId::new(), ["Message 1"], cx)
 141    });
 142    cx.run_until_parked();
 143
 144    let completion = fake_model.pending_completions().pop().unwrap();
 145    assert_eq!(
 146        completion.messages[1..],
 147        vec![LanguageModelRequestMessage {
 148            role: Role::User,
 149            content: vec!["Message 1".into()],
 150            cache: true
 151        }]
 152    );
 153    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 154        "Response to Message 1".into(),
 155    ));
 156    fake_model.end_last_completion_stream();
 157    cx.run_until_parked();
 158
 159    // Send another user message and verify only the latest is cached
 160    thread.update(cx, |thread, cx| {
 161        thread.send(UserMessageId::new(), ["Message 2"], cx)
 162    });
 163    cx.run_until_parked();
 164
 165    let completion = fake_model.pending_completions().pop().unwrap();
 166    assert_eq!(
 167        completion.messages[1..],
 168        vec![
 169            LanguageModelRequestMessage {
 170                role: Role::User,
 171                content: vec!["Message 1".into()],
 172                cache: false
 173            },
 174            LanguageModelRequestMessage {
 175                role: Role::Assistant,
 176                content: vec!["Response to Message 1".into()],
 177                cache: false
 178            },
 179            LanguageModelRequestMessage {
 180                role: Role::User,
 181                content: vec!["Message 2".into()],
 182                cache: true
 183            }
 184        ]
 185    );
 186    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 187        "Response to Message 2".into(),
 188    ));
 189    fake_model.end_last_completion_stream();
 190    cx.run_until_parked();
 191
 192    // Simulate a tool call and verify that the latest tool result is cached
 193    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 194    thread.update(cx, |thread, cx| {
 195        thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 196    });
 197    cx.run_until_parked();
 198
 199    let tool_use = LanguageModelToolUse {
 200        id: "tool_1".into(),
 201        name: EchoTool.name().into(),
 202        raw_input: json!({"text": "test"}).to_string(),
 203        input: json!({"text": "test"}),
 204        is_input_complete: true,
 205    };
 206    fake_model
 207        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 208    fake_model.end_last_completion_stream();
 209    cx.run_until_parked();
 210
 211    let completion = fake_model.pending_completions().pop().unwrap();
 212    let tool_result = LanguageModelToolResult {
 213        tool_use_id: "tool_1".into(),
 214        tool_name: EchoTool.name().into(),
 215        is_error: false,
 216        content: "test".into(),
 217        output: Some("test".into()),
 218    };
 219    assert_eq!(
 220        completion.messages[1..],
 221        vec![
 222            LanguageModelRequestMessage {
 223                role: Role::User,
 224                content: vec!["Message 1".into()],
 225                cache: false
 226            },
 227            LanguageModelRequestMessage {
 228                role: Role::Assistant,
 229                content: vec!["Response to Message 1".into()],
 230                cache: false
 231            },
 232            LanguageModelRequestMessage {
 233                role: Role::User,
 234                content: vec!["Message 2".into()],
 235                cache: false
 236            },
 237            LanguageModelRequestMessage {
 238                role: Role::Assistant,
 239                content: vec!["Response to Message 2".into()],
 240                cache: false
 241            },
 242            LanguageModelRequestMessage {
 243                role: Role::User,
 244                content: vec!["Use the echo tool".into()],
 245                cache: false
 246            },
 247            LanguageModelRequestMessage {
 248                role: Role::Assistant,
 249                content: vec![MessageContent::ToolUse(tool_use)],
 250                cache: false
 251            },
 252            LanguageModelRequestMessage {
 253                role: Role::User,
 254                content: vec![MessageContent::ToolResult(tool_result)],
 255                cache: true
 256            }
 257        ]
 258    );
 259}
 260
 261#[gpui::test]
 262#[ignore = "can't run on CI yet"]
 263async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 264    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 265
 266    // Test a tool call that's likely to complete *before* streaming stops.
 267    let events = thread
 268        .update(cx, |thread, cx| {
 269            thread.add_tool(EchoTool);
 270            thread.send(
 271                UserMessageId::new(),
 272                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 273                cx,
 274            )
 275        })
 276        .collect()
 277        .await;
 278    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 279
 280    // Test a tool calls that's likely to complete *after* streaming stops.
 281    let events = thread
 282        .update(cx, |thread, cx| {
 283            thread.remove_tool(&AgentTool::name(&EchoTool));
 284            thread.add_tool(DelayTool);
 285            thread.send(
 286                UserMessageId::new(),
 287                [
 288                    "Now call the delay tool with 200ms.",
 289                    "When the timer goes off, then you echo the output of the tool.",
 290                ],
 291                cx,
 292            )
 293        })
 294        .collect()
 295        .await;
 296    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 297    thread.update(cx, |thread, _cx| {
 298        assert!(
 299            thread
 300                .last_message()
 301                .unwrap()
 302                .as_agent_message()
 303                .unwrap()
 304                .content
 305                .iter()
 306                .any(|content| {
 307                    if let AgentMessageContent::Text(text) = content {
 308                        text.contains("Ding")
 309                    } else {
 310                        false
 311                    }
 312                }),
 313            "{}",
 314            thread.to_markdown()
 315        );
 316    });
 317}
 318
 319#[gpui::test]
 320#[ignore = "can't run on CI yet"]
 321async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 322    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 323
 324    // Test a tool call that's likely to complete *before* streaming stops.
 325    let mut events = thread.update(cx, |thread, cx| {
 326        thread.add_tool(WordListTool);
 327        thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 328    });
 329
 330    let mut saw_partial_tool_use = false;
 331    while let Some(event) = events.next().await {
 332        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 333            thread.update(cx, |thread, _cx| {
 334                // Look for a tool use in the thread's last message
 335                let message = thread.last_message().unwrap();
 336                let agent_message = message.as_agent_message().unwrap();
 337                let last_content = agent_message.content.last().unwrap();
 338                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 339                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 340                    if tool_call.status == acp::ToolCallStatus::Pending {
 341                        if !last_tool_use.is_input_complete
 342                            && last_tool_use.input.get("g").is_none()
 343                        {
 344                            saw_partial_tool_use = true;
 345                        }
 346                    } else {
 347                        last_tool_use
 348                            .input
 349                            .get("a")
 350                            .expect("'a' has streamed because input is now complete");
 351                        last_tool_use
 352                            .input
 353                            .get("g")
 354                            .expect("'g' has streamed because input is now complete");
 355                    }
 356                } else {
 357                    panic!("last content should be a tool use");
 358                }
 359            });
 360        }
 361    }
 362
 363    assert!(
 364        saw_partial_tool_use,
 365        "should see at least one partially streamed tool use in the history"
 366    );
 367}
 368
 369#[gpui::test]
 370async fn test_tool_authorization(cx: &mut TestAppContext) {
 371    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 372    let fake_model = model.as_fake();
 373
 374    let mut events = thread.update(cx, |thread, cx| {
 375        thread.add_tool(ToolRequiringPermission);
 376        thread.send(UserMessageId::new(), ["abc"], cx)
 377    });
 378    cx.run_until_parked();
 379    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 380        LanguageModelToolUse {
 381            id: "tool_id_1".into(),
 382            name: ToolRequiringPermission.name().into(),
 383            raw_input: "{}".into(),
 384            input: json!({}),
 385            is_input_complete: true,
 386        },
 387    ));
 388    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 389        LanguageModelToolUse {
 390            id: "tool_id_2".into(),
 391            name: ToolRequiringPermission.name().into(),
 392            raw_input: "{}".into(),
 393            input: json!({}),
 394            is_input_complete: true,
 395        },
 396    ));
 397    fake_model.end_last_completion_stream();
 398    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 399    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 400
 401    // Approve the first
 402    tool_call_auth_1
 403        .response
 404        .send(tool_call_auth_1.options[1].id.clone())
 405        .unwrap();
 406    cx.run_until_parked();
 407
 408    // Reject the second
 409    tool_call_auth_2
 410        .response
 411        .send(tool_call_auth_1.options[2].id.clone())
 412        .unwrap();
 413    cx.run_until_parked();
 414
 415    let completion = fake_model.pending_completions().pop().unwrap();
 416    let message = completion.messages.last().unwrap();
 417    assert_eq!(
 418        message.content,
 419        vec![
 420            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 421                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 422                tool_name: ToolRequiringPermission.name().into(),
 423                is_error: false,
 424                content: "Allowed".into(),
 425                output: Some("Allowed".into())
 426            }),
 427            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 428                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 429                tool_name: ToolRequiringPermission.name().into(),
 430                is_error: true,
 431                content: "Permission to run tool denied by user".into(),
 432                output: None
 433            })
 434        ]
 435    );
 436
 437    // Simulate yet another tool call.
 438    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 439        LanguageModelToolUse {
 440            id: "tool_id_3".into(),
 441            name: ToolRequiringPermission.name().into(),
 442            raw_input: "{}".into(),
 443            input: json!({}),
 444            is_input_complete: true,
 445        },
 446    ));
 447    fake_model.end_last_completion_stream();
 448
 449    // Respond by always allowing tools.
 450    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 451    tool_call_auth_3
 452        .response
 453        .send(tool_call_auth_3.options[0].id.clone())
 454        .unwrap();
 455    cx.run_until_parked();
 456    let completion = fake_model.pending_completions().pop().unwrap();
 457    let message = completion.messages.last().unwrap();
 458    assert_eq!(
 459        message.content,
 460        vec![language_model::MessageContent::ToolResult(
 461            LanguageModelToolResult {
 462                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 463                tool_name: ToolRequiringPermission.name().into(),
 464                is_error: false,
 465                content: "Allowed".into(),
 466                output: Some("Allowed".into())
 467            }
 468        )]
 469    );
 470
 471    // Simulate a final tool call, ensuring we don't trigger authorization.
 472    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 473        LanguageModelToolUse {
 474            id: "tool_id_4".into(),
 475            name: ToolRequiringPermission.name().into(),
 476            raw_input: "{}".into(),
 477            input: json!({}),
 478            is_input_complete: true,
 479        },
 480    ));
 481    fake_model.end_last_completion_stream();
 482    cx.run_until_parked();
 483    let completion = fake_model.pending_completions().pop().unwrap();
 484    let message = completion.messages.last().unwrap();
 485    assert_eq!(
 486        message.content,
 487        vec![language_model::MessageContent::ToolResult(
 488            LanguageModelToolResult {
 489                tool_use_id: "tool_id_4".into(),
 490                tool_name: ToolRequiringPermission.name().into(),
 491                is_error: false,
 492                content: "Allowed".into(),
 493                output: Some("Allowed".into())
 494            }
 495        )]
 496    );
 497}
 498
 499#[gpui::test]
 500async fn test_tool_hallucination(cx: &mut TestAppContext) {
 501    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 502    let fake_model = model.as_fake();
 503
 504    let mut events = thread.update(cx, |thread, cx| {
 505        thread.send(UserMessageId::new(), ["abc"], cx)
 506    });
 507    cx.run_until_parked();
 508    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 509        LanguageModelToolUse {
 510            id: "tool_id_1".into(),
 511            name: "nonexistent_tool".into(),
 512            raw_input: "{}".into(),
 513            input: json!({}),
 514            is_input_complete: true,
 515        },
 516    ));
 517    fake_model.end_last_completion_stream();
 518
 519    let tool_call = expect_tool_call(&mut events).await;
 520    assert_eq!(tool_call.title, "nonexistent_tool");
 521    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 522    let update = expect_tool_call_update_fields(&mut events).await;
 523    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 524}
 525
 526#[gpui::test]
 527async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 528    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 529    let fake_model = model.as_fake();
 530
 531    let events = thread.update(cx, |thread, cx| {
 532        thread.add_tool(EchoTool);
 533        thread.send(UserMessageId::new(), ["abc"], cx)
 534    });
 535    cx.run_until_parked();
 536    let tool_use = LanguageModelToolUse {
 537        id: "tool_id_1".into(),
 538        name: EchoTool.name().into(),
 539        raw_input: "{}".into(),
 540        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 541        is_input_complete: true,
 542    };
 543    fake_model
 544        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 545    fake_model.end_last_completion_stream();
 546
 547    cx.run_until_parked();
 548    let completion = fake_model.pending_completions().pop().unwrap();
 549    let tool_result = LanguageModelToolResult {
 550        tool_use_id: "tool_id_1".into(),
 551        tool_name: EchoTool.name().into(),
 552        is_error: false,
 553        content: "def".into(),
 554        output: Some("def".into()),
 555    };
 556    assert_eq!(
 557        completion.messages[1..],
 558        vec![
 559            LanguageModelRequestMessage {
 560                role: Role::User,
 561                content: vec!["abc".into()],
 562                cache: false
 563            },
 564            LanguageModelRequestMessage {
 565                role: Role::Assistant,
 566                content: vec![MessageContent::ToolUse(tool_use.clone())],
 567                cache: false
 568            },
 569            LanguageModelRequestMessage {
 570                role: Role::User,
 571                content: vec![MessageContent::ToolResult(tool_result.clone())],
 572                cache: true
 573            },
 574        ]
 575    );
 576
 577    // Simulate reaching tool use limit.
 578    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 579        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 580    ));
 581    fake_model.end_last_completion_stream();
 582    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 583    assert!(
 584        last_event
 585            .unwrap_err()
 586            .is::<language_model::ToolUseLimitReachedError>()
 587    );
 588
 589    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 590    cx.run_until_parked();
 591    let completion = fake_model.pending_completions().pop().unwrap();
 592    assert_eq!(
 593        completion.messages[1..],
 594        vec![
 595            LanguageModelRequestMessage {
 596                role: Role::User,
 597                content: vec!["abc".into()],
 598                cache: false
 599            },
 600            LanguageModelRequestMessage {
 601                role: Role::Assistant,
 602                content: vec![MessageContent::ToolUse(tool_use)],
 603                cache: false
 604            },
 605            LanguageModelRequestMessage {
 606                role: Role::User,
 607                content: vec![MessageContent::ToolResult(tool_result)],
 608                cache: false
 609            },
 610            LanguageModelRequestMessage {
 611                role: Role::User,
 612                content: vec!["Continue where you left off".into()],
 613                cache: true
 614            }
 615        ]
 616    );
 617
 618    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 619    fake_model.end_last_completion_stream();
 620    events.collect::<Vec<_>>().await;
 621    thread.read_with(cx, |thread, _cx| {
 622        assert_eq!(
 623            thread.last_message().unwrap().to_markdown(),
 624            indoc! {"
 625                ## Assistant
 626
 627                Done
 628            "}
 629        )
 630    });
 631
 632    // Ensure we error if calling resume when tool use limit was *not* reached.
 633    let error = thread
 634        .update(cx, |thread, cx| thread.resume(cx))
 635        .unwrap_err();
 636    assert_eq!(
 637        error.to_string(),
 638        "can only resume after tool use limit is reached"
 639    )
 640}
 641
 642#[gpui::test]
 643async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 644    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 645    let fake_model = model.as_fake();
 646
 647    let events = thread.update(cx, |thread, cx| {
 648        thread.add_tool(EchoTool);
 649        thread.send(UserMessageId::new(), ["abc"], cx)
 650    });
 651    cx.run_until_parked();
 652
 653    let tool_use = LanguageModelToolUse {
 654        id: "tool_id_1".into(),
 655        name: EchoTool.name().into(),
 656        raw_input: "{}".into(),
 657        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 658        is_input_complete: true,
 659    };
 660    let tool_result = LanguageModelToolResult {
 661        tool_use_id: "tool_id_1".into(),
 662        tool_name: EchoTool.name().into(),
 663        is_error: false,
 664        content: "def".into(),
 665        output: Some("def".into()),
 666    };
 667    fake_model
 668        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 669    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 670        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 671    ));
 672    fake_model.end_last_completion_stream();
 673    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 674    assert!(
 675        last_event
 676            .unwrap_err()
 677            .is::<language_model::ToolUseLimitReachedError>()
 678    );
 679
 680    thread.update(cx, |thread, cx| {
 681        thread.send(UserMessageId::new(), vec!["ghi"], cx)
 682    });
 683    cx.run_until_parked();
 684    let completion = fake_model.pending_completions().pop().unwrap();
 685    assert_eq!(
 686        completion.messages[1..],
 687        vec![
 688            LanguageModelRequestMessage {
 689                role: Role::User,
 690                content: vec!["abc".into()],
 691                cache: false
 692            },
 693            LanguageModelRequestMessage {
 694                role: Role::Assistant,
 695                content: vec![MessageContent::ToolUse(tool_use)],
 696                cache: false
 697            },
 698            LanguageModelRequestMessage {
 699                role: Role::User,
 700                content: vec![MessageContent::ToolResult(tool_result)],
 701                cache: false
 702            },
 703            LanguageModelRequestMessage {
 704                role: Role::User,
 705                content: vec!["ghi".into()],
 706                cache: true
 707            }
 708        ]
 709    );
 710}
 711
 712async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
 713    let event = events
 714        .next()
 715        .await
 716        .expect("no tool call authorization event received")
 717        .unwrap();
 718    match event {
 719        ThreadEvent::ToolCall(tool_call) => return tool_call,
 720        event => {
 721            panic!("Unexpected event {event:?}");
 722        }
 723    }
 724}
 725
 726async fn expect_tool_call_update_fields(
 727    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 728) -> acp::ToolCallUpdate {
 729    let event = events
 730        .next()
 731        .await
 732        .expect("no tool call authorization event received")
 733        .unwrap();
 734    match event {
 735        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
 736            return update;
 737        }
 738        event => {
 739            panic!("Unexpected event {event:?}");
 740        }
 741    }
 742}
 743
 744async fn next_tool_call_authorization(
 745    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 746) -> ToolCallAuthorization {
 747    loop {
 748        let event = events
 749            .next()
 750            .await
 751            .expect("no tool call authorization event received")
 752            .unwrap();
 753        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
 754            let permission_kinds = tool_call_authorization
 755                .options
 756                .iter()
 757                .map(|o| o.kind)
 758                .collect::<Vec<_>>();
 759            assert_eq!(
 760                permission_kinds,
 761                vec![
 762                    acp::PermissionOptionKind::AllowAlways,
 763                    acp::PermissionOptionKind::AllowOnce,
 764                    acp::PermissionOptionKind::RejectOnce,
 765                ]
 766            );
 767            return tool_call_authorization;
 768        }
 769    }
 770}
 771
 772#[gpui::test]
 773#[ignore = "can't run on CI yet"]
 774async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 775    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 776
 777    // Test concurrent tool calls with different delay times
 778    let events = thread
 779        .update(cx, |thread, cx| {
 780            thread.add_tool(DelayTool);
 781            thread.send(
 782                UserMessageId::new(),
 783                [
 784                    "Call the delay tool twice in the same message.",
 785                    "Once with 100ms. Once with 300ms.",
 786                    "When both timers are complete, describe the outputs.",
 787                ],
 788                cx,
 789            )
 790        })
 791        .collect()
 792        .await;
 793
 794    let stop_reasons = stop_events(events);
 795    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 796
 797    thread.update(cx, |thread, _cx| {
 798        let last_message = thread.last_message().unwrap();
 799        let agent_message = last_message.as_agent_message().unwrap();
 800        let text = agent_message
 801            .content
 802            .iter()
 803            .filter_map(|content| {
 804                if let AgentMessageContent::Text(text) = content {
 805                    Some(text.as_str())
 806                } else {
 807                    None
 808                }
 809            })
 810            .collect::<String>();
 811
 812        assert!(text.contains("Ding"));
 813    });
 814}
 815
 816#[gpui::test]
 817async fn test_profiles(cx: &mut TestAppContext) {
 818    let ThreadTest {
 819        model, thread, fs, ..
 820    } = setup(cx, TestModel::Fake).await;
 821    let fake_model = model.as_fake();
 822
 823    thread.update(cx, |thread, _cx| {
 824        thread.add_tool(DelayTool);
 825        thread.add_tool(EchoTool);
 826        thread.add_tool(InfiniteTool);
 827    });
 828
 829    // Override profiles and wait for settings to be loaded.
 830    fs.insert_file(
 831        paths::settings_file(),
 832        json!({
 833            "agent": {
 834                "profiles": {
 835                    "test-1": {
 836                        "name": "Test Profile 1",
 837                        "tools": {
 838                            EchoTool.name(): true,
 839                            DelayTool.name(): true,
 840                        }
 841                    },
 842                    "test-2": {
 843                        "name": "Test Profile 2",
 844                        "tools": {
 845                            InfiniteTool.name(): true,
 846                        }
 847                    }
 848                }
 849            }
 850        })
 851        .to_string()
 852        .into_bytes(),
 853    )
 854    .await;
 855    cx.run_until_parked();
 856
 857    // Test that test-1 profile (default) has echo and delay tools
 858    thread.update(cx, |thread, cx| {
 859        thread.set_profile(AgentProfileId("test-1".into()));
 860        thread.send(UserMessageId::new(), ["test"], cx);
 861    });
 862    cx.run_until_parked();
 863
 864    let mut pending_completions = fake_model.pending_completions();
 865    assert_eq!(pending_completions.len(), 1);
 866    let completion = pending_completions.pop().unwrap();
 867    let tool_names: Vec<String> = completion
 868        .tools
 869        .iter()
 870        .map(|tool| tool.name.clone())
 871        .collect();
 872    assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
 873    fake_model.end_last_completion_stream();
 874
 875    // Switch to test-2 profile, and verify that it has only the infinite tool.
 876    thread.update(cx, |thread, cx| {
 877        thread.set_profile(AgentProfileId("test-2".into()));
 878        thread.send(UserMessageId::new(), ["test2"], cx)
 879    });
 880    cx.run_until_parked();
 881    let mut pending_completions = fake_model.pending_completions();
 882    assert_eq!(pending_completions.len(), 1);
 883    let completion = pending_completions.pop().unwrap();
 884    let tool_names: Vec<String> = completion
 885        .tools
 886        .iter()
 887        .map(|tool| tool.name.clone())
 888        .collect();
 889    assert_eq!(tool_names, vec![InfiniteTool.name()]);
 890}
 891
 892#[gpui::test]
 893#[ignore = "can't run on CI yet"]
 894async fn test_cancellation(cx: &mut TestAppContext) {
 895    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 896
 897    let mut events = thread.update(cx, |thread, cx| {
 898        thread.add_tool(InfiniteTool);
 899        thread.add_tool(EchoTool);
 900        thread.send(
 901            UserMessageId::new(),
 902            ["Call the echo tool, then call the infinite tool, then explain their output"],
 903            cx,
 904        )
 905    });
 906
 907    // Wait until both tools are called.
 908    let mut expected_tools = vec!["Echo", "Infinite Tool"];
 909    let mut echo_id = None;
 910    let mut echo_completed = false;
 911    while let Some(event) = events.next().await {
 912        match event.unwrap() {
 913            ThreadEvent::ToolCall(tool_call) => {
 914                assert_eq!(tool_call.title, expected_tools.remove(0));
 915                if tool_call.title == "Echo" {
 916                    echo_id = Some(tool_call.id);
 917                }
 918            }
 919            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
 920                acp::ToolCallUpdate {
 921                    id,
 922                    fields:
 923                        acp::ToolCallUpdateFields {
 924                            status: Some(acp::ToolCallStatus::Completed),
 925                            ..
 926                        },
 927                },
 928            )) if Some(&id) == echo_id.as_ref() => {
 929                echo_completed = true;
 930            }
 931            _ => {}
 932        }
 933
 934        if expected_tools.is_empty() && echo_completed {
 935            break;
 936        }
 937    }
 938
 939    // Cancel the current send and ensure that the event stream is closed, even
 940    // if one of the tools is still running.
 941    thread.update(cx, |thread, cx| thread.cancel(cx));
 942    let events = events.collect::<Vec<_>>().await;
 943    let last_event = events.last();
 944    assert!(
 945        matches!(
 946            last_event,
 947            Some(Ok(ThreadEvent::Stop(acp::StopReason::Canceled)))
 948        ),
 949        "unexpected event {last_event:?}"
 950    );
 951
 952    // Ensure we can still send a new message after cancellation.
 953    let events = thread
 954        .update(cx, |thread, cx| {
 955            thread.send(
 956                UserMessageId::new(),
 957                ["Testing: reply with 'Hello' then stop."],
 958                cx,
 959            )
 960        })
 961        .collect::<Vec<_>>()
 962        .await;
 963    thread.update(cx, |thread, _cx| {
 964        let message = thread.last_message().unwrap();
 965        let agent_message = message.as_agent_message().unwrap();
 966        assert_eq!(
 967            agent_message.content,
 968            vec![AgentMessageContent::Text("Hello".to_string())]
 969        );
 970    });
 971    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 972}
 973
 974#[gpui::test]
 975async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
 976    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 977    let fake_model = model.as_fake();
 978
 979    let events_1 = thread.update(cx, |thread, cx| {
 980        thread.send(UserMessageId::new(), ["Hello 1"], cx)
 981    });
 982    cx.run_until_parked();
 983    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
 984    cx.run_until_parked();
 985
 986    let events_2 = thread.update(cx, |thread, cx| {
 987        thread.send(UserMessageId::new(), ["Hello 2"], cx)
 988    });
 989    cx.run_until_parked();
 990    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
 991    fake_model
 992        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
 993    fake_model.end_last_completion_stream();
 994
 995    let events_1 = events_1.collect::<Vec<_>>().await;
 996    assert_eq!(stop_events(events_1), vec![acp::StopReason::Canceled]);
 997    let events_2 = events_2.collect::<Vec<_>>().await;
 998    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
 999}
1000
1001#[gpui::test]
1002async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1003    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1004    let fake_model = model.as_fake();
1005
1006    let events_1 = thread.update(cx, |thread, cx| {
1007        thread.send(UserMessageId::new(), ["Hello 1"], cx)
1008    });
1009    cx.run_until_parked();
1010    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1011    fake_model
1012        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1013    fake_model.end_last_completion_stream();
1014    let events_1 = events_1.collect::<Vec<_>>().await;
1015
1016    let events_2 = thread.update(cx, |thread, cx| {
1017        thread.send(UserMessageId::new(), ["Hello 2"], cx)
1018    });
1019    cx.run_until_parked();
1020    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1021    fake_model
1022        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1023    fake_model.end_last_completion_stream();
1024    let events_2 = events_2.collect::<Vec<_>>().await;
1025
1026    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1027    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1028}
1029
1030#[gpui::test]
1031async fn test_refusal(cx: &mut TestAppContext) {
1032    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1033    let fake_model = model.as_fake();
1034
1035    let events = thread.update(cx, |thread, cx| {
1036        thread.send(UserMessageId::new(), ["Hello"], cx)
1037    });
1038    cx.run_until_parked();
1039    thread.read_with(cx, |thread, _| {
1040        assert_eq!(
1041            thread.to_markdown(),
1042            indoc! {"
1043                ## User
1044
1045                Hello
1046            "}
1047        );
1048    });
1049
1050    fake_model.send_last_completion_stream_text_chunk("Hey!");
1051    cx.run_until_parked();
1052    thread.read_with(cx, |thread, _| {
1053        assert_eq!(
1054            thread.to_markdown(),
1055            indoc! {"
1056                ## User
1057
1058                Hello
1059
1060                ## Assistant
1061
1062                Hey!
1063            "}
1064        );
1065    });
1066
1067    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1068    fake_model
1069        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1070    let events = events.collect::<Vec<_>>().await;
1071    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1072    thread.read_with(cx, |thread, _| {
1073        assert_eq!(thread.to_markdown(), "");
1074    });
1075}
1076
1077#[gpui::test]
1078async fn test_truncate(cx: &mut TestAppContext) {
1079    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1080    let fake_model = model.as_fake();
1081
1082    let message_id = UserMessageId::new();
1083    thread.update(cx, |thread, cx| {
1084        thread.send(message_id.clone(), ["Hello"], cx)
1085    });
1086    cx.run_until_parked();
1087    thread.read_with(cx, |thread, _| {
1088        assert_eq!(
1089            thread.to_markdown(),
1090            indoc! {"
1091                ## User
1092
1093                Hello
1094            "}
1095        );
1096    });
1097
1098    fake_model.send_last_completion_stream_text_chunk("Hey!");
1099    cx.run_until_parked();
1100    thread.read_with(cx, |thread, _| {
1101        assert_eq!(
1102            thread.to_markdown(),
1103            indoc! {"
1104                ## User
1105
1106                Hello
1107
1108                ## Assistant
1109
1110                Hey!
1111            "}
1112        );
1113    });
1114
1115    thread
1116        .update(cx, |thread, cx| thread.truncate(message_id, cx))
1117        .unwrap();
1118    cx.run_until_parked();
1119    thread.read_with(cx, |thread, _| {
1120        assert_eq!(thread.to_markdown(), "");
1121    });
1122
1123    // Ensure we can still send a new message after truncation.
1124    thread.update(cx, |thread, cx| {
1125        thread.send(UserMessageId::new(), ["Hi"], cx)
1126    });
1127    thread.update(cx, |thread, _cx| {
1128        assert_eq!(
1129            thread.to_markdown(),
1130            indoc! {"
1131                ## User
1132
1133                Hi
1134            "}
1135        );
1136    });
1137    cx.run_until_parked();
1138    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1139    cx.run_until_parked();
1140    thread.read_with(cx, |thread, _| {
1141        assert_eq!(
1142            thread.to_markdown(),
1143            indoc! {"
1144                ## User
1145
1146                Hi
1147
1148                ## Assistant
1149
1150                Ahoy!
1151            "}
1152        );
1153    });
1154}
1155
1156#[gpui::test]
1157async fn test_agent_connection(cx: &mut TestAppContext) {
1158    cx.update(settings::init);
1159    let templates = Templates::new();
1160
1161    // Initialize language model system with test provider
1162    cx.update(|cx| {
1163        gpui_tokio::init(cx);
1164        client::init_settings(cx);
1165
1166        let http_client = FakeHttpClient::with_404_response();
1167        let clock = Arc::new(clock::FakeSystemClock::new());
1168        let client = Client::new(clock, http_client, cx);
1169        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1170        language_model::init(client.clone(), cx);
1171        language_models::init(user_store.clone(), client.clone(), cx);
1172        Project::init_settings(cx);
1173        LanguageModelRegistry::test(cx);
1174        agent_settings::init(cx);
1175    });
1176    cx.executor().forbid_parking();
1177
1178    // Create a project for new_thread
1179    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1180    fake_fs.insert_tree(path!("/test"), json!({})).await;
1181    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1182    let cwd = Path::new("/test");
1183
1184    // Create agent and connection
1185    let agent = NativeAgent::new(
1186        project.clone(),
1187        templates.clone(),
1188        None,
1189        fake_fs.clone(),
1190        &mut cx.to_async(),
1191    )
1192    .await
1193    .unwrap();
1194    let connection = NativeAgentConnection(agent.clone());
1195
1196    // Test model_selector returns Some
1197    let selector_opt = connection.model_selector();
1198    assert!(
1199        selector_opt.is_some(),
1200        "agent2 should always support ModelSelector"
1201    );
1202    let selector = selector_opt.unwrap();
1203
1204    // Test list_models
1205    let listed_models = cx
1206        .update(|cx| selector.list_models(cx))
1207        .await
1208        .expect("list_models should succeed");
1209    let AgentModelList::Grouped(listed_models) = listed_models else {
1210        panic!("Unexpected model list type");
1211    };
1212    assert!(!listed_models.is_empty(), "should have at least one model");
1213    assert_eq!(
1214        listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1215        "fake/fake"
1216    );
1217
1218    // Create a thread using new_thread
1219    let connection_rc = Rc::new(connection.clone());
1220    let acp_thread = cx
1221        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1222        .await
1223        .expect("new_thread should succeed");
1224
1225    // Get the session_id from the AcpThread
1226    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1227
1228    // Test selected_model returns the default
1229    let model = cx
1230        .update(|cx| selector.selected_model(&session_id, cx))
1231        .await
1232        .expect("selected_model should succeed");
1233    let model = cx
1234        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1235        .unwrap();
1236    let model = model.as_fake();
1237    assert_eq!(model.id().0, "fake", "should return default model");
1238
1239    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1240    cx.run_until_parked();
1241    model.send_last_completion_stream_text_chunk("def");
1242    cx.run_until_parked();
1243    acp_thread.read_with(cx, |thread, cx| {
1244        assert_eq!(
1245            thread.to_markdown(cx),
1246            indoc! {"
1247                ## User
1248
1249                abc
1250
1251                ## Assistant
1252
1253                def
1254
1255            "}
1256        )
1257    });
1258
1259    // Test cancel
1260    cx.update(|cx| connection.cancel(&session_id, cx));
1261    request.await.expect("prompt should fail gracefully");
1262
1263    // Ensure that dropping the ACP thread causes the native thread to be
1264    // dropped as well.
1265    cx.update(|_| drop(acp_thread));
1266    let result = cx
1267        .update(|cx| {
1268            connection.prompt(
1269                Some(acp_thread::UserMessageId::new()),
1270                acp::PromptRequest {
1271                    session_id: session_id.clone(),
1272                    prompt: vec!["ghi".into()],
1273                },
1274                cx,
1275            )
1276        })
1277        .await;
1278    assert_eq!(
1279        result.as_ref().unwrap_err().to_string(),
1280        "Session not found",
1281        "unexpected result: {:?}",
1282        result
1283    );
1284}
1285
1286#[gpui::test]
1287async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1288    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1289    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1290    let fake_model = model.as_fake();
1291
1292    let mut events = thread.update(cx, |thread, cx| {
1293        thread.send(UserMessageId::new(), ["Think"], cx)
1294    });
1295    cx.run_until_parked();
1296
1297    // Simulate streaming partial input.
1298    let input = json!({});
1299    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1300        LanguageModelToolUse {
1301            id: "1".into(),
1302            name: ThinkingTool.name().into(),
1303            raw_input: input.to_string(),
1304            input,
1305            is_input_complete: false,
1306        },
1307    ));
1308
1309    // Input streaming completed
1310    let input = json!({ "content": "Thinking hard!" });
1311    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1312        LanguageModelToolUse {
1313            id: "1".into(),
1314            name: "thinking".into(),
1315            raw_input: input.to_string(),
1316            input,
1317            is_input_complete: true,
1318        },
1319    ));
1320    fake_model.end_last_completion_stream();
1321    cx.run_until_parked();
1322
1323    let tool_call = expect_tool_call(&mut events).await;
1324    assert_eq!(
1325        tool_call,
1326        acp::ToolCall {
1327            id: acp::ToolCallId("1".into()),
1328            title: "Thinking".into(),
1329            kind: acp::ToolKind::Think,
1330            status: acp::ToolCallStatus::Pending,
1331            content: vec![],
1332            locations: vec![],
1333            raw_input: Some(json!({})),
1334            raw_output: None,
1335        }
1336    );
1337    let update = expect_tool_call_update_fields(&mut events).await;
1338    assert_eq!(
1339        update,
1340        acp::ToolCallUpdate {
1341            id: acp::ToolCallId("1".into()),
1342            fields: acp::ToolCallUpdateFields {
1343                title: Some("Thinking".into()),
1344                kind: Some(acp::ToolKind::Think),
1345                raw_input: Some(json!({ "content": "Thinking hard!" })),
1346                ..Default::default()
1347            },
1348        }
1349    );
1350    let update = expect_tool_call_update_fields(&mut events).await;
1351    assert_eq!(
1352        update,
1353        acp::ToolCallUpdate {
1354            id: acp::ToolCallId("1".into()),
1355            fields: acp::ToolCallUpdateFields {
1356                status: Some(acp::ToolCallStatus::InProgress),
1357                ..Default::default()
1358            },
1359        }
1360    );
1361    let update = expect_tool_call_update_fields(&mut events).await;
1362    assert_eq!(
1363        update,
1364        acp::ToolCallUpdate {
1365            id: acp::ToolCallId("1".into()),
1366            fields: acp::ToolCallUpdateFields {
1367                content: Some(vec!["Thinking hard!".into()]),
1368                ..Default::default()
1369            },
1370        }
1371    );
1372    let update = expect_tool_call_update_fields(&mut events).await;
1373    assert_eq!(
1374        update,
1375        acp::ToolCallUpdate {
1376            id: acp::ToolCallId("1".into()),
1377            fields: acp::ToolCallUpdateFields {
1378                status: Some(acp::ToolCallStatus::Completed),
1379                raw_output: Some("Finished thinking.".into()),
1380                ..Default::default()
1381            },
1382        }
1383    );
1384}
1385
1386/// Filters out the stop events for asserting against in tests
1387fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
1388    result_events
1389        .into_iter()
1390        .filter_map(|event| match event.unwrap() {
1391            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
1392            _ => None,
1393        })
1394        .collect()
1395}
1396
1397struct ThreadTest {
1398    model: Arc<dyn LanguageModel>,
1399    thread: Entity<Thread>,
1400    project_context: Rc<RefCell<ProjectContext>>,
1401    fs: Arc<FakeFs>,
1402}
1403
1404enum TestModel {
1405    Sonnet4,
1406    Sonnet4Thinking,
1407    Fake,
1408}
1409
1410impl TestModel {
1411    fn id(&self) -> LanguageModelId {
1412        match self {
1413            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
1414            TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
1415            TestModel::Fake => unreachable!(),
1416        }
1417    }
1418}
1419
1420async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
1421    cx.executor().allow_parking();
1422
1423    let fs = FakeFs::new(cx.background_executor.clone());
1424    fs.create_dir(paths::settings_file().parent().unwrap())
1425        .await
1426        .unwrap();
1427    fs.insert_file(
1428        paths::settings_file(),
1429        json!({
1430            "agent": {
1431                "default_profile": "test-profile",
1432                "profiles": {
1433                    "test-profile": {
1434                        "name": "Test Profile",
1435                        "tools": {
1436                            EchoTool.name(): true,
1437                            DelayTool.name(): true,
1438                            WordListTool.name(): true,
1439                            ToolRequiringPermission.name(): true,
1440                            InfiniteTool.name(): true,
1441                        }
1442                    }
1443                }
1444            }
1445        })
1446        .to_string()
1447        .into_bytes(),
1448    )
1449    .await;
1450
1451    cx.update(|cx| {
1452        settings::init(cx);
1453        Project::init_settings(cx);
1454        agent_settings::init(cx);
1455        gpui_tokio::init(cx);
1456        let http_client = ReqwestClient::user_agent("agent tests").unwrap();
1457        cx.set_http_client(Arc::new(http_client));
1458
1459        client::init_settings(cx);
1460        let client = Client::production(cx);
1461        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1462        language_model::init(client.clone(), cx);
1463        language_models::init(user_store.clone(), client.clone(), cx);
1464
1465        watch_settings(fs.clone(), cx);
1466    });
1467
1468    let templates = Templates::new();
1469
1470    fs.insert_tree(path!("/test"), json!({})).await;
1471    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1472
1473    let model = cx
1474        .update(|cx| {
1475            if let TestModel::Fake = model {
1476                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
1477            } else {
1478                let model_id = model.id();
1479                let models = LanguageModelRegistry::read_global(cx);
1480                let model = models
1481                    .available_models(cx)
1482                    .find(|model| model.id() == model_id)
1483                    .unwrap();
1484
1485                let provider = models.provider(&model.provider_id()).unwrap();
1486                let authenticated = provider.authenticate(cx);
1487
1488                cx.spawn(async move |_cx| {
1489                    authenticated.await.unwrap();
1490                    model
1491                })
1492            }
1493        })
1494        .await;
1495
1496    let project_context = Rc::new(RefCell::new(ProjectContext::default()));
1497    let context_server_registry =
1498        cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1499    let action_log = cx.new(|_| ActionLog::new(project.clone()));
1500    let thread = cx.new(|cx| {
1501        Thread::new(
1502            generate_session_id(),
1503            project,
1504            project_context.clone(),
1505            context_server_registry,
1506            action_log,
1507            templates,
1508            model.clone(),
1509            None,
1510            cx,
1511        )
1512    });
1513    ThreadTest {
1514        model,
1515        thread,
1516        project_context,
1517        fs,
1518    }
1519}
1520
1521#[cfg(test)]
1522#[ctor::ctor]
1523fn init_logger() {
1524    if std::env::var("RUST_LOG").is_ok() {
1525        env_logger::init();
1526    }
1527}
1528
1529fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
1530    let fs = fs.clone();
1531    cx.spawn({
1532        async move |cx| {
1533            let mut new_settings_content_rx = settings::watch_config_file(
1534                cx.background_executor(),
1535                fs,
1536                paths::settings_file().clone(),
1537            );
1538
1539            while let Some(new_settings_content) = new_settings_content_rx.next().await {
1540                cx.update(|cx| {
1541                    SettingsStore::update_global(cx, |settings, cx| {
1542                        settings.set_user_settings(&new_settings_content, cx)
1543                    })
1544                })
1545                .ok();
1546            }
1547        }
1548    })
1549    .detach();
1550}