mod.rs

   1use super::*;
   2use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
   3use agent_client_protocol::{self as acp};
   4use agent_settings::AgentProfileId;
   5use anyhow::Result;
   6use client::{Client, UserStore};
   7use fs::{FakeFs, Fs};
   8use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
   9use gpui::{
  10    App, AppContext, Entity, Task, TestAppContext, UpdateGlobal, http_client::FakeHttpClient,
  11};
  12use indoc::indoc;
  13use language_model::{
  14    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
  15    LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequestMessage,
  16    LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role, StopReason,
  17    fake_provider::FakeLanguageModel,
  18};
  19use pretty_assertions::assert_eq;
  20use project::Project;
  21use prompt_store::ProjectContext;
  22use reqwest_client::ReqwestClient;
  23use schemars::JsonSchema;
  24use serde::{Deserialize, Serialize};
  25use serde_json::json;
  26use settings::SettingsStore;
  27use std::{path::Path, rc::Rc, sync::Arc, time::Duration};
  28use util::path;
  29
  30mod test_tools;
  31use test_tools::*;
  32
  33#[gpui::test]
  34async fn test_echo(cx: &mut TestAppContext) {
  35    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  36    let fake_model = model.as_fake();
  37
  38    let events = thread
  39        .update(cx, |thread, cx| {
  40            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
  41        })
  42        .unwrap();
  43    cx.run_until_parked();
  44    fake_model.send_last_completion_stream_text_chunk("Hello");
  45    fake_model
  46        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
  47    fake_model.end_last_completion_stream();
  48
  49    let events = events.collect().await;
  50    thread.update(cx, |thread, _cx| {
  51        assert_eq!(
  52            thread.last_message().unwrap().to_markdown(),
  53            indoc! {"
  54                ## Assistant
  55
  56                Hello
  57            "}
  58        )
  59    });
  60    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
  61}
  62
  63#[gpui::test]
  64async fn test_thinking(cx: &mut TestAppContext) {
  65    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
  66    let fake_model = model.as_fake();
  67
  68    let events = thread
  69        .update(cx, |thread, cx| {
  70            thread.send(
  71                UserMessageId::new(),
  72                [indoc! {"
  73                    Testing:
  74
  75                    Generate a thinking step where you just think the word 'Think',
  76                    and have your final answer be 'Hello'
  77                "}],
  78                cx,
  79            )
  80        })
  81        .unwrap();
  82    cx.run_until_parked();
  83    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Thinking {
  84        text: "Think".to_string(),
  85        signature: None,
  86    });
  87    fake_model.send_last_completion_stream_text_chunk("Hello");
  88    fake_model
  89        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
  90    fake_model.end_last_completion_stream();
  91
  92    let events = events.collect().await;
  93    thread.update(cx, |thread, _cx| {
  94        assert_eq!(
  95            thread.last_message().unwrap().to_markdown(),
  96            indoc! {"
  97                ## Assistant
  98
  99                <think>Think</think>
 100                Hello
 101            "}
 102        )
 103    });
 104    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 105}
 106
 107#[gpui::test]
 108async fn test_system_prompt(cx: &mut TestAppContext) {
 109    let ThreadTest {
 110        model,
 111        thread,
 112        project_context,
 113        ..
 114    } = setup(cx, TestModel::Fake).await;
 115    let fake_model = model.as_fake();
 116
 117    project_context.update(cx, |project_context, _cx| {
 118        project_context.shell = "test-shell".into()
 119    });
 120    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 121    thread
 122        .update(cx, |thread, cx| {
 123            thread.send(UserMessageId::new(), ["abc"], cx)
 124        })
 125        .unwrap();
 126    cx.run_until_parked();
 127    let mut pending_completions = fake_model.pending_completions();
 128    assert_eq!(
 129        pending_completions.len(),
 130        1,
 131        "unexpected pending completions: {:?}",
 132        pending_completions
 133    );
 134
 135    let pending_completion = pending_completions.pop().unwrap();
 136    assert_eq!(pending_completion.messages[0].role, Role::System);
 137
 138    let system_message = &pending_completion.messages[0];
 139    let system_prompt = system_message.content[0].to_str().unwrap();
 140    assert!(
 141        system_prompt.contains("test-shell"),
 142        "unexpected system message: {:?}",
 143        system_message
 144    );
 145    assert!(
 146        system_prompt.contains("## Fixing Diagnostics"),
 147        "unexpected system message: {:?}",
 148        system_message
 149    );
 150}
 151
 152#[gpui::test]
 153async fn test_prompt_caching(cx: &mut TestAppContext) {
 154    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 155    let fake_model = model.as_fake();
 156
 157    // Send initial user message and verify it's cached
 158    thread
 159        .update(cx, |thread, cx| {
 160            thread.send(UserMessageId::new(), ["Message 1"], cx)
 161        })
 162        .unwrap();
 163    cx.run_until_parked();
 164
 165    let completion = fake_model.pending_completions().pop().unwrap();
 166    assert_eq!(
 167        completion.messages[1..],
 168        vec![LanguageModelRequestMessage {
 169            role: Role::User,
 170            content: vec!["Message 1".into()],
 171            cache: true
 172        }]
 173    );
 174    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 175        "Response to Message 1".into(),
 176    ));
 177    fake_model.end_last_completion_stream();
 178    cx.run_until_parked();
 179
 180    // Send another user message and verify only the latest is cached
 181    thread
 182        .update(cx, |thread, cx| {
 183            thread.send(UserMessageId::new(), ["Message 2"], cx)
 184        })
 185        .unwrap();
 186    cx.run_until_parked();
 187
 188    let completion = fake_model.pending_completions().pop().unwrap();
 189    assert_eq!(
 190        completion.messages[1..],
 191        vec![
 192            LanguageModelRequestMessage {
 193                role: Role::User,
 194                content: vec!["Message 1".into()],
 195                cache: false
 196            },
 197            LanguageModelRequestMessage {
 198                role: Role::Assistant,
 199                content: vec!["Response to Message 1".into()],
 200                cache: false
 201            },
 202            LanguageModelRequestMessage {
 203                role: Role::User,
 204                content: vec!["Message 2".into()],
 205                cache: true
 206            }
 207        ]
 208    );
 209    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text(
 210        "Response to Message 2".into(),
 211    ));
 212    fake_model.end_last_completion_stream();
 213    cx.run_until_parked();
 214
 215    // Simulate a tool call and verify that the latest tool result is cached
 216    thread.update(cx, |thread, _| thread.add_tool(EchoTool));
 217    thread
 218        .update(cx, |thread, cx| {
 219            thread.send(UserMessageId::new(), ["Use the echo tool"], cx)
 220        })
 221        .unwrap();
 222    cx.run_until_parked();
 223
 224    let tool_use = LanguageModelToolUse {
 225        id: "tool_1".into(),
 226        name: EchoTool::name().into(),
 227        raw_input: json!({"text": "test"}).to_string(),
 228        input: json!({"text": "test"}),
 229        is_input_complete: true,
 230    };
 231    fake_model
 232        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 233    fake_model.end_last_completion_stream();
 234    cx.run_until_parked();
 235
 236    let completion = fake_model.pending_completions().pop().unwrap();
 237    let tool_result = LanguageModelToolResult {
 238        tool_use_id: "tool_1".into(),
 239        tool_name: EchoTool::name().into(),
 240        is_error: false,
 241        content: "test".into(),
 242        output: Some("test".into()),
 243    };
 244    assert_eq!(
 245        completion.messages[1..],
 246        vec![
 247            LanguageModelRequestMessage {
 248                role: Role::User,
 249                content: vec!["Message 1".into()],
 250                cache: false
 251            },
 252            LanguageModelRequestMessage {
 253                role: Role::Assistant,
 254                content: vec!["Response to Message 1".into()],
 255                cache: false
 256            },
 257            LanguageModelRequestMessage {
 258                role: Role::User,
 259                content: vec!["Message 2".into()],
 260                cache: false
 261            },
 262            LanguageModelRequestMessage {
 263                role: Role::Assistant,
 264                content: vec!["Response to Message 2".into()],
 265                cache: false
 266            },
 267            LanguageModelRequestMessage {
 268                role: Role::User,
 269                content: vec!["Use the echo tool".into()],
 270                cache: false
 271            },
 272            LanguageModelRequestMessage {
 273                role: Role::Assistant,
 274                content: vec![MessageContent::ToolUse(tool_use)],
 275                cache: false
 276            },
 277            LanguageModelRequestMessage {
 278                role: Role::User,
 279                content: vec![MessageContent::ToolResult(tool_result)],
 280                cache: true
 281            }
 282        ]
 283    );
 284}
 285
 286#[gpui::test]
 287#[cfg_attr(not(feature = "e2e"), ignore)]
 288async fn test_basic_tool_calls(cx: &mut TestAppContext) {
 289    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 290
 291    // Test a tool call that's likely to complete *before* streaming stops.
 292    let events = thread
 293        .update(cx, |thread, cx| {
 294            thread.add_tool(EchoTool);
 295            thread.send(
 296                UserMessageId::new(),
 297                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
 298                cx,
 299            )
 300        })
 301        .unwrap()
 302        .collect()
 303        .await;
 304    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 305
 306    // Test a tool calls that's likely to complete *after* streaming stops.
 307    let events = thread
 308        .update(cx, |thread, cx| {
 309            thread.remove_tool(&EchoTool::name());
 310            thread.add_tool(DelayTool);
 311            thread.send(
 312                UserMessageId::new(),
 313                [
 314                    "Now call the delay tool with 200ms.",
 315                    "When the timer goes off, then you echo the output of the tool.",
 316                ],
 317                cx,
 318            )
 319        })
 320        .unwrap()
 321        .collect()
 322        .await;
 323    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 324    thread.update(cx, |thread, _cx| {
 325        assert!(
 326            thread
 327                .last_message()
 328                .unwrap()
 329                .as_agent_message()
 330                .unwrap()
 331                .content
 332                .iter()
 333                .any(|content| {
 334                    if let AgentMessageContent::Text(text) = content {
 335                        text.contains("Ding")
 336                    } else {
 337                        false
 338                    }
 339                }),
 340            "{}",
 341            thread.to_markdown()
 342        );
 343    });
 344}
 345
 346#[gpui::test]
 347#[cfg_attr(not(feature = "e2e"), ignore)]
 348async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
 349    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 350
 351    // Test a tool call that's likely to complete *before* streaming stops.
 352    let mut events = thread
 353        .update(cx, |thread, cx| {
 354            thread.add_tool(WordListTool);
 355            thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
 356        })
 357        .unwrap();
 358
 359    let mut saw_partial_tool_use = false;
 360    while let Some(event) = events.next().await {
 361        if let Ok(ThreadEvent::ToolCall(tool_call)) = event {
 362            thread.update(cx, |thread, _cx| {
 363                // Look for a tool use in the thread's last message
 364                let message = thread.last_message().unwrap();
 365                let agent_message = message.as_agent_message().unwrap();
 366                let last_content = agent_message.content.last().unwrap();
 367                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
 368                    assert_eq!(last_tool_use.name.as_ref(), "word_list");
 369                    if tool_call.status == acp::ToolCallStatus::Pending {
 370                        if !last_tool_use.is_input_complete
 371                            && last_tool_use.input.get("g").is_none()
 372                        {
 373                            saw_partial_tool_use = true;
 374                        }
 375                    } else {
 376                        last_tool_use
 377                            .input
 378                            .get("a")
 379                            .expect("'a' has streamed because input is now complete");
 380                        last_tool_use
 381                            .input
 382                            .get("g")
 383                            .expect("'g' has streamed because input is now complete");
 384                    }
 385                } else {
 386                    panic!("last content should be a tool use");
 387                }
 388            });
 389        }
 390    }
 391
 392    assert!(
 393        saw_partial_tool_use,
 394        "should see at least one partially streamed tool use in the history"
 395    );
 396}
 397
 398#[gpui::test]
 399async fn test_tool_authorization(cx: &mut TestAppContext) {
 400    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 401    let fake_model = model.as_fake();
 402
 403    let mut events = thread
 404        .update(cx, |thread, cx| {
 405            thread.add_tool(ToolRequiringPermission);
 406            thread.send(UserMessageId::new(), ["abc"], cx)
 407        })
 408        .unwrap();
 409    cx.run_until_parked();
 410    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 411        LanguageModelToolUse {
 412            id: "tool_id_1".into(),
 413            name: ToolRequiringPermission::name().into(),
 414            raw_input: "{}".into(),
 415            input: json!({}),
 416            is_input_complete: true,
 417        },
 418    ));
 419    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 420        LanguageModelToolUse {
 421            id: "tool_id_2".into(),
 422            name: ToolRequiringPermission::name().into(),
 423            raw_input: "{}".into(),
 424            input: json!({}),
 425            is_input_complete: true,
 426        },
 427    ));
 428    fake_model.end_last_completion_stream();
 429    let tool_call_auth_1 = next_tool_call_authorization(&mut events).await;
 430    let tool_call_auth_2 = next_tool_call_authorization(&mut events).await;
 431
 432    // Approve the first
 433    tool_call_auth_1
 434        .response
 435        .send(tool_call_auth_1.options[1].id.clone())
 436        .unwrap();
 437    cx.run_until_parked();
 438
 439    // Reject the second
 440    tool_call_auth_2
 441        .response
 442        .send(tool_call_auth_1.options[2].id.clone())
 443        .unwrap();
 444    cx.run_until_parked();
 445
 446    let completion = fake_model.pending_completions().pop().unwrap();
 447    let message = completion.messages.last().unwrap();
 448    assert_eq!(
 449        message.content,
 450        vec![
 451            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 452                tool_use_id: tool_call_auth_1.tool_call.id.0.to_string().into(),
 453                tool_name: ToolRequiringPermission::name().into(),
 454                is_error: false,
 455                content: "Allowed".into(),
 456                output: Some("Allowed".into())
 457            }),
 458            language_model::MessageContent::ToolResult(LanguageModelToolResult {
 459                tool_use_id: tool_call_auth_2.tool_call.id.0.to_string().into(),
 460                tool_name: ToolRequiringPermission::name().into(),
 461                is_error: true,
 462                content: "Permission to run tool denied by user".into(),
 463                output: None
 464            })
 465        ]
 466    );
 467
 468    // Simulate yet another tool call.
 469    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 470        LanguageModelToolUse {
 471            id: "tool_id_3".into(),
 472            name: ToolRequiringPermission::name().into(),
 473            raw_input: "{}".into(),
 474            input: json!({}),
 475            is_input_complete: true,
 476        },
 477    ));
 478    fake_model.end_last_completion_stream();
 479
 480    // Respond by always allowing tools.
 481    let tool_call_auth_3 = next_tool_call_authorization(&mut events).await;
 482    tool_call_auth_3
 483        .response
 484        .send(tool_call_auth_3.options[0].id.clone())
 485        .unwrap();
 486    cx.run_until_parked();
 487    let completion = fake_model.pending_completions().pop().unwrap();
 488    let message = completion.messages.last().unwrap();
 489    assert_eq!(
 490        message.content,
 491        vec![language_model::MessageContent::ToolResult(
 492            LanguageModelToolResult {
 493                tool_use_id: tool_call_auth_3.tool_call.id.0.to_string().into(),
 494                tool_name: ToolRequiringPermission::name().into(),
 495                is_error: false,
 496                content: "Allowed".into(),
 497                output: Some("Allowed".into())
 498            }
 499        )]
 500    );
 501
 502    // Simulate a final tool call, ensuring we don't trigger authorization.
 503    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 504        LanguageModelToolUse {
 505            id: "tool_id_4".into(),
 506            name: ToolRequiringPermission::name().into(),
 507            raw_input: "{}".into(),
 508            input: json!({}),
 509            is_input_complete: true,
 510        },
 511    ));
 512    fake_model.end_last_completion_stream();
 513    cx.run_until_parked();
 514    let completion = fake_model.pending_completions().pop().unwrap();
 515    let message = completion.messages.last().unwrap();
 516    assert_eq!(
 517        message.content,
 518        vec![language_model::MessageContent::ToolResult(
 519            LanguageModelToolResult {
 520                tool_use_id: "tool_id_4".into(),
 521                tool_name: ToolRequiringPermission::name().into(),
 522                is_error: false,
 523                content: "Allowed".into(),
 524                output: Some("Allowed".into())
 525            }
 526        )]
 527    );
 528}
 529
 530#[gpui::test]
 531async fn test_tool_hallucination(cx: &mut TestAppContext) {
 532    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 533    let fake_model = model.as_fake();
 534
 535    let mut events = thread
 536        .update(cx, |thread, cx| {
 537            thread.send(UserMessageId::new(), ["abc"], cx)
 538        })
 539        .unwrap();
 540    cx.run_until_parked();
 541    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
 542        LanguageModelToolUse {
 543            id: "tool_id_1".into(),
 544            name: "nonexistent_tool".into(),
 545            raw_input: "{}".into(),
 546            input: json!({}),
 547            is_input_complete: true,
 548        },
 549    ));
 550    fake_model.end_last_completion_stream();
 551
 552    let tool_call = expect_tool_call(&mut events).await;
 553    assert_eq!(tool_call.title, "nonexistent_tool");
 554    assert_eq!(tool_call.status, acp::ToolCallStatus::Pending);
 555    let update = expect_tool_call_update_fields(&mut events).await;
 556    assert_eq!(update.fields.status, Some(acp::ToolCallStatus::Failed));
 557}
 558
 559#[gpui::test]
 560async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
 561    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 562    let fake_model = model.as_fake();
 563
 564    let events = thread
 565        .update(cx, |thread, cx| {
 566            thread.add_tool(EchoTool);
 567            thread.send(UserMessageId::new(), ["abc"], cx)
 568        })
 569        .unwrap();
 570    cx.run_until_parked();
 571    let tool_use = LanguageModelToolUse {
 572        id: "tool_id_1".into(),
 573        name: EchoTool::name().into(),
 574        raw_input: "{}".into(),
 575        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 576        is_input_complete: true,
 577    };
 578    fake_model
 579        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 580    fake_model.end_last_completion_stream();
 581
 582    cx.run_until_parked();
 583    let completion = fake_model.pending_completions().pop().unwrap();
 584    let tool_result = LanguageModelToolResult {
 585        tool_use_id: "tool_id_1".into(),
 586        tool_name: EchoTool::name().into(),
 587        is_error: false,
 588        content: "def".into(),
 589        output: Some("def".into()),
 590    };
 591    assert_eq!(
 592        completion.messages[1..],
 593        vec![
 594            LanguageModelRequestMessage {
 595                role: Role::User,
 596                content: vec!["abc".into()],
 597                cache: false
 598            },
 599            LanguageModelRequestMessage {
 600                role: Role::Assistant,
 601                content: vec![MessageContent::ToolUse(tool_use.clone())],
 602                cache: false
 603            },
 604            LanguageModelRequestMessage {
 605                role: Role::User,
 606                content: vec![MessageContent::ToolResult(tool_result.clone())],
 607                cache: true
 608            },
 609        ]
 610    );
 611
 612    // Simulate reaching tool use limit.
 613    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 614        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 615    ));
 616    fake_model.end_last_completion_stream();
 617    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 618    assert!(
 619        last_event
 620            .unwrap_err()
 621            .is::<language_model::ToolUseLimitReachedError>()
 622    );
 623
 624    let events = thread.update(cx, |thread, cx| thread.resume(cx)).unwrap();
 625    cx.run_until_parked();
 626    let completion = fake_model.pending_completions().pop().unwrap();
 627    assert_eq!(
 628        completion.messages[1..],
 629        vec![
 630            LanguageModelRequestMessage {
 631                role: Role::User,
 632                content: vec!["abc".into()],
 633                cache: false
 634            },
 635            LanguageModelRequestMessage {
 636                role: Role::Assistant,
 637                content: vec![MessageContent::ToolUse(tool_use)],
 638                cache: false
 639            },
 640            LanguageModelRequestMessage {
 641                role: Role::User,
 642                content: vec![MessageContent::ToolResult(tool_result)],
 643                cache: false
 644            },
 645            LanguageModelRequestMessage {
 646                role: Role::User,
 647                content: vec!["Continue where you left off".into()],
 648                cache: true
 649            }
 650        ]
 651    );
 652
 653    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::Text("Done".into()));
 654    fake_model.end_last_completion_stream();
 655    events.collect::<Vec<_>>().await;
 656    thread.read_with(cx, |thread, _cx| {
 657        assert_eq!(
 658            thread.last_message().unwrap().to_markdown(),
 659            indoc! {"
 660                ## Assistant
 661
 662                Done
 663            "}
 664        )
 665    });
 666
 667    // Ensure we error if calling resume when tool use limit was *not* reached.
 668    let error = thread
 669        .update(cx, |thread, cx| thread.resume(cx))
 670        .unwrap_err();
 671    assert_eq!(
 672        error.to_string(),
 673        "can only resume after tool use limit is reached"
 674    )
 675}
 676
 677#[gpui::test]
 678async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
 679    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
 680    let fake_model = model.as_fake();
 681
 682    let events = thread
 683        .update(cx, |thread, cx| {
 684            thread.add_tool(EchoTool);
 685            thread.send(UserMessageId::new(), ["abc"], cx)
 686        })
 687        .unwrap();
 688    cx.run_until_parked();
 689
 690    let tool_use = LanguageModelToolUse {
 691        id: "tool_id_1".into(),
 692        name: EchoTool::name().into(),
 693        raw_input: "{}".into(),
 694        input: serde_json::to_value(&EchoToolInput { text: "def".into() }).unwrap(),
 695        is_input_complete: true,
 696    };
 697    let tool_result = LanguageModelToolResult {
 698        tool_use_id: "tool_id_1".into(),
 699        tool_name: EchoTool::name().into(),
 700        is_error: false,
 701        content: "def".into(),
 702        output: Some("def".into()),
 703    };
 704    fake_model
 705        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
 706    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
 707        cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
 708    ));
 709    fake_model.end_last_completion_stream();
 710    let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
 711    assert!(
 712        last_event
 713            .unwrap_err()
 714            .is::<language_model::ToolUseLimitReachedError>()
 715    );
 716
 717    thread
 718        .update(cx, |thread, cx| {
 719            thread.send(UserMessageId::new(), vec!["ghi"], cx)
 720        })
 721        .unwrap();
 722    cx.run_until_parked();
 723    let completion = fake_model.pending_completions().pop().unwrap();
 724    assert_eq!(
 725        completion.messages[1..],
 726        vec![
 727            LanguageModelRequestMessage {
 728                role: Role::User,
 729                content: vec!["abc".into()],
 730                cache: false
 731            },
 732            LanguageModelRequestMessage {
 733                role: Role::Assistant,
 734                content: vec![MessageContent::ToolUse(tool_use)],
 735                cache: false
 736            },
 737            LanguageModelRequestMessage {
 738                role: Role::User,
 739                content: vec![MessageContent::ToolResult(tool_result)],
 740                cache: false
 741            },
 742            LanguageModelRequestMessage {
 743                role: Role::User,
 744                content: vec!["ghi".into()],
 745                cache: true
 746            }
 747        ]
 748    );
 749}
 750
 751async fn expect_tool_call(events: &mut UnboundedReceiver<Result<ThreadEvent>>) -> acp::ToolCall {
 752    let event = events
 753        .next()
 754        .await
 755        .expect("no tool call authorization event received")
 756        .unwrap();
 757    match event {
 758        ThreadEvent::ToolCall(tool_call) => tool_call,
 759        event => {
 760            panic!("Unexpected event {event:?}");
 761        }
 762    }
 763}
 764
 765async fn expect_tool_call_update_fields(
 766    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 767) -> acp::ToolCallUpdate {
 768    let event = events
 769        .next()
 770        .await
 771        .expect("no tool call authorization event received")
 772        .unwrap();
 773    match event {
 774        ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(update)) => update,
 775        event => {
 776            panic!("Unexpected event {event:?}");
 777        }
 778    }
 779}
 780
 781async fn next_tool_call_authorization(
 782    events: &mut UnboundedReceiver<Result<ThreadEvent>>,
 783) -> ToolCallAuthorization {
 784    loop {
 785        let event = events
 786            .next()
 787            .await
 788            .expect("no tool call authorization event received")
 789            .unwrap();
 790        if let ThreadEvent::ToolCallAuthorization(tool_call_authorization) = event {
 791            let permission_kinds = tool_call_authorization
 792                .options
 793                .iter()
 794                .map(|o| o.kind)
 795                .collect::<Vec<_>>();
 796            assert_eq!(
 797                permission_kinds,
 798                vec![
 799                    acp::PermissionOptionKind::AllowAlways,
 800                    acp::PermissionOptionKind::AllowOnce,
 801                    acp::PermissionOptionKind::RejectOnce,
 802                ]
 803            );
 804            return tool_call_authorization;
 805        }
 806    }
 807}
 808
 809#[gpui::test]
 810#[cfg_attr(not(feature = "e2e"), ignore)]
 811async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
 812    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 813
 814    // Test concurrent tool calls with different delay times
 815    let events = thread
 816        .update(cx, |thread, cx| {
 817            thread.add_tool(DelayTool);
 818            thread.send(
 819                UserMessageId::new(),
 820                [
 821                    "Call the delay tool twice in the same message.",
 822                    "Once with 100ms. Once with 300ms.",
 823                    "When both timers are complete, describe the outputs.",
 824                ],
 825                cx,
 826            )
 827        })
 828        .unwrap()
 829        .collect()
 830        .await;
 831
 832    let stop_reasons = stop_events(events);
 833    assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 834
 835    thread.update(cx, |thread, _cx| {
 836        let last_message = thread.last_message().unwrap();
 837        let agent_message = last_message.as_agent_message().unwrap();
 838        let text = agent_message
 839            .content
 840            .iter()
 841            .filter_map(|content| {
 842                if let AgentMessageContent::Text(text) = content {
 843                    Some(text.as_str())
 844                } else {
 845                    None
 846                }
 847            })
 848            .collect::<String>();
 849
 850        assert!(text.contains("Ding"));
 851    });
 852}
 853
 854#[gpui::test]
 855async fn test_profiles(cx: &mut TestAppContext) {
 856    let ThreadTest {
 857        model, thread, fs, ..
 858    } = setup(cx, TestModel::Fake).await;
 859    let fake_model = model.as_fake();
 860
 861    thread.update(cx, |thread, _cx| {
 862        thread.add_tool(DelayTool);
 863        thread.add_tool(EchoTool);
 864        thread.add_tool(InfiniteTool);
 865    });
 866
 867    // Override profiles and wait for settings to be loaded.
 868    fs.insert_file(
 869        paths::settings_file(),
 870        json!({
 871            "agent": {
 872                "profiles": {
 873                    "test-1": {
 874                        "name": "Test Profile 1",
 875                        "tools": {
 876                            EchoTool::name(): true,
 877                            DelayTool::name(): true,
 878                        }
 879                    },
 880                    "test-2": {
 881                        "name": "Test Profile 2",
 882                        "tools": {
 883                            InfiniteTool::name(): true,
 884                        }
 885                    }
 886                }
 887            }
 888        })
 889        .to_string()
 890        .into_bytes(),
 891    )
 892    .await;
 893    cx.run_until_parked();
 894
 895    // Test that test-1 profile (default) has echo and delay tools
 896    thread
 897        .update(cx, |thread, cx| {
 898            thread.set_profile(AgentProfileId("test-1".into()));
 899            thread.send(UserMessageId::new(), ["test"], cx)
 900        })
 901        .unwrap();
 902    cx.run_until_parked();
 903
 904    let mut pending_completions = fake_model.pending_completions();
 905    assert_eq!(pending_completions.len(), 1);
 906    let completion = pending_completions.pop().unwrap();
 907    let tool_names: Vec<String> = completion
 908        .tools
 909        .iter()
 910        .map(|tool| tool.name.clone())
 911        .collect();
 912    assert_eq!(tool_names, vec![DelayTool::name(), EchoTool::name()]);
 913    fake_model.end_last_completion_stream();
 914
 915    // Switch to test-2 profile, and verify that it has only the infinite tool.
 916    thread
 917        .update(cx, |thread, cx| {
 918            thread.set_profile(AgentProfileId("test-2".into()));
 919            thread.send(UserMessageId::new(), ["test2"], cx)
 920        })
 921        .unwrap();
 922    cx.run_until_parked();
 923    let mut pending_completions = fake_model.pending_completions();
 924    assert_eq!(pending_completions.len(), 1);
 925    let completion = pending_completions.pop().unwrap();
 926    let tool_names: Vec<String> = completion
 927        .tools
 928        .iter()
 929        .map(|tool| tool.name.clone())
 930        .collect();
 931    assert_eq!(tool_names, vec![InfiniteTool::name()]);
 932}
 933
 934#[gpui::test]
 935#[cfg_attr(not(feature = "e2e"), ignore)]
 936async fn test_cancellation(cx: &mut TestAppContext) {
 937    let ThreadTest { thread, .. } = setup(cx, TestModel::Sonnet4).await;
 938
 939    let mut events = thread
 940        .update(cx, |thread, cx| {
 941            thread.add_tool(InfiniteTool);
 942            thread.add_tool(EchoTool);
 943            thread.send(
 944                UserMessageId::new(),
 945                ["Call the echo tool, then call the infinite tool, then explain their output"],
 946                cx,
 947            )
 948        })
 949        .unwrap();
 950
 951    // Wait until both tools are called.
 952    let mut expected_tools = vec!["Echo", "Infinite Tool"];
 953    let mut echo_id = None;
 954    let mut echo_completed = false;
 955    while let Some(event) = events.next().await {
 956        match event.unwrap() {
 957            ThreadEvent::ToolCall(tool_call) => {
 958                assert_eq!(tool_call.title, expected_tools.remove(0));
 959                if tool_call.title == "Echo" {
 960                    echo_id = Some(tool_call.id);
 961                }
 962            }
 963            ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
 964                acp::ToolCallUpdate {
 965                    id,
 966                    fields:
 967                        acp::ToolCallUpdateFields {
 968                            status: Some(acp::ToolCallStatus::Completed),
 969                            ..
 970                        },
 971                },
 972            )) if Some(&id) == echo_id.as_ref() => {
 973                echo_completed = true;
 974            }
 975            _ => {}
 976        }
 977
 978        if expected_tools.is_empty() && echo_completed {
 979            break;
 980        }
 981    }
 982
 983    // Cancel the current send and ensure that the event stream is closed, even
 984    // if one of the tools is still running.
 985    thread.update(cx, |thread, cx| thread.cancel(cx));
 986    let events = events.collect::<Vec<_>>().await;
 987    let last_event = events.last();
 988    assert!(
 989        matches!(
 990            last_event,
 991            Some(Ok(ThreadEvent::Stop(acp::StopReason::Cancelled)))
 992        ),
 993        "unexpected event {last_event:?}"
 994    );
 995
 996    // Ensure we can still send a new message after cancellation.
 997    let events = thread
 998        .update(cx, |thread, cx| {
 999            thread.send(
1000                UserMessageId::new(),
1001                ["Testing: reply with 'Hello' then stop."],
1002                cx,
1003            )
1004        })
1005        .unwrap()
1006        .collect::<Vec<_>>()
1007        .await;
1008    thread.update(cx, |thread, _cx| {
1009        let message = thread.last_message().unwrap();
1010        let agent_message = message.as_agent_message().unwrap();
1011        assert_eq!(
1012            agent_message.content,
1013            vec![AgentMessageContent::Text("Hello".to_string())]
1014        );
1015    });
1016    assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
1017}
1018
1019#[gpui::test]
1020async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
1021    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1022    let fake_model = model.as_fake();
1023
1024    let events_1 = thread
1025        .update(cx, |thread, cx| {
1026            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1027        })
1028        .unwrap();
1029    cx.run_until_parked();
1030    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1031    cx.run_until_parked();
1032
1033    let events_2 = thread
1034        .update(cx, |thread, cx| {
1035            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1036        })
1037        .unwrap();
1038    cx.run_until_parked();
1039    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1040    fake_model
1041        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1042    fake_model.end_last_completion_stream();
1043
1044    let events_1 = events_1.collect::<Vec<_>>().await;
1045    assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
1046    let events_2 = events_2.collect::<Vec<_>>().await;
1047    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1048}
1049
1050#[gpui::test]
1051async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
1052    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1053    let fake_model = model.as_fake();
1054
1055    let events_1 = thread
1056        .update(cx, |thread, cx| {
1057            thread.send(UserMessageId::new(), ["Hello 1"], cx)
1058        })
1059        .unwrap();
1060    cx.run_until_parked();
1061    fake_model.send_last_completion_stream_text_chunk("Hey 1!");
1062    fake_model
1063        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1064    fake_model.end_last_completion_stream();
1065    let events_1 = events_1.collect::<Vec<_>>().await;
1066
1067    let events_2 = thread
1068        .update(cx, |thread, cx| {
1069            thread.send(UserMessageId::new(), ["Hello 2"], cx)
1070        })
1071        .unwrap();
1072    cx.run_until_parked();
1073    fake_model.send_last_completion_stream_text_chunk("Hey 2!");
1074    fake_model
1075        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
1076    fake_model.end_last_completion_stream();
1077    let events_2 = events_2.collect::<Vec<_>>().await;
1078
1079    assert_eq!(stop_events(events_1), vec![acp::StopReason::EndTurn]);
1080    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
1081}
1082
1083#[gpui::test]
1084async fn test_refusal(cx: &mut TestAppContext) {
1085    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1086    let fake_model = model.as_fake();
1087
1088    let events = thread
1089        .update(cx, |thread, cx| {
1090            thread.send(UserMessageId::new(), ["Hello"], cx)
1091        })
1092        .unwrap();
1093    cx.run_until_parked();
1094    thread.read_with(cx, |thread, _| {
1095        assert_eq!(
1096            thread.to_markdown(),
1097            indoc! {"
1098                ## User
1099
1100                Hello
1101            "}
1102        );
1103    });
1104
1105    fake_model.send_last_completion_stream_text_chunk("Hey!");
1106    cx.run_until_parked();
1107    thread.read_with(cx, |thread, _| {
1108        assert_eq!(
1109            thread.to_markdown(),
1110            indoc! {"
1111                ## User
1112
1113                Hello
1114
1115                ## Assistant
1116
1117                Hey!
1118            "}
1119        );
1120    });
1121
1122    // If the model refuses to continue, the thread should remove all the messages after the last user message.
1123    fake_model
1124        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::Refusal));
1125    let events = events.collect::<Vec<_>>().await;
1126    assert_eq!(stop_events(events), vec![acp::StopReason::Refusal]);
1127    thread.read_with(cx, |thread, _| {
1128        assert_eq!(thread.to_markdown(), "");
1129    });
1130}
1131
1132#[gpui::test]
1133async fn test_truncate_first_message(cx: &mut TestAppContext) {
1134    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1135    let fake_model = model.as_fake();
1136
1137    let message_id = UserMessageId::new();
1138    thread
1139        .update(cx, |thread, cx| {
1140            thread.send(message_id.clone(), ["Hello"], cx)
1141        })
1142        .unwrap();
1143    cx.run_until_parked();
1144    thread.read_with(cx, |thread, _| {
1145        assert_eq!(
1146            thread.to_markdown(),
1147            indoc! {"
1148                ## User
1149
1150                Hello
1151            "}
1152        );
1153        assert_eq!(thread.latest_token_usage(), None);
1154    });
1155
1156    fake_model.send_last_completion_stream_text_chunk("Hey!");
1157    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1158        language_model::TokenUsage {
1159            input_tokens: 32_000,
1160            output_tokens: 16_000,
1161            cache_creation_input_tokens: 0,
1162            cache_read_input_tokens: 0,
1163        },
1164    ));
1165    cx.run_until_parked();
1166    thread.read_with(cx, |thread, _| {
1167        assert_eq!(
1168            thread.to_markdown(),
1169            indoc! {"
1170                ## User
1171
1172                Hello
1173
1174                ## Assistant
1175
1176                Hey!
1177            "}
1178        );
1179        assert_eq!(
1180            thread.latest_token_usage(),
1181            Some(acp_thread::TokenUsage {
1182                used_tokens: 32_000 + 16_000,
1183                max_tokens: 1_000_000,
1184            })
1185        );
1186    });
1187
1188    thread
1189        .update(cx, |thread, cx| thread.truncate(message_id, cx))
1190        .unwrap();
1191    cx.run_until_parked();
1192    thread.read_with(cx, |thread, _| {
1193        assert_eq!(thread.to_markdown(), "");
1194        assert_eq!(thread.latest_token_usage(), None);
1195    });
1196
1197    // Ensure we can still send a new message after truncation.
1198    thread
1199        .update(cx, |thread, cx| {
1200            thread.send(UserMessageId::new(), ["Hi"], cx)
1201        })
1202        .unwrap();
1203    thread.update(cx, |thread, _cx| {
1204        assert_eq!(
1205            thread.to_markdown(),
1206            indoc! {"
1207                ## User
1208
1209                Hi
1210            "}
1211        );
1212    });
1213    cx.run_until_parked();
1214    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
1215    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1216        language_model::TokenUsage {
1217            input_tokens: 40_000,
1218            output_tokens: 20_000,
1219            cache_creation_input_tokens: 0,
1220            cache_read_input_tokens: 0,
1221        },
1222    ));
1223    cx.run_until_parked();
1224    thread.read_with(cx, |thread, _| {
1225        assert_eq!(
1226            thread.to_markdown(),
1227            indoc! {"
1228                ## User
1229
1230                Hi
1231
1232                ## Assistant
1233
1234                Ahoy!
1235            "}
1236        );
1237
1238        assert_eq!(
1239            thread.latest_token_usage(),
1240            Some(acp_thread::TokenUsage {
1241                used_tokens: 40_000 + 20_000,
1242                max_tokens: 1_000_000,
1243            })
1244        );
1245    });
1246}
1247
1248#[gpui::test]
1249async fn test_truncate_second_message(cx: &mut TestAppContext) {
1250    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1251    let fake_model = model.as_fake();
1252
1253    thread
1254        .update(cx, |thread, cx| {
1255            thread.send(UserMessageId::new(), ["Message 1"], cx)
1256        })
1257        .unwrap();
1258    cx.run_until_parked();
1259    fake_model.send_last_completion_stream_text_chunk("Message 1 response");
1260    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1261        language_model::TokenUsage {
1262            input_tokens: 32_000,
1263            output_tokens: 16_000,
1264            cache_creation_input_tokens: 0,
1265            cache_read_input_tokens: 0,
1266        },
1267    ));
1268    fake_model.end_last_completion_stream();
1269    cx.run_until_parked();
1270
1271    let assert_first_message_state = |cx: &mut TestAppContext| {
1272        thread.clone().read_with(cx, |thread, _| {
1273            assert_eq!(
1274                thread.to_markdown(),
1275                indoc! {"
1276                    ## User
1277
1278                    Message 1
1279
1280                    ## Assistant
1281
1282                    Message 1 response
1283                "}
1284            );
1285
1286            assert_eq!(
1287                thread.latest_token_usage(),
1288                Some(acp_thread::TokenUsage {
1289                    used_tokens: 32_000 + 16_000,
1290                    max_tokens: 1_000_000,
1291                })
1292            );
1293        });
1294    };
1295
1296    assert_first_message_state(cx);
1297
1298    let second_message_id = UserMessageId::new();
1299    thread
1300        .update(cx, |thread, cx| {
1301            thread.send(second_message_id.clone(), ["Message 2"], cx)
1302        })
1303        .unwrap();
1304    cx.run_until_parked();
1305
1306    fake_model.send_last_completion_stream_text_chunk("Message 2 response");
1307    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
1308        language_model::TokenUsage {
1309            input_tokens: 40_000,
1310            output_tokens: 20_000,
1311            cache_creation_input_tokens: 0,
1312            cache_read_input_tokens: 0,
1313        },
1314    ));
1315    fake_model.end_last_completion_stream();
1316    cx.run_until_parked();
1317
1318    thread.read_with(cx, |thread, _| {
1319        assert_eq!(
1320            thread.to_markdown(),
1321            indoc! {"
1322                ## User
1323
1324                Message 1
1325
1326                ## Assistant
1327
1328                Message 1 response
1329
1330                ## User
1331
1332                Message 2
1333
1334                ## Assistant
1335
1336                Message 2 response
1337            "}
1338        );
1339
1340        assert_eq!(
1341            thread.latest_token_usage(),
1342            Some(acp_thread::TokenUsage {
1343                used_tokens: 40_000 + 20_000,
1344                max_tokens: 1_000_000,
1345            })
1346        );
1347    });
1348
1349    thread
1350        .update(cx, |thread, cx| thread.truncate(second_message_id, cx))
1351        .unwrap();
1352    cx.run_until_parked();
1353
1354    assert_first_message_state(cx);
1355}
1356
1357#[gpui::test]
1358async fn test_title_generation(cx: &mut TestAppContext) {
1359    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
1360    let fake_model = model.as_fake();
1361
1362    let summary_model = Arc::new(FakeLanguageModel::default());
1363    thread.update(cx, |thread, cx| {
1364        thread.set_summarization_model(Some(summary_model.clone()), cx)
1365    });
1366
1367    let send = thread
1368        .update(cx, |thread, cx| {
1369            thread.send(UserMessageId::new(), ["Hello"], cx)
1370        })
1371        .unwrap();
1372    cx.run_until_parked();
1373
1374    fake_model.send_last_completion_stream_text_chunk("Hey!");
1375    fake_model.end_last_completion_stream();
1376    cx.run_until_parked();
1377    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "New Thread"));
1378
1379    // Ensure the summary model has been invoked to generate a title.
1380    summary_model.send_last_completion_stream_text_chunk("Hello ");
1381    summary_model.send_last_completion_stream_text_chunk("world\nG");
1382    summary_model.send_last_completion_stream_text_chunk("oodnight Moon");
1383    summary_model.end_last_completion_stream();
1384    send.collect::<Vec<_>>().await;
1385    cx.run_until_parked();
1386    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1387
1388    // Send another message, ensuring no title is generated this time.
1389    let send = thread
1390        .update(cx, |thread, cx| {
1391            thread.send(UserMessageId::new(), ["Hello again"], cx)
1392        })
1393        .unwrap();
1394    cx.run_until_parked();
1395    fake_model.send_last_completion_stream_text_chunk("Hey again!");
1396    fake_model.end_last_completion_stream();
1397    cx.run_until_parked();
1398    assert_eq!(summary_model.pending_completions(), Vec::new());
1399    send.collect::<Vec<_>>().await;
1400    thread.read_with(cx, |thread, _| assert_eq!(thread.title(), "Hello world"));
1401}
1402
1403#[gpui::test]
1404async fn test_agent_connection(cx: &mut TestAppContext) {
1405    cx.update(settings::init);
1406    let templates = Templates::new();
1407
1408    // Initialize language model system with test provider
1409    cx.update(|cx| {
1410        gpui_tokio::init(cx);
1411        client::init_settings(cx);
1412
1413        let http_client = FakeHttpClient::with_404_response();
1414        let clock = Arc::new(clock::FakeSystemClock::new());
1415        let client = Client::new(clock, http_client, cx);
1416        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1417        Project::init_settings(cx);
1418        agent_settings::init(cx);
1419        language_model::init(client.clone(), cx);
1420        language_models::init(user_store, client.clone(), cx);
1421        LanguageModelRegistry::test(cx);
1422    });
1423    cx.executor().forbid_parking();
1424
1425    // Create a project for new_thread
1426    let fake_fs = cx.update(|cx| fs::FakeFs::new(cx.background_executor().clone()));
1427    fake_fs.insert_tree(path!("/test"), json!({})).await;
1428    let project = Project::test(fake_fs.clone(), [Path::new("/test")], cx).await;
1429    let cwd = Path::new("/test");
1430    let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
1431    let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
1432
1433    // Create agent and connection
1434    let agent = NativeAgent::new(
1435        project.clone(),
1436        history_store,
1437        templates.clone(),
1438        None,
1439        fake_fs.clone(),
1440        &mut cx.to_async(),
1441    )
1442    .await
1443    .unwrap();
1444    let connection = NativeAgentConnection(agent.clone());
1445
1446    // Test model_selector returns Some
1447    let selector_opt = connection.model_selector();
1448    assert!(
1449        selector_opt.is_some(),
1450        "agent2 should always support ModelSelector"
1451    );
1452    let selector = selector_opt.unwrap();
1453
1454    // Test list_models
1455    let listed_models = cx
1456        .update(|cx| selector.list_models(cx))
1457        .await
1458        .expect("list_models should succeed");
1459    let AgentModelList::Grouped(listed_models) = listed_models else {
1460        panic!("Unexpected model list type");
1461    };
1462    assert!(!listed_models.is_empty(), "should have at least one model");
1463    assert_eq!(
1464        listed_models[&AgentModelGroupName("Fake".into())][0].id.0,
1465        "fake/fake"
1466    );
1467
1468    // Create a thread using new_thread
1469    let connection_rc = Rc::new(connection.clone());
1470    let acp_thread = cx
1471        .update(|cx| connection_rc.new_thread(project, cwd, cx))
1472        .await
1473        .expect("new_thread should succeed");
1474
1475    // Get the session_id from the AcpThread
1476    let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
1477
1478    // Test selected_model returns the default
1479    let model = cx
1480        .update(|cx| selector.selected_model(&session_id, cx))
1481        .await
1482        .expect("selected_model should succeed");
1483    let model = cx
1484        .update(|cx| agent.read(cx).models().model_from_id(&model.id))
1485        .unwrap();
1486    let model = model.as_fake();
1487    assert_eq!(model.id().0, "fake", "should return default model");
1488
1489    let request = acp_thread.update(cx, |thread, cx| thread.send(vec!["abc".into()], cx));
1490    cx.run_until_parked();
1491    model.send_last_completion_stream_text_chunk("def");
1492    cx.run_until_parked();
1493    acp_thread.read_with(cx, |thread, cx| {
1494        assert_eq!(
1495            thread.to_markdown(cx),
1496            indoc! {"
1497                ## User
1498
1499                abc
1500
1501                ## Assistant
1502
1503                def
1504
1505            "}
1506        )
1507    });
1508
1509    // Test cancel
1510    cx.update(|cx| connection.cancel(&session_id, cx));
1511    request.await.expect("prompt should fail gracefully");
1512
1513    // Ensure that dropping the ACP thread causes the native thread to be
1514    // dropped as well.
1515    cx.update(|_| drop(acp_thread));
1516    let result = cx
1517        .update(|cx| {
1518            connection.prompt(
1519                Some(acp_thread::UserMessageId::new()),
1520                acp::PromptRequest {
1521                    session_id: session_id.clone(),
1522                    prompt: vec!["ghi".into()],
1523                },
1524                cx,
1525            )
1526        })
1527        .await;
1528    assert_eq!(
1529        result.as_ref().unwrap_err().to_string(),
1530        "Session not found",
1531        "unexpected result: {:?}",
1532        result
1533    );
1534}
1535
1536#[gpui::test]
1537async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
1538    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1539    thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
1540    let fake_model = model.as_fake();
1541
1542    let mut events = thread
1543        .update(cx, |thread, cx| {
1544            thread.send(UserMessageId::new(), ["Think"], cx)
1545        })
1546        .unwrap();
1547    cx.run_until_parked();
1548
1549    // Simulate streaming partial input.
1550    let input = json!({});
1551    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1552        LanguageModelToolUse {
1553            id: "1".into(),
1554            name: ThinkingTool::name().into(),
1555            raw_input: input.to_string(),
1556            input,
1557            is_input_complete: false,
1558        },
1559    ));
1560
1561    // Input streaming completed
1562    let input = json!({ "content": "Thinking hard!" });
1563    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
1564        LanguageModelToolUse {
1565            id: "1".into(),
1566            name: "thinking".into(),
1567            raw_input: input.to_string(),
1568            input,
1569            is_input_complete: true,
1570        },
1571    ));
1572    fake_model.end_last_completion_stream();
1573    cx.run_until_parked();
1574
1575    let tool_call = expect_tool_call(&mut events).await;
1576    assert_eq!(
1577        tool_call,
1578        acp::ToolCall {
1579            id: acp::ToolCallId("1".into()),
1580            title: "Thinking".into(),
1581            kind: acp::ToolKind::Think,
1582            status: acp::ToolCallStatus::Pending,
1583            content: vec![],
1584            locations: vec![],
1585            raw_input: Some(json!({})),
1586            raw_output: None,
1587        }
1588    );
1589    let update = expect_tool_call_update_fields(&mut events).await;
1590    assert_eq!(
1591        update,
1592        acp::ToolCallUpdate {
1593            id: acp::ToolCallId("1".into()),
1594            fields: acp::ToolCallUpdateFields {
1595                title: Some("Thinking".into()),
1596                kind: Some(acp::ToolKind::Think),
1597                raw_input: Some(json!({ "content": "Thinking hard!" })),
1598                ..Default::default()
1599            },
1600        }
1601    );
1602    let update = expect_tool_call_update_fields(&mut events).await;
1603    assert_eq!(
1604        update,
1605        acp::ToolCallUpdate {
1606            id: acp::ToolCallId("1".into()),
1607            fields: acp::ToolCallUpdateFields {
1608                status: Some(acp::ToolCallStatus::InProgress),
1609                ..Default::default()
1610            },
1611        }
1612    );
1613    let update = expect_tool_call_update_fields(&mut events).await;
1614    assert_eq!(
1615        update,
1616        acp::ToolCallUpdate {
1617            id: acp::ToolCallId("1".into()),
1618            fields: acp::ToolCallUpdateFields {
1619                content: Some(vec!["Thinking hard!".into()]),
1620                ..Default::default()
1621            },
1622        }
1623    );
1624    let update = expect_tool_call_update_fields(&mut events).await;
1625    assert_eq!(
1626        update,
1627        acp::ToolCallUpdate {
1628            id: acp::ToolCallId("1".into()),
1629            fields: acp::ToolCallUpdateFields {
1630                status: Some(acp::ToolCallStatus::Completed),
1631                raw_output: Some("Finished thinking.".into()),
1632                ..Default::default()
1633            },
1634        }
1635    );
1636}
1637
1638#[gpui::test]
1639async fn test_send_no_retry_on_success(cx: &mut TestAppContext) {
1640    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1641    let fake_model = model.as_fake();
1642
1643    let mut events = thread
1644        .update(cx, |thread, cx| {
1645            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1646            thread.send(UserMessageId::new(), ["Hello!"], cx)
1647        })
1648        .unwrap();
1649    cx.run_until_parked();
1650
1651    fake_model.send_last_completion_stream_text_chunk("Hey!");
1652    fake_model.end_last_completion_stream();
1653
1654    let mut retry_events = Vec::new();
1655    while let Some(Ok(event)) = events.next().await {
1656        match event {
1657            ThreadEvent::Retry(retry_status) => {
1658                retry_events.push(retry_status);
1659            }
1660            ThreadEvent::Stop(..) => break,
1661            _ => {}
1662        }
1663    }
1664
1665    assert_eq!(retry_events.len(), 0);
1666    thread.read_with(cx, |thread, _cx| {
1667        assert_eq!(
1668            thread.to_markdown(),
1669            indoc! {"
1670                ## User
1671
1672                Hello!
1673
1674                ## Assistant
1675
1676                Hey!
1677            "}
1678        )
1679    });
1680}
1681
1682#[gpui::test]
1683async fn test_send_retry_on_error(cx: &mut TestAppContext) {
1684    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1685    let fake_model = model.as_fake();
1686
1687    let mut events = thread
1688        .update(cx, |thread, cx| {
1689            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1690            thread.send(UserMessageId::new(), ["Hello!"], cx)
1691        })
1692        .unwrap();
1693    cx.run_until_parked();
1694
1695    fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
1696        provider: LanguageModelProviderName::new("Anthropic"),
1697        retry_after: Some(Duration::from_secs(3)),
1698    });
1699    fake_model.end_last_completion_stream();
1700
1701    cx.executor().advance_clock(Duration::from_secs(3));
1702    cx.run_until_parked();
1703
1704    fake_model.send_last_completion_stream_text_chunk("Hey!");
1705    fake_model.end_last_completion_stream();
1706
1707    let mut retry_events = Vec::new();
1708    while let Some(Ok(event)) = events.next().await {
1709        match event {
1710            ThreadEvent::Retry(retry_status) => {
1711                retry_events.push(retry_status);
1712            }
1713            ThreadEvent::Stop(..) => break,
1714            _ => {}
1715        }
1716    }
1717
1718    assert_eq!(retry_events.len(), 1);
1719    assert!(matches!(
1720        retry_events[0],
1721        acp_thread::RetryStatus { attempt: 1, .. }
1722    ));
1723    thread.read_with(cx, |thread, _cx| {
1724        assert_eq!(
1725            thread.to_markdown(),
1726            indoc! {"
1727                ## User
1728
1729                Hello!
1730
1731                ## Assistant
1732
1733                Hey!
1734            "}
1735        )
1736    });
1737}
1738
1739#[gpui::test]
1740async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
1741    let ThreadTest { thread, model, .. } = setup(cx, TestModel::Fake).await;
1742    let fake_model = model.as_fake();
1743
1744    let mut events = thread
1745        .update(cx, |thread, cx| {
1746            thread.set_completion_mode(agent_settings::CompletionMode::Burn, cx);
1747            thread.send(UserMessageId::new(), ["Hello!"], cx)
1748        })
1749        .unwrap();
1750    cx.run_until_parked();
1751
1752    for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
1753        fake_model.send_last_completion_stream_error(
1754            LanguageModelCompletionError::ServerOverloaded {
1755                provider: LanguageModelProviderName::new("Anthropic"),
1756                retry_after: Some(Duration::from_secs(3)),
1757            },
1758        );
1759        fake_model.end_last_completion_stream();
1760        cx.executor().advance_clock(Duration::from_secs(3));
1761        cx.run_until_parked();
1762    }
1763
1764    let mut errors = Vec::new();
1765    let mut retry_events = Vec::new();
1766    while let Some(event) = events.next().await {
1767        match event {
1768            Ok(ThreadEvent::Retry(retry_status)) => {
1769                retry_events.push(retry_status);
1770            }
1771            Ok(ThreadEvent::Stop(..)) => break,
1772            Err(error) => errors.push(error),
1773            _ => {}
1774        }
1775    }
1776
1777    assert_eq!(
1778        retry_events.len(),
1779        crate::thread::MAX_RETRY_ATTEMPTS as usize
1780    );
1781    for i in 0..crate::thread::MAX_RETRY_ATTEMPTS as usize {
1782        assert_eq!(retry_events[i].attempt, i + 1);
1783    }
1784    assert_eq!(errors.len(), 1);
1785    let error = errors[0]
1786        .downcast_ref::<LanguageModelCompletionError>()
1787        .unwrap();
1788    assert!(matches!(
1789        error,
1790        LanguageModelCompletionError::ServerOverloaded { .. }
1791    ));
1792}
1793
1794/// Filters out the stop events for asserting against in tests
1795fn stop_events(result_events: Vec<Result<ThreadEvent>>) -> Vec<acp::StopReason> {
1796    result_events
1797        .into_iter()
1798        .filter_map(|event| match event.unwrap() {
1799            ThreadEvent::Stop(stop_reason) => Some(stop_reason),
1800            _ => None,
1801        })
1802        .collect()
1803}
1804
1805struct ThreadTest {
1806    model: Arc<dyn LanguageModel>,
1807    thread: Entity<Thread>,
1808    project_context: Entity<ProjectContext>,
1809    fs: Arc<FakeFs>,
1810}
1811
1812enum TestModel {
1813    Sonnet4,
1814    Fake,
1815}
1816
1817impl TestModel {
1818    fn id(&self) -> LanguageModelId {
1819        match self {
1820            TestModel::Sonnet4 => LanguageModelId("claude-sonnet-4-latest".into()),
1821            TestModel::Fake => unreachable!(),
1822        }
1823    }
1824}
1825
1826async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
1827    cx.executor().allow_parking();
1828
1829    let fs = FakeFs::new(cx.background_executor.clone());
1830    fs.create_dir(paths::settings_file().parent().unwrap())
1831        .await
1832        .unwrap();
1833    fs.insert_file(
1834        paths::settings_file(),
1835        json!({
1836            "agent": {
1837                "default_profile": "test-profile",
1838                "profiles": {
1839                    "test-profile": {
1840                        "name": "Test Profile",
1841                        "tools": {
1842                            EchoTool::name(): true,
1843                            DelayTool::name(): true,
1844                            WordListTool::name(): true,
1845                            ToolRequiringPermission::name(): true,
1846                            InfiniteTool::name(): true,
1847                        }
1848                    }
1849                }
1850            }
1851        })
1852        .to_string()
1853        .into_bytes(),
1854    )
1855    .await;
1856
1857    cx.update(|cx| {
1858        settings::init(cx);
1859        Project::init_settings(cx);
1860        agent_settings::init(cx);
1861        gpui_tokio::init(cx);
1862        let http_client = ReqwestClient::user_agent("agent tests").unwrap();
1863        cx.set_http_client(Arc::new(http_client));
1864
1865        client::init_settings(cx);
1866        let client = Client::production(cx);
1867        let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1868        language_model::init(client.clone(), cx);
1869        language_models::init(user_store, client.clone(), cx);
1870
1871        watch_settings(fs.clone(), cx);
1872    });
1873
1874    let templates = Templates::new();
1875
1876    fs.insert_tree(path!("/test"), json!({})).await;
1877    let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
1878
1879    let model = cx
1880        .update(|cx| {
1881            if let TestModel::Fake = model {
1882                Task::ready(Arc::new(FakeLanguageModel::default()) as Arc<_>)
1883            } else {
1884                let model_id = model.id();
1885                let models = LanguageModelRegistry::read_global(cx);
1886                let model = models
1887                    .available_models(cx)
1888                    .find(|model| model.id() == model_id)
1889                    .unwrap();
1890
1891                let provider = models.provider(&model.provider_id()).unwrap();
1892                let authenticated = provider.authenticate(cx);
1893
1894                cx.spawn(async move |_cx| {
1895                    authenticated.await.unwrap();
1896                    model
1897                })
1898            }
1899        })
1900        .await;
1901
1902    let project_context = cx.new(|_cx| ProjectContext::default());
1903    let context_server_registry =
1904        cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1905    let thread = cx.new(|cx| {
1906        Thread::new(
1907            project,
1908            project_context.clone(),
1909            context_server_registry,
1910            templates,
1911            Some(model.clone()),
1912            cx,
1913        )
1914    });
1915    ThreadTest {
1916        model,
1917        thread,
1918        project_context,
1919        fs,
1920    }
1921}
1922
1923#[cfg(test)]
1924#[ctor::ctor]
1925fn init_logger() {
1926    if std::env::var("RUST_LOG").is_ok() {
1927        env_logger::init();
1928    }
1929}
1930
1931fn watch_settings(fs: Arc<dyn Fs>, cx: &mut App) {
1932    let fs = fs.clone();
1933    cx.spawn({
1934        async move |cx| {
1935            let mut new_settings_content_rx = settings::watch_config_file(
1936                cx.background_executor(),
1937                fs,
1938                paths::settings_file().clone(),
1939            );
1940
1941            while let Some(new_settings_content) = new_settings_content_rx.next().await {
1942                cx.update(|cx| {
1943                    SettingsStore::update_global(cx, |settings, cx| {
1944                        settings.set_user_settings(&new_settings_content, cx)
1945                    })
1946                })
1947                .ok();
1948            }
1949        }
1950    })
1951    .detach();
1952}