mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use action_log::ActionLog;
   4use agent_client_protocol::{self as acp};
   5use agent_settings::AgentProfileId;
   6use anyhow::Result;
   7use client::{Client, UserStore};
   8use fs::{FakeFs, Fs};
   9use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
  10use gpui::{
  11    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  12};
  13use indoc::indoc;
  14use language_model::{
  15    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  16    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
  17    LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
  18    fake_provider::FakeLanguageModel,
  19};
  20use pretty_assertions::assert_eq;
  21use project::Project;
  22use prompt_store::ProjectContext;
  23use reqwest_client::ReqwestClient;
  24use schemars::JsonSchema;
  25use serde::{Deserialize, Serialize};
  26use serde_json::json;
  27use settings::SettingsStore;
  28use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
  29use util::path;
  30
  31mod test_tools;
  32use test_tools::*;
  33
  34#[gpui::test]
  35#[ignore = "can't run on CI yet"]
  36async fn test_echo(cx: &mut TestAppContext) {
  37    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
  38
  39    let events = thread
  40        .update(cx, |thread, cx| {
  41            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  42        })
  43        .unwrap()
  44        .collect()
  45        .await;
  46    thread.update(cx, |thread, _cx| {
  47        assert_eq!(
  48            thread.last_message().unwrap().to_markdown(),
  49            indoc! {"
  50                ## Assistant
  51
  52                Hello
  53            "}
  54        )
  55    });
  56    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  57}
  58
  59#[gpui::test]
  60#[ignore = "can't run on CI yet"]
  61async fn test_thinking(cx: &mut TestAppContext) {
  62    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4Thinking).await;
  63
  64    let events = thread
  65        .update(cx, |thread, cx| {
  66            thread.send(
  67                UserMessageId::new(),
  68                [indoc! {"
  69                    Testing:
  70
  71                    Generate a thinking step where you just think the word 'Think',
  72                    and have your final answer be 'Hello'
  73                "}],
  74                cx,
  75            )
  76        })
  77        .unwrap()
  78        .collect()
  79        .await;
  80    thread.update(cx, |thread, _cx| {
  81        assert_eq!(
  82            thread.last_message().unwrap().to_markdown(),
  83            indoc! {"
  84                ## Assistant
  85
  86                <think>Think</think>
  87                Hello
  88            "}
  89        )
  90    });
  91    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  92}
  93
  94#[gpui::test]
  95async fn test_system_prompt(cx: &mut TestAppContext) {
  96    let ThreadTest {
  97        model,
  98        thread,
  99        project_context,
 100        ..
 101    } = setup(cx, TestModel::Fake).await;
 102    let fake_model = model.as_fake();
 103
 104    project_context.update(cx, |project_context, _cx| {
 105        project_context.shell = "test-shell".into()
 106    });
 107    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 108    thread
 109        .update(cx, |thread, cx| {
 110            thread.send(UserMessageId::new(), ["abc"], cx)
 111        })
 112        .unwrap();
 113    cx.run_until_parked();
 114    let mut pending_completions = fake_model.pending_completions();
 115    assert_eq!(
 116        pending_completions.len(),
 117        1,
 118        "unexpected pending completions: {:?}",
 119        pending_completions
 120    );
 121
 122    let pending_completion = pending_completions.pop().unwrap();
 123    assert_eq!(pending_completion.messages[0].role, Role::System);
 124
 125    let system_message = &pending_completion.messages[0];
 126    let system_prompt = system_message.content[0].to_str().unwrap();
 127    assert!(
 128        system_prompt.contains("test-shell"),
 129        "unexpected system message: {:?}",
 130        system_message
 131    );
 132    assert!(
 133        system_prompt.contains("## Fixing Diagnostics"),
 134        "unexpected system message: {:?}",
 135        system_message
 136    );
 137}
 138
 139#[gpui::test]
 140async fn test_prompt_caching(cx: &mut TestAppContext) {
 141    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 142    let fake_model = model.as_fake();
 143
 144    // Send initial user message and verify it's cached
 145    thread
 146        .update(cx, |thread, cx| {
 147            thread.send(UserMessageId::new(), ["Message 1"], cx)
 148        })
 149        .unwrap();
 150    cx.run_until_parked();
 151
 152    let completion = fake_model.pending_completions().pop().unwrap();
 153    assert_eq!(
 154        completion.messages[1..],
 155        vec![LanguageModelRequestMessage {
 156            role: Role::User,
 157            content: vec!["Message 1".into()],
 158            cache: true
 159        }]
 160    );
 161    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 162        "Response to Message 1".into(),
 163    ));
 164    fake_model.end_last_completion_stream();
 165    cx.run_until_parked();
 166
 167    // Send another user message and verify only the latest is cached
 168    thread
 169        .update(cx, |thread, cx| {
 170            thread.send(UserMessageId::new(), ["Message 2"], cx)
 171        })
 172        .unwrap();
 173    cx.run_until_parked();
 174
 175    let completion = fake_model.pending_completions().pop().unwrap();
 176    assert_eq!(
 177        completion.messages[1..],
 178        vec![
 179            LanguageModelRequestMessage {
 180                role: Role::User,
 181                content: vec!["Message 1".into()],
 182                cache: false
 183            },
 184            LanguageModelRequestMessage {
 185                role: Role::Assistant,
 186                content: vec!["Response to Message 1".into()],
 187                cache: false
 188            },
 189            LanguageModelRequestMessage {
 190                role: Role::User,
 191                content: vec!["Message 2".into()],
 192                cache: true
 193            }
 194        ]
 195    );
 196    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 197        "Response to Message 2".into(),
 198    ));
 199    fake_model.end_last_completion_stream();
 200    cx.run_until_parked();
 201
 202    // Simulate a tool call and verify that the latest tool result is cached
 203    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 204    thread
 205        .update(cx, |thread, cx| {
 206            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 207        })
 208        .unwrap();
 209    cx.run_until_parked();
 210
 211    let tool_use = LanguageModelToolUse {
 212        id: "tool_1".into(),
 213        name: EchoTool.name().into(),
 214        raw_input: json!({"text": "test"}).to_string(),
 215        input: json!({"text": "test"}),
 216        is_input_complete: true,
 217    };
 218    fake_model
 219        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 220    fake_model.end_last_completion_stream();
 221    cx.run_until_parked();
 222
 223    let completion = fake_model.pending_completions().pop().unwrap();
 224    let tool_result = LanguageModelToolResult {
 225        tool_use_id: "tool_1".into(),
 226        tool_name: EchoTool.name().into(),
 227        is_error: false,
 228        content: "test".into(),
 229        output: Some("test".into()),
 230    };
 231    assert_eq!(
 232        completion.messages[1..],
 233        vec![
 234            LanguageModelRequestMessage {
 235                role: Role::User,
 236                content: vec!["Message 1".into()],
 237                cache: false
 238            },
 239            LanguageModelRequestMessage {
 240                role: Role::Assistant,
 241                content: vec!["Response to Message 1".into()],
 242                cache: false
 243            },
 244            LanguageModelRequestMessage {
 245                role: Role::User,
 246                content: vec!["Message 2".into()],
 247                cache: false
 248            },
 249            LanguageModelRequestMessage {
 250                role: Role::Assistant,
 251                content: vec!["Response to Message 2".into()],
 252                cache: false
 253            },
 254            LanguageModelRequestMessage {
 255                role: Role::User,
 256                content: vec!["Use the echo tool".into()],
 257                cache: false
 258            },
 259            LanguageModelRequestMessage {
 260                role: Role::Assistant,
 261                content: vec![MessageContent::ToolUse(tool_use)],
 262                cache: false
 263            },
 264            LanguageModelRequestMessage {
 265                role: Role::User,
 266                content: vec![MessageContent::ToolResult(tool_result)],
 267                cache: true
 268            }
 269        ]
 270    );
 271}
 272
 273#[gpui::test]
 274#[ignore = "can't run on CI yet"]
 275async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 276    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 277
 278    // Test a tool call that's likely to complete *before* streaming stops.
 279    let events = thread
 280        .update(cx, |thread, cx| {
 281            thread.add_tool(EchoTool);
 282            thread.send(
 283                UserMessageId::new(),
 284                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 285                cx,
 286            )
 287        })
 288        .unwrap()
 289        .collect()
 290        .await;
 291    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 292
 293    // Test a tool calls that's likely to complete *after* streaming stops.
 294    let events = thread
 295        .update(cx, |thread, cx| {
 296            thread.remove_tool(&AgentTool::name(&EchoTool));
 297            thread.add_tool(DelayTool);
 298            thread.send(
 299                UserMessageId::new(),
 300                [
 301                    "Now call the delay tool with 200ms.",
 302                    "When the timer goes off, then you echo the output of the tool.",
 303                ],
 304                cx,
 305            )
 306        })
 307        .unwrap()
 308        .collect()
 309        .await;
 310    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 311    thread.update(cx, |thread, _cx| {
 312        assert!(
 313            thread
 314                .last_message()
 315                .unwrap()
 316                .as_agent_message()
 317                .unwrap()
 318                .content
 319                .iter()
 320                .any(|content| {
 321                    if let AgentMessageContent::Text(text) = content {
 322                        text.contains("Ding")
 323                    } else {
 324                        false
 325                    }
 326                }),
 327            "{}",
 328            thread.to_markdown()
 329        );
 330    });
 331}
 332
 333#[gpui::test]
 334#[ignore = "can't run on CI yet"]
 335async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 336    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 337
 338    // Test a tool call that's likely to complete *before* streaming stops.
 339    let mut events = thread
 340        .update(cx, |thread, cx| {
 341            thread.add_tool(WordListTool);
 342            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 343        })
 344        .unwrap();
 345
 346    let mut saw_partial_tool_use = false;
 347    while let Some(event) = events.next().await {
 348        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 349            thread.update(cx, |thread, _cx| {
 350                // Look for a tool use in the thread's last message
 351                let message = thread.last_message().unwrap();
 352                let agent_message = message.as_agent_message().unwrap();
 353                let last_content = agent_message.content.last().unwrap();
 354                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 355                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 356                    if tool_call.status == acp::ToolCallStatus::Pending {
 357                        if !last_tool_use.is_input_complete
 358                            && last_tool_use.input.get("g").is_none()
 359                        {
 360                            saw_partial_tool_use = true;
 361                        }
 362                    } else {
 363                        last_tool_use
 364                            .input
 365                            .get("a")
 366                            .expect("'a' has streamed because input is now complete");
 367                        last_tool_use
 368                            .input
 369                            .get("g")
 370                            .expect("'g' has streamed because input is now complete");
 371                    }
 372                } else {
 373                    panic!("last content should be a tool use");
 374                }
 375            });
 376        }
 377    }
 378
 379    assert!(
 380        saw_partial_tool_use,
 381        "should see at least one partially streamed tool use in the history"
 382    );
 383}
 384
 385#[gpui::test]
 386async fn test_tool_authorization(cx: &mut TestAppContext) {
 387    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 388    let fake_model = model.as_fake();
 389
 390    let mut events = thread
 391        .update(cx, |thread, cx| {
 392            thread.add_tool(ToolRequiringPermission);
 393            thread.send(UserMessageId::new(), ["abc"], cx)
 394        })
 395        .unwrap();
 396    cx.run_until_parked();
 397    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 398        LanguageModelToolUse {
 399            id: "tool_id_1".into(),
 400            name: ToolRequiringPermission.name().into(),
 401            raw_input: "{}".into(),
 402            input: json!({}),
 403            is_input_complete: true,
 404        },
 405    ));
 406    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 407        LanguageModelToolUse {
 408            id: "tool_id_2".into(),
 409            name: ToolRequiringPermission.name().into(),
 410            raw_input: "{}".into(),
 411            input: json!({}),
 412            is_input_complete: true,
 413        },
 414    ));
 415    fake_model.end_last_completion_stream();
 416    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 417    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 418
 419    // Approve the first
 420    tool_call_auth_1
 421        .response
 422        .send(tool_call_auth_1.options[1].id.clone())
 423        .unwrap();
 424    cx.run_until_parked();
 425
 426    // Reject the second
 427    tool_call_auth_2
 428        .response
 429        .send(tool_call_auth_1.options[2].id.clone())
 430        .unwrap();
 431    cx.run_until_parked();
 432
 433    let completion = fake_model.pending_completions().pop().unwrap();
 434    let message = completion.messages.last().unwrap();
 435    assert_eq!(
 436        message.content,
 437        vec![
 438            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 439                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 440                tool_name: ToolRequiringPermission.name().into(),
 441                is_error: false,
 442                content: "Allowed".into(),
 443                output: Some("Allowed".into())
 444            }),
 445            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 446                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 447                tool_name: ToolRequiringPermission.name().into(),
 448                is_error: true,
 449                content: "Permission to run tool denied by user".into(),
 450                output: None
 451            })
 452        ]
 453    );
 454
 455    // Simulate yet another tool call.
 456    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 457        LanguageModelToolUse {
 458            id: "tool_id_3".into(),
 459            name: ToolRequiringPermission.name().into(),
 460            raw_input: "{}".into(),
 461            input: json!({}),
 462            is_input_complete: true,
 463        },
 464    ));
 465    fake_model.end_last_completion_stream();
 466
 467    // Respond by always allowing tools.
 468    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 469    tool_call_auth_3
 470        .response
 471        .send(tool_call_auth_3.options[0].id.clone())
 472        .unwrap();
 473    cx.run_until_parked();
 474    let completion = fake_model.pending_completions().pop().unwrap();
 475    let message = completion.messages.last().unwrap();
 476    assert_eq!(
 477        message.content,
 478        vec![language_model::MessageContent::ToolResult(
 479            LanguageModelToolResult {
 480                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 481                tool_name: ToolRequiringPermission.name().into(),
 482                is_error: false,
 483                content: "Allowed".into(),
 484                output: Some("Allowed".into())
 485            }
 486        )]
 487    );
 488
 489    // Simulate a final tool call, ensuring we don't trigger authorization.
 490    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 491        LanguageModelToolUse {
 492            id: "tool_id_4".into(),
 493            name: ToolRequiringPermission.name().into(),
 494            raw_input: "{}".into(),
 495            input: json!({}),
 496            is_input_complete: true,
 497        },
 498    ));
 499    fake_model.end_last_completion_stream();
 500    cx.run_until_parked();
 501    let completion = fake_model.pending_completions().pop().unwrap();
 502    let message = completion.messages.last().unwrap();
 503    assert_eq!(
 504        message.content,
 505        vec![language_model::MessageContent::ToolResult(
 506            LanguageModelToolResult {
 507                tool_use_id: "tool_id_4".into(),
 508                tool_name: ToolRequiringPermission.name().into(),
 509                is_error: false,
 510                content: "Allowed".into(),
 511                output: Some("Allowed".into())
 512            }
 513        )]
 514    );
 515}
 516
 517#[gpui::test]
 518async fn test_tool_hallucination(cx: &mut TestAppContext) {
 519    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 520    let fake_model = model.as_fake();
 521
 522    let mut events = thread
 523        .update(cx, |thread, cx| {
 524            thread.send(UserMessageId::new(), ["abc"], cx)
 525        })
 526        .unwrap();
 527    cx.run_until_parked();
 528    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 529        LanguageModelToolUse {
 530            id: "tool_id_1".into(),
 531            name: "nonexistent_tool".into(),
 532            raw_input: "{}".into(),
 533            input: json!({}),
 534            is_input_complete: true,
 535        },
 536    ));
 537    fake_model.end_last_completion_stream();
 538
 539    let tool_call = expect_tool_call(&mut events).await;
 540    assert_eq!(tool_call.title, "nonexistent_tool");
 541    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 542    let update = expect_tool_call_update_fields(&mut events).await;
 543    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 544}
 545
 546#[gpui::test]
 547async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 548    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 549    let fake_model = model.as_fake();
 550
 551    let events = thread
 552        .update(cx, |thread, cx| {
 553            thread.add_tool(EchoTool);
 554            thread.send(UserMessageId::new(), ["abc"], cx)
 555        })
 556        .unwrap();
 557    cx.run_until_parked();
 558    let tool_use = LanguageModelToolUse {
 559        id: "tool_id_1".into(),
 560        name: EchoTool.name().into(),
 561        raw_input: "{}".into(),
 562        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 563        is_input_complete: true,
 564    };
 565    fake_model
 566        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 567    fake_model.end_last_completion_stream();
 568
 569    cx.run_until_parked();
 570    let completion = fake_model.pending_completions().pop().unwrap();
 571    let tool_result = LanguageModelToolResult {
 572        tool_use_id: "tool_id_1".into(),
 573        tool_name: EchoTool.name().into(),
 574        is_error: false,
 575        content: "def".into(),
 576        output: Some("def".into()),
 577    };
 578    assert_eq!(
 579        completion.messages[1..],
 580        vec![
 581            LanguageModelRequestMessage {
 582                role: Role::User,
 583                content: vec!["abc".into()],
 584                cache: false
 585            },
 586            LanguageModelRequestMessage {
 587                role: Role::Assistant,
 588                content: vec![MessageContent::ToolUse(tool_use.clone())],
 589                cache: false
 590            },
 591            LanguageModelRequestMessage {
 592                role: Role::User,
 593                content: vec![MessageContent::ToolResult(tool_result.clone())],
 594                cache: true
 595            },
 596        ]
 597    );
 598
 599    // Simulate reaching tool use limit.
 600    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 601        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 602    ));
 603    fake_model.end_last_completion_stream();
 604    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 605    assert!(
 606        last_event
 607            .unwrap_err()
 608            .is::<language_model::ToolUseLimitReachedError>()
 609    );
 610
 611    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 612    cx.run_until_parked();
 613    let completion = fake_model.pending_completions().pop().unwrap();
 614    assert_eq!(
 615        completion.messages[1..],
 616        vec![
 617            LanguageModelRequestMessage {
 618                role: Role::User,
 619                content: vec!["abc".into()],
 620                cache: false
 621            },
 622            LanguageModelRequestMessage {
 623                role: Role::Assistant,
 624                content: vec![MessageContent::ToolUse(tool_use)],
 625                cache: false
 626            },
 627            LanguageModelRequestMessage {
 628                role: Role::User,
 629                content: vec![MessageContent::ToolResult(tool_result)],
 630                cache: false
 631            },
 632            LanguageModelRequestMessage {
 633                role: Role::User,
 634                content: vec!["Continue where you left off".into()],
 635                cache: true
 636            }
 637        ]
 638    );
 639
 640    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 641    fake_model.end_last_completion_stream();
 642    events.collect::<Vec<_>>().await;
 643    thread.read_with(cx, |thread, _cx| {
 644        assert_eq!(
 645            thread.last_message().unwrap().to_markdown(),
 646            indoc! {"
 647                ## Assistant
 648
 649                Done
 650            "}
 651        )
 652    });
 653
 654    // Ensure we error if calling resume when tool use limit was *not* reached.
 655    let error = thread
 656        .update(cx, |thread, cx| thread.resume(cx))
 657        .unwrap_err();
 658    assert_eq!(
 659        error.to_string(),
 660        "can only resume after tool use limit is reached"
 661    )
 662}
 663
 664#[gpui::test]
 665async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 666    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 667    let fake_model = model.as_fake();
 668
 669    let events = thread
 670        .update(cx, |thread, cx| {
 671            thread.add_tool(EchoTool);
 672            thread.send(UserMessageId::new(), ["abc"], cx)
 673        })
 674        .unwrap();
 675    cx.run_until_parked();
 676
 677    let tool_use = LanguageModelToolUse {
 678        id: "tool_id_1".into(),
 679        name: EchoTool.name().into(),
 680        raw_input: "{}".into(),
 681        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 682        is_input_complete: true,
 683    };
 684    let tool_result = LanguageModelToolResult {
 685        tool_use_id: "tool_id_1".into(),
 686        tool_name: EchoTool.name().into(),
 687        is_error: false,
 688        content: "def".into(),
 689        output: Some("def".into()),
 690    };
 691    fake_model
 692        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 693    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 694        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 695    ));
 696    fake_model.end_last_completion_stream();
 697    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 698    assert!(
 699        last_event
 700            .unwrap_err()
 701            .is::<language_model::ToolUseLimitReachedError>()
 702    );
 703
 704    thread
 705        .update(cx, |thread, cx| {
 706            thread.send(UserMessageId::new(), vec!["ghi"], cx)
 707        })
 708        .unwrap();
 709    cx.run_until_parked();
 710    let completion = fake_model.pending_completions().pop().unwrap();
 711    assert_eq!(
 712        completion.messages[1..],
 713        vec![
 714            LanguageModelRequestMessage {
 715                role: Role::User,
 716                content: vec!["abc".into()],
 717                cache: false
 718            },
 719            LanguageModelRequestMessage {
 720                role: Role::Assistant,
 721                content: vec![MessageContent::ToolUse(tool_use)],
 722                cache: false
 723            },
 724            LanguageModelRequestMessage {
 725                role: Role::User,
 726                content: vec![MessageContent::ToolResult(tool_result)],
 727                cache: false
 728            },
 729            LanguageModelRequestMessage {
 730                role: Role::User,
 731                content: vec!["ghi".into()],
 732                cache: true
 733            }
 734        ]
 735    );
 736}
 737
 738async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
 739    let event = events
 740        .next()
 741        .await
 742        .expect("no tool call authorization event received")
 743        .unwrap();
 744    match event {
 745        ThreadEvent::ToolCall(tool_call) => tool_call,
 746        event => {
 747            panic!("Unexpected event {event:?}");
 748        }
 749    }
 750}
 751
 752async fn expect_tool_call_update_fields(
 753    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 754) -> acp::ToolCallUpdate {
 755    let event = events
 756        .next()
 757        .await
 758        .expect("no tool call authorization event received")
 759        .unwrap();
 760    match event {
 761        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
 762        event => {
 763            panic!("Unexpected event {event:?}");
 764        }
 765    }
 766}
 767
 768async fn next_tool_call_authorization(
 769    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 770) -> ToolCallAuthorization {
 771    loop {
 772        let event = events
 773            .next()
 774            .await
 775            .expect("no tool call authorization event received")
 776            .unwrap();
 777        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
 778            let permission_kinds = tool_call_authorization
 779                .options
 780                .iter()
 781                .map(|o| o.kind)
 782                .collect::<Vec<_>>();
 783            assert_eq!(
 784                permission_kinds,
 785                vec![
 786                    acp::PermissionOptionKind::AllowAlways,
 787                    acp::PermissionOptionKind::AllowOnce,
 788                    acp::PermissionOptionKind::RejectOnce,
 789                ]
 790            );
 791            return tool_call_authorization;
 792        }
 793    }
 794}
 795
 796#[gpui::test]
 797#[ignore = "can't run on CI yet"]
 798async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 799    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 800
 801    // Test concurrent tool calls with different delay times
 802    let events = thread
 803        .update(cx, |thread, cx| {
 804            thread.add_tool(DelayTool);
 805            thread.send(
 806                UserMessageId::new(),
 807                [
 808                    "Call the delay tool twice in the same message.",
 809                    "Once with 100ms. Once with 300ms.",
 810                    "When both timers are complete, describe the outputs.",
 811                ],
 812                cx,
 813            )
 814        })
 815        .unwrap()
 816        .collect()
 817        .await;
 818
 819    let stop_reasons = stop_events(events);
 820    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 821
 822    thread.update(cx, |thread, _cx| {
 823        let last_message = thread.last_message().unwrap();
 824        let agent_message = last_message.as_agent_message().unwrap();
 825        let text = agent_message
 826            .content
 827            .iter()
 828            .filter_map(|content| {
 829                if let AgentMessageContent::Text(text) = content {
 830                    Some(text.as_str())
 831                } else {
 832                    None
 833                }
 834            })
 835            .collect::<String>();
 836
 837        assert!(text.contains("Ding"));
 838    });
 839}
 840
 841#[gpui::test]
 842async fn test_profiles(cx: &mut TestAppContext) {
 843    let ThreadTest {
 844        model, thread, fs, ..
 845    } = setup(cx, TestModel::Fake).await;
 846    let fake_model = model.as_fake();
 847
 848    thread.update(cx, |thread, _cx| {
 849        thread.add_tool(DelayTool);
 850        thread.add_tool(EchoTool);
 851        thread.add_tool(InfiniteTool);
 852    });
 853
 854    // Override profiles and wait for settings to be loaded.
 855    fs.insert_file(
 856        paths::settings_file(),
 857        json!({
 858            "agent": {
 859                "profiles": {
 860                    "test-1": {
 861                        "name": "Test Profile 1",
 862                        "tools": {
 863                            EchoTool.name(): true,
 864                            DelayTool.name(): true,
 865                        }
 866                    },
 867                    "test-2": {
 868                        "name": "Test Profile 2",
 869                        "tools": {
 870                            InfiniteTool.name(): true,
 871                        }
 872                    }
 873                }
 874            }
 875        })
 876        .to_string()
 877        .into_bytes(),
 878    )
 879    .await;
 880    cx.run_until_parked();
 881
 882    // Test that test-1 profile (default) has echo and delay tools
 883    thread
 884        .update(cx, |thread, cx| {
 885            thread.set_profile(AgentProfileId("test-1".into()));
 886            thread.send(UserMessageId::new(), ["test"], cx)
 887        })
 888        .unwrap();
 889    cx.run_until_parked();
 890
 891    let mut pending_completions = fake_model.pending_completions();
 892    assert_eq!(pending_completions.len(), 1);
 893    let completion = pending_completions.pop().unwrap();
 894    let tool_names: Vec<String> = completion
 895        .tools
 896        .iter()
 897        .map(|tool| tool.name.clone())
 898        .collect();
 899    assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
 900    fake_model.end_last_completion_stream();
 901
 902    // Switch to test-2 profile, and verify that it has only the infinite tool.
 903    thread
 904        .update(cx, |thread, cx| {
 905            thread.set_profile(AgentProfileId("test-2".into()));
 906            thread.send(UserMessageId::new(), ["test2"], cx)
 907        })
 908        .unwrap();
 909    cx.run_until_parked();
 910    let mut pending_completions = fake_model.pending_completions();
 911    assert_eq!(pending_completions.len(), 1);
 912    let completion = pending_completions.pop().unwrap();
 913    let tool_names: Vec<String> = completion
 914        .tools
 915        .iter()
 916        .map(|tool| tool.name.clone())
 917        .collect();
 918    assert_eq!(tool_names, vec![InfiniteTool.name()]);
 919}
 920
 921#[gpui::test]
 922#[ignore = "can't run on CI yet"]
 923async fn test_cancellation(cx: &mut TestAppContext) {
 924    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 925
 926    let mut events = thread
 927        .update(cx, |thread, cx| {
 928            thread.add_tool(InfiniteTool);
 929            thread.add_tool(EchoTool);
 930            thread.send(
 931                UserMessageId::new(),
 932                ["Call the echo tool, then call the infinite tool, then explain their output"],
 933                cx,
 934            )
 935        })
 936        .unwrap();
 937
 938    // Wait until both tools are called.
 939    let mut expected_tools = vec!["Echo", "Infinite Tool"];
 940    let mut echo_id = None;
 941    let mut echo_completed = false;
 942    while let Some(event) = events.next().await {
 943        match event.unwrap() {
 944            ThreadEvent::ToolCall(tool_call) => {
 945                assert_eq!(tool_call.title, expected_tools.remove(0));
 946                if tool_call.title == "Echo" {
 947                    echo_id = Some(tool_call.id);
 948                }
 949            }
 950            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
 951                acp::ToolCallUpdate {
 952                    id,
 953                    fields:
 954                        acp::ToolCallUpdateFields {
 955                            status: Some(acp::ToolCallStatus::Completed),
 956                            ..
 957                        },
 958                },
 959            )) if Some(&id) == echo_id.as_ref() => {
 960                echo_completed = true;
 961            }
 962            _ => {}
 963        }
 964
 965        if expected_tools.is_empty() && echo_completed {
 966            break;
 967        }
 968    }
 969
 970    // Cancel the current send and ensure that the event stream is closed, even
 971    // if one of the tools is still running.
 972    thread.update(cx, |thread, cx| thread.cancel(cx));
 973    let events = events.collect::<Vec<_>>().await;
 974    let last_event = events.last();
 975    assert!(
 976        matches!(
 977            last_event,
 978            Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
 979        ),
 980        "unexpected event {last_event:?}"
 981    );
 982
 983    // Ensure we can still send a new message after cancellation.
 984    let events = thread
 985        .update(cx, |thread, cx| {
 986            thread.send(
 987                UserMessageId::new(),
 988                ["Testing: reply with 'Hello' then stop."],
 989                cx,
 990            )
 991        })
 992        .unwrap()
 993        .collect::<Vec<_>>()
 994        .await;
 995    thread.update(cx, |thread, _cx| {
 996        let message = thread.last_message().unwrap();
 997        let agent_message = message.as_agent_message().unwrap();
 998        assert_eq!(
 999            agent_message.content,
1000            vec![AgentMessageContent::Text("Hello".to_string())]
1001        );
1002    });
1003    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1004}
1005
1006#[gpui::test]
1007async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1008    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1009    let fake_model = model.as_fake();
1010
1011    let events_1 = thread
1012        .update(cx, |thread, cx| {
1013            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1014        })
1015        .unwrap();
1016    cx.run_until_parked();
1017    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1018    cx.run_until_parked();
1019
1020    let events_2 = thread
1021        .update(cx, |thread, cx| {
1022            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1023        })
1024        .unwrap();
1025    cx.run_until_parked();
1026    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1027    fake_model
1028        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1029    fake_model.end_last_completion_stream();
1030
1031    let events_1 = events_1.collect::<Vec<_>>().await;
1032    assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1033    let events_2 = events_2.collect::<Vec<_>>().await;
1034    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1035}
1036
1037#[gpui::test]
1038async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1039    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1040    let fake_model = model.as_fake();
1041
1042    let events_1 = thread
1043        .update(cx, |thread, cx| {
1044            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1045        })
1046        .unwrap();
1047    cx.run_until_parked();
1048    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1049    fake_model
1050        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1051    fake_model.end_last_completion_stream();
1052    let events_1 = events_1.collect::<Vec<_>>().await;
1053
1054    let events_2 = thread
1055        .update(cx, |thread, cx| {
1056            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1057        })
1058        .unwrap();
1059    cx.run_until_parked();
1060    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1061    fake_model
1062        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1063    fake_model.end_last_completion_stream();
1064    let events_2 = events_2.collect::<Vec<_>>().await;
1065
1066    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1067    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1068}
1069
1070#[gpui::test]
1071async fn test_refusal(cx: &mut TestAppContext) {
1072    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1073    let fake_model = model.as_fake();
1074
1075    let events = thread
1076        .update(cx, |thread, cx| {
1077            thread.send(UserMessageId::new(), ["Hello"], cx)
1078        })
1079        .unwrap();
1080    cx.run_until_parked();
1081    thread.read_with(cx, |thread, _| {
1082        assert_eq!(
1083            thread.to_markdown(),
1084            indoc! {"
1085                ## User
1086
1087                Hello
1088            "}
1089        );
1090    });
1091
1092    fake_model.send_last_completion_stream_text_chunk("Hey!");
1093    cx.run_until_parked();
1094    thread.read_with(cx, |thread, _| {
1095        assert_eq!(
1096            thread.to_markdown(),
1097            indoc! {"
1098                ## User
1099
1100                Hello
1101
1102                ## Assistant
1103
1104                Hey!
1105            "}
1106        );
1107    });
1108
1109    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1110    fake_model
1111        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1112    let events = events.collect::<Vec<_>>().await;
1113    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1114    thread.read_with(cx, |thread, _| {
1115        assert_eq!(thread.to_markdown(), "");
1116    });
1117}
1118
1119#[gpui::test]
1120async fn test_truncate_first_message(cx: &mut TestAppContext) {
1121    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1122    let fake_model = model.as_fake();
1123
1124    let message_id = UserMessageId::new();
1125    thread
1126        .update(cx, |thread, cx| {
1127            thread.send(message_id.clone(), ["Hello"], cx)
1128        })
1129        .unwrap();
1130    cx.run_until_parked();
1131    thread.read_with(cx, |thread, _| {
1132        assert_eq!(
1133            thread.to_markdown(),
1134            indoc! {"
1135                ## User
1136
1137                Hello
1138            "}
1139        );
1140        assert_eq!(thread.latest_token_usage(), None);
1141    });
1142
1143    fake_model.send_last_completion_stream_text_chunk("Hey!");
1144    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1145        language_model::TokenUsage {
1146            input_tokens: 32_000,
1147            output_tokens: 16_000,
1148            cache_creation_input_tokens: 0,
1149            cache_read_input_tokens: 0,
1150        },
1151    ));
1152    cx.run_until_parked();
1153    thread.read_with(cx, |thread, _| {
1154        assert_eq!(
1155            thread.to_markdown(),
1156            indoc! {"
1157                ## User
1158
1159                Hello
1160
1161                ## Assistant
1162
1163                Hey!
1164            "}
1165        );
1166        assert_eq!(
1167            thread.latest_token_usage(),
1168            Some(acp_thread::TokenUsage {
1169                used_tokens: 32_000 + 16_000,
1170                max_tokens: 1_000_000,
1171            })
1172        );
1173    });
1174
1175    thread
1176        .update(cx, |thread, cx| thread.truncate(message_id, cx))
1177        .unwrap();
1178    cx.run_until_parked();
1179    thread.read_with(cx, |thread, _| {
1180        assert_eq!(thread.to_markdown(), "");
1181        assert_eq!(thread.latest_token_usage(), None);
1182    });
1183
1184    // Ensure we can still send a new message after truncation.
1185    thread
1186        .update(cx, |thread, cx| {
1187            thread.send(UserMessageId::new(), ["Hi"], cx)
1188        })
1189        .unwrap();
1190    thread.update(cx, |thread, _cx| {
1191        assert_eq!(
1192            thread.to_markdown(),
1193            indoc! {"
1194                ## User
1195
1196                Hi
1197            "}
1198        );
1199    });
1200    cx.run_until_parked();
1201    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1202    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1203        language_model::TokenUsage {
1204            input_tokens: 40_000,
1205            output_tokens: 20_000,
1206            cache_creation_input_tokens: 0,
1207            cache_read_input_tokens: 0,
1208        },
1209    ));
1210    cx.run_until_parked();
1211    thread.read_with(cx, |thread, _| {
1212        assert_eq!(
1213            thread.to_markdown(),
1214            indoc! {"
1215                ## User
1216
1217                Hi
1218
1219                ## Assistant
1220
1221                Ahoy!
1222            "}
1223        );
1224
1225        assert_eq!(
1226            thread.latest_token_usage(),
1227            Some(acp_thread::TokenUsage {
1228                used_tokens: 40_000 + 20_000,
1229                max_tokens: 1_000_000,
1230            })
1231        );
1232    });
1233}
1234
1235#[gpui::test]
1236async fn test_truncate_second_message(cx: &mut TestAppContext) {
1237    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1238    let fake_model = model.as_fake();
1239
1240    thread
1241        .update(cx, |thread, cx| {
1242            thread.send(UserMessageId::new(), ["Message 1"], cx)
1243        })
1244        .unwrap();
1245    cx.run_until_parked();
1246    fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1247    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1248        language_model::TokenUsage {
1249            input_tokens: 32_000,
1250            output_tokens: 16_000,
1251            cache_creation_input_tokens: 0,
1252            cache_read_input_tokens: 0,
1253        },
1254    ));
1255    fake_model.end_last_completion_stream();
1256    cx.run_until_parked();
1257
1258    let assert_first_message_state = |cx: &mut TestAppContext| {
1259        thread.clone().read_with(cx, |thread, _| {
1260            assert_eq!(
1261                thread.to_markdown(),
1262                indoc! {"
1263                    ## User
1264
1265                    Message 1
1266
1267                    ## Assistant
1268
1269                    Message 1 response
1270                "}
1271            );
1272
1273            assert_eq!(
1274                thread.latest_token_usage(),
1275                Some(acp_thread::TokenUsage {
1276                    used_tokens: 32_000 + 16_000,
1277                    max_tokens: 1_000_000,
1278                })
1279            );
1280        });
1281    };
1282
1283    assert_first_message_state(cx);
1284
1285    let second_message_id = UserMessageId::new();
1286    thread
1287        .update(cx, |thread, cx| {
1288            thread.send(second_message_id.clone(), ["Message 2"], cx)
1289        })
1290        .unwrap();
1291    cx.run_until_parked();
1292
1293    fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1294    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1295        language_model::TokenUsage {
1296            input_tokens: 40_000,
1297            output_tokens: 20_000,
1298            cache_creation_input_tokens: 0,
1299            cache_read_input_tokens: 0,
1300        },
1301    ));
1302    fake_model.end_last_completion_stream();
1303    cx.run_until_parked();
1304
1305    thread.read_with(cx, |thread, _| {
1306        assert_eq!(
1307            thread.to_markdown(),
1308            indoc! {"
1309                ## User
1310
1311                Message 1
1312
1313                ## Assistant
1314
1315                Message 1 response
1316
1317                ## User
1318
1319                Message 2
1320
1321                ## Assistant
1322
1323                Message 2 response
1324            "}
1325        );
1326
1327        assert_eq!(
1328            thread.latest_token_usage(),
1329            Some(acp_thread::TokenUsage {
1330                used_tokens: 40_000 + 20_000,
1331                max_tokens: 1_000_000,
1332            })
1333        );
1334    });
1335
1336    thread
1337        .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1338        .unwrap();
1339    cx.run_until_parked();
1340
1341    assert_first_message_state(cx);
1342}
1343
1344#[gpui::test]
1345async fn test_title_generation(cx: &mut TestAppContext) {
1346    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1347    let fake_model = model.as_fake();
1348
1349    let summary_model = Arc::new(FakeLanguageModel::default());
1350    thread.update(cx, |thread, cx| {
1351        thread.set_summarization_model(Some(summary_model.clone()), cx)
1352    });
1353
1354    let send = thread
1355        .update(cx, |thread, cx| {
1356            thread.send(UserMessageId::new(), ["Hello"], cx)
1357        })
1358        .unwrap();
1359    cx.run_until_parked();
1360
1361    fake_model.send_last_completion_stream_text_chunk("Hey!");
1362    fake_model.end_last_completion_stream();
1363    cx.run_until_parked();
1364    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1365
1366    // Ensure the summary model has been invoked to generate a title.
1367    summary_model.send_last_completion_stream_text_chunk("Hello ");
1368    summary_model.send_last_completion_stream_text_chunk("world\nG");
1369    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1370    summary_model.end_last_completion_stream();
1371    send.collect::<Vec<_>>().await;
1372    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1373
1374    // Send another message, ensuring no title is generated this time.
1375    let send = thread
1376        .update(cx, |thread, cx| {
1377            thread.send(UserMessageId::new(), ["Hello again"], cx)
1378        })
1379        .unwrap();
1380    cx.run_until_parked();
1381    fake_model.send_last_completion_stream_text_chunk("Hey again!");
1382    fake_model.end_last_completion_stream();
1383    cx.run_until_parked();
1384    assert_eq!(summary_model.pending_completions(), Vec::new());
1385    send.collect::<Vec<_>>().await;
1386    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1387}
1388
1389#[gpui::test]
1390async fn test_agent_connection(cx: &mut TestAppContext) {
1391    cx.update(settings::init);
1392    let templates = Templates::new();
1393
1394    // Initialize language model system with test provider
1395    cx.update(|cx| {
1396        gpui_tokio::init(cx);
1397        client::init_settings(cx);
1398
1399        let http_client = FakeHttpClient::with_404_response();
1400        let clock = Arc::new(clock::FakeSystemClock::new());
1401        let client = Client::new(clock, http_client, cx);
1402        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1403        language_model::init(client.clone(), cx);
1404        language_models::init(user_store, client.clone(), cx);
1405        Project::init_settings(cx);
1406        LanguageModelRegistry::test(cx);
1407        agent_settings::init(cx);
1408    });
1409    cx.executor().forbid_parking();
1410
1411    // Create a project for new_thread
1412    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1413    fake_fs.insert_tree(path!("/test"), json!({})).await;
1414    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1415    let cwd = Path::new("/test");
1416    let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1417    let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1418
1419    // Create agent and connection
1420    let agent = NativeAgent::new(
1421        project.clone(),
1422        history_store,
1423        templates.clone(),
1424        None,
1425        fake_fs.clone(),
1426        &mut cx.to_async(),
1427    )
1428    .await
1429    .unwrap();
1430    let connection = NativeAgentConnection(agent.clone());
1431
1432    // Test model_selector returns Some
1433    let selector_opt = connection.model_selector();
1434    assert!(
1435        selector_opt.is_some(),
1436        "agent2 should always support ModelSelector"
1437    );
1438    let selector = selector_opt.unwrap();
1439
1440    // Test list_models
1441    let listed_models = cx
1442        .update(|cx| selector.list_models(cx))
1443        .await
1444        .expect("list_models should succeed");
1445    let AgentModelList::Grouped(listed_models) = listed_models else {
1446        panic!("Unexpected model list type");
1447    };
1448    assert!(!listed_models.is_empty(), "should have at least one model");
1449    assert_eq!(
1450        listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1451        "fake/fake"
1452    );
1453
1454    // Create a thread using new_thread
1455    let connection_rc = Rc::new(connection.clone());
1456    let acp_thread = cx
1457        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1458        .await
1459        .expect("new_thread should succeed");
1460
1461    // Get the session_id from the AcpThread
1462    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1463
1464    // Test selected_model returns the default
1465    let model = cx
1466        .update(|cx| selector.selected_model(&session_id, cx))
1467        .await
1468        .expect("selected_model should succeed");
1469    let model = cx
1470        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1471        .unwrap();
1472    let model = model.as_fake();
1473    assert_eq!(model.id().0, "fake", "should return default model");
1474
1475    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1476    cx.run_until_parked();
1477    model.send_last_completion_stream_text_chunk("def");
1478    cx.run_until_parked();
1479    acp_thread.read_with(cx, |thread, cx| {
1480        assert_eq!(
1481            thread.to_markdown(cx),
1482            indoc! {"
1483                ## User
1484
1485                abc
1486
1487                ## Assistant
1488
1489                def
1490
1491            "}
1492        )
1493    });
1494
1495    // Test cancel
1496    cx.update(|cx| connection.cancel(&session_id, cx));
1497    request.await.expect("prompt should fail gracefully");
1498
1499    // Ensure that dropping the ACP thread causes the native thread to be
1500    // dropped as well.
1501    cx.update(|_| drop(acp_thread));
1502    let result = cx
1503        .update(|cx| {
1504            connection.prompt(
1505                Some(acp_thread::UserMessageId::new()),
1506                acp::PromptRequest {
1507                    session_id: session_id.clone(),
1508                    prompt: vec!["ghi".into()],
1509                },
1510                cx,
1511            )
1512        })
1513        .await;
1514    assert_eq!(
1515        result.as_ref().unwrap_err().to_string(),
1516        "Session not found",
1517        "unexpected result: {:?}",
1518        result
1519    );
1520}
1521
1522#[gpui::test]
1523async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1524    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1525    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1526    let fake_model = model.as_fake();
1527
1528    let mut events = thread
1529        .update(cx, |thread, cx| {
1530            thread.send(UserMessageId::new(), ["Think"], cx)
1531        })
1532        .unwrap();
1533    cx.run_until_parked();
1534
1535    // Simulate streaming partial input.
1536    let input = json!({});
1537    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1538        LanguageModelToolUse {
1539            id: "1".into(),
1540            name: ThinkingTool.name().into(),
1541            raw_input: input.to_string(),
1542            input,
1543            is_input_complete: false,
1544        },
1545    ));
1546
1547    // Input streaming completed
1548    let input = json!({ "content": "Thinking hard!" });
1549    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1550        LanguageModelToolUse {
1551            id: "1".into(),
1552            name: "thinking".into(),
1553            raw_input: input.to_string(),
1554            input,
1555            is_input_complete: true,
1556        },
1557    ));
1558    fake_model.end_last_completion_stream();
1559    cx.run_until_parked();
1560
1561    let tool_call = expect_tool_call(&mut events).await;
1562    assert_eq!(
1563        tool_call,
1564        acp::ToolCall {
1565            id: acp::ToolCallId("1".into()),
1566            title: "Thinking".into(),
1567            kind: acp::ToolKind::Think,
1568            status: acp::ToolCallStatus::Pending,
1569            content: vec![],
1570            locations: vec![],
1571            raw_input: Some(json!({})),
1572            raw_output: None,
1573        }
1574    );
1575    let update = expect_tool_call_update_fields(&mut events).await;
1576    assert_eq!(
1577        update,
1578        acp::ToolCallUpdate {
1579            id: acp::ToolCallId("1".into()),
1580            fields: acp::ToolCallUpdateFields {
1581                title: Some("Thinking".into()),
1582                kind: Some(acp::ToolKind::Think),
1583                raw_input: Some(json!({ "content": "Thinking hard!" })),
1584                ..Default::default()
1585            },
1586        }
1587    );
1588    let update = expect_tool_call_update_fields(&mut events).await;
1589    assert_eq!(
1590        update,
1591        acp::ToolCallUpdate {
1592            id: acp::ToolCallId("1".into()),
1593            fields: acp::ToolCallUpdateFields {
1594                status: Some(acp::ToolCallStatus::InProgress),
1595                ..Default::default()
1596            },
1597        }
1598    );
1599    let update = expect_tool_call_update_fields(&mut events).await;
1600    assert_eq!(
1601        update,
1602        acp::ToolCallUpdate {
1603            id: acp::ToolCallId("1".into()),
1604            fields: acp::ToolCallUpdateFields {
1605                content: Some(vec!["Thinking hard!".into()]),
1606                ..Default::default()
1607            },
1608        }
1609    );
1610    let update = expect_tool_call_update_fields(&mut events).await;
1611    assert_eq!(
1612        update,
1613        acp::ToolCallUpdate {
1614            id: acp::ToolCallId("1".into()),
1615            fields: acp::ToolCallUpdateFields {
1616                status: Some(acp::ToolCallStatus::Completed),
1617                raw_output: Some("Finished thinking.".into()),
1618                ..Default::default()
1619            },
1620        }
1621    );
1622}
1623
1624#[gpui::test]
1625async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
1626    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1627    let fake_model = model.as_fake();
1628
1629    let mut events = thread
1630        .update(cx, |thread, cx| {
1631            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1632            thread.send(UserMessageId::new(), ["Hello!"], cx)
1633        })
1634        .unwrap();
1635    cx.run_until_parked();
1636
1637    fake_model.send_last_completion_stream_text_chunk("Hey!");
1638    fake_model.end_last_completion_stream();
1639
1640    let mut retry_events = Vec::new();
1641    while let Some(Ok(event)) = events.next().await {
1642        match event {
1643            ThreadEvent::Retry(retry_status) => {
1644                retry_events.push(retry_status);
1645            }
1646            ThreadEvent::Stop(..) => break,
1647            _ => {}
1648        }
1649    }
1650
1651    assert_eq!(retry_events.len(), 0);
1652    thread.read_with(cx, |thread, _cx| {
1653        assert_eq!(
1654            thread.to_markdown(),
1655            indoc! {"
1656                ## User
1657
1658                Hello!
1659
1660                ## Assistant
1661
1662                Hey!
1663            "}
1664        )
1665    });
1666}
1667
1668#[gpui::test]
1669async fn test_send_retry_on_error(cx: &mut TestAppContext) {
1670    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1671    let fake_model = model.as_fake();
1672
1673    let mut events = thread
1674        .update(cx, |thread, cx| {
1675            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1676            thread.send(UserMessageId::new(), ["Hello!"], cx)
1677        })
1678        .unwrap();
1679    cx.run_until_parked();
1680
1681    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
1682        provider: LanguageModelProviderName::new("Anthropic"),
1683        retry_after: Some(Duration::from_secs(3)),
1684    });
1685    fake_model.end_last_completion_stream();
1686
1687    cx.executor().advance_clock(Duration::from_secs(3));
1688    cx.run_until_parked();
1689
1690    fake_model.send_last_completion_stream_text_chunk("Hey!");
1691    fake_model.end_last_completion_stream();
1692
1693    let mut retry_events = Vec::new();
1694    while let Some(Ok(event)) = events.next().await {
1695        match event {
1696            ThreadEvent::Retry(retry_status) => {
1697                retry_events.push(retry_status);
1698            }
1699            ThreadEvent::Stop(..) => break,
1700            _ => {}
1701        }
1702    }
1703
1704    assert_eq!(retry_events.len(), 1);
1705    assert!(matches!(
1706        retry_events[0],
1707        acp_thread::RetryStatus { attempt: 1, .. }
1708    ));
1709    thread.read_with(cx, |thread, _cx| {
1710        assert_eq!(
1711            thread.to_markdown(),
1712            indoc! {"
1713                ## User
1714
1715                Hello!
1716
1717                ## Assistant
1718
1719                Hey!
1720            "}
1721        )
1722    });
1723}
1724
1725#[gpui::test]
1726async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
1727    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1728    let fake_model = model.as_fake();
1729
1730    let mut events = thread
1731        .update(cx, |thread, cx| {
1732            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1733            thread.send(UserMessageId::new(), ["Hello!"], cx)
1734        })
1735        .unwrap();
1736    cx.run_until_parked();
1737
1738    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
1739        fake_model.send_last_completion_stream_error(
1740            LanguageModelCompletionError::ServerOverloaded {
1741                provider: LanguageModelProviderName::new("Anthropic"),
1742                retry_after: Some(Duration::from_secs(3)),
1743            },
1744        );
1745        fake_model.end_last_completion_stream();
1746        cx.executor().advance_clock(Duration::from_secs(3));
1747        cx.run_until_parked();
1748    }
1749
1750    let mut errors = Vec::new();
1751    let mut retry_events = Vec::new();
1752    while let Some(event) = events.next().await {
1753        match event {
1754            Ok(ThreadEvent::Retry(retry_status)) => {
1755                retry_events.push(retry_status);
1756            }
1757            Ok(ThreadEvent::Stop(..)) => break,
1758            Err(error) => errors.push(error),
1759            _ => {}
1760        }
1761    }
1762
1763    assert_eq!(
1764        retry_events.len(),
1765        crate::thread::MAX_RETRY_ATTEMPTS as usize
1766    );
1767    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
1768        assert_eq!(retry_events[i].attempt, i + 1);
1769    }
1770    assert_eq!(errors.len(), 1);
1771    let error = errors[0]
1772        .downcast_ref::<LanguageModelCompletionError>()
1773        .unwrap();
1774    assert!(matches!(
1775        error,
1776        LanguageModelCompletionError::ServerOverloaded { .. }
1777    ));
1778}
1779
1780/// Filters out the stop events for asserting against in tests
1781fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
1782    result_events
1783        .into_iter()
1784        .filter_map(|event| match event.unwrap() {
1785            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
1786            _ => None,
1787        })
1788        .collect()
1789}
1790
1791struct ThreadTest {
1792    model: Arc<dyn LanguageModel>,
1793    thread: Entity<Thread>,
1794    project_context: Entity<ProjectContext>,
1795    fs: Arc<FakeFs>,
1796}
1797
1798enum TestModel {
1799    Sonnet4,
1800    Sonnet4Thinking,
1801    Fake,
1802}
1803
1804impl TestModel {
1805    fn id(&self) -> LanguageModelId {
1806        match self {
1807            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
1808            TestModel::Sonnet4Thinking => LanguageModelId("claude-sonnet-4-thinking-latest".into()),
1809            TestModel::Fake => unreachable!(),
1810        }
1811    }
1812}
1813
1814async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
1815    cx.executor().allow_parking();
1816
1817    let fs = FakeFs::new(cx.background_executor.clone());
1818    fs.create_dir(paths::settings_file().parent().unwrap())
1819        .await
1820        .unwrap();
1821    fs.insert_file(
1822        paths::settings_file(),
1823        json!({
1824            "agent": {
1825                "default_profile": "test-profile",
1826                "profiles": {
1827                    "test-profile": {
1828                        "name": "Test Profile",
1829                        "tools": {
1830                            EchoTool.name(): true,
1831                            DelayTool.name(): true,
1832                            WordListTool.name(): true,
1833                            ToolRequiringPermission.name(): true,
1834                            InfiniteTool.name(): true,
1835                        }
1836                    }
1837                }
1838            }
1839        })
1840        .to_string()
1841        .into_bytes(),
1842    )
1843    .await;
1844
1845    cx.update(|cx| {
1846        settings::init(cx);
1847        Project::init_settings(cx);
1848        agent_settings::init(cx);
1849        gpui_tokio::init(cx);
1850        let http_client = ReqwestClient::user_agent("agent tests").unwrap();
1851        cx.set_http_client(Arc::new(http_client));
1852
1853        client::init_settings(cx);
1854        let client = Client::production(cx);
1855        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1856        language_model::init(client.clone(), cx);
1857        language_models::init(user_store, client.clone(), cx);
1858
1859        watch_settings(fs.clone(), cx);
1860    });
1861
1862    let templates = Templates::new();
1863
1864    fs.insert_tree(path!("/test"), json!({})).await;
1865    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1866
1867    let model = cx
1868        .update(|cx| {
1869            if let TestModel::Fake = model {
1870                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
1871            } else {
1872                let model_id = model.id();
1873                let models = LanguageModelRegistry::read_global(cx);
1874                let model = models
1875                    .available_models(cx)
1876                    .find(|model| model.id() == model_id)
1877                    .unwrap();
1878
1879                let provider = models.provider(&model.provider_id()).unwrap();
1880                let authenticated = provider.authenticate(cx);
1881
1882                cx.spawn(async move |_cx| {
1883                    authenticated.await.unwrap();
1884                    model
1885                })
1886            }
1887        })
1888        .await;
1889
1890    let project_context = cx.new(|_cx| ProjectContext::default());
1891    let context_server_registry =
1892        cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1893    let action_log = cx.new(|_| ActionLog::new(project.clone()));
1894    let thread = cx.new(|cx| {
1895        Thread::new(
1896            project,
1897            project_context.clone(),
1898            context_server_registry,
1899            action_log,
1900            templates,
1901            Some(model.clone()),
1902            cx,
1903        )
1904    });
1905    ThreadTest {
1906        model,
1907        thread,
1908        project_context,
1909        fs,
1910    }
1911}
1912
1913#[cfg(test)]
1914#[ctor::ctor]
1915fn init_logger() {
1916    if std::env::var("RUST_LOG").is_ok() {
1917        env_logger::init();
1918    }
1919}
1920
1921fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
1922    let fs = fs.clone();
1923    cx.spawn({
1924        async move |cx| {
1925            let mut new_settings_content_rx = settings::watch_config_file(
1926                cx.background_executor(),
1927                fs,
1928                paths::settings_file().clone(),
1929            );
1930
1931            while let Some(new_settings_content) = new_settings_content_rx.next().await {
1932                cx.update(|cx| {
1933                    SettingsStore::update_global(cx, |settings, cx| {
1934                        settings.set_user_settings(&new_settings_content, cx)
1935                    })
1936                })
1937                .ok();
1938            }
1939        }
1940    })
1941    .detach();
1942}