mod.rs

   1use super::*;
   2use crate::MessageContent;
   3use acp_thread::AgentConnection;
   4use action_log::ActionLog;
   5use agent_client_protocol::{self as acp};
   6use agent_settings::AgentProfileId;
   7use anyhow::Result;
   8use client::{Client, UserStore};
   9use fs::{FakeFs, Fs};
  10use futures::channel::mpsc::UnboundedReceiver;
  11use gpui::{
  12    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  13};
  14use indoc::indoc;
  15use language_model::{
  16    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  17    LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, Role, StopReason,
  18    fake_provider::FakeLanguageModel,
  19};
  20use project::Project;
  21use prompt_store::ProjectContext;
  22use reqwest_client::ReqwestClient;
  23use schemars::JsonSchema;
  24use serde::{Deserialize, Serialize};
  25use serde_json::json;
  26use settings::SettingsStore;
  27use smol::stream::StreamExt;
  28use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
  29use util::path;
  30
  31mod test_tools;
  32use test_tools::*;
  33
  34#[gpui::test]
  35#[ignore = "can't run on CI yet"]
  36async fn test_echo(cx: &mut TestAppContext) {
  37    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
  38
  39    let events = thread
  40        .update(cx, |thread, cx| {
  41            thread.send("Testing: Reply with 'Hello'", cx)
  42        })
  43        .collect()
  44        .await;
  45    thread.update(cx, |thread, _cx| {
  46        assert_eq!(
  47            thread.messages().last().unwrap().content,
  48            vec![MessageContent::Text("Hello".to_string())]
  49        );
  50    });
  51    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  52}
  53
  54#[gpui::test]
  55#[ignore = "can't run on CI yet"]
  56async fn test_thinking(cx: &mut TestAppContext) {
  57    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
  58
  59    let events = thread
  60        .update(cx, |thread, cx| {
  61            thread.send(
  62                indoc! {"
  63                    Testing:
  64
  65                    Generate a thinking step where you just think the word 'Think',
  66                    and have your final answer be 'Hello'
  67                "},
  68                cx,
  69            )
  70        })
  71        .collect()
  72        .await;
  73    thread.update(cx, |thread, _cx| {
  74        assert_eq!(
  75            thread.messages().last().unwrap().to_markdown(),
  76            indoc! {"
  77                ## assistant
  78                <think>Think</think>
  79                Hello
  80            "}
  81        )
  82    });
  83    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  84}
  85
  86#[gpui::test]
  87async fn test_system_prompt(cx: &mut TestAppContext) {
  88    let ThreadTest {
  89        model,
  90        thread,
  91        project_context,
  92        ..
  93    } = setup(cx, TestModel::Fake).await;
  94    let fake_model = model.as_fake();
  95
  96    project_context.borrow_mut().shell = "test-shell".into();
  97    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
  98    thread.update(cx, |thread, cx| thread.send("abc", cx));
  99    cx.run_until_parked();
 100    let mut pending_completions = fake_model.pending_completions();
 101    assert_eq!(
 102        pending_completions.len(),
 103        1,
 104        "unexpected pending completions: {:?}",
 105        pending_completions
 106    );
 107
 108    let pending_completion = pending_completions.pop().unwrap();
 109    assert_eq!(pending_completion.messages[0].role, Role::System);
 110
 111    let system_message = &pending_completion.messages[0];
 112    let system_prompt = system_message.content[0].to_str().unwrap();
 113    assert!(
 114        system_prompt.contains("test-shell"),
 115        "unexpected system message: {:?}",
 116        system_message
 117    );
 118    assert!(
 119        system_prompt.contains("## Fixing Diagnostics"),
 120        "unexpected system message: {:?}",
 121        system_message
 122    );
 123}
 124
 125#[gpui::test]
 126#[ignore = "can't run on CI yet"]
 127async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 128    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 129
 130    // Test a tool call that's likely to complete *before* streaming stops.
 131    let events = thread
 132        .update(cx, |thread, cx| {
 133            thread.add_tool(EchoTool);
 134            thread.send(
 135                "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
 136                cx,
 137            )
 138        })
 139        .collect()
 140        .await;
 141    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 142
 143    // Test a tool calls that's likely to complete *after* streaming stops.
 144    let events = thread
 145        .update(cx, |thread, cx| {
 146            thread.remove_tool(&AgentTool::name(&EchoTool));
 147            thread.add_tool(DelayTool);
 148            thread.send(
 149                "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
 150                cx,
 151            )
 152        })
 153        .collect()
 154        .await;
 155    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 156    thread.update(cx, |thread, _cx| {
 157        assert!(
 158            thread
 159                .messages()
 160                .last()
 161                .unwrap()
 162                .content
 163                .iter()
 164                .any(|content| {
 165                    if let MessageContent::Text(text) = content {
 166                        text.contains("Ding")
 167                    } else {
 168                        false
 169                    }
 170                }),
 171            "{}",
 172            thread.to_markdown()
 173        );
 174    });
 175}
 176
 177#[gpui::test]
 178#[ignore = "can't run on CI yet"]
 179async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 180    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 181
 182    // Test a tool call that's likely to complete *before* streaming stops.
 183    let mut events = thread.update(cx, |thread, cx| {
 184        thread.add_tool(WordListTool);
 185        thread.send("Test the word_list tool.", cx)
 186    });
 187
 188    let mut saw_partial_tool_use = false;
 189    while let Some(event) = events.next().await {
 190        if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
 191            thread.update(cx, |thread, _cx| {
 192                // Look for a tool use in the thread's last message
 193                let last_content = thread.messages().last().unwrap().content.last().unwrap();
 194                if let MessageContent::ToolUse(last_tool_use) = last_content {
 195                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 196                    if tool_call.status == acp::ToolCallStatus::Pending {
 197                        if !last_tool_use.is_input_complete
 198                            && last_tool_use.input.get("g").is_none()
 199                        {
 200                            saw_partial_tool_use = true;
 201                        }
 202                    } else {
 203                        last_tool_use
 204                            .input
 205                            .get("a")
 206                            .expect("'a' has streamed because input is now complete");
 207                        last_tool_use
 208                            .input
 209                            .get("g")
 210                            .expect("'g' has streamed because input is now complete");
 211                    }
 212                } else {
 213                    panic!("last content should be a tool use");
 214                }
 215            });
 216        }
 217    }
 218
 219    assert!(
 220        saw_partial_tool_use,
 221        "should see at least one partially streamed tool use in the history"
 222    );
 223}
 224
 225#[gpui::test]
 226async fn test_tool_authorization(cx: &mut TestAppContext) {
 227    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 228    let fake_model = model.as_fake();
 229
 230    let mut events = thread.update(cx, |thread, cx| {
 231        thread.add_tool(ToolRequiringPermission);
 232        thread.send("abc", cx)
 233    });
 234    cx.run_until_parked();
 235    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 236        LanguageModelToolUse {
 237            id: "tool_id_1".into(),
 238            name: ToolRequiringPermission.name().into(),
 239            raw_input: "{}".into(),
 240            input: json!({}),
 241            is_input_complete: true,
 242        },
 243    ));
 244    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 245        LanguageModelToolUse {
 246            id: "tool_id_2".into(),
 247            name: ToolRequiringPermission.name().into(),
 248            raw_input: "{}".into(),
 249            input: json!({}),
 250            is_input_complete: true,
 251        },
 252    ));
 253    fake_model.end_last_completion_stream();
 254    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 255    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 256
 257    // Approve the first
 258    tool_call_auth_1
 259        .response
 260        .send(tool_call_auth_1.options[1].id.clone())
 261        .unwrap();
 262    cx.run_until_parked();
 263
 264    // Reject the second
 265    tool_call_auth_2
 266        .response
 267        .send(tool_call_auth_1.options[2].id.clone())
 268        .unwrap();
 269    cx.run_until_parked();
 270
 271    let completion = fake_model.pending_completions().pop().unwrap();
 272    let message = completion.messages.last().unwrap();
 273    assert_eq!(
 274        message.content,
 275        vec![
 276            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 277                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 278                tool_name: ToolRequiringPermission.name().into(),
 279                is_error: false,
 280                content: "Allowed".into(),
 281                output: Some("Allowed".into())
 282            }),
 283            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 284                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 285                tool_name: ToolRequiringPermission.name().into(),
 286                is_error: true,
 287                content: "Permission to run tool denied by user".into(),
 288                output: None
 289            })
 290        ]
 291    );
 292
 293    // Simulate yet another tool call.
 294    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 295        LanguageModelToolUse {
 296            id: "tool_id_3".into(),
 297            name: ToolRequiringPermission.name().into(),
 298            raw_input: "{}".into(),
 299            input: json!({}),
 300            is_input_complete: true,
 301        },
 302    ));
 303    fake_model.end_last_completion_stream();
 304
 305    // Respond by always allowing tools.
 306    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 307    tool_call_auth_3
 308        .response
 309        .send(tool_call_auth_3.options[0].id.clone())
 310        .unwrap();
 311    cx.run_until_parked();
 312    let completion = fake_model.pending_completions().pop().unwrap();
 313    let message = completion.messages.last().unwrap();
 314    assert_eq!(
 315        message.content,
 316        vec![language_model::MessageContent::ToolResult(
 317            LanguageModelToolResult {
 318                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 319                tool_name: ToolRequiringPermission.name().into(),
 320                is_error: false,
 321                content: "Allowed".into(),
 322                output: Some("Allowed".into())
 323            }
 324        )]
 325    );
 326
 327    // Simulate a final tool call, ensuring we don't trigger authorization.
 328    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 329        LanguageModelToolUse {
 330            id: "tool_id_4".into(),
 331            name: ToolRequiringPermission.name().into(),
 332            raw_input: "{}".into(),
 333            input: json!({}),
 334            is_input_complete: true,
 335        },
 336    ));
 337    fake_model.end_last_completion_stream();
 338    cx.run_until_parked();
 339    let completion = fake_model.pending_completions().pop().unwrap();
 340    let message = completion.messages.last().unwrap();
 341    assert_eq!(
 342        message.content,
 343        vec![language_model::MessageContent::ToolResult(
 344            LanguageModelToolResult {
 345                tool_use_id: "tool_id_4".into(),
 346                tool_name: ToolRequiringPermission.name().into(),
 347                is_error: false,
 348                content: "Allowed".into(),
 349                output: Some("Allowed".into())
 350            }
 351        )]
 352    );
 353}
 354
 355#[gpui::test]
 356async fn test_tool_hallucination(cx: &mut TestAppContext) {
 357    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 358    let fake_model = model.as_fake();
 359
 360    let mut events = thread.update(cx, |thread, cx| thread.send("abc", cx));
 361    cx.run_until_parked();
 362    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 363        LanguageModelToolUse {
 364            id: "tool_id_1".into(),
 365            name: "nonexistent_tool".into(),
 366            raw_input: "{}".into(),
 367            input: json!({}),
 368            is_input_complete: true,
 369        },
 370    ));
 371    fake_model.end_last_completion_stream();
 372
 373    let tool_call = expect_tool_call(&mut events).await;
 374    assert_eq!(tool_call.title, "nonexistent_tool");
 375    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 376    let update = expect_tool_call_update_fields(&mut events).await;
 377    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 378}
 379
 380async fn expect_tool_call(
 381    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
 382) -> acp::ToolCall {
 383    let event = events
 384        .next()
 385        .await
 386        .expect("no tool call authorization event received")
 387        .unwrap();
 388    match event {
 389        AgentResponseEvent::ToolCall(tool_call) => return tool_call,
 390        event => {
 391            panic!("Unexpected event {event:?}");
 392        }
 393    }
 394}
 395
 396async fn expect_tool_call_update_fields(
 397    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
 398) -> acp::ToolCallUpdate {
 399    let event = events
 400        .next()
 401        .await
 402        .expect("no tool call authorization event received")
 403        .unwrap();
 404    match event {
 405        AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
 406            return update;
 407        }
 408        event => {
 409            panic!("Unexpected event {event:?}");
 410        }
 411    }
 412}
 413
 414async fn next_tool_call_authorization(
 415    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
 416) -> ToolCallAuthorization {
 417    loop {
 418        let event = events
 419            .next()
 420            .await
 421            .expect("no tool call authorization event received")
 422            .unwrap();
 423        if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
 424            let permission_kinds = tool_call_authorization
 425                .options
 426                .iter()
 427                .map(|o| o.kind)
 428                .collect::<Vec<_>>();
 429            assert_eq!(
 430                permission_kinds,
 431                vec![
 432                    acp::PermissionOptionKind::AllowAlways,
 433                    acp::PermissionOptionKind::AllowOnce,
 434                    acp::PermissionOptionKind::RejectOnce,
 435                ]
 436            );
 437            return tool_call_authorization;
 438        }
 439    }
 440}
 441
 442#[gpui::test]
 443#[ignore = "can't run on CI yet"]
 444async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 445    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 446
 447    // Test concurrent tool calls with different delay times
 448    let events = thread
 449        .update(cx, |thread, cx| {
 450            thread.add_tool(DelayTool);
 451            thread.send(
 452                "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
 453                cx,
 454            )
 455        })
 456        .collect()
 457        .await;
 458
 459    let stop_reasons = stop_events(events);
 460    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 461
 462    thread.update(cx, |thread, _cx| {
 463        let last_message = thread.messages().last().unwrap();
 464        let text = last_message
 465            .content
 466            .iter()
 467            .filter_map(|content| {
 468                if let MessageContent::Text(text) = content {
 469                    Some(text.as_str())
 470                } else {
 471                    None
 472                }
 473            })
 474            .collect::<String>();
 475
 476        assert!(text.contains("Ding"));
 477    });
 478}
 479
 480#[gpui::test]
 481async fn test_profiles(cx: &mut TestAppContext) {
 482    let ThreadTest {
 483        model, thread, fs, ..
 484    } = setup(cx, TestModel::Fake).await;
 485    let fake_model = model.as_fake();
 486
 487    thread.update(cx, |thread, _cx| {
 488        thread.add_tool(DelayTool);
 489        thread.add_tool(EchoTool);
 490        thread.add_tool(InfiniteTool);
 491    });
 492
 493    // Override profiles and wait for settings to be loaded.
 494    fs.insert_file(
 495        paths::settings_file(),
 496        json!({
 497            "agent": {
 498                "profiles": {
 499                    "test-1": {
 500                        "name": "Test Profile 1",
 501                        "tools": {
 502                            EchoTool.name(): true,
 503                            DelayTool.name(): true,
 504                        }
 505                    },
 506                    "test-2": {
 507                        "name": "Test Profile 2",
 508                        "tools": {
 509                            InfiniteTool.name(): true,
 510                        }
 511                    }
 512                }
 513            }
 514        })
 515        .to_string()
 516        .into_bytes(),
 517    )
 518    .await;
 519    cx.run_until_parked();
 520
 521    // Test that test-1 profile (default) has echo and delay tools
 522    thread.update(cx, |thread, cx| {
 523        thread.set_profile(AgentProfileId("test-1".into()));
 524        thread.send("test", cx);
 525    });
 526    cx.run_until_parked();
 527
 528    let mut pending_completions = fake_model.pending_completions();
 529    assert_eq!(pending_completions.len(), 1);
 530    let completion = pending_completions.pop().unwrap();
 531    let tool_names: Vec<String> = completion
 532        .tools
 533        .iter()
 534        .map(|tool| tool.name.clone())
 535        .collect();
 536    assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
 537    fake_model.end_last_completion_stream();
 538
 539    // Switch to test-2 profile, and verify that it has only the infinite tool.
 540    thread.update(cx, |thread, cx| {
 541        thread.set_profile(AgentProfileId("test-2".into()));
 542        thread.send("test2", cx)
 543    });
 544    cx.run_until_parked();
 545    let mut pending_completions = fake_model.pending_completions();
 546    assert_eq!(pending_completions.len(), 1);
 547    let completion = pending_completions.pop().unwrap();
 548    let tool_names: Vec<String> = completion
 549        .tools
 550        .iter()
 551        .map(|tool| tool.name.clone())
 552        .collect();
 553    assert_eq!(tool_names, vec![InfiniteTool.name()]);
 554}
 555
 556#[gpui::test]
 557#[ignore = "can't run on CI yet"]
 558async fn test_cancellation(cx: &mut TestAppContext) {
 559    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 560
 561    let mut events = thread.update(cx, |thread, cx| {
 562        thread.add_tool(InfiniteTool);
 563        thread.add_tool(EchoTool);
 564        thread.send(
 565            "Call the echo tool and then call the infinite tool, then explain their output",
 566            cx,
 567        )
 568    });
 569
 570    // Wait until both tools are called.
 571    let mut expected_tools = vec!["Echo", "Infinite Tool"];
 572    let mut echo_id = None;
 573    let mut echo_completed = false;
 574    while let Some(event) = events.next().await {
 575        match event.unwrap() {
 576            AgentResponseEvent::ToolCall(tool_call) => {
 577                assert_eq!(tool_call.title, expected_tools.remove(0));
 578                if tool_call.title == "Echo" {
 579                    echo_id = Some(tool_call.id);
 580                }
 581            }
 582            AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
 583                acp::ToolCallUpdate {
 584                    id,
 585                    fields:
 586                        acp::ToolCallUpdateFields {
 587                            status: Some(acp::ToolCallStatus::Completed),
 588                            ..
 589                        },
 590                },
 591            )) if Some(&id) == echo_id.as_ref() => {
 592                echo_completed = true;
 593            }
 594            _ => {}
 595        }
 596
 597        if expected_tools.is_empty() && echo_completed {
 598            break;
 599        }
 600    }
 601
 602    // Cancel the current send and ensure that the event stream is closed, even
 603    // if one of the tools is still running.
 604    thread.update(cx, |thread, _cx| thread.cancel());
 605    events.collect::<Vec<_>>().await;
 606
 607    // Ensure we can still send a new message after cancellation.
 608    let events = thread
 609        .update(cx, |thread, cx| {
 610            thread.send("Testing: reply with 'Hello' then stop.", cx)
 611        })
 612        .collect::<Vec<_>>()
 613        .await;
 614    thread.update(cx, |thread, _cx| {
 615        assert_eq!(
 616            thread.messages().last().unwrap().content,
 617            vec![MessageContent::Text("Hello".to_string())]
 618        );
 619    });
 620    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 621}
 622
 623#[gpui::test]
 624async fn test_refusal(cx: &mut TestAppContext) {
 625    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 626    let fake_model = model.as_fake();
 627
 628    let events = thread.update(cx, |thread, cx| thread.send("Hello", cx));
 629    cx.run_until_parked();
 630    thread.read_with(cx, |thread, _| {
 631        assert_eq!(
 632            thread.to_markdown(),
 633            indoc! {"
 634                ## user
 635                Hello
 636            "}
 637        );
 638    });
 639
 640    fake_model.send_last_completion_stream_text_chunk("Hey!");
 641    cx.run_until_parked();
 642    thread.read_with(cx, |thread, _| {
 643        assert_eq!(
 644            thread.to_markdown(),
 645            indoc! {"
 646                ## user
 647                Hello
 648                ## assistant
 649                Hey!
 650            "}
 651        );
 652    });
 653
 654    // If the model refuses to continue, the thread should remove all the messages after the last user message.
 655    fake_model
 656        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
 657    let events = events.collect::<Vec<_>>().await;
 658    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
 659    thread.read_with(cx, |thread, _| {
 660        assert_eq!(thread.to_markdown(), "");
 661    });
 662}
 663
 664#[gpui::test]
 665async fn test_agent_connection(cx: &mut TestAppContext) {
 666    cx.update(settings::init);
 667    let templates = Templates::new();
 668
 669    // Initialize language model system with test provider
 670    cx.update(|cx| {
 671        gpui_tokio::init(cx);
 672        client::init_settings(cx);
 673
 674        let http_client = FakeHttpClient::with_404_response();
 675        let clock = Arc::new(clock::FakeSystemClock::new());
 676        let client = Client::new(clock, http_client, cx);
 677        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
 678        language_model::init(client.clone(), cx);
 679        language_models::init(user_store.clone(), client.clone(), cx);
 680        Project::init_settings(cx);
 681        LanguageModelRegistry::test(cx);
 682        agent_settings::init(cx);
 683    });
 684    cx.executor().forbid_parking();
 685
 686    // Create a project for new_thread
 687    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
 688    fake_fs.insert_tree(path!("/test"), json!({})).await;
 689    let project = Project::test(fake_fs, [Path::new("/test")], cx).await;
 690    let cwd = Path::new("/test");
 691
 692    // Create agent and connection
 693    let agent = NativeAgent::new(project.clone(), templates.clone(), None, &mut cx.to_async())
 694        .await
 695        .unwrap();
 696    let connection = NativeAgentConnection(agent.clone());
 697
 698    // Test model_selector returns Some
 699    let selector_opt = connection.model_selector();
 700    assert!(
 701        selector_opt.is_some(),
 702        "agent2 should always support ModelSelector"
 703    );
 704    let selector = selector_opt.unwrap();
 705
 706    // Test list_models
 707    let listed_models = cx
 708        .update(|cx| {
 709            let mut async_cx = cx.to_async();
 710            selector.list_models(&mut async_cx)
 711        })
 712        .await
 713        .expect("list_models should succeed");
 714    assert!(!listed_models.is_empty(), "should have at least one model");
 715    assert_eq!(listed_models[0].id().0, "fake");
 716
 717    // Create a thread using new_thread
 718    let connection_rc = Rc::new(connection.clone());
 719    let acp_thread = cx
 720        .update(|cx| {
 721            let mut async_cx = cx.to_async();
 722            connection_rc.new_thread(project, cwd, &mut async_cx)
 723        })
 724        .await
 725        .expect("new_thread should succeed");
 726
 727    // Get the session_id from the AcpThread
 728    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
 729
 730    // Test selected_model returns the default
 731    let model = cx
 732        .update(|cx| {
 733            let mut async_cx = cx.to_async();
 734            selector.selected_model(&session_id, &mut async_cx)
 735        })
 736        .await
 737        .expect("selected_model should succeed");
 738    let model = model.as_fake();
 739    assert_eq!(model.id().0, "fake", "should return default model");
 740
 741    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
 742    cx.run_until_parked();
 743    model.send_last_completion_stream_text_chunk("def");
 744    cx.run_until_parked();
 745    acp_thread.read_with(cx, |thread, cx| {
 746        assert_eq!(
 747            thread.to_markdown(cx),
 748            indoc! {"
 749                ## User
 750
 751                abc
 752
 753                ## Assistant
 754
 755                def
 756
 757            "}
 758        )
 759    });
 760
 761    // Test cancel
 762    cx.update(|cx| connection.cancel(&session_id, cx));
 763    request.await.expect("prompt should fail gracefully");
 764
 765    // Ensure that dropping the ACP thread causes the native thread to be
 766    // dropped as well.
 767    cx.update(|_| drop(acp_thread));
 768    let result = cx
 769        .update(|cx| {
 770            connection.prompt(
 771                acp::PromptRequest {
 772                    session_id: session_id.clone(),
 773                    prompt: vec!["ghi".into()],
 774                },
 775                cx,
 776            )
 777        })
 778        .await;
 779    assert_eq!(
 780        result.as_ref().unwrap_err().to_string(),
 781        "Session not found",
 782        "unexpected result: {:?}",
 783        result
 784    );
 785}
 786
 787#[gpui::test]
 788async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
 789    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
 790    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
 791    let fake_model = model.as_fake();
 792
 793    let mut events = thread.update(cx, |thread, cx| thread.send("Think", cx));
 794    cx.run_until_parked();
 795
 796    // Simulate streaming partial input.
 797    let input = json!({});
 798    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 799        LanguageModelToolUse {
 800            id: "1".into(),
 801            name: ThinkingTool.name().into(),
 802            raw_input: input.to_string(),
 803            input,
 804            is_input_complete: false,
 805        },
 806    ));
 807
 808    // Input streaming completed
 809    let input = json!({ "content": "Thinking hard!" });
 810    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 811        LanguageModelToolUse {
 812            id: "1".into(),
 813            name: "thinking".into(),
 814            raw_input: input.to_string(),
 815            input,
 816            is_input_complete: true,
 817        },
 818    ));
 819    fake_model.end_last_completion_stream();
 820    cx.run_until_parked();
 821
 822    let tool_call = expect_tool_call(&mut events).await;
 823    assert_eq!(
 824        tool_call,
 825        acp::ToolCall {
 826            id: acp::ToolCallId("1".into()),
 827            title: "Thinking".into(),
 828            kind: acp::ToolKind::Think,
 829            status: acp::ToolCallStatus::Pending,
 830            content: vec![],
 831            locations: vec![],
 832            raw_input: Some(json!({})),
 833            raw_output: None,
 834        }
 835    );
 836    let update = expect_tool_call_update_fields(&mut events).await;
 837    assert_eq!(
 838        update,
 839        acp::ToolCallUpdate {
 840            id: acp::ToolCallId("1".into()),
 841            fields: acp::ToolCallUpdateFields {
 842                title: Some("Thinking".into()),
 843                kind: Some(acp::ToolKind::Think),
 844                raw_input: Some(json!({ "content": "Thinking hard!" })),
 845                ..Default::default()
 846            },
 847        }
 848    );
 849    let update = expect_tool_call_update_fields(&mut events).await;
 850    assert_eq!(
 851        update,
 852        acp::ToolCallUpdate {
 853            id: acp::ToolCallId("1".into()),
 854            fields: acp::ToolCallUpdateFields {
 855                status: Some(acp::ToolCallStatus::InProgress),
 856                ..Default::default()
 857            },
 858        }
 859    );
 860    let update = expect_tool_call_update_fields(&mut events).await;
 861    assert_eq!(
 862        update,
 863        acp::ToolCallUpdate {
 864            id: acp::ToolCallId("1".into()),
 865            fields: acp::ToolCallUpdateFields {
 866                content: Some(vec!["Thinking hard!".into()]),
 867                ..Default::default()
 868            },
 869        }
 870    );
 871    let update = expect_tool_call_update_fields(&mut events).await;
 872    assert_eq!(
 873        update,
 874        acp::ToolCallUpdate {
 875            id: acp::ToolCallId("1".into()),
 876            fields: acp::ToolCallUpdateFields {
 877                status: Some(acp::ToolCallStatus::Completed),
 878                raw_output: Some("Finished thinking.".into()),
 879                ..Default::default()
 880            },
 881        }
 882    );
 883}
 884
 885/// Filters out the stop events for asserting against in tests
 886fn stop_events(
 887    result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
 888) -> Vec<acp::StopReason> {
 889    result_events
 890        .into_iter()
 891        .filter_map(|event| match event.unwrap() {
 892            AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
 893            _ => None,
 894        })
 895        .collect()
 896}
 897
 898struct ThreadTest {
 899    model: Arc<dyn LanguageModel>,
 900    thread: Entity<Thread>,
 901    project_context: Rc<RefCell<ProjectContext>>,
 902    fs: Arc<FakeFs>,
 903}
 904
 905enum TestModel {
 906    Sonnet4,
 907    Sonnet4Thinking,
 908    Fake,
 909}
 910
 911impl TestModel {
 912    fn id(&self) -> LanguageModelId {
 913        match self {
 914            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
 915            TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
 916            TestModel::Fake => unreachable!(),
 917        }
 918    }
 919}
 920
 921async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
 922    cx.executor().allow_parking();
 923
 924    let fs = FakeFs::new(cx.background_executor.clone());
 925    fs.create_dir(paths::settings_file().parent().unwrap())
 926        .await
 927        .unwrap();
 928    fs.insert_file(
 929        paths::settings_file(),
 930        json!({
 931            "agent": {
 932                "default_profile": "test-profile",
 933                "profiles": {
 934                    "test-profile": {
 935                        "name": "Test Profile",
 936                        "tools": {
 937                            EchoTool.name(): true,
 938                            DelayTool.name(): true,
 939                            WordListTool.name(): true,
 940                            ToolRequiringPermission.name(): true,
 941                            InfiniteTool.name(): true,
 942                        }
 943                    }
 944                }
 945            }
 946        })
 947        .to_string()
 948        .into_bytes(),
 949    )
 950    .await;
 951
 952    cx.update(|cx| {
 953        settings::init(cx);
 954        Project::init_settings(cx);
 955        agent_settings::init(cx);
 956        gpui_tokio::init(cx);
 957        let http_client = ReqwestClient::user_agent("agent tests").unwrap();
 958        cx.set_http_client(Arc::new(http_client));
 959
 960        client::init_settings(cx);
 961        let client = Client::production(cx);
 962        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
 963        language_model::init(client.clone(), cx);
 964        language_models::init(user_store.clone(), client.clone(), cx);
 965
 966        watch_settings(fs.clone(), cx);
 967    });
 968
 969    let templates = Templates::new();
 970
 971    fs.insert_tree(path!("/test"), json!({})).await;
 972    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
 973
 974    let model = cx
 975        .update(|cx| {
 976            if let TestModel::Fake = model {
 977                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
 978            } else {
 979                let model_id = model.id();
 980                let models = LanguageModelRegistry::read_global(cx);
 981                let model = models
 982                    .available_models(cx)
 983                    .find(|model| model.id() == model_id)
 984                    .unwrap();
 985
 986                let provider = models.provider(&model.provider_id()).unwrap();
 987                let authenticated = provider.authenticate(cx);
 988
 989                cx.spawn(async move |_cx| {
 990                    authenticated.await.unwrap();
 991                    model
 992                })
 993            }
 994        })
 995        .await;
 996
 997    let project_context = Rc::new(RefCell::new(ProjectContext::default()));
 998    let context_server_registry =
 999        cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1000    let action_log = cx.new(|_| ActionLog::new(project.clone()));
1001    let thread = cx.new(|cx| {
1002        Thread::new(
1003            project,
1004            project_context.clone(),
1005            context_server_registry,
1006            action_log,
1007            templates,
1008            model.clone(),
1009            cx,
1010        )
1011    });
1012    ThreadTest {
1013        model,
1014        thread,
1015        project_context,
1016        fs,
1017    }
1018}
1019
1020#[cfg(test)]
1021#[ctor::ctor]
1022fn init_logger() {
1023    if std::env::var("RUST_LOG").is_ok() {
1024        env_logger::init();
1025    }
1026}
1027
1028fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
1029    let fs = fs.clone();
1030    cx.spawn({
1031        async move |cx| {
1032            let mut new_settings_content_rx = settings::watch_config_file(
1033                cx.background_executor(),
1034                fs,
1035                paths::settings_file().clone(),
1036            );
1037
1038            while let Some(new_settings_content) = new_settings_content_rx.next().await {
1039                cx.update(|cx| {
1040                    SettingsStore::update_global(cx, |settings, cx| {
1041                        settings.set_user_settings(&new_settings_content, cx)
1042                    })
1043                })
1044                .ok();
1045            }
1046        }
1047    })
1048    .detach();
1049}