mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use action_log::ActionLog;
   4use agent_client_protocol::{self as acp};
   5use agent_settings::AgentProfileId;
   6use anyhow::Result;
   7use client::{Client, UserStore};
   8use fs::{FakeFs, Fs};
   9use futures::channel::mpsc::UnboundedReceiver;
  10use gpui::{
  11    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  12};
  13use indoc::indoc;
  14use language_model::{
  15    LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry,
  16    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
  17    Role, StopReason, fake_provider::FakeLanguageModel,
  18};
  19use pretty_assertions::assert_eq;
  20use project::Project;
  21use prompt_store::ProjectContext;
  22use reqwest_client::ReqwestClient;
  23use schemars::JsonSchema;
  24use serde::{Deserialize, Serialize};
  25use serde_json::json;
  26use settings::SettingsStore;
  27use smol::stream::StreamExt;
  28use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
  29use util::path;
  30
  31mod test_tools;
  32use test_tools::*;
  33
  34#[gpui::test]
  35#[ignore = "can't run on CI yet"]
  36async fn test_echo(cx: &mut TestAppContext) {
  37    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
  38
  39    let events = thread
  40        .update(cx, |thread, cx| {
  41            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  42        })
  43        .collect()
  44        .await;
  45    thread.update(cx, |thread, _cx| {
  46        assert_eq!(
  47            thread.last_message().unwrap().to_markdown(),
  48            indoc! {"
  49                ## Assistant
  50
  51                Hello
  52            "}
  53        )
  54    });
  55    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  56}
  57
  58#[gpui::test]
  59#[ignore = "can't run on CI yet"]
  60async fn test_thinking(cx: &mut TestAppContext) {
  61    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
  62
  63    let events = thread
  64        .update(cx, |thread, cx| {
  65            thread.send(
  66                UserMessageId::new(),
  67                [indoc! {"
  68                    Testing:
  69
  70                    Generate a thinking step where you just think the word 'Think',
  71                    and have your final answer be 'Hello'
  72                "}],
  73                cx,
  74            )
  75        })
  76        .collect()
  77        .await;
  78    thread.update(cx, |thread, _cx| {
  79        assert_eq!(
  80            thread.last_message().unwrap().to_markdown(),
  81            indoc! {"
  82                ## Assistant
  83
  84                <think>Think</think>
  85                Hello
  86            "}
  87        )
  88    });
  89    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  90}
  91
  92#[gpui::test]
  93async fn test_system_prompt(cx: &mut TestAppContext) {
  94    let ThreadTest {
  95        model,
  96        thread,
  97        project_context,
  98        ..
  99    } = setup(cx, TestModel::Fake).await;
 100    let fake_model = model.as_fake();
 101
 102    project_context.borrow_mut().shell = "test-shell".into();
 103    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 104    thread.update(cx, |thread, cx| {
 105        thread.send(UserMessageId::new(), ["abc"], cx)
 106    });
 107    cx.run_until_parked();
 108    let mut pending_completions = fake_model.pending_completions();
 109    assert_eq!(
 110        pending_completions.len(),
 111        1,
 112        "unexpected pending completions: {:?}",
 113        pending_completions
 114    );
 115
 116    let pending_completion = pending_completions.pop().unwrap();
 117    assert_eq!(pending_completion.messages[0].role, Role::System);
 118
 119    let system_message = &pending_completion.messages[0];
 120    let system_prompt = system_message.content[0].to_str().unwrap();
 121    assert!(
 122        system_prompt.contains("test-shell"),
 123        "unexpected system message: {:?}",
 124        system_message
 125    );
 126    assert!(
 127        system_prompt.contains("## Fixing Diagnostics"),
 128        "unexpected system message: {:?}",
 129        system_message
 130    );
 131}
 132
 133#[gpui::test]
 134async fn test_prompt_caching(cx: &mut TestAppContext) {
 135    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 136    let fake_model = model.as_fake();
 137
 138    // Send initial user message and verify it's cached
 139    thread.update(cx, |thread, cx| {
 140        thread.send(UserMessageId::new(), ["Message 1"], cx)
 141    });
 142    cx.run_until_parked();
 143
 144    let completion = fake_model.pending_completions().pop().unwrap();
 145    assert_eq!(
 146        completion.messages[1..],
 147        vec![LanguageModelRequestMessage {
 148            role: Role::User,
 149            content: vec!["Message 1".into()],
 150            cache: true
 151        }]
 152    );
 153    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 154        "Response to Message 1".into(),
 155    ));
 156    fake_model.end_last_completion_stream();
 157    cx.run_until_parked();
 158
 159    // Send another user message and verify only the latest is cached
 160    thread.update(cx, |thread, cx| {
 161        thread.send(UserMessageId::new(), ["Message 2"], cx)
 162    });
 163    cx.run_until_parked();
 164
 165    let completion = fake_model.pending_completions().pop().unwrap();
 166    assert_eq!(
 167        completion.messages[1..],
 168        vec![
 169            LanguageModelRequestMessage {
 170                role: Role::User,
 171                content: vec!["Message 1".into()],
 172                cache: false
 173            },
 174            LanguageModelRequestMessage {
 175                role: Role::Assistant,
 176                content: vec!["Response to Message 1".into()],
 177                cache: false
 178            },
 179            LanguageModelRequestMessage {
 180                role: Role::User,
 181                content: vec!["Message 2".into()],
 182                cache: true
 183            }
 184        ]
 185    );
 186    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 187        "Response to Message 2".into(),
 188    ));
 189    fake_model.end_last_completion_stream();
 190    cx.run_until_parked();
 191
 192    // Simulate a tool call and verify that the latest tool result is cached
 193    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 194    thread.update(cx, |thread, cx| {
 195        thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 196    });
 197    cx.run_until_parked();
 198
 199    let tool_use = LanguageModelToolUse {
 200        id: "tool_1".into(),
 201        name: EchoTool.name().into(),
 202        raw_input: json!({"text": "test"}).to_string(),
 203        input: json!({"text": "test"}),
 204        is_input_complete: true,
 205    };
 206    fake_model
 207        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 208    fake_model.end_last_completion_stream();
 209    cx.run_until_parked();
 210
 211    let completion = fake_model.pending_completions().pop().unwrap();
 212    let tool_result = LanguageModelToolResult {
 213        tool_use_id: "tool_1".into(),
 214        tool_name: EchoTool.name().into(),
 215        is_error: false,
 216        content: "test".into(),
 217        output: Some("test".into()),
 218    };
 219    assert_eq!(
 220        completion.messages[1..],
 221        vec![
 222            LanguageModelRequestMessage {
 223                role: Role::User,
 224                content: vec!["Message 1".into()],
 225                cache: false
 226            },
 227            LanguageModelRequestMessage {
 228                role: Role::Assistant,
 229                content: vec!["Response to Message 1".into()],
 230                cache: false
 231            },
 232            LanguageModelRequestMessage {
 233                role: Role::User,
 234                content: vec!["Message 2".into()],
 235                cache: false
 236            },
 237            LanguageModelRequestMessage {
 238                role: Role::Assistant,
 239                content: vec!["Response to Message 2".into()],
 240                cache: false
 241            },
 242            LanguageModelRequestMessage {
 243                role: Role::User,
 244                content: vec!["Use the echo tool".into()],
 245                cache: false
 246            },
 247            LanguageModelRequestMessage {
 248                role: Role::Assistant,
 249                content: vec![MessageContent::ToolUse(tool_use)],
 250                cache: false
 251            },
 252            LanguageModelRequestMessage {
 253                role: Role::User,
 254                content: vec![MessageContent::ToolResult(tool_result)],
 255                cache: true
 256            }
 257        ]
 258    );
 259}
 260
 261#[gpui::test]
 262#[ignore = "can't run on CI yet"]
 263async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 264    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 265
 266    // Test a tool call that's likely to complete *before* streaming stops.
 267    let events = thread
 268        .update(cx, |thread, cx| {
 269            thread.add_tool(EchoTool);
 270            thread.send(
 271                UserMessageId::new(),
 272                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 273                cx,
 274            )
 275        })
 276        .collect()
 277        .await;
 278    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 279
 280    // Test a tool calls that's likely to complete *after* streaming stops.
 281    let events = thread
 282        .update(cx, |thread, cx| {
 283            thread.remove_tool(&AgentTool::name(&EchoTool));
 284            thread.add_tool(DelayTool);
 285            thread.send(
 286                UserMessageId::new(),
 287                [
 288                    "Now call the delay tool with 200ms.",
 289                    "When the timer goes off, then you echo the output of the tool.",
 290                ],
 291                cx,
 292            )
 293        })
 294        .collect()
 295        .await;
 296    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 297    thread.update(cx, |thread, _cx| {
 298        assert!(
 299            thread
 300                .last_message()
 301                .unwrap()
 302                .as_agent_message()
 303                .unwrap()
 304                .content
 305                .iter()
 306                .any(|content| {
 307                    if let AgentMessageContent::Text(text) = content {
 308                        text.contains("Ding")
 309                    } else {
 310                        false
 311                    }
 312                }),
 313            "{}",
 314            thread.to_markdown()
 315        );
 316    });
 317}
 318
 319#[gpui::test]
 320#[ignore = "can't run on CI yet"]
 321async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 322    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 323
 324    // Test a tool call that's likely to complete *before* streaming stops.
 325    let mut events = thread.update(cx, |thread, cx| {
 326        thread.add_tool(WordListTool);
 327        thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 328    });
 329
 330    let mut saw_partial_tool_use = false;
 331    while let Some(event) = events.next().await {
 332        if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
 333            thread.update(cx, |thread, _cx| {
 334                // Look for a tool use in the thread's last message
 335                let message = thread.last_message().unwrap();
 336                let agent_message = message.as_agent_message().unwrap();
 337                let last_content = agent_message.content.last().unwrap();
 338                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 339                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 340                    if tool_call.status == acp::ToolCallStatus::Pending {
 341                        if !last_tool_use.is_input_complete
 342                            && last_tool_use.input.get("g").is_none()
 343                        {
 344                            saw_partial_tool_use = true;
 345                        }
 346                    } else {
 347                        last_tool_use
 348                            .input
 349                            .get("a")
 350                            .expect("'a' has streamed because input is now complete");
 351                        last_tool_use
 352                            .input
 353                            .get("g")
 354                            .expect("'g' has streamed because input is now complete");
 355                    }
 356                } else {
 357                    panic!("last content should be a tool use");
 358                }
 359            });
 360        }
 361    }
 362
 363    assert!(
 364        saw_partial_tool_use,
 365        "should see at least one partially streamed tool use in the history"
 366    );
 367}
 368
 369#[gpui::test]
 370async fn test_tool_authorization(cx: &mut TestAppContext) {
 371    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 372    let fake_model = model.as_fake();
 373
 374    let mut events = thread.update(cx, |thread, cx| {
 375        thread.add_tool(ToolRequiringPermission);
 376        thread.send(UserMessageId::new(), ["abc"], cx)
 377    });
 378    cx.run_until_parked();
 379    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 380        LanguageModelToolUse {
 381            id: "tool_id_1".into(),
 382            name: ToolRequiringPermission.name().into(),
 383            raw_input: "{}".into(),
 384            input: json!({}),
 385            is_input_complete: true,
 386        },
 387    ));
 388    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 389        LanguageModelToolUse {
 390            id: "tool_id_2".into(),
 391            name: ToolRequiringPermission.name().into(),
 392            raw_input: "{}".into(),
 393            input: json!({}),
 394            is_input_complete: true,
 395        },
 396    ));
 397    fake_model.end_last_completion_stream();
 398    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 399    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 400
 401    // Approve the first
 402    tool_call_auth_1
 403        .response
 404        .send(tool_call_auth_1.options[1].id.clone())
 405        .unwrap();
 406    cx.run_until_parked();
 407
 408    // Reject the second
 409    tool_call_auth_2
 410        .response
 411        .send(tool_call_auth_1.options[2].id.clone())
 412        .unwrap();
 413    cx.run_until_parked();
 414
 415    let completion = fake_model.pending_completions().pop().unwrap();
 416    let message = completion.messages.last().unwrap();
 417    assert_eq!(
 418        message.content,
 419        vec![
 420            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 421                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 422                tool_name: ToolRequiringPermission.name().into(),
 423                is_error: false,
 424                content: "Allowed".into(),
 425                output: Some("Allowed".into())
 426            }),
 427            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 428                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 429                tool_name: ToolRequiringPermission.name().into(),
 430                is_error: true,
 431                content: "Permission to run tool denied by user".into(),
 432                output: None
 433            })
 434        ]
 435    );
 436
 437    // Simulate yet another tool call.
 438    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 439        LanguageModelToolUse {
 440            id: "tool_id_3".into(),
 441            name: ToolRequiringPermission.name().into(),
 442            raw_input: "{}".into(),
 443            input: json!({}),
 444            is_input_complete: true,
 445        },
 446    ));
 447    fake_model.end_last_completion_stream();
 448
 449    // Respond by always allowing tools.
 450    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 451    tool_call_auth_3
 452        .response
 453        .send(tool_call_auth_3.options[0].id.clone())
 454        .unwrap();
 455    cx.run_until_parked();
 456    let completion = fake_model.pending_completions().pop().unwrap();
 457    let message = completion.messages.last().unwrap();
 458    assert_eq!(
 459        message.content,
 460        vec![language_model::MessageContent::ToolResult(
 461            LanguageModelToolResult {
 462                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 463                tool_name: ToolRequiringPermission.name().into(),
 464                is_error: false,
 465                content: "Allowed".into(),
 466                output: Some("Allowed".into())
 467            }
 468        )]
 469    );
 470
 471    // Simulate a final tool call, ensuring we don't trigger authorization.
 472    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 473        LanguageModelToolUse {
 474            id: "tool_id_4".into(),
 475            name: ToolRequiringPermission.name().into(),
 476            raw_input: "{}".into(),
 477            input: json!({}),
 478            is_input_complete: true,
 479        },
 480    ));
 481    fake_model.end_last_completion_stream();
 482    cx.run_until_parked();
 483    let completion = fake_model.pending_completions().pop().unwrap();
 484    let message = completion.messages.last().unwrap();
 485    assert_eq!(
 486        message.content,
 487        vec![language_model::MessageContent::ToolResult(
 488            LanguageModelToolResult {
 489                tool_use_id: "tool_id_4".into(),
 490                tool_name: ToolRequiringPermission.name().into(),
 491                is_error: false,
 492                content: "Allowed".into(),
 493                output: Some("Allowed".into())
 494            }
 495        )]
 496    );
 497}
 498
 499#[gpui::test]
 500async fn test_tool_hallucination(cx: &mut TestAppContext) {
 501    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 502    let fake_model = model.as_fake();
 503
 504    let mut events = thread.update(cx, |thread, cx| {
 505        thread.send(UserMessageId::new(), ["abc"], cx)
 506    });
 507    cx.run_until_parked();
 508    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 509        LanguageModelToolUse {
 510            id: "tool_id_1".into(),
 511            name: "nonexistent_tool".into(),
 512            raw_input: "{}".into(),
 513            input: json!({}),
 514            is_input_complete: true,
 515        },
 516    ));
 517    fake_model.end_last_completion_stream();
 518
 519    let tool_call = expect_tool_call(&mut events).await;
 520    assert_eq!(tool_call.title, "nonexistent_tool");
 521    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 522    let update = expect_tool_call_update_fields(&mut events).await;
 523    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 524}
 525
 526#[gpui::test]
 527async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 528    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 529    let fake_model = model.as_fake();
 530
 531    let events = thread.update(cx, |thread, cx| {
 532        thread.add_tool(EchoTool);
 533        thread.send(UserMessageId::new(), ["abc"], cx)
 534    });
 535    cx.run_until_parked();
 536    let tool_use = LanguageModelToolUse {
 537        id: "tool_id_1".into(),
 538        name: EchoTool.name().into(),
 539        raw_input: "{}".into(),
 540        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 541        is_input_complete: true,
 542    };
 543    fake_model
 544        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 545    fake_model.end_last_completion_stream();
 546
 547    cx.run_until_parked();
 548    let completion = fake_model.pending_completions().pop().unwrap();
 549    let tool_result = LanguageModelToolResult {
 550        tool_use_id: "tool_id_1".into(),
 551        tool_name: EchoTool.name().into(),
 552        is_error: false,
 553        content: "def".into(),
 554        output: Some("def".into()),
 555    };
 556    assert_eq!(
 557        completion.messages[1..],
 558        vec![
 559            LanguageModelRequestMessage {
 560                role: Role::User,
 561                content: vec!["abc".into()],
 562                cache: false
 563            },
 564            LanguageModelRequestMessage {
 565                role: Role::Assistant,
 566                content: vec![MessageContent::ToolUse(tool_use.clone())],
 567                cache: false
 568            },
 569            LanguageModelRequestMessage {
 570                role: Role::User,
 571                content: vec![MessageContent::ToolResult(tool_result.clone())],
 572                cache: true
 573            },
 574        ]
 575    );
 576
 577    // Simulate reaching tool use limit.
 578    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 579        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 580    ));
 581    fake_model.end_last_completion_stream();
 582    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 583    assert!(
 584        last_event
 585            .unwrap_err()
 586            .is::<language_model::ToolUseLimitReachedError>()
 587    );
 588
 589    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 590    cx.run_until_parked();
 591    let completion = fake_model.pending_completions().pop().unwrap();
 592    assert_eq!(
 593        completion.messages[1..],
 594        vec![
 595            LanguageModelRequestMessage {
 596                role: Role::User,
 597                content: vec!["abc".into()],
 598                cache: false
 599            },
 600            LanguageModelRequestMessage {
 601                role: Role::Assistant,
 602                content: vec![MessageContent::ToolUse(tool_use)],
 603                cache: false
 604            },
 605            LanguageModelRequestMessage {
 606                role: Role::User,
 607                content: vec![MessageContent::ToolResult(tool_result)],
 608                cache: false
 609            },
 610            LanguageModelRequestMessage {
 611                role: Role::User,
 612                content: vec!["Continue where you left off".into()],
 613                cache: true
 614            }
 615        ]
 616    );
 617
 618    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 619    fake_model.end_last_completion_stream();
 620    events.collect::<Vec<_>>().await;
 621    thread.read_with(cx, |thread, _cx| {
 622        assert_eq!(
 623            thread.last_message().unwrap().to_markdown(),
 624            indoc! {"
 625                ## Assistant
 626
 627                Done
 628            "}
 629        )
 630    });
 631
 632    // Ensure we error if calling resume when tool use limit was *not* reached.
 633    let error = thread
 634        .update(cx, |thread, cx| thread.resume(cx))
 635        .unwrap_err();
 636    assert_eq!(
 637        error.to_string(),
 638        "can only resume after tool use limit is reached"
 639    )
 640}
 641
 642#[gpui::test]
 643async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 644    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 645    let fake_model = model.as_fake();
 646
 647    let events = thread.update(cx, |thread, cx| {
 648        thread.add_tool(EchoTool);
 649        thread.send(UserMessageId::new(), ["abc"], cx)
 650    });
 651    cx.run_until_parked();
 652
 653    let tool_use = LanguageModelToolUse {
 654        id: "tool_id_1".into(),
 655        name: EchoTool.name().into(),
 656        raw_input: "{}".into(),
 657        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 658        is_input_complete: true,
 659    };
 660    let tool_result = LanguageModelToolResult {
 661        tool_use_id: "tool_id_1".into(),
 662        tool_name: EchoTool.name().into(),
 663        is_error: false,
 664        content: "def".into(),
 665        output: Some("def".into()),
 666    };
 667    fake_model
 668        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 669    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 670        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 671    ));
 672    fake_model.end_last_completion_stream();
 673    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 674    assert!(
 675        last_event
 676            .unwrap_err()
 677            .is::<language_model::ToolUseLimitReachedError>()
 678    );
 679
 680    thread.update(cx, |thread, cx| {
 681        thread.send(UserMessageId::new(), vec!["ghi"], cx)
 682    });
 683    cx.run_until_parked();
 684    let completion = fake_model.pending_completions().pop().unwrap();
 685    assert_eq!(
 686        completion.messages[1..],
 687        vec![
 688            LanguageModelRequestMessage {
 689                role: Role::User,
 690                content: vec!["abc".into()],
 691                cache: false
 692            },
 693            LanguageModelRequestMessage {
 694                role: Role::Assistant,
 695                content: vec![MessageContent::ToolUse(tool_use)],
 696                cache: false
 697            },
 698            LanguageModelRequestMessage {
 699                role: Role::User,
 700                content: vec![MessageContent::ToolResult(tool_result)],
 701                cache: false
 702            },
 703            LanguageModelRequestMessage {
 704                role: Role::User,
 705                content: vec!["ghi".into()],
 706                cache: true
 707            }
 708        ]
 709    );
 710}
 711
 712async fn expect_tool_call(
 713    events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
 714) -> acp::ToolCall {
 715    let event = events
 716        .next()
 717        .await
 718        .expect("no tool call authorization event received")
 719        .unwrap();
 720    match event {
 721        AgentResponseEvent::ToolCall(tool_call) => return tool_call,
 722        event => {
 723            panic!("Unexpected event {event:?}");
 724        }
 725    }
 726}
 727
 728async fn expect_tool_call_update_fields(
 729    events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
 730) -> acp::ToolCallUpdate {
 731    let event = events
 732        .next()
 733        .await
 734        .expect("no tool call authorization event received")
 735        .unwrap();
 736    match event {
 737        AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
 738            return update;
 739        }
 740        event => {
 741            panic!("Unexpected event {event:?}");
 742        }
 743    }
 744}
 745
 746async fn next_tool_call_authorization(
 747    events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
 748) -> ToolCallAuthorization {
 749    loop {
 750        let event = events
 751            .next()
 752            .await
 753            .expect("no tool call authorization event received")
 754            .unwrap();
 755        if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
 756            let permission_kinds = tool_call_authorization
 757                .options
 758                .iter()
 759                .map(|o| o.kind)
 760                .collect::<Vec<_>>();
 761            assert_eq!(
 762                permission_kinds,
 763                vec![
 764                    acp::PermissionOptionKind::AllowAlways,
 765                    acp::PermissionOptionKind::AllowOnce,
 766                    acp::PermissionOptionKind::RejectOnce,
 767                ]
 768            );
 769            return tool_call_authorization;
 770        }
 771    }
 772}
 773
 774#[gpui::test]
 775#[ignore = "can't run on CI yet"]
 776async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 777    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 778
 779    // Test concurrent tool calls with different delay times
 780    let events = thread
 781        .update(cx, |thread, cx| {
 782            thread.add_tool(DelayTool);
 783            thread.send(
 784                UserMessageId::new(),
 785                [
 786                    "Call the delay tool twice in the same message.",
 787                    "Once with 100ms. Once with 300ms.",
 788                    "When both timers are complete, describe the outputs.",
 789                ],
 790                cx,
 791            )
 792        })
 793        .collect()
 794        .await;
 795
 796    let stop_reasons = stop_events(events);
 797    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 798
 799    thread.update(cx, |thread, _cx| {
 800        let last_message = thread.last_message().unwrap();
 801        let agent_message = last_message.as_agent_message().unwrap();
 802        let text = agent_message
 803            .content
 804            .iter()
 805            .filter_map(|content| {
 806                if let AgentMessageContent::Text(text) = content {
 807                    Some(text.as_str())
 808                } else {
 809                    None
 810                }
 811            })
 812            .collect::<String>();
 813
 814        assert!(text.contains("Ding"));
 815    });
 816}
 817
 818#[gpui::test]
 819async fn test_profiles(cx: &mut TestAppContext) {
 820    let ThreadTest {
 821        model, thread, fs, ..
 822    } = setup(cx, TestModel::Fake).await;
 823    let fake_model = model.as_fake();
 824
 825    thread.update(cx, |thread, _cx| {
 826        thread.add_tool(DelayTool);
 827        thread.add_tool(EchoTool);
 828        thread.add_tool(InfiniteTool);
 829    });
 830
 831    // Override profiles and wait for settings to be loaded.
 832    fs.insert_file(
 833        paths::settings_file(),
 834        json!({
 835            "agent": {
 836                "profiles": {
 837                    "test-1": {
 838                        "name": "Test Profile 1",
 839                        "tools": {
 840                            EchoTool.name(): true,
 841                            DelayTool.name(): true,
 842                        }
 843                    },
 844                    "test-2": {
 845                        "name": "Test Profile 2",
 846                        "tools": {
 847                            InfiniteTool.name(): true,
 848                        }
 849                    }
 850                }
 851            }
 852        })
 853        .to_string()
 854        .into_bytes(),
 855    )
 856    .await;
 857    cx.run_until_parked();
 858
 859    // Test that test-1 profile (default) has echo and delay tools
 860    thread.update(cx, |thread, cx| {
 861        thread.set_profile(AgentProfileId("test-1".into()));
 862        thread.send(UserMessageId::new(), ["test"], cx);
 863    });
 864    cx.run_until_parked();
 865
 866    let mut pending_completions = fake_model.pending_completions();
 867    assert_eq!(pending_completions.len(), 1);
 868    let completion = pending_completions.pop().unwrap();
 869    let tool_names: Vec<String> = completion
 870        .tools
 871        .iter()
 872        .map(|tool| tool.name.clone())
 873        .collect();
 874    assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
 875    fake_model.end_last_completion_stream();
 876
 877    // Switch to test-2 profile, and verify that it has only the infinite tool.
 878    thread.update(cx, |thread, cx| {
 879        thread.set_profile(AgentProfileId("test-2".into()));
 880        thread.send(UserMessageId::new(), ["test2"], cx)
 881    });
 882    cx.run_until_parked();
 883    let mut pending_completions = fake_model.pending_completions();
 884    assert_eq!(pending_completions.len(), 1);
 885    let completion = pending_completions.pop().unwrap();
 886    let tool_names: Vec<String> = completion
 887        .tools
 888        .iter()
 889        .map(|tool| tool.name.clone())
 890        .collect();
 891    assert_eq!(tool_names, vec![InfiniteTool.name()]);
 892}
 893
 894#[gpui::test]
 895#[ignore = "can't run on CI yet"]
 896async fn test_cancellation(cx: &mut TestAppContext) {
 897    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 898
 899    let mut events = thread.update(cx, |thread, cx| {
 900        thread.add_tool(InfiniteTool);
 901        thread.add_tool(EchoTool);
 902        thread.send(
 903            UserMessageId::new(),
 904            ["Call the echo tool, then call the infinite tool, then explain their output"],
 905            cx,
 906        )
 907    });
 908
 909    // Wait until both tools are called.
 910    let mut expected_tools = vec!["Echo", "Infinite Tool"];
 911    let mut echo_id = None;
 912    let mut echo_completed = false;
 913    while let Some(event) = events.next().await {
 914        match event.unwrap() {
 915            AgentResponseEvent::ToolCall(tool_call) => {
 916                assert_eq!(tool_call.title, expected_tools.remove(0));
 917                if tool_call.title == "Echo" {
 918                    echo_id = Some(tool_call.id);
 919                }
 920            }
 921            AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
 922                acp::ToolCallUpdate {
 923                    id,
 924                    fields:
 925                        acp::ToolCallUpdateFields {
 926                            status: Some(acp::ToolCallStatus::Completed),
 927                            ..
 928                        },
 929                },
 930            )) if Some(&id) == echo_id.as_ref() => {
 931                echo_completed = true;
 932            }
 933            _ => {}
 934        }
 935
 936        if expected_tools.is_empty() && echo_completed {
 937            break;
 938        }
 939    }
 940
 941    // Cancel the current send and ensure that the event stream is closed, even
 942    // if one of the tools is still running.
 943    thread.update(cx, |thread, _cx| thread.cancel());
 944    let events = events.collect::<Vec<_>>().await;
 945    let last_event = events.last();
 946    assert!(
 947        matches!(
 948            last_event,
 949            Some(Ok(AgentResponseEvent::Stop(acp::StopReason::Canceled)))
 950        ),
 951        "unexpected event {last_event:?}"
 952    );
 953
 954    // Ensure we can still send a new message after cancellation.
 955    let events = thread
 956        .update(cx, |thread, cx| {
 957            thread.send(
 958                UserMessageId::new(),
 959                ["Testing: reply with 'Hello' then stop."],
 960                cx,
 961            )
 962        })
 963        .collect::<Vec<_>>()
 964        .await;
 965    thread.update(cx, |thread, _cx| {
 966        let message = thread.last_message().unwrap();
 967        let agent_message = message.as_agent_message().unwrap();
 968        assert_eq!(
 969            agent_message.content,
 970            vec![AgentMessageContent::Text("Hello".to_string())]
 971        );
 972    });
 973    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 974}
 975
 976#[gpui::test]
 977async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
 978    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 979    let fake_model = model.as_fake();
 980
 981    let events_1 = thread.update(cx, |thread, cx| {
 982        thread.send(UserMessageId::new(), ["Hello 1"], cx)
 983    });
 984    cx.run_until_parked();
 985    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
 986    cx.run_until_parked();
 987
 988    let events_2 = thread.update(cx, |thread, cx| {
 989        thread.send(UserMessageId::new(), ["Hello 2"], cx)
 990    });
 991    cx.run_until_parked();
 992    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
 993    fake_model
 994        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
 995    fake_model.end_last_completion_stream();
 996
 997    let events_1 = events_1.collect::<Vec<_>>().await;
 998    assert_eq!(stop_events(events_1), vec![acp::StopReason::Canceled]);
 999    let events_2 = events_2.collect::<Vec<_>>().await;
