mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use action_log::ActionLog;
   4use agent_client_protocol::{self as acp};
   5use agent_settings::AgentProfileId;
   6use anyhow::Result;
   7use client::{Client, UserStore};
   8use fs::{FakeFs, Fs};
   9use futures::channel::mpsc::UnboundedReceiver;
  10use gpui::{
  11    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  12};
  13use indoc::indoc;
  14use language_model::{
  15    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  16    LanguageModelRegistry, LanguageModelToolResult, LanguageModelToolUse, Role, StopReason,
  17    fake_provider::FakeLanguageModel,
  18};
  19use project::Project;
  20use prompt_store::ProjectContext;
  21use reqwest_client::ReqwestClient;
  22use schemars::JsonSchema;
  23use serde::{Deserialize, Serialize};
  24use serde_json::json;
  25use settings::SettingsStore;
  26use smol::stream::StreamExt;
  27use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
  28use util::path;
  29
  30mod test_tools;
  31use test_tools::*;
  32
  33#[gpui::test]
  34#[ignore = "can't run on CI yet"]
  35async fn test_echo(cx: &mut TestAppContext) {
  36    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
  37
  38    let events = thread
  39        .update(cx, |thread, cx| {
  40            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  41        })
  42        .collect()
  43        .await;
  44    thread.update(cx, |thread, _cx| {
  45        assert_eq!(
  46            thread.last_message().unwrap().to_markdown(),
  47            indoc! {"
  48                ## Assistant
  49
  50                Hello
  51            "}
  52        )
  53    });
  54    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  55}
  56
  57#[gpui::test]
  58#[ignore = "can't run on CI yet"]
  59async fn test_thinking(cx: &mut TestAppContext) {
  60    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
  61
  62    let events = thread
  63        .update(cx, |thread, cx| {
  64            thread.send(
  65                UserMessageId::new(),
  66                [indoc! {"
  67                    Testing:
  68
  69                    Generate a thinking step where you just think the word 'Think',
  70                    and have your final answer be 'Hello'
  71                "}],
  72                cx,
  73            )
  74        })
  75        .collect()
  76        .await;
  77    thread.update(cx, |thread, _cx| {
  78        assert_eq!(
  79            thread.last_message().unwrap().to_markdown(),
  80            indoc! {"
  81                ## Assistant
  82
  83                <think>Think</think>
  84                Hello
  85            "}
  86        )
  87    });
  88    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  89}
  90
  91#[gpui::test]
  92async fn test_system_prompt(cx: &mut TestAppContext) {
  93    let ThreadTest {
  94        model,
  95        thread,
  96        project_context,
  97        ..
  98    } = setup(cx, TestModel::Fake).await;
  99    let fake_model = model.as_fake();
 100
 101    project_context.borrow_mut().shell = "test-shell".into();
 102    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 103    thread.update(cx, |thread, cx| {
 104        thread.send(UserMessageId::new(), ["abc"], cx)
 105    });
 106    cx.run_until_parked();
 107    let mut pending_completions = fake_model.pending_completions();
 108    assert_eq!(
 109        pending_completions.len(),
 110        1,
 111        "unexpected pending completions: {:?}",
 112        pending_completions
 113    );
 114
 115    let pending_completion = pending_completions.pop().unwrap();
 116    assert_eq!(pending_completion.messages[0].role, Role::System);
 117
 118    let system_message = &pending_completion.messages[0];
 119    let system_prompt = system_message.content[0].to_str().unwrap();
 120    assert!(
 121        system_prompt.contains("test-shell"),
 122        "unexpected system message: {:?}",
 123        system_message
 124    );
 125    assert!(
 126        system_prompt.contains("## Fixing Diagnostics"),
 127        "unexpected system message: {:?}",
 128        system_message
 129    );
 130}
 131
 132#[gpui::test]
 133#[ignore = "can't run on CI yet"]
 134async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 135    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 136
 137    // Test a tool call that's likely to complete *before* streaming stops.
 138    let events = thread
 139        .update(cx, |thread, cx| {
 140            thread.add_tool(EchoTool);
 141            thread.send(
 142                UserMessageId::new(),
 143                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 144                cx,
 145            )
 146        })
 147        .collect()
 148        .await;
 149    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 150
 151    // Test a tool calls that's likely to complete *after* streaming stops.
 152    let events = thread
 153        .update(cx, |thread, cx| {
 154            thread.remove_tool(&AgentTool::name(&EchoTool));
 155            thread.add_tool(DelayTool);
 156            thread.send(
 157                UserMessageId::new(),
 158                [
 159                    "Now call the delay tool with 200ms.",
 160                    "When the timer goes off, then you echo the output of the tool.",
 161                ],
 162                cx,
 163            )
 164        })
 165        .collect()
 166        .await;
 167    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 168    thread.update(cx, |thread, _cx| {
 169        assert!(
 170            thread
 171                .last_message()
 172                .unwrap()
 173                .as_agent_message()
 174                .unwrap()
 175                .content
 176                .iter()
 177                .any(|content| {
 178                    if let AgentMessageContent::Text(text) = content {
 179                        text.contains("Ding")
 180                    } else {
 181                        false
 182                    }
 183                }),
 184            "{}",
 185            thread.to_markdown()
 186        );
 187    });
 188}
 189
 190#[gpui::test]
 191#[ignore = "can't run on CI yet"]
 192async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 193    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 194
 195    // Test a tool call that's likely to complete *before* streaming stops.
 196    let mut events = thread.update(cx, |thread, cx| {
 197        thread.add_tool(WordListTool);
 198        thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 199    });
 200
 201    let mut saw_partial_tool_use = false;
 202    while let Some(event) = events.next().await {
 203        if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
 204            thread.update(cx, |thread, _cx| {
 205                // Look for a tool use in the thread's last message
 206                let message = thread.last_message().unwrap();
 207                let agent_message = message.as_agent_message().unwrap();
 208                let last_content = agent_message.content.last().unwrap();
 209                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 210                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 211                    if tool_call.status == acp::ToolCallStatus::Pending {
 212                        if !last_tool_use.is_input_complete
 213                            && last_tool_use.input.get("g").is_none()
 214                        {
 215                            saw_partial_tool_use = true;
 216                        }
 217                    } else {
 218                        last_tool_use
 219                            .input
 220                            .get("a")
 221                            .expect("'a' has streamed because input is now complete");
 222                        last_tool_use
 223                            .input
 224                            .get("g")
 225                            .expect("'g' has streamed because input is now complete");
 226                    }
 227                } else {
 228                    panic!("last content should be a tool use");
 229                }
 230            });
 231        }
 232    }
 233
 234    assert!(
 235        saw_partial_tool_use,
 236        "should see at least one partially streamed tool use in the history"
 237    );
 238}
 239
 240#[gpui::test]
 241async fn test_tool_authorization(cx: &mut TestAppContext) {
 242    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 243    let fake_model = model.as_fake();
 244
 245    let mut events = thread.update(cx, |thread, cx| {
 246        thread.add_tool(ToolRequiringPermission);
 247        thread.send(UserMessageId::new(), ["abc"], cx)
 248    });
 249    cx.run_until_parked();
 250    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 251        LanguageModelToolUse {
 252            id: "tool_id_1".into(),
 253            name: ToolRequiringPermission.name().into(),
 254            raw_input: "{}".into(),
 255            input: json!({}),
 256            is_input_complete: true,
 257        },
 258    ));
 259    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 260        LanguageModelToolUse {
 261            id: "tool_id_2".into(),
 262            name: ToolRequiringPermission.name().into(),
 263            raw_input: "{}".into(),
 264            input: json!({}),
 265            is_input_complete: true,
 266        },
 267    ));
 268    fake_model.end_last_completion_stream();
 269    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 270    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 271
 272    // Approve the first
 273    tool_call_auth_1
 274        .response
 275        .send(tool_call_auth_1.options[1].id.clone())
 276        .unwrap();
 277    cx.run_until_parked();
 278
 279    // Reject the second
 280    tool_call_auth_2
 281        .response
 282        .send(tool_call_auth_1.options[2].id.clone())
 283        .unwrap();
 284    cx.run_until_parked();
 285
 286    let completion = fake_model.pending_completions().pop().unwrap();
 287    let message = completion.messages.last().unwrap();
 288    assert_eq!(
 289        message.content,
 290        vec![
 291            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 292                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 293                tool_name: ToolRequiringPermission.name().into(),
 294                is_error: false,
 295                content: "Allowed".into(),
 296                output: Some("Allowed".into())
 297            }),
 298            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 299                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 300                tool_name: ToolRequiringPermission.name().into(),
 301                is_error: true,
 302                content: "Permission to run tool denied by user".into(),
 303                output: None
 304            })
 305        ]
 306    );
 307
 308    // Simulate yet another tool call.
 309    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 310        LanguageModelToolUse {
 311            id: "tool_id_3".into(),
 312            name: ToolRequiringPermission.name().into(),
 313            raw_input: "{}".into(),
 314            input: json!({}),
 315            is_input_complete: true,
 316        },
 317    ));
 318    fake_model.end_last_completion_stream();
 319
 320    // Respond by always allowing tools.
 321    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 322    tool_call_auth_3
 323        .response
 324        .send(tool_call_auth_3.options[0].id.clone())
 325        .unwrap();
 326    cx.run_until_parked();
 327    let completion = fake_model.pending_completions().pop().unwrap();
 328    let message = completion.messages.last().unwrap();
 329    assert_eq!(
 330        message.content,
 331        vec![language_model::MessageContent::ToolResult(
 332            LanguageModelToolResult {
 333                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 334                tool_name: ToolRequiringPermission.name().into(),
 335                is_error: false,
 336                content: "Allowed".into(),
 337                output: Some("Allowed".into())
 338            }
 339        )]
 340    );
 341
 342    // Simulate a final tool call, ensuring we don't trigger authorization.
 343    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 344        LanguageModelToolUse {
 345            id: "tool_id_4".into(),
 346            name: ToolRequiringPermission.name().into(),
 347            raw_input: "{}".into(),
 348            input: json!({}),
 349            is_input_complete: true,
 350        },
 351    ));
 352    fake_model.end_last_completion_stream();
 353    cx.run_until_parked();
 354    let completion = fake_model.pending_completions().pop().unwrap();
 355    let message = completion.messages.last().unwrap();
 356    assert_eq!(
 357        message.content,
 358        vec![language_model::MessageContent::ToolResult(
 359            LanguageModelToolResult {
 360                tool_use_id: "tool_id_4".into(),
 361                tool_name: ToolRequiringPermission.name().into(),
 362                is_error: false,
 363                content: "Allowed".into(),
 364                output: Some("Allowed".into())
 365            }
 366        )]
 367    );
 368}
 369
 370#[gpui::test]
 371async fn test_tool_hallucination(cx: &mut TestAppContext) {
 372    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 373    let fake_model = model.as_fake();
 374
 375    let mut events = thread.update(cx, |thread, cx| {
 376        thread.send(UserMessageId::new(), ["abc"], cx)
 377    });
 378    cx.run_until_parked();
 379    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 380        LanguageModelToolUse {
 381            id: "tool_id_1".into(),
 382            name: "nonexistent_tool".into(),
 383            raw_input: "{}".into(),
 384            input: json!({}),
 385            is_input_complete: true,
 386        },
 387    ));
 388    fake_model.end_last_completion_stream();
 389
 390    let tool_call = expect_tool_call(&mut events).await;
 391    assert_eq!(tool_call.title, "nonexistent_tool");
 392    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 393    let update = expect_tool_call_update_fields(&mut events).await;
 394    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 395}
 396
 397async fn expect_tool_call(
 398    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
 399) -> acp::ToolCall {
 400    let event = events
 401        .next()
 402        .await
 403        .expect("no tool call authorization event received")
 404        .unwrap();
 405    match event {
 406        AgentResponseEvent::ToolCall(tool_call) => return tool_call,
 407        event => {
 408            panic!("Unexpected event {event:?}");
 409        }
 410    }
 411}
 412
 413async fn expect_tool_call_update_fields(
 414    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
 415) -> acp::ToolCallUpdate {
 416    let event = events
 417        .next()
 418        .await
 419        .expect("no tool call authorization event received")
 420        .unwrap();
 421    match event {
 422        AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
 423            return update;
 424        }
 425        event => {
 426            panic!("Unexpected event {event:?}");
 427        }
 428    }
 429}
 430
 431async fn next_tool_call_authorization(
 432    events: &mut UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>,
 433) -> ToolCallAuthorization {
 434    loop {
 435        let event = events
 436            .next()
 437            .await
 438            .expect("no tool call authorization event received")
 439            .unwrap();
 440        if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
 441            let permission_kinds = tool_call_authorization
 442                .options
 443                .iter()
 444                .map(|o| o.kind)
 445                .collect::<Vec<_>>();
 446            assert_eq!(
 447                permission_kinds,
 448                vec![
 449                    acp::PermissionOptionKind::AllowAlways,
 450                    acp::PermissionOptionKind::AllowOnce,
 451                    acp::PermissionOptionKind::RejectOnce,
 452                ]
 453            );
 454            return tool_call_authorization;
 455        }
 456    }
 457}
 458
 459#[gpui::test]
 460#[ignore = "can't run on CI yet"]
 461async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 462    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 463
 464    // Test concurrent tool calls with different delay times
 465    let events = thread
 466        .update(cx, |thread, cx| {
 467            thread.add_tool(DelayTool);
 468            thread.send(
 469                UserMessageId::new(),
 470                [
 471                    "Call the delay tool twice in the same message.",
 472                    "Once with 100ms. Once with 300ms.",
 473                    "When both timers are complete, describe the outputs.",
 474                ],
 475                cx,
 476            )
 477        })
 478        .collect()
 479        .await;
 480
 481    let stop_reasons = stop_events(events);
 482    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 483
 484    thread.update(cx, |thread, _cx| {
 485        let last_message = thread.last_message().unwrap();
 486        let agent_message = last_message.as_agent_message().unwrap();
 487        let text = agent_message
 488            .content
 489            .iter()
 490            .filter_map(|content| {
 491                if let AgentMessageContent::Text(text) = content {
 492                    Some(text.as_str())
 493                } else {
 494                    None
 495                }
 496            })
 497            .collect::<String>();
 498
 499        assert!(text.contains("Ding"));
 500    });
 501}
 502
 503#[gpui::test]
 504async fn test_profiles(cx: &mut TestAppContext) {
 505    let ThreadTest {
 506        model, thread, fs, ..
 507    } = setup(cx, TestModel::Fake).await;
 508    let fake_model = model.as_fake();
 509
 510    thread.update(cx, |thread, _cx| {
 511        thread.add_tool(DelayTool);
 512        thread.add_tool(EchoTool);
 513        thread.add_tool(InfiniteTool);
 514    });
 515
 516    // Override profiles and wait for settings to be loaded.
 517    fs.insert_file(
 518        paths::settings_file(),
 519        json!({
 520            "agent": {
 521                "profiles": {
 522                    "test-1": {
 523                        "name": "Test Profile 1",
 524                        "tools": {
 525                            EchoTool.name(): true,
 526                            DelayTool.name(): true,
 527                        }
 528                    },
 529                    "test-2": {
 530                        "name": "Test Profile 2",
 531                        "tools": {
 532                            InfiniteTool.name(): true,
 533                        }
 534                    }
 535                }
 536            }
 537        })
 538        .to_string()
 539        .into_bytes(),
 540    )
 541    .await;
 542    cx.run_until_parked();
 543
 544    // Test that test-1 profile (default) has echo and delay tools
 545    thread.update(cx, |thread, cx| {
 546        thread.set_profile(AgentProfileId("test-1".into()));
 547        thread.send(UserMessageId::new(), ["test"], cx);
 548    });
 549    cx.run_until_parked();
 550
 551    let mut pending_completions = fake_model.pending_completions();
 552    assert_eq!(pending_completions.len(), 1);
 553    let completion = pending_completions.pop().unwrap();
 554    let tool_names: Vec<String> = completion
 555        .tools
 556        .iter()
 557        .map(|tool| tool.name.clone())
 558        .collect();
 559    assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
 560    fake_model.end_last_completion_stream();
 561
 562    // Switch to test-2 profile, and verify that it has only the infinite tool.
 563    thread.update(cx, |thread, cx| {
 564        thread.set_profile(AgentProfileId("test-2".into()));
 565        thread.send(UserMessageId::new(), ["test2"], cx)
 566    });
 567    cx.run_until_parked();
 568    let mut pending_completions = fake_model.pending_completions();
 569    assert_eq!(pending_completions.len(), 1);
 570    let completion = pending_completions.pop().unwrap();
 571    let tool_names: Vec<String> = completion
 572        .tools
 573        .iter()
 574        .map(|tool| tool.name.clone())
 575        .collect();
 576    assert_eq!(tool_names, vec![InfiniteTool.name()]);
 577}
 578
 579#[gpui::test]
 580#[ignore = "can't run on CI yet"]
 581async fn test_cancellation(cx: &mut TestAppContext) {
 582    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 583
 584    let mut events = thread.update(cx, |thread, cx| {
 585        thread.add_tool(InfiniteTool);
 586        thread.add_tool(EchoTool);
 587        thread.send(
 588            UserMessageId::new(),
 589            ["Call the echo tool, then call the infinite tool, then explain their output"],
 590            cx,
 591        )
 592    });
 593
 594    // Wait until both tools are called.
 595    let mut expected_tools = vec!["Echo", "Infinite Tool"];
 596    let mut echo_id = None;
 597    let mut echo_completed = false;
 598    while let Some(event) = events.next().await {
 599        match event.unwrap() {
 600            AgentResponseEvent::ToolCall(tool_call) => {
 601                assert_eq!(tool_call.title, expected_tools.remove(0));
 602                if tool_call.title == "Echo" {
 603                    echo_id = Some(tool_call.id);
 604                }
 605            }
 606            AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
 607                acp::ToolCallUpdate {
 608                    id,
 609                    fields:
 610                        acp::ToolCallUpdateFields {
 611                            status: Some(acp::ToolCallStatus::Completed),
 612                            ..
 613                        },
 614                },
 615            )) if Some(&id) == echo_id.as_ref() => {
 616                echo_completed = true;
 617            }
 618            _ => {}
 619        }
 620
 621        if expected_tools.is_empty() && echo_completed {
 622            break;
 623        }
 624    }
 625
 626    // Cancel the current send and ensure that the event stream is closed, even
 627    // if one of the tools is still running.
 628    thread.update(cx, |thread, _cx| thread.cancel());
 629    events.collect::<Vec<_>>().await;
 630
 631    // Ensure we can still send a new message after cancellation.
 632    let events = thread
 633        .update(cx, |thread, cx| {
 634            thread.send(
 635                UserMessageId::new(),
 636                ["Testing: reply with 'Hello' then stop."],
 637                cx,
 638            )
 639        })
 640        .collect::<Vec<_>>()
 641        .await;
 642    thread.update(cx, |thread, _cx| {
 643        let message = thread.last_message().unwrap();
 644        let agent_message = message.as_agent_message().unwrap();
 645        assert_eq!(
 646            agent_message.content,
 647            vec![AgentMessageContent::Text("Hello".to_string())]
 648        );
 649    });
 650    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 651}
 652
 653#[gpui::test]
 654async fn test_refusal(cx: &mut TestAppContext) {
 655    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 656    let fake_model = model.as_fake();
 657
 658    let events = thread.update(cx, |thread, cx| {
 659        thread.send(UserMessageId::new(), ["Hello"], cx)
 660    });
 661    cx.run_until_parked();
 662    thread.read_with(cx, |thread, _| {
 663        assert_eq!(
 664            thread.to_markdown(),
 665            indoc! {"
 666                ## User
 667
 668                Hello
 669            "}
 670        );
 671    });
 672
 673    fake_model.send_last_completion_stream_text_chunk("Hey!");
 674    cx.run_until_parked();
 675    thread.read_with(cx, |thread, _| {
 676        assert_eq!(
 677            thread.to_markdown(),
 678            indoc! {"
 679                ## User
 680
 681                Hello
 682
 683                ## Assistant
 684
 685                Hey!
 686            "}
 687        );
 688    });
 689
 690    // If the model refuses to continue, the thread should remove all the messages after the last user message.
 691    fake_model
 692        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
 693    let events = events.collect::<Vec<_>>().await;
 694    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
 695    thread.read_with(cx, |thread, _| {
 696        assert_eq!(thread.to_markdown(), "");
 697    });
 698}
 699
 700#[gpui::test]
 701async fn test_truncate(cx: &mut TestAppContext) {
 702    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 703    let fake_model = model.as_fake();
 704
 705    let message_id = UserMessageId::new();
 706    thread.update(cx, |thread, cx| {
 707        thread.send(message_id.clone(), ["Hello"], cx)
 708    });
 709    cx.run_until_parked();
 710    thread.read_with(cx, |thread, _| {
 711        assert_eq!(
 712            thread.to_markdown(),
 713            indoc! {"
 714                ## User
 715
 716                Hello
 717            "}
 718        );
 719    });
 720
 721    fake_model.send_last_completion_stream_text_chunk("Hey!");
 722    cx.run_until_parked();
 723    thread.read_with(cx, |thread, _| {
 724        assert_eq!(
 725            thread.to_markdown(),
 726            indoc! {"
 727                ## User
 728
 729                Hello
 730
 731                ## Assistant
 732
 733                Hey!
 734            "}
 735        );
 736    });
 737
 738    thread
 739        .update(cx, |thread, _cx| thread.truncate(message_id))
 740        .unwrap();
 741    cx.run_until_parked();
 742    thread.read_with(cx, |thread, _| {
 743        assert_eq!(thread.to_markdown(), "");
 744    });
 745
 746    // Ensure we can still send a new message after truncation.
 747    thread.update(cx, |thread, cx| {
 748        thread.send(UserMessageId::new(), ["Hi"], cx)
 749    });
 750    thread.update(cx, |thread, _cx| {
 751        assert_eq!(
 752            thread.to_markdown(),
 753            indoc! {"
 754                ## User
 755
 756                Hi
 757            "}
 758        );
 759    });
 760    cx.run_until_parked();
 761    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
 762    cx.run_until_parked();
 763    thread.read_with(cx, |thread, _| {
 764        assert_eq!(
 765            thread.to_markdown(),
 766            indoc! {"
 767                ## User
 768
 769                Hi
 770
 771                ## Assistant
 772
 773                Ahoy!
 774            "}
 775        );
 776    });
 777}
 778
 779#[gpui::test]
 780async fn test_agent_connection(cx: &mut TestAppContext) {
 781    cx.update(settings::init);
 782    let templates = Templates::new();
 783
 784    // Initialize language model system with test provider
 785    cx.update(|cx| {
 786        gpui_tokio::init(cx);
 787        client::init_settings(cx);
 788
 789        let http_client = FakeHttpClient::with_404_response();
 790        let clock = Arc::new(clock::FakeSystemClock::new());
 791        let client = Client::new(clock, http_client, cx);
 792        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
 793        language_model::init(client.clone(), cx);
 794        language_models::init(user_store.clone(), client.clone(), cx);
 795        Project::init_settings(cx);
 796        LanguageModelRegistry::test(cx);
 797        agent_settings::init(cx);
 798    });
 799    cx.executor().forbid_parking();
 800
 801    // Create a project for new_thread
 802    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
 803    fake_fs.insert_tree(path!("/test"), json!({})).await;
 804    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
 805    let cwd = Path::new("/test");
 806
 807    // Create agent and connection
 808    let agent = NativeAgent::new(
 809        project.clone(),
 810        templates.clone(),
 811        None,
 812        fake_fs.clone(),
 813        &mut cx.to_async(),
 814    )
 815    .await
 816    .unwrap();
 817    let connection = NativeAgentConnection(agent.clone());
 818
 819    // Test model_selector returns Some
 820    let selector_opt = connection.model_selector();
 821    assert!(
 822        selector_opt.is_some(),
 823        "agent2 should always support ModelSelector"
 824    );
 825    let selector = selector_opt.unwrap();
 826
 827    // Test list_models
 828    let listed_models = cx
 829        .update(|cx| selector.list_models(cx))
 830        .await
 831        .expect("list_models should succeed");
 832    let AgentModelList::Grouped(listed_models) = listed_models else {
 833        panic!("Unexpected model list type");
 834    };
 835    assert!(!listed_models.is_empty(), "should have at least one model");
 836    assert_eq!(
 837        listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
 838        "fake/fake"
 839    );
 840
 841    // Create a thread using new_thread
 842    let connection_rc = Rc::new(connection.clone());
 843    let acp_thread = cx
 844        .update(|cx| connection_rc.new_thread(project, cwd, &mut cx.to_async()))
 845        .await
 846        .expect("new_thread should succeed");
 847
 848    // Get the session_id from the AcpThread
 849    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
 850
 851    // Test selected_model returns the default
 852    let model = cx
 853        .update(|cx| selector.selected_model(&session_id, cx))
 854        .await
 855        .expect("selected_model should succeed");
 856    let model = cx
 857        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
 858        .unwrap();
 859    let model = model.as_fake();
 860    assert_eq!(model.id().0, "fake", "should return default model");
 861
 862    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
 863    cx.run_until_parked();
 864    model.send_last_completion_stream_text_chunk("def");
 865    cx.run_until_parked();
 866    acp_thread.read_with(cx, |thread, cx| {
 867        assert_eq!(
 868            thread.to_markdown(cx),
 869            indoc! {"
 870                ## User
 871
 872                abc
 873
 874                ## Assistant
 875
 876                def
 877
 878            "}
 879        )
 880    });
 881
 882    // Test cancel
 883    cx.update(|cx| connection.cancel(&session_id, cx));
 884    request.await.expect("prompt should fail gracefully");
 885
 886    // Ensure that dropping the ACP thread causes the native thread to be
 887    // dropped as well.
 888    cx.update(|_| drop(acp_thread));
 889    let result = cx
 890        .update(|cx| {
 891            connection.prompt(
 892                Some(acp_thread::UserMessageId::new()),
 893                acp::PromptRequest {
 894                    session_id: session_id.clone(),
 895                    prompt: vec!["ghi".into()],
 896                },
 897                cx,
 898            )
 899        })
 900        .await;
 901    assert_eq!(
 902        result.as_ref().unwrap_err().to_string(),
 903        "Session not found",
 904        "unexpected result: {:?}",
 905        result
 906    );
 907}
 908
 909#[gpui::test]
 910async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
 911    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
 912    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
 913    let fake_model = model.as_fake();
 914
 915    let mut events = thread.update(cx, |thread, cx| {
 916        thread.send(UserMessageId::new(), ["Think"], cx)
 917    });
 918    cx.run_until_parked();
 919
 920    // Simulate streaming partial input.
 921    let input = json!({});
 922    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 923        LanguageModelToolUse {
 924            id: "1".into(),
 925            name: ThinkingTool.name().into(),
 926            raw_input: input.to_string(),
 927            input,
 928            is_input_complete: false,
 929        },
 930    ));
 931
 932    // Input streaming completed
 933    let input = json!({ "content": "Thinking hard!" });
 934    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 935        LanguageModelToolUse {
 936            id: "1".into(),
 937            name: "thinking".into(),
 938            raw_input: input.to_string(),
 939            input,
 940            is_input_complete: true,
 941        },
 942    ));
 943    fake_model.end_last_completion_stream();
 944    cx.run_until_parked();
 945
 946    let tool_call = expect_tool_call(&mut events).await;
 947    assert_eq!(
 948        tool_call,
 949        acp::ToolCall {
 950            id: acp::ToolCallId("1".into()),
 951            title: "Thinking".into(),
 952            kind: acp::ToolKind::Think,
 953            status: acp::ToolCallStatus::Pending,
 954            content: vec![],
 955            locations: vec![],
 956            raw_input: Some(json!({})),
 957            raw_output: None,
 958        }
 959    );
 960    let update = expect_tool_call_update_fields(&mut events).await;
 961    assert_eq!(
 962        update,
 963        acp::ToolCallUpdate {
 964            id: acp::ToolCallId("1".into()),
 965            fields: acp::ToolCallUpdateFields {
 966                title: Some("Thinking".into()),
 967                kind: Some(acp::ToolKind::Think),
 968                raw_input: Some(json!({ "content": "Thinking hard!" })),
 969                ..Default::default()
 970            },
 971        }
 972    );
 973    let update = expect_tool_call_update_fields(&mut events).await;
 974    assert_eq!(
 975        update,
 976        acp::ToolCallUpdate {
 977            id: acp::ToolCallId("1".into()),
 978            fields: acp::ToolCallUpdateFields {
 979                status: Some(acp::ToolCallStatus::InProgress),
 980                ..Default::default()
 981            },
 982        }
 983    );
 984    let update = expect_tool_call_update_fields(&mut events).await;
 985    assert_eq!(
 986        update,
 987        acp::ToolCallUpdate {
 988            id: acp::ToolCallId("1".into()),
 989            fields: acp::ToolCallUpdateFields {
 990                content: Some(vec!["Thinking hard!".into()]),
 991                ..Default::default()
 992            },
 993        }
 994    );
 995    let update = expect_tool_call_update_fields(&mut events).await;
 996    assert_eq!(
 997        update,
 998        acp::ToolCallUpdate {
 999            id: acp::ToolCallId("1".into()),
1000            fields: acp::ToolCallUpdateFields {
1001                status: Some(acp::ToolCallStatus::Completed),
1002                raw_output: Some("Finished thinking.".into()),
1003                ..Default::default()
1004            },
1005        }
1006    );
1007}
1008
1009/// Filters out the stop events for asserting against in tests
1010fn stop_events(
1011    result_events: Vec<Result<AgentResponseEvent, LanguageModelCompletionError>>,
1012) -> Vec<acp::StopReason> {
1013    result_events
1014        .into_iter()
1015        .filter_map(|event| match event.unwrap() {
1016            AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
1017            _ => None,
1018        })
1019        .collect()
1020}
1021
1022struct ThreadTest {
1023    model: Arc<dyn LanguageModel>,
1024    thread: Entity<Thread>,
1025    project_context: Rc<RefCell<ProjectContext>>,
1026    fs: Arc<FakeFs>,
1027}
1028
1029enum TestModel {
1030    Sonnet4,
1031    Sonnet4Thinking,
1032    Fake,
1033}
1034
1035impl TestModel {
1036    fn id(&self) -> LanguageModelId {
1037        match self {
1038            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
1039            TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
1040            TestModel::Fake => unreachable!(),
1041        }
1042    }
1043}
1044
1045async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
1046    cx.executor().allow_parking();
1047
1048    let fs = FakeFs::new(cx.background_executor.clone());
1049    fs.create_dir(paths::settings_file().parent().unwrap())
1050        .await
1051        .unwrap();
1052    fs.insert_file(
1053        paths::settings_file(),
1054        json!({
1055            "agent": {
1056                "default_profile": "test-profile",
1057                "profiles": {
1058                    "test-profile": {
1059                        "name": "Test Profile",
1060                        "tools": {
1061                            EchoTool.name(): true,
1062                            DelayTool.name(): true,
1063                            WordListTool.name(): true,
1064                            ToolRequiringPermission.name(): true,
1065                            InfiniteTool.name(): true,
1066                        }
1067                    }
1068                }
1069            }
1070        })
1071        .to_string()
1072        .into_bytes(),
1073    )
1074    .await;
1075
1076    cx.update(|cx| {
1077        settings::init(cx);
1078        Project::init_settings(cx);
1079        agent_settings::init(cx);
1080        gpui_tokio::init(cx);
1081        let http_client = ReqwestClient::user_agent("agent tests").unwrap();
1082        cx.set_http_client(Arc::new(http_client));
1083
1084        client::init_settings(cx);
1085        let client = Client::production(cx);
1086        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1087        language_model::init(client.clone(), cx);
1088        language_models::init(user_store.clone(), client.clone(), cx);
1089
1090        watch_settings(fs.clone(), cx);
1091    });
1092
1093    let templates = Templates::new();
1094
1095    fs.insert_tree(path!("/test"), json!({})).await;
1096    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1097
1098    let model = cx
1099        .update(|cx| {
1100            if let TestModel::Fake = model {
1101                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
1102            } else {
1103                let model_id = model.id();
1104                let models = LanguageModelRegistry::read_global(cx);
1105                let model = models
1106                    .available_models(cx)
1107                    .find(|model| model.id() == model_id)
1108                    .unwrap();
1109
1110                let provider = models.provider(&model.provider_id()).unwrap();
1111                let authenticated = provider.authenticate(cx);
1112
1113                cx.spawn(async move |_cx| {
1114                    authenticated.await.unwrap();
1115                    model
1116                })
1117            }
1118        })
1119        .await;
1120
1121    let project_context = Rc::new(RefCell::new(ProjectContext::default()));
1122    let context_server_registry =
1123        cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1124    let action_log = cx.new(|_| ActionLog::new(project.clone()));
1125    let thread = cx.new(|cx| {
1126        Thread::new(
1127            project,
1128            project_context.clone(),
1129            context_server_registry,
1130            action_log,
1131            templates,
1132            model.clone(),
1133            cx,
1134        )
1135    });
1136    ThreadTest {
1137        model,
1138        thread,
1139        project_context,
1140        fs,
1141    }
1142}
1143
1144#[cfg(test)]
1145#[ctor::ctor]
1146fn init_logger() {
1147    if std::env::var("RUST_LOG").is_ok() {
1148        env_logger::init();
1149    }
1150}
1151
1152fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
1153    let fs = fs.clone();
1154    cx.spawn({
1155        async move |cx| {
1156            let mut new_settings_content_rx = settings::watch_config_file(
1157                cx.background_executor(),
1158                fs,
1159                paths::settings_file().clone(),
1160            );
1161
1162            while let Some(new_settings_content) = new_settings_content_rx.next().await {
1163                cx.update(|cx| {
1164                    SettingsStore::update_global(cx, |settings, cx| {
1165                        settings.set_user_settings(&new_settings_content, cx)
1166                    })
1167                })
1168                .ok();
1169            }
1170        }
1171    })
1172    .detach();
1173}