mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use action_log::ActionLog;
   4use agent_client_protocol::{self as acp};
   5use agent_settings::AgentProfileId;
   6use anyhow::Result;
   7use client::{Client, UserStore};
   8use fs::{FakeFs, Fs};
   9use futures::channel::mpsc::UnboundedReceiver;
  10use gpui::{
  11    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  12};
  13use indoc::indoc;
  14use language_model::{
  15    LanguageModel, LanguageModelCompletionEvent, LanguageModelId, LanguageModelRegistry,
  16    LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, MessageContent,
  17    Role, StopReason, fake_provider::FakeLanguageModel,
  18};
  19use project::Project;
  20use prompt_store::ProjectContext;
  21use reqwest_client::ReqwestClient;
  22use schemars::JsonSchema;
  23use serde::{Deserialize, Serialize};
  24use serde_json::json;
  25use settings::SettingsStore;
  26use smol::stream::StreamExt;
  27use std::{cell::RefCell, path::Path, rc::Rc, sync::Arc, time::Duration};
  28use util::path;
  29
  30mod test_tools;
  31use test_tools::*;
  32
  33#[gpui::test]
  34#[ignore = "can't run on CI yet"]
  35async fn test_echo(cx: &mut TestAppContext) {
  36    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
  37
  38    let events = thread
  39        .update(cx, |thread, cx| {
  40            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  41        })
  42        .collect()
  43        .await;
  44    thread.update(cx, |thread, _cx| {
  45        assert_eq!(
  46            thread.last_message().unwrap().to_markdown(),
  47            indoc! {"
  48                ## Assistant
  49
  50                Hello
  51            "}
  52        )
  53    });
  54    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  55}
  56
  57#[gpui::test]
  58#[ignore = "can't run on CI yet"]
  59async fn test_thinking(cx: &mut TestAppContext) {
  60    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
  61
  62    let events = thread
  63        .update(cx, |thread, cx| {
  64            thread.send(
  65                UserMessageId::new(),
  66                [indoc! {"
  67                    Testing:
  68
  69                    Generate a thinking step where you just think the word 'Think',
  70                    and have your final answer be 'Hello'
  71                "}],
  72                cx,
  73            )
  74        })
  75        .collect()
  76        .await;
  77    thread.update(cx, |thread, _cx| {
  78        assert_eq!(
  79            thread.last_message().unwrap().to_markdown(),
  80            indoc! {"
  81                ## Assistant
  82
  83                <think>Think</think>
  84                Hello
  85            "}
  86        )
  87    });
  88    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  89}
  90
  91#[gpui::test]
  92async fn test_system_prompt(cx: &mut TestAppContext) {
  93    let ThreadTest {
  94        model,
  95        thread,
  96        project_context,
  97        ..
  98    } = setup(cx, TestModel::Fake).await;
  99    let fake_model = model.as_fake();
 100
 101    project_context.borrow_mut().shell = "test-shell".into();
 102    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 103    thread.update(cx, |thread, cx| {
 104        thread.send(UserMessageId::new(), ["abc"], cx)
 105    });
 106    cx.run_until_parked();
 107    let mut pending_completions = fake_model.pending_completions();
 108    assert_eq!(
 109        pending_completions.len(),
 110        1,
 111        "unexpected pending completions: {:?}",
 112        pending_completions
 113    );
 114
 115    let pending_completion = pending_completions.pop().unwrap();
 116    assert_eq!(pending_completion.messages[0].role, Role::System);
 117
 118    let system_message = &pending_completion.messages[0];
 119    let system_prompt = system_message.content[0].to_str().unwrap();
 120    assert!(
 121        system_prompt.contains("test-shell"),
 122        "unexpected system message: {:?}",
 123        system_message
 124    );
 125    assert!(
 126        system_prompt.contains("## Fixing Diagnostics"),
 127        "unexpected system message: {:?}",
 128        system_message
 129    );
 130}
 131
 132#[gpui::test]
 133#[ignore = "can't run on CI yet"]
 134async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 135    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 136
 137    // Test a tool call that's likely to complete *before* streaming stops.
 138    let events = thread
 139        .update(cx, |thread, cx| {
 140            thread.add_tool(EchoTool);
 141            thread.send(
 142                UserMessageId::new(),
 143                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 144                cx,
 145            )
 146        })
 147        .collect()
 148        .await;
 149    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 150
 151    // Test a tool calls that's likely to complete *after* streaming stops.
 152    let events = thread
 153        .update(cx, |thread, cx| {
 154            thread.remove_tool(&AgentTool::name(&EchoTool));
 155            thread.add_tool(DelayTool);
 156            thread.send(
 157                UserMessageId::new(),
 158                [
 159                    "Now call the delay tool with 200ms.",
 160                    "When the timer goes off, then you echo the output of the tool.",
 161                ],
 162                cx,
 163            )
 164        })
 165        .collect()
 166        .await;
 167    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 168    thread.update(cx, |thread, _cx| {
 169        assert!(
 170            thread
 171                .last_message()
 172                .unwrap()
 173                .as_agent_message()
 174                .unwrap()
 175                .content
 176                .iter()
 177                .any(|content| {
 178                    if let AgentMessageContent::Text(text) = content {
 179                        text.contains("Ding")
 180                    } else {
 181                        false
 182                    }
 183                }),
 184            "{}",
 185            thread.to_markdown()
 186        );
 187    });
 188}
 189
 190#[gpui::test]
 191#[ignore = "can't run on CI yet"]
 192async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 193    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 194
 195    // Test a tool call that's likely to complete *before* streaming stops.
 196    let mut events = thread.update(cx, |thread, cx| {
 197        thread.add_tool(WordListTool);
 198        thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 199    });
 200
 201    let mut saw_partial_tool_use = false;
 202    while let Some(event) = events.next().await {
 203        if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
 204            thread.update(cx, |thread, _cx| {
 205                // Look for a tool use in the thread's last message
 206                let message = thread.last_message().unwrap();
 207                let agent_message = message.as_agent_message().unwrap();
 208                let last_content = agent_message.content.last().unwrap();
 209                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 210                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 211                    if tool_call.status == acp::ToolCallStatus::Pending {
 212                        if !last_tool_use.is_input_complete
 213                            && last_tool_use.input.get("g").is_none()
 214                        {
 215                            saw_partial_tool_use = true;
 216                        }
 217                    } else {
 218                        last_tool_use
 219                            .input
 220                            .get("a")
 221                            .expect("'a' has streamed because input is now complete");
 222                        last_tool_use
 223                            .input
 224                            .get("g")
 225                            .expect("'g' has streamed because input is now complete");
 226                    }
 227                } else {
 228                    panic!("last content should be a tool use");
 229                }
 230            });
 231        }
 232    }
 233
 234    assert!(
 235        saw_partial_tool_use,
 236        "should see at least one partially streamed tool use in the history"
 237    );
 238}
 239
 240#[gpui::test]
 241async fn test_tool_authorization(cx: &mut TestAppContext) {
 242    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 243    let fake_model = model.as_fake();
 244
 245    let mut events = thread.update(cx, |thread, cx| {
 246        thread.add_tool(ToolRequiringPermission);
 247        thread.send(UserMessageId::new(), ["abc"], cx)
 248    });
 249    cx.run_until_parked();
 250    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 251        LanguageModelToolUse {
 252            id: "tool_id_1".into(),
 253            name: ToolRequiringPermission.name().into(),
 254            raw_input: "{}".into(),
 255            input: json!({}),
 256            is_input_complete: true,
 257        },
 258    ));
 259    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 260        LanguageModelToolUse {
 261            id: "tool_id_2".into(),
 262            name: ToolRequiringPermission.name().into(),
 263            raw_input: "{}".into(),
 264            input: json!({}),
 265            is_input_complete: true,
 266        },
 267    ));
 268    fake_model.end_last_completion_stream();
 269    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 270    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 271
 272    // Approve the first
 273    tool_call_auth_1
 274        .response
 275        .send(tool_call_auth_1.options[1].id.clone())
 276        .unwrap();
 277    cx.run_until_parked();
 278
 279    // Reject the second
 280    tool_call_auth_2
 281        .response
 282        .send(tool_call_auth_1.options[2].id.clone())
 283        .unwrap();
 284    cx.run_until_parked();
 285
 286    let completion = fake_model.pending_completions().pop().unwrap();
 287    let message = completion.messages.last().unwrap();
 288    assert_eq!(
 289        message.content,
 290        vec![
 291            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 292                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 293                tool_name: ToolRequiringPermission.name().into(),
 294                is_error: false,
 295                content: "Allowed".into(),
 296                output: Some("Allowed".into())
 297            }),
 298            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 299                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 300                tool_name: ToolRequiringPermission.name().into(),
 301                is_error: true,
 302                content: "Permission to run tool denied by user".into(),
 303                output: None
 304            })
 305        ]
 306    );
 307
 308    // Simulate yet another tool call.
 309    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 310        LanguageModelToolUse {
 311            id: "tool_id_3".into(),
 312            name: ToolRequiringPermission.name().into(),
 313            raw_input: "{}".into(),
 314            input: json!({}),
 315            is_input_complete: true,
 316        },
 317    ));
 318    fake_model.end_last_completion_stream();
 319
 320    // Respond by always allowing tools.
 321    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 322    tool_call_auth_3
 323        .response
 324        .send(tool_call_auth_3.options[0].id.clone())
 325        .unwrap();
 326    cx.run_until_parked();
 327    let completion = fake_model.pending_completions().pop().unwrap();
 328    let message = completion.messages.last().unwrap();
 329    assert_eq!(
 330        message.content,
 331        vec![language_model::MessageContent::ToolResult(
 332            LanguageModelToolResult {
 333                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 334                tool_name: ToolRequiringPermission.name().into(),
 335                is_error: false,
 336                content: "Allowed".into(),
 337                output: Some("Allowed".into())
 338            }
 339        )]
 340    );
 341
 342    // Simulate a final tool call, ensuring we don't trigger authorization.
 343    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 344        LanguageModelToolUse {
 345            id: "tool_id_4".into(),
 346            name: ToolRequiringPermission.name().into(),
 347            raw_input: "{}".into(),
 348            input: json!({}),
 349            is_input_complete: true,
 350        },
 351    ));
 352    fake_model.end_last_completion_stream();
 353    cx.run_until_parked();
 354    let completion = fake_model.pending_completions().pop().unwrap();
 355    let message = completion.messages.last().unwrap();
 356    assert_eq!(
 357        message.content,
 358        vec![language_model::MessageContent::ToolResult(
 359            LanguageModelToolResult {
 360                tool_use_id: "tool_id_4".into(),
 361                tool_name: ToolRequiringPermission.name().into(),
 362                is_error: false,
 363                content: "Allowed".into(),
 364                output: Some("Allowed".into())
 365            }
 366        )]
 367    );
 368}
 369
 370#[gpui::test]
 371async fn test_tool_hallucination(cx: &mut TestAppContext) {
 372    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 373    let fake_model = model.as_fake();
 374
 375    let mut events = thread.update(cx, |thread, cx| {
 376        thread.send(UserMessageId::new(), ["abc"], cx)
 377    });
 378    cx.run_until_parked();
 379    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 380        LanguageModelToolUse {
 381            id: "tool_id_1".into(),
 382            name: "nonexistent_tool".into(),
 383            raw_input: "{}".into(),
 384            input: json!({}),
 385            is_input_complete: true,
 386        },
 387    ));
 388    fake_model.end_last_completion_stream();
 389
 390    let tool_call = expect_tool_call(&mut events).await;
 391    assert_eq!(tool_call.title, "nonexistent_tool");
 392    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 393    let update = expect_tool_call_update_fields(&mut events).await;
 394    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 395}
 396
 397#[gpui::test]
 398async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 399    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 400    let fake_model = model.as_fake();
 401
 402    let events = thread.update(cx, |thread, cx| {
 403        thread.add_tool(EchoTool);
 404        thread.send(UserMessageId::new(), ["abc"], cx)
 405    });
 406    cx.run_until_parked();
 407    let tool_use = LanguageModelToolUse {
 408        id: "tool_id_1".into(),
 409        name: EchoTool.name().into(),
 410        raw_input: "{}".into(),
 411        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 412        is_input_complete: true,
 413    };
 414    fake_model
 415        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 416    fake_model.end_last_completion_stream();
 417
 418    cx.run_until_parked();
 419    let completion = fake_model.pending_completions().pop().unwrap();
 420    let tool_result = LanguageModelToolResult {
 421        tool_use_id: "tool_id_1".into(),
 422        tool_name: EchoTool.name().into(),
 423        is_error: false,
 424        content: "def".into(),
 425        output: Some("def".into()),
 426    };
 427    assert_eq!(
 428        completion.messages[1..],
 429        vec![
 430            LanguageModelRequestMessage {
 431                role: Role::User,
 432                content: vec!["abc".into()],
 433                cache: false
 434            },
 435            LanguageModelRequestMessage {
 436                role: Role::Assistant,
 437                content: vec![MessageContent::ToolUse(tool_use.clone())],
 438                cache: false
 439            },
 440            LanguageModelRequestMessage {
 441                role: Role::User,
 442                content: vec![MessageContent::ToolResult(tool_result.clone())],
 443                cache: false
 444            },
 445        ]
 446    );
 447
 448    // Simulate reaching tool use limit.
 449    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 450        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 451    ));
 452    fake_model.end_last_completion_stream();
 453    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 454    assert!(
 455        last_event
 456            .unwrap_err()
 457            .is::<language_model::ToolUseLimitReachedError>()
 458    );
 459
 460    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 461    cx.run_until_parked();
 462    let completion = fake_model.pending_completions().pop().unwrap();
 463    assert_eq!(
 464        completion.messages[1..],
 465        vec![
 466            LanguageModelRequestMessage {
 467                role: Role::User,
 468                content: vec!["abc".into()],
 469                cache: false
 470            },
 471            LanguageModelRequestMessage {
 472                role: Role::Assistant,
 473                content: vec![MessageContent::ToolUse(tool_use)],
 474                cache: false
 475            },
 476            LanguageModelRequestMessage {
 477                role: Role::User,
 478                content: vec![MessageContent::ToolResult(tool_result)],
 479                cache: false
 480            },
 481            LanguageModelRequestMessage {
 482                role: Role::User,
 483                content: vec!["Continue where you left off".into()],
 484                cache: false
 485            }
 486        ]
 487    );
 488
 489    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 490    fake_model.end_last_completion_stream();
 491    events.collect::<Vec<_>>().await;
 492    thread.read_with(cx, |thread, _cx| {
 493        assert_eq!(
 494            thread.last_message().unwrap().to_markdown(),
 495            indoc! {"
 496                ## Assistant
 497
 498                Done
 499            "}
 500        )
 501    });
 502
 503    // Ensure we error if calling resume when tool use limit was *not* reached.
 504    let error = thread
 505        .update(cx, |thread, cx| thread.resume(cx))
 506        .unwrap_err();
 507    assert_eq!(
 508        error.to_string(),
 509        "can only resume after tool use limit is reached"
 510    )
 511}
 512
 513#[gpui::test]
 514async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 515    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 516    let fake_model = model.as_fake();
 517
 518    let events = thread.update(cx, |thread, cx| {
 519        thread.add_tool(EchoTool);
 520        thread.send(UserMessageId::new(), ["abc"], cx)
 521    });
 522    cx.run_until_parked();
 523
 524    let tool_use = LanguageModelToolUse {
 525        id: "tool_id_1".into(),
 526        name: EchoTool.name().into(),
 527        raw_input: "{}".into(),
 528        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 529        is_input_complete: true,
 530    };
 531    let tool_result = LanguageModelToolResult {
 532        tool_use_id: "tool_id_1".into(),
 533        tool_name: EchoTool.name().into(),
 534        is_error: false,
 535        content: "def".into(),
 536        output: Some("def".into()),
 537    };
 538    fake_model
 539        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 540    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 541        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 542    ));
 543    fake_model.end_last_completion_stream();
 544    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 545    assert!(
 546        last_event
 547            .unwrap_err()
 548            .is::<language_model::ToolUseLimitReachedError>()
 549    );
 550
 551    thread.update(cx, |thread, cx| {
 552        thread.send(UserMessageId::new(), vec!["ghi"], cx)
 553    });
 554    cx.run_until_parked();
 555    let completion = fake_model.pending_completions().pop().unwrap();
 556    assert_eq!(
 557        completion.messages[1..],
 558        vec![
 559            LanguageModelRequestMessage {
 560                role: Role::User,
 561                content: vec!["abc".into()],
 562                cache: false
 563            },
 564            LanguageModelRequestMessage {
 565                role: Role::Assistant,
 566                content: vec![MessageContent::ToolUse(tool_use)],
 567                cache: false
 568            },
 569            LanguageModelRequestMessage {
 570                role: Role::User,
 571                content: vec![MessageContent::ToolResult(tool_result)],
 572                cache: false
 573            },
 574            LanguageModelRequestMessage {
 575                role: Role::User,
 576                content: vec!["ghi".into()],
 577                cache: false
 578            }
 579        ]
 580    );
 581}
 582
 583async fn expect_tool_call(
 584    events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
 585) -> acp::ToolCall {
 586    let event = events
 587        .next()
 588        .await
 589        .expect("no tool call authorization event received")
 590        .unwrap();
 591    match event {
 592        AgentResponseEvent::ToolCall(tool_call) => return tool_call,
 593        event => {
 594            panic!("Unexpected event {event:?}");
 595        }
 596    }
 597}
 598
 599async fn expect_tool_call_update_fields(
 600    events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
 601) -> acp::ToolCallUpdate {
 602    let event = events
 603        .next()
 604        .await
 605        .expect("no tool call authorization event received")
 606        .unwrap();
 607    match event {
 608        AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => {
 609            return update;
 610        }
 611        event => {
 612            panic!("Unexpected event {event:?}");
 613        }
 614    }
 615}
 616
 617async fn next_tool_call_authorization(
 618    events: &mut UnboundedReceiver<Result<AgentResponseEvent>>,
 619) -> ToolCallAuthorization {
 620    loop {
 621        let event = events
 622            .next()
 623            .await
 624            .expect("no tool call authorization event received")
 625            .unwrap();
 626        if let AgentResponseEvent::ToolCallAuthorization(tool_call_authorization) = event {
 627            let permission_kinds = tool_call_authorization
 628                .options
 629                .iter()
 630                .map(|o| o.kind)
 631                .collect::<Vec<_>>();
 632            assert_eq!(
 633                permission_kinds,
 634                vec![
 635                    acp::PermissionOptionKind::AllowAlways,
 636                    acp::PermissionOptionKind::AllowOnce,
 637                    acp::PermissionOptionKind::RejectOnce,
 638                ]
 639            );
 640            return tool_call_authorization;
 641        }
 642    }
 643}
 644
 645#[gpui::test]
 646#[ignore = "can't run on CI yet"]
 647async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 648    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 649
 650    // Test concurrent tool calls with different delay times
 651    let events = thread
 652        .update(cx, |thread, cx| {
 653            thread.add_tool(DelayTool);
 654            thread.send(
 655                UserMessageId::new(),
 656                [
 657                    "Call the delay tool twice in the same message.",
 658                    "Once with 100ms. Once with 300ms.",
 659                    "When both timers are complete, describe the outputs.",
 660                ],
 661                cx,
 662            )
 663        })
 664        .collect()
 665        .await;
 666
 667    let stop_reasons = stop_events(events);
 668    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 669
 670    thread.update(cx, |thread, _cx| {
 671        let last_message = thread.last_message().unwrap();
 672        let agent_message = last_message.as_agent_message().unwrap();
 673        let text = agent_message
 674            .content
 675            .iter()
 676            .filter_map(|content| {
 677                if let AgentMessageContent::Text(text) = content {
 678                    Some(text.as_str())
 679                } else {
 680                    None
 681                }
 682            })
 683            .collect::<String>();
 684
 685        assert!(text.contains("Ding"));
 686    });
 687}
 688
 689#[gpui::test]
 690async fn test_profiles(cx: &mut TestAppContext) {
 691    let ThreadTest {
 692        model, thread, fs, ..
 693    } = setup(cx, TestModel::Fake).await;
 694    let fake_model = model.as_fake();
 695
 696    thread.update(cx, |thread, _cx| {
 697        thread.add_tool(DelayTool);
 698        thread.add_tool(EchoTool);
 699        thread.add_tool(InfiniteTool);
 700    });
 701
 702    // Override profiles and wait for settings to be loaded.
 703    fs.insert_file(
 704        paths::settings_file(),
 705        json!({
 706            "agent": {
 707                "profiles": {
 708                    "test-1": {
 709                        "name": "Test Profile 1",
 710                        "tools": {
 711                            EchoTool.name(): true,
 712                            DelayTool.name(): true,
 713                        }
 714                    },
 715                    "test-2": {
 716                        "name": "Test Profile 2",
 717                        "tools": {
 718                            InfiniteTool.name(): true,
 719                        }
 720                    }
 721                }
 722            }
 723        })
 724        .to_string()
 725        .into_bytes(),
 726    )
 727    .await;
 728    cx.run_until_parked();
 729
 730    // Test that test-1 profile (default) has echo and delay tools
 731    thread.update(cx, |thread, cx| {
 732        thread.set_profile(AgentProfileId("test-1".into()));
 733        thread.send(UserMessageId::new(), ["test"], cx);
 734    });
 735    cx.run_until_parked();
 736
 737    let mut pending_completions = fake_model.pending_completions();
 738    assert_eq!(pending_completions.len(), 1);
 739    let completion = pending_completions.pop().unwrap();
 740    let tool_names: Vec<String> = completion
 741        .tools
 742        .iter()
 743        .map(|tool| tool.name.clone())
 744        .collect();
 745    assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
 746    fake_model.end_last_completion_stream();
 747
 748    // Switch to test-2 profile, and verify that it has only the infinite tool.
 749    thread.update(cx, |thread, cx| {
 750        thread.set_profile(AgentProfileId("test-2".into()));
 751        thread.send(UserMessageId::new(), ["test2"], cx)
 752    });
 753    cx.run_until_parked();
 754    let mut pending_completions = fake_model.pending_completions();
 755    assert_eq!(pending_completions.len(), 1);
 756    let completion = pending_completions.pop().unwrap();
 757    let tool_names: Vec<String> = completion
 758        .tools
 759        .iter()
 760        .map(|tool| tool.name.clone())
 761        .collect();
 762    assert_eq!(tool_names, vec![InfiniteTool.name()]);
 763}
 764
 765#[gpui::test]
 766#[ignore = "can't run on CI yet"]
 767async fn test_cancellation(cx: &mut TestAppContext) {
 768    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 769
 770    let mut events = thread.update(cx, |thread, cx| {
 771        thread.add_tool(InfiniteTool);
 772        thread.add_tool(EchoTool);
 773        thread.send(
 774            UserMessageId::new(),
 775            ["Call the echo tool, then call the infinite tool, then explain their output"],
 776            cx,
 777        )
 778    });
 779
 780    // Wait until both tools are called.
 781    let mut expected_tools = vec!["Echo", "Infinite Tool"];
 782    let mut echo_id = None;
 783    let mut echo_completed = false;
 784    while let Some(event) = events.next().await {
 785        match event.unwrap() {
 786            AgentResponseEvent::ToolCall(tool_call) => {
 787                assert_eq!(tool_call.title, expected_tools.remove(0));
 788                if tool_call.title == "Echo" {
 789                    echo_id = Some(tool_call.id);
 790                }
 791            }
 792            AgentResponseEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
 793                acp::ToolCallUpdate {
 794                    id,
 795                    fields:
 796                        acp::ToolCallUpdateFields {
 797                            status: Some(acp::ToolCallStatus::Completed),
 798                            ..
 799                        },
 800                },
 801            )) if Some(&id) == echo_id.as_ref() => {
 802                echo_completed = true;
 803            }
 804            _ => {}
 805        }
 806
 807        if expected_tools.is_empty() && echo_completed {
 808            break;
 809        }
 810    }
 811
 812    // Cancel the current send and ensure that the event stream is closed, even
 813    // if one of the tools is still running.
 814    thread.update(cx, |thread, _cx| thread.cancel());
 815    events.collect::<Vec<_>>().await;
 816
 817    // Ensure we can still send a new message after cancellation.
 818    let events = thread
 819        .update(cx, |thread, cx| {
 820            thread.send(
 821                UserMessageId::new(),
 822                ["Testing: reply with 'Hello' then stop."],
 823                cx,
 824            )
 825        })
 826        .collect::<Vec<_>>()
 827        .await;
 828    thread.update(cx, |thread, _cx| {
 829        let message = thread.last_message().unwrap();
 830        let agent_message = message.as_agent_message().unwrap();
 831        assert_eq!(
 832            agent_message.content,
 833            vec![AgentMessageContent::Text("Hello".to_string())]
 834        );
 835    });
 836    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 837}
 838
 839#[gpui::test]
 840async fn test_refusal(cx: &mut TestAppContext) {
 841    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 842    let fake_model = model.as_fake();
 843
 844    let events = thread.update(cx, |thread, cx| {
 845        thread.send(UserMessageId::new(), ["Hello"], cx)
 846    });
 847    cx.run_until_parked();
 848    thread.read_with(cx, |thread, _| {
 849        assert_eq!(
 850            thread.to_markdown(),
 851            indoc! {"
 852                ## User
 853
 854                Hello
 855            "}
 856        );
 857    });
 858
 859    fake_model.send_last_completion_stream_text_chunk("Hey!");
 860    cx.run_until_parked();
 861    thread.read_with(cx, |thread, _| {
 862        assert_eq!(
 863            thread.to_markdown(),
 864            indoc! {"
 865                ## User
 866
 867                Hello
 868
 869                ## Assistant
 870
 871                Hey!
 872            "}
 873        );
 874    });
 875
 876    // If the model refuses to continue, the thread should remove all the messages after the last user message.
 877    fake_model
 878        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
 879    let events = events.collect::<Vec<_>>().await;
 880    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
 881    thread.read_with(cx, |thread, _| {
 882        assert_eq!(thread.to_markdown(), "");
 883    });
 884}
 885
 886#[gpui::test]
 887async fn test_truncate(cx: &mut TestAppContext) {
 888    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 889    let fake_model = model.as_fake();
 890
 891    let message_id = UserMessageId::new();
 892    thread.update(cx, |thread, cx| {
 893        thread.send(message_id.clone(), ["Hello"], cx)
 894    });
 895    cx.run_until_parked();
 896    thread.read_with(cx, |thread, _| {
 897        assert_eq!(
 898            thread.to_markdown(),
 899            indoc! {"
 900                ## User
 901
 902                Hello
 903            "}
 904        );
 905    });
 906
 907    fake_model.send_last_completion_stream_text_chunk("Hey!");
 908    cx.run_until_parked();
 909    thread.read_with(cx, |thread, _| {
 910        assert_eq!(
 911            thread.to_markdown(),
 912            indoc! {"
 913                ## User
 914
 915                Hello
 916
 917                ## Assistant
 918
 919                Hey!
 920            "}
 921        );
 922    });
 923
 924    thread
 925        .update(cx, |thread, _cx| thread.truncate(message_id))
 926        .unwrap();
 927    cx.run_until_parked();
 928    thread.read_with(cx, |thread, _| {
 929        assert_eq!(thread.to_markdown(), "");
 930    });
 931
 932    // Ensure we can still send a new message after truncation.
 933    thread.update(cx, |thread, cx| {
 934        thread.send(UserMessageId::new(), ["Hi"], cx)
 935    });
 936    thread.update(cx, |thread, _cx| {
 937        assert_eq!(
 938            thread.to_markdown(),
 939            indoc! {"
 940                ## User
 941
 942                Hi
 943            "}
 944        );
 945    });
 946    cx.run_until_parked();
 947    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
 948    cx.run_until_parked();
 949    thread.read_with(cx, |thread, _| {
 950        assert_eq!(
 951            thread.to_markdown(),
 952            indoc! {"
 953                ## User
 954
 955                Hi
 956
 957                ## Assistant
 958
 959                Ahoy!
 960            "}
 961        );
 962    });
 963}
 964
 965#[gpui::test]
 966async fn test_agent_connection(cx: &mut TestAppContext) {
 967    cx.update(settings::init);
 968    let templates = Templates::new();
 969
 970    // Initialize language model system with test provider
 971    cx.update(|cx| {
 972        gpui_tokio::init(cx);
 973        client::init_settings(cx);
 974
 975        let http_client = FakeHttpClient::with_404_response();
 976        let clock = Arc::new(clock::FakeSystemClock::new());
 977        let client = Client::new(clock, http_client, cx);
 978        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
 979        language_model::init(client.clone(), cx);
 980        language_models::init(user_store.clone(), client.clone(), cx);
 981        Project::init_settings(cx);
 982        LanguageModelRegistry::test(cx);
 983        agent_settings::init(cx);
 984    });
 985    cx.executor().forbid_parking();
 986
 987    // Create a project for new_thread
 988    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
 989    fake_fs.insert_tree(path!("/test"), json!({})).await;
 990    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
 991    let cwd = Path::new("/test");
 992
 993    // Create agent and connection
 994    let agent = NativeAgent::new(
 995        project.clone(),
 996        templates.clone(),
 997        None,
 998        fake_fs.clone(),
 999        &mut cx.to_async(),
1000    )
1001    .await
1002    .unwrap();
1003    let connection = NativeAgentConnection(agent.clone());
1004
1005    // Test model_selector returns Some
1006    let selector_opt = connection.model_selector();
1007    assert!(
1008        selector_opt.is_some(),
1009        "agent2 should always support ModelSelector"
1010    );
1011    let selector = selector_opt.unwrap();
1012
1013    // Test list_models
1014    let listed_models = cx
1015        .update(|cx| selector.list_models(cx))
1016        .await
1017        .expect("list_models should succeed");
1018    let AgentModelList::Grouped(listed_models) = listed_models else {
1019        panic!("Unexpected model list type");
1020    };
1021    assert!(!listed_models.is_empty(), "should have at least one model");
1022    assert_eq!(
1023        listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1024        "fake/fake"
1025    );
1026
1027    // Create a thread using new_thread
1028    let connection_rc = Rc::new(connection.clone());
1029    let acp_thread = cx
1030        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1031        .await
1032        .expect("new_thread should succeed");
1033
1034    // Get the session_id from the AcpThread
1035    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1036
1037    // Test selected_model returns the default
1038    let model = cx
1039        .update(|cx| selector.selected_model(&session_id, cx))
1040        .await
1041        .expect("selected_model should succeed");
1042    let model = cx
1043        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1044        .unwrap();
1045    let model = model.as_fake();
1046    assert_eq!(model.id().0, "fake", "should return default model");
1047
1048    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1049    cx.run_until_parked();
1050    model.send_last_completion_stream_text_chunk("def");
1051    cx.run_until_parked();
1052    acp_thread.read_with(cx, |thread, cx| {
1053        assert_eq!(
1054            thread.to_markdown(cx),
1055            indoc! {"
1056                ## User
1057
1058                abc
1059
1060                ## Assistant
1061
1062                def
1063
1064            "}
1065        )
1066    });
1067
1068    // Test cancel
1069    cx.update(|cx| connection.cancel(&session_id, cx));
1070    request.await.expect("prompt should fail gracefully");
1071
1072    // Ensure that dropping the ACP thread causes the native thread to be
1073    // dropped as well.
1074    cx.update(|_| drop(acp_thread));
1075    let result = cx
1076        .update(|cx| {
1077            connection.prompt(
1078                Some(acp_thread::UserMessageId::new()),
1079                acp::PromptRequest {
1080                    session_id: session_id.clone(),
1081                    prompt: vec!["ghi".into()],
1082                },
1083                cx,
1084            )
1085        })
1086        .await;
1087    assert_eq!(
1088        result.as_ref().unwrap_err().to_string(),
1089        "Session not found",
1090        "unexpected result: {:?}",
1091        result
1092    );
1093}
1094
1095#[gpui::test]
1096async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1097    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1098    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1099    let fake_model = model.as_fake();
1100
1101    let mut events = thread.update(cx, |thread, cx| {
1102        thread.send(UserMessageId::new(), ["Think"], cx)
1103    });
1104    cx.run_until_parked();
1105
1106    // Simulate streaming partial input.
1107    let input = json!({});
1108    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1109        LanguageModelToolUse {
1110            id: "1".into(),
1111            name: ThinkingTool.name().into(),
1112            raw_input: input.to_string(),
1113            input,
1114            is_input_complete: false,
1115        },
1116    ));
1117
1118    // Input streaming completed
1119    let input = json!({ "content": "Thinking hard!" });
1120    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1121        LanguageModelToolUse {
1122            id: "1".into(),
1123            name: "thinking".into(),
1124            raw_input: input.to_string(),
1125            input,
1126            is_input_complete: true,
1127        },
1128    ));
1129    fake_model.end_last_completion_stream();
1130    cx.run_until_parked();
1131
1132    let tool_call = expect_tool_call(&mut events).await;
1133    assert_eq!(
1134        tool_call,
1135        acp::ToolCall {
1136            id: acp::ToolCallId("1".into()),
1137            title: "Thinking".into(),
1138            kind: acp::ToolKind::Think,
1139            status: acp::ToolCallStatus::Pending,
1140            content: vec![],
1141            locations: vec![],
1142            raw_input: Some(json!({})),
1143            raw_output: None,
1144        }
1145    );
1146    let update = expect_tool_call_update_fields(&mut events).await;
1147    assert_eq!(
1148        update,
1149        acp::ToolCallUpdate {
1150            id: acp::ToolCallId("1".into()),
1151            fields: acp::ToolCallUpdateFields {
1152                title: Some("Thinking".into()),
1153                kind: Some(acp::ToolKind::Think),
1154                raw_input: Some(json!({ "content": "Thinking hard!" })),
1155                ..Default::default()
1156            },
1157        }
1158    );
1159    let update = expect_tool_call_update_fields(&mut events).await;
1160    assert_eq!(
1161        update,
1162        acp::ToolCallUpdate {
1163            id: acp::ToolCallId("1".into()),
1164            fields: acp::ToolCallUpdateFields {
1165                status: Some(acp::ToolCallStatus::InProgress),
1166                ..Default::default()
1167            },
1168        }
1169    );
1170    let update = expect_tool_call_update_fields(&mut events).await;
1171    assert_eq!(
1172        update,
1173        acp::ToolCallUpdate {
1174            id: acp::ToolCallId("1".into()),
1175            fields: acp::ToolCallUpdateFields {
1176                content: Some(vec!["Thinking hard!".into()]),
1177                ..Default::default()
1178            },
1179        }
1180    );
1181    let update = expect_tool_call_update_fields(&mut events).await;
1182    assert_eq!(
1183        update,
1184        acp::ToolCallUpdate {
1185            id: acp::ToolCallId("1".into()),
1186            fields: acp::ToolCallUpdateFields {
1187                status: Some(acp::ToolCallStatus::Completed),
1188                raw_output: Some("Finished thinking.".into()),
1189                ..Default::default()
1190            },
1191        }
1192    );
1193}
1194
1195/// Filters out the stop events for asserting against in tests
1196fn stop_events(result_events: Vec<Result<AgentResponseEvent>>) -> Vec<acp::StopReason> {
1197    result_events
1198        .into_iter()
1199        .filter_map(|event| match event.unwrap() {
1200            AgentResponseEvent::Stop(stop_reason) => Some(stop_reason),
1201            _ => None,
1202        })
1203        .collect()
1204}
1205
1206struct ThreadTest {
1207    model: Arc<dyn LanguageModel>,
1208    thread: Entity<Thread>,
1209    project_context: Rc<RefCell<ProjectContext>>,
1210    fs: Arc<FakeFs>,
1211}
1212
1213enum TestModel {
1214    Sonnet4,
1215    Sonnet4Thinking,
1216    Fake,
1217}
1218
1219impl TestModel {
1220    fn id(&self) -> LanguageModelId {
1221        match self {
1222            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
1223            TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
1224            TestModel::Fake => unreachable!(),
1225        }
1226    }
1227}
1228
1229async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
1230    cx.executor().allow_parking();
1231
1232    let fs = FakeFs::new(cx.background_executor.clone());
1233    fs.create_dir(paths::settings_file().parent().unwrap())
1234        .await
1235        .unwrap();
1236    fs.insert_file(
1237        paths::settings_file(),
1238        json!({
1239            "agent": {
1240                "default_profile": "test-profile",
1241                "profiles": {
1242                    "test-profile": {
1243                        "name": "Test Profile",
1244                        "tools": {
1245                            EchoTool.name(): true,
1246                            DelayTool.name(): true,
1247                            WordListTool.name(): true,
1248                            ToolRequiringPermission.name(): true,
1249                            InfiniteTool.name(): true,
1250                        }
1251                    }
1252                }
1253            }
1254        })
1255        .to_string()
1256        .into_bytes(),
1257    )
1258    .await;
1259
1260    cx.update(|cx| {
1261        settings::init(cx);
1262        Project::init_settings(cx);
1263        agent_settings::init(cx);
1264        gpui_tokio::init(cx);
1265        let http_client = ReqwestClient::user_agent("agent tests").unwrap();
1266        cx.set_http_client(Arc::new(http_client));
1267
1268        client::init_settings(cx);
1269        let client = Client::production(cx);
1270        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1271        language_model::init(client.clone(), cx);
1272        language_models::init(user_store.clone(), client.clone(), cx);
1273
1274        watch_settings(fs.clone(), cx);
1275    });
1276
1277    let templates = Templates::new();
1278
1279    fs.insert_tree(path!("/test"), json!({})).await;
1280    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1281
1282    let model = cx
1283        .update(|cx| {
1284            if let TestModel::Fake = model {
1285                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
1286            } else {
1287                let model_id = model.id();
1288                let models = LanguageModelRegistry::read_global(cx);
1289                let model = models
1290                    .available_models(cx)
1291                    .find(|model| model.id() == model_id)
1292                    .unwrap();
1293
1294                let provider = models.provider(&model.provider_id()).unwrap();
1295                let authenticated = provider.authenticate(cx);
1296
1297                cx.spawn(async move |_cx| {
1298                    authenticated.await.unwrap();
1299                    model
1300                })
1301            }
1302        })
1303        .await;
1304
1305    let project_context = Rc::new(RefCell::new(ProjectContext::default()));
1306    let context_server_registry =
1307        cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1308    let action_log = cx.new(|_| ActionLog::new(project.clone()));
1309    let thread = cx.new(|cx| {
1310        Thread::new(
1311            project,
1312            project_context.clone(),
1313            context_server_registry,
1314            action_log,
1315            templates,
1316            model.clone(),
1317            cx,
1318        )
1319    });
1320    ThreadTest {
1321        model,
1322        thread,
1323        project_context,
1324        fs,
1325    }
1326}
1327
1328#[cfg(test)]
1329#[ctor::ctor]
1330fn init_logger() {
1331    if std::env::var("RUST_LOG").is_ok() {
1332        env_logger::init();
1333    }
1334}
1335
1336fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
1337    let fs = fs.clone();
1338    cx.spawn({
1339        async move |cx| {
1340            let mut new_settings_content_rx = settings::watch_config_file(
1341                cx.background_executor(),
1342                fs,
1343                paths::settings_file().clone(),
1344            );
1345
1346            while let Some(new_settings_content) = new_settings_content_rx.next().await {
1347                cx.update(|cx| {
1348                    SettingsStore::update_global(cx, |settings, cx| {
1349                        settings.set_user_settings(&new_settings_content, cx)
1350                    })
1351                })
1352                .ok();
1353            }
1354        }
1355    })
1356    .detach();
1357}