1000    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1001}
1002
1003#[gpui::test]
1004async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1005    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1006    let fake_model = model.as_fake();
1007
1008    let events_1 = thread.update(cx, |thread, cx| {
1009        thread.send(UserMessageId::new(), ["Hello 1"], cx)
1010    });
1011    cx.run_until_parked();
1012    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1013    fake_model
1014        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1015    fake_model.end_last_completion_stream();
1016    let events_1 = events_1.collect::<Vec<_>>().await;
1017
1018    let events_2 = thread.update(cx, |thread, cx| {
1019        thread.send(UserMessageId::new(), ["Hello 2"], cx)
1020    });
1021    cx.run_until_parked();
1022    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1023    fake_model
1024        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1025    fake_model.end_last_completion_stream();
1026    let events_2 = events_2.collect::<Vec<_>>().await;
1027
1028    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1029    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1030}
1031
1032#[gpui::test]
1033async fn test_refusal(cx: &mut TestAppContext) {
1034    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1035    let fake_model = model.as_fake();
1036
1037    let events = thread.update(cx, |thread, cx| {
1038        thread.send(UserMessageId::new(), ["Hello"], cx)
1039    });
1040    cx.run_until_parked();
1041    thread.read_with(cx, |thread, _| {
1042        assert_eq!(
1043            thread.to_markdown(),
1044            indoc! {"
1045                ## User
1046
1047                Hello
1048            "}
1049        );
1050    });
1051
1052    fake_model.send_last_completion_stream_text_chunk("Hey!");
1053    cx.run_until_parked();
1054    thread.read_with(cx, |thread, _| {
1055        assert_eq!(
1056            thread.to_markdown(),
1057            indoc! {"
1058                ## User
1059
1060                Hello
1061
1062                ## Assistant
1063
1064                Hey!
1065            "}
1066        );
1067    });
1068
1069    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1070    fake_model
1071        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1072    let events = events.collect::<Vec<_>>().await;
1073    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1074    thread.read_with(cx, |thread, _| {
1075        assert_eq!(thread.to_markdown(), "");
1076    });
1077}
1078
1079#[gpui::test]
1080async fn test_truncate(cx: &mut TestAppContext) {
1081    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1082    let fake_model = model.as_fake();
1083
1084    let message_id = UserMessageId::new();
1085    thread.update(cx, |thread, cx| {
1086        thread.send(message_id.clone(), ["Hello"], cx)
1087    });
1088    cx.run_until_parked();
1089    thread.read_with(cx, |thread, _| {
1090        assert_eq!(
1091            thread.to_markdown(),
1092            indoc! {"
1093                ## User
1094
1095                Hello
1096            "}
1097        );
1098    });
1099
1100    fake_model.send_last_completion_stream_text_chunk("Hey!");
1101    cx.run_until_parked();
1102    thread.read_with(cx, |thread, _| {
1103        assert_eq!(
1104            thread.to_markdown(),
1105            indoc! {"
1106                ## User
1107
1108                Hello
1109
1110                ## Assistant
1111
1112                Hey!
1113            "}
1114        );
1115    });
1116
1117    thread
1118        .update(cx, |thread, _cx| thread.truncate(message_id))
1119        .unwrap();
1120    cx.run_until_parked();
1121    thread.read_with(cx, |thread, _| {
1122        assert_eq!(thread.to_markdown(), "");
1123    });
1124
1125    // Ensure we can still send a new message after truncation.
1126    thread.update(cx, |thread, cx| {
1127        thread.send(UserMessageId::new(), ["Hi"], cx)
1128    });
1129    thread.update(cx, |thread, _cx| {
1130        assert_eq!(
1131            thread.to_markdown(),
1132            indoc! {"
1133                ## User
1134
1135                Hi
1136            "}
1137        );
1138    });
1139    cx.run_until_parked();
1140    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1141    cx.run_until_parked();
1142    thread.read_with(cx, |thread, _| {
1143        assert_eq!(
1144            thread.to_markdown(),
1145            indoc! {"
1146                ## User
1147
1148                Hi
1149
1150                ## Assistant
1151
1152                Ahoy!
1153            "}
1154        );
1155    });
1156}
1157
1158#[gpui::test]
1159async fn test_agent_connection(cx: &mut TestAppContext) {
1160    cx.update(settings::init);
1161    let templates = Templates::new();
1162
1163    // Initialize language model system with test provider
1164    cx.update(|cx| {
1165        gpui_tokio::init(cx);
1166        client::init_settings(cx);
1167
1168        let http_client = FakeHttpClient::with_404_response();
1169        let clock = Arc::new(clock::FakeSystemClock::new());
1170        let client = Client::new(clock, http_client, cx);
1171        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1172        language_model::init(client.clone(), cx);
1173        language_models::init(user_store.clone(), client.clone(), cx);
1174        Project::init_settings(cx);
1175        LanguageModelRegistry::test(cx);
1176        agent_settings::init(cx);
1177    });
1178    cx.executor().forbid_parking();
1179
1180    // Create a project for new_thread
1181    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1182    fake_fs.insert_tree(path!("/test"), json!({})).await;
1183    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1184    let cwd = Path::new("/test");
1185
1186    // Create agent and connection
1187    let agent = NativeAgent::new(
1188        project.clone(),
1189        templates.clone(),
1190        None,
1191        fake_fs.clone(),
1192        &mut cx.to_async(),
1193    )
1194    .await
1195    .unwrap();
1196    let connection = NativeAgentConnection(agent.clone());
1197
1198    // Test model_selector returns Some
1199    let selector_opt = connection.model_selector();
1200    assert!(
1201        selector_opt.is_some(),
1202        "agent2 should always support ModelSelector"
1203    );
1204    let selector = selector_opt.unwrap();
1205
1206    // Test list_models
1207    let listed_models = cx
1208        .update(|cx| selector.list_models(cx))
1209        .await
1210        .expect("list_models should succeed");
1211    let AgentModelList::Grouped(listed_models) = listed_models else {
1212        panic!("Unexpected model list type");
1213    };
1214    assert!(!listed_models.is_empty(), "should have at least one model");
1215    assert_eq!(
1216        listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1217        "fake/fake"
1218    );
1219
1220    // Create a thread using new_thread
1221    let connection_rc = Rc::new(connection.clone());
1222    let acp_thread = cx
1223        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1224        .await
1225        .expect("new_thread should succeed");
1226
1227    // Get the session_id from the AcpThread
1228    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1229
1230    // Test selected_model returns the default
1231    let model = cx
1232        .update(|cx| selector.selected_model(&session_id, cx))
1233        .await
1234        .expect("selected_model should succeed");
1235    let model = cx
1236        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1237        .unwrap();
1238    let model = model.as_fake();
1239    assert_eq!(model.id().0, "fake", "should return default model");
1240
1241    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1242    cx.run_until_parked();
1243    model.send_last_completion_stream_text_chunk("def");
1244    cx.run_until_parked();
1245    acp_thread.read_with(cx, |thread, cx| {
1246        assert_eq!(
1247            thread.to_markdown(cx),
1248            indoc! {"
1249                ## User
1250
1251                abc
1252
1253                ## Assistant
1254
1255                def
1256
1257            "}
1258        )
1259    });
1260
1261    // Test cancel
1262    cx.update(|cx| connection.cancel(&session_id, cx));
1263    request.await.expect("prompt should fail gracefully");
1264
1265    // Ensure that dropping the ACP thread causes the native thread to be
1266    // dropped as well.
1267    cx.update(|_| drop(acp_thread));
1268    let result = cx
1269        .update(|cx| {
1270            connection.prompt(
1271                Some(acp_thread::UserMessageId::new()),
1272                acp::PromptRequest {
1273                    session_id: session_id.clone(),
1274                    prompt: vec!["ghi".into()],
1275                },
1276                cx,
1277            )
1278        })
1279        .await;
1280    assert_eq!(
1281        result.as_ref().unwrap_err().to_string(),
1282        "Session not found",
1283        "unexpected result: {:?}",
1284        result
1285    );
1286}
1287
1288#[gpui::test]
1289async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1290    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1291    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1292    let fake_model = model.as_fake();
1293
1294    let mut events = thread.update(cx, |thread, cx| {
1295        thread.send(UserMessageId::new(), ["Think"], cx)
1296    });
1297    cx.run_until_parked();
1298
1299    // Simulate streaming partial input.
1300    let input = json!({});
1301    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1302        LanguageModelToolUse {
1303            id: "1".into(),
1304            name: ThinkingTool.name().into(),
1305            raw_input: input.to_string(),
1306            input,
1307            is_input_complete: false,
1308        },
1309    ));
1310
1311    // Input streaming completed
1312    let input = json!({ "content": "Thinking hard!" });
1313    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1314        LanguageModelToolUse {
1315            id: "1".into(),
1316            name: "thinking".into(),
1317            raw_input: input.to_string(),
1318            input,
1319            is_input_complete: true,
1320        },
1321    ));
1322    fake_model.end_last_completion_stream();
1323    cx.run_until_parked();
1324
1325    let tool_call = expect_tool_call(&mut events).await;
1326    assert_eq!(
1327        tool_call,
1328        acp::ToolCall {
1329            id: acp::ToolCallId("1".into()),
1330            title: "Thinking".into(),
1331            kind: acp::ToolKind::Think,
1332            status: acp::ToolCallStatus::Pending,
1333            content: vec![],
1334            locations: vec![],
1335            raw_input: Some(json!({})),
1336            raw_output: None,
1337        }
1338    );
1339    let update = expect_tool_call_update_fields(&mut events).await;
1340    assert_eq!(
1341        update,
1342        acp::ToolCallUpdate {
1343            id: acp::ToolCallId("1".into()),
1344            fields: acp::ToolCallUpdateFields {
1345                title: Some("Thinking".into()),
1346                kind: Some(acp::ToolKind::Think),
1347                raw_input: Some(json!({ "content": "Thinking hard!" })),
1348                ..Default::default()
1349            },
1350        }
1351    );
1352    let update = expect_tool_call_update_fields(&mut events).await;
1353    assert_eq!(
1354        update,
1355        acp::ToolCallUpdate {
1356            id: acp::ToolCallId("1".into()),
1357            fields: acp::ToolCallUpdateFields {
1358                status: Some(acp::ToolCallStatus::InProgress),
1359                ..Default::default()
1360            },
1361        }
1362    );
1363    let update = expect_tool_call_update_fields(&mut events).await;
1364    assert_eq!(
1365        update,
1366        acp::ToolCallUpdate {
1367            id: acp::ToolCallId("1".into()),
1368            fields: acp::ToolCallUpdateFields {
1369                content: Some(vec!["Thinking hard!".into()]),
1370                ..Default::default()
1371            },
1372        }
1373    );
1374    let update = expect_tool_call_update_fields(&mut events).await;
1375    assert_eq!(
1376        update,
1377        acp::ToolCallUpdate {
1378            id: acp::ToolCallId("1".into()),
1379            fields: acp::ToolCallUpdateFields {
1380                status: Some(acp::ToolCallStatus::Completed),
1381                raw_output: Some("Finished thinking.".into()),
1382                ..Default::default()
1383            },
1384        }
1385    );
1386}
1387
1388/// Filters out the stop events for asserting against in tests
1389fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
1390    result_events
1391        .into_iter()
1392        .filter_map(|event| match event.unwrap() {
1393            AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
1394            _ => None,
1395        })
1396        .collect()
1397}
1398
1399struct ThreadTest {
1400    model: Arc<dyn LanguageModel>,
1401    thread: Entity<Thread>,
1402    project_context: Rc<RefCell<ProjectContext>>,
1403    fs: Arc<FakeFs>,
1404}
1405
1406enum TestModel {
1407    Sonnet4,
1408    Sonnet4Thinking,
1409    Fake,
1410}
1411
1412impl TestModel {
1413    fn id(&self) -> LanguageModelId {
1414        match self {
1415            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
1416            TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
1417            TestModel::Fake => unreachable!(),
1418        }
1419    }
1420}
1421
1422async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
1423    cx.executor().allow_parking();
1424
1425    let fs = FakeFs::new(cx.background_executor.clone());
1426    fs.create_dir(paths::settings_file().parent().unwrap())
1427        .await
1428        .unwrap();
1429    fs.insert_file(
1430        paths::settings_file(),
1431        json!({
1432            "agent": {
1433                "default_profile": "test-profile",
1434                "profiles": {
1435                    "test-profile": {
1436                        "name": "Test Profile",
1437                        "tools": {
1438                            EchoTool.name(): true,
1439                            DelayTool.name(): true,
1440                            WordListTool.name(): true,
1441                            ToolRequiringPermission.name(): true,
1442                            InfiniteTool.name(): true,
1443                        }
1444                    }
1445                }
1446            }
1447        })
1448        .to_string()
1449        .into_bytes(),
1450    )
1451    .await;
1452
1453    cx.update(|cx| {
1454        settings::init(cx);
1455        Project::init_settings(cx);
1456        agent_settings::init(cx);
1457        gpui_tokio::init(cx);
1458        let http_client = ReqwestClient::user_agent("agent tests").unwrap();
1459        cx.set_http_client(Arc::new(http_client));
1460
1461        client::init_settings(cx);
1462        let client = Client::production(cx);
1463        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1464        language_model::init(client.clone(), cx);
1465        language_models::init(user_store.clone(), client.clone(), cx);
1466
1467        watch_settings(fs.clone(), cx);
1468    });
1469
1470    let templates = Templates::new();
1471
1472    fs.insert_tree(path!("/test"), json!({})).await;
1473    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1474
1475    let model = cx
1476        .update(|cx| {
1477            if let TestModel::Fake = model {
1478                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
1479            } else {
1480                let model_id = model.id();
1481                let models = LanguageModelRegistry::read_global(cx);
1482                let model = models
1483                    .available_models(cx)
1484                    .find(|model| model.id() == model_id)
1485                    .unwrap();
1486
1487                let provider = models.provider(&model.provider_id()).unwrap();
1488                let authenticated = provider.authenticate(cx);
1489
1490                cx.spawn(async move |_cx| {
1491                    authenticated.await.unwrap();
1492                    model
1493                })
1494            }
1495        })
1496        .await;
1497
1498    let project_context = Rc::new(RefCell::new(ProjectContext::default()));
1499    let context_server_registry =
1500        cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1501    let action_log = cx.new(|_| ActionLog::new(project.clone()));
1502    let thread = cx.new(|cx| {
1503        Thread::new(
1504            project,
1505            project_context.clone(),
1506            context_server_registry,
1507            action_log,
1508            templates,
1509            model.clone(),
1510            cx,
1511        )
1512    });
1513    ThreadTest {
1514        model,
1515        thread,
1516        project_context,
1517        fs,
1518    }
1519}
1520
1521#[cfg(test)]
1522#[ctor::ctor]
1523fn init_logger() {
1524    if std::env::var("RUST_LOG").is_ok() {
1525        env_logger::init();
1526    }
1527}
1528
1529fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
1530    let fs = fs.clone();
1531    cx.spawn({
1532        async move |cx| {
1533            let mut new_settings_content_rx = settings::watch_config_file(
1534                cx.background_executor(),
1535                fs,
1536                paths::settings_file().clone(),
1537            );
1538
1539            while let Some(new_settings_content) = new_settings_content_rx.next().await {
1540                cx.update(|cx| {
1541                    SettingsStore::update_global(cx, |settings, cx| {
1542                        settings.set_user_settings(&new_settings_content, cx)
1543                    })
1544                })
1545                .ok();
1546            }
1547        }
1548    })
1549    .detach();
1550}