mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use action_log::ActionLog;
   4use agent_client_protocol::{self as acp};
   5use agent_settings::AgentProfileId;
   6use anyhow::Result;
   7use client::{Client, UserStore};
   8use fs::{FakeFs, Fs};
   9use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
  10use gpui::{
  11    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  12};
  13use indoc::indoc;
  14use language_model::{
  15    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  16    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
  17    LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
  18    fake_provider::FakeLanguageModel,
  19};
  20use pretty_assertions::assert_eq;
  21use project::Project;
  22use prompt_store::ProjectContext;
  23use reqwest_client::ReqwestClient;
  24use schemars::JsonSchema;
  25use serde::{Deserialize, Serialize};
  26use serde_json::json;
  27use settings::SettingsStore;
  28use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
  29use util::path;
  30
  31mod test_tools;
  32use test_tools::*;
  33
  34#[gpui::test]
  35async fn test_echo(cx: &mut TestAppContext) {
  36    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  37    let fake_model = model.as_fake();
  38
  39    let events = thread
  40        .update(cx, |thread, cx| {
  41            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  42        })
  43        .unwrap();
  44    cx.run_until_parked();
  45    fake_model.send_last_completion_stream_text_chunk("Hello");
  46    fake_model
  47        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
  48    fake_model.end_last_completion_stream();
  49
  50    let events = events.collect().await;
  51    thread.update(cx, |thread, _cx| {
  52        assert_eq!(
  53            thread.last_message().unwrap().to_markdown(),
  54            indoc! {"
  55                ## Assistant
  56
  57                Hello
  58            "}
  59        )
  60    });
  61    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  62}
  63
  64#[gpui::test]
  65async fn test_thinking(cx: &mut TestAppContext) {
  66    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  67    let fake_model = model.as_fake();
  68
  69    let events = thread
  70        .update(cx, |thread, cx| {
  71            thread.send(
  72                UserMessageId::new(),
  73                [indoc! {"
  74                    Testing:
  75
  76                    Generate a thinking step where you just think the word 'Think',
  77                    and have your final answer be 'Hello'
  78                "}],
  79                cx,
  80            )
  81        })
  82        .unwrap();
  83    cx.run_until_parked();
  84    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
  85        text: "Think".to_string(),
  86        signature: None,
  87    });
  88    fake_model.send_last_completion_stream_text_chunk("Hello");
  89    fake_model
  90        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
  91    fake_model.end_last_completion_stream();
  92
  93    let events = events.collect().await;
  94    thread.update(cx, |thread, _cx| {
  95        assert_eq!(
  96            thread.last_message().unwrap().to_markdown(),
  97            indoc! {"
  98                ## Assistant
  99
 100                <think>Think</think>
 101                Hello
 102            "}
 103        )
 104    });
 105    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 106}
 107
 108#[gpui::test]
 109async fn test_system_prompt(cx: &mut TestAppContext) {
 110    let ThreadTest {
 111        model,
 112        thread,
 113        project_context,
 114        ..
 115    } = setup(cx, TestModel::Fake).await;
 116    let fake_model = model.as_fake();
 117
 118    project_context.update(cx, |project_context, _cx| {
 119        project_context.shell = "test-shell".into()
 120    });
 121    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 122    thread
 123        .update(cx, |thread, cx| {
 124            thread.send(UserMessageId::new(), ["abc"], cx)
 125        })
 126        .unwrap();
 127    cx.run_until_parked();
 128    let mut pending_completions = fake_model.pending_completions();
 129    assert_eq!(
 130        pending_completions.len(),
 131        1,
 132        "unexpected pending completions: {:?}",
 133        pending_completions
 134    );
 135
 136    let pending_completion = pending_completions.pop().unwrap();
 137    assert_eq!(pending_completion.messages[0].role, Role::System);
 138
 139    let system_message = &pending_completion.messages[0];
 140    let system_prompt = system_message.content[0].to_str().unwrap();
 141    assert!(
 142        system_prompt.contains("test-shell"),
 143        "unexpected system message: {:?}",
 144        system_message
 145    );
 146    assert!(
 147        system_prompt.contains("## Fixing Diagnostics"),
 148        "unexpected system message: {:?}",
 149        system_message
 150    );
 151}
 152
 153#[gpui::test]
 154async fn test_prompt_caching(cx: &mut TestAppContext) {
 155    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 156    let fake_model = model.as_fake();
 157
 158    // Send initial user message and verify it's cached
 159    thread
 160        .update(cx, |thread, cx| {
 161            thread.send(UserMessageId::new(), ["Message 1"], cx)
 162        })
 163        .unwrap();
 164    cx.run_until_parked();
 165
 166    let completion = fake_model.pending_completions().pop().unwrap();
 167    assert_eq!(
 168        completion.messages[1..],
 169        vec![LanguageModelRequestMessage {
 170            role: Role::User,
 171            content: vec!["Message 1".into()],
 172            cache: true
 173        }]
 174    );
 175    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 176        "Response to Message 1".into(),
 177    ));
 178    fake_model.end_last_completion_stream();
 179    cx.run_until_parked();
 180
 181    // Send another user message and verify only the latest is cached
 182    thread
 183        .update(cx, |thread, cx| {
 184            thread.send(UserMessageId::new(), ["Message 2"], cx)
 185        })
 186        .unwrap();
 187    cx.run_until_parked();
 188
 189    let completion = fake_model.pending_completions().pop().unwrap();
 190    assert_eq!(
 191        completion.messages[1..],
 192        vec![
 193            LanguageModelRequestMessage {
 194                role: Role::User,
 195                content: vec!["Message 1".into()],
 196                cache: false
 197            },
 198            LanguageModelRequestMessage {
 199                role: Role::Assistant,
 200                content: vec!["Response to Message 1".into()],
 201                cache: false
 202            },
 203            LanguageModelRequestMessage {
 204                role: Role::User,
 205                content: vec!["Message 2".into()],
 206                cache: true
 207            }
 208        ]
 209    );
 210    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 211        "Response to Message 2".into(),
 212    ));
 213    fake_model.end_last_completion_stream();
 214    cx.run_until_parked();
 215
 216    // Simulate a tool call and verify that the latest tool result is cached
 217    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 218    thread
 219        .update(cx, |thread, cx| {
 220            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 221        })
 222        .unwrap();
 223    cx.run_until_parked();
 224
 225    let tool_use = LanguageModelToolUse {
 226        id: "tool_1".into(),
 227        name: EchoTool.name().into(),
 228        raw_input: json!({"text": "test"}).to_string(),
 229        input: json!({"text": "test"}),
 230        is_input_complete: true,
 231    };
 232    fake_model
 233        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 234    fake_model.end_last_completion_stream();
 235    cx.run_until_parked();
 236
 237    let completion = fake_model.pending_completions().pop().unwrap();
 238    let tool_result = LanguageModelToolResult {
 239        tool_use_id: "tool_1".into(),
 240        tool_name: EchoTool.name().into(),
 241        is_error: false,
 242        content: "test".into(),
 243        output: Some("test".into()),
 244    };
 245    assert_eq!(
 246        completion.messages[1..],
 247        vec![
 248            LanguageModelRequestMessage {
 249                role: Role::User,
 250                content: vec!["Message 1".into()],
 251                cache: false
 252            },
 253            LanguageModelRequestMessage {
 254                role: Role::Assistant,
 255                content: vec!["Response to Message 1".into()],
 256                cache: false
 257            },
 258            LanguageModelRequestMessage {
 259                role: Role::User,
 260                content: vec!["Message 2".into()],
 261                cache: false
 262            },
 263            LanguageModelRequestMessage {
 264                role: Role::Assistant,
 265                content: vec!["Response to Message 2".into()],
 266                cache: false
 267            },
 268            LanguageModelRequestMessage {
 269                role: Role::User,
 270                content: vec!["Use the echo tool".into()],
 271                cache: false
 272            },
 273            LanguageModelRequestMessage {
 274                role: Role::Assistant,
 275                content: vec![MessageContent::ToolUse(tool_use)],
 276                cache: false
 277            },
 278            LanguageModelRequestMessage {
 279                role: Role::User,
 280                content: vec![MessageContent::ToolResult(tool_result)],
 281                cache: true
 282            }
 283        ]
 284    );
 285}
 286
 287#[gpui::test]
 288#[cfg_attr(not(feature = "e2e"), ignore)]
 289async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 290    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 291
 292    // Test a tool call that's likely to complete *before* streaming stops.
 293    let events = thread
 294        .update(cx, |thread, cx| {
 295            thread.add_tool(EchoTool);
 296            thread.send(
 297                UserMessageId::new(),
 298                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 299                cx,
 300            )
 301        })
 302        .unwrap()
 303        .collect()
 304        .await;
 305    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 306
 307    // Test a tool calls that's likely to complete *after* streaming stops.
 308    let events = thread
 309        .update(cx, |thread, cx| {
 310            thread.remove_tool(&AgentTool::name(&EchoTool));
 311            thread.add_tool(DelayTool);
 312            thread.send(
 313                UserMessageId::new(),
 314                [
 315                    "Now call the delay tool with 200ms.",
 316                    "When the timer goes off, then you echo the output of the tool.",
 317                ],
 318                cx,
 319            )
 320        })
 321        .unwrap()
 322        .collect()
 323        .await;
 324    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 325    thread.update(cx, |thread, _cx| {
 326        assert!(
 327            thread
 328                .last_message()
 329                .unwrap()
 330                .as_agent_message()
 331                .unwrap()
 332                .content
 333                .iter()
 334                .any(|content| {
 335                    if let AgentMessageContent::Text(text) = content {
 336                        text.contains("Ding")
 337                    } else {
 338                        false
 339                    }
 340                }),
 341            "{}",
 342            thread.to_markdown()
 343        );
 344    });
 345}
 346
 347#[gpui::test]
 348#[cfg_attr(not(feature = "e2e"), ignore)]
 349async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 350    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 351
 352    // Test a tool call that's likely to complete *before* streaming stops.
 353    let mut events = thread
 354        .update(cx, |thread, cx| {
 355            thread.add_tool(WordListTool);
 356            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 357        })
 358        .unwrap();
 359
 360    let mut saw_partial_tool_use = false;
 361    while let Some(event) = events.next().await {
 362        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 363            thread.update(cx, |thread, _cx| {
 364                // Look for a tool use in the thread's last message
 365                let message = thread.last_message().unwrap();
 366                let agent_message = message.as_agent_message().unwrap();
 367                let last_content = agent_message.content.last().unwrap();
 368                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 369                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 370                    if tool_call.status == acp::ToolCallStatus::Pending {
 371                        if !last_tool_use.is_input_complete
 372                            && last_tool_use.input.get("g").is_none()
 373                        {
 374                            saw_partial_tool_use = true;
 375                        }
 376                    } else {
 377                        last_tool_use
 378                            .input
 379                            .get("a")
 380                            .expect("'a' has streamed because input is now complete");
 381                        last_tool_use
 382                            .input
 383                            .get("g")
 384                            .expect("'g' has streamed because input is now complete");
 385                    }
 386                } else {
 387                    panic!("last content should be a tool use");
 388                }
 389            });
 390        }
 391    }
 392
 393    assert!(
 394        saw_partial_tool_use,
 395        "should see at least one partially streamed tool use in the history"
 396    );
 397}
 398
 399#[gpui::test]
 400async fn test_tool_authorization(cx: &mut TestAppContext) {
 401    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 402    let fake_model = model.as_fake();
 403
 404    let mut events = thread
 405        .update(cx, |thread, cx| {
 406            thread.add_tool(ToolRequiringPermission);
 407            thread.send(UserMessageId::new(), ["abc"], cx)
 408        })
 409        .unwrap();
 410    cx.run_until_parked();
 411    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 412        LanguageModelToolUse {
 413            id: "tool_id_1".into(),
 414            name: ToolRequiringPermission.name().into(),
 415            raw_input: "{}".into(),
 416            input: json!({}),
 417            is_input_complete: true,
 418        },
 419    ));
 420    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 421        LanguageModelToolUse {
 422            id: "tool_id_2".into(),
 423            name: ToolRequiringPermission.name().into(),
 424            raw_input: "{}".into(),
 425            input: json!({}),
 426            is_input_complete: true,
 427        },
 428    ));
 429    fake_model.end_last_completion_stream();
 430    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 431    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 432
 433    // Approve the first
 434    tool_call_auth_1
 435        .response
 436        .send(tool_call_auth_1.options[1].id.clone())
 437        .unwrap();
 438    cx.run_until_parked();
 439
 440    // Reject the second
 441    tool_call_auth_2
 442        .response
 443        .send(tool_call_auth_1.options[2].id.clone())
 444        .unwrap();
 445    cx.run_until_parked();
 446
 447    let completion = fake_model.pending_completions().pop().unwrap();
 448    let message = completion.messages.last().unwrap();
 449    assert_eq!(
 450        message.content,
 451        vec![
 452            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 453                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 454                tool_name: ToolRequiringPermission.name().into(),
 455                is_error: false,
 456                content: "Allowed".into(),
 457                output: Some("Allowed".into())
 458            }),
 459            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 460                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 461                tool_name: ToolRequiringPermission.name().into(),
 462                is_error: true,
 463                content: "Permission to run tool denied by user".into(),
 464                output: None
 465            })
 466        ]
 467    );
 468
 469    // Simulate yet another tool call.
 470    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 471        LanguageModelToolUse {
 472            id: "tool_id_3".into(),
 473            name: ToolRequiringPermission.name().into(),
 474            raw_input: "{}".into(),
 475            input: json!({}),
 476            is_input_complete: true,
 477        },
 478    ));
 479    fake_model.end_last_completion_stream();
 480
 481    // Respond by always allowing tools.
 482    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 483    tool_call_auth_3
 484        .response
 485        .send(tool_call_auth_3.options[0].id.clone())
 486        .unwrap();
 487    cx.run_until_parked();
 488    let completion = fake_model.pending_completions().pop().unwrap();
 489    let message = completion.messages.last().unwrap();
 490    assert_eq!(
 491        message.content,
 492        vec![language_model::MessageContent::ToolResult(
 493            LanguageModelToolResult {
 494                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 495                tool_name: ToolRequiringPermission.name().into(),
 496                is_error: false,
 497                content: "Allowed".into(),
 498                output: Some("Allowed".into())
 499            }
 500        )]
 501    );
 502
 503    // Simulate a final tool call, ensuring we don't trigger authorization.
 504    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 505        LanguageModelToolUse {
 506            id: "tool_id_4".into(),
 507            name: ToolRequiringPermission.name().into(),
 508            raw_input: "{}".into(),
 509            input: json!({}),
 510            is_input_complete: true,
 511        },
 512    ));
 513    fake_model.end_last_completion_stream();
 514    cx.run_until_parked();
 515    let completion = fake_model.pending_completions().pop().unwrap();
 516    let message = completion.messages.last().unwrap();
 517    assert_eq!(
 518        message.content,
 519        vec![language_model::MessageContent::ToolResult(
 520            LanguageModelToolResult {
 521                tool_use_id: "tool_id_4".into(),
 522                tool_name: ToolRequiringPermission.name().into(),
 523                is_error: false,
 524                content: "Allowed".into(),
 525                output: Some("Allowed".into())
 526            }
 527        )]
 528    );
 529}
 530
 531#[gpui::test]
 532async fn test_tool_hallucination(cx: &mut TestAppContext) {
 533    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 534    let fake_model = model.as_fake();
 535
 536    let mut events = thread
 537        .update(cx, |thread, cx| {
 538            thread.send(UserMessageId::new(), ["abc"], cx)
 539        })
 540        .unwrap();
 541    cx.run_until_parked();
 542    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 543        LanguageModelToolUse {
 544            id: "tool_id_1".into(),
 545            name: "nonexistent_tool".into(),
 546            raw_input: "{}".into(),
 547            input: json!({}),
 548            is_input_complete: true,
 549        },
 550    ));
 551    fake_model.end_last_completion_stream();
 552
 553    let tool_call = expect_tool_call(&mut events).await;
 554    assert_eq!(tool_call.title, "nonexistent_tool");
 555    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 556    let update = expect_tool_call_update_fields(&mut events).await;
 557    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 558}
 559
 560#[gpui::test]
 561async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 562    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 563    let fake_model = model.as_fake();
 564
 565    let events = thread
 566        .update(cx, |thread, cx| {
 567            thread.add_tool(EchoTool);
 568            thread.send(UserMessageId::new(), ["abc"], cx)
 569        })
 570        .unwrap();
 571    cx.run_until_parked();
 572    let tool_use = LanguageModelToolUse {
 573        id: "tool_id_1".into(),
 574        name: EchoTool.name().into(),
 575        raw_input: "{}".into(),
 576        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 577        is_input_complete: true,
 578    };
 579    fake_model
 580        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 581    fake_model.end_last_completion_stream();
 582
 583    cx.run_until_parked();
 584    let completion = fake_model.pending_completions().pop().unwrap();
 585    let tool_result = LanguageModelToolResult {
 586        tool_use_id: "tool_id_1".into(),
 587        tool_name: EchoTool.name().into(),
 588        is_error: false,
 589        content: "def".into(),
 590        output: Some("def".into()),
 591    };
 592    assert_eq!(
 593        completion.messages[1..],
 594        vec![
 595            LanguageModelRequestMessage {
 596                role: Role::User,
 597                content: vec!["abc".into()],
 598                cache: false
 599            },
 600            LanguageModelRequestMessage {
 601                role: Role::Assistant,
 602                content: vec![MessageContent::ToolUse(tool_use.clone())],
 603                cache: false
 604            },
 605            LanguageModelRequestMessage {
 606                role: Role::User,
 607                content: vec![MessageContent::ToolResult(tool_result.clone())],
 608                cache: true
 609            },
 610        ]
 611    );
 612
 613    // Simulate reaching tool use limit.
 614    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 615        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 616    ));
 617    fake_model.end_last_completion_stream();
 618    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 619    assert!(
 620        last_event
 621            .unwrap_err()
 622            .is::<language_model::ToolUseLimitReachedError>()
 623    );
 624
 625    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 626    cx.run_until_parked();
 627    let completion = fake_model.pending_completions().pop().unwrap();
 628    assert_eq!(
 629        completion.messages[1..],
 630        vec![
 631            LanguageModelRequestMessage {
 632                role: Role::User,
 633                content: vec!["abc".into()],
 634                cache: false
 635            },
 636            LanguageModelRequestMessage {
 637                role: Role::Assistant,
 638                content: vec![MessageContent::ToolUse(tool_use)],
 639                cache: false
 640            },
 641            LanguageModelRequestMessage {
 642                role: Role::User,
 643                content: vec![MessageContent::ToolResult(tool_result)],
 644                cache: false
 645            },
 646            LanguageModelRequestMessage {
 647                role: Role::User,
 648                content: vec!["Continue where you left off".into()],
 649                cache: true
 650            }
 651        ]
 652    );
 653
 654    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 655    fake_model.end_last_completion_stream();
 656    events.collect::<Vec<_>>().await;
 657    thread.read_with(cx, |thread, _cx| {
 658        assert_eq!(
 659            thread.last_message().unwrap().to_markdown(),
 660            indoc! {"
 661                ## Assistant
 662
 663                Done
 664            "}
 665        )
 666    });
 667
 668    // Ensure we error if calling resume when tool use limit was *not* reached.
 669    let error = thread
 670        .update(cx, |thread, cx| thread.resume(cx))
 671        .unwrap_err();
 672    assert_eq!(
 673        error.to_string(),
 674        "can only resume after tool use limit is reached"
 675    )
 676}
 677
 678#[gpui::test]
 679async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 680    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 681    let fake_model = model.as_fake();
 682
 683    let events = thread
 684        .update(cx, |thread, cx| {
 685            thread.add_tool(EchoTool);
 686            thread.send(UserMessageId::new(), ["abc"], cx)
 687        })
 688        .unwrap();
 689    cx.run_until_parked();
 690
 691    let tool_use = LanguageModelToolUse {
 692        id: "tool_id_1".into(),
 693        name: EchoTool.name().into(),
 694        raw_input: "{}".into(),
 695        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 696        is_input_complete: true,
 697    };
 698    let tool_result = LanguageModelToolResult {
 699        tool_use_id: "tool_id_1".into(),
 700        tool_name: EchoTool.name().into(),
 701        is_error: false,
 702        content: "def".into(),
 703        output: Some("def".into()),
 704    };
 705    fake_model
 706        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 707    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 708        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 709    ));
 710    fake_model.end_last_completion_stream();
 711    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 712    assert!(
 713        last_event
 714            .unwrap_err()
 715            .is::<language_model::ToolUseLimitReachedError>()
 716    );
 717
 718    thread
 719        .update(cx, |thread, cx| {
 720            thread.send(UserMessageId::new(), vec!["ghi"], cx)
 721        })
 722        .unwrap();
 723    cx.run_until_parked();
 724    let completion = fake_model.pending_completions().pop().unwrap();
 725    assert_eq!(
 726        completion.messages[1..],
 727        vec![
 728            LanguageModelRequestMessage {
 729                role: Role::User,
 730                content: vec!["abc".into()],
 731                cache: false
 732            },
 733            LanguageModelRequestMessage {
 734                role: Role::Assistant,
 735                content: vec![MessageContent::ToolUse(tool_use)],
 736                cache: false
 737            },
 738            LanguageModelRequestMessage {
 739                role: Role::User,
 740                content: vec![MessageContent::ToolResult(tool_result)],
 741                cache: false
 742            },
 743            LanguageModelRequestMessage {
 744                role: Role::User,
 745                content: vec!["ghi".into()],
 746                cache: true
 747            }
 748        ]
 749    );
 750}
 751
 752async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
 753    let event = events
 754        .next()
 755        .await
 756        .expect("no tool call authorization event received")
 757        .unwrap();
 758    match event {
 759        ThreadEvent::ToolCall(tool_call) => tool_call,
 760        event => {
 761            panic!("Unexpected event {event:?}");
 762        }
 763    }
 764}
 765
 766async fn expect_tool_call_update_fields(
 767    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 768) -> acp::ToolCallUpdate {
 769    let event = events
 770        .next()
 771        .await
 772        .expect("no tool call authorization event received")
 773        .unwrap();
 774    match event {
 775        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
 776        event => {
 777            panic!("Unexpected event {event:?}");
 778        }
 779    }
 780}
 781
 782async fn next_tool_call_authorization(
 783    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 784) -> ToolCallAuthorization {
 785    loop {
 786        let event = events
 787            .next()
 788            .await
 789            .expect("no tool call authorization event received")
 790            .unwrap();
 791        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
 792            let permission_kinds = tool_call_authorization
 793                .options
 794                .iter()
 795                .map(|o| o.kind)
 796                .collect::<Vec<_>>();
 797            assert_eq!(
 798                permission_kinds,
 799                vec![
 800                    acp::PermissionOptionKind::AllowAlways,
 801                    acp::PermissionOptionKind::AllowOnce,
 802                    acp::PermissionOptionKind::RejectOnce,
 803                ]
 804            );
 805            return tool_call_authorization;
 806        }
 807    }
 808}
 809
 810#[gpui::test]
 811#[cfg_attr(not(feature = "e2e"), ignore)]
 812async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 813    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 814
 815    // Test concurrent tool calls with different delay times
 816    let events = thread
 817        .update(cx, |thread, cx| {
 818            thread.add_tool(DelayTool);
 819            thread.send(
 820                UserMessageId::new(),
 821                [
 822                    "Call the delay tool twice in the same message.",
 823                    "Once with 100ms. Once with 300ms.",
 824                    "When both timers are complete, describe the outputs.",
 825                ],
 826                cx,
 827            )
 828        })
 829        .unwrap()
 830        .collect()
 831        .await;
 832
 833    let stop_reasons = stop_events(events);
 834    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 835
 836    thread.update(cx, |thread, _cx| {
 837        let last_message = thread.last_message().unwrap();
 838        let agent_message = last_message.as_agent_message().unwrap();
 839        let text = agent_message
 840            .content
 841            .iter()
 842            .filter_map(|content| {
 843                if let AgentMessageContent::Text(text) = content {
 844                    Some(text.as_str())
 845                } else {
 846                    None
 847                }
 848            })
 849            .collect::<String>();
 850
 851        assert!(text.contains("Ding"));
 852    });
 853}
 854
 855#[gpui::test]
 856async fn test_profiles(cx: &mut TestAppContext) {
 857    let ThreadTest {
 858        model, thread, fs, ..
 859    } = setup(cx, TestModel::Fake).await;
 860    let fake_model = model.as_fake();
 861
 862    thread.update(cx, |thread, _cx| {
 863        thread.add_tool(DelayTool);
 864        thread.add_tool(EchoTool);
 865        thread.add_tool(InfiniteTool);
 866    });
 867
 868    // Override profiles and wait for settings to be loaded.
 869    fs.insert_file(
 870        paths::settings_file(),
 871        json!({
 872            "agent": {
 873                "profiles": {
 874                    "test-1": {
 875                        "name": "Test Profile 1",
 876                        "tools": {
 877                            EchoTool.name(): true,
 878                            DelayTool.name(): true,
 879                        }
 880                    },
 881                    "test-2": {
 882                        "name": "Test Profile 2",
 883                        "tools": {
 884                            InfiniteTool.name(): true,
 885                        }
 886                    }
 887                }
 888            }
 889        })
 890        .to_string()
 891        .into_bytes(),
 892    )
 893    .await;
 894    cx.run_until_parked();
 895
 896    // Test that test-1 profile (default) has echo and delay tools
 897    thread
 898        .update(cx, |thread, cx| {
 899            thread.set_profile(AgentProfileId("test-1".into()));
 900            thread.send(UserMessageId::new(), ["test"], cx)
 901        })
 902        .unwrap();
 903    cx.run_until_parked();
 904
 905    let mut pending_completions = fake_model.pending_completions();
 906    assert_eq!(pending_completions.len(), 1);
 907    let completion = pending_completions.pop().unwrap();
 908    let tool_names: Vec<String> = completion
 909        .tools
 910        .iter()
 911        .map(|tool| tool.name.clone())
 912        .collect();
 913    assert_eq!(tool_names, vec![DelayTool.name(), EchoTool.name()]);
 914    fake_model.end_last_completion_stream();
 915
 916    // Switch to test-2 profile, and verify that it has only the infinite tool.
 917    thread
 918        .update(cx, |thread, cx| {
 919            thread.set_profile(AgentProfileId("test-2".into()));
 920            thread.send(UserMessageId::new(), ["test2"], cx)
 921        })
 922        .unwrap();
 923    cx.run_until_parked();
 924    let mut pending_completions = fake_model.pending_completions();
 925    assert_eq!(pending_completions.len(), 1);
 926    let completion = pending_completions.pop().unwrap();
 927    let tool_names: Vec<String> = completion
 928        .tools
 929        .iter()
 930        .map(|tool| tool.name.clone())
 931        .collect();
 932    assert_eq!(tool_names, vec![InfiniteTool.name()]);
 933}
 934
 935#[gpui::test]
 936#[cfg_attr(not(feature = "e2e"), ignore)]
 937async fn test_cancellation(cx: &mut TestAppContext) {
 938    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 939
 940    let mut events = thread
 941        .update(cx, |thread, cx| {
 942            thread.add_tool(InfiniteTool);
 943            thread.add_tool(EchoTool);
 944            thread.send(
 945                UserMessageId::new(),
 946                ["Call the echo tool, then call the infinite tool, then explain their output"],
 947                cx,
 948            )
 949        })
 950        .unwrap();
 951
 952    // Wait until both tools are called.
 953    let mut expected_tools = vec!["Echo", "Infinite Tool"];
 954    let mut echo_id = None;
 955    let mut echo_completed = false;
 956    while let Some(event) = events.next().await {
 957        match event.unwrap() {
 958            ThreadEvent::ToolCall(tool_call) => {
 959                assert_eq!(tool_call.title, expected_tools.remove(0));
 960                if tool_call.title == "Echo" {
 961                    echo_id = Some(tool_call.id);
 962                }
 963            }
 964            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
 965                acp::ToolCallUpdate {
 966                    id,
 967                    fields:
 968                        acp::ToolCallUpdateFields {
 969                            status: Some(acp::ToolCallStatus::Completed),
 970                            ..
 971                        },
 972                },
 973            )) if Some(&id) == echo_id.as_ref() => {
 974                echo_completed = true;
 975            }
 976            _ => {}
 977        }
 978
 979        if expected_tools.is_empty() && echo_completed {
 980            break;
 981        }
 982    }
 983
 984    // Cancel the current send and ensure that the event stream is closed, even
 985    // if one of the tools is still running.
 986    thread.update(cx, |thread, cx| thread.cancel(cx));
 987    let events = events.collect::<Vec<_>>().await;
 988    let last_event = events.last();
 989    assert!(
 990        matches!(
 991            last_event,
 992            Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
 993        ),
 994        "unexpected event {last_event:?}"
 995    );
 996
 997    // Ensure we can still send a new message after cancellation.
 998    let events = thread
 999        .update(cx, |thread, cx| {
1000            thread.send(
1001                UserMessageId::new(),
1002                ["Testing: reply with 'Hello' then stop."],
1003                cx,
1004            )
1005        })
1006        .unwrap()
1007        .collect::<Vec<_>>()
1008        .await;
1009    thread.update(cx, |thread, _cx| {
1010        let message = thread.last_message().unwrap();
1011        let agent_message = message.as_agent_message().unwrap();
1012        assert_eq!(
1013            agent_message.content,
1014            vec![AgentMessageContent::Text("Hello".to_string())]
1015        );
1016    });
1017    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1018}
1019
1020#[gpui::test]
1021async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1022    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1023    let fake_model = model.as_fake();
1024
1025    let events_1 = thread
1026        .update(cx, |thread, cx| {
1027            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1028        })
1029        .unwrap();
1030    cx.run_until_parked();
1031    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1032    cx.run_until_parked();
1033
1034    let events_2 = thread
1035        .update(cx, |thread, cx| {
1036            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1037        })
1038        .unwrap();
1039    cx.run_until_parked();
1040    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1041    fake_model
1042        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1043    fake_model.end_last_completion_stream();
1044
1045    let events_1 = events_1.collect::<Vec<_>>().await;
1046    assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1047    let events_2 = events_2.collect::<Vec<_>>().await;
1048    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1049}
1050
1051#[gpui::test]
1052async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1053    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1054    let fake_model = model.as_fake();
1055
1056    let events_1 = thread
1057        .update(cx, |thread, cx| {
1058            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1059        })
1060        .unwrap();
1061    cx.run_until_parked();
1062    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1063    fake_model
1064        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1065    fake_model.end_last_completion_stream();
1066    let events_1 = events_1.collect::<Vec<_>>().await;
1067
1068    let events_2 = thread
1069        .update(cx, |thread, cx| {
1070            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1071        })
1072        .unwrap();
1073    cx.run_until_parked();
1074    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1075    fake_model
1076        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1077    fake_model.end_last_completion_stream();
1078    let events_2 = events_2.collect::<Vec<_>>().await;
1079
1080    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1081    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1082}
1083
1084#[gpui::test]
1085async fn test_refusal(cx: &mut TestAppContext) {
1086    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1087    let fake_model = model.as_fake();
1088
1089    let events = thread
1090        .update(cx, |thread, cx| {
1091            thread.send(UserMessageId::new(), ["Hello"], cx)
1092        })
1093        .unwrap();
1094    cx.run_until_parked();
1095    thread.read_with(cx, |thread, _| {
1096        assert_eq!(
1097            thread.to_markdown(),
1098            indoc! {"
1099                ## User
1100
1101                Hello
1102            "}
1103        );
1104    });
1105
1106    fake_model.send_last_completion_stream_text_chunk("Hey!");
1107    cx.run_until_parked();
1108    thread.read_with(cx, |thread, _| {
1109        assert_eq!(
1110            thread.to_markdown(),
1111            indoc! {"
1112                ## User
1113
1114                Hello
1115
1116                ## Assistant
1117
1118                Hey!
1119            "}
1120        );
1121    });
1122
1123    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1124    fake_model
1125        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1126    let events = events.collect::<Vec<_>>().await;
1127    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1128    thread.read_with(cx, |thread, _| {
1129        assert_eq!(thread.to_markdown(), "");
1130    });
1131}
1132
1133#[gpui::test]
1134async fn test_truncate_first_message(cx: &mut TestAppContext) {
1135    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1136    let fake_model = model.as_fake();
1137
1138    let message_id = UserMessageId::new();
1139    thread
1140        .update(cx, |thread, cx| {
1141            thread.send(message_id.clone(), ["Hello"], cx)
1142        })
1143        .unwrap();
1144    cx.run_until_parked();
1145    thread.read_with(cx, |thread, _| {
1146        assert_eq!(
1147            thread.to_markdown(),
1148            indoc! {"
1149                ## User
1150
1151                Hello
1152            "}
1153        );
1154        assert_eq!(thread.latest_token_usage(), None);
1155    });
1156
1157    fake_model.send_last_completion_stream_text_chunk("Hey!");
1158    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1159        language_model::TokenUsage {
1160            input_tokens: 32_000,
1161            output_tokens: 16_000,
1162            cache_creation_input_tokens: 0,
1163            cache_read_input_tokens: 0,
1164        },
1165    ));
1166    cx.run_until_parked();
1167    thread.read_with(cx, |thread, _| {
1168        assert_eq!(
1169            thread.to_markdown(),
1170            indoc! {"
1171                ## User
1172
1173                Hello
1174
1175                ## Assistant
1176
1177                Hey!
1178            "}
1179        );
1180        assert_eq!(
1181            thread.latest_token_usage(),
1182            Some(acp_thread::TokenUsage {
1183                used_tokens: 32_000 + 16_000,
1184                max_tokens: 1_000_000,
1185            })
1186        );
1187    });
1188
1189    thread
1190        .update(cx, |thread, cx| thread.truncate(message_id, cx))
1191        .unwrap();
1192    cx.run_until_parked();
1193    thread.read_with(cx, |thread, _| {
1194        assert_eq!(thread.to_markdown(), "");
1195        assert_eq!(thread.latest_token_usage(), None);
1196    });
1197
1198    // Ensure we can still send a new message after truncation.
1199    thread
1200        .update(cx, |thread, cx| {
1201            thread.send(UserMessageId::new(), ["Hi"], cx)
1202        })
1203        .unwrap();
1204    thread.update(cx, |thread, _cx| {
1205        assert_eq!(
1206            thread.to_markdown(),
1207            indoc! {"
1208                ## User
1209
1210                Hi
1211            "}
1212        );
1213    });
1214    cx.run_until_parked();
1215    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1216    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1217        language_model::TokenUsage {
1218            input_tokens: 40_000,
1219            output_tokens: 20_000,
1220            cache_creation_input_tokens: 0,
1221            cache_read_input_tokens: 0,
1222        },
1223    ));
1224    cx.run_until_parked();
1225    thread.read_with(cx, |thread, _| {
1226        assert_eq!(
1227            thread.to_markdown(),
1228            indoc! {"
1229                ## User
1230
1231                Hi
1232
1233                ## Assistant
1234
1235                Ahoy!
1236            "}
1237        );
1238
1239        assert_eq!(
1240            thread.latest_token_usage(),
1241            Some(acp_thread::TokenUsage {
1242                used_tokens: 40_000 + 20_000,
1243                max_tokens: 1_000_000,
1244            })
1245        );
1246    });
1247}
1248
1249#[gpui::test]
1250async fn test_truncate_second_message(cx: &mut TestAppContext) {
1251    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1252    let fake_model = model.as_fake();
1253
1254    thread
1255        .update(cx, |thread, cx| {
1256            thread.send(UserMessageId::new(), ["Message 1"], cx)
1257        })
1258        .unwrap();
1259    cx.run_until_parked();
1260    fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1261    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1262        language_model::TokenUsage {
1263            input_tokens: 32_000,
1264            output_tokens: 16_000,
1265            cache_creation_input_tokens: 0,
1266            cache_read_input_tokens: 0,
1267        },
1268    ));
1269    fake_model.end_last_completion_stream();
1270    cx.run_until_parked();
1271
1272    let assert_first_message_state = |cx: &mut TestAppContext| {
1273        thread.clone().read_with(cx, |thread, _| {
1274            assert_eq!(
1275                thread.to_markdown(),
1276                indoc! {"
1277                    ## User
1278
1279                    Message 1
1280
1281                    ## Assistant
1282
1283                    Message 1 response
1284                "}
1285            );
1286
1287            assert_eq!(
1288                thread.latest_token_usage(),
1289                Some(acp_thread::TokenUsage {
1290                    used_tokens: 32_000 + 16_000,
1291                    max_tokens: 1_000_000,
1292                })
1293            );
1294        });
1295    };
1296
1297    assert_first_message_state(cx);
1298
1299    let second_message_id = UserMessageId::new();
1300    thread
1301        .update(cx, |thread, cx| {
1302            thread.send(second_message_id.clone(), ["Message 2"], cx)
1303        })
1304        .unwrap();
1305    cx.run_until_parked();
1306
1307    fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1308    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1309        language_model::TokenUsage {
1310            input_tokens: 40_000,
1311            output_tokens: 20_000,
1312            cache_creation_input_tokens: 0,
1313            cache_read_input_tokens: 0,
1314        },
1315    ));
1316    fake_model.end_last_completion_stream();
1317    cx.run_until_parked();
1318
1319    thread.read_with(cx, |thread, _| {
1320        assert_eq!(
1321            thread.to_markdown(),
1322            indoc! {"
1323                ## User
1324
1325                Message 1
1326
1327                ## Assistant
1328
1329                Message 1 response
1330
1331                ## User
1332
1333                Message 2
1334
1335                ## Assistant
1336
1337                Message 2 response
1338            "}
1339        );
1340
1341        assert_eq!(
1342            thread.latest_token_usage(),
1343            Some(acp_thread::TokenUsage {
1344                used_tokens: 40_000 + 20_000,
1345                max_tokens: 1_000_000,
1346            })
1347        );
1348    });
1349
1350    thread
1351        .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1352        .unwrap();
1353    cx.run_until_parked();
1354
1355    assert_first_message_state(cx);
1356}
1357
1358#[gpui::test]
1359async fn test_title_generation(cx: &mut TestAppContext) {
1360    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1361    let fake_model = model.as_fake();
1362
1363    let summary_model = Arc::new(FakeLanguageModel::default());
1364    thread.update(cx, |thread, cx| {
1365        thread.set_summarization_model(Some(summary_model.clone()), cx)
1366    });
1367
1368    let send = thread
1369        .update(cx, |thread, cx| {
1370            thread.send(UserMessageId::new(), ["Hello"], cx)
1371        })
1372        .unwrap();
1373    cx.run_until_parked();
1374
1375    fake_model.send_last_completion_stream_text_chunk("Hey!");
1376    fake_model.end_last_completion_stream();
1377    cx.run_until_parked();
1378    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1379
1380    // Ensure the summary model has been invoked to generate a title.
1381    summary_model.send_last_completion_stream_text_chunk("Hello ");
1382    summary_model.send_last_completion_stream_text_chunk("world\nG");
1383    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1384    summary_model.end_last_completion_stream();
1385    send.collect::<Vec<_>>().await;
1386    cx.run_until_parked();
1387    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1388
1389    // Send another message, ensuring no title is generated this time.
1390    let send = thread
1391        .update(cx, |thread, cx| {
1392            thread.send(UserMessageId::new(), ["Hello again"], cx)
1393        })
1394        .unwrap();
1395    cx.run_until_parked();
1396    fake_model.send_last_completion_stream_text_chunk("Hey again!");
1397    fake_model.end_last_completion_stream();
1398    cx.run_until_parked();
1399    assert_eq!(summary_model.pending_completions(), Vec::new());
1400    send.collect::<Vec<_>>().await;
1401    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1402}
1403
1404#[gpui::test]
1405async fn test_agent_connection(cx: &mut TestAppContext) {
1406    cx.update(settings::init);
1407    let templates = Templates::new();
1408
1409    // Initialize language model system with test provider
1410    cx.update(|cx| {
1411        gpui_tokio::init(cx);
1412        client::init_settings(cx);
1413
1414        let http_client = FakeHttpClient::with_404_response();
1415        let clock = Arc::new(clock::FakeSystemClock::new());
1416        let client = Client::new(clock, http_client, cx);
1417        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1418        language_model::init(client.clone(), cx);
1419        language_models::init(user_store, client.clone(), cx);
1420        Project::init_settings(cx);
1421        LanguageModelRegistry::test(cx);
1422        agent_settings::init(cx);
1423    });
1424    cx.executor().forbid_parking();
1425
1426    // Create a project for new_thread
1427    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1428    fake_fs.insert_tree(path!("/test"), json!({})).await;
1429    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1430    let cwd = Path::new("/test");
1431    let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1432    let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1433
1434    // Create agent and connection
1435    let agent = NativeAgent::new(
1436        project.clone(),
1437        history_store,
1438        templates.clone(),
1439        None,
1440        fake_fs.clone(),
1441        &mut cx.to_async(),
1442    )
1443    .await
1444    .unwrap();
1445    let connection = NativeAgentConnection(agent.clone());
1446
1447    // Test model_selector returns Some
1448    let selector_opt = connection.model_selector();
1449    assert!(
1450        selector_opt.is_some(),
1451        "agent2 should always support ModelSelector"
1452    );
1453    let selector = selector_opt.unwrap();
1454
1455    // Test list_models
1456    let listed_models = cx
1457        .update(|cx| selector.list_models(cx))
1458        .await
1459        .expect("list_models should succeed");
1460    let AgentModelList::Grouped(listed_models) = listed_models else {
1461        panic!("Unexpected model list type");
1462    };
1463    assert!(!listed_models.is_empty(), "should have at least one model");
1464    assert_eq!(
1465        listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1466        "fake/fake"
1467    );
1468
1469    // Create a thread using new_thread
1470    let connection_rc = Rc::new(connection.clone());
1471    let acp_thread = cx
1472        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1473        .await
1474        .expect("new_thread should succeed");
1475
1476    // Get the session_id from the AcpThread
1477    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1478
1479    // Test selected_model returns the default
1480    let model = cx
1481        .update(|cx| selector.selected_model(&session_id, cx))
1482        .await
1483        .expect("selected_model should succeed");
1484    let model = cx
1485        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1486        .unwrap();
1487    let model = model.as_fake();
1488    assert_eq!(model.id().0, "fake", "should return default model");
1489
1490    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1491    cx.run_until_parked();
1492    model.send_last_completion_stream_text_chunk("def");
1493    cx.run_until_parked();
1494    acp_thread.read_with(cx, |thread, cx| {
1495        assert_eq!(
1496            thread.to_markdown(cx),
1497            indoc! {"
1498                ## User
1499
1500                abc
1501
1502                ## Assistant
1503
1504                def
1505
1506            "}
1507        )
1508    });
1509
1510    // Test cancel
1511    cx.update(|cx| connection.cancel(&session_id, cx));
1512    request.await.expect("prompt should fail gracefully");
1513
1514    // Ensure that dropping the ACP thread causes the native thread to be
1515    // dropped as well.
1516    cx.update(|_| drop(acp_thread));
1517    let result = cx
1518        .update(|cx| {
1519            connection.prompt(
1520                Some(acp_thread::UserMessageId::new()),
1521                acp::PromptRequest {
1522                    session_id: session_id.clone(),
1523                    prompt: vec!["ghi".into()],
1524                },
1525                cx,
1526            )
1527        })
1528        .await;
1529    assert_eq!(
1530        result.as_ref().unwrap_err().to_string(),
1531        "Session not found",
1532        "unexpected result: {:?}",
1533        result
1534    );
1535}
1536
1537#[gpui::test]
1538async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1539    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1540    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1541    let fake_model = model.as_fake();
1542
1543    let mut events = thread
1544        .update(cx, |thread, cx| {
1545            thread.send(UserMessageId::new(), ["Think"], cx)
1546        })
1547        .unwrap();
1548    cx.run_until_parked();
1549
1550    // Simulate streaming partial input.
1551    let input = json!({});
1552    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1553        LanguageModelToolUse {
1554            id: "1".into(),
1555            name: ThinkingTool.name().into(),
1556            raw_input: input.to_string(),
1557            input,
1558            is_input_complete: false,
1559        },
1560    ));
1561
1562    // Input streaming completed
1563    let input = json!({ "content": "Thinking hard!" });
1564    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1565        LanguageModelToolUse {
1566            id: "1".into(),
1567            name: "thinking".into(),
1568            raw_input: input.to_string(),
1569            input,
1570            is_input_complete: true,
1571        },
1572    ));
1573    fake_model.end_last_completion_stream();
1574    cx.run_until_parked();
1575
1576    let tool_call = expect_tool_call(&mut events).await;
1577    assert_eq!(
1578        tool_call,
1579        acp::ToolCall {
1580            id: acp::ToolCallId("1".into()),
1581            title: "Thinking".into(),
1582            kind: acp::ToolKind::Think,
1583            status: acp::ToolCallStatus::Pending,
1584            content: vec![],
1585            locations: vec![],
1586            raw_input: Some(json!({})),
1587            raw_output: None,
1588        }
1589    );
1590    let update = expect_tool_call_update_fields(&mut events).await;
1591    assert_eq!(
1592        update,
1593        acp::ToolCallUpdate {
1594            id: acp::ToolCallId("1".into()),
1595            fields: acp::ToolCallUpdateFields {
1596                title: Some("Thinking".into()),
1597                kind: Some(acp::ToolKind::Think),
1598                raw_input: Some(json!({ "content": "Thinking hard!" })),
1599                ..Default::default()
1600            },
1601        }
1602    );
1603    let update = expect_tool_call_update_fields(&mut events).await;
1604    assert_eq!(
1605        update,
1606        acp::ToolCallUpdate {
1607            id: acp::ToolCallId("1".into()),
1608            fields: acp::ToolCallUpdateFields {
1609                status: Some(acp::ToolCallStatus::InProgress),
1610                ..Default::default()
1611            },
1612        }
1613    );
1614    let update = expect_tool_call_update_fields(&mut events).await;
1615    assert_eq!(
1616        update,
1617        acp::ToolCallUpdate {
1618            id: acp::ToolCallId("1".into()),
1619            fields: acp::ToolCallUpdateFields {
1620                content: Some(vec!["Thinking hard!".into()]),
1621                ..Default::default()
1622            },
1623        }
1624    );
1625    let update = expect_tool_call_update_fields(&mut events).await;
1626    assert_eq!(
1627        update,
1628        acp::ToolCallUpdate {
1629            id: acp::ToolCallId("1".into()),
1630            fields: acp::ToolCallUpdateFields {
1631                status: Some(acp::ToolCallStatus::Completed),
1632                raw_output: Some("Finished thinking.".into()),
1633                ..Default::default()
1634            },
1635        }
1636    );
1637}
1638
1639#[gpui::test]
1640async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
1641    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1642    let fake_model = model.as_fake();
1643
1644    let mut events = thread
1645        .update(cx, |thread, cx| {
1646            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1647            thread.send(UserMessageId::new(), ["Hello!"], cx)
1648        })
1649        .unwrap();
1650    cx.run_until_parked();
1651
1652    fake_model.send_last_completion_stream_text_chunk("Hey!");
1653    fake_model.end_last_completion_stream();
1654
1655    let mut retry_events = Vec::new();
1656    while let Some(Ok(event)) = events.next().await {
1657        match event {
1658            ThreadEvent::Retry(retry_status) => {
1659                retry_events.push(retry_status);
1660            }
1661            ThreadEvent::Stop(..) => break,
1662            _ => {}
1663        }
1664    }
1665
1666    assert_eq!(retry_events.len(), 0);
1667    thread.read_with(cx, |thread, _cx| {
1668        assert_eq!(
1669            thread.to_markdown(),
1670            indoc! {"
1671                ## User
1672
1673                Hello!
1674
1675                ## Assistant
1676
1677                Hey!
1678            "}
1679        )
1680    });
1681}
1682
1683#[gpui::test]
1684async fn test_send_retry_on_error(cx: &mut TestAppContext) {
1685    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1686    let fake_model = model.as_fake();
1687
1688    let mut events = thread
1689        .update(cx, |thread, cx| {
1690            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1691            thread.send(UserMessageId::new(), ["Hello!"], cx)
1692        })
1693        .unwrap();
1694    cx.run_until_parked();
1695
1696    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
1697        provider: LanguageModelProviderName::new("Anthropic"),
1698        retry_after: Some(Duration::from_secs(3)),
1699    });
1700    fake_model.end_last_completion_stream();
1701
1702    cx.executor().advance_clock(Duration::from_secs(3));
1703    cx.run_until_parked();
1704
1705    fake_model.send_last_completion_stream_text_chunk("Hey!");
1706    fake_model.end_last_completion_stream();
1707
1708    let mut retry_events = Vec::new();
1709    while let Some(Ok(event)) = events.next().await {
1710        match event {
1711            ThreadEvent::Retry(retry_status) => {
1712                retry_events.push(retry_status);
1713            }
1714            ThreadEvent::Stop(..) => break,
1715            _ => {}
1716        }
1717    }
1718
1719    assert_eq!(retry_events.len(), 1);
1720    assert!(matches!(
1721        retry_events[0],
1722        acp_thread::RetryStatus { attempt: 1, .. }
1723    ));
1724    thread.read_with(cx, |thread, _cx| {
1725        assert_eq!(
1726            thread.to_markdown(),
1727            indoc! {"
1728                ## User
1729
1730                Hello!
1731
1732                ## Assistant
1733
1734                Hey!
1735            "}
1736        )
1737    });
1738}
1739
1740#[gpui::test]
1741async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
1742    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1743    let fake_model = model.as_fake();
1744
1745    let mut events = thread
1746        .update(cx, |thread, cx| {
1747            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1748            thread.send(UserMessageId::new(), ["Hello!"], cx)
1749        })
1750        .unwrap();
1751    cx.run_until_parked();
1752
1753    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
1754        fake_model.send_last_completion_stream_error(
1755            LanguageModelCompletionError::ServerOverloaded {
1756                provider: LanguageModelProviderName::new("Anthropic"),
1757                retry_after: Some(Duration::from_secs(3)),
1758            },
1759        );
1760        fake_model.end_last_completion_stream();
1761        cx.executor().advance_clock(Duration::from_secs(3));
1762        cx.run_until_parked();
1763    }
1764
1765    let mut errors = Vec::new();
1766    let mut retry_events = Vec::new();
1767    while let Some(event) = events.next().await {
1768        match event {
1769            Ok(ThreadEvent::Retry(retry_status)) => {
1770                retry_events.push(retry_status);
1771            }
1772            Ok(ThreadEvent::Stop(..)) => break,
1773            Err(error) => errors.push(error),
1774            _ => {}
1775        }
1776    }
1777
1778    assert_eq!(
1779        retry_events.len(),
1780        crate::thread::MAX_RETRY_ATTEMPTS as usize
1781    );
1782    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
1783        assert_eq!(retry_events[i].attempt, i + 1);
1784    }
1785    assert_eq!(errors.len(), 1);
1786    let error = errors[0]
1787        .downcast_ref::<LanguageModelCompletionError>()
1788        .unwrap();
1789    assert!(matches!(
1790        error,
1791        LanguageModelCompletionError::ServerOverloaded { .. }
1792    ));
1793}
1794
1795/// Filters out the stop events for asserting against in tests
1796fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
1797    result_events
1798        .into_iter()
1799        .filter_map(|event| match event.unwrap() {
1800            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
1801            _ => None,
1802        })
1803        .collect()
1804}
1805
1806struct ThreadTest {
1807    model: Arc<dyn LanguageModel>,
1808    thread: Entity<Thread>,
1809    project_context: Entity<ProjectContext>,
1810    fs: Arc<FakeFs>,
1811}
1812
1813enum TestModel {
1814    Sonnet4,
1815    Fake,
1816}
1817
1818impl TestModel {
1819    fn id(&self) -> LanguageModelId {
1820        match self {
1821            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
1822            TestModel::Fake => unreachable!(),
1823        }
1824    }
1825}
1826
1827async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
1828    cx.executor().allow_parking();
1829
1830    let fs = FakeFs::new(cx.background_executor.clone());
1831    fs.create_dir(paths::settings_file().parent().unwrap())
1832        .await
1833        .unwrap();
1834    fs.insert_file(
1835        paths::settings_file(),
1836        json!({
1837            "agent": {
1838                "default_profile": "test-profile",
1839                "profiles": {
1840                    "test-profile": {
1841                        "name": "Test Profile",
1842                        "tools": {
1843                            EchoTool.name(): true,
1844                            DelayTool.name(): true,
1845                            WordListTool.name(): true,
1846                            ToolRequiringPermission.name(): true,
1847                            InfiniteTool.name(): true,
1848                        }
1849                    }
1850                }
1851            }
1852        })
1853        .to_string()
1854        .into_bytes(),
1855    )
1856    .await;
1857
1858    cx.update(|cx| {
1859        settings::init(cx);
1860        Project::init_settings(cx);
1861        agent_settings::init(cx);
1862        gpui_tokio::init(cx);
1863        let http_client = ReqwestClient::user_agent("agent tests").unwrap();
1864        cx.set_http_client(Arc::new(http_client));
1865
1866        client::init_settings(cx);
1867        let client = Client::production(cx);
1868        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1869        language_model::init(client.clone(), cx);
1870        language_models::init(user_store, client.clone(), cx);
1871
1872        watch_settings(fs.clone(), cx);
1873    });
1874
1875    let templates = Templates::new();
1876
1877    fs.insert_tree(path!("/test"), json!({})).await;
1878    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1879
1880    let model = cx
1881        .update(|cx| {
1882            if let TestModel::Fake = model {
1883                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
1884            } else {
1885                let model_id = model.id();
1886                let models = LanguageModelRegistry::read_global(cx);
1887                let model = models
1888                    .available_models(cx)
1889                    .find(|model| model.id() == model_id)
1890                    .unwrap();
1891
1892                let provider = models.provider(&model.provider_id()).unwrap();
1893                let authenticated = provider.authenticate(cx);
1894
1895                cx.spawn(async move |_cx| {
1896                    authenticated.await.unwrap();
1897                    model
1898                })
1899            }
1900        })
1901        .await;
1902
1903    let project_context = cx.new(|_cx| ProjectContext::default());
1904    let context_server_registry =
1905        cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1906    let action_log = cx.new(|_| ActionLog::new(project.clone()));
1907    let thread = cx.new(|cx| {
1908        Thread::new(
1909            project,
1910            project_context.clone(),
1911            context_server_registry,
1912            action_log,
1913            templates,
1914            Some(model.clone()),
1915            cx,
1916        )
1917    });
1918    ThreadTest {
1919        model,
1920        thread,
1921        project_context,
1922        fs,
1923    }
1924}
1925
1926#[cfg(test)]
1927#[ctor::ctor]
1928fn init_logger() {
1929    if std::env::var("RUST_LOG").is_ok() {
1930        env_logger::init();
1931    }
1932}
1933
1934fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
1935    let fs = fs.clone();
1936    cx.spawn({
1937        async move |cx| {
1938            let mut new_settings_content_rx = settings::watch_config_file(
1939                cx.background_executor(),
1940                fs,
1941                paths::settings_file().clone(),
1942            );
1943
1944            while let Some(new_settings_content) = new_settings_content_rx.next().await {
1945                cx.update(|cx| {
1946                    SettingsStore::update_global(cx, |settings, cx| {
1947                        settings.set_user_settings(&new_settings_content, cx)
1948                    })
1949                })
1950                .ok();
1951            }
1952        }
1953    })
1954    .detach();
1955